package ve;

import common.TupleIterator;
import common.Util;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:ve/MultiArray.class */
public class MultiArray {
    private Integer[] bases;
    private int[] index_multipliers;
    private double[] values;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ve/MultiArray$TrackingPointer.class */
    public class TrackingPointer {
        private int curIndex;
        private int[] changeAmounts;

        private TrackingPointer(DimMapping dimMapping) {
            this.curIndex = 0;
            int[] indicesInSub = dimMapping.indicesInSub();
            this.changeAmounts = new int[indicesInSub.length];
            int i = 0;
            for (int length = indicesInSub.length - 1; length >= 0; length--) {
                if (indicesInSub[length] == -1) {
                    this.changeAmounts[length] = i;
                } else {
                    int i2 = indicesInSub[length];
                    int i3 = MultiArray.this.index_multipliers[i2];
                    this.changeAmounts[length] = i3 + i;
                    i -= (MultiArray.this.bases[i2].intValue() - 1) * i3;
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void updateForIncrement(int i) {
            if (i < this.changeAmounts.length) {
                this.curIndex += this.changeAmounts[i];
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getValue() {
            return MultiArray.this.values[this.curIndex];
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addToValue(double d) {
            double[] dArr = MultiArray.this.values;
            int i = this.curIndex;
            dArr[i] = dArr[i] + d;
        }
    }

    public static MultiArray marginalize(MultiArray multiArray, List<Integer> list) {
        boolean[] zArr = new boolean[multiArray.bases.length];
        Arrays.fill(zArr, false);
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            zArr[it.next().intValue()] = true;
        }
        return marginalize(multiArray, zArr);
    }

    public static MultiArray marginalize(MultiArray multiArray, boolean[] zArr) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < multiArray.bases.length; i++) {
            if (!zArr[i]) {
                arrayList.add(multiArray.bases[i]);
                arrayList2.add(Integer.valueOf(i));
            }
        }
        MultiArray multiArray2 = new MultiArray(arrayList, 0.0d);
        TrackingPointer trackingPointer = multiArray2.trackingPointer(new DimMapping(arrayList2));
        int[] iArr = new int[multiArray.bases.length];
        Arrays.fill(iArr, 0);
        for (int i2 = 0; i2 < multiArray.values.length; i2++) {
            trackingPointer.addToValue(multiArray.values[i2]);
            int length = multiArray.bases.length - 1;
            while (true) {
                if (length < 0) {
                    break;
                }
                if (iArr[length] < multiArray.bases[length].intValue() - 1) {
                    int i3 = length;
                    iArr[i3] = iArr[i3] + 1;
                    Arrays.fill(iArr, length + 1, multiArray.bases.length, 0);
                    trackingPointer.updateForIncrement(length);
                    break;
                }
                length--;
            }
        }
        return multiArray2;
    }

    public static MultiArray pointwiseProduct(MultiArray multiArray, MultiArray multiArray2, DimMapping dimMapping) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < multiArray.bases.length; i++) {
            arrayList.add(multiArray.bases[i]);
        }
        int[] indicesInSuper = dimMapping.indicesInSuper();
        for (int i2 = 0; i2 < multiArray2.bases.length; i2++) {
            if (indicesInSuper[i2] >= multiArray.bases.length) {
                arrayList.add(multiArray2.bases[i2]);
            }
        }
        MultiArray multiArray3 = new MultiArray(arrayList);
        TrackingPointer trackingPointer = multiArray.trackingPointer(DimMapping.identityDimMapping(multiArray.bases.length));
        TrackingPointer trackingPointer2 = multiArray2.trackingPointer(dimMapping);
        int[] iArr = new int[multiArray3.bases.length];
        Arrays.fill(iArr, 0);
        for (int i3 = 0; i3 < multiArray3.values.length; i3++) {
            multiArray3.values[i3] = trackingPointer.getValue() * trackingPointer2.getValue();
            int length = iArr.length - 1;
            while (true) {
                if (length < 0) {
                    break;
                }
                if (iArr[length] < multiArray3.bases[length].intValue() - 1) {
                    int i4 = length;
                    iArr[i4] = iArr[i4] + 1;
                    Arrays.fill(iArr, length + 1, iArr.length, 0);
                    trackingPointer.updateForIncrement(length);
                    trackingPointer2.updateForIncrement(length);
                    break;
                }
                length--;
            }
        }
        return multiArray3;
    }

