package edu.berkeley.cs.nlp.ocular.main;

import edu.berkeley.cs.nlp.ocular.data.Document;
import edu.berkeley.cs.nlp.ocular.data.TextAndLineImagesLoader;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.eval.Evaluator;
import edu.berkeley.cs.nlp.ocular.font.Font;
import edu.berkeley.cs.nlp.ocular.image.ImageUtils;
import edu.berkeley.cs.nlp.ocular.image.Visualizer;
import edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel;
import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate;
import edu.berkeley.cs.nlp.ocular.model.em.BeamingSemiMarkovDP;
import edu.berkeley.cs.nlp.ocular.model.em.CUDAInnerLoop;
import edu.berkeley.cs.nlp.ocular.model.em.DefaultInnerLoop;
import edu.berkeley.cs.nlp.ocular.model.em.DenseBigramTransitionModel;
import edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop;
import edu.berkeley.cs.nlp.ocular.model.em.JOCLInnerLoop;
import edu.berkeley.cs.nlp.ocular.model.emission.CachingEmissionModel;
import edu.berkeley.cs.nlp.ocular.model.emission.CachingEmissionModelExplicitOffset;
import edu.berkeley.cs.nlp.ocular.model.emission.EmissionModel;
import edu.berkeley.cs.nlp.ocular.model.transition.CharacterNgramTransitionModel;
import edu.berkeley.cs.nlp.ocular.model.transition.CharacterNgramTransitionModelMarkovOffset;
import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel;
import edu.berkeley.cs.nlp.ocular.util.Tuple2;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import tberg.murphy.arrays.a;
import tberg.murphy.fig.Execution;
import tberg.murphy.fig.Option;
import tberg.murphy.fileio.f;
import tberg.murphy.indexer.Indexer;
import tberg.murphy.threading.BetterThreader;

/* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/main/ExperimentsMain.class */
public class ExperimentsMain implements Runnable {

    @Option(gloss = "")
    public static String inputPath = "/Users/tberg/Dropbox/ocr_data/old_bailey_test_list.txt";

    @Option(gloss = "")
    public static String outputRelPath = "";

    @Option(gloss = "")
    public static String fontPath = "/Users/tberg/Dropbox/corpora/ocr_data/fonts/init.fontser";

    @Option(gloss = "")
    public static String lmDir = "/Users/tberg/Dropbox/corpora/ocr_data/lms/";

    @Option(gloss = "")
    public static String lmBaseName = "nyt";

    @Option(gloss = "")
    public static int paddingMinWidth = 1;

    @Option(gloss = "")
    public static int paddingMaxWidth = 5;

    @Option(gloss = "")
    public static boolean markovVerticalOffset = true;

    @Option(gloss = "")
    public static int beamSize = 10;

    @Option(gloss = "")
    public static int numEMIters = 4;

    @Option(gloss = "")
    public static EmissionCacheInnerLoopType emissionEngine = EmissionCacheInnerLoopType.DEFAULT;

    @Option(gloss = "")
    public static int cudaDeviceID = 0;

    @Option(gloss = "")
    public static int numMstepThreads = 8;

    @Option(gloss = "")
    public static int numEmissionCacheThreads = 8;

    @Option(gloss = "")
    public static int numDecodeThreads = 4;

    @Option(gloss = "")
    public static boolean popupVisuals = false;

    @Option(gloss = "")
    public static boolean writeVisuals = false;

    @Option(gloss = "")
    public static boolean evaluate = false;

    /* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/main/ExperimentsMain$EmissionCacheInnerLoopType.class */
    public enum EmissionCacheInnerLoopType {
        DEFAULT,
        OPENCL,
        CUDA
    }

    public static void main(String[] strArr) {
        Execution.run(strArr, new ExperimentsMain());
    }

