package fove;

import blog.ArgSpec;
import blog.ArgSpecQuery;
import blog.AtomicFormula;
import blog.BasicVar;
import blog.BayesNetVar;
import blog.ConjFormula;
import blog.Evidence;
import blog.Formula;
import blog.FormulaQuery;
import blog.FuncAppTerm;
import blog.Function;
import blog.InferenceEngine;
import blog.LogicalVar;
import blog.Model;
import blog.NonRandomFunction;
import blog.Query;
import blog.RandFuncAppVar;
import blog.Term;
import blog.Type;
import common.Util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import test.Statistics;
import ve.Factor;
import ve.Potential;

/* loaded from: input_file:fove/LiftedVarElim.class */
public class LiftedVarElim extends InferenceEngine {
    protected ParMarkovNet pmn;
    private Properties props;
    boolean ondemand;
    protected List<Parfactor> evidenceParfactors;
    protected Map<Query, Set<FuncAppTerm>> conjQueries;
    private static final LiftedInfOperatorOracle GREEDY_OPERATOR_ORACLE = new LiftedInfOperatorOracle() { // from class: fove.LiftedVarElim.1
        @Override // fove.LiftedInfOperatorOracle
        public LiftedInfOperator nextOperator(Set<Parfactor> set, ElimTester elimTester) {
            Collection<LiftedInfOperator> validOps = LiftedInfOperator.validOps(set, elimTester);
            if (validOps.isEmpty()) {
                throw new IllegalStateException("Not done with elimination, and no lifted inference operators can be applied.");
            }
            LiftedInfOperator liftedInfOperator = null;
            double d = 0.0d;
            for (LiftedInfOperator liftedInfOperator2 : validOps) {
                double logCost = liftedInfOperator2.logCost();
                if (liftedInfOperator == null || (logCost < d && (!(liftedInfOperator2 instanceof CountConversion) || !((CountConversion) liftedInfOperator2).getCountedLogvar().getName().equals("Y")))) {
                    liftedInfOperator = liftedInfOperator2;
                    d = logCost;
                }
            }
            return liftedInfOperator;
        }
    };

    public static void liftedElim(Set<Parfactor> set, ElimTester elimTester) {
        liftedElim(set, elimTester, (Collection<? extends Term>) new LinkedList(), GREEDY_OPERATOR_ORACLE, false);
    }

    public static void liftedElim(Set<Parfactor> set, ElimTester elimTester, LiftedInfOperatorOracle liftedInfOperatorOracle) {
        liftedElim(set, elimTester, (Collection<? extends Term>) new LinkedList(), liftedInfOperatorOracle, false);
    }

    public static void shatter(Set<Parfactor> set, Collection<? extends Term> collection) {
        ShatteredParfactorBag shatteredParfactorBag = new ShatteredParfactorBag(set);
        shatteredParfactorBag.splitOnQueryTerms(collection);
        set.clear();
        Iterator<Parfactor> it = shatteredParfactorBag.parfactors().iterator();
        while (it.hasNext()) {
            Iterator<Parfactor> it2 = it.next().makeConstraintsNormalForm().iterator();
            while (it2.hasNext()) {
                set.add(it2.next().simplify());
            }
        }
    }

    public static void shatter(Set<Parfactor> set, Collection<? extends Term> collection, boolean z) {
        ShatteredParfactorBag shatteredParfactorBag = new ShatteredParfactorBag(set);
        if (z) {
            shatteredParfactorBag.splitOnQueryTerms(collection);
        }
        set.clear();
        Iterator<Parfactor> it = shatteredParfactorBag.parfactors().iterator();
        while (it.hasNext()) {
            Iterator<Parfactor> it2 = it.next().makeConstraintsNormalForm().iterator();
            while (it2.hasNext()) {
                set.add(it2.next().simplify());
            }
        }
    }

    private static void liftedElim(Set<Parfactor> set, ElimTester elimTester, Collection<? extends Term> collection, boolean z) {
        liftedElim(set, elimTester, collection, GREEDY_OPERATOR_ORACLE, z);
    }