    public static MultiArray pointwiseProduct(List<MultiArray> list, List<Integer> list2, DimMapping[] dimMappingArr) {
        MultiArray multiArray = new MultiArray(list2);
        if (list.isEmpty()) {
            Arrays.fill(multiArray.values, 1.0d);
            return multiArray;
        }
        TrackingPointer[] trackingPointerArr = new TrackingPointer[list.size()];
        for (int i = 0; i < list.size(); i++) {
            trackingPointerArr[i] = list.get(i).trackingPointer(dimMappingArr[i]);
        }
        int[] iArr = new int[multiArray.bases.length];
        Arrays.fill(iArr, 0);
        for (int i2 = 0; i2 < multiArray.values.length; i2++) {
            double value = trackingPointerArr[0].getValue();
            for (int i3 = 1; i3 < trackingPointerArr.length; i3++) {
                value *= trackingPointerArr[i3].getValue();
            }
            multiArray.values[i2] = value;
            int length = iArr.length - 1;
            while (true) {
                if (length < 0) {
                    break;
                }
                if (iArr[length] < multiArray.bases[length].intValue() - 1) {
                    int i4 = length;
                    iArr[i4] = iArr[i4] + 1;
                    Arrays.fill(iArr, length + 1, iArr.length, 0);
                    for (TrackingPointer trackingPointer : trackingPointerArr) {
                        trackingPointer.updateForIncrement(length);
                    }
                } else {
                    length--;
                }
            }
        }
        return multiArray;
    }

    public MultiArray(List<Integer> list) {
        this.bases = new Integer[list.size()];
        int i = 0;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            this.bases[i2] = Integer.valueOf(it.next().intValue());
        }
        this.index_multipliers = new int[this.bases.length];
        int i3 = 1;
        for (int length = this.bases.length - 1; length >= 0; length--) {
            this.index_multipliers[length] = i3;
            i3 *= this.bases[length].intValue();
        }
        this.values = new double[i3];
    }

    public MultiArray(List<Integer> list, double d) {
        this(list);
        Arrays.fill(this.values, d);
    }

    public MultiArray(MultiArray multiArray) {
        this.bases = new Integer[multiArray.bases.length];
        for (int i = 0; i < this.bases.length; i++) {
            this.bases[i] = multiArray.bases[i];
        }
        this.index_multipliers = new int[multiArray.index_multipliers.length];
        for (int i2 = 0; i2 < this.index_multipliers.length; i2++) {
            this.index_multipliers[i2] = multiArray.index_multipliers[i2];
        }
        this.values = new double[multiArray.values.length];
        for (int i3 = 0; i3 < this.values.length; i3++) {
            this.values[i3] = multiArray.values[i3];
        }
    }

    public MultiArray copy() {
        return new MultiArray(this);
    }

