package blog;

import blog.DependencyModel;
import common.Util;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import java.util.Set;

/* loaded from: input_file:blog/GenericProposer.class */
public class GenericProposer implements Proposer {
    protected Model model;
    protected Sampler initialStateSampler;
    protected List queries;
    protected Set evidenceVars;
    protected int numBasicEvidenceVars;
    protected Set queryVars;
    protected double logProbForward;
    protected double logProbBackward;
    protected Evidence evidence = null;
    protected int numTrials = 0;
    protected int totalNumInitialStateTries = 0;
    protected int numInitialStateTriesThisTrial = 0;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:blog/GenericProposer$PickVarToSampleResult.class */
    public class PickVarToSampleResult {
        public VarWithDistrib varToSample;
        public int numberOfChoices;

        public PickVarToSampleResult(VarWithDistrib varWithDistrib, int i) {
            this.varToSample = varWithDistrib;
            this.numberOfChoices = i;
        }
    }

    public GenericProposer(Model model, Properties properties) {
        this.model = model;
        this.initialStateSampler = new LWSampler(model, properties);
    }

    @Override // blog.Proposer
    public PartialWorldDiff initialize(Evidence evidence, List list) {
        this.evidence = new Evidence();
        this.evidenceVars = new HashSet();
        this.numBasicEvidenceVars = 0;
        this.queries = new LinkedList();
        this.queryVars = new LinkedHashSet();
        add(evidence);
        addQueries(list);
        return constructInitialState();
    }

    public void add(Evidence evidence) {
        this.evidence.addAll(evidence);
        this.evidenceVars.add(evidence.getEvidenceVars());
        Iterator it = this.evidenceVars.iterator();
        while (it.hasNext()) {
            if (it.next() instanceof BasicVar) {
                this.numBasicEvidenceVars++;
            }
        }
    }

    public void addQueries(List list) {
        this.queries.addAll(list);
        Iterator it = list.iterator();
        while (it.hasNext()) {
            this.queryVars.addAll(((Query) it.next()).getVariables());
        }
    }

    protected PartialWorldDiff constructInitialState() {
        PartialWorld latestWorld;
        this.numTrials++;
        this.numInitialStateTriesThisTrial = 0;
        this.initialStateSampler.initialize(this.evidence, this.queries);
        while (true) {
            this.initialStateSampler.nextSample();
            latestWorld = this.initialStateSampler.getLatestWorld();
            this.totalNumInitialStateTries++;
            this.numInitialStateTriesThisTrial++;
            if (this.initialStateSampler.getLatestWeight() > 0.0d) {
                break;
            }
            if (Util.verbose()) {
                System.out.println(this.numInitialStateTriesThisTrial + "th initial world rejected.");
            }
        }
        if (Util.verbose()) {
            System.out.println("Probability of " + this.numInitialStateTriesThisTrial + "th initial state = " + this.initialStateSampler.getLatestWeight());
            latestWorld.print(System.out);
        }
        return new PartialWorldDiff(new DefaultPartialWorld(latestWorld.getIdTypes()), latestWorld);
    }

    protected PickVarToSampleResult pickVarToSample(PartialWorld partialWorld) {
        HashSet hashSet = new HashSet(partialWorld.getInstantiatedVars());
        hashSet.removeAll(this.evidenceVars);
        return new PickVarToSampleResult((VarWithDistrib) Util.uniformSample(hashSet), hashSet.size());
    }

    @Override // blog.Proposer
    public double proposeNextState(PartialWorldDiff partialWorldDiff) {
        if (this.evidence == null) {
            throw new IllegalStateException("initialize() has not been called on proposer.");
        }
        this.logProbForward = 0.0d;
        this.logProbBackward = 0.0d;
        PickVarToSampleResult pickVarToSample = pickVarToSample(partialWorldDiff);
        if (pickVarToSample.varToSample == null) {
            return 1.0d;
        }
        if (Util.verbose()) {
            System.out.println("  sampling " + pickVarToSample.varToSample);
        }
        this.logProbForward += -Math.log(pickVarToSample.numberOfChoices);
        sampleValue(pickVarToSample.varToSample, partialWorldDiff);
        LinkedList linkedList = new LinkedList(partialWorldDiff.getNewlyBarrenVars());
        while (!linkedList.isEmpty()) {
            BayesNetVar bayesNetVar = (BayesNetVar) linkedList.removeFirst();
            if (!this.evidenceVars.contains(bayesNetVar) && !this.queryVars.contains(bayesNetVar)) {
                Set<BayesNetVar> parents = partialWorldDiff.getBayesNet().getParents(bayesNetVar);
                if (bayesNetVar instanceof VarWithDistrib) {
                    this.logProbBackward += partialWorldDiff.getSaved().getLogProbOfValue(bayesNetVar);
                    partialWorldDiff.setValue((VarWithDistrib) bayesNetVar, null);
                }
                for (BayesNetVar bayesNetVar2 : parents) {
                    if (partialWorldDiff.getBayesNet().getChildren(bayesNetVar2).isEmpty()) {
                        linkedList.addLast(bayesNetVar2);
                    }
                }
            }
        }
        this.logProbBackward += -Math.log(partialWorldDiff.getInstantiatedVars().size() - this.numBasicEvidenceVars);
        return this.logProbBackward - this.logProbForward;
    }

    @Override // blog.Proposer
    public void printStats() {
        System.out.println("===== " + getClass().getName() + " Stats ====");
        System.out.println("Initial world attempts: " + this.numInitialStateTriesThisTrial);
        if (this.numTrials > 0) {
            System.out.println("\tRunning average (for trials so far): " + (this.totalNumInitialStateTries / this.numTrials));
        }
    }

    @Override // blog.Proposer
    public void updateStats(boolean z) {
    }

    private void sampleValue(VarWithDistrib varWithDistrib, PartialWorld partialWorld) {
        Set children = partialWorld.getBayesNet().getChildren(varWithDistrib);
        DependencyModel.Distrib distrib = varWithDistrib.getDistrib(new DefaultEvalContext(partialWorld, true));
        this.logProbBackward += Math.log(distrib.getCPD().getProb(distrib.getArgValues(), partialWorld.getValue(varWithDistrib)));
        Object sampleVal = distrib.getCPD().sampleVal(distrib.getArgValues(), varWithDistrib.getType());
        partialWorld.setValue(varWithDistrib, sampleVal);
        this.logProbForward += Math.log(distrib.getCPD().getProb(distrib.getArgValues(), sampleVal));
        InstantiatingEvalContext instantiatingEvalContext = new InstantiatingEvalContext(partialWorld);
        Iterator it = children.iterator();
        while (it.hasNext()) {
            ((BayesNetVar) it.next()).ensureDetAndSupported(instantiatingEvalContext);
        }
        this.logProbForward += instantiatingEvalContext.getLogProbability();
    }
}
