package tberg.murphy.sequence;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import tberg.murphy.arrays.a;
import tberg.murphy.tuple.Pair;
import tberg.murphy.util.GeneralPriorityQueue;

/* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/SparseSemiMarkovDP.class */
public class SparseSemiMarkovDP {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/SparseSemiMarkovDP$BeamState.class */
    public static class BeamState {
        private final TransitionState transState;
        public double score = Double.NEGATIVE_INFINITY;
        public Pair<Integer, TransitionState> backPointer = null;

        public BeamState(TransitionState transitionState) {
            this.transState = transitionState;
        }

        public int hashCode() {
            return this.transState.hashCode();
        }

        public boolean equals(Object obj) {
            if (obj instanceof BeamState) {
                return this.transState.equals(((BeamState) obj).transState);
            }
            return false;
        }
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/SparseSemiMarkovDP$EmissionModel.class */
    public interface EmissionModel {
        int sequenceLength();

        int[] allowedWidths(int i);

        double score(int i, int i2, int i3);
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/SparseSemiMarkovDP$TransitionModel.class */
    public interface TransitionModel {
        Collection<Pair<TransitionState, Double>> startStates();
    }

    /* loaded from: input_file:lib/murphy.jar:tberg/murphy/sequence/SparseSemiMarkovDP$TransitionState.class */
    public interface TransitionState {
        int getEmissionStateIndex();

        Collection<Pair<TransitionState, Double>> forwardTransitions();

        double endScore();
    }

    public static Pair<TransitionState[], int[]> decode(EmissionModel emissionModel, TransitionModel transitionModel, int i) {
        Pair<TransitionState[], int[]> followBackpointers = followBackpointers(doForwardPass(emissionModel, transitionModel, i), emissionModel);
        return Pair.makePair(followBackpointers.getFirst(), followBackpointers.getSecond());
    }

    private static GeneralPriorityQueue<BeamState>[] doForwardPass(EmissionModel emissionModel, TransitionModel transitionModel, int i) {
        GeneralPriorityQueue<BeamState>[] generalPriorityQueueArr = new GeneralPriorityQueue[emissionModel.sequenceLength() + 1];
        for (int i2 = 0; i2 < emissionModel.sequenceLength() + 1; i2++) {
            generalPriorityQueueArr[i2] = new GeneralPriorityQueue<>();
        }
        for (int i3 = 0; i3 < emissionModel.sequenceLength() + 1; i3++) {
            if (i3 == 0) {
                for (BeamState beamState : addNullBackpointers(transitionModel.startStates())) {
                    TransitionState transitionState = beamState.transState;
                    double d = beamState.score;
                    if (d != Double.NEGATIVE_INFINITY) {
                        int emissionStateIndex = transitionState.getEmissionStateIndex();
                        for (int i4 : emissionModel.allowedWidths(emissionStateIndex)) {
                            if (i3 + i4 < emissionModel.sequenceLength() + 1) {
                                int i5 = i3 + i4;
                                double score = d + emissionModel.score(i3, emissionStateIndex, i5 - i3);
                                if (score != Double.NEGATIVE_INFINITY) {
                                    addToBeam(generalPriorityQueueArr[i5], transitionState, score, new Pair(0, beamState.backPointer.getSecond()), i);
                                }
                            }
                        }
                    }
                }
            } else {
                for (BeamState beamState2 : generalPriorityQueueArr[i3].getObjects()) {
                    for (Pair<TransitionState, Double> pair : beamState2.transState.forwardTransitions()) {
                        TransitionState first = pair.getFirst();
                        double doubleValue = pair.getSecond().doubleValue();
                        int emissionStateIndex2 = first.getEmissionStateIndex();
                        for (int i6 : emissionModel.allowedWidths(emissionStateIndex2)) {
                            if (i3 + i6 < emissionModel.sequenceLength() + 1) {
                                int i7 = i3 + i6;
                                double score2 = beamState2.score + doubleValue + emissionModel.score(i3, emissionStateIndex2, i7 - i3);
                                if (score2 != Double.NEGATIVE_INFINITY) {
                                    addToBeam(generalPriorityQueueArr[i7], first, score2, Pair.makePair(Integer.valueOf(i3), beamState2.transState), i);
                                }
                            }
                        }
                    }
                }
            }
        }
        return generalPriorityQueueArr;
    }

    private static void addToBeam(GeneralPriorityQueue<BeamState> generalPriorityQueue, TransitionState transitionState, double d, Pair<Integer, TransitionState> pair, int i) {
        double d2 = -d;
        if (generalPriorityQueue.isEmpty() || d2 < generalPriorityQueue.getPriority()) {
            BeamState beamState = new BeamState(transitionState);
            if (generalPriorityQueue.containsKey(beamState)) {
                generalPriorityQueue.decreasePriority(beamState, d2);
            } else {
                generalPriorityQueue.setPriority(beamState, d2);
            }
            BeamState object = generalPriorityQueue.getObject(beamState);
            if (object.score < d) {
                object.score = d;
                object.backPointer = pair;
            }
            while (generalPriorityQueue.size() > i) {
                generalPriorityQueue.removeFirst();
            }
        }
    }

    private static Collection<BeamState> addNullBackpointers(Collection<Pair<TransitionState, Double>> collection) {
        ArrayList arrayList = new ArrayList();
        for (Pair<TransitionState, Double> pair : collection) {
            BeamState beamState = new BeamState(pair.getFirst());
            beamState.score = pair.getSecond().doubleValue();
            beamState.backPointer = Pair.makePair(-1, null);
            arrayList.add(beamState);
        }
        return arrayList;
    }

    private static Pair<TransitionState[], int[]> followBackpointers(GeneralPriorityQueue<BeamState>[] generalPriorityQueueArr, EmissionModel emissionModel) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        TransitionState transitionState = null;
        double d = Double.NEGATIVE_INFINITY;
        for (BeamState beamState : generalPriorityQueueArr[emissionModel.sequenceLength()].getObjects()) {
            double endScore = beamState.score + beamState.transState.endScore();
            if (endScore > d) {
                d = endScore;
                transitionState = beamState.transState;
            }
        }
        int sequenceLength = emissionModel.sequenceLength();
        TransitionState transitionState2 = transitionState;
        do {
            Pair<Integer, TransitionState> pair = generalPriorityQueueArr[sequenceLength].getObject(new BeamState(transitionState2)).backPointer;
            int intValue = sequenceLength - pair.getFirst().intValue();
            arrayList.add(transitionState2);
            arrayList2.add(Integer.valueOf(intValue));
            sequenceLength = pair.getFirst().intValue();
            transitionState2 = pair.getSecond();
        } while (sequenceLength != 0);
        Collections.reverse(arrayList);
        Collections.reverse(arrayList2);
        return Pair.makePair(arrayList.toArray(new TransitionState[0]), a.toIntArray(arrayList2));
    }
}
