package tberg.murphy.sequence;

import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import java.util.Arrays;
import java.util.Iterator;
import tberg.murphy.arrays.a;
import tberg.murphy.math.m;
import tberg.murphy.threading.BetterThreader;
import tberg.murphy.tuple.Pair;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward.class */
public class ForwardBackward {
    public static final double SCALE = Math.exp(100.0d);
    public static final double INVSCALE = 1.0d / SCALE;
    public static final double LOG_SCALE = Math.log(SCALE);

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$Lattice.class */
    public interface Lattice {
        int numSequences();

        int sequenceLength(int i);

        int numStates(int i, int i2);

        double nodeLogPotential(int i, int i2, int i3);

        double[] allowedEdgesLogPotentials(int i, int i2, int i3, boolean z);

        double nodePotential(int i, int i2, int i3);

        double[] allowedEdgesPotentials(int i, int i2, int i3, boolean z);

        int[] allowedEdges(int i, int i2, int i3, boolean z);
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NodeMarginals.class */
    public interface NodeMarginals {
        double[] nodeCondProbs(int i, int i2);

        double sequenceLogMarginalProb(int i);

        double logMarginalProb();

        int numSequences();

        int sequenceLength(int i);

        int numStates(int i);

        Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Double>> getNodeMarginalsIterator();

        double estimateMemoryUsage();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NodeMarginalsLogSpace.class */
    public static class NodeMarginalsLogSpace implements NodeMarginals {
        Lattice lattice;
        double[][][] nodeCondProbs;
        double[] sequenceLogMarginalProbs;
        StationaryStateProjector stateProjector;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NodeMarginalsLogSpace$NodeMarginalsIterator.class */
        private class NodeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Double>> {
            int d;
            int t;
            int s;
            double[] nodeCondProbs;

