package ve;

import blog.BasicVar;
import blog.BayesNetVar;
import blog.Evidence;
import blog.InferenceEngine;
import blog.Model;
import blog.Query;
import common.Util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

/* loaded from: input_file:ve/VarElimEngine.class */
public class VarElimEngine extends InferenceEngine {
    protected MarkovNet net;
    protected List<Factor> evidenceFactors;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:ve/VarElimEngine$VarNode.class */
    public static class VarNode {
        public BasicVar var;
        public double logRangeSize;
        public Set<VarNode> neighbors = new HashSet();
        public double logNeighborsRangeSize = 1.0d;

        public VarNode(BasicVar basicVar) {
            this.var = basicVar;
            this.logRangeSize = Math.log(basicVar.getType().range().size());
        }

        public void addNeighbor(VarNode varNode) {
            if (this.neighbors.add(varNode)) {
                this.logNeighborsRangeSize += varNode.logRangeSize;
            }
        }

        public void removeNeighbor(VarNode varNode) {
            if (this.neighbors.remove(varNode)) {
                this.logNeighborsRangeSize -= varNode.logRangeSize;
            }
        }
    }

    public VarElimEngine(Model model, Properties properties) {
        super(model);
        this.evidenceFactors = Collections.emptyList();
        this.net = new MarkovNet(model);
    }

    @Override // blog.InferenceEngine
    public void setEvidence(Evidence evidence) {
        super.setEvidence(evidence);
        if (!evidence.getSkolemConstants().isEmpty()) {
            throw new IllegalArgumentException("VarElimEngine doesn't handle symbol evidence.");
        }
        this.evidenceFactors = new ArrayList();
        for (BayesNetVar bayesNetVar : evidence.getEvidenceVars()) {
            if (!(bayesNetVar instanceof BasicVar)) {
                throw new IllegalArgumentException("Can't handle non-basic evidence variable: " + bayesNetVar);
            }
            this.evidenceFactors.add(Factor.delta(Collections.singletonList((BasicVar) bayesNetVar), Collections.singletonList(evidence.getObservedValue(bayesNetVar))));
        }
    }

