package aprove.Framework.Haskell.Typing;

import aprove.Framework.Haskell.BasicTerms.Apply;
import aprove.Framework.Haskell.BasicTerms.BasicTerm;
import aprove.Framework.Haskell.BasicTerms.Cons;
import aprove.Framework.Haskell.BasicTerms.Var;
import aprove.Framework.Haskell.Collectors.FreeVarSymCollector;
import aprove.Framework.Haskell.Declarations.DataDecl;
import aprove.Framework.Haskell.Declarations.PatDecl;
import aprove.Framework.Haskell.Declarations.SynTypeDecl;
import aprove.Framework.Haskell.Expressions.HaskellExp;
import aprove.Framework.Haskell.Expressions.LetExp;
import aprove.Framework.Haskell.Expressions.TypeExp;
import aprove.Framework.Haskell.Function;
import aprove.Framework.Haskell.HaskellError;
import aprove.Framework.Haskell.HaskellNamedSym;
import aprove.Framework.Haskell.HaskellObject;
import aprove.Framework.Haskell.HaskellRule;
import aprove.Framework.Haskell.HaskellSubstitution;
import aprove.Framework.Haskell.HaskellSym;
import aprove.Framework.Haskell.HaskellVisitor;
import aprove.Framework.Haskell.Literals.CharLit;
import aprove.Framework.Haskell.Literals.FloatLit;
import aprove.Framework.Haskell.Literals.IntegerLit;
import aprove.Framework.Haskell.Modules.EntityFrame;
import aprove.Framework.Haskell.Modules.Group;
import aprove.Framework.Haskell.Modules.HaskellEntity;
import aprove.Framework.Haskell.Modules.Module;
import aprove.Framework.Haskell.Modules.Modules;
import aprove.Framework.Haskell.Modules.Prelude;
import aprove.Framework.Haskell.Patterns.PlusPat;
import aprove.Framework.Haskell.Substitutors.VarSubstitutor;
import aprove.Framework.Utility.Copy;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import java.util.Vector;

/* loaded from: input_file:aprove/Framework/Haskell/Typing/TypeInferenceVisitor.class */
public class TypeInferenceVisitor extends OmegaVisitor {
    Set<HaskellEntity.Sort> VALUESOF;
    ClassConstraintGraph ccg;
    Stack<EntityFrame> arguments;
    List<Cons> defaultList;
    Vector<HaskellObject> typeAnnos;
    private static long forTimeSum = 0;
    NoQuanStack noQuanStack;
    Stack<Set<ClassConstraint>> constraintStack;
    HaskellSubstitution currentRefine;

