/*
 * Decompiled with CFR 0.152.
 */
package ghidra.app.plugin.assembler.sleigh.expr;

import ghidra.app.plugin.assembler.sleigh.expr.AbstractBinaryExpressionSolver;
import ghidra.app.plugin.assembler.sleigh.expr.DefaultSolverHint;
import ghidra.app.plugin.assembler.sleigh.expr.MaskedLong;
import ghidra.app.plugin.assembler.sleigh.expr.NeedsBackfillException;
import ghidra.app.plugin.assembler.sleigh.expr.SolverException;
import ghidra.app.plugin.assembler.sleigh.expr.SolverHint;
import ghidra.app.plugin.assembler.sleigh.expr.match.ExpressionMatcher;
import ghidra.app.plugin.assembler.sleigh.sem.AbstractAssemblyResolutionFactory;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolution;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolvedError;
import ghidra.app.plugin.assembler.sleigh.sem.AssemblyResolvedPatterns;
import ghidra.app.plugin.assembler.sleigh.util.DbgTimer;
import ghidra.app.plugin.processors.sleigh.expression.BinaryExpression;
import ghidra.app.plugin.processors.sleigh.expression.ConstantValue;
import ghidra.app.plugin.processors.sleigh.expression.LeftShiftExpression;
import ghidra.app.plugin.processors.sleigh.expression.OrExpression;
import ghidra.app.plugin.processors.sleigh.expression.PatternExpression;
import ghidra.app.plugin.processors.sleigh.expression.PatternValue;
import ghidra.app.plugin.processors.sleigh.expression.RightShiftExpression;
import ghidra.app.plugin.processors.sleigh.expression.SubExpression;
import ghidra.util.Msg;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