    public static void liftedElim(Set<Parfactor> set, ElimTester elimTester, Collection<? extends Term> collection, LiftedInfOperatorOracle liftedInfOperatorOracle, boolean z) {
        String checkParfactors = checkParfactors(set);
        if (checkParfactors != null) {
            Util.fatalErrorWithoutStack(checkParfactors);
        }
        shatter(set, collection);
        if (z) {
            System.out.println("\n After splitting: \n");
            System.out.println("----------------------");
            Iterator<Parfactor> it = set.iterator();
            while (it.hasNext()) {
                it.next().print(System.out);
                System.out.println("*");
            }
            System.out.println("----------------------");
        }
        while (hasAnyToElim(set, elimTester)) {
            LiftedInfOperator nextOperator = liftedInfOperatorOracle.nextOperator(set, elimTester);
            if (z) {
                System.out.println();
                System.out.println("Applying op: " + nextOperator);
            }
            nextOperator.operate();
        }
        if (z) {
            System.out.println();
            System.out.println("Final set of Parfactors:\n");
            Iterator<Parfactor> it2 = set.iterator();
            while (it2.hasNext()) {
                it2.next().print(System.out);
                System.out.println();
            }
        }
    }

    public static void liftedElimOnDemand(Set<Parfactor> set, ElimTester elimTester, Collection<? extends Term> collection, boolean z) {
        liftedElimOnDemand(set, elimTester, collection, GREEDY_OPERATOR_ORACLE, z);
    }

    private static void liftedElimOnDemand(Set<Parfactor> set, ElimTester elimTester, Collection<? extends Term> collection, LiftedInfOperatorOracle liftedInfOperatorOracle, boolean z) {
        String checkParfactors = checkParfactors(set);
        if (checkParfactors != null) {
            Util.fatalErrorWithoutStack(checkParfactors);
        }
        GroundQuery groundQuery = new GroundQuery(Collections.emptySet());
        while (true) {
            if (!hasAnyToElim(set, groundQuery)) {
                break;
            }
            LiftedInfOperator nextOperator = liftedInfOperatorOracle.nextOperator(set, groundQuery);
            if (nextOperator instanceof SummingOut) {
                if (collection.iterator().next().makeOverlapSubst(((SummingOut) nextOperator).var()) != null) {
                    shatter(set, collection);
                    break;
                }
            }
            if (z) {
                System.out.println();
                System.out.println("Applying op: " + nextOperator);
            }
            nextOperator.operate();
        }
        while (hasAnyToElim(set, elimTester)) {
            LiftedInfOperator nextOperator2 = liftedInfOperatorOracle.nextOperator(set, elimTester);
            if (z) {
                System.out.println();
                System.out.println("Applying op: " + nextOperator2);
            }
            nextOperator2.operate();
        }
        if (z) {
            System.out.println();
            System.out.println("Final set of Parfactors:\n");
            Iterator<Parfactor> it = set.iterator();
            while (it.hasNext()) {
                it.next().print(System.out);
                System.out.println();
            }
        }
    }

    public static void liftedElim(Set<Parfactor> set, ElimTester elimTester, Collection<? extends Term> collection, boolean z, boolean z2) {
        liftedElim(set, elimTester, collection, GREEDY_OPERATOR_ORACLE, z, z2);
    }

    public static void liftedElim(Set<Parfactor> set, ElimTester elimTester, Collection<? extends Term> collection, LiftedInfOperatorOracle liftedInfOperatorOracle, boolean z, boolean z2) {
        String checkParfactors = checkParfactors(set);
        if (checkParfactors != null) {
            Util.fatalErrorWithoutStack(checkParfactors);
        }
        shatter(set, collection, z);
        if (z2) {
            System.out.println("\n After splitting: \n");
            System.out.println("----------------------");
            Iterator<Parfactor> it = set.iterator();
            while (it.hasNext()) {
                it.next().print(System.out);
                System.out.println("*");
            }
            System.out.println("----------------------");
        }
        while (hasAnyToElim(set, elimTester)) {
            LiftedInfOperator nextOperator = liftedInfOperatorOracle.nextOperator(set, elimTester);
            if (z2) {
                System.out.println();
                System.out.println("Applying op: " + nextOperator);
            }
            nextOperator.operate();
        }
        if (z2) {
            System.out.println();
            System.out.println("Final set of Parfactors:\n");
            Iterator<Parfactor> it2 = set.iterator();
            while (it2.hasNext()) {
                it2.next().print(System.out);
                System.out.println();
            }
        }
    }

    private static String checkParfactors(Collection<Parfactor> collection) {
        for (Parfactor parfactor : collection) {
            Iterator<? extends LogicalVar> it = parfactor.logicalVars().iterator();
            while (it.hasNext()) {
                Type type = it.next().getType();
                if (!type.getPOPs().isEmpty()) {
                    return "Can't handle parfactor quantifying over type " + type + ", which has unknown objects.";
                }
                if (!type.hasFiniteGuaranteed()) {
                    return "Can't handle parfactor quantifying over type " + type + ", which has infinitely many objects.";
                }
            }
            for (Term term : parfactor.dimTerms()) {
                String checkParfactorTerm = checkParfactorTerm(term);
                if (checkParfactorTerm != null) {
                    return "Can't handle parfactor with a dimension defined by " + term + " " + checkParfactorTerm + ".";
                }
            }
        }
        return null;
    }

