package fove;

import blog.FuncAppTerm;
import blog.LogicalVar;
import blog.Substitution;
import blog.Term;
import common.Util;
import fove.Constraint;
import fove.Parfactor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import test.Statistics;

/* loaded from: input_file:fove/SummingOut.class */
public class SummingOut extends LiftedInfOperator {
    private Set<Parfactor> parfactors;
    private List<Parfactor.TermPtr> targetPtrs;
    private List<Parfactor> pfsToMultiply = null;

    private SummingOut(Set<Parfactor> set, List<Parfactor.TermPtr> list) {
        this.parfactors = set;
        this.targetPtrs = list;
    }

    public Term var() {
        return this.targetPtrs.get(0).term();
    }

    @Override // fove.LiftedInfOperator
    public double logCost() {
        initPfsToMultiply();
        HashSet hashSet = new HashSet();
        Iterator<Parfactor> it = this.pfsToMultiply.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().dimTerms());
        }
        double d = 0.0d;
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            d += Math.log(((Term) it2.next()).getType().range().size());
        }
        return d;
    }

    @Override // fove.LiftedInfOperator
    public void operate() {
        Statistics statistics = Statistics.getInstance();
        statistics.incrOps();
        statistics.incrOpsSO();
        Iterator<Parfactor.TermPtr> it = this.targetPtrs.iterator();
        while (it.hasNext()) {
            this.parfactors.remove(it.next().parfactor());
        }
        initPfsToMultiply();
        Parfactor multiply = Parfactor.multiply(this.pfsToMultiply);
        Term term = this.pfsToMultiply.get(0).dimTerms().get(this.targetPtrs.get(0).index());
        int indexOf = multiply.dimTerms().indexOf(term);
        if (indexOf == -1) {
            throw new IllegalStateException("Target term " + term + " is not in product " + multiply);
        }
        Parfactor sumOut = multiply.sumOut(indexOf);
        for (LogicalVar logicalVar : sumOut.getUnusedVars()) {
            if (Util.verbose()) {
                System.out.println("\tApplying op: Exponentiation on logvar " + logicalVar + ", in parfactor:" + sumOut);
            }
            sumOut = sumOut.exponentiate(logicalVar);
        }
        if (sumOut.dimTerms().size() > 0) {
            this.parfactors.add(sumOut);
        }
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("SummingOut(");
        Parfactor.TermPtr termPtr = this.targetPtrs.get(0);
        stringBuffer.append(termPtr.term());
        stringBuffer.append(" : ");
        stringBuffer.append(termPtr.parfactor().constraint());
        stringBuffer.append(")");
        return stringBuffer.toString();
    }

    private void initPfsToMultiply() {
        if (this.pfsToMultiply != null) {
            return;
        }
        List<LogicalVar> argVars = getArgVars(this.targetPtrs.get(0).term());
        ArrayList arrayList = new ArrayList(argVars.size());
        Iterator<LogicalVar> it = argVars.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        this.pfsToMultiply = new ArrayList(this.targetPtrs.size());
        for (Parfactor.TermPtr termPtr : this.targetPtrs) {
            this.pfsToMultiply.add(termPtr.parfactor().applySubstitution(new Substitution(getArgVars(termPtr.term()), arrayList)));
        }
    }

    public static Collection<LiftedInfOperator> opFactory(Set<Parfactor> set, ElimTester elimTester) {
        SummingOut tryMakeOpForTerm;
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        for (Parfactor parfactor : set) {
            List<? extends Term> dimTerms = parfactor.dimTerms();
            Constraint constraint = parfactor.constraint();
            for (int i = 0; i < dimTerms.size(); i++) {
                Parfactor.TermPtr termPtr = parfactor.termPtr(i);
                if (!hashSet.contains(termPtr) && elimTester.shouldEliminate(termPtr.term(), constraint) && (tryMakeOpForTerm = tryMakeOpForTerm(set, termPtr, hashSet)) != null) {
                    arrayList.add(tryMakeOpForTerm);
                }
            }
        }
        return arrayList;
    }

    private static SummingOut tryMakeOpForTerm(Set<Parfactor> set, Parfactor.TermPtr termPtr, Set<Parfactor.TermPtr> set2) {
        Parfactor parfactor = termPtr.parfactor();
        int index = termPtr.index();
        Term term = termPtr.term();
        Constraint constraint = parfactor.constraint();
        if (!getArgVars(term).containsAll(parfactor.logicalVars())) {
            return null;
        }
        List<? extends Term> dimTerms = parfactor.dimTerms();
        for (int i = 0; i < dimTerms.size(); i++) {
            if (i != index && Constraint.getOverlap(term, constraint, dimTerms.get(i), constraint) != null) {
                set2.add(parfactor.termPtr(i));
                return null;
            }
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(termPtr);
        for (Parfactor parfactor2 : set) {
            if (parfactor2 != parfactor) {
                Constraint constraint2 = parfactor2.constraint();
                List<? extends Term> dimTerms2 = parfactor2.dimTerms();
                boolean z = false;
                for (int i2 = 0; i2 < dimTerms2.size(); i2++) {
                    Constraint.Overlap overlap = Constraint.getOverlap(term, constraint, dimTerms2.get(i2), constraint2);
                    if (overlap != null) {
                        Parfactor.TermPtr termPtr2 = parfactor2.termPtr(i2);
                        set2.add(termPtr2);
                        if (z || !overlap.isFull() || !getArgVars(termPtr2.term()).containsAll(parfactor2.nslogicalVars()) || !termPtr2.term().getSubstResult(overlap.theta()).equals(term.getSubstResult(overlap.theta()))) {
                            return null;
                        }
                        if (parfactor2.nslogicalVars().size() > 0) {
                            z = true;
                        }
                        arrayList.add(termPtr2);
                    }
                }
            }
        }
        return new SummingOut(set, arrayList);
    }

    private static List<LogicalVar> getArgVars(Term term) {
        if (!(term instanceof FuncAppTerm)) {
            if (term instanceof CountingTerm) {
                return getArgVars(((CountingTerm) term).singleSubTerm());
            }
            throw new IllegalArgumentException("Can't get argument variables in term of class " + term.getClass());
        }
        ArrayList arrayList = new ArrayList();
        Term[] args = ((FuncAppTerm) term).getArgs();
        for (int i = 0; i < args.length; i++) {
            if (args[i] instanceof LogicalVar) {
                arrayList.add((LogicalVar) args[i]);
            }
        }
        return arrayList;
    }
}