public class OrExpressionSolver
extends AbstractBinaryExpressionSolver<OrExpression> {
    protected static final Matchers MATCHERS = new Matchers();

    public OrExpressionSolver() {
        super(OrExpression.class);
    }

    @Override
    public MaskedLong compute(MaskedLong lval, MaskedLong rval) {
        return lval.or(rval);
    }

    @Override
    public MaskedLong computeLeft(MaskedLong rval, MaskedLong goal) throws SolverException {
        return goal.invOr(rval);
    }

    protected AssemblyResolution tryCatenationExpression(AbstractAssemblyResolutionFactory<?, ?> factory, OrExpression exp, MaskedLong goal, Map<String, Long> vals, AssemblyResolvedPatterns cur, Set<SolverHint> hints, String description) throws SolverException {
        TreeMap<Long, PatternExpression> fields = new TreeMap<Long, PatternExpression>();
        this.collectComponentsOr(exp, 0L, fields, vals, cur);
        fields.computeIfAbsent(0L, __ -> new ConstantValue(0L));
        fields.put(64L, new ConstantValue(0L));
        long lo = 0L;
        PatternExpression fieldExp = null;
        Object result = factory.nop(description);
        try (DbgTimer.DbgCtx dc = this.dbg.start("Trying solution of field catenation");){
            this.dbg.println("Original: " + String.valueOf(goal) + ":= " + String.valueOf(exp));
            for (Map.Entry ent : fields.entrySet()) {
                long hi = (Long)ent.getKey();
                if (hi == 0L) {
                    fieldExp = (PatternExpression)ent.getValue();
                    continue;
                }
                this.dbg.println("Part(" + hi + ":" + lo + "]:= " + String.valueOf(fieldExp));
                MaskedLong part = goal.shiftLeft(64L - hi).shiftRightPositional(64L - hi + lo);
                this.dbg.println("Solving: " + String.valueOf(part) + ":= " + String.valueOf(fieldExp));
                AssemblyResolution sol = this.solver.solve(factory, fieldExp, part, vals, cur, hints, description + " with shift " + lo);
                if (sol.isError()) {
                    AssemblyResolution assemblyResolution = sol;
                    return assemblyResolution;
                }
                if ((result = result.combine((AssemblyResolvedPatterns)sol)) == null) {
                    throw new SolverException("Solutions to individual fields produced conflict");
                }
                lo = hi;
                fieldExp = (PatternExpression)ent.getValue();
            }
        }
        return result;
    }

    protected AssemblyResolution tryCircularShiftExpression(AbstractAssemblyResolutionFactory<?, ?> factory, OrExpression exp, MaskedLong goal, Map<String, Long> vals, AssemblyResolvedPatterns cur, Set<SolverHint> hints, String description) throws SolverException {
        PatternExpression c;
        MaskedLong cc;
        SubExpression sub;
        int leftdir;
        if (exp.getLeft() instanceof LeftShiftExpression && exp.getRight() instanceof RightShiftExpression) {
            leftdir = 0;
        } else if (exp.getLeft() instanceof RightShiftExpression && exp.getRight() instanceof LeftShiftExpression) {
            leftdir = 1;
        } else {
            throw new SolverException("Not a circular shift");
        }
        BinaryExpression left = (BinaryExpression)exp.getLeft();
        BinaryExpression right = (BinaryExpression)exp.getRight();
        PatternExpression expValu1 = left.getLeft();
        PatternExpression expValu2 = right.getLeft();
        if (!expValu1.equals(expValu2)) {
            throw new SolverException("Not a circular shift");
        }
        PatternExpression expShift = null;
        int size = -1;
        int dir = -1;
        PatternExpression s1 = left.getRight();
        PatternExpression s2 = right.getRight();
        if (s1 instanceof SubExpression && (expShift = (sub = (SubExpression)s1).getRight()).equals(s2) && (cc = this.solver.getValue(c = sub.getLeft(), vals, cur)).isFullyDefined()) {
            dir = 1 - leftdir;
            size = (int)cc.longValue();
        }
        if (dir == -1 && s2 instanceof SubExpression && (expShift = (sub = (SubExpression)s2).getRight()).equals(s1) && (cc = this.solver.getValue(c = sub.getLeft(), vals, cur)).isFullyDefined()) {
            dir = leftdir;
            size = (int)cc.longValue();
        }
        if (dir == -1) {
            throw new SolverException("Not a circular shift (or of known size)");
        }
        this.dbg.println("Identified circular shift: value:= " + String.valueOf(expValu1) + ", shift:= " + String.valueOf(expShift) + ", size:= " + size + ", dir:= " + (dir == 1 ? "right" : "left"));
        return this.solveLeftCircularShift(factory, expValu1, expShift, size, dir, goal, vals, cur, hints, description);
    }

    protected AssemblyResolution solveLeftCircularShift(AbstractAssemblyResolutionFactory<?, ?> factory, PatternExpression expValue, PatternExpression expShift, int size, int dir, MaskedLong goal, Map<String, Long> vals, AssemblyResolvedPatterns cur, Set<SolverHint> hints, String description) throws NeedsBackfillException, SolverException {
        MaskedLong valValue = this.solver.getValue(expValue, vals, cur);
        MaskedLong valShift = this.solver.getValue(expShift, vals, cur);
        if (valValue != null && !valValue.isFullyDefined()) {
            if (!valValue.isFullyUndefined()) {
                this.dbg.println("Partially-defined f for left circular shift solver: " + String.valueOf(valValue));
            }
            valValue = null;
        }
        if (valShift != null && valShift.isFullyDefined()) {
            if (!valShift.isFullyUndefined()) {
                this.dbg.println("Partially-defined g for left circular shift solver: " + String.valueOf(valShift));
            }
            valShift = null;
        }
        if (valValue != null && valShift != null) {
            throw new AssertionError((Object)"Should not have constants when solving special forms");
        }
        if (valValue != null) {
            return this.solver.solve(factory, expShift, this.computeCircShiftG(valValue, size, dir, goal), vals, cur, hints, description);
        }
        if (valShift != null) {
            return this.solver.solve(factory, expValue, this.computeCircShiftF(valShift, size, dir, goal), vals, cur, hints, description);
        }
        if (hints.contains(DefaultSolverHint.GUESSING_CIRCULAR_SHIFT_AMOUNT)) {
            throw new SolverException("Already guessing circular shift amount. Try to express a double-shift as a shift by sum.");
        }
        Set<SolverHint> hintsWithCircularShift = SolverHint.with(hints, DefaultSolverHint.GUESSING_CIRCULAR_SHIFT_AMOUNT);
        for (int shift = 0; shift < size; ++shift) {
            try {
                MaskedLong reqShift = MaskedLong.fromLong(shift);
                MaskedLong reqValue = this.computeCircShiftF(reqShift, size, dir, goal);
                AssemblyResolution resValue = this.solver.solve(factory, expValue, reqValue, vals, cur, hintsWithCircularShift, description);
                if (resValue.isError()) {
                    AssemblyResolvedError err = (AssemblyResolvedError)resValue;
                    throw new SolverException("Solving f failed: " + err.getError());
                }
                AssemblyResolution resShift = this.solver.solve(factory, expShift, reqShift, vals, cur, hints, description);
                if (resShift.isError()) {
                    AssemblyResolvedError err = (AssemblyResolvedError)resShift;
                    throw new SolverException("Solving g failed: " + err.getError());
                }
                AssemblyResolvedPatterns solValue = (AssemblyResolvedPatterns)resValue;
                AssemblyResolvedPatterns solShift = (AssemblyResolvedPatterns)resShift;
                AssemblyResolvedPatterns sol = solValue.combine(solShift);
                if (sol == null) {
                    throw new SolverException("value and shift solutions conflict for shift=" + shift);
                }
                return sol;
            }
            catch (SolverException | UnsupportedOperationException e) {
                Msg.trace((Object)this, (Object)("Shift of " + shift + " resulted in " + String.valueOf(e)));
                continue;
            }
        }
        throw new SolverException("Could not solve circular shift with variable bits and shift amount");
    }

    protected MaskedLong computeCircShiftG(MaskedLong fval, int size, int dir, MaskedLong goal) throws SolverException {
        long acc = 0L;
        for (int i = 0; i < size; ++i) {
            if (!fval.shiftCircular(i, size, dir).agrees(goal)) continue;
            return MaskedLong.fromLong(i);
        }
        if (Long.bitCount(acc) == 1) {
            return MaskedLong.fromLong(Long.numberOfTrailingZeros(acc));
        }
        throw new SolverException("Cannot solve for the circular shift amount");
    }

    protected MaskedLong computeCircShiftF(MaskedLong gval, int size, int dir, MaskedLong goal) {
        return goal.shiftCircular(gval, size, 1 - dir);
    }

    @Override
    protected AssemblyResolution solveTwoSided(AbstractAssemblyResolutionFactory<?, ?> factory, OrExpression exp, MaskedLong goal, Map<String, Long> vals, AssemblyResolvedPatterns cur, Set<SolverHint> hints, String description) throws NeedsBackfillException, SolverException {
        try {
            return this.tryCatenationExpression(factory, exp, goal, vals, cur, hints, description);
        }
        catch (Exception e) {
            this.dbg.println("while solving: " + String.valueOf(goal) + "=:" + String.valueOf(exp));
            this.dbg.println(e.getMessage());
            try {
                return this.tryCircularShiftExpression(factory, exp, goal, vals, cur, hints, description);
            }
            catch (Exception e2) {
                this.dbg.println("while solving: " + String.valueOf(goal) + "=:" + String.valueOf(exp));
                this.dbg.println(e2.getMessage());
                Map<ExpressionMatcher<?>, PatternExpression> match = OrExpressionSolver.MATCHERS.neqConst.match(exp);
                if (match != null) {
                    long value = OrExpressionSolver.MATCHERS.val.get(match).getValue();
                    PatternValue field = OrExpressionSolver.MATCHERS.fld.get(match);
                    AssemblyResolution solution = this.solver.solve(factory, field, MaskedLong.fromLong(value), vals, cur, hints, description);
                    if (goal.equals(MaskedLong.fromMaskAndValue(0L, 1L))) {
                        return solution;
                    }
                    if (goal.equals(MaskedLong.fromMaskAndValue(1L, 1L))) {
                        if (solution.isError()) {
                            return factory.nop(description);
                        }
                        if (solution.isBackfill()) {
                            throw new AssertionError();
                        }
                        AssemblyResolvedPatterns forbidden = (AssemblyResolvedPatterns)solution;
                        forbidden = forbidden.withDescription("Solved 'not equals'");
                        Object rp = factory.nop(description);
                        return rp.withForbids(Set.of(forbidden));
                    }
                }
                throw new SolverException("Could not solve two-sided OR");
            }
        }
    }

    void collectComponents(PatternExpression exp, long shift, Map<Long, PatternExpression> components, Map<String, Long> vals, AssemblyResolvedPatterns cur) throws SolverException {
        if (exp instanceof OrExpression) {
            this.collectComponentsOr((OrExpression)exp, shift, components, vals, cur);
        } else if (exp instanceof LeftShiftExpression) {
            this.collectComponentsLeft((LeftShiftExpression)exp, shift, components, vals, cur);
        } else if (exp instanceof RightShiftExpression) {
            this.collectComponentsRight((RightShiftExpression)exp, shift, components, vals, cur);
        } else {
            assert (shift < 64L);
            PatternExpression conflict = components.put(shift, exp);
            if (conflict != null) {
                throw new SolverException("Two 'fields' at the same shift indicates conflict");
            }
        }
    }

    void collectComponentsOr(OrExpression exp, long shift, Map<Long, PatternExpression> components, Map<String, Long> vals, AssemblyResolvedPatterns cur) throws SolverException {
        this.collectComponents(exp.getLeft(), shift, components, vals, cur);
        this.collectComponents(exp.getRight(), shift, components, vals, cur);
    }

    void collectComponentsLeft(LeftShiftExpression exp, long shift, Map<Long, PatternExpression> components, Map<String, Long> vals, AssemblyResolvedPatterns cur) throws SolverException {
        MaskedLong adj;
        try {
            adj = this.solver.getValue(exp.getRight(), vals, cur);
        }
        catch (NeedsBackfillException e) {
            throw new SolverException("Variable shifts break field catenation solver", e);
        }
        if (adj == null || !adj.isFullyDefined()) {
            throw new SolverException("Variable shifts break field catenation solver");
        }
        this.collectComponents(exp.getLeft(), shift + adj.val, components, vals, cur);
    }

    void collectComponentsRight(RightShiftExpression exp, long shift, Map<Long, PatternExpression> components, Map<String, Long> vals, AssemblyResolvedPatterns cur) throws SolverException {
        MaskedLong adj;
        try {
            adj = this.solver.getValue(exp.getRight(), vals, cur);
        }
        catch (NeedsBackfillException e) {
            throw new SolverException("Variable shifts break field catenation solver", e);
        }
        if (adj == null || !adj.isFullyDefined()) {
            throw new SolverException("Variable shifts break field catenation solver");
        }
        this.collectComponents(exp.getLeft(), shift - adj.val, components, vals, cur);
    }

    protected static class Matchers
    implements ExpressionMatcher.Context {
        protected ExpressionMatcher<ConstantValue> val = this.var(ConstantValue.class);
        protected ExpressionMatcher<ConstantValue> size = this.var(ConstantValue.class);
        protected ExpressionMatcher<PatternValue> fld = this.fldSz(this.size);
        protected ExpressionMatcher<?> neqConst = this.or(this.and(this.shr(this.sub(this.opnd(this.fld), this.val), this.size), this.cv(1L)), this.and(this.shr(this.sub(this.val, this.opnd(this.fld)), this.size), this.cv(1L)));

        protected Matchers() {
        }
    }
}