    private static String checkParfactorTerm(Term term) {
        FuncAppTerm singleSubTerm;
        if (term instanceof FuncAppTerm) {
            singleSubTerm = (FuncAppTerm) term;
        } else {
            if (!(term instanceof CountingTerm)) {
                return "of class " + term.getClass().getName();
            }
            singleSubTerm = ((CountingTerm) term).singleSubTerm();
        }
        Function function = singleSubTerm.getFunction();
        if (function instanceof NonRandomFunction) {
            return "which is application of nonrandom function " + function;
        }
        Type[] argTypes = function.getArgTypes();
        Term[] args = singleSubTerm.getArgs();
        for (int i = 0; i < args.length; i++) {
            String checkParfactorTermArg = checkParfactorTermArg(args[i], argTypes[i]);
            if (checkParfactorTermArg != null) {
                return "whose argument " + args[i] + " " + checkParfactorTermArg;
            }
        }
        return null;
    }

    private static String checkParfactorTermArg(Term term, Type type) {
        if (term instanceof LogicalVar) {
            return null;
        }
        if (!(term instanceof FuncAppTerm)) {
            return "is not a logical variable or function application";
        }
        if (!term.getFreeVars().isEmpty()) {
            return "contains a nested logical variable";
        }
        Object valueIfNonRandom = term.getValueIfNonRandom();
        if (valueIfNonRandom == null) {
            return "is random";
        }
        if (term.equals(type.getCanonicalTerm(valueIfNonRandom))) {
            return null;
        }
        return "is not the canonical term for the object it denotes";
    }

    private static boolean hasAnyToElim(Collection<? extends Parfactor> collection, ElimTester elimTester) {
        for (Parfactor parfactor : collection) {
            Constraint constraint = parfactor.constraint();
            Iterator<? extends Term> it = parfactor.dimTerms().iterator();
            while (it.hasNext()) {
                if (elimTester.shouldEliminate(it.next(), constraint)) {
                    return true;
                }
            }
        }
        return false;
    }

    public LiftedVarElim(Model model, Properties properties) {
        super(model);
        this.ondemand = true;
        this.evidenceParfactors = Collections.emptyList();
        this.conjQueries = Collections.emptyMap();
        this.pmn = new ParMarkovNet(model);
        this.props = properties;
        this.ondemand = Boolean.valueOf(this.props.getProperty("ondemand")).booleanValue();
        if (Util.verbose()) {
            this.pmn.print(System.out);
        }
    }

    public void setEvidenceGroups(Evidence evidence) {
        super.setEvidence(evidence);
        if (!evidence.getSkolemConstants().isEmpty()) {
            throw new IllegalArgumentException("LiftedVarElim doesn't handle symbol evidence.");
        }
        HashMap<String, Set<Term>> hashMap = new HashMap<>();
        HashMap<String, Term> hashMap2 = new HashMap<>();
        HashMap<String, Object> hashMap3 = new HashMap<>();
        this.evidenceParfactors = new ArrayList();
        for (BayesNetVar bayesNetVar : evidence.getEvidenceVars()) {
            if (!(bayesNetVar instanceof RandFuncAppVar)) {
                throw new IllegalArgumentException("Can't handle evidence variable: " + bayesNetVar);
            }
            Object observedValue = evidence.getObservedValue(bayesNetVar);
            FuncAppTerm canonicalTerm = ((RandFuncAppVar) bayesNetVar).getCanonicalTerm();
            if (canonicalTerm.getArgs().length == 1) {
                recordEvidence(canonicalTerm, observedValue, hashMap, hashMap2, hashMap3);
            } else {
                this.evidenceParfactors.add(Parfactor.delta(Collections.emptyList(), Constraint.EMPTY, Collections.singletonList(canonicalTerm), Collections.singletonList(observedValue)));
            }
        }
        this.evidenceParfactors.addAll(groupEvidInPfs(hashMap, hashMap2, hashMap3));
    }

    public void recordEvidence(FuncAppTerm funcAppTerm, Object obj, HashMap<String, Set<Term>> hashMap, HashMap<String, Term> hashMap2, HashMap<String, Object> hashMap3) {
        String str = funcAppTerm.getFunction().getName() + "_" + obj.toString();
        if (!hashMap.keySet().contains(str)) {
            hashMap.put(str, new HashSet());
        }
        hashMap.get(str).add(funcAppTerm.getArgs()[0]);
        if (!hashMap2.keySet().contains(str)) {
            hashMap2.put(str, funcAppTerm);
        }
        if (hashMap3.keySet().contains(str)) {
            return;
        }
        hashMap3.put(str, obj);
    }