    @Override // java.lang.Runnable
    public void run() {
        long nanoTime = System.nanoTime();
        long j = 0;
        EmissionCacheInnerLoop emissionCacheInnerLoop = null;
        if (emissionEngine == EmissionCacheInnerLoopType.DEFAULT) {
            emissionCacheInnerLoop = new DefaultInnerLoop(numEmissionCacheThreads);
        } else if (emissionEngine == EmissionCacheInnerLoopType.OPENCL) {
            emissionCacheInnerLoop = new JOCLInnerLoop(numEmissionCacheThreads);
        } else if (emissionEngine == EmissionCacheInnerLoopType.CUDA) {
            emissionCacheInnerLoop = new CUDAInnerLoop(numEmissionCacheThreads, cudaDeviceID);
        }
        ArrayList arrayList = new ArrayList();
        List<Document> loadDocuments = TextAndLineImagesLoader.loadDocuments(inputPath, 30);
        if (loadDocuments.isEmpty()) {
            throw new NoDocumentsFoundException();
        }
        for (Document document : loadDocuments) {
            System.out.println("Loading LM..");
            SingleLanguageModel readLM = ((TextAndLineImagesLoader.TextAndLineImagesDocument) document).useLongS() ? LMTrainMain.readLM(lmDir + "/" + lmBaseName + "_longs.lmser") : LMTrainMain.readLM(lmDir + "/" + lmBaseName + ".lmser");
            Indexer<String> characterIndexer = readLM.getCharacterIndexer();
            System.out.println("Loading font initializer..");
            Font readFont = InitializeFont.readFont(fontPath);
            final CharacterTemplate[] characterTemplateArr = new CharacterTemplate[characterIndexer.size()];
            for (int i = 0; i < characterTemplateArr.length; i++) {
                characterTemplateArr[i] = readFont.get(characterIndexer.getObject(i));
            }
            System.out.println("Characters: " + characterIndexer.getObjects());
            System.out.println("Num characters: " + characterIndexer.size());
            ImageUtils.PixelType[][][] loadLineImages = document.loadLineImages();
            String[][] loadDiplomaticTextLines = document.loadDiplomaticTextLines();
            final EmissionModel cachingEmissionModelExplicitOffset = markovVerticalOffset ? new CachingEmissionModelExplicitOffset(characterTemplateArr, characterIndexer, loadLineImages, paddingMinWidth, paddingMaxWidth, emissionCacheInnerLoop) : new CachingEmissionModel(characterTemplateArr, characterIndexer, loadLineImages, paddingMinWidth, paddingMaxWidth, emissionCacheInnerLoop);
            SparseTransitionModel characterNgramTransitionModelMarkovOffset = markovVerticalOffset ? new CharacterNgramTransitionModelMarkovOffset(readLM) : new CharacterNgramTransitionModel(readLM);
            DenseBigramTransitionModel denseBigramTransitionModel = new DenseBigramTransitionModel(readLM);
            long nanoTime2 = System.nanoTime();
            cachingEmissionModelExplicitOffset.rebuildCache();
            j += System.nanoTime() - nanoTime2;
            for (int i2 = 0; i2 < numEMIters; i2++) {
                System.out.println("Iteration " + i2 + " e-step");
                long nanoTime3 = System.nanoTime();
                Tuple2<Tuple2<SparseTransitionModel.TransitionState[][], int[][]>, Double> decode = new BeamingSemiMarkovDP(cachingEmissionModelExplicitOffset, characterNgramTransitionModelMarkovOffset, denseBigramTransitionModel).decode(beamSize, numDecodeThreads);
                double doubleValue = decode._2.doubleValue();
                final SparseTransitionModel.TransitionState[][] transitionStateArr = decode._1._1;
                final int[][] iArr = decode._1._2;
                System.out.println("Compute marginals and decode: " + ((System.nanoTime() - nanoTime3) / 1000000) + "ms");
                System.gc();
                System.gc();
                System.gc();
                System.out.println("Iteration " + i2 + ": " + doubleValue);
                printTranscription(i2, document, arrayList, loadLineImages, loadDiplomaticTextLines, transitionStateArr, iArr, characterIndexer, characterTemplateArr, cachingEmissionModelExplicitOffset);
                if (i2 < numEMIters - 1) {
                    long nanoTime4 = System.nanoTime();
                    for (int i3 = 0; i3 < characterTemplateArr.length; i3++) {
                        if (characterTemplateArr[i3] != null) {
                            characterTemplateArr[i3].clearCounts();
                        }
                    }
                    BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: edu.berkeley.cs.nlp.ocular.main.ExperimentsMain.1
                        @Override // tberg.murphy.threading.BetterThreader.Function
                        public void call(Integer num, Object obj) {
                            cachingEmissionModelExplicitOffset.incrementCounts(num.intValue(), transitionStateArr[num.intValue()], iArr[num.intValue()]);
                        }
                    }, numMstepThreads);
                    for (int i4 = 0; i4 < cachingEmissionModelExplicitOffset.numSequences(); i4++) {
                        betterThreader.addFunctionArgument(Integer.valueOf(i4));
                    }
                    betterThreader.run();
                    BetterThreader betterThreader2 = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: edu.berkeley.cs.nlp.ocular.main.ExperimentsMain.2
                        @Override // tberg.murphy.threading.BetterThreader.Function
                        public void call(Integer num, Object obj) {
                            if (characterTemplateArr[num.intValue()] != null) {
                                characterTemplateArr[num.intValue()].updateParameters();
                            }
                        }
                    }, numMstepThreads);
                    for (int i5 = 0; i5 < characterTemplateArr.length; i5++) {
                        betterThreader2.addFunctionArgument(Integer.valueOf(i5));
                    }
                    betterThreader2.run();
                    System.out.println("Update parameters: " + ((System.nanoTime() - nanoTime4) / 1000000) + "ms");
                    long nanoTime5 = System.nanoTime();
                    cachingEmissionModelExplicitOffset.rebuildCache();
                    j += System.nanoTime() - nanoTime5;
                }
            }
        }
        if (!arrayList.isEmpty() && evaluate) {
            printEvaluation(arrayList);
        }
        System.out.println("Emission cache time: " + (j / 1.0E9d) + "s");
        System.out.println("Overall time: " + ((System.nanoTime() - nanoTime) / 1.0E9d) + "s");
    }

    public static void printEvaluation(List<Tuple2<String, Map<String, Evaluator.EvalSuffStats>>> list) {
        HashMap hashMap = new HashMap();
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("All evals:\n");
        for (Tuple2<String, Map<String, Evaluator.EvalSuffStats>> tuple2 : list) {
            String str = tuple2._1;
            Map<String, Evaluator.EvalSuffStats> map = tuple2._2;
            stringBuffer.append("Document: " + str + "\n");
            stringBuffer.append(Evaluator.renderEval(map) + "\n");
            for (String str2 : map.keySet()) {
                Evaluator.EvalSuffStats evalSuffStats = map.get(str2);
                Evaluator.EvalSuffStats evalSuffStats2 = (Evaluator.EvalSuffStats) hashMap.get(str2);
                if (evalSuffStats2 == null) {
                    evalSuffStats2 = new Evaluator.EvalSuffStats();
                    hashMap.put(str2, evalSuffStats2);
                }
                evalSuffStats2.increment(evalSuffStats);
            }
        }
        stringBuffer.append("\nMarco-avg total eval:\n");
        stringBuffer.append(Evaluator.renderEval(hashMap) + "\n");
        f.writeString(Execution.getVirtualExecDir() + "/" + outputRelPath + "/out.txt", stringBuffer.toString());
        System.out.println();
        System.out.println(stringBuffer.toString());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v14, types: [double[][], double[][][]] */
    private static void printTranscription(int i, Document document, List<Tuple2<String, Map<String, Evaluator.EvalSuffStats>>> list, ImageUtils.PixelType[][][] pixelTypeArr, String[][] strArr, SparseTransitionModel.TransitionState[][] transitionStateArr, int[][] iArr, Indexer<String> indexer, CharacterTemplate[] characterTemplateArr, EmissionModel emissionModel) {
        if (evaluate || writeVisuals || popupVisuals) {
            List[] listArr = new List[pixelTypeArr.length];
            List[] listArr2 = new List[pixelTypeArr.length];
            List[] listArr3 = new List[pixelTypeArr.length];
            for (int i2 = 0; i2 < transitionStateArr.length; i2++) {
                listArr[i2] = new ArrayList();
                listArr2[i2] = new ArrayList();
                listArr3[i2] = new ArrayList();
                int i3 = 0;
                for (int i4 = 0; i4 < transitionStateArr[i2].length; i4++) {
                    int i5 = transitionStateArr[i2][i4].getGlyphChar().templateCharIndex;
                    int i6 = iArr[i2][i4];
                    int exposure = emissionModel.getExposure(i2, i3, transitionStateArr[i2][i4], i6);
                    int offset = emissionModel.getOffset(i2, i3, transitionStateArr[i2][i4], i6);
                    int padWidth = emissionModel.getPadWidth(i2, i3, transitionStateArr[i2][i4], i6);
                    if (i5 == indexer.getIndex(Charset.SPACE)) {
                        for (int i7 = i3; i7 < i3 + (i6 - padWidth); i7++) {
                            listArr[i2].add(Integer.valueOf(i7));
                        }
                    }
                    for (int i8 = i3 + (i6 - padWidth); i8 < i3 + i6; i8++) {
                        listArr[i2].add(Integer.valueOf(i8));
                    }
                    if (listArr2[i2].isEmpty() || !Charset.HYPHEN.equals(listArr2[i2].get(listArr2[i2].size() - 1)) || !Charset.HYPHEN.equals(indexer.getObject(i5))) {
                        listArr2[i2].add(indexer.getObject(i5));
                    }
                    for (double[] dArr : a.toDouble(characterTemplateArr[i5].blackProbs(exposure, offset, i6 - padWidth))) {
                        listArr3[i2].add(dArr);
                    }
                    for (double[] dArr2 : a.toDouble(characterTemplateArr[indexer.getIndex(Charset.SPACE)].blackProbs(exposure, offset, padWidth))) {
                        listArr3[i2].add(dArr2);
                    }
                    i3 += i6;
                }
            }
            ?? r0 = new double[pixelTypeArr.length];
            for (int i9 = 0; i9 < transitionStateArr.length; i9++) {
                r0[i9] = new double[listArr3[i9].size()];
                for (int i10 = 0; i10 < listArr3[i9].size(); i10++) {
                    r0[i9][i10] = (double[]) listArr3[i9].get(i10);
                }
            }
            ArrayList arrayList = new ArrayList();
            for (int i11 = 0; i11 < indexer.size(); i11++) {
                if (i11 != indexer.getIndex(Charset.SPACE)) {
                    int i12 = -1;
                    double d = Double.NEGATIVE_INFINITY;
                    for (int templateMinWidth = characterTemplateArr[i11].templateMinWidth(); templateMinWidth <= characterTemplateArr[i11].templateMaxWidth(); templateMinWidth++) {
                        double widthLogProb = characterTemplateArr[i11].widthLogProb(templateMinWidth);
                        if (widthLogProb >= d) {
                            d = widthLogProb;
                            i12 = templateMinWidth;
                        }
                    }
                    double[][] dArr3 = a.toDouble(characterTemplateArr[i11].blackProbs(0, 0, i12));
                    for (double[] dArr4 : dArr3) {
                        arrayList.add(dArr4);
                    }
                    for (int i13 = 0; i13 < 5; i13++) {
                        arrayList.add(new double[dArr3[0].length]);
                    }
                }
            }
            double[][][] dArr5 = new double[1][arrayList.size()];
            for (int i14 = 0; i14 < arrayList.size(); i14++) {
                dArr5[0][i14] = (double[]) arrayList.get(i14);
            }
            if (strArr == null || !evaluate) {
                StringBuffer stringBuffer = new StringBuffer();
                for (List list2 : listArr2) {
                    Iterator it = list2.iterator();
                    while (it.hasNext()) {
                        stringBuffer.append((String) it.next());
                    }
                    stringBuffer.append("\n");
                    stringBuffer.append("\n");
                }
                System.out.println(stringBuffer.toString());
                new File(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName()).mkdirs();
                f.writeString(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName() + "/out-" + i + ".txt", stringBuffer.toString());
            } else {
                List[] listArr4 = new List[strArr.length];
                for (int i15 = 0; i15 < strArr.length; i15++) {
                    listArr4[i15] = new ArrayList();
                    for (int i16 = 0; i16 < strArr[i15].length; i16++) {
                        listArr4[i15].add(strArr[i15][i16]);
                    }
                }
                StringBuffer stringBuffer2 = new StringBuffer();
                for (int i17 = 0; i17 < listArr2.length; i17++) {
                    Iterator it2 = listArr2[i17].iterator();
                    while (it2.hasNext()) {
                        stringBuffer2.append((String) it2.next());
                    }
                    stringBuffer2.append("\n");
                    Iterator it3 = listArr4[i17].iterator();
                    while (it3.hasNext()) {
                        stringBuffer2.append((String) it3.next());
                    }
                    stringBuffer2.append("\n");
                    stringBuffer2.append("\n");
                }
                Map<String, Evaluator.EvalSuffStats> unsegmentedEval = Evaluator.getUnsegmentedEval(listArr2, listArr4, true);
                if (i == numEMIters - 1) {
                    list.add(Tuple2.Tuple2(document.baseName(), unsegmentedEval));
                }
                System.out.println(stringBuffer2.toString() + Evaluator.renderEval(unsegmentedEval));
                new File(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName()).mkdirs();
                f.writeString(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName() + "/out-" + i + ".txt", stringBuffer2.toString() + Evaluator.renderEval(unsegmentedEval));
            }
            if (writeVisuals) {
                new File(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName()).mkdirs();
                f.writeImage(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName() + "/overlay-" + i + ".png", Visualizer.renderOverlay(pixelTypeArr, r0, listArr));
                f.writeImage(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName() + "/alphabet-" + i + ".png", Visualizer.renderBlackProbs(dArr5));
                f.writeImage(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName() + "/original-" + i + ".png", Visualizer.renderObservations(pixelTypeArr));
                f.writeImage(Execution.getVirtualExecDir() + "/" + outputRelPath + "/" + document.baseName() + "/probs-" + i + ".png", Visualizer.renderBlackProbsAndSegmentation(r0, listArr));
            }
            if (popupVisuals) {
                ImageUtils.display(Visualizer.renderOverlay(pixelTypeArr, r0, listArr));
                ImageUtils.display(Visualizer.renderBlackProbs(dArr5));
                ImageUtils.display(Visualizer.renderObservations(pixelTypeArr));
                ImageUtils.display(Visualizer.renderBlackProbsAndSegmentation(r0, listArr));
            }
        }
    }
}
