package common;

import java.io.Serializable;

/* loaded from: input_file:common/Bernoulli.class */
public class Bernoulli implements Serializable {
    double probTrue;
    transient int totalCount;
    transient int numTrue;

    public Bernoulli() {
        this.probTrue = 0.5d;
    }

    public Bernoulli(double d) {
        this.probTrue = d;
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Illegal probability: " + d);
        }
    }

    public double getProb(boolean z) {
        return z ? this.probTrue : 1.0d - this.probTrue;
    }

    public double getLogProb(boolean z) {
        return Math.log(getProb(z));
    }

    public void collectStats(boolean z) {
        this.totalCount++;
        if (z) {
            this.numTrue++;
        }
    }

    public double updateParams() {
        double log = (this.numTrue * Math.log(this.probTrue)) + ((this.totalCount - this.numTrue) * Math.log(1.0d - this.probTrue));
        if (this.totalCount > 0) {
            this.probTrue = this.numTrue / this.totalCount;
        }
        double log2 = (this.numTrue * Math.log(this.probTrue)) + ((this.totalCount - this.numTrue) * Math.log(1.0d - this.probTrue));
        this.totalCount = 0;
        this.numTrue = 0;
        return log2 - log;
    }
}
