package tberg.murphy.floatsequence;

import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.Arrays;
import java.util.Iterator;
import tberg.murphy.arrays.a;
import tberg.murphy.math.m;
import tberg.murphy.threading.BetterThreader;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward.class */
public class ForwardBackward {
    public static final float SCALE = (float) Math.exp(20.0d);
    public static final float INVSCALE = 1.0f / SCALE;
    public static final float LOG_SCALE = (float) Math.log(SCALE);

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$Lattice.class */
    public interface Lattice {
        int numSequences();

        int sequenceLength(int i);

        int numStates(int i, int i2);

        float nodeLogPotential(int i, int i2, int i3);

        float[] allowedEdgesLogPotentials(int i, int i2, int i3, boolean z);

        float nodePotential(int i, int i2, int i3);

        float[] allowedEdgesPotentials(int i, int i2, int i3, boolean z);

        int[] allowedEdges(int i, int i2, int i3, boolean z);
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NodeMarginals.class */
    public interface NodeMarginals {
        float[] nodeCondProbs(int i, int i2);

        float sequenceLogMarginalProb(int i);

        float logMarginalProb();

        int numSequences();

        int sequenceLength(int i);

        int numStates(int i);

        Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Float>> getNodeMarginalsIterator();

        double estimateMemoryUsage();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NodeMarginalsLogSpace.class */
    public static class NodeMarginalsLogSpace implements NodeMarginals {
        Lattice lattice;
        float[][][] nodeCondProbs;
        float[] sequenceLogMarginalProbs;
        StationaryStateProjector stateProjector;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NodeMarginalsLogSpace$NodeMarginalsIterator.class */
        private class NodeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Float>> {
            int d;
            int t;
            int s;
            float[] nodeCondProbs;

