package blog;

import common.Histogram;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import ve.Factor;
import ve.Potential;

/* loaded from: input_file:blog/ArgSpecQuery.class */
public class ArgSpecQuery extends AbstractQuery {
    protected ArgSpec argSpec;
    protected BayesNetVar variable;
    protected Integer timeStep = -1;
    protected Integer queriedTime = -1;
    protected Histogram histogram = new Histogram();
    protected int trialNum = 0;
    protected Map outputFiles = new HashMap();
    protected PrintStream outputFile;

    public ArgSpecQuery(ArgSpec argSpec) {
        this.outputFile = null;
        this.argSpec = argSpec;
        if (Main.histOut() != null) {
            this.outputFile = Main.filePrintStream(Main.histOut() + "-trial" + this.trialNum + ".data");
        }
    }

    public ArgSpec argSpec() {
        return this.argSpec;
    }

    @Override // blog.Query
    public void printResults(PrintStream printStream) {
        printStream.println("Distribution of values for " + this.argSpec);
        ArrayList<Histogram.Entry> arrayList = new ArrayList(this.histogram.entrySet());
        if (this.argSpec.isNumeric()) {
            Collections.sort(arrayList, new Comparator() { // from class: blog.ArgSpecQuery.1
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    Object element = ((Histogram.Entry) obj).getElement();
                    Object element2 = ((Histogram.Entry) obj2).getElement();
                    double doubleValue = ((Number) element).doubleValue();
                    double doubleValue2 = ((Number) element2).doubleValue();
                    if (doubleValue < doubleValue2) {
                        return -1;
                    }
                    return doubleValue > doubleValue2 ? 1 : 0;
                }
            });
        } else {
            Collections.sort(arrayList, new Comparator() { // from class: blog.ArgSpecQuery.2
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    double weight = ((Histogram.Entry) obj).getWeight() - ((Histogram.Entry) obj2).getWeight();
                    if (weight < 0.0d) {
                        return 1;
                    }
                    return weight > 0.0d ? -1 : 0;
                }
            });
        }
        for (Histogram.Entry entry : arrayList) {
            printStream.println("\t" + (entry.getWeight() / this.histogram.getTotalWeight()) + "\t" + entry.getElement());
        }
    }

    public void printResultsWithTimes(PrintStream printStream) {
        printStream.println("Distribution of values for " + this.argSpec + " for timestep: " + this.queriedTime + " from time step " + this.timeStep);
        ArrayList<Histogram.Entry> arrayList = new ArrayList(this.histogram.entrySet());
        if (this.argSpec.isNumeric()) {
            Collections.sort(arrayList, new Comparator() { // from class: blog.ArgSpecQuery.3
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    Object element = ((Histogram.Entry) obj).getElement();
                    Object element2 = ((Histogram.Entry) obj2).getElement();
                    double doubleValue = ((Number) element).doubleValue();
                    double doubleValue2 = ((Number) element2).doubleValue();
                    if (doubleValue < doubleValue2) {
                        return -1;
                    }
                    return doubleValue > doubleValue2 ? 1 : 0;
                }
            });
        } else {
            Collections.sort(arrayList, new Comparator() { // from class: blog.ArgSpecQuery.4
                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    double weight = ((Histogram.Entry) obj).getWeight() - ((Histogram.Entry) obj2).getWeight();
                    if (weight < 0.0d) {
                        return 1;
                    }
                    return weight > 0.0d ? -1 : 0;
                }
            });
        }
        for (Histogram.Entry entry : arrayList) {
            printStream.println("\t" + (entry.getWeight() / this.histogram.getTotalWeight()) + "\t" + entry.getElement());
        }
    }

    @Override // blog.AbstractQuery, blog.Query
    public void logResults(int i) {
        ArrayList<Histogram.Entry> arrayList = new ArrayList(this.histogram.entrySet());
        for (Histogram.Entry entry : arrayList) {
            getOutputFile(entry.getElement()).println("\t" + i + "\t" + (entry.getWeight() / this.histogram.getTotalWeight()));
        }
        if (i != Main.numSamples() || Main.histOut() == null) {
            return;
        }
        Collections.sort(arrayList, new Comparator() { // from class: blog.ArgSpecQuery.5
            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                return new Integer(((Histogram.Entry) obj).getElement().toString()).compareTo(new Integer(((Histogram.Entry) obj2).getElement().toString()));
            }
        });
        for (Histogram.Entry entry2 : arrayList) {
            this.outputFile.println("\t" + entry2.getElement() + "\t" + (entry2.getWeight() / this.histogram.getTotalWeight()));
        }
    }

    @Override // blog.Query
    public Collection<? extends BayesNetVar> getVariables() {
        if (this.variable == null) {
            throw new IllegalStateException("Query has not yet been compiled.");
        }
        return Collections.singleton(this.variable);
    }

    @Override // blog.Query
    public boolean checkTypesAndScope(Model model) {
        if (!(this.argSpec instanceof Term)) {
            return this.argSpec.checkTypesAndScope(model, Collections.EMPTY_MAP);
        }
        Term termInScope = ((Term) this.argSpec).getTermInScope(model, Collections.EMPTY_MAP);
        if (termInScope == null) {
            return false;
        }
        this.argSpec = termInScope;
        return true;
    }

    @Override // blog.Query
    public int compile() {
        int compile = this.argSpec.compile(new LinkedHashSet());
        if (compile == 0) {
            this.variable = this.argSpec.getVariable();
        }
        return compile;
    }

    @Override // blog.Query
    public void updateStats(PartialWorld partialWorld, double d) {
        this.histogram.increaseWeight(this.argSpec.evaluate(partialWorld), d);
    }

    @Override // blog.AbstractQuery, blog.Query
    public void setPosterior(Factor factor) {
        if (factor.getRandomVars().size() > 1) {
            throw new IllegalArgumentException("Answer to query on " + this.variable + " should be factor on that variable alone, not " + factor.getRandomVars());
        }
        Potential potential = factor.getPotential();
        Type type = potential.getDims().get(0);
        this.histogram.clear();
        for (Object obj : type.getGuaranteedObjects()) {
            this.histogram.increaseWeight(obj, potential.getValue(Collections.singletonList(obj)));
        }
    }

    @Override // blog.Query
    public void zeroOut() {
        this.trialNum++;
        if (this.outputFile != null && this.trialNum != Main.numTrials()) {
            this.outputFile = Main.filePrintStream(Main.histOut() + "-trial" + this.trialNum + ".data");
        }
        this.outputFiles = new HashMap();
        this.histogram.clear();
    }

    @Override // blog.Query
    public void printVarianceResults(PrintStream printStream) {
        printStream.println("\tVariance of " + this.argSpec + " results is not computed.");
    }

    private PrintStream getOutputFile(Object obj) {
        PrintStream printStream = (PrintStream) this.outputFiles.get(obj);
        if (printStream == null) {
            printStream = Main.filePrintStream(Main.outputPath() + "-trial" + this.trialNum + "." + obj.toString() + ".data");
            this.outputFiles.put(obj, printStream);
        }
        return printStream;
    }

    public Histogram getHistogram() {
        return this.histogram;
    }

    @Override // blog.AbstractQuery, blog.Query
    public Object getLocation() {
        return this.argSpec.getLocation();
    }

    public String toString() {
        return this.variable == null ? this.argSpec.toString() : this.variable.toString();
    }

    public void setVariable(BayesNetVar bayesNetVar) {
        this.variable = bayesNetVar;
    }

    public void setArg(ArgSpec argSpec) {
        this.argSpec = argSpec;
    }

    public void setTimeStep(Integer num) {
        this.timeStep = num;
    }

    public void setQueriedTime(Integer num) {
        this.queriedTime = num;
    }

    public Integer getTimeStep() {
        return this.timeStep;
    }

    public Integer getQueriedTime() {
        return this.queriedTime;
    }
}
