package common;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;

/* loaded from: input_file:common/MultinomialWithTail.class */
public class MultinomialWithTail implements Serializable, IntegerDist {
    int k;
    Multinomial multinomial;
    Geometric geometric;
    double lambda;
    transient double logLambda;
    transient double logOneMinusLambda;
    transient int totalCount;
    transient int multinomialCount;

    public MultinomialWithTail(int i) {
        this.k = i;
        this.multinomial = new Multinomial(i);
        this.geometric = new Geometric();
        this.lambda = 0.5d;
        cacheParams();
    }

    public MultinomialWithTail(double[] dArr, double d, double d2) {
        this.k = dArr.length;
        this.multinomial = new Multinomial(dArr);
        this.geometric = new Geometric(d2);
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Mixing parameter lambda must be between 0 and 1");
        }
        this.lambda = d;
        cacheParams();
    }

    @Override // common.IntegerDist
    public double getLogProb(int i) {
        if (i < 0) {
            return 0.0d;
        }
        return i < this.k ? this.logLambda + this.multinomial.getLogProb(i) : this.logOneMinusLambda + this.geometric.getLogProb(i - this.k);
    }

    @Override // common.IntegerDist
    public double getProb(int i) {
        return Math.exp(getLogProb(i));
    }

    public void collectStats(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("MultinomialWithTail can't generate a negative number.");
        }
        this.totalCount++;
        if (i >= this.k) {
            this.geometric.collectStats(i - this.k);
        } else {
            this.multinomialCount++;
            this.multinomial.collectStats(i);
        }
    }

    @Override // common.IntegerDist
    public int sample() {
        return Util.random() < this.lambda ? this.multinomial.sample() : this.geometric.sample() + this.k;
    }

    public double updateParams() {
        double d = (this.multinomialCount * this.logLambda) + ((this.totalCount - this.multinomialCount) * this.logOneMinusLambda);
        if (this.totalCount > 0) {
            this.lambda = this.multinomialCount / this.totalCount;
            cacheParams();
        }
        double updateParams = (((this.multinomialCount * this.logLambda) + ((this.totalCount - this.multinomialCount) * this.logOneMinusLambda)) - d) + this.multinomial.updateParams() + this.geometric.updateParams();
        this.totalCount = 0;
        this.multinomialCount = 0;
        return updateParams;
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        cacheParams();
    }

    void cacheParams() {
        this.logLambda = Math.log(this.lambda);
        this.logOneMinusLambda = Math.log(1.0d - this.lambda);
    }
}
