package aprove.Complexity.LowerBounds.EquationalUnification;

import aprove.Complexity.LowerBounds.BasicStructures.Equation;
import aprove.Complexity.LowerBounds.EquationalUnification.EquationalUnificationRule;
import aprove.Complexity.LowerBounds.EquationalUnification.UnificationProblem;
import aprove.Complexity.LowerBounds.GeneratorEquations.TermGenerator;
import aprove.Complexity.LowerBounds.Types.TrsTypes;
import aprove.Complexity.LowerBounds.Util.Renaming.RenamingCentral;
import aprove.DPFramework.BasicStructures.TRSSubstitution;
import aprove.DPFramework.BasicStructures.TRSTerm;
import aprove.DPFramework.BasicStructures.TRSVariable;
import aprove.Framework.BasicStructures.FunctionSymbol;
import aprove.Framework.Utility.GenericStructures.BidirectionalMap;
import aprove.Framework.Utility.GenericStructures.CollectionMap;
import aprove.Strategies.Abortions.AbortionException;
import immutables.Immutable.ImmutableCreator;
import immutables.Immutable.ImmutableMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:aprove/Complexity/LowerBounds/EquationalUnification/EquationalUnifier.class */
public class EquationalUnifier {
    private EquationalUnificationRule[] rules;
    private CollectionMap<FunctionSymbol, FunctionSymbol> eqSymbols = new CollectionMap<>();
    private RenamingCentral renamingCentral;

    public EquationalUnifier(Iterable<Equation> iterable, RenamingCentral renamingCentral, TrsTypes trsTypes, TermGenerator termGenerator) {
        this.renamingCentral = renamingCentral;
        this.rules = new EquationalUnificationRule[]{new ATermIsVariable(), new EqualRootSymbols(termGenerator), new EqualTerms(), new UnifyModuloEquation(iterable, renamingCentral, trsTypes), new UnifyPolynomials()};
        for (Equation equation : iterable) {
            FunctionSymbol leftRootSymbol = equation.getLeftRootSymbol();
            FunctionSymbol rightRootSymbol = equation.getRightRootSymbol();
            this.eqSymbols.getNotNullAndAdd(leftRootSymbol).add(rightRootSymbol);
            this.eqSymbols.getNotNullAndAdd(rightRootSymbol).add(leftRootSymbol);
        }
    }

    private Set<UnificationProblem> oneStep(UnificationProblem unificationProblem) throws AbortionException {
        UnificationProblem m146clone = unificationProblem.m146clone();
        try {
            Optional<Set<EquationalUnificationRule.Result>> applyOneRule = applyOneRule(m146clone);
            if (!applyOneRule.isPresent()) {
                return null;
            }
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            if (applyOneRule.get().isEmpty()) {
                linkedHashSet.add(m146clone);
            } else {
                for (EquationalUnificationRule.Result result : applyOneRule.get()) {
                    if (result.needsRefinement()) {
                        TRSSubstitution refinement = result.getRefinement();
                        m146clone = m146clone.m146clone();
                        Iterator<UnificationProblem.Entry> it = m146clone.iterator();
                        while (it.hasNext()) {
                            it.next().applySubstitution(refinement);
                        }
                    }
                    if (result.getNewProblem().isPresent()) {
                        linkedHashSet.add(m146clone.union(result.getNewProblem().get()));
                    } else {
                        linkedHashSet.add(m146clone);
                    }
                }
            }
            return linkedHashSet;
        } catch (EquationalUnificationRule.NoUnifierException e) {
            return null;
        }
    }

    private Optional<Set<EquationalUnificationRule.Result>> applyOneRule(UnificationProblem unificationProblem) throws EquationalUnificationRule.NoUnifierException {
        for (EquationalUnificationRule equationalUnificationRule : this.rules) {
            Iterator<UnificationProblem.Entry> it = unificationProblem.iterator();
            while (it.hasNext()) {
                UnificationProblem.Entry next = it.next();
                Optional<Set<EquationalUnificationRule.Result>> apply = equationalUnificationRule.apply(next.getS(), next.getT(), unificationProblem);
                if (apply.isPresent()) {
                    it.remove();
                    return apply;
                }
            }
        }
        return Optional.empty();
    }

    private TRSSubstitution unify(Set<UnificationProblem> set) throws AbortionException {
        boolean z;
        do {
            z = false;
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            Iterator<UnificationProblem> it = set.iterator();
            while (it.hasNext()) {
                UnificationProblem next = it.next();
                Set<UnificationProblem> oneStep = oneStep(next);
                it.remove();
                if (oneStep == null) {
                    TRSSubstitution solution = next.getSolution();
                    if (solution != null) {
                        return solution;
                    }
                } else {
                    z = true;
                    linkedHashSet.addAll(oneStep);
                }
            }
            set.addAll(linkedHashSet);
        } while (z);
        return null;
    }

    public TRSSubstitution unify(TRSTerm tRSTerm, TRSTerm tRSTerm2) throws AbortionException {
        UnificationProblem unificationProblem = new UnificationProblem(tRSTerm, tRSTerm2);
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        linkedHashSet.add(unificationProblem);
        return unify(linkedHashSet);
    }

    public TRSSubstitution match(TRSTerm tRSTerm, TRSTerm tRSTerm2) throws AbortionException {
        BidirectionalMap<TRSTerm, TRSTerm> mapVariablesToFreshConstants = this.renamingCentral.mapVariablesToFreshConstants(tRSTerm2.getVariables());
        TRSSubstitution unify = unify(tRSTerm, tRSTerm2.replaceAll(mapVariablesToFreshConstants.getLRMap()));
        if (unify == null) {
            return null;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<TRSVariable, ? extends TRSTerm> entry : unify.toMap().entrySet()) {
            linkedHashMap.put(entry.getKey(), entry.getValue().replaceAll(mapVariablesToFreshConstants.getRLMap()));
        }
        return TRSSubstitution.create((ImmutableMap<TRSVariable, ? extends TRSTerm>) ImmutableCreator.create((Map) linkedHashMap));
    }

    public boolean matches(TRSTerm tRSTerm, TRSTerm tRSTerm2) {
        return match(tRSTerm, tRSTerm2) != null;
    }
}