    public TupleIterator allArgumentIterator() {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < this.bases.length; i++) {
            int intValue = this.bases[i].intValue();
            ArrayList arrayList = new ArrayList(intValue);
            for (int i2 = 0; i2 < intValue; i2++) {
                arrayList.add(Integer.valueOf(i2));
            }
            linkedList.add(arrayList);
        }
        return new TupleIterator(linkedList);
    }

    public double getValue(List<Integer> list) {
        return this.values[indexIntoValues(list)];
    }

    public double getValue(int[] iArr) {
        return this.values[indexIntoValues(iArr)];
    }

    public void setValue(List<Integer> list, double d) {
        this.values[indexIntoValues(list)] = d;
    }

    public void setValue(int[] iArr, double d) {
        this.values[indexIntoValues(iArr)] = d;
    }

    public void setValues(double[] dArr) {
        for (int i = 0; i < this.values.length && i < dArr.length; i++) {
            this.values[i] = dArr[i];
        }
    }

    public int size() {
        return this.values.length;
    }

    public void pow(double d) {
        for (int i = 0; i < this.values.length; i++) {
            this.values[i] = Math.pow(this.values[i], d);
        }
    }

    public void normalize() {
        double d = 0.0d;
        for (int i = 0; i < this.values.length; i++) {
            d += this.values[i];
        }
        for (int i2 = 0; i2 < this.values.length; i2++) {
            double[] dArr = this.values;
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
    }

    public boolean isConstant(double d) {
        for (int i = 0; i < this.values.length; i++) {
            if (this.values[i] != d) {
                return false;
            }
        }
        return true;
    }

    public boolean withinTol(MultiArray multiArray) {
        if (!Arrays.equals(this.bases, multiArray.bases)) {
            return false;
        }
        for (int i = 0; i < this.values.length; i++) {
            if (!Util.withinTol(this.values[i], multiArray.values[i])) {
                return false;
            }
        }
        return true;
    }

    public void print(PrintStream printStream) {
        print(printStream, null);
    }

    public void print(PrintStream printStream, List<List<String>> list) {
        int[] iArr = new int[this.bases.length];
        Arrays.fill(iArr, 0);
        String[] strArr = new String[this.bases.length];
        for (int i = 0; i < this.bases.length; i++) {
            strArr[i] = list == null ? "0" : list.get(i).get(0);
        }
        for (int i2 = 0; i2 < this.values.length; i2++) {
            printEntry(printStream, strArr, i2);
            int length = this.bases.length - 1;
            while (true) {
                if (length < 0) {
                    break;
                }
                if (iArr[length] < this.bases[length].intValue() - 1) {
                    int i3 = length;
                    iArr[i3] = iArr[i3] + 1;
                    strArr[length] = list == null ? String.valueOf(iArr[length]) : list.get(length).get(iArr[length]);
                } else {
                    iArr[length] = 0;
                    strArr[length] = list == null ? "0" : list.get(length).get(0);
                    length--;
                }
            }
        }
    }

    private void printEntry(PrintStream printStream, String[] strArr, int i) {
        for (String str : strArr) {
            printStream.print(str);
            printStream.print('\t');
        }
        printStream.println(this.values[i]);
    }

    private int indexIntoValues(List<Integer> list) {
        int i = 0;
        int i2 = 0;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int i3 = i2;
            i2++;
            i += this.index_multipliers[i3] * it.next().intValue();
        }
        return i;
    }

    private int indexIntoValues(int[] iArr) {
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            i += this.index_multipliers[i2] * iArr[i2];
        }
        return i;
    }

    private TrackingPointer trackingPointer(DimMapping dimMapping) {
        return new TrackingPointer(dimMapping);
    }

    public static void main(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(2);
        arrayList.add(4);
        arrayList.add(6);
        MultiArray multiArray = new MultiArray(arrayList);
        for (int i = 0; i < 2; i++) {
            for (int i2 = 0; i2 < 4; i2++) {
                for (int i3 = 0; i3 < 6; i3++) {
                    multiArray.setValue(new int[]{i, i2, i3}, (i * 1000) + (i2 * 100) + (i3 * 10));
                }
            }
        }
        for (int i4 = 0; i4 < 2; i4++) {
            for (int i5 = 0; i5 < 4; i5++) {
                for (int i6 = 0; i6 < 6; i6++) {
                    int[] iArr = {i4, i5, i6};
                    if (multiArray.getValue(iArr) != (i4 * 1000) + (i5 * 100) + (i6 * 10)) {
                        System.out.println("invalid value at: ");
                        System.out.println(iArr);
                    }
                }
            }
        }
        System.out.println("Indexing works");
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(3);
        arrayList2.add(2);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(2);
        arrayList3.add(2);
        MultiArray multiArray2 = new MultiArray(arrayList2);
        MultiArray multiArray3 = new MultiArray(arrayList3);
        int[] iArr2 = {0, 0};
        int[] iArr3 = {0, 1};
        int[] iArr4 = {1, 0};
        int[] iArr5 = {1, 1};
        int[] iArr6 = {2, 0};
        int[] iArr7 = {2, 1};
        multiArray2.setValue(iArr2, 0.5d);
        multiArray2.setValue(iArr3, 0.8d);
        multiArray2.setValue(iArr4, 0.1d);
        multiArray2.setValue(iArr5, 0.0d);
        multiArray2.setValue(iArr6, 0.3d);
        multiArray2.setValue(iArr7, 0.9d);
        multiArray3.setValue(iArr2, 0.5d);
        multiArray3.setValue(iArr3, 0.7d);
        multiArray3.setValue(iArr4, 0.1d);
        multiArray3.setValue(iArr5, 0.2d);
        List<Integer> singletonList = Collections.singletonList(0);
        List<Integer> singletonList2 = Collections.singletonList(1);
        MultiArray pointwiseProduct = pointwiseProduct(multiArray2, multiArray3, new DimMapping(new int[]{1, 2}));
        int[] iArr8 = {0, 0, 0};
        int[] iArr9 = {0, 0, 1};
        int[] iArr10 = {0, 1, 0};
        int[] iArr11 = {0, 1, 1};
        int[] iArr12 = {1, 0, 0};
        int[] iArr13 = {1, 0, 1};
        int[] iArr14 = {1, 1, 0};
        int[] iArr15 = {1, 1, 1};
        int[] iArr16 = {2, 0, 0};
        int[] iArr17 = {2, 0, 1};
        int[] iArr18 = {2, 1, 0};
        int[] iArr19 = {2, 1, 1};
        if (pointwiseProduct.getValue(iArr8) == 0.25d && pointwiseProduct.getValue(iArr9) == 0.35d && pointwiseProduct.getValue(iArr10) == 0.08000000000000002d && pointwiseProduct.getValue(iArr11) == 0.16000000000000003d && pointwiseProduct.getValue(iArr12) == 0.05d && pointwiseProduct.getValue(iArr13) == 0.06999999999999999d && pointwiseProduct.getValue(iArr14) == 0.0d && pointwiseProduct.getValue(iArr15) == 0.0d && pointwiseProduct.getValue(iArr16) == 0.15d && pointwiseProduct.getValue(iArr17) == 0.21d && pointwiseProduct.getValue(iArr18) == 0.09000000000000001d && pointwiseProduct.getValue(iArr19) == 0.18000000000000002d) {
            System.out.println("Pointwise Product works");
        } else {
            System.out.println("pointwise product failed!");
        }
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(3);
        arrayList4.add(2);
        arrayList4.add(2);
        DimMapping[] dimMappingArr = {new DimMapping(new int[]{0, 1}), new DimMapping(new int[]{1, 2})};
        ArrayList arrayList5 = new ArrayList(2);
        arrayList5.add(multiArray2);
        arrayList5.add(multiArray3);
        if (Arrays.equals(pointwiseProduct.values, pointwiseProduct(arrayList5, arrayList4, dimMappingArr).values)) {
            System.out.println("n-ary and binary products same");
        } else {
            System.out.println("n-ary and binary products different!");
        }
        MultiArray multiArray4 = new MultiArray(arrayList4);
        multiArray4.setValue(iArr8, 0.25d);
        multiArray4.setValue(iArr9, 0.35d);
        multiArray4.setValue(iArr10, 0.08d);
        multiArray4.setValue(iArr11, 0.16d);
        multiArray4.setValue(iArr12, 0.05d);
        multiArray4.setValue(iArr13, 0.07d);
        multiArray4.setValue(iArr14, 0.0d);
        multiArray4.setValue(iArr15, 0.0d);
        multiArray4.setValue(iArr16, 0.15d);
        multiArray4.setValue(iArr17, 0.21d);
        multiArray4.setValue(iArr18, 0.09d);
        multiArray4.setValue(iArr19, 0.18d);
        ArrayList arrayList6 = new ArrayList();
        arrayList6.add(2);
        arrayList6.add(2);
        arrayList6.add(2);
        MultiArray multiArray5 = new MultiArray(arrayList6);
        multiArray5.setValue(iArr8, 0.56d);
        multiArray5.setValue(iArr9, 0.23d);
        multiArray5.setValue(iArr10, 0.35d);
        multiArray5.setValue(iArr11, 0.24d);
        multiArray5.setValue(iArr12, 0.0d);
        multiArray5.setValue(iArr13, 0.87d);
        multiArray5.setValue(iArr14, 0.0d);
        multiArray5.setValue(iArr15, 0.9d);
        ArrayList arrayList7 = new ArrayList();
        arrayList7.add(1);
        arrayList7.add(2);
        ArrayList arrayList8 = new ArrayList();
        arrayList8.add(2);
        arrayList8.add(1);
        MultiArray pointwiseProduct2 = pointwiseProduct(multiArray4, multiArray5, new DimMapping(new int[]{3, 2, 1}));
        int[] iArr20 = {2, 1, 1, 0};
        if (pointwiseProduct2.getValue(new int[]{0, 0, 0, 0}) == 0.14d && pointwiseProduct2.getValue(iArr20) == 0.043199999999999995d) {
            System.out.println("Product works");
        } else {
            System.out.println("Product Failed!");
        }
        MultiArray pointwiseProduct3 = pointwiseProduct(multiArray4, multiArray5, new DimMapping(new int[]{3, 4, 5}));
        int[] iArr21 = {0, 1, 0, 1, 1, 1};
        if (pointwiseProduct3.getValue(new int[]{2, 0, 1, 1, 0, 1}) == 0.1827d && pointwiseProduct3.getValue(iArr21) == 0.07200000000000001d) {
            System.out.println("null product works");
        } else {
            System.out.println("null product failed");
        }
        ArrayList arrayList9 = new ArrayList();
        arrayList9.add(0);
        arrayList9.add(2);
        MultiArray marginalize = marginalize(multiArray4, arrayList9);
        System.out.print("Should be around 1.08: ");
        System.out.println(marginalize.getValue(singletonList));
        System.out.print("Should be around 0.51: ");
        System.out.println(marginalize.getValue(singletonList2));
        MultiArray marginalize2 = marginalize(multiArray4, (List<Integer>) Collections.singletonList(1));
        if (marginalize2.getValue(iArr2) == 0.33d && marginalize2.getValue(iArr3) == 0.51d && marginalize2.getValue(iArr4) == 0.05d && marginalize2.getValue(iArr5) == 0.07d && marginalize2.getValue(iArr6) == 0.24d && marginalize2.getValue(iArr7) == 0.39d) {
            System.out.println("Marginalization works");
        } else {
            System.out.println("Marginalization Failed!");
        }
    }
}