    public TypeInferenceVisitor(Prelude prelude, Assumptions assumptions, ClassConstraintGraph classConstraintGraph) {
        super(assumptions, prelude);
        this.VALUESOF = EnumSet.of(HaskellEntity.Sort.VAR, HaskellEntity.Sort.IVAR, HaskellEntity.Sort.PATDECL);
        this.typeAnnos = new Vector<>();
        this.noQuanStack = new NoQuanStack();
        this.constraintStack = new Stack<>();
        this.ccg = classConstraintGraph;
        this.arguments = new Stack<>();
        this.defaultList = null;
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor
    public TypeSchema getTypeSchema(HaskellEntity haskellEntity) {
        return haskellEntity.getType() != null ? ((TypeSchema) haskellEntity.getType()).getFreshInstance() : getAssumptionFor(haskellEntity).getFreshInstance();
    }

    public TypeSchema getAssumptionFor(HaskellEntity haskellEntity) {
        TypeSchema typeSchemaFor = this.assumptions.getTypeSchemaFor(haskellEntity);
        if (typeSchemaFor == null) {
            typeSchemaFor = TypeSchema.create(Var.createFreshVar());
            this.assumptions.pushAssumption(haskellEntity, typeSchemaFor);
        }
        return typeSchemaFor.getFreshInstance();
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor
    public HaskellSubstitution mgu(BasicTerm basicTerm, BasicTerm basicTerm2, HaskellObject haskellObject) {
        BasicTerm applyTo;
        BasicTerm applyTo2;
        if (this.currentRefine == null) {
            applyTo = (BasicTerm) Copy.deep(basicTerm);
            applyTo2 = (BasicTerm) Copy.deep(basicTerm2);
        } else {
            applyTo = this.currentRefine.applyTo(basicTerm);
            applyTo2 = this.currentRefine.applyTo(basicTerm2);
        }
        HaskellSubstitution mgu = BasicTerm.Tools.mgu(applyTo, applyTo2);
        if (mgu == null) {
            System.err.println("mgu1:" + applyTo);
            System.err.println("mgu2:" + applyTo2);
            HaskellSym.showee(new Apply(applyTo, applyTo2));
            HaskellError.output(haskellObject, "types are not unifiable");
        }
        return mgu;
    }

    public boolean restrictedGroup(Set<HaskellEntity> set) {
        for (HaskellEntity haskellEntity : set) {
            if (haskellEntity.getSort() == HaskellEntity.Sort.PATDECL) {
                return true;
            }
            if (haskellEntity.getSort() == HaskellEntity.Sort.VAR && (haskellEntity.getValue() instanceof Function) && ((Function) haskellEntity.getValue()).isSimplePattern() && haskellEntity.getType() == null) {
                return true;
            }
        }
        return false;
    }

    public HaskellSubstitution solveAmbiguousConstraints(Set<ClassConstraint> set) {
        boolean z = false;
        boolean z2 = true;
        for (ClassConstraint classConstraint : set) {
            z2 = z2 && classConstraint.isInPrelude();
            if (classConstraint.isNumSubClass(this.ccg)) {
                z = true;
            }
        }
        if (!z || !z2) {
            return null;
        }
        for (Cons cons : this.defaultList) {
            Iterator<ClassConstraint> it = set.iterator();
            while (it.hasNext()) {
                if (!it.next().solvedBy(cons, this.ccg.getRules())) {
                    break;
                }
            }
            return new HaskellSubstitution((Var) set.iterator().next().getType(), (BasicTerm) Copy.deep(cons));
        }
        System.err.println();
        return null;
    }

    public void solveAllAmbiguousConstraints(HaskellObject haskellObject, Set<Set<ClassConstraint>> set) {
        delayedRefine();
        while (set.size() > 0) {
            Set<ClassConstraint> next = set.iterator().next();
            set.remove(next);
            HaskellSubstitution solveAmbiguousConstraints = solveAmbiguousConstraints(next);
            if (solveAmbiguousConstraints != null) {
                directRefine(solveAmbiguousConstraints);
                HashSet hashSet = new HashSet();
                for (Set<ClassConstraint> set2 : set) {
                    HashSet hashSet2 = new HashSet();
                    Iterator<ClassConstraint> it = set2.iterator();
                    while (it.hasNext()) {
                        hashSet2.add(it.next().apply(solveAmbiguousConstraints));
                    }
                    hashSet.add(hashSet2);
                }
                set = hashSet;
            } else {
                HaskellError.output(haskellObject, " " + next + " Constraint contains ambiguous type variable");
            }
        }
    }

    public void setDefaultList(List<Cons> list) {
        this.defaultList = list;
        if (this.defaultList == null) {
            this.defaultList = this.prelude.getDefaultList();
        }
    }

    public void localGroup(Group group, boolean z) {
        TypeSchema typeSchema;
        delayedRefine();
        int size = this.typeAnnos.size();
        if (z || !(group.isPreludeGroup() || group.isAlreadyLoadedGroup())) {
            if (group.isMultiGroup()) {
                HaskellError.output((HaskellObject) group.iterator().next(), "mutual recursive block " + group + " overlaps module borders");
            }
            setDefaultList(group.getGroupModule().getDefaultList());
            boolean restrictedGroup = restrictedGroup(group);
            int[] iArr = new int[group.size()];
            int i = 0;
            Iterator it = group.iterator();
            while (it.hasNext()) {
                HaskellEntity haskellEntity = (HaskellEntity) it.next();
                int size2 = this.typeAnnos.size();
                if (haskellEntity.getValue() != null) {
                    this.noQuanStack.pushNewGroup();
                    this.noQuanStack.addHoToPeekGroup(getAssumptionFor(haskellEntity));
                    haskellEntity.visit(this);
                    push(getAssumptionFor(haskellEntity));
                    push(massMgu(2, haskellEntity));
                    delayedRefine();
                    TypeSchema peek = peek();
                    this.noQuanStack.popGroup();
                    if (restrictedGroup) {
                        this.constraintStack.peek().addAll(peek.getConstraints());
                    } else {
                        reduce(peek.getConstraints());
                        this.constraintStack.peek().addAll(removeSurroundingConstraints(peek, this.noQuanStack.unitedGroups()));
                        solveAllAmbiguousConstraints(haskellEntity.getValue(), peek.ambiguousConstraints(true));
                        reduce(peek.getConstraints());
                        this.ccg.checkConstraints(peek.getConstraints(), haskellEntity);
                    }
                    typeSchema = pop();
                } else {
                    typeSchema = (TypeSchema) haskellEntity.getType();
                }
                iArr[i] = this.typeAnnos.size() - size2;
                this.assumptions.pushAssumption(haskellEntity, typeSchema);
                i++;
            }
            delayedRefine();
            reduce(this.constraintStack.peek());
            delayedRefine();
            if (restrictedGroup) {
                Iterator<ClassConstraint> it2 = this.constraintStack.peek().iterator();
                while (it2.hasNext()) {
                    this.noQuanStack.addHoToPeekGroup(it2.next());
                }
            }
            long currentTimeMillis = System.currentTimeMillis();
            int i2 = 0;
            Iterator it3 = group.iterator();
            while (it3.hasNext()) {
                HaskellEntity haskellEntity2 = (HaskellEntity) it3.next();
                delayedRefine();
                TypeSchema typeSchemaFor = this.assumptions.getTypeSchemaFor(haskellEntity2);
                reduce(typeSchemaFor.getConstraints());
                int i3 = iArr[i2];
                if (haskellEntity2.getType() != null) {
                    if (restrictedGroup) {
                    }
                    push(typeSchemaFor);
                    specializeWith(haskellEntity2, (TypeSchema) haskellEntity2.getType());
                    typeSchemaFor = pop();
                }
                if (!typeSchemaFor.autoQuantor(this.noQuanStack.unitedGroups())) {
                }
                size += i3;
                this.assumptions.pushAssumption(haskellEntity2, typeSchemaFor);
                if (!restrictedGroup) {
                    this.constraintStack.peek().addAll(getFreeConstraints(typeSchemaFor));
                }
                i2++;
            }
            delayedRefine();
            forTimeSum += System.currentTimeMillis() - currentTimeMillis;
        }
    }

    public Set<ClassConstraint> localGrouping(List<Group> list, boolean z) {
        this.constraintStack.push(new HashSet());
        this.noQuanStack.pushNewGroup();
        Iterator<Group> it = list.iterator();
        while (it.hasNext()) {
            localGroup(it.next(), z);
        }
        this.noQuanStack.popGroup();
        return this.constraintStack.pop();
    }

    public Set<ClassConstraint> getFreeConstraints(TypeSchema typeSchema) {
        HashSet hashSet = new HashSet();
        for (ClassConstraint classConstraint : typeSchema.getConstraints()) {
            HashSet hashSet2 = new HashSet();
            classConstraint.visit(new FreeVarSymCollector(hashSet2));
            hashSet2.removeAll(typeSchema.getQuantor());
            if (hashSet2.size() > 0) {
                hashSet.add(classConstraint);
            }
        }
        return hashSet;
    }

    public Set<ClassConstraint> removeSurroundingConstraints(TypeSchema typeSchema, Set<HaskellSym> set) {
        HashSet hashSet = new HashSet();
        Iterator<ClassConstraint> it = typeSchema.getConstraints().iterator();
        while (it.hasNext()) {
            ClassConstraint next = it.next();
            HashSet hashSet2 = new HashSet();
            next.visit(new FreeVarSymCollector(hashSet2));
            if (hashSet2.size() > 0) {
                hashSet2.removeAll(set);
                if (hashSet2.size() == 0) {
                    hashSet.add(next);
                    it.remove();
                }
            }
        }
        return hashSet;
    }

    public Set<HaskellEntity> filterEntities(Set<HaskellEntity> set) {
        HashSet hashSet = new HashSet();
        for (HaskellEntity haskellEntity : set) {
            if (this.VALUESOF.contains(haskellEntity.getSort())) {
                hashSet.add(haskellEntity);
            }
        }
        return hashSet;
    }

    public TypeSchema buildConstraintedTyVar(String str) {
        Var createFreshVar = Var.createFreshVar();
        ClassConstraint classConstraint = new ClassConstraint(this.prelude.createSymbolRef(str, HaskellEntity.Sort.TYCLASS), createFreshVar);
        HashSet hashSet = new HashSet();
        hashSet.add(classConstraint);
        return TypeSchema.create(hashSet, createFreshVar);
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseCharLit(CharLit charLit) {
        push(TypeSchema.create(new Cons(this.prelude.createSymbolRef("Char", HaskellEntity.Sort.TYCONS))));
        return leave(charLit);
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseFloatLit(FloatLit floatLit) {
        push(buildConstraintedTyVar("Fractional"));
        return leave(floatLit);
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseIntegerLit(IntegerLit integerLit) {
        push(buildConstraintedTyVar("Num"));
        return leave(integerLit);
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject casePlusPat(PlusPat plusPat) {
        push(buildConstraintedTyVar("Integral"));
        push(massMgu(3, plusPat));
        return leave(plusPat);
    }

    public void specializeWith(HaskellObject haskellObject, TypeSchema typeSchema) {
        delayedRefine();
        TypeSchema pop = pop();
        push(pop);
        TypeSchema freshCopy = typeSchema.getFreshCopy();
        HashSet hashSet = new HashSet(freshCopy.getQuantor());
        HaskellSubstitution match = pop.match(freshCopy, this.ccg);
        if (match == null) {
            HaskellSym.showee(new Apply(pop, freshCopy));
            HaskellError.output(haskellObject, ("         " + pop + " II-->|>-- " + freshCopy) + "  infered type " + pop + "is not general enough for signature " + typeSchema);
        } else {
            directRefine(match);
            pop.setConstraints((Set) Copy.deepCol(typeSchema.getConstraints()));
        }
        hashSet.retainAll(this.noQuanStack.unitedGroups());
        if (hashSet.size() > 0) {
            HaskellError.output((HaskellObject) hashSet.iterator().next(), "infered type " + pop + " is not general enough for signature " + typeSchema);
        }
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseTypeExp(TypeExp typeExp) {
        specializeWith(typeExp, typeExp.getTypeSchema());
        return leave(typeExp);
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseModule(Module module) {
        return module;
    }

    public HaskellObject forTerm(HaskellExp haskellExp, Module module) {
        setDefaultList(module.getDefaultList());
        haskellExp.visit(this);
        TypeSchema pop = pop();
        pop.ambiguousConstraints(false);
        delayedRefine();
        reduce(pop.getConstraints());
        this.ccg.checkConstraints(pop.getConstraints(), haskellExp);
        pop.autoQuantor();
        return pop;
    }

    public void forTuples(Modules modules) {
        if (modules.getMainModule() != null) {
            localGrouping(modules.getPrelude().getTupleGroups(), true);
        }
    }

    public void forModules(Modules modules) {
        long currentTimeMillis = System.currentTimeMillis();
        forTuples(modules);
        List<Group> groups = modules.getGroups();
        localGrouping(groups, false);
        for (Group group : groups) {
            setDefaultList(group.getGroupModule().getDefaultList());
            if (!group.isPreludeGroup() && !group.isAlreadyLoadedGroup()) {
                Iterator it = group.iterator();
                while (it.hasNext()) {
                    HaskellEntity haskellEntity = (HaskellEntity) it.next();
                    delayedRefine();
                    TypeSchema typeSchemaFor = this.assumptions.getTypeSchemaFor(haskellEntity);
                    solveAllAmbiguousConstraints(haskellEntity.getValue(), typeSchemaFor.ambiguousConstraints(false));
                    reduce(typeSchemaFor.getConstraints());
                    typeSchemaFor.autoQuantor();
                    this.assumptions.pushAssumption(haskellEntity, typeSchemaFor);
                }
            }
        }
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public void icaseLetExp(LetExp letExp) {
        push(TypeSchema.create(localGrouping(letExp.getGroups(), false), Var.createFreshVar()));
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseLetExp(LetExp letExp) {
        delayedRefine();
        TypeSchema pop = pop();
        pop.getConstraints().addAll(pop().getConstraints());
        push(pop);
        return leave(letExp);
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public void fcaseHaskellRule(HaskellRule haskellRule) {
        this.arguments.push(haskellRule.getEntityFrame());
        this.noQuanStack.pushNewGroup();
        Iterator<HaskellEntity> it = haskellRule.getEntityFrame().getCollectedEntities().iterator();
        while (it.hasNext()) {
            this.noQuanStack.addHoToPeekGroup(getAssumptionFor(it.next()));
        }
        haskellRule.getEntityFrame();
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor, aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseHaskellRule(HaskellRule haskellRule) {
        push(toArrow(haskellRule.getPatterns().size()));
        this.arguments.pop();
        this.noQuanStack.popGroup();
        return leave(haskellRule);
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public HaskellObject caseEntity(HaskellEntity haskellEntity) {
        return haskellEntity;
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor
    public HaskellType buildArrow(HaskellType haskellType, HaskellType haskellType2) {
        return this.prelude.buildArrow(haskellType, haskellType2);
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor
    public TypeSchema getBoolTypeSchema() {
        return this.prelude.getBoolTypeSchema();
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor, aprove.Framework.Haskell.HaskellVisitor, aprove.Framework.Haskell.HaskellObject
    public HaskellObject visit(HaskellVisitor haskellVisitor) {
        super.visit(haskellVisitor);
        Iterator<Set<ClassConstraint>> it = this.constraintStack.iterator();
        while (it.hasNext()) {
            Iterator<ClassConstraint> it2 = it.next().iterator();
            while (it2.hasNext()) {
                walk(it2.next(), haskellVisitor);
            }
        }
        Iterator<HaskellObject> it3 = this.typeAnnos.iterator();
        while (it3.hasNext()) {
            HaskellObject next = it3.next();
            next.setTypeTerm((HaskellType) walk(next.getTypeTerm(), haskellVisitor));
        }
        return this;
    }

    public void directRefine(HaskellSubstitution haskellSubstitution) {
        this.noQuanStack.apply(haskellSubstitution);
        visit(new VarSubstitutor(haskellSubstitution));
    }

    public void delayedRefine() {
        if (this.currentRefine != null) {
            directRefine(this.currentRefine);
            this.currentRefine = null;
        }
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor
    public HaskellSubstitution refine(HaskellSubstitution haskellSubstitution) {
        if (this.currentRefine != null) {
            this.currentRefine = this.currentRefine.combineWith(haskellSubstitution);
        } else {
            this.currentRefine = haskellSubstitution;
        }
        return this.currentRefine;
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor
    public void reduce(Set<ClassConstraint> set) {
        this.ccg.reduce(set);
    }

    @Override // aprove.Framework.Haskell.Typing.OmegaVisitor
    public HaskellObject leave(HaskellObject haskellObject) {
        this.typeAnnos.add(haskellObject);
        haskellObject.setTypeTerm(peek().getMatrix());
        return haskellObject;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardEntity(HaskellEntity haskellEntity) {
        return true;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardEntities(Module module) {
        return false;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardValue(HaskellEntity haskellEntity) {
        return this.VALUESOF.contains(haskellEntity.getSort());
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardLetFrame(LetExp letExp) {
        return false;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardType(HaskellEntity haskellEntity) {
        return false;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardMember(HaskellEntity haskellEntity) {
        return true;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardHaskellNamedSym(HaskellNamedSym haskellNamedSym) {
        return false;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardDefType(SynTypeDecl synTypeDecl) {
        return false;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardConss(DataDecl dataDecl) {
        return false;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardTypeTypeExp(TypeExp typeExp) {
        return false;
    }

    @Override // aprove.Framework.Haskell.HaskellVisitor
    public boolean guardPatDecl(PatDecl patDecl) {
        return false;
    }
}