    @Override // blog.InferenceEngine
    public void answerQueries() {
        Factor factor;
        long nanoTime = System.nanoTime();
        Collection<Factor> arrayList = new ArrayList<>(this.net.factors);
        arrayList.addAll(this.evidenceFactors);
        LinkedList linkedList = new LinkedList(this.net.rvs);
        ListIterator listIterator = this.queries.listIterator();
        while (listIterator.hasNext()) {
            Query query = (Query) listIterator.next();
            Collection<? extends BayesNetVar> variables = query.getVariables();
            for (BayesNetVar bayesNetVar : variables) {
                if (!(bayesNetVar instanceof BasicVar)) {
                    throw new IllegalArgumentException("VarElimEngine can't handle query " + query + " involving non-basic random variable " + bayesNetVar);
                }
            }
            linkedList.removeAll(variables);
        }
        Map<BasicVar, VarNode> hashMap = new HashMap<>();
        for (BasicVar basicVar : this.net.rvs) {
            hashMap.put(basicVar, new VarNode(basicVar));
        }
        for (Factor factor2 : arrayList) {
            for (BasicVar basicVar2 : factor2.getRandomVars()) {
                VarNode varNode = hashMap.get(basicVar2);
                for (BasicVar basicVar3 : factor2.getRandomVars()) {
                    if (!basicVar3.equals(basicVar2)) {
                        varNode.addNeighbor(hashMap.get(basicVar3));
                    }
                }
            }
        }
        while (!linkedList.isEmpty()) {
            BasicVar chooseVarToElim = chooseVarToElim(linkedList, arrayList, hashMap);
            if (Util.verbose()) {
                System.out.println("Summing out " + chooseVarToElim);
            }
            Collection<Factor> arrayList2 = new ArrayList<>(arrayList.size());
            ArrayList arrayList3 = new ArrayList(arrayList.size());
            for (Factor factor3 : arrayList) {
                if (factor3.inScope(chooseVarToElim)) {
                    arrayList3.add(factor3);
                } else {
                    arrayList2.add(factor3);
                }
            }
            if (arrayList3.size() > 0) {
                Factor multiply = Factor.multiply(arrayList3);
                ArrayList arrayList4 = new ArrayList(1);
                arrayList4.add(chooseVarToElim);
                arrayList2.add(Factor.sumOut(multiply, arrayList4));
                VarNode varNode2 = hashMap.get(chooseVarToElim);
                for (VarNode varNode3 : varNode2.neighbors) {
                    varNode3.removeNeighbor(varNode2);
                    for (VarNode varNode4 : varNode2.neighbors) {
                        if (varNode4 != varNode3) {
                            varNode3.addNeighbor(varNode4);
                        }
                    }
                }
            }
            linkedList.remove(chooseVarToElim);
            arrayList = arrayList2;
        }
        Iterator<Factor> it = arrayList.iterator();
        Factor next = it.next();
        while (true) {
            factor = next;
            if (!it.hasNext()) {
                break;
            } else {
                next = Factor.multiply(factor, it.next());
            }
        }
        ListIterator listIterator2 = this.queries.listIterator();
        while (listIterator2.hasNext()) {
            Query query2 = (Query) listIterator2.next();
            Collection<? extends BayesNetVar> variables2 = query2.getVariables();
            ArrayList arrayList5 = new ArrayList(factor.getRandomVars());
            arrayList5.removeAll(variables2);
            Factor sumOut = Factor.sumOut(factor, arrayList5);
            sumOut.normalize();
            query2.setPosterior(sumOut);
        }
        long nanoTime2 = System.nanoTime();
        if (this.queries.isEmpty()) {
            System.out.println("Resulting zero-ary potential:");
            factor.getPotential().print(System.out);
        }
        System.out.println("\n**TIME**" + (nanoTime2 - nanoTime));
    }

    public static Factor computeMarginal(Collection<Factor> collection, List<BasicVar> list) {
        ArrayList<Factor> arrayList;
        Factor factor;
        ListIterator<BasicVar> listIterator = list.listIterator();
        ArrayList arrayList2 = new ArrayList(collection);
        while (true) {
            arrayList = arrayList2;
            if (!listIterator.hasNext()) {
                break;
            }
            BasicVar next = listIterator.next();
            ArrayList arrayList3 = new ArrayList(arrayList.size());
            ArrayList arrayList4 = new ArrayList(arrayList.size());
            for (Factor factor2 : arrayList) {
                if (factor2.inScope(next)) {
                    arrayList4.add(factor2);
                } else {
                    arrayList3.add(factor2);
                }
            }
            if (arrayList4.size() > 0) {
                Iterator it = arrayList4.iterator();
                Factor factor3 = (Factor) it.next();
                while (true) {
                    factor = factor3;
                    if (!it.hasNext()) {
                        break;
                    }
                    factor3 = Factor.multiply(factor, (Factor) it.next());
                }
                ArrayList arrayList5 = new ArrayList(1);
                arrayList5.add(next);
                arrayList3.add(Factor.sumOut(factor, arrayList5));
            }
            arrayList2 = arrayList3;
        }
        Iterator it2 = arrayList.iterator();
        Factor factor4 = (Factor) it2.next();
        while (true) {
            Factor factor5 = factor4;
            if (!it2.hasNext()) {
                return factor5;
            }
            factor4 = Factor.multiply(factor5, (Factor) it2.next());
        }
    }

    private BasicVar chooseVarToElim(Collection<BasicVar> collection, Collection<Factor> collection2, Map<BasicVar, VarNode> map) {
        BasicVar basicVar = null;
        double d = Double.POSITIVE_INFINITY;
        for (BasicVar basicVar2 : collection) {
            double d2 = map.get(basicVar2).logNeighborsRangeSize;
            if (d2 < d) {
                basicVar = basicVar2;
                d = d2;
            }
        }
        return basicVar;
    }
}
