package fojt;

import blog.ArgSpecQuery;
import blog.Evidence;
import blog.FuncAppTerm;
import blog.LogicalVar;
import blog.Model;
import blog.Query;
import blog.Term;
import blog.ValueEvidenceStatement;
import fove.Constraint;
import fove.Parfactor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.TreeMap;
import test.Statistics;

/* loaded from: input_file:fojt/LiftedDynamicJTEngine.class */
public class LiftedDynamicJTEngine extends LiftedJTEngine {
    private int t;
    private int maxTime;
    private int a;
    private LiftedJTEngine ljt0;
    private LiftedJTEngine ljtt;
    private Parfactor i1;
    private Parfactor i2;
    private Model m1;
    private Model mt;
    private HashMap<Integer, Evidence> evidences;
    private HashMap<Integer, List<Query>> queries;
    private Message betaT;
    private Message[] alpha;
    private long queryTime;

    public LiftedDynamicJTEngine(Model model, Properties properties) {
        super(model, properties);
        this.evidences = new HashMap<>();
        this.queries = new HashMap<>();
        this.queryTime = 0L;
        this.m1 = model.deepCopy();
        this.mt = model.deepCopy();
        Statistics statistics = Statistics.getInstance();
        long nanoTime = System.nanoTime();
        new LinkedHashSet();
        new LinkedHashSet();
        new LinkedHashSet();
        HashMap<Integer, Set<Parfactor>> detectSlices = model.detectSlices();
        Set<Parfactor> set = detectSlices.get(1);
        Set<Parfactor> set2 = detectSlices.get(2);
        Set<Parfactor> set3 = detectSlices.get(3);
        createInterfaceParfactors(set2, set3);
        this.m1.removeParfactors(set3);
        this.m1.removeParfactors(set2);
        this.m1.addParfactor(this.i1);
        this.m1.removeFunctions(model.getFunctionsFromTimeSlice(set3));
        this.ljt0 = new LiftedJTEngine(this.m1, properties);
        this.ljt0.setOutClique(this.i1);
        this.mt.removeParfactors(set);
        this.mt.addParfactor(this.i1);
        this.mt.addParfactor(this.i2);
        new LinkedHashSet();
        Set<Term> linkedHashSet = new LinkedHashSet<>();
        LinkedHashSet linkedHashSet2 = new LinkedHashSet();
        Iterator<Parfactor> it = set.iterator();
        while (it.hasNext()) {
            linkedHashSet.addAll(it.next().dimTerms());
        }
        Iterator<Parfactor> it2 = set2.iterator();
        while (it2.hasNext()) {
            linkedHashSet2.addAll(it2.next().dimTerms());
        }
        linkedHashSet.removeAll(linkedHashSet2);
        this.mt.removeFunctions(model.getFunctionsFromTerms(linkedHashSet));
        this.ljtt = new LiftedJTEngine(this.mt, properties);
        this.ljtt.setInClique(this.i1);
        this.ljtt.setOutClique(this.i2);
        statistics.addSplitTime(System.nanoTime() - nanoTime);
        this.t = 1;
        this.maxTime = 1;
        this.a = Integer.parseUnsignedInt(properties.getProperty("archClass", "0"));
        this.ljt0.passMessages(this.a);
        this.ljtt.passMessages(this.a);
        this.ljt0.resetJTree();
        this.ljtt.resetJTree();
    }

