package edu.berkeley.cs.nlp.ocular.train;

import edu.berkeley.cs.nlp.ocular.data.Document;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.eval.EvalPrinter;
import edu.berkeley.cs.nlp.ocular.eval.Evaluator;
import edu.berkeley.cs.nlp.ocular.eval.ModelTranscriptions;
import edu.berkeley.cs.nlp.ocular.eval.MultiDocumentTranscriber;
import edu.berkeley.cs.nlp.ocular.eval.SingleDocumentEvaluatorAndOutputPrinter;
import edu.berkeley.cs.nlp.ocular.font.Font;
import edu.berkeley.cs.nlp.ocular.gsm.BasicGlyphSubstitutionModel;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphSubstitutionModel;
import edu.berkeley.cs.nlp.ocular.lm.BasicCodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.CodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.CorpusCounter;
import edu.berkeley.cs.nlp.ocular.lm.InterpolatingSingleLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.NgramLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel;
import edu.berkeley.cs.nlp.ocular.main.FonttrainTranscribeShared;
import edu.berkeley.cs.nlp.ocular.main.InitializeFont;
import edu.berkeley.cs.nlp.ocular.main.InitializeGlyphSubstitutionModel;
import edu.berkeley.cs.nlp.ocular.main.InitializeLanguageModel;
import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate;
import edu.berkeley.cs.nlp.ocular.model.DecodeState;
import edu.berkeley.cs.nlp.ocular.model.DecoderEM;
import edu.berkeley.cs.nlp.ocular.model.TransitionStateType;
import edu.berkeley.cs.nlp.ocular.model.em.DenseBigramTransitionModel;
import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel;
import edu.berkeley.cs.nlp.ocular.util.CollectionHelper;
import edu.berkeley.cs.nlp.ocular.util.StringHelper;
import edu.berkeley.cs.nlp.ocular.util.Tuple2;
import edu.berkeley.cs.nlp.ocular.util.Tuple3;
import java.io.File;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import tberg.murphy.indexer.Indexer;
import tberg.murphy.threading.BetterThreader;

/* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/train/FontTrainer.class */
public class FontTrainer {
    public Tuple3<Font, CodeSwitchLanguageModel, GlyphSubstitutionModel> trainFont(List<Document> list, Font font, CodeSwitchLanguageModel codeSwitchLanguageModel, GlyphSubstitutionModel glyphSubstitutionModel, TrainingRestarter trainingRestarter, String str, String str2, String str3, DecoderEM decoderEM, BasicGlyphSubstitutionModel.BasicGlyphSubstitutionModelFactory basicGlyphSubstitutionModelFactory, SingleDocumentEvaluatorAndOutputPrinter singleDocumentEvaluatorAndOutputPrinter, int i, int i2, boolean z, boolean z2, int i3, String str4, String str5, Set<FonttrainTranscribeShared.OutputFormat> set, MultiDocumentTranscriber multiDocumentTranscriber, int i4, boolean z3, boolean z4) {
        System.out.println("trainFont(numEMIters=" + i + ", updateDocBatchSize=" + i2 + ", noUpdateIfBatchTooSmall=" + z + ", writeIntermediateModelsToTemp=" + z2 + ")");
        int size = list.size();
        int i5 = 0;
        if (trainingRestarter != null) {
            Tuple2<Integer, Tuple3<Font, CodeSwitchLanguageModel, GlyphSubstitutionModel>> restartModels = trainingRestarter.getRestartModels(font, codeSwitchLanguageModel, glyphSubstitutionModel, str2 != null, str3 != null, str5, i, size, i2, z);
            i5 = restartModels._1.intValue();
            font = restartModels._2._1;
            codeSwitchLanguageModel = restartModels._2._2;
            glyphSubstitutionModel = restartModels._2._3;
        }
        int i6 = i5 + 1;
        while (true) {
            if (i6 > i && i6 != 1) {
                break;
            }
            System.out.println("Training iteration: " + i6 + "    " + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime()));
            Tuple3<Font, CodeSwitchLanguageModel, GlyphSubstitutionModel> doFontTrainPass = doFontTrainPass(i6, list, font, codeSwitchLanguageModel, glyphSubstitutionModel, str, str2, str3, decoderEM, basicGlyphSubstitutionModelFactory, singleDocumentEvaluatorAndOutputPrinter, i, i2, z, z2, i3, str4, str5, set, multiDocumentTranscriber, i4, z3, z4);
            font = doFontTrainPass._1;
            codeSwitchLanguageModel = doFontTrainPass._2;
            glyphSubstitutionModel = doFontTrainPass._3;
            i6++;
        }
        System.out.println("Training completed; saving models.");
        if (str != null) {
            System.out.println("Writing trained font to " + str);
            InitializeFont.writeFont(font, str);
        }
        if (str2 != null) {
            System.out.println("Writing trained lm to " + str2);
            InitializeLanguageModel.writeLM(codeSwitchLanguageModel, str2);
        }
        if (str3 != null) {
            System.out.println("Writing trained gsm to " + str3);
            InitializeGlyphSubstitutionModel.writeGSM(glyphSubstitutionModel, str3);
        }
        return Tuple3.Tuple3(font, codeSwitchLanguageModel, glyphSubstitutionModel);
    }

    public Tuple3<Font, CodeSwitchLanguageModel, GlyphSubstitutionModel> doFontTrainPass(int i, List<Document> list, Font font, CodeSwitchLanguageModel codeSwitchLanguageModel, GlyphSubstitutionModel glyphSubstitutionModel, String str, String str2, String str3, DecoderEM decoderEM, BasicGlyphSubstitutionModel.BasicGlyphSubstitutionModelFactory basicGlyphSubstitutionModelFactory, SingleDocumentEvaluatorAndOutputPrinter singleDocumentEvaluatorAndOutputPrinter, int i2, int i3, boolean z, boolean z2, int i4, String str4, String str5, Set<FonttrainTranscribeShared.OutputFormat> set, MultiDocumentTranscriber multiDocumentTranscriber, int i5, boolean z3, boolean z4) {
        Indexer<String> characterIndexer = codeSwitchLanguageModel.getCharacterIndexer();
        codeSwitchLanguageModel.getLanguageIndexer().size();
        int size = list.size();
        CharacterTemplate[] loadTemplates = loadTemplates(font, characterIndexer);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        DenseBigramTransitionModel denseBigramTransitionModel = new DenseBigramTransitionModel(codeSwitchLanguageModel);
        clearTemplates(loadTemplates);
        ArrayList arrayList3 = new ArrayList();
        double[][][] initializeNewCountsMatrix = basicGlyphSubstitutionModelFactory.initializeNewCountsMatrix();
        double d = 0.0d;
        double d2 = 0.0d;
        int i6 = 0;
        int i7 = 0;
        for (int i8 = 0; i8 < size; i8++) {
            i7++;
            Document document = list.get(i8);
            System.out.println("Training iteration " + i + " of " + i2 + ", document " + (i8 + 1) + " of " + size + ":  " + document.baseName() + "    " + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime()));
            document.loadDiplomaticTextLines();
            document.loadNormalizedText();
            try {
                Tuple2<DecodeState[][], Double> computeEStep = decoderEM.computeEStep(document, true, codeSwitchLanguageModel, glyphSubstitutionModel, loadTemplates, denseBigramTransitionModel);
                DecodeState[][] decodeStateArr = computeEStep._1;
                d += computeEStep._2.doubleValue();
                d2 += computeEStep._2.doubleValue();
                arrayList3.add(decodeStateArr);
                basicGlyphSubstitutionModelFactory.incrementCounts(initializeNewCountsMatrix, makeFullViterbiStateSeq(decodeStateArr, characterIndexer));
                Tuple2<Map<String, Evaluator.EvalSuffStats>, Map<String, Evaluator.EvalSuffStats>> evaluateAndPrintTranscription = singleDocumentEvaluatorAndOutputPrinter.evaluateAndPrintTranscription(i, 0, document, decodeStateArr, str4, str5, set, codeSwitchLanguageModel);
                if (evaluateAndPrintTranscription._1 != null) {
                    arrayList.add(Tuple2.Tuple2(document.baseName(), evaluateAndPrintTranscription._1));
                }
                if (evaluateAndPrintTranscription._2 != null) {
                    arrayList2.add(Tuple2.Tuple2(document.baseName(), evaluateAndPrintTranscription._2));
                }
                if (isBatchComplete(size, i8, i7, i3, z)) {
                    i6++;
                    if (str != null) {
                        updateFontParameters(loadTemplates, i4);
                        String makeFontPath = z2 ? ModelPathMaker.makeFontPath(str5, i, i6) : str;
                        System.out.println("Writing updated font to " + makeFontPath);
                        InitializeFont.writeFont(font, makeFontPath);
                    }
                    if (str2 != null) {
                        codeSwitchLanguageModel = reestimateLM(arrayList3, codeSwitchLanguageModel, 0.5d);
                        String makeLmPath = z2 ? ModelPathMaker.makeLmPath(str5, i, i6) : str2;
                        System.out.println("Writing updated lm to " + makeLmPath);
                        InitializeLanguageModel.writeLM(codeSwitchLanguageModel, makeLmPath);
                        denseBigramTransitionModel = new DenseBigramTransitionModel(codeSwitchLanguageModel);
                    }
                    if (str3 != null) {
                        System.out.println("Estimating parameters of a new Glyph Substitution Model.  Iter: " + i + ", batch: " + i6);
                        glyphSubstitutionModel = basicGlyphSubstitutionModelFactory.make(initializeNewCountsMatrix, i, i6);
                        String makeGsmPath = z2 ? ModelPathMaker.makeGsmPath(str5, i, i6) : str3;
                        System.out.println("Writing updated gsm to " + makeGsmPath);
                        InitializeGlyphSubstitutionModel.writeGSM(glyphSubstitutionModel, makeGsmPath);
                    }
                    System.out.println("Clearing font parameter statistics.");
                    clearTemplates(loadTemplates);
                    arrayList3 = new ArrayList();
                    initializeNewCountsMatrix = basicGlyphSubstitutionModelFactory.initializeNewCountsMatrix();
                    System.out.println("Completed Batch: Iteration " + i + ", batch " + i6 + ": avg joint log prob: " + (d2 / i7) + "    " + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime()));
                    if (z3 && ((i % i5 == 0 || i == i2) && (i != i2 || i8 + 1 != size))) {
                        multiDocumentTranscriber.transcribe(i, i6, font, codeSwitchLanguageModel, glyphSubstitutionModel);
                    }
                    d2 = 0.0d;
                    i7 = 0;
                }
            } catch (RuntimeException e) {
                if (!z4) {
                    throw e;
                }
                System.err.println("DOCUMENT FAILED! Skipping " + document.baseName());
                e.printStackTrace();
            }
        }
        System.out.println("Iteration " + i + " avg joint log prob: " + (d / size));
        if (new File(str4).isDirectory()) {
            if (!arrayList.isEmpty()) {
                EvalPrinter.printEvaluation(arrayList, str5 + "/all_transcriptions/" + new File(str4).getName() + "/eval_iter-" + i + "_diplomatic.txt");
            }
            if (!arrayList2.isEmpty()) {
                EvalPrinter.printEvaluation(arrayList2, str5 + "/all_transcriptions/" + new File(str4).getName() + "/eval_iter-" + i + "_normalized.txt");
            }
        }
        if (i % i5 == 0 || i == i2) {
            System.out.println("Evaluating dev data at the end of iteration " + i + "    " + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime()));
            multiDocumentTranscriber.transcribe(i, 0, font, codeSwitchLanguageModel, glyphSubstitutionModel);
        }
        return Tuple3.Tuple3(font, codeSwitchLanguageModel, glyphSubstitutionModel);
    }

    public static boolean isBatchComplete(int i, int i2, int i3, int i4, boolean z) {
        boolean z2 = false;
        if (i2 + 1 == i) {
            if (!z || i3 >= i4) {
                z2 = true;
            }
        } else if (i - (i2 + 1) >= i4 && i3 == i4) {
            z2 = true;
        }
        return z2;
    }

    public static CharacterTemplate[] loadTemplates(Font font, Indexer<String> indexer) {
        CharacterTemplate[] characterTemplateArr = new CharacterTemplate[indexer.size()];
        for (int i = 0; i < indexer.size(); i++) {
            CharacterTemplate characterTemplate = font.get(indexer.getObject(i));
            if (characterTemplate == null) {
                throw new RuntimeException("No template found for character '" + indexer.getObject(i) + "' (" + StringHelper.toUnicode(indexer.getObject(i)) + ")");
            }
            characterTemplateArr[i] = characterTemplate;
        }
        return characterTemplateArr;
    }

    private void clearTemplates(CharacterTemplate[] characterTemplateArr) {
        for (int i = 0; i < characterTemplateArr.length; i++) {
            if (characterTemplateArr[i] != null) {
                characterTemplateArr[i].clearCounts();
            }
        }
    }

    private void updateFontParameters(final CharacterTemplate[] characterTemplateArr, int i) {
        long nanoTime = System.nanoTime();
        BetterThreader betterThreader = new BetterThreader(new BetterThreader.Function<Integer, Object>() { // from class: edu.berkeley.cs.nlp.ocular.train.FontTrainer.1
            @Override // tberg.murphy.threading.BetterThreader.Function
            public void call(Integer num, Object obj) {
                if (characterTemplateArr[num.intValue()] != null) {
                    characterTemplateArr[num.intValue()].updateParameters();
                }
            }
        }, i);
        for (int i2 = 0; i2 < characterTemplateArr.length; i2++) {
            betterThreader.addFunctionArgument(Integer.valueOf(i2));
        }
        betterThreader.run();
        System.out.println("Update font parameters: " + ((System.nanoTime() - nanoTime) / 1000000) + "ms");
    }

    private void incrementLmCounts(int[] iArr, List<DecodeState> list, Indexer<String> indexer) {
        int index = indexer.getIndex(Charset.SPACE);
        Iterator<DecodeState> it = list.iterator();
        while (it.hasNext()) {
            SparseTransitionModel.TransitionState transitionState = it.next().ts;
            int languageIndex = transitionState.getLanguageIndex();
            if (languageIndex >= 0 && transitionState.getType() == TransitionStateType.TMPL && transitionState.getLmCharIndex() != index) {
                iArr[languageIndex] = iArr[languageIndex] + 1;
            }
        }
    }

    private CodeSwitchLanguageModel reestimateLM(List<DecodeState[][]> list, CodeSwitchLanguageModel codeSwitchLanguageModel, double d) {
        SingleLanguageModel singleLanguageModel;
        long nanoTime = System.nanoTime();
        Indexer<String> characterIndexer = codeSwitchLanguageModel.getCharacterIndexer();
        Indexer<String> languageIndexer = codeSwitchLanguageModel.getLanguageIndexer();
        int size = languageIndexer.size();
        System.out.println("Retraining LM");
        List<List<List<String>>> separateTranscriptionsByLanguage = separateTranscriptionsByLanguage(list, characterIndexer, languageIndexer);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            SingleLanguageModel singleLanguageModel2 = codeSwitchLanguageModel.get(i);
            if (singleLanguageModel2 instanceof InterpolatingSingleLanguageModel) {
                singleLanguageModel2 = ((InterpolatingSingleLanguageModel) singleLanguageModel2).getSubModel(0);
            }
            CorpusCounter corpusCounter = new CorpusCounter(singleLanguageModel2.getMaxOrder());
            int i2 = 0;
            for (List<String> list2 : separateTranscriptionsByLanguage.get(i)) {
                corpusCounter.countChars(list2, characterIndexer, 0);
                i2 += list2.size();
            }
            System.out.println("  found " + i2 + " characters for " + languageIndexer.getObject(i) + " read from transcription output");
            if (i2 > 0) {
                singleLanguageModel = new InterpolatingSingleLanguageModel(CollectionHelper.makeList(Tuple2.Tuple2(singleLanguageModel2, Double.valueOf(1.0d - d)), Tuple2.Tuple2(new NgramLanguageModel(characterIndexer, corpusCounter.getCounts(), singleLanguageModel2.getActiveCharacters(), NgramLanguageModel.LMType.KNESER_NEY, singleLanguageModel2 instanceof NgramLanguageModel ? ((NgramLanguageModel) singleLanguageModel2).getLmPower() : 4.0d), Double.valueOf(d))));
                System.out.println("  using new interpolated lm for " + languageIndexer.getObject(i));
            } else {
                System.out.println("  using original lm for " + languageIndexer.getObject(i));
                singleLanguageModel = singleLanguageModel2;
            }
            arrayList.add(Tuple2.Tuple2(singleLanguageModel, Double.valueOf(i2 + 1.0d)));
        }
        BasicCodeSwitchLanguageModel basicCodeSwitchLanguageModel = new BasicCodeSwitchLanguageModel(arrayList, characterIndexer, languageIndexer, codeSwitchLanguageModel.getProbKeepSameLanguage());
        System.out.println("New LM: " + ((System.nanoTime() - nanoTime) / 1000000) + "ms");
        return basicCodeSwitchLanguageModel;
    }

    private List<List<List<String>>> separateTranscriptionsByLanguage(List<DecodeState[][]> list, Indexer<String> indexer, Indexer<String> indexer2) {
        int size = indexer2.size();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            arrayList.add(new ArrayList());
        }
        Iterator<DecodeState[][]> it = list.iterator();
        while (it.hasNext()) {
            List<Tuple2<String, String>> viterbiNormalizedCharLangRunning = new ModelTranscriptions(it.next(), indexer, indexer2).getViterbiNormalizedCharLangRunning();
            String str = null;
            ArrayList arrayList2 = new ArrayList();
            for (Tuple2<String, String> tuple2 : viterbiNormalizedCharLangRunning) {
                String str2 = tuple2._2;
                if (!equalsNullSafe(str2, str)) {
                    if (!arrayList2.isEmpty()) {
                        ((List) arrayList.get((str == null && size == 1) ? 0 : indexer2.getIndex(str))).add(arrayList2);
                    }
                    arrayList2 = new ArrayList();
                    str = str2;
                }
                if (!Charset.SPACE.equals(tuple2._1) || arrayList2.isEmpty() || !Charset.SPACE.equals(CollectionHelper.last(arrayList2))) {
                    arrayList2.add(tuple2._1);
                }
            }
            if (!arrayList2.isEmpty()) {
                ((List) arrayList.get((str == null && size == 1) ? 0 : indexer2.getIndex(str))).add(arrayList2);
            }
        }
        return arrayList;
    }

    private <A> boolean equalsNullSafe(A a, A a2) {
        if (a == null && a2 == null) {
            return true;
        }
        if (a == null || a2 == null) {
            return false;
        }
        return a.equals(a2);
    }

    public static List<DecodeState> makeFullViterbiStateSeq(DecodeState[][] decodeStateArr, Indexer<String> indexer) {
        int length = decodeStateArr.length;
        List[] listArr = new List[length];
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < length; i++) {
            listArr[i] = new ArrayList();
            if (i < decodeStateArr.length) {
                int length2 = decodeStateArr[i].length;
                for (int i2 = 0; i2 < length2; i2++) {
                    DecodeState decodeState = decodeStateArr[i][i2];
                    int i3 = decodeState.ts.getGlyphChar().templateCharIndex;
                    if (listArr[i].isEmpty() || !Charset.HYPHEN.equals(listArr[i].get(listArr[i].size() - 1)) || !Charset.HYPHEN.equals(indexer.getObject(i3))) {
                        listArr[i].add(indexer.getObject(i3));
                        arrayList.add(decodeState);
                    }
                }
            }
        }
        return arrayList;
    }
}
