package common;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;

/* loaded from: input_file:common/Multinomial.class */
public class Multinomial implements Serializable, IntegerDist {
    double[] pi;
    transient int totalCount;
    transient int[] counts;

    public Multinomial(int i) {
        this.pi = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            this.pi[i2] = 1.0d / i;
        }
        this.counts = new int[i];
    }

    public Multinomial(double[] dArr) {
        this.pi = (double[]) dArr.clone();
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] < 0.0d || dArr[i] > 1.0d) {
                throw new IllegalArgumentException("Probability " + dArr[i] + " for element " + i + " is not valid.");
            }
            d += dArr[i];
        }
        if (Math.abs(d - 1.0d) > 1.0E-9d) {
            throw new IllegalArgumentException("Probabilities sum to " + d + " rather than 1.0.");
        }
        this.counts = new int[dArr.length];
    }

    public int size() {
        return this.pi.length;
    }

    @Override // common.IntegerDist
    public double getProb(int i) {
        return this.pi[i];
    }

    @Override // common.IntegerDist
    public double getLogProb(int i) {
        return Math.log(this.pi[i]);
    }

    public void collectStats(int i) {
        this.totalCount++;
        int[] iArr = this.counts;
        iArr[i] = iArr[i] + 1;
    }

    public void collectAggrStats(int i, int i2) {
        this.totalCount += i2;
        int[] iArr = this.counts;
        iArr[i] = iArr[i] + i2;
    }

    public double updateParams() {
        double d = 0.0d;
        double d2 = 0.0d;
        if (this.totalCount > 0) {
            for (int i = 0; i < this.counts.length; i++) {
                d += this.counts[i] * Math.log(this.pi[i]);
                this.pi[i] = this.counts[i] / this.totalCount;
                d2 += this.counts[i] * Math.log(this.pi[i]);
            }
        }
        this.totalCount = 0;
        for (int i2 = 0; i2 < this.counts.length; i2++) {
            this.counts[i2] = 0;
        }
        return d2 - d;
    }

    @Override // common.IntegerDist
    public int sample() {
        double random = Util.random();
        double d = 0.0d;
        for (int i = 0; i < this.pi.length; i++) {
            d += this.pi[i];
            if (random < d) {
                return i;
            }
        }
        return this.pi.length - 1;
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        this.counts = new int[this.pi.length];
    }
}