    public ArrayList<Parfactor> groupEvidInPfs(HashMap<String, Set<Term>> hashMap, HashMap<String, Term> hashMap2, HashMap<String, Object> hashMap3) {
        ArrayList<Parfactor> arrayList = new ArrayList<>();
        for (String str : hashMap2.keySet()) {
            FuncAppTerm funcAppTerm = (FuncAppTerm) hashMap2.get(str);
            Set<Term> set = hashMap.get(str);
            Object obj = hashMap3.get(str);
            Function function = funcAppTerm.getFunction();
            LogicalVar createVar = LogicalVar.createVar(function.getArgTypes()[0]);
            FuncAppTerm funcAppTerm2 = new FuncAppTerm(function, createVar);
            HashMap hashMap4 = new HashMap();
            hashMap4.put(createVar, set);
            Constraint constraint = new Constraint(hashMap4);
            HashMap hashMap5 = new HashMap();
            hashMap5.put(createVar, constraint.allowedConstants(createVar));
            arrayList.add(Parfactor.delta(Collections.singletonList(createVar), new Constraint(hashMap5), Collections.singletonList(funcAppTerm2), Collections.singletonList(obj)));
        }
        return arrayList;
    }

    @Override // blog.InferenceEngine
    public void setEvidence(Evidence evidence) {
        setEvidenceGroups(evidence);
    }

    @Override // blog.InferenceEngine
    public void setQueries(List list) {
        FuncAppTerm next;
        String checkParfactorTerm;
        super.setQueries(list);
        this.conjQueries = new LinkedHashMap();
        Iterator it = list.iterator();
        while (it.hasNext()) {
            Query query = (Query) it.next();
            Set<FuncAppTerm> queryTerms = getQueryTerms(query);
            if (queryTerms.size() == 1 && (checkParfactorTerm = checkParfactorTerm((next = queryTerms.iterator().next()))) != null) {
                Util.fatalErrorWithoutStack("LiftedVarElim engine can't handle query on term " + next + " " + checkParfactorTerm + ".");
            }
            this.conjQueries.put(query, queryTerms);
        }
    }

    @Override // blog.InferenceEngine
    public void answerQueries() {
        Statistics statistics = Statistics.getInstance();
        if (!this.conjQueries.isEmpty()) {
            for (Map.Entry<Query, Set<FuncAppTerm>> entry : this.conjQueries.entrySet()) {
                LinkedHashSet linkedHashSet = new LinkedHashSet(this.pmn.getParfactors());
                linkedHashSet.addAll(this.evidenceParfactors);
                GroundQuery groundQuery = new GroundQuery(entry.getValue());
                long nanoTime = System.nanoTime();
                if (this.ondemand) {
                    liftedElimOnDemand(linkedHashSet, groundQuery, entry.getValue(), Util.verbose());
                } else {
                    liftedElim(linkedHashSet, groundQuery, entry.getValue(), Util.verbose());
                }
                long nanoTime2 = System.nanoTime();
                recordAnswer(entry.getKey(), entry.getValue(), linkedHashSet);
                statistics.addSplitTime(nanoTime2 - nanoTime);
            }
            return;
        }
        long nanoTime3 = System.nanoTime();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet(this.pmn.getParfactors());
        linkedHashSet2.addAll(this.evidenceParfactors);
        liftedElim(linkedHashSet2, new GroundQuery(Collections.emptySet()), Collections.emptySet(), Util.verbose());
        long nanoTime4 = System.nanoTime();
        Iterator it = linkedHashSet2.iterator();
        while (it.hasNext()) {
            Parfactor simplify = ((Parfactor) it.next()).simplify();
            if (!simplify.logicalVars().isEmpty()) {
                Util.fatalError("Parfactor " + simplify + " still contains logical variables. Simplified: " + simplify.simplify());
            }
            if (!simplify.dimTerms().isEmpty()) {
                Util.fatalError("Parfactor " + simplify + " still contains a non-query term.");
            }
        }
        Parfactor multiply = Parfactor.multiply(new ArrayList(linkedHashSet2));
        System.out.println("Resulting zero-ary potential:");
        multiply.potential().print(System.out);
        System.out.println("\n**TIME**" + (nanoTime4 - nanoTime3));
        statistics.addSplitTime(nanoTime4 - nanoTime3);
    }