    @Override // fove.LiftedVarElim, blog.InferenceEngine
    public void setEvidence(Evidence evidence) {
        Iterator it = evidence.getNotParsedValueEvidence().iterator();
        while (it.hasNext()) {
            ValueEvidenceStatement valueEvidenceStatement = (ValueEvidenceStatement) it.next();
            Integer parseEvidenceStatement = parseEvidenceStatement(valueEvidenceStatement);
            Evidence evidence2 = new Evidence();
            if (parseEvidenceStatement.intValue() == 1) {
                if (valueEvidenceStatement.checkTypesAndScope(this.m1)) {
                    evidence2.addValueEvidence(valueEvidenceStatement);
                    evidence2.compile();
                    evidence2.setTime(parseEvidenceStatement);
                } else {
                    System.out.println("check failed");
                    System.out.println("Vars " + valueEvidenceStatement + " subexpr: " + valueEvidenceStatement.getLeftSide().getSubExprs());
                    it.remove();
                }
            } else if (parseEvidenceStatement.intValue() <= 1) {
                System.out.println("Not assigned to a timestep: " + valueEvidenceStatement);
            } else if (valueEvidenceStatement.checkTypesAndScope(this.mt)) {
                evidence2.addValueEvidence(valueEvidenceStatement);
                evidence2.compile();
                evidence2.setTime(parseEvidenceStatement);
            } else {
                System.out.println("check failed");
                System.out.println("Vars " + valueEvidenceStatement + " subexpr: " + valueEvidenceStatement.getLeftSide().getSubExprs());
                it.remove();
            }
            if (this.evidences.containsKey(parseEvidenceStatement)) {
                Evidence evidence3 = this.evidences.get(parseEvidenceStatement);
                evidence3.addAll(evidence2);
                this.evidences.put(parseEvidenceStatement, evidence3);
            } else {
                this.evidences.put(parseEvidenceStatement, evidence2);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v40, types: [java.util.List] */
    @Override // fove.LiftedVarElim, blog.InferenceEngine
    public void setQueries(List list) {
        Iterator it = list.iterator();
        while (it.hasNext()) {
            ArgSpecQuery argSpecQuery = (ArgSpecQuery) it.next();
            parseQuery(argSpecQuery);
            if (argSpecQuery.getQueriedTime().intValue() == 1) {
                if (argSpecQuery.checkTypesAndScope(this.m1)) {
                    argSpecQuery.compile();
                } else {
                    System.out.println("check failed");
                    System.out.println("Vars " + argSpecQuery + " subexpr: " + argSpecQuery.argSpec().getSubExprs());
                    it.remove();
                }
            } else if (argSpecQuery.getQueriedTime().intValue() <= 1) {
                System.out.println("Not assigned to a timestep: " + argSpecQuery);
            } else if (argSpecQuery.checkTypesAndScope(this.mt)) {
                argSpecQuery.compile();
            } else {
                System.out.println("check failed");
                System.out.println("Vars " + argSpecQuery + " subexpr: " + argSpecQuery.argSpec().getSubExprs());
                it.remove();
            }
            ArrayList arrayList = this.queries.containsKey(argSpecQuery.getTimeStep()) ? (List) this.queries.get(argSpecQuery.getTimeStep()) : new ArrayList();
            arrayList.add(argSpecQuery);
            this.queries.put(argSpecQuery.getTimeStep(), arrayList);
            if (argSpecQuery.getTimeStep().intValue() > this.maxTime) {
                this.maxTime = argSpecQuery.getTimeStep().intValue();
            }
        }
    }

    @Override // fojt.LiftedJTEngine, fove.LiftedVarElim, blog.InferenceEngine
    public void answerQueries() {
        Statistics statistics = Statistics.getInstance();
        this.alpha = new Message[this.maxTime];
        ArrayList arrayList = new ArrayList();
        Evidence evidence = new Evidence();
        while (this.t <= this.maxTime) {
            if (this.t == 1) {
                if (this.evidences.get(Integer.valueOf(this.t)) == null) {
                    this.ljt0.setEvidence(evidence);
                } else {
                    this.ljt0.setEvidence(this.evidences.get(Integer.valueOf(this.t)));
                }
                if (this.queries.get(Integer.valueOf(this.t)) == null) {
                    distributeAndAnswerQueries(arrayList);
                } else {
                    distributeAndAnswerQueries(this.queries.get(Integer.valueOf(this.t)));
                }
            } else {
                long nanoTime = System.nanoTime();
                this.ljtt.resetJTree();
                this.ljtt.addAlphaT(this.alpha[this.t - 2]);
                this.queryTime += System.nanoTime() - nanoTime;
                if (this.evidences.get(Integer.valueOf(this.t)) == null) {
                    this.ljtt.setEvidence(evidence);
                } else {
                    this.ljtt.setEvidence(this.evidences.get(Integer.valueOf(this.t)));
                }
                if (this.queries.get(Integer.valueOf(this.t)) == null) {
                    distributeAndAnswerQueries(arrayList);
                } else {
                    distributeAndAnswerQueries(this.queries.get(Integer.valueOf(this.t)));
                }
            }
            this.t++;
        }
        statistics.addSplitTime(this.queryTime);
    }

    private Message forwardPass() {
        Message calculateAlphaT = this.ljtt.calculateAlphaT(this.i2);
        calculateAlphaT.renameParfactors("2", "1", this.mt);
        return calculateAlphaT;
    }

    private void distributeAndAnswerQueries(List<Query> list) {
        Message message;
        int i = this.t;
        ArrayList arrayList = new ArrayList();
        TreeMap treeMap = new TreeMap();
        TreeMap treeMap2 = new TreeMap(Collections.reverseOrder());
        new ArrayList();
        Iterator<Query> it = list.iterator();
        while (it.hasNext()) {
            ArgSpecQuery argSpecQuery = (ArgSpecQuery) it.next();
            if (argSpecQuery.getQueriedTime() == argSpecQuery.getTimeStep()) {
                arrayList.add(argSpecQuery);
            } else if (argSpecQuery.getQueriedTime().intValue() < argSpecQuery.getTimeStep().intValue()) {
                List list2 = (List) treeMap2.get(argSpecQuery.getQueriedTime());
                if (list2 == null) {
                    list2 = new ArrayList();
                }
                list2.add(argSpecQuery);
                treeMap2.put(argSpecQuery.getQueriedTime(), list2);
            } else {
                List list3 = (List) treeMap.get(argSpecQuery.getQueriedTime());
                if (list3 == null) {
                    list3 = new ArrayList();
                }
                list3.add(argSpecQuery);
                treeMap.put(argSpecQuery.getQueriedTime(), list3);
            }
        }
        long nanoTime = System.nanoTime();
        if (i == 1) {
            this.ljt0.setQueries(arrayList);
            this.ljt0.answerQueries();
            this.alpha[i - 1] = this.ljt0.calculateAlphaT(this.i1);
        } else if (i > 1) {
            this.ljtt.setQueries(arrayList);
            this.ljtt.answerQueries();
            this.alpha[i - 1] = forwardPass();
        } else {
            System.out.println("something went wrong");
        }
        for (Map.Entry entry : treeMap2.entrySet()) {
            int intValue = ((Integer) entry.getKey()).intValue();
            while (intValue != i && intValue > 0) {
                i--;
                backwardPass(i, intValue == i);
            }
            if (i == 1) {
                this.ljt0.setQueries((List) entry.getValue());
                this.ljt0.answerQueries();
            } else if (i > 1) {
                this.ljtt.setQueries((List) entry.getValue());
                this.ljtt.answerQueries();
            }
        }
        int i2 = this.t;
        for (Map.Entry entry2 : treeMap.entrySet()) {
            int intValue2 = ((Integer) entry2.getKey()).intValue();
            while (intValue2 != i2 && intValue2 > 0) {
                i2++;
                if (i2 - 1 > this.maxTime) {
                    message = forwardPass();
                } else {
                    message = this.alpha[i2 - 2];
                    if (message == null) {
                        message = forwardPass();
                    }
                }
                this.ljtt.resetJTree();
                this.ljtt.addAlphaT(message);
                if (intValue2 != i2) {
                    this.ljtt.passMessages(this.a);
                }
            }
            this.ljtt.setQueries((List) entry2.getValue());
            this.ljtt.answerQueries();
        }
        this.queryTime += System.nanoTime() - nanoTime;
    }

    private void backwardPass(int i, boolean z) {
        this.betaT = this.ljtt.calculateBetaT(this.i1, this.alpha[i - 1]);
        if (i != 1) {
            this.betaT.renameParfactors("1", "2", this.mt);
        }
        reinstantiateJTree(i, z);
    }

    private void reinstantiateJTree(int i, boolean z) {
        System.out.println("reinstantiating fojt for " + i);
        LiftedJTEngine liftedJTEngine = i == 1 ? this.ljt0 : this.ljtt;
        liftedJTEngine.resetJTree();
        liftedJTEngine.addBetaT(this.betaT);
        if (i != 1) {
            liftedJTEngine.addAlphaT(this.alpha[i - 2]);
            System.out.println("Alpha Message added: " + this.alpha[this.t - 2]);
        }
        addEvidence(i, liftedJTEngine);
        if (z) {
            return;
        }
        liftedJTEngine.passMessages(this.a);
    }

    private void addEvidence(int i, LiftedJTEngine liftedJTEngine) {
        if (this.evidences.get(Integer.valueOf(i)) == null) {
            liftedJTEngine.setEvidence(new Evidence());
        } else {
            liftedJTEngine.setEvidence(this.evidences.get(Integer.valueOf(i)));
        }
    }

    private Integer parseEvidenceStatement(ValueEvidenceStatement valueEvidenceStatement) {
        int i = 0;
        ArrayList arrayList = new ArrayList(valueEvidenceStatement.getLeftSide().getSubExprs());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            FuncAppTerm funcAppTerm = (FuncAppTerm) it.next();
            if (funcAppTerm.getType().getName().equals("Timestep")) {
                i = Integer.parseInt(funcAppTerm.toString().substring(1));
                it.remove();
            }
        }
        FuncAppTerm[] funcAppTermArr = new FuncAppTerm[arrayList.size()];
        arrayList.toArray(funcAppTermArr);
        ((FuncAppTerm) valueEvidenceStatement.getLeftSide()).setSubExprs(funcAppTermArr);
        return Integer.valueOf(i);
    }

    private void parseQuery(ArgSpecQuery argSpecQuery) {
        ArrayList arrayList = new ArrayList(argSpecQuery.argSpec().getSubExprs());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            FuncAppTerm funcAppTerm = (FuncAppTerm) it.next();
            if (funcAppTerm.getType().getName().equals("Timestep")) {
                int parseInt = Integer.parseInt(funcAppTerm.toString().substring(1));
                it.remove();
                if (argSpecQuery.getTimeStep().intValue() == -1) {
                    argSpecQuery.setTimeStep(Integer.valueOf(parseInt));
                } else {
                    if (parseInt < 1) {
                        parseInt = 2;
                    }
                    argSpecQuery.setQueriedTime(Integer.valueOf(parseInt));
                }
            }
        }
        FuncAppTerm[] funcAppTermArr = new FuncAppTerm[arrayList.size()];
        arrayList.toArray(funcAppTermArr);
        ((FuncAppTerm) argSpecQuery.argSpec()).setSubExprs(funcAppTermArr);
    }

    private void createInterfaceParfactors(Set<Parfactor> set, Set<Parfactor> set2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        Constraint constraint = Constraint.EMPTY;
        for (Parfactor parfactor : set) {
            for (Term term : new LinkedHashSet(parfactor.dimTerms())) {
                if (term.toString().contains("1") && !arrayList.contains(term)) {
                    arrayList.add(term);
                }
            }
            for (LogicalVar logicalVar : parfactor.logicalVars()) {
                String logicalVar2 = logicalVar.toString();
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    if (((Term) it.next()).toString().contains(logicalVar2) && !arrayList3.contains(logicalVar)) {
                        arrayList3.add(logicalVar);
                    }
                }
            }
            constraint = new Constraint(parfactor.constraint(), constraint);
        }
        this.i1 = Parfactor.uniformDistribution(arrayList3, constraint, arrayList);
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            arrayList2.add(findTermInNextTimeSlice((Term) it2.next(), set2));
        }
        this.i2 = Parfactor.uniformDistribution(arrayList3, constraint, arrayList2);
    }

    private Term findTermInNextTimeSlice(Term term, Set<Parfactor> set) {
        String replace = term.toString().replace("1", "2");
        Iterator<Parfactor> it = set.iterator();
        while (it.hasNext()) {
            for (Term term2 : new LinkedHashSet(it.next().dimTerms())) {
                if (term2.toString().equals(replace)) {
                    return term2;
                }
            }
        }
        return null;
    }

    public Model getTimeModel() {
        return this.mt;
    }

    @Override // fojt.LiftedJTEngine, fove.LiftedVarElim
    public void storeStats() {
        this.ljtt.storeStats();
    }

    @Override // fojt.LiftedJTEngine
    public void printFOJT() {
        this.ljtt.printFOJT();
    }

    public void setMaxTime(int i) {
        this.maxTime = i;
    }

    public FOJT getFirstFojt() {
        return this.ljt0.getFOJT();
    }

    public FOJT getTemporalFojt() {
        return this.ljtt.getFOJT();
    }

    public Parfactor getInParfactor() {
        return this.i1;
    }

    public Parfactor getOutParfactor() {
        return this.i2;
    }
}