            private NodeMarginalsIterator() {
                this.nodeCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NodeMarginalsLogSpace.this.numSequences() - 1 && this.t == NodeMarginalsLogSpace.this.sequenceLength(this.d) - 1 && this.s == NodeMarginalsLogSpace.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Integer>, Float> next() {
                if (this.nodeCondProbs == null) {
                    this.d = 0;
                    this.t = 0;
                    this.s = 0;
                    this.nodeCondProbs = NodeMarginalsLogSpace.this.nodeCondProbs(0, 0);
                } else if (this.s == NodeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s = 0;
                    if (this.t == NodeMarginalsLogSpace.this.sequenceLength(this.d) - 1) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                    this.nodeCondProbs = NodeMarginalsLogSpace.this.nodeCondProbs(this.d, this.t);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Integer.valueOf(this.s)), Float.valueOf(this.nodeCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v7, types: [float[][], float[][][]] */
        public NodeMarginalsLogSpace(Lattice lattice, StationaryStateProjector stationaryStateProjector) {
            this.lattice = lattice;
            this.stateProjector = stationaryStateProjector;
            this.sequenceLogMarginalProbs = new float[lattice.numSequences()];
            this.nodeCondProbs = new float[lattice.numSequences()];
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(float[][] fArr, float[][] fArr2, int i, boolean z) {
            this.sequenceLogMarginalProbs[i] = Float.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                if (z) {
                    this.sequenceLogMarginalProbs[i] = Math.max(this.sequenceLogMarginalProbs[i], fArr2[0][i2]);
                } else {
                    this.sequenceLogMarginalProbs[i] = m.logAdd(this.sequenceLogMarginalProbs[i], fArr2[0][i2]);
                }
            }
            this.nodeCondProbs[i] = new float[this.lattice.sequenceLength(i)];
            for (int i3 = 0; i3 < this.lattice.sequenceLength(i); i3++) {
                this.nodeCondProbs[i][i3] = new float[this.stateProjector.rangeSize(i)];
                if (z) {
                    Arrays.fill(this.nodeCondProbs[i][i3], Float.NEGATIVE_INFINITY);
                }
                int numStates = this.lattice.numStates(i, i3);
                for (int i4 = 0; i4 < numStates; i4++) {
                    int project = this.stateProjector.project(i, i3, i4);
                    if (fArr[i3][i4] != Float.NEGATIVE_INFINITY && fArr2[i3][i4] != Float.NEGATIVE_INFINITY) {
                        if (z) {
                            this.nodeCondProbs[i][i3][project] = Math.max(this.nodeCondProbs[i][i3][project], (float) Math.exp(((fArr[i3][i4] - this.lattice.nodeLogPotential(i, i3, i4)) + fArr2[i3][i4]) - this.sequenceLogMarginalProbs[i]));
                        } else {
                            this.nodeCondProbs[i][i3][project] = (float) (r0[project] + Math.exp(((fArr[i3][i4] - this.lattice.nodeLogPotential(i, i3, i4)) + fArr2[i3][i4]) - this.sequenceLogMarginalProbs[i]));
                        }
                    }
                }
            }
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.nodeCondProbs[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.nodeCondProbs[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public float sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public float logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public float[] nodeCondProbs(int i, int i2) {
            return this.nodeCondProbs[i][i2];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public int numStates(int i) {
            return this.stateProjector.rangeSize(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Float>> getNodeMarginalsIterator() {
            return new NodeMarginalsIterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NodeMarginalsScaling.class */
    public static class NodeMarginalsScaling implements NodeMarginals {
        Lattice lattice;
        float[][][] nodeCondProbs;
        float[] sequenceLogMarginalProbs;
        StationaryStateProjector stateProjector;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NodeMarginalsScaling$NodeMarginalsIterator.class */
        private class NodeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Float>> {
            int d;
            int t;
            int s;
            float[] nodeCondProbs;

            private NodeMarginalsIterator() {
                this.nodeCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NodeMarginalsScaling.this.numSequences() - 1 && this.t == NodeMarginalsScaling.this.sequenceLength(this.d) - 1 && this.s == NodeMarginalsScaling.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Integer>, Float> next() {
                if (this.nodeCondProbs == null) {
                    this.d = 0;
                    this.t = 0;
                    this.s = 0;
                    this.nodeCondProbs = NodeMarginalsScaling.this.nodeCondProbs(0, 0);
                } else if (this.s == NodeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s = 0;
                    if (this.t == NodeMarginalsScaling.this.sequenceLength(this.d) - 1) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                    this.nodeCondProbs = NodeMarginalsScaling.this.nodeCondProbs(this.d, this.t);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Integer.valueOf(this.s)), Float.valueOf(this.nodeCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v7, types: [float[][], float[][][]] */
        public NodeMarginalsScaling(Lattice lattice, StationaryStateProjector stationaryStateProjector) {
            this.lattice = lattice;
            this.stateProjector = stationaryStateProjector;
            this.sequenceLogMarginalProbs = new float[lattice.numSequences()];
            this.nodeCondProbs = new float[lattice.numSequences()];
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, int i, boolean z) {
            float f = z ? Float.NEGATIVE_INFINITY : 0.0f;
            float f2 = fArr4[0];
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                if (this.lattice.nodePotential(i, 0, i2) > Cropper.VERT_GROW_RATIO) {
                    f = z ? Math.max(f, fArr3[0][i2]) : f + fArr3[0][i2];
                }
            }
            this.sequenceLogMarginalProbs[i] = (f2 * ForwardBackward.LOG_SCALE) + ((float) Math.log(f));
            this.nodeCondProbs[i] = new float[this.lattice.sequenceLength(i)];
            for (int i3 = 0; i3 < this.lattice.sequenceLength(i); i3++) {
                this.nodeCondProbs[i][i3] = new float[this.stateProjector.rangeSize(i)];
                if (z) {
                    Arrays.fill(this.nodeCondProbs[i][i3], Float.NEGATIVE_INFINITY);
                }
                int numStates = this.lattice.numStates(i, i3);
                for (int i4 = 0; i4 < numStates; i4++) {
                    float scaleFactor = ForwardBackward.getScaleFactor((fArr2[i3] + fArr4[i3]) - f2);
                    int project = this.stateProjector.project(i, i3, i4);
                    if (fArr[i3][i4] != Cropper.VERT_GROW_RATIO && fArr3[i3][i4] != Cropper.VERT_GROW_RATIO) {
                        if (z) {
                            this.nodeCondProbs[i][i3][project] = Math.max(this.nodeCondProbs[i][i3][project], (fArr[i3][i4] / this.lattice.nodePotential(i, i3, i4)) * (fArr3[i3][i4] / f) * scaleFactor);
                        } else {
                            float[] fArr5 = this.nodeCondProbs[i][i3];
                            fArr5[project] = fArr5[project] + ((fArr[i3][i4] / this.lattice.nodePotential(i, i3, i4)) * (fArr3[i3][i4] / f) * scaleFactor);
                        }
                    }
                }
            }
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.nodeCondProbs[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.nodeCondProbs[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public float sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public float logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public float[] nodeCondProbs(int i, int i2) {
            return this.nodeCondProbs[i][i2];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public int numStates(int i) {
            return this.stateProjector.rangeSize(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NodeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Float>> getNodeMarginalsIterator() {
            return new NodeMarginalsIterator();
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginals.class */
    public interface NonStationaryEdgeMarginals {
        float[] startNodeCondProbs(int i);

        float[] endNodeCondProbs(int i);

        int[] allowedForwardEdges(int i, int i2, int i3);

        float[] allowedForwardEdgesExpectedCounts(int i, int i2, int i3);

        float sequenceLogMarginalProb(int i);

        float logMarginalProb();

        int numSequences();

        int sequenceLength(int i);

        int numStates(int i, int i2);

        Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float>> getEdgeMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Float>> getStartMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Float>> getEndMarginalsIterator();

        double estimateMemoryUsage();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace.class */
    public static class NonStationaryEdgeMarginalsLogSpace implements NonStationaryEdgeMarginals {
        Lattice lattice;
        float[] sequenceLogMarginalProbs;
        float[][] startNodeCondProbs;
        float[][] endNodeCondProbs;
        float[][][] allAlphas;
        float[][][] allBetas;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float>> {
            int d = 0;
            int t = 0;
            int s1 = 0;
            int s2i = 0;
            float[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                while (this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float> next() {
                Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float> makePair = Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.t, this.s1)[this.s2i]))), Float.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.t, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, this.t) - 1) {
                    this.s1 = 0;
                    if (this.t >= NonStationaryEdgeMarginalsLogSpace.this.sequenceLength(this.d) - 2) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                } else {
                    this.s1++;
                }
                if (this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences()) {
                    this.edgeCondProbs = NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, NonStationaryEdgeMarginalsLogSpace.this.sequenceLength(this.d) - 1) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = NonStationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, NonStationaryEdgeMarginalsLogSpace.this.sequenceLength(this.d) - 1) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = NonStationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, 0) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = NonStationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, 0) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = NonStationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [float[][], float[][][]] */
        /* JADX WARN: Type inference failed for: r1v15, types: [float[][], float[][][]] */
        /* JADX WARN: Type inference failed for: r1v6, types: [float[], float[][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [float[], float[][]] */
        public NonStationaryEdgeMarginalsLogSpace(Lattice lattice) {
            this.lattice = lattice;
            this.sequenceLogMarginalProbs = new float[lattice.numSequences()];
            this.startNodeCondProbs = new float[lattice.numSequences()];
            this.endNodeCondProbs = new float[lattice.numSequences()];
            this.allAlphas = new float[lattice.numSequences()];
            this.allBetas = new float[lattice.numSequences()];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2, int i3) {
            return this.lattice.allowedEdges(i, i2, i3, false);
        }

        public void incrementExpectedCounts(float[][] fArr, float[][] fArr2, int i) {
            this.allAlphas[i] = fArr;
            this.allBetas[i] = fArr2;
            this.sequenceLogMarginalProbs[i] = Float.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                this.sequenceLogMarginalProbs[i] = m.logAdd(this.sequenceLogMarginalProbs[i], fArr2[0][i2]);
            }
            this.startNodeCondProbs[i] = new float[this.lattice.numStates(i, 0)];
            for (int i3 = 0; i3 < this.lattice.numStates(i, 0); i3++) {
                this.startNodeCondProbs[i][i3] = (float) Math.exp(fArr2[0][i3] - this.sequenceLogMarginalProbs[i]);
            }
            this.endNodeCondProbs[i] = new float[this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1)];
            for (int i4 = 0; i4 < this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1); i4++) {
                this.endNodeCondProbs[i][i4] = (float) Math.exp(fArr[this.lattice.sequenceLength(i) - 1][i4] - this.sequenceLogMarginalProbs[i]);
            }
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allAlphas[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.allAlphas[i][i2].length;
                    }
                }
            }
            for (int i3 = 0; i3 < numSequences(); i3++) {
                if (this.allBetas[i3] != null) {
                    for (int i4 = 0; i4 < sequenceLength(i3); i4++) {
                        d += this.allBetas[i3][i4].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float[] allowedForwardEdgesExpectedCounts(int i, int i2, int i3) {
            int[] allowedEdges = this.lattice.allowedEdges(i, i2, i3, false);
            float[] fArr = new float[allowedEdges.length];
            float[] allowedEdgesLogPotentials = this.lattice.allowedEdgesLogPotentials(i, i2, i3, false);
            for (int i4 = 0; i4 < allowedEdges.length; i4++) {
                int i5 = allowedEdges[i4];
                fArr[i4] = (float) (fArr[r1] + Math.exp(((this.allAlphas[i][i2][i3] + allowedEdgesLogPotentials[i4]) + this.allBetas[i][i2 + 1][i5]) - this.sequenceLogMarginalProbs[i]));
            }
            return fArr;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numStates(int i, int i2) {
            return this.lattice.numStates(i, i2);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsScaling.class */
    public static class NonStationaryEdgeMarginalsScaling implements NonStationaryEdgeMarginals {
        Lattice lattice;
        float[] sequenceLogMarginalProbs;
        float[] sequenceMarginalProbs;
        float[] sequenceMarginalProbLogScales;
        float[][] startNodeCondProbs;
        float[][] endNodeCondProbs;
        float[][][] allAlphas;
        float[][][] allBetas;
        float[][] allAlphaLogScales;
        float[][] allBetaLogScales;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsScaling$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float>> {
            int d = 0;
            int t = 0;
            int s1 = 0;
            int s2i = 0;
            float[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = NonStationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                while (this.d < NonStationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < NonStationaryEdgeMarginalsScaling.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float> next() {
                Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float> makePair = Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(NonStationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.t, this.s1)[this.s2i]))), Float.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < NonStationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < NonStationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.t, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= NonStationaryEdgeMarginalsScaling.this.numStates(this.d, this.t) - 1) {
                    this.s1 = 0;
                    if (this.t >= NonStationaryEdgeMarginalsScaling.this.sequenceLength(this.d) - 2) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                } else {
                    this.s1++;
                }
                if (this.d < NonStationaryEdgeMarginalsScaling.this.numSequences()) {
                    this.edgeCondProbs = NonStationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsScaling$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, NonStationaryEdgeMarginalsScaling.this.sequenceLength(this.d) - 1) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = NonStationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, NonStationaryEdgeMarginalsScaling.this.sequenceLength(this.d) - 1) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = NonStationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$NonStationaryEdgeMarginalsScaling$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, 0) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = NonStationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, 0) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = NonStationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [float[], float[][]] */
        /* JADX WARN: Type inference failed for: r1v15, types: [float[], float[][]] */
        /* JADX WARN: Type inference failed for: r1v18, types: [float[][], float[][][]] */
        /* JADX WARN: Type inference failed for: r1v21, types: [float[], float[][]] */
        /* JADX WARN: Type inference failed for: r1v24, types: [float[], float[][]] */
        public NonStationaryEdgeMarginalsScaling(Lattice lattice) {
            this.lattice = lattice;
            this.sequenceLogMarginalProbs = new float[lattice.numSequences()];
            this.sequenceMarginalProbs = new float[lattice.numSequences()];
            this.sequenceMarginalProbLogScales = new float[lattice.numSequences()];
            this.startNodeCondProbs = new float[lattice.numSequences()];
            this.endNodeCondProbs = new float[lattice.numSequences()];
            this.allAlphas = new float[lattice.numSequences()];
            this.allAlphaLogScales = new float[lattice.numSequences()];
            this.allBetaLogScales = new float[lattice.numSequences()];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2, int i3) {
            return this.lattice.allowedEdges(i, i2, i3, false);
        }

        public void incrementExpectedCounts(float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, int i) {
            this.allAlphas[i] = fArr;
            this.allAlphaLogScales[i] = fArr2;
            this.allBetas[i] = fArr3;
            this.allBetaLogScales[i] = fArr4;
            this.sequenceMarginalProbs[i] = 0.0f;
            this.sequenceMarginalProbLogScales[i] = fArr4[0];
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                if (this.lattice.nodePotential(i, 0, i2) > Cropper.VERT_GROW_RATIO) {
                    float[] fArr5 = this.sequenceMarginalProbs;
                    fArr5[i] = fArr5[i] + fArr3[0][i2];
                }
            }
            this.sequenceLogMarginalProbs[i] = (this.sequenceMarginalProbLogScales[i] * ForwardBackward.LOG_SCALE) + ((float) Math.log(this.sequenceMarginalProbs[i]));
            this.startNodeCondProbs[i] = new float[this.lattice.numStates(i, 0)];
            float scaleFactor = ForwardBackward.getScaleFactor(fArr4[0] - this.sequenceMarginalProbLogScales[i]);
            for (int i3 = 0; i3 < this.lattice.numStates(i, 0); i3++) {
                this.startNodeCondProbs[i][i3] = (fArr3[0][i3] / this.sequenceMarginalProbs[i]) * scaleFactor;
            }
            this.endNodeCondProbs[i] = new float[this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1)];
            float scaleFactor2 = ForwardBackward.getScaleFactor(fArr2[this.lattice.sequenceLength(i) - 1] - this.sequenceMarginalProbLogScales[i]);
            for (int i4 = 0; i4 < this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1); i4++) {
                this.endNodeCondProbs[i][i4] = (fArr[this.lattice.sequenceLength(i) - 1][i4] / this.sequenceMarginalProbs[i]) * scaleFactor2;
            }
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allAlphas[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.allAlphas[i][i2].length;
                    }
                }
            }
            for (int i3 = 0; i3 < numSequences(); i3++) {
                if (this.allBetas[i3] != null) {
                    for (int i4 = 0; i4 < sequenceLength(i3); i4++) {
                        d += this.allBetas[i3][i4].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float[] allowedForwardEdgesExpectedCounts(int i, int i2, int i3) {
            int[] allowedEdges = this.lattice.allowedEdges(i, i2, i3, false);
            float[] fArr = new float[allowedEdges.length];
            float[] allowedEdgesPotentials = this.lattice.allowedEdgesPotentials(i, i2, i3, false);
            float scaleFactor = ForwardBackward.getScaleFactor((this.allAlphaLogScales[i][i2] + this.allBetaLogScales[i][i2 + 1]) - this.sequenceMarginalProbLogScales[i]);
            for (int i4 = 0; i4 < allowedEdges.length; i4++) {
                int i5 = i4;
                fArr[i5] = fArr[i5] + ((((this.allAlphas[i][i2][i3] * allowedEdgesPotentials[i4]) * this.allBetas[i][i2 + 1][allowedEdges[i4]]) / this.sequenceMarginalProbs[i]) * scaleFactor);
            }
            return fArr;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public float logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numStates(int i, int i2) {
            return this.lattice.numStates(i, i2);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Float>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginals.class */
    public interface StationaryEdgeMarginals {
        float[] startNodeCondProbs(int i);

        float[] endNodeCondProbs(int i);

        int[] allowedForwardEdges(int i, int i2);

        float[] allowedForwardEdgesExpectedCounts(int i, int i2);

        float sequenceLogMarginalProb(int i);

        float logMarginalProb();

        int numSequences();

        int numStates(int i);

        Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Float>> getEdgeMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Float>> getStartMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Float>> getEndMarginalsIterator();

        double estimateMemoryUsage();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsLogSpace.class */
    public static class StationaryEdgeMarginalsLogSpace implements StationaryEdgeMarginals {
        StationaryLattice lattice;
        float[] sequenceLogMarginalProbs;
        float[][][] allowedForwardEdgesExpectedCounts;
        float[][] startNodeCondProbs;
        float[][] endNodeCondProbs;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsLogSpace$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Float>> {
            int d = 0;
            int s1 = 0;
            int s2i = 0;
            float[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = StationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                while (this.d < StationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < StationaryEdgeMarginalsLogSpace.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Pair<Integer, Integer>>, Float> next() {
                Pair<Pair<Integer, Pair<Integer, Integer>>, Float> makePair = Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(StationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.s1)[this.s2i]))), Float.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < StationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < StationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s1 = 0;
                    this.d++;
                } else {
                    this.s1++;
                }
                if (this.d < StationaryEdgeMarginalsLogSpace.this.numSequences()) {
                    this.edgeCondProbs = StationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsLogSpace$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = StationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = StationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsLogSpace$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = StationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = StationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [float[], float[][]] */
        /* JADX WARN: Type inference failed for: r1v6, types: [float[][], float[][][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [float[], float[][]] */
        public StationaryEdgeMarginalsLogSpace(StationaryLattice stationaryLattice) {
            this.lattice = stationaryLattice;
            this.sequenceLogMarginalProbs = new float[stationaryLattice.numSequences()];
            this.allowedForwardEdgesExpectedCounts = new float[stationaryLattice.numSequences()];
            this.startNodeCondProbs = new float[stationaryLattice.numSequences()];
            this.endNodeCondProbs = new float[stationaryLattice.numSequences()];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2) {
            return this.lattice.allowedEdges(i, i2, false);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(float[][] fArr, float[][] fArr2, int i) {
            this.sequenceLogMarginalProbs[i] = Float.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.lattice.numStates(i); i2++) {
                this.sequenceLogMarginalProbs[i] = m.logAdd(this.sequenceLogMarginalProbs[i], fArr2[0][i2]);
            }
            this.allowedForwardEdgesExpectedCounts[i] = new float[this.lattice.numStates(i)];
            for (int i3 = 0; i3 < this.lattice.numStates(i); i3++) {
                this.allowedForwardEdgesExpectedCounts[i][i3] = new float[this.lattice.allowedEdges(i, i3, false).length];
            }
            for (int i4 = 0; i4 < this.lattice.sequenceLength(i) - 1; i4++) {
                int numStates = this.lattice.numStates(i);
                for (int i5 = 0; i5 < numStates; i5++) {
                    int[] allowedEdges = this.lattice.allowedEdges(i, i5, false);
                    float[] allowedEdgesLogPotentials = this.lattice.allowedEdgesLogPotentials(i, i5, false);
                    for (int i6 = 0; i6 < allowedEdges.length; i6++) {
                        int i7 = allowedEdges[i6];
                        float f = allowedEdgesLogPotentials[i6];
                        this.allowedForwardEdgesExpectedCounts[i][i5][i6] = (float) (r0[r1] + Math.exp(((fArr[i4][i5] + f) + fArr2[i4 + 1][i7]) - this.sequenceLogMarginalProbs[i]));
                    }
                }
            }
            this.startNodeCondProbs[i] = new float[this.lattice.numStates(i)];
            for (int i8 = 0; i8 < this.lattice.numStates(i); i8++) {
                this.startNodeCondProbs[i][i8] = (float) Math.exp(fArr2[0][i8] - this.sequenceLogMarginalProbs[i]);
            }
            this.endNodeCondProbs[i] = new float[this.lattice.numStates(i)];
            for (int i9 = 0; i9 < this.lattice.numStates(i); i9++) {
                this.endNodeCondProbs[i][i9] = (float) Math.exp(fArr[this.lattice.sequenceLength(i) - 1][i9] - this.sequenceLogMarginalProbs[i]);
            }
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allowedForwardEdgesExpectedCounts[i] != null) {
                    for (int i2 = 0; i2 < numStates(i); i2++) {
                        d += this.allowedForwardEdgesExpectedCounts[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float[] allowedForwardEdgesExpectedCounts(int i, int i2) {
            return this.allowedForwardEdgesExpectedCounts[i][i2];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public int numStates(int i) {
            return this.lattice.numStates(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Float>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsScaling.class */
    public static class StationaryEdgeMarginalsScaling implements StationaryEdgeMarginals {
        StationaryLattice lattice;
        float[] sequenceLogMarginalProbs;
        float[][][] allowedForwardEdgesExpectedCounts;
        float[][] startNodeCondProbs;
        float[][] endNodeCondProbs;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsScaling$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Float>> {
            int d = 0;
            int s1 = 0;
            int s2i = 0;
            float[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = StationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                while (this.d < StationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < StationaryEdgeMarginalsScaling.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Pair<Integer, Integer>>, Float> next() {
                Pair<Pair<Integer, Pair<Integer, Integer>>, Float> makePair = Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(StationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.s1)[this.s2i]))), Float.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < StationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < StationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s1 = 0;
                    this.d++;
                } else {
                    this.s1++;
                }
                if (this.d < StationaryEdgeMarginalsScaling.this.numSequences()) {
                    this.edgeCondProbs = StationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsScaling$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = StationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = StationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryEdgeMarginalsScaling$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Float>> {
            int d;
            int s;
            float[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Float> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = StationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = StationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Float.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [float[], float[][]] */
        /* JADX WARN: Type inference failed for: r1v6, types: [float[][], float[][][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [float[], float[][]] */
        public StationaryEdgeMarginalsScaling(StationaryLattice stationaryLattice) {
            this.lattice = stationaryLattice;
            this.sequenceLogMarginalProbs = new float[stationaryLattice.numSequences()];
            this.allowedForwardEdgesExpectedCounts = new float[stationaryLattice.numSequences()];
            this.startNodeCondProbs = new float[stationaryLattice.numSequences()];
            this.endNodeCondProbs = new float[stationaryLattice.numSequences()];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2) {
            return this.lattice.allowedEdges(i, i2, false);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(float[][] fArr, float[] fArr2, float[][] fArr3, float[] fArr4, int i) {
            float f = 0.0f;
            float f2 = fArr4[0];
            for (int i2 = 0; i2 < this.lattice.numStates(i); i2++) {
                if (this.lattice.nodePotential(i, 0, i2) > Cropper.VERT_GROW_RATIO) {
                    f += fArr3[0][i2];
                }
            }
            this.sequenceLogMarginalProbs[i] = (f2 * ForwardBackward.LOG_SCALE) + ((float) Math.log(f));
            this.allowedForwardEdgesExpectedCounts[i] = new float[this.lattice.numStates(i)];
            for (int i3 = 0; i3 < this.lattice.numStates(i); i3++) {
                this.allowedForwardEdgesExpectedCounts[i][i3] = new float[this.lattice.allowedEdges(i, i3, false).length];
            }
            for (int i4 = 0; i4 < this.lattice.sequenceLength(i) - 1; i4++) {
                float scaleFactor = ForwardBackward.getScaleFactor((fArr2[i4] + fArr4[i4 + 1]) - f2);
                int numStates = this.lattice.numStates(i);
                for (int i5 = 0; i5 < numStates; i5++) {
                    int[] allowedEdges = this.lattice.allowedEdges(i, i5, false);
                    float[] allowedEdgesPotentials = this.lattice.allowedEdgesPotentials(i, i5, false);
                    for (int i6 = 0; i6 < allowedEdges.length; i6++) {
                        int i7 = allowedEdges[i6];
                        float f3 = allowedEdgesPotentials[i6];
                        float[] fArr5 = this.allowedForwardEdgesExpectedCounts[i][i5];
                        int i8 = i6;
                        fArr5[i8] = fArr5[i8] + ((((fArr[i4][i5] * f3) * fArr3[i4 + 1][i7]) / f) * scaleFactor);
                    }
                }
            }
            this.startNodeCondProbs[i] = new float[this.lattice.numStates(i)];
            float scaleFactor2 = ForwardBackward.getScaleFactor(fArr4[0] - f2);
            for (int i9 = 0; i9 < this.lattice.numStates(i); i9++) {
                this.startNodeCondProbs[i][i9] = (fArr3[0][i9] / f) * scaleFactor2;
            }
            this.endNodeCondProbs[i] = new float[this.lattice.numStates(i)];
            float scaleFactor3 = ForwardBackward.getScaleFactor(fArr2[this.lattice.sequenceLength(i) - 1] - f2);
            for (int i10 = 0; i10 < this.lattice.numStates(i); i10++) {
                this.endNodeCondProbs[i][i10] = (fArr[this.lattice.sequenceLength(i) - 1][i10] / f) * scaleFactor3;
            }
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allowedForwardEdgesExpectedCounts[i] != null) {
                    for (int i2 = 0; i2 < numStates(i); i2++) {
                        d += this.allowedForwardEdgesExpectedCounts[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float[] allowedForwardEdgesExpectedCounts(int i, int i2) {
            return this.allowedForwardEdgesExpectedCounts[i][i2];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public float logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public int numStates(int i) {
            return this.lattice.numStates(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Float>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Float>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryLattice.class */
    public interface StationaryLattice {
        int numSequences();

        int sequenceLength(int i);

        int numStates(int i);

        float nodeLogPotential(int i, int i2, int i3);

        float[] allowedEdgesLogPotentials(int i, int i2, boolean z);

        float nodePotential(int i, int i2, int i3);

        float[] allowedEdgesPotentials(int i, int i2, boolean z);

        int[] allowedEdges(int i, int i2, boolean z);
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryLatticeWrapper.class */
    public static class StationaryLatticeWrapper implements Lattice {
        StationaryLattice lattice;

        public StationaryLatticeWrapper(StationaryLattice stationaryLattice) {
            this.lattice = stationaryLattice;
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public int numStates(int i, int i2) {
            return this.lattice.numStates(i);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public float nodeLogPotential(int i, int i2, int i3) {
            return this.lattice.nodeLogPotential(i, i2, i3);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public float[] allowedEdgesLogPotentials(int i, int i2, int i3, boolean z) {
            return this.lattice.allowedEdgesLogPotentials(i, i3, z);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public float nodePotential(int i, int i2, int i3) {
            return this.lattice.nodePotential(i, i2, i3);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public float[] allowedEdgesPotentials(int i, int i2, int i3, boolean z) {
            return this.lattice.allowedEdgesPotentials(i, i3, z);
        }

        @Override // tberg.murphy.floatsequence.ForwardBackward.Lattice
        public int[] allowedEdges(int i, int i2, int i3, boolean z) {
            return this.lattice.allowedEdges(i, i3, z);
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/floatsequence/ForwardBackward$StationaryStateProjector.class */
    public interface StationaryStateProjector {
        int domainSize(int i, int i2);

        int rangeSize(int i);

        int project(int i, int i2, int i3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static float getScaleFactor(float f) {
        if (f == Cropper.VERT_GROW_RATIO) {
            return 1.0f;
        }
        return ((double) f) == 1.0d ? SCALE : ((double) f) == 2.0d ? SCALE * SCALE : ((double) f) == 3.0d ? SCALE * SCALE * SCALE : ((double) f) == -1.0d ? 1.0f * INVSCALE : ((double) f) == -2.0d ? 1.0f * INVSCALE * INVSCALE : ((double) f) == -3.0d ? 1.0f * INVSCALE * INVSCALE * INVSCALE : (float) Math.pow(SCALE, f);
    }

    public static Pair<NodeMarginals, StationaryEdgeMarginals> computeMarginalsLogSpace(final StationaryLattice stationaryLattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsLogSpace nodeMarginalsLogSpace = new NodeMarginalsLogSpace(new StationaryLatticeWrapper(stationaryLattice), stationaryStateProjector);
        final StationaryEdgeMarginalsLogSpace stationaryEdgeMarginalsLogSpace = z ? null : new StationaryEdgeMarginalsLogSpace(stationaryLattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.1
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                float[][] doPassLogSpace = ForwardBackward.doPassLogSpace(new StationaryLatticeWrapper(StationaryLattice.this), false, z, num.intValue());
                float[][] doPassLogSpace2 = ForwardBackward.doPassLogSpace(new StationaryLatticeWrapper(StationaryLattice.this), true, z, num.intValue());
                nodeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue(), z);
                if (z) {
                    return;
                }
                stationaryEdgeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < stationaryLattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsLogSpace, stationaryEdgeMarginalsLogSpace);
    }

    public static Pair<NodeMarginals, NonStationaryEdgeMarginals> computeMarginalsLogSpace(final Lattice lattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsLogSpace nodeMarginalsLogSpace = new NodeMarginalsLogSpace(lattice, stationaryStateProjector);
        final NonStationaryEdgeMarginalsLogSpace nonStationaryEdgeMarginalsLogSpace = z ? null : new NonStationaryEdgeMarginalsLogSpace(lattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.2
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                float[][] doPassLogSpace = ForwardBackward.doPassLogSpace(Lattice.this, false, z, num.intValue());
                float[][] doPassLogSpace2 = ForwardBackward.doPassLogSpace(Lattice.this, true, z, num.intValue());
                nodeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue(), z);
                if (z) {
                    return;
                }
                nonStationaryEdgeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsLogSpace, nonStationaryEdgeMarginalsLogSpace);
    }

    public static Pair<NodeMarginals, StationaryEdgeMarginals> computeMarginalsScaling(final StationaryLattice stationaryLattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsScaling nodeMarginalsScaling = new NodeMarginalsScaling(new StationaryLatticeWrapper(stationaryLattice), stationaryStateProjector);
        final StationaryEdgeMarginalsScaling stationaryEdgeMarginalsScaling = z ? null : new StationaryEdgeMarginalsScaling(stationaryLattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.3
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                Pair doPassScaling = ForwardBackward.doPassScaling(new StationaryLatticeWrapper(StationaryLattice.this), false, z, num.intValue());
                float[][] fArr = (float[][]) doPassScaling.getFirst();
                float[] fArr2 = (float[]) doPassScaling.getSecond();
                Pair doPassScaling2 = ForwardBackward.doPassScaling(new StationaryLatticeWrapper(StationaryLattice.this), true, z, num.intValue());
                float[][] fArr3 = (float[][]) doPassScaling2.getFirst();
                float[] fArr4 = (float[]) doPassScaling2.getSecond();
                nodeMarginalsScaling.incrementExpectedCounts(fArr, fArr2, fArr3, fArr4, num.intValue(), z);
                if (z) {
                    return;
                }
                stationaryEdgeMarginalsScaling.incrementExpectedCounts(fArr, fArr2, fArr3, fArr4, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < stationaryLattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsScaling, stationaryEdgeMarginalsScaling);
    }

    public static Pair<NodeMarginals, NonStationaryEdgeMarginals> computeMarginalsScaling(final Lattice lattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsScaling nodeMarginalsScaling = new NodeMarginalsScaling(lattice, stationaryStateProjector);
        final NonStationaryEdgeMarginalsScaling nonStationaryEdgeMarginalsScaling = z ? null : new NonStationaryEdgeMarginalsScaling(lattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.4
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                Pair doPassScaling = ForwardBackward.doPassScaling(Lattice.this, false, z, num.intValue());
                float[][] fArr = (float[][]) doPassScaling.getFirst();
                float[] fArr2 = (float[]) doPassScaling.getSecond();
                Pair doPassScaling2 = ForwardBackward.doPassScaling(Lattice.this, true, z, num.intValue());
                float[][] fArr3 = (float[][]) doPassScaling2.getFirst();
                float[] fArr4 = (float[]) doPassScaling2.getSecond();
                nodeMarginalsScaling.incrementExpectedCounts(fArr, fArr2, fArr3, fArr4, num.intValue(), z);
                if (z) {
                    return;
                }
                nonStationaryEdgeMarginalsScaling.incrementExpectedCounts(fArr, fArr2, fArr3, fArr4, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsScaling, nonStationaryEdgeMarginalsScaling);
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsScaling(final Lattice lattice, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.5
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, null, (float[][]) ForwardBackward.doPassScaling(Lattice.this, false, true, num.intValue()).getFirst(), true, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsScaling(final Lattice lattice, final StationaryStateProjector stationaryStateProjector, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.6
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, stationaryStateProjector, (float[][]) ForwardBackward.doPassScaling(Lattice.this, false, true, num.intValue()).getFirst(), true, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsLogSpace(final Lattice lattice, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.7
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, null, ForwardBackward.doPassLogSpace(Lattice.this, false, true, num.intValue()), false, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsLogSpace(final Lattice lattice, final StationaryStateProjector stationaryStateProjector, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.floatsequence.ForwardBackward.8
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, stationaryStateProjector, ForwardBackward.doPassLogSpace(Lattice.this, false, true, num.intValue()), false, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] extractViterbiPath(Lattice lattice, StationaryStateProjector stationaryStateProjector, float[][] fArr, boolean z, int i) {
        int[] iArr = new int[lattice.sequenceLength(i)];
        iArr[lattice.sequenceLength(i) - 1] = a.argmax(fArr[lattice.sequenceLength(i) - 1]);
        for (int sequenceLength = lattice.sequenceLength(i) - 2; sequenceLength >= 0; sequenceLength--) {
            int i2 = iArr[sequenceLength + 1];
            int[] allowedEdges = lattice.allowedEdges(i, sequenceLength + 1, i2, true);
            int i3 = -1;
            float f = Float.NEGATIVE_INFINITY;
            if (z) {
                float[] allowedEdgesPotentials = lattice.allowedEdgesPotentials(i, sequenceLength + 1, i2, true);
                for (int i4 = 0; i4 < allowedEdges.length; i4++) {
                    int i5 = allowedEdges[i4];
                    float f2 = fArr[sequenceLength][i5] * allowedEdgesPotentials[i4];
                    if (f2 > f) {
                        f = f2;
                        i3 = i5;
                    }
                }
            } else {
                float[] allowedEdgesLogPotentials = lattice.allowedEdgesLogPotentials(i, sequenceLength + 1, i2, true);
                for (int i6 = 0; i6 < allowedEdges.length; i6++) {
                    int i7 = allowedEdges[i6];
                    float f3 = fArr[sequenceLength][i7] + allowedEdgesLogPotentials[i6];
                    if (f3 > f) {
                        f = f3;
                        i3 = i7;
                    }
                }
            }
            iArr[sequenceLength] = i3;
        }
        if (stationaryStateProjector != null) {
            for (int i8 = 0; i8 < iArr.length; i8++) {
                iArr[i8] = stationaryStateProjector.project(i, i8, iArr[i8]);
            }
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [float[], float[][]] */
    public static float[][] doPassLogSpace(Lattice lattice, boolean z, boolean z2, int i) {
        ?? r0 = new float[lattice.sequenceLength(i)];
        for (int i2 = 0; i2 < lattice.sequenceLength(i); i2++) {
            r0[i2] = new float[lattice.numStates(i, i2)];
        }
        int[] enumerate = z ? a.enumerate(lattice.sequenceLength(i), 0) : a.enumerate(0, lattice.sequenceLength(i));
        for (int i3 = 0; i3 < enumerate.length; i3++) {
            int i4 = enumerate[i3];
            Arrays.fill(r0[i4], Float.NEGATIVE_INFINITY);
            if (i3 == 0) {
                int numStates = lattice.numStates(i, i4);
                for (int i5 = 0; i5 < numStates; i5++) {
                    r0[i4][i5] = lattice.nodeLogPotential(i, i4, i5);
                }
            } else {
                int i6 = enumerate[i3 - 1];
                int numStates2 = lattice.numStates(i, i6);
                for (int i7 = 0; i7 < numStates2; i7++) {
                    char c = r0[i6][i7];
                    int[] allowedEdges = lattice.allowedEdges(i, i6, i7, z);
                    float[] allowedEdgesLogPotentials = lattice.allowedEdgesLogPotentials(i, i6, i7, z);
                    float[] fArr = r0[i4];
                    for (int i8 = 0; i8 < allowedEdges.length; i8++) {
                        int i9 = allowedEdges[i8];
                        float f = allowedEdgesLogPotentials[i8];
                        if (z2) {
                            fArr[i9] = Math.max((float) fArr[i9], c + f);
                        } else {
                            fArr[i9] = m.logAdd((float) fArr[i9], c + f);
                        }
                    }
                }
                for (int i10 = 0; i10 < lattice.numStates(i, i4); i10++) {
                    float[] fArr2 = r0[i4];
                    int i11 = i10;
                    fArr2[i11] = fArr2[i11] + lattice.nodeLogPotential(i, i4, i10);
                }
            }
        }
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public static Pair<float[][], float[]> doPassScaling(Lattice lattice, boolean z, boolean z2, int i) {
        float[] fArr = new float[lattice.sequenceLength(i)];
        float[] fArr2 = new float[lattice.sequenceLength(i)];
        for (int i2 = 0; i2 < lattice.sequenceLength(i); i2++) {
            fArr2[i2] = new float[lattice.numStates(i, i2)];
        }
        int[] enumerate = z ? a.enumerate(lattice.sequenceLength(i), 0) : a.enumerate(0, lattice.sequenceLength(i));
        for (int i3 = 0; i3 < enumerate.length; i3++) {
            int i4 = enumerate[i3];
            Arrays.fill(fArr2[i4], 0.0f);
            float f = Float.NEGATIVE_INFINITY;
            if (i3 == 0) {
                int numStates = lattice.numStates(i, i4);
                for (int i5 = 0; i5 < numStates; i5++) {
                    float nodePotential = lattice.nodePotential(i, i4, i5);
                    fArr2[i4][i5] = nodePotential;
                    if (nodePotential > f) {
                        f = nodePotential;
                    }
                }
            } else {
                int i6 = enumerate[i3 - 1];
                int numStates2 = lattice.numStates(i, i6);
                for (int i7 = 0; i7 < numStates2; i7++) {
                    char c = fArr2[i6][i7];
                    int[] allowedEdges = lattice.allowedEdges(i, i6, i7, z);
                    float[] allowedEdgesPotentials = lattice.allowedEdgesPotentials(i, i6, i7, z);
                    float[] fArr3 = fArr2[i4];
                    for (int i8 = 0; i8 < allowedEdges.length; i8++) {
                        int i9 = allowedEdges[i8];
                        float f2 = allowedEdgesPotentials[i8];
                        if (z2) {
                            fArr3[i9] = Math.max((float) fArr3[i9], c * f2);
                        } else {
                            fArr3[i9] = fArr3[i9] + (c * f2);
                        }
                    }
                }
                for (int i10 = 0; i10 < lattice.numStates(i, i4); i10++) {
                    float[] fArr4 = fArr2[i4];
                    int i11 = i10;
                    fArr4[i11] = fArr4[i11] * lattice.nodePotential(i, i4, i10);
                    char c2 = fArr2[i4][i10];
                    if (c2 > f) {
                        f = c2;
                    }
                }
            }
            int i12 = 0;
            float f3 = 1.0f;
            while (f > SCALE) {
                f /= SCALE;
                f3 *= SCALE;
                i12++;
            }
            while (f > Cropper.VERT_GROW_RATIO && f < 1.0d / SCALE) {
                f *= SCALE;
                f3 /= SCALE;
                i12--;
            }
            if (i12 != 0) {
                for (int i13 = 0; i13 < lattice.numStates(i, i4); i13++) {
                    float[] fArr5 = fArr2[i4];
                    int i14 = i13;
                    fArr5[i14] = fArr5[i14] / f3;
                }
            }
            if (i3 == 0) {
                fArr[i4] = i12;
            } else {
                fArr[i4] = fArr[enumerate[i3 - 1]] + i12;
            }
        }
        return Pair.makePair(fArr2, fArr);
    }
}