    private Set<FuncAppTerm> getQueryTerms(Query query) {
        if (!(query instanceof ArgSpecQuery)) {
            Util.fatalErrorWithoutStack("LiftedVarElim engine can't handle query of class " + query.getClass().getName());
        }
        ArgSpec argSpec = ((ArgSpecQuery) query).argSpec();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        if (argSpec instanceof ConjFormula) {
            addConjuncts(new ArrayList(((ConjFormula) argSpec).getConjuncts()), linkedHashSet);
        } else if (argSpec instanceof FuncAppTerm) {
            linkedHashSet.add((FuncAppTerm) argSpec);
        } else if (!(argSpec instanceof FuncAppTerm)) {
            Util.fatalErrorWithoutStack("With LiftedVarElim engine, queries must be function applications, not " + query);
        }
        return linkedHashSet;
    }

    private void addConjuncts(List<Formula> list, Set<FuncAppTerm> set) {
        Iterator<Formula> it = list.iterator();
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            Formula next = it.next();
            if (next instanceof AtomicFormula) {
                set.add((FuncAppTerm) ((AtomicFormula) next).getTerm());
                it.remove();
            } else if (next instanceof ConjFormula) {
                arrayList.addAll(((ConjFormula) next).getConjuncts());
            } else {
                System.out.println("Warning: I am ignoring something from your conjunctive query: " + next);
            }
        }
        if (list.isEmpty()) {
            return;
        }
        addConjuncts(arrayList, set);
    }

    private FuncAppTerm getQueryTerm(Query query) {
        if (!(query instanceof ArgSpecQuery)) {
            Util.fatalErrorWithoutStack("LiftedVarElim engine can't handle query of class " + query.getClass().getName());
        }
        ArgSpec argSpec = ((ArgSpecQuery) query).argSpec();
        if (!(argSpec instanceof FuncAppTerm)) {
            Util.fatalErrorWithoutStack("With LiftedVarElim engine, queries must be function applications, not " + query);
        }
        return (FuncAppTerm) argSpec;
    }

    private void recordAnswer(Query query, Set<FuncAppTerm> set, Set<Parfactor> set2) {
        boolean booleanValue = Boolean.valueOf(this.props.getProperty("parameterised")).booleanValue();
        ArrayList arrayList = new ArrayList();
        Iterator<Parfactor> it = set2.iterator();
        while (it.hasNext()) {
            Parfactor simplify = it.next().simplify();
            if (!simplify.logicalVars().isEmpty()) {
                Util.fatalError("Parfactor " + simplify + " still contains logical variables. Simplified: " + simplify.simplify());
            }
            if (!booleanValue && !set.containsAll(simplify.dimTerms())) {
                Util.fatalError("Parfactor " + simplify + " still contains a non-query term.");
            } else if (booleanValue) {
                Iterator<? extends Term> it2 = simplify.dimTerms().iterator();
                while (it2.hasNext()) {
                    for (Object obj : it2.next().getSubExprs()) {
                        if (obj instanceof FuncAppTerm) {
                            FuncAppTerm funcAppTerm = (FuncAppTerm) obj;
                            Iterator<FuncAppTerm> it3 = set.iterator();
                            while (true) {
                                if (it3.hasNext()) {
                                    FuncAppTerm next = it3.next();
                                    if (next.getFunction().equals(funcAppTerm)) {
                                        Term[] args = funcAppTerm.getArgs();
                                        Term[] args2 = next.getArgs();
                                        for (int i = 0; i < args.length; i++) {
                                            if (!simplify.constraint().consistent(args2[i].makeOverlapSubst(args[i]))) {
                                                Util.fatalError("Parfactor " + simplify + " still contains a non-query term.");
                                            }
                                            System.out.println(args[i].equals(args2[i]));
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
            arrayList.add(simplify);
        }
        Parfactor multiply = Parfactor.multiply(arrayList);
        if (multiply == null) {
            System.out.println("No cluster contained the query term or something went wrong during solving.\n Query term: " + set.toString());
            return;
        }
        Potential copy = multiply.potential().copy();
        copy.normalize();
        ArrayList arrayList2 = new ArrayList();
        Iterator<? extends Term> it4 = multiply.dimTerms().iterator();
        while (it4.hasNext()) {
            arrayList2.add((BasicVar) it4.next().getVariable());
        }
        if (query instanceof FormulaQuery) {
            ((FormulaQuery) query).setJointPosterior(new Factor(arrayList2, copy));
        } else {
            query.setPosterior(new Factor(arrayList2, copy));
        }
    }

    public void storeStats() {
        Statistics.getInstance().setJtreeValues(-1, -1);
    }

    public ParMarkovNet getPMN() {
        return this.pmn;
    }

    public Model getModel() {
        return this.model;
    }
}