            private NodeMarginalsIterator() {
                this.nodeCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NodeMarginalsLogSpace.this.numSequences() - 1 && this.t == NodeMarginalsLogSpace.this.sequenceLength(this.d) - 1 && this.s == NodeMarginalsLogSpace.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Integer>, Double> next() {
                if (this.nodeCondProbs == null) {
                    this.d = 0;
                    this.t = 0;
                    this.s = 0;
                    this.nodeCondProbs = NodeMarginalsLogSpace.this.nodeCondProbs(0, 0);
                } else if (this.s == NodeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s = 0;
                    if (this.t == NodeMarginalsLogSpace.this.sequenceLength(this.d) - 1) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                    this.nodeCondProbs = NodeMarginalsLogSpace.this.nodeCondProbs(this.d, this.t);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Integer.valueOf(this.s)), Double.valueOf(this.nodeCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v7, types: [double[][], double[][][]] */
        public NodeMarginalsLogSpace(Lattice lattice, StationaryStateProjector stationaryStateProjector) {
            this.lattice = lattice;
            this.stateProjector = stationaryStateProjector;
            this.sequenceLogMarginalProbs = new double[lattice.numSequences()];
            this.nodeCondProbs = new double[lattice.numSequences()];
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(double[][] dArr, double[][] dArr2, int i, boolean z) {
            this.sequenceLogMarginalProbs[i] = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                if (z) {
                    this.sequenceLogMarginalProbs[i] = Math.max(this.sequenceLogMarginalProbs[i], dArr2[0][i2]);
                } else {
                    this.sequenceLogMarginalProbs[i] = m.logAdd(this.sequenceLogMarginalProbs[i], dArr2[0][i2]);
                }
            }
            this.nodeCondProbs[i] = new double[this.lattice.sequenceLength(i)];
            for (int i3 = 0; i3 < this.lattice.sequenceLength(i); i3++) {
                this.nodeCondProbs[i][i3] = new double[this.stateProjector.rangeSize(i)];
                if (z) {
                    Arrays.fill(this.nodeCondProbs[i][i3], Double.NEGATIVE_INFINITY);
                }
                int numStates = this.lattice.numStates(i, i3);
                for (int i4 = 0; i4 < numStates; i4++) {
                    int project = this.stateProjector.project(i, i3, i4);
                    if (dArr[i3][i4] != Double.NEGATIVE_INFINITY && dArr2[i3][i4] != Double.NEGATIVE_INFINITY) {
                        if (z) {
                            this.nodeCondProbs[i][i3][project] = Math.max(this.nodeCondProbs[i][i3][project], Math.exp(((dArr[i3][i4] - this.lattice.nodeLogPotential(i, i3, i4)) + dArr2[i3][i4]) - this.sequenceLogMarginalProbs[i]));
                        } else {
                            double[] dArr3 = this.nodeCondProbs[i][i3];
                            dArr3[project] = dArr3[project] + Math.exp(((dArr[i3][i4] - this.lattice.nodeLogPotential(i, i3, i4)) + dArr2[i3][i4]) - this.sequenceLogMarginalProbs[i]);
                        }
                    }
                }
            }
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.nodeCondProbs[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.nodeCondProbs[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double[] nodeCondProbs(int i, int i2) {
            return this.nodeCondProbs[i][i2];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public int numStates(int i) {
            return this.stateProjector.rangeSize(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Double>> getNodeMarginalsIterator() {
            return new NodeMarginalsIterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NodeMarginalsScaling.class */
    public static class NodeMarginalsScaling implements NodeMarginals {
        Lattice lattice;
        double[][][] nodeCondProbs;
        double[] sequenceLogMarginalProbs;
        StationaryStateProjector stateProjector;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NodeMarginalsScaling$NodeMarginalsIterator.class */
        private class NodeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Double>> {
            int d;
            int t;
            int s;
            double[] nodeCondProbs;

            private NodeMarginalsIterator() {
                this.nodeCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NodeMarginalsScaling.this.numSequences() - 1 && this.t == NodeMarginalsScaling.this.sequenceLength(this.d) - 1 && this.s == NodeMarginalsScaling.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Integer>, Double> next() {
                if (this.nodeCondProbs == null) {
                    this.d = 0;
                    this.t = 0;
                    this.s = 0;
                    this.nodeCondProbs = NodeMarginalsScaling.this.nodeCondProbs(0, 0);
                } else if (this.s == NodeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s = 0;
                    if (this.t == NodeMarginalsScaling.this.sequenceLength(this.d) - 1) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                    this.nodeCondProbs = NodeMarginalsScaling.this.nodeCondProbs(this.d, this.t);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Integer.valueOf(this.s)), Double.valueOf(this.nodeCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v7, types: [double[][], double[][][]] */
        public NodeMarginalsScaling(Lattice lattice, StationaryStateProjector stationaryStateProjector) {
            this.lattice = lattice;
            this.stateProjector = stationaryStateProjector;
            this.sequenceLogMarginalProbs = new double[lattice.numSequences()];
            this.nodeCondProbs = new double[lattice.numSequences()];
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, int i, boolean z) {
            double d = z ? Double.NEGATIVE_INFINITY : Cropper.VERT_GROW_RATIO;
            double d2 = dArr4[0];
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                if (this.lattice.nodePotential(i, 0, i2) > Cropper.VERT_GROW_RATIO) {
                    d = z ? Math.max(d, dArr3[0][i2]) : d + dArr3[0][i2];
                }
            }
            this.sequenceLogMarginalProbs[i] = (d2 * ForwardBackward.LOG_SCALE) + Math.log(d);
            this.nodeCondProbs[i] = new double[this.lattice.sequenceLength(i)];
            for (int i3 = 0; i3 < this.lattice.sequenceLength(i); i3++) {
                this.nodeCondProbs[i][i3] = new double[this.stateProjector.rangeSize(i)];
                if (z) {
                    Arrays.fill(this.nodeCondProbs[i][i3], Double.NEGATIVE_INFINITY);
                }
                int numStates = this.lattice.numStates(i, i3);
                for (int i4 = 0; i4 < numStates; i4++) {
                    double scaleFactor = ForwardBackward.getScaleFactor((dArr2[i3] + dArr4[i3]) - d2);
                    int project = this.stateProjector.project(i, i3, i4);
                    if (dArr[i3][i4] != Cropper.VERT_GROW_RATIO && dArr3[i3][i4] != Cropper.VERT_GROW_RATIO) {
                        if (z) {
                            this.nodeCondProbs[i][i3][project] = Math.max(this.nodeCondProbs[i][i3][project], (dArr[i3][i4] / this.lattice.nodePotential(i, i3, i4)) * (dArr3[i3][i4] / d) * scaleFactor);
                        } else {
                            double[] dArr5 = this.nodeCondProbs[i][i3];
                            dArr5[project] = dArr5[project] + ((dArr[i3][i4] / this.lattice.nodePotential(i, i3, i4)) * (dArr3[i3][i4] / d) * scaleFactor);
                        }
                    }
                }
            }
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.nodeCondProbs[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.nodeCondProbs[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public double[] nodeCondProbs(int i, int i2) {
            return this.nodeCondProbs[i][i2];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public int numStates(int i) {
            return this.stateProjector.rangeSize(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NodeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Integer>, Double>> getNodeMarginalsIterator() {
            return new NodeMarginalsIterator();
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginals.class */
    public interface NonStationaryEdgeMarginals {
        double[] startNodeCondProbs(int i);

        double[] endNodeCondProbs(int i);

        int[] allowedForwardEdges(int i, int i2, int i3);

        double[] allowedForwardEdgesExpectedCounts(int i, int i2, int i3);

        double sequenceLogMarginalProb(int i);

        double logMarginalProb();

        int numSequences();

        int sequenceLength(int i);

        int numStates(int i, int i2);

        Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double>> getEdgeMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Double>> getStartMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Double>> getEndMarginalsIterator();

        double estimateMemoryUsage();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace.class */
    public static class NonStationaryEdgeMarginalsLogSpace implements NonStationaryEdgeMarginals {
        Lattice lattice;
        double[] sequenceLogMarginalProbs;
        double[][] startNodeCondProbs;
        double[][] endNodeCondProbs;
        double[][][] allAlphas;
        double[][][] allBetas;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double>> {
            int d = 0;
            int t = 0;
            int s1 = 0;
            int s2i = 0;
            double[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                while (this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double> next() {
                Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double> makePair = Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.t, this.s1)[this.s2i]))), Double.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.t, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, this.t) - 1) {
                    this.s1 = 0;
                    if (this.t >= NonStationaryEdgeMarginalsLogSpace.this.sequenceLength(this.d) - 2) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                } else {
                    this.s1++;
                }
                if (this.d < NonStationaryEdgeMarginalsLogSpace.this.numSequences()) {
                    this.edgeCondProbs = NonStationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, NonStationaryEdgeMarginalsLogSpace.this.sequenceLength(this.d) - 1) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = NonStationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, NonStationaryEdgeMarginalsLogSpace.this.sequenceLength(this.d) - 1) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = NonStationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsLogSpace$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, 0) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = NonStationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsLogSpace.this.numStates(this.d, 0) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = NonStationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [double[][], double[][][]] */
        /* JADX WARN: Type inference failed for: r1v15, types: [double[][], double[][][]] */
        /* JADX WARN: Type inference failed for: r1v6, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [double[], double[][]] */
        public NonStationaryEdgeMarginalsLogSpace(Lattice lattice) {
            this.lattice = lattice;
            this.sequenceLogMarginalProbs = new double[lattice.numSequences()];
            this.startNodeCondProbs = new double[lattice.numSequences()];
            this.endNodeCondProbs = new double[lattice.numSequences()];
            this.allAlphas = new double[lattice.numSequences()];
            this.allBetas = new double[lattice.numSequences()];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2, int i3) {
            return this.lattice.allowedEdges(i, i2, i3, false);
        }

        public void incrementExpectedCounts(double[][] dArr, double[][] dArr2, int i) {
            this.allAlphas[i] = dArr;
            this.allBetas[i] = dArr2;
            this.sequenceLogMarginalProbs[i] = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                this.sequenceLogMarginalProbs[i] = m.logAdd(this.sequenceLogMarginalProbs[i], dArr2[0][i2]);
            }
            this.startNodeCondProbs[i] = new double[this.lattice.numStates(i, 0)];
            for (int i3 = 0; i3 < this.lattice.numStates(i, 0); i3++) {
                this.startNodeCondProbs[i][i3] = Math.exp(dArr2[0][i3] - this.sequenceLogMarginalProbs[i]);
            }
            this.endNodeCondProbs[i] = new double[this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1)];
            for (int i4 = 0; i4 < this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1); i4++) {
                this.endNodeCondProbs[i][i4] = Math.exp(dArr[this.lattice.sequenceLength(i) - 1][i4] - this.sequenceLogMarginalProbs[i]);
            }
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allAlphas[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.allAlphas[i][i2].length;
                    }
                }
            }
            for (int i3 = 0; i3 < numSequences(); i3++) {
                if (this.allBetas[i3] != null) {
                    for (int i4 = 0; i4 < sequenceLength(i3); i4++) {
                        d += this.allBetas[i3][i4].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double[] allowedForwardEdgesExpectedCounts(int i, int i2, int i3) {
            int[] allowedEdges = this.lattice.allowedEdges(i, i2, i3, false);
            double[] dArr = new double[allowedEdges.length];
            double[] allowedEdgesLogPotentials = this.lattice.allowedEdgesLogPotentials(i, i2, i3, false);
            for (int i4 = 0; i4 < allowedEdges.length; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] + Math.exp(((this.allAlphas[i][i2][i3] + allowedEdgesLogPotentials[i4]) + this.allBetas[i][i2 + 1][allowedEdges[i4]]) - this.sequenceLogMarginalProbs[i]);
            }
            return dArr;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numStates(int i, int i2) {
            return this.lattice.numStates(i, i2);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsScaling.class */
    public static class NonStationaryEdgeMarginalsScaling implements NonStationaryEdgeMarginals {
        Lattice lattice;
        double[] sequenceLogMarginalProbs;
        double[] sequenceMarginalProbs;
        double[] sequenceMarginalProbLogScales;
        double[][] startNodeCondProbs;
        double[][] endNodeCondProbs;
        double[][][] allAlphas;
        double[][][] allBetas;
        double[][] allAlphaLogScales;
        double[][] allBetaLogScales;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsScaling$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double>> {
            int d = 0;
            int t = 0;
            int s1 = 0;
            int s2i = 0;
            double[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = NonStationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                while (this.d < NonStationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < NonStationaryEdgeMarginalsScaling.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double> next() {
                Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double> makePair = Pair.makePair(Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.t)), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(NonStationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.t, this.s1)[this.s2i]))), Double.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < NonStationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < NonStationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.t, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= NonStationaryEdgeMarginalsScaling.this.numStates(this.d, this.t) - 1) {
                    this.s1 = 0;
                    if (this.t >= NonStationaryEdgeMarginalsScaling.this.sequenceLength(this.d) - 2) {
                        this.t = 0;
                        this.d++;
                    } else {
                        this.t++;
                    }
                } else {
                    this.s1++;
                }
                if (this.d < NonStationaryEdgeMarginalsScaling.this.numSequences()) {
                    this.edgeCondProbs = NonStationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.t, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsScaling$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, NonStationaryEdgeMarginalsScaling.this.sequenceLength(this.d) - 1) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = NonStationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, NonStationaryEdgeMarginalsScaling.this.sequenceLength(this.d) - 1) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = NonStationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$NonStationaryEdgeMarginalsScaling$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == NonStationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, 0) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = NonStationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else if (this.s == NonStationaryEdgeMarginalsScaling.this.numStates(this.d, 0) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = NonStationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v15, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v18, types: [double[][], double[][][]] */
        /* JADX WARN: Type inference failed for: r1v21, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v24, types: [double[], double[][]] */
        public NonStationaryEdgeMarginalsScaling(Lattice lattice) {
            this.lattice = lattice;
            this.sequenceLogMarginalProbs = new double[lattice.numSequences()];
            this.sequenceMarginalProbs = new double[lattice.numSequences()];
            this.sequenceMarginalProbLogScales = new double[lattice.numSequences()];
            this.startNodeCondProbs = new double[lattice.numSequences()];
            this.endNodeCondProbs = new double[lattice.numSequences()];
            this.allAlphas = new double[lattice.numSequences()];
            this.allAlphaLogScales = new double[lattice.numSequences()];
            this.allBetaLogScales = new double[lattice.numSequences()];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2, int i3) {
            return this.lattice.allowedEdges(i, i2, i3, false);
        }

        public void incrementExpectedCounts(double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, int i) {
            this.allAlphas[i] = dArr;
            this.allAlphaLogScales[i] = dArr2;
            this.allBetas[i] = dArr3;
            this.allBetaLogScales[i] = dArr4;
            this.sequenceMarginalProbs[i] = 0.0d;
            this.sequenceMarginalProbLogScales[i] = dArr4[0];
            for (int i2 = 0; i2 < this.lattice.numStates(i, 0); i2++) {
                if (this.lattice.nodePotential(i, 0, i2) > Cropper.VERT_GROW_RATIO) {
                    double[] dArr5 = this.sequenceMarginalProbs;
                    dArr5[i] = dArr5[i] + dArr3[0][i2];
                }
            }
            this.sequenceLogMarginalProbs[i] = (this.sequenceMarginalProbLogScales[i] * ForwardBackward.LOG_SCALE) + Math.log(this.sequenceMarginalProbs[i]);
            this.startNodeCondProbs[i] = new double[this.lattice.numStates(i, 0)];
            double scaleFactor = ForwardBackward.getScaleFactor(dArr4[0] - this.sequenceMarginalProbLogScales[i]);
            for (int i3 = 0; i3 < this.lattice.numStates(i, 0); i3++) {
                this.startNodeCondProbs[i][i3] = (dArr3[0][i3] / this.sequenceMarginalProbs[i]) * scaleFactor;
            }
            this.endNodeCondProbs[i] = new double[this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1)];
            double scaleFactor2 = ForwardBackward.getScaleFactor(dArr2[this.lattice.sequenceLength(i) - 1] - this.sequenceMarginalProbLogScales[i]);
            for (int i4 = 0; i4 < this.lattice.numStates(i, this.lattice.sequenceLength(i) - 1); i4++) {
                this.endNodeCondProbs[i][i4] = (dArr[this.lattice.sequenceLength(i) - 1][i4] / this.sequenceMarginalProbs[i]) * scaleFactor2;
            }
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allAlphas[i] != null) {
                    for (int i2 = 0; i2 < sequenceLength(i); i2++) {
                        d += this.allAlphas[i][i2].length;
                    }
                }
            }
            for (int i3 = 0; i3 < numSequences(); i3++) {
                if (this.allBetas[i3] != null) {
                    for (int i4 = 0; i4 < sequenceLength(i3); i4++) {
                        d += this.allBetas[i3][i4].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double[] allowedForwardEdgesExpectedCounts(int i, int i2, int i3) {
            int[] allowedEdges = this.lattice.allowedEdges(i, i2, i3, false);
            double[] dArr = new double[allowedEdges.length];
            double[] allowedEdgesPotentials = this.lattice.allowedEdgesPotentials(i, i2, i3, false);
            double scaleFactor = ForwardBackward.getScaleFactor((this.allAlphaLogScales[i][i2] + this.allBetaLogScales[i][i2 + 1]) - this.sequenceMarginalProbLogScales[i]);
            for (int i4 = 0; i4 < allowedEdges.length; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] + ((((this.allAlphas[i][i2][i3] * allowedEdgesPotentials[i4]) * this.allBetas[i][i2 + 1][allowedEdges[i4]]) / this.sequenceMarginalProbs[i]) * scaleFactor);
            }
            return dArr;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public double logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public int numStates(int i, int i2) {
            return this.lattice.numStates(i, i2);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Pair<Integer, Integer>, Pair<Integer, Integer>>, Double>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.NonStationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginals.class */
    public interface StationaryEdgeMarginals {
        double[] startNodeCondProbs(int i);

        double[] endNodeCondProbs(int i);

        int[] allowedForwardEdges(int i, int i2);

        double[] allowedForwardEdgesExpectedCounts(int i, int i2);

        double sequenceLogMarginalProb(int i);

        double logMarginalProb();

        int numSequences();

        int numStates(int i);

        Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Double>> getEdgeMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Double>> getStartMarginalsIterator();

        Iterator<Pair<Pair<Integer, Integer>, Double>> getEndMarginalsIterator();

        double estimateMemoryUsage();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsLogSpace.class */
    public static class StationaryEdgeMarginalsLogSpace implements StationaryEdgeMarginals {
        StationaryLattice lattice;
        double[] sequenceLogMarginalProbs;
        double[][][] allowedForwardEdgesExpectedCounts;
        double[][] startNodeCondProbs;
        double[][] endNodeCondProbs;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsLogSpace$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Double>> {
            int d = 0;
            int s1 = 0;
            int s2i = 0;
            double[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = StationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                while (this.d < StationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < StationaryEdgeMarginalsLogSpace.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Pair<Integer, Integer>>, Double> next() {
                Pair<Pair<Integer, Pair<Integer, Integer>>, Double> makePair = Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(StationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.s1)[this.s2i]))), Double.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < StationaryEdgeMarginalsLogSpace.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < StationaryEdgeMarginalsLogSpace.this.allowedForwardEdges(this.d, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s1 = 0;
                    this.d++;
                } else {
                    this.s1++;
                }
                if (this.d < StationaryEdgeMarginalsLogSpace.this.numSequences()) {
                    this.edgeCondProbs = StationaryEdgeMarginalsLogSpace.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsLogSpace$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = StationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = StationaryEdgeMarginalsLogSpace.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsLogSpace$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsLogSpace.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = StationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsLogSpace.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = StationaryEdgeMarginalsLogSpace.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v6, types: [double[][], double[][][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [double[], double[][]] */
        public StationaryEdgeMarginalsLogSpace(StationaryLattice stationaryLattice) {
            this.lattice = stationaryLattice;
            this.sequenceLogMarginalProbs = new double[stationaryLattice.numSequences()];
            this.allowedForwardEdgesExpectedCounts = new double[stationaryLattice.numSequences()];
            this.startNodeCondProbs = new double[stationaryLattice.numSequences()];
            this.endNodeCondProbs = new double[stationaryLattice.numSequences()];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2) {
            return this.lattice.allowedEdges(i, i2, false);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(double[][] dArr, double[][] dArr2, int i) {
            this.sequenceLogMarginalProbs[i] = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.lattice.numStates(i); i2++) {
                this.sequenceLogMarginalProbs[i] = m.logAdd(this.sequenceLogMarginalProbs[i], dArr2[0][i2]);
            }
            this.allowedForwardEdgesExpectedCounts[i] = new double[this.lattice.numStates(i)];
            for (int i3 = 0; i3 < this.lattice.numStates(i); i3++) {
                this.allowedForwardEdgesExpectedCounts[i][i3] = new double[this.lattice.allowedEdges(i, i3, false).length];
            }
            for (int i4 = 0; i4 < this.lattice.sequenceLength(i) - 1; i4++) {
                int numStates = this.lattice.numStates(i);
                for (int i5 = 0; i5 < numStates; i5++) {
                    int[] allowedEdges = this.lattice.allowedEdges(i, i5, false);
                    double[] allowedEdgesLogPotentials = this.lattice.allowedEdgesLogPotentials(i, i5, false);
                    for (int i6 = 0; i6 < allowedEdges.length; i6++) {
                        int i7 = allowedEdges[i6];
                        double d = allowedEdgesLogPotentials[i6];
                        double[] dArr3 = this.allowedForwardEdgesExpectedCounts[i][i5];
                        int i8 = i6;
                        dArr3[i8] = dArr3[i8] + Math.exp(((dArr[i4][i5] + d) + dArr2[i4 + 1][i7]) - this.sequenceLogMarginalProbs[i]);
                    }
                }
            }
            this.startNodeCondProbs[i] = new double[this.lattice.numStates(i)];
            for (int i9 = 0; i9 < this.lattice.numStates(i); i9++) {
                this.startNodeCondProbs[i][i9] = Math.exp(dArr2[0][i9] - this.sequenceLogMarginalProbs[i]);
            }
            this.endNodeCondProbs[i] = new double[this.lattice.numStates(i)];
            for (int i10 = 0; i10 < this.lattice.numStates(i); i10++) {
                this.endNodeCondProbs[i][i10] = Math.exp(dArr[this.lattice.sequenceLength(i) - 1][i10] - this.sequenceLogMarginalProbs[i]);
            }
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allowedForwardEdgesExpectedCounts[i] != null) {
                    for (int i2 = 0; i2 < numStates(i); i2++) {
                        d += this.allowedForwardEdgesExpectedCounts[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double[] allowedForwardEdgesExpectedCounts(int i, int i2) {
            return this.allowedForwardEdgesExpectedCounts[i][i2];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public int numStates(int i) {
            return this.lattice.numStates(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Double>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsScaling.class */
    public static class StationaryEdgeMarginalsScaling implements StationaryEdgeMarginals {
        StationaryLattice lattice;
        double[] sequenceLogMarginalProbs;
        double[][][] allowedForwardEdgesExpectedCounts;
        double[][] startNodeCondProbs;
        double[][] endNodeCondProbs;

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsScaling$EdgeMarginalsIterator.class */
        private class EdgeMarginalsIterator implements Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Double>> {
            int d = 0;
            int s1 = 0;
            int s2i = 0;
            double[] edgeCondProbs;

            public EdgeMarginalsIterator() {
                this.edgeCondProbs = null;
                this.edgeCondProbs = StationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                while (this.d < StationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.d < StationaryEdgeMarginalsScaling.this.numSequences();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Pair<Integer, Integer>>, Double> next() {
                Pair<Pair<Integer, Pair<Integer, Integer>>, Double> makePair = Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Pair.makePair(Integer.valueOf(this.s1), Integer.valueOf(StationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.s1)[this.s2i]))), Double.valueOf(this.edgeCondProbs[this.s2i]));
                advance();
                while (this.d < StationaryEdgeMarginalsScaling.this.numSequences() && this.edgeCondProbs.length == 0) {
                    advance();
                }
                return makePair;
            }

            private void advance() {
                if (this.s2i < StationaryEdgeMarginalsScaling.this.allowedForwardEdges(this.d, this.s1).length - 1) {
                    this.s2i++;
                    return;
                }
                this.s2i = 0;
                if (this.s1 >= StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s1 = 0;
                    this.d++;
                } else {
                    this.s1++;
                }
                if (this.d < StationaryEdgeMarginalsScaling.this.numSequences()) {
                    this.edgeCondProbs = StationaryEdgeMarginalsScaling.this.allowedForwardEdgesExpectedCounts(this.d, this.s1);
                }
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsScaling$EndMarginalsIterator.class */
        private class EndMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] endCondProbs;

            private EndMarginalsIterator() {
                this.endCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.endCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.endCondProbs = StationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.endCondProbs = StationaryEdgeMarginalsScaling.this.endNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.endCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryEdgeMarginalsScaling$StartMarginalsIterator.class */
        private class StartMarginalsIterator implements Iterator<Pair<Pair<Integer, Integer>, Double>> {
            int d;
            int s;
            double[] startCondProbs;

            private StartMarginalsIterator() {
                this.startCondProbs = null;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return (this.d == StationaryEdgeMarginalsScaling.this.numSequences() - 1 && this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) ? false : true;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public Pair<Pair<Integer, Integer>, Double> next() {
                if (this.startCondProbs == null) {
                    this.d = 0;
                    this.s = 0;
                    this.startCondProbs = StationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else if (this.s == StationaryEdgeMarginalsScaling.this.numStates(this.d) - 1) {
                    this.s = 0;
                    this.d++;
                    this.startCondProbs = StationaryEdgeMarginalsScaling.this.startNodeCondProbs(this.d);
                } else {
                    this.s++;
                }
                return Pair.makePair(Pair.makePair(Integer.valueOf(this.d), Integer.valueOf(this.s)), Double.valueOf(this.startCondProbs[this.s]));
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [double[], double[][]] */
        /* JADX WARN: Type inference failed for: r1v6, types: [double[][], double[][][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [double[], double[][]] */
        public StationaryEdgeMarginalsScaling(StationaryLattice stationaryLattice) {
            this.lattice = stationaryLattice;
            this.sequenceLogMarginalProbs = new double[stationaryLattice.numSequences()];
            this.allowedForwardEdgesExpectedCounts = new double[stationaryLattice.numSequences()];
            this.startNodeCondProbs = new double[stationaryLattice.numSequences()];
            this.endNodeCondProbs = new double[stationaryLattice.numSequences()];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public int[] allowedForwardEdges(int i, int i2) {
            return this.lattice.allowedEdges(i, i2, false);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void incrementExpectedCounts(double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, int i) {
            double d = 0.0d;
            double d2 = dArr4[0];
            for (int i2 = 0; i2 < this.lattice.numStates(i); i2++) {
                if (this.lattice.nodePotential(i, 0, i2) > Cropper.VERT_GROW_RATIO) {
                    d += dArr3[0][i2];
                }
            }
            this.sequenceLogMarginalProbs[i] = (d2 * ForwardBackward.LOG_SCALE) + Math.log(d);
            this.allowedForwardEdgesExpectedCounts[i] = new double[this.lattice.numStates(i)];
            for (int i3 = 0; i3 < this.lattice.numStates(i); i3++) {
                this.allowedForwardEdgesExpectedCounts[i][i3] = new double[this.lattice.allowedEdges(i, i3, false).length];
            }
            for (int i4 = 0; i4 < this.lattice.sequenceLength(i) - 1; i4++) {
                double scaleFactor = ForwardBackward.getScaleFactor((dArr2[i4] + dArr4[i4 + 1]) - d2);
                int numStates = this.lattice.numStates(i);
                for (int i5 = 0; i5 < numStates; i5++) {
                    int[] allowedEdges = this.lattice.allowedEdges(i, i5, false);
                    double[] allowedEdgesPotentials = this.lattice.allowedEdgesPotentials(i, i5, false);
                    for (int i6 = 0; i6 < allowedEdges.length; i6++) {
                        int i7 = allowedEdges[i6];
                        double d3 = allowedEdgesPotentials[i6];
                        double[] dArr5 = this.allowedForwardEdgesExpectedCounts[i][i5];
                        int i8 = i6;
                        dArr5[i8] = dArr5[i8] + ((((dArr[i4][i5] * d3) * dArr3[i4 + 1][i7]) / d) * scaleFactor);
                    }
                }
            }
            this.startNodeCondProbs[i] = new double[this.lattice.numStates(i)];
            double scaleFactor2 = ForwardBackward.getScaleFactor(dArr4[0] - d2);
            for (int i9 = 0; i9 < this.lattice.numStates(i); i9++) {
                this.startNodeCondProbs[i][i9] = (dArr3[0][i9] / d) * scaleFactor2;
            }
            this.endNodeCondProbs[i] = new double[this.lattice.numStates(i)];
            double scaleFactor3 = ForwardBackward.getScaleFactor(dArr2[this.lattice.sequenceLength(i) - 1] - d2);
            for (int i10 = 0; i10 < this.lattice.numStates(i); i10++) {
                this.endNodeCondProbs[i][i10] = (dArr[this.lattice.sequenceLength(i) - 1][i10] / d) * scaleFactor3;
            }
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double estimateMemoryUsage() {
            double d = 0.0d;
            for (int i = 0; i < numSequences(); i++) {
                if (this.allowedForwardEdgesExpectedCounts[i] != null) {
                    for (int i2 = 0; i2 < numStates(i); i2++) {
                        d += this.allowedForwardEdgesExpectedCounts[i][i2].length;
                    }
                }
            }
            return (8.0d * d) / 1.0E9d;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double[] allowedForwardEdgesExpectedCounts(int i, int i2) {
            return this.allowedForwardEdgesExpectedCounts[i][i2];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double[] startNodeCondProbs(int i) {
            return this.startNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double[] endNodeCondProbs(int i) {
            return this.endNodeCondProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double sequenceLogMarginalProb(int i) {
            return this.sequenceLogMarginalProbs[i];
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public double logMarginalProb() {
            return a.sum(this.sequenceLogMarginalProbs);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public int numStates(int i) {
            return this.lattice.numStates(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Pair<Integer, Integer>>, Double>> getEdgeMarginalsIterator() {
            return new EdgeMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getStartMarginalsIterator() {
            return new StartMarginalsIterator();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.StationaryEdgeMarginals
        public Iterator<Pair<Pair<Integer, Integer>, Double>> getEndMarginalsIterator() {
            return new EndMarginalsIterator();
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryLattice.class */
    public interface StationaryLattice {
        int numSequences();

        int sequenceLength(int i);

        int numStates(int i);

        double nodeLogPotential(int i, int i2, int i3);

        double[] allowedEdgesLogPotentials(int i, int i2, boolean z);

        double nodePotential(int i, int i2, int i3);

        double[] allowedEdgesPotentials(int i, int i2, boolean z);

        int[] allowedEdges(int i, int i2, boolean z);
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryLatticeWrapper.class */
    public static class StationaryLatticeWrapper implements Lattice {
        StationaryLattice lattice;

        public StationaryLatticeWrapper(StationaryLattice stationaryLattice) {
            this.lattice = stationaryLattice;
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public int numSequences() {
            return this.lattice.numSequences();
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public int sequenceLength(int i) {
            return this.lattice.sequenceLength(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public int numStates(int i, int i2) {
            return this.lattice.numStates(i);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public double nodeLogPotential(int i, int i2, int i3) {
            return this.lattice.nodeLogPotential(i, i2, i3);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public double[] allowedEdgesLogPotentials(int i, int i2, int i3, boolean z) {
            return this.lattice.allowedEdgesLogPotentials(i, i3, z);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public double nodePotential(int i, int i2, int i3) {
            return this.lattice.nodePotential(i, i2, i3);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public double[] allowedEdgesPotentials(int i, int i2, int i3, boolean z) {
            return this.lattice.allowedEdgesPotentials(i, i3, z);
        }

        @Override // tberg.murphy.sequence.ForwardBackward.Lattice
        public int[] allowedEdges(int i, int i2, int i3, boolean z) {
            return this.lattice.allowedEdges(i, i3, z);
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/ForwardBackward$StationaryStateProjector.class */
    public interface StationaryStateProjector {
        int domainSize(int i, int i2);

        int rangeSize(int i);

        int project(int i, int i2, int i3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double getScaleFactor(double d) {
        if (d == Cropper.VERT_GROW_RATIO) {
            return 1.0d;
        }
        return d == 1.0d ? SCALE : d == 2.0d ? SCALE * SCALE : d == 3.0d ? SCALE * SCALE * SCALE : d == -1.0d ? 1.0d * INVSCALE : d == -2.0d ? 1.0d * INVSCALE * INVSCALE : d == -3.0d ? 1.0d * INVSCALE * INVSCALE * INVSCALE : Math.pow(SCALE, d);
    }

    public static Pair<NodeMarginals, StationaryEdgeMarginals> computeMarginalsLogSpace(final StationaryLattice stationaryLattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsLogSpace nodeMarginalsLogSpace = new NodeMarginalsLogSpace(new StationaryLatticeWrapper(stationaryLattice), stationaryStateProjector);
        final StationaryEdgeMarginalsLogSpace stationaryEdgeMarginalsLogSpace = z ? null : new StationaryEdgeMarginalsLogSpace(stationaryLattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.1
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                double[][] doPassLogSpace = ForwardBackward.doPassLogSpace(new StationaryLatticeWrapper(StationaryLattice.this), false, z, num.intValue());
                double[][] doPassLogSpace2 = ForwardBackward.doPassLogSpace(new StationaryLatticeWrapper(StationaryLattice.this), true, z, num.intValue());
                nodeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue(), z);
                if (z) {
                    return;
                }
                stationaryEdgeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < stationaryLattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsLogSpace, stationaryEdgeMarginalsLogSpace);
    }

    public static Pair<NodeMarginals, NonStationaryEdgeMarginals> computeMarginalsLogSpace(final Lattice lattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsLogSpace nodeMarginalsLogSpace = new NodeMarginalsLogSpace(lattice, stationaryStateProjector);
        final NonStationaryEdgeMarginalsLogSpace nonStationaryEdgeMarginalsLogSpace = z ? null : new NonStationaryEdgeMarginalsLogSpace(lattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.2
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                double[][] doPassLogSpace = ForwardBackward.doPassLogSpace(Lattice.this, false, z, num.intValue());
                double[][] doPassLogSpace2 = ForwardBackward.doPassLogSpace(Lattice.this, true, z, num.intValue());
                nodeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue(), z);
                if (z) {
                    return;
                }
                nonStationaryEdgeMarginalsLogSpace.incrementExpectedCounts(doPassLogSpace, doPassLogSpace2, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsLogSpace, nonStationaryEdgeMarginalsLogSpace);
    }

    public static Pair<NodeMarginals, StationaryEdgeMarginals> computeMarginalsScaling(final StationaryLattice stationaryLattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsScaling nodeMarginalsScaling = new NodeMarginalsScaling(new StationaryLatticeWrapper(stationaryLattice), stationaryStateProjector);
        final StationaryEdgeMarginalsScaling stationaryEdgeMarginalsScaling = z ? null : new StationaryEdgeMarginalsScaling(stationaryLattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.3
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                Pair doPassScaling = ForwardBackward.doPassScaling(new StationaryLatticeWrapper(StationaryLattice.this), false, z, num.intValue());
                double[][] dArr = (double[][]) doPassScaling.getFirst();
                double[] dArr2 = (double[]) doPassScaling.getSecond();
                Pair doPassScaling2 = ForwardBackward.doPassScaling(new StationaryLatticeWrapper(StationaryLattice.this), true, z, num.intValue());
                double[][] dArr3 = (double[][]) doPassScaling2.getFirst();
                double[] dArr4 = (double[]) doPassScaling2.getSecond();
                nodeMarginalsScaling.incrementExpectedCounts(dArr, dArr2, dArr3, dArr4, num.intValue(), z);
                if (z) {
                    return;
                }
                stationaryEdgeMarginalsScaling.incrementExpectedCounts(dArr, dArr2, dArr3, dArr4, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < stationaryLattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsScaling, stationaryEdgeMarginalsScaling);
    }

    public static Pair<NodeMarginals, NonStationaryEdgeMarginals> computeMarginalsScaling(final Lattice lattice, StationaryStateProjector stationaryStateProjector, final boolean z, int i) {
        final NodeMarginalsScaling nodeMarginalsScaling = new NodeMarginalsScaling(lattice, stationaryStateProjector);
        final NonStationaryEdgeMarginalsScaling nonStationaryEdgeMarginalsScaling = z ? null : new NonStationaryEdgeMarginalsScaling(lattice);
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.4
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                Pair doPassScaling = ForwardBackward.doPassScaling(Lattice.this, false, z, num.intValue());
                double[][] dArr = (double[][]) doPassScaling.getFirst();
                double[] dArr2 = (double[]) doPassScaling.getSecond();
                Pair doPassScaling2 = ForwardBackward.doPassScaling(Lattice.this, true, z, num.intValue());
                double[][] dArr3 = (double[][]) doPassScaling2.getFirst();
                double[] dArr4 = (double[]) doPassScaling2.getSecond();
                nodeMarginalsScaling.incrementExpectedCounts(dArr, dArr2, dArr3, dArr4, num.intValue(), z);
                if (z) {
                    return;
                }
                nonStationaryEdgeMarginalsScaling.incrementExpectedCounts(dArr, dArr2, dArr3, dArr4, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return Pair.makePair(nodeMarginalsScaling, nonStationaryEdgeMarginalsScaling);
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsScaling(final Lattice lattice, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.5
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, null, (double[][]) ForwardBackward.doPassScaling(Lattice.this, false, true, num.intValue()).getFirst(), true, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsScaling(final Lattice lattice, final StationaryStateProjector stationaryStateProjector, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.6
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, stationaryStateProjector, (double[][]) ForwardBackward.doPassScaling(Lattice.this, false, true, num.intValue()).getFirst(), true, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsLogSpace(final Lattice lattice, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.7
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, null, ForwardBackward.doPassLogSpace(Lattice.this, false, true, num.intValue()), false, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [int[], int[][]] */
    public static int[][] computeViterbiPathsLogSpace(final Lattice lattice, final StationaryStateProjector stationaryStateProjector, int i) {
        final ?? r0 = new int[lattice.numSequences()];
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: tberg.murphy.sequence.ForwardBackward.8
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                r0[num.intValue()] = ForwardBackward.extractViterbiPath(Lattice.this, stationaryStateProjector, ForwardBackward.doPassLogSpace(Lattice.this, false, true, num.intValue()), false, num.intValue());
            }
        }, i);
        for (int i2 = 0; i2 < lattice.numSequences(); i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] extractViterbiPath(Lattice lattice, StationaryStateProjector stationaryStateProjector, double[][] dArr, boolean z, int i) {
        int[] iArr = new int[lattice.sequenceLength(i)];
        iArr[lattice.sequenceLength(i) - 1] = a.argmax(dArr[lattice.sequenceLength(i) - 1]);
        for (int sequenceLength = lattice.sequenceLength(i) - 2; sequenceLength >= 0; sequenceLength--) {
            int i2 = iArr[sequenceLength + 1];
            int[] allowedEdges = lattice.allowedEdges(i, sequenceLength + 1, i2, true);
            int i3 = -1;
            double d = Double.NEGATIVE_INFINITY;
            if (z) {
                double[] allowedEdgesPotentials = lattice.allowedEdgesPotentials(i, sequenceLength + 1, i2, true);
                for (int i4 = 0; i4 < allowedEdges.length; i4++) {
                    int i5 = allowedEdges[i4];
                    double d2 = dArr[sequenceLength][i5] * allowedEdgesPotentials[i4];
                    if (d2 > d) {
                        d = d2;
                        i3 = i5;
                    }
                }
            } else {
                double[] allowedEdgesLogPotentials = lattice.allowedEdgesLogPotentials(i, sequenceLength + 1, i2, true);
                for (int i6 = 0; i6 < allowedEdges.length; i6++) {
                    int i7 = allowedEdges[i6];
                    double d3 = dArr[sequenceLength][i7] + allowedEdgesLogPotentials[i6];
                    if (d3 > d) {
                        d = d3;
                        i3 = i7;
                    }
                }
            }
            iArr[sequenceLength] = i3;
        }
        if (stationaryStateProjector != null) {
            for (int i8 = 0; i8 < iArr.length; i8++) {
                iArr[i8] = stationaryStateProjector.project(i, i8, iArr[i8]);
            }
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    public static double[][] doPassLogSpace(Lattice lattice, boolean z, boolean z2, int i) {
        ?? r0 = new double[lattice.sequenceLength(i)];
        for (int i2 = 0; i2 < lattice.sequenceLength(i); i2++) {
            r0[i2] = new double[lattice.numStates(i, i2)];
        }
        int[] enumerate = z ? a.enumerate(lattice.sequenceLength(i), 0) : a.enumerate(0, lattice.sequenceLength(i));
        for (int i3 = 0; i3 < enumerate.length; i3++) {
            int i4 = enumerate[i3];
            Arrays.fill(r0[i4], Double.NEGATIVE_INFINITY);
            if (i3 == 0) {
                int numStates = lattice.numStates(i, i4);
                for (int i5 = 0; i5 < numStates; i5++) {
                    r0[i4][i5] = lattice.nodeLogPotential(i, i4, i5);
                }
            } else {
                int i6 = enumerate[i3 - 1];
                int numStates2 = lattice.numStates(i, i6);
                for (int i7 = 0; i7 < numStates2; i7++) {
                    long j = r0[i6][i7];
                    int[] allowedEdges = lattice.allowedEdges(i, i6, i7, z);
                    double[] allowedEdgesLogPotentials = lattice.allowedEdgesLogPotentials(i, i6, i7, z);
                    double[] dArr = r0[i4];
                    for (int i8 = 0; i8 < allowedEdges.length; i8++) {
                        int i9 = allowedEdges[i8];
                        double d = allowedEdgesLogPotentials[i8];
                        if (z2) {
                            dArr[i9] = Math.max(dArr[i9], j + d);
                        } else {
                            dArr[i9] = m.logAdd(dArr[i9], j + d);
                        }
                    }
                }
                for (int i10 = 0; i10 < lattice.numStates(i, i4); i10++) {
                    double[] dArr2 = r0[i4];
                    int i11 = i10;
                    dArr2[i11] = dArr2[i11] + lattice.nodeLogPotential(i, i4, i10);
                }
            }
        }
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public static Pair<double[][], double[]> doPassScaling(Lattice lattice, boolean z, boolean z2, int i) {
        double[] dArr = new double[lattice.sequenceLength(i)];
        double[] dArr2 = new double[lattice.sequenceLength(i)];
        for (int i2 = 0; i2 < lattice.sequenceLength(i); i2++) {
            dArr2[i2] = new double[lattice.numStates(i, i2)];
        }
        int[] enumerate = z ? a.enumerate(lattice.sequenceLength(i), 0) : a.enumerate(0, lattice.sequenceLength(i));
        for (int i3 = 0; i3 < enumerate.length; i3++) {
            int i4 = enumerate[i3];
            Arrays.fill(dArr2[i4], Cropper.VERT_GROW_RATIO);
            double d = Double.NEGATIVE_INFINITY;
            if (i3 == 0) {
                int numStates = lattice.numStates(i, i4);
                for (int i5 = 0; i5 < numStates; i5++) {
                    double nodePotential = lattice.nodePotential(i, i4, i5);
                    dArr2[i4][i5] = nodePotential;
                    if (nodePotential > d) {
                        d = nodePotential;
                    }
                }
            } else {
                int i6 = enumerate[i3 - 1];
                int numStates2 = lattice.numStates(i, i6);
                for (int i7 = 0; i7 < numStates2; i7++) {
                    long j = dArr2[i6][i7];
                    int[] allowedEdges = lattice.allowedEdges(i, i6, i7, z);
                    double[] allowedEdgesPotentials = lattice.allowedEdgesPotentials(i, i6, i7, z);
                    double[] dArr3 = dArr2[i4];
                    for (int i8 = 0; i8 < allowedEdges.length; i8++) {
                        int i9 = allowedEdges[i8];
                        double d2 = allowedEdgesPotentials[i8];
                        if (z2) {
                            dArr3[i9] = Math.max(dArr3[i9], j * d2);
                        } else {
                            dArr3[i9] = dArr3[i9] + (j * d2);
                        }
                    }
                }
                for (int i10 = 0; i10 < lattice.numStates(i, i4); i10++) {
                    double[] dArr4 = dArr2[i4];
                    int i11 = i10;
                    dArr4[i11] = dArr4[i11] * lattice.nodePotential(i, i4, i10);
                    long j2 = dArr2[i4][i10];
                    if (j2 > d) {
                        d = j2;
                    }
                }
            }
            int i12 = 0;
            double d3 = 1.0d;
            while (d > SCALE) {
                d /= SCALE;
                d3 *= SCALE;
                i12++;
            }
            while (d > Cropper.VERT_GROW_RATIO && d < 1.0d / SCALE) {
                d *= SCALE;
                d3 /= SCALE;
                i12--;
            }
            if (i12 != 0) {
                for (int i13 = 0; i13 < lattice.numStates(i, i4); i13++) {
                    double[] dArr5 = dArr2[i4];
                    int i14 = i13;
                    dArr5[i14] = dArr5[i14] / d3;
                }
            }
            if (i3 == 0) {
                dArr[i4] = i12;
            } else {
                dArr[i4] = dArr[enumerate[i3 - 1]] + i12;
            }
        }
        return Pair.makePair(dArr2, dArr);
    }
}
