package edu.berkeley.cs.nlp.ocular.gsm;

import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar;
import edu.berkeley.cs.nlp.ocular.model.DecodeState;
import edu.berkeley.cs.nlp.ocular.model.TransitionStateType;
import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel;
import edu.berkeley.cs.nlp.ocular.preprocessing.Cropper;
import edu.berkeley.cs.nlp.ocular.util.ArrayHelper;
import edu.berkeley.cs.nlp.ocular.util.CollectionHelper;
import edu.berkeley.cs.nlp.ocular.util.FileHelper;
import java.util.List;
import java.util.Map;
import java.util.Set;
import tberg.murphy.indexer.Indexer;

/* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/gsm/BasicGlyphSubstitutionModel.class */
public class BasicGlyphSubstitutionModel implements GlyphSubstitutionModel {
    private static final long serialVersionUID = -8473038413268727114L;
    private Indexer<String> langIndexer;
    private Indexer<String> charIndexer;
    private int numChars;
    private double[][][] probs;
    private double gsmPower;

    /* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/gsm/BasicGlyphSubstitutionModel$BasicGlyphSubstitutionModelFactory.class */
    public static class BasicGlyphSubstitutionModelFactory {
        private double gsmSmoothingCount;
        private double elisionSmoothingCountMultiplier;
        private Indexer<String> langIndexer;
        private Indexer<String> charIndexer;
        private Set<Integer>[] activeCharacterSets;
        private Set<Integer> canBeReplaced;
        private Set<Integer> canBeDoubled;
        private Set<Integer> validSubstitutionChars;
        private Set<Integer> canBeElided;
        private Map<Integer, Integer> addTilde;
        private Map<Integer, Integer> diacriticDisregardMap;
        private int sCharIndex;
        private int longsCharIndex;
        private int fCharIndex;
        private int lCharIndex;
        private int hyphenCharIndex;
        private int spaceCharIndex;
        private int numLanguages;
        private int numChars;
        private int numGlyphs;
        public final int GLYPH_ELISION_TILDE;
        public final int GLYPH_TILDE_ELIDED;
        public final int GLYPH_FIRST_ELIDED;
        public final int GLYPH_DOUBLED;
        public final int GLYPH_ELIDED;
        private double gsmPower;
        private int minCountsForEvalGsm;
        private String outputPath;

        public BasicGlyphSubstitutionModelFactory(double d, double d2, Indexer<String> indexer, Indexer<String> indexer2, Set<Integer>[] setArr, double d3, int i, String str) {
            this.gsmSmoothingCount = d;
            this.elisionSmoothingCountMultiplier = d2;
            this.langIndexer = indexer;
            this.charIndexer = indexer2;
            this.activeCharacterSets = setArr;
            this.gsmPower = d3;
            this.minCountsForEvalGsm = i;
            this.canBeReplaced = Charset.makeCanBeReplacedSet(indexer2);
            this.canBeDoubled = Charset.makeValidDoublableSet(indexer2);
            this.validSubstitutionChars = Charset.makeValidSubstitutionCharsSet(indexer2);
            this.canBeElided = Charset.makeCanBeElidedSet(indexer2);
            this.addTilde = Charset.makeAddTildeMap(indexer2);
            this.diacriticDisregardMap = Charset.makeDiacriticDisregardMap(indexer2);
            this.sCharIndex = indexer2.contains("s") ? indexer2.getIndex("s") : -1;
            this.longsCharIndex = indexer2.getIndex(Charset.LONG_S);
            this.fCharIndex = indexer2.contains("f") ? indexer2.getIndex("f") : -1;
            this.lCharIndex = indexer2.contains("l") ? indexer2.getIndex("l") : -1;
            this.hyphenCharIndex = indexer2.getIndex(Charset.HYPHEN);
            this.spaceCharIndex = indexer2.getIndex(Charset.SPACE);
            this.numLanguages = indexer.size();
            this.numChars = indexer2.size();
            this.numGlyphs = (this.numChars + GlyphChar.GlyphType.values().length) - 1;
            this.GLYPH_ELISION_TILDE = this.numChars + GlyphChar.GlyphType.ELISION_TILDE.ordinal();
            this.GLYPH_TILDE_ELIDED = this.numChars + GlyphChar.GlyphType.TILDE_ELIDED.ordinal();
            this.GLYPH_FIRST_ELIDED = this.numChars + GlyphChar.GlyphType.FIRST_ELIDED.ordinal();
            this.GLYPH_DOUBLED = this.numChars + GlyphChar.GlyphType.DOUBLED.ordinal();
            this.GLYPH_ELIDED = this.numChars + GlyphChar.GlyphType.ELIDED.ordinal();
            this.outputPath = str;
        }

        public GlyphSubstitutionModel uniform() {
            return make(initializeNewCountsMatrix(), 0, 0);
        }

        public double[][][] initializeNewCountsMatrix() {
            double[][][] dArr = new double[this.numLanguages][this.numChars][this.numGlyphs];
            for (int i = 0; i < this.numLanguages; i++) {
                for (int i2 = 0; i2 < this.numChars; i2++) {
                    for (int i3 = 0; i3 < this.numGlyphs; i3++) {
                        dArr[i][i2][i3] = getSmoothingValue(i, i2, i3);
                    }
                }
            }
            return dArr;
        }

        public double getSmoothingValue(int i, int i2, int i3) {
            if (!this.activeCharacterSets[i].contains(Integer.valueOf(i2)) && i2 != this.hyphenCharIndex) {
                return Cropper.VERT_GROW_RATIO;
            }
            if (i3 == this.GLYPH_ELISION_TILDE) {
                return this.addTilde.get(Integer.valueOf(i2)) == null ? Cropper.VERT_GROW_RATIO : this.gsmSmoothingCount * this.elisionSmoothingCountMultiplier;
            }
            if (i3 == this.GLYPH_TILDE_ELIDED) {
                return !this.canBeElided.contains(Integer.valueOf(i2)) ? Cropper.VERT_GROW_RATIO : this.gsmSmoothingCount * this.elisionSmoothingCountMultiplier;
            }
            if (i3 == this.GLYPH_FIRST_ELIDED) {
                return !this.canBeElided.contains(Integer.valueOf(i2)) ? Cropper.VERT_GROW_RATIO : this.gsmSmoothingCount * this.elisionSmoothingCountMultiplier;
            }
            if (i3 == this.GLYPH_DOUBLED) {
                return !this.canBeDoubled.contains(Integer.valueOf(i2)) ? Cropper.VERT_GROW_RATIO : this.gsmSmoothingCount;
            }
            if (i3 == this.GLYPH_ELIDED) {
                return !this.canBeElided.contains(Integer.valueOf(i2)) ? Cropper.VERT_GROW_RATIO : this.gsmSmoothingCount;
            }
            Integer num = this.diacriticDisregardMap.get(Integer.valueOf(i2));
            return (num == null || !num.equals(Integer.valueOf(i3))) ? (i2 == this.sCharIndex && i3 == this.longsCharIndex) ? this.gsmSmoothingCount : (i2 == this.sCharIndex && (i3 == this.fCharIndex || i3 == this.lCharIndex)) ? Cropper.VERT_GROW_RATIO : (i2 == this.hyphenCharIndex && i3 == this.spaceCharIndex) ? this.gsmSmoothingCount : ((this.canBeReplaced.contains(Integer.valueOf(i2)) && this.validSubstitutionChars.contains(Integer.valueOf(i3)) && this.activeCharacterSets[i].contains(Integer.valueOf(i3))) || i2 == i3) ? this.gsmSmoothingCount : Cropper.VERT_GROW_RATIO : this.gsmSmoothingCount * this.elisionSmoothingCountMultiplier;
        }

        public void incrementCounts(double[][][] dArr, List<DecodeState> list) {
            int languageIndex;
            for (int i = 0; i < list.size(); i++) {
                SparseTransitionModel.TransitionState transitionState = list.get(i).ts;
                TransitionStateType type = transitionState.getType();
                if (type == TransitionStateType.TMPL) {
                    int languageIndex2 = transitionState.getLanguageIndex();
                    if (languageIndex2 >= 0) {
                        int lmCharIndex = transitionState.getLmCharIndex();
                        int glyphIndex = glyphIndex(transitionState.getGlyphChar());
                        double[] dArr2 = dArr[languageIndex2][lmCharIndex];
                        dArr2[glyphIndex] = dArr2[glyphIndex] + 1.0d;
                    }
                } else if (type == TransitionStateType.RMRGN_HPHN_INIT && (languageIndex = transitionState.getLanguageIndex()) >= 0) {
                    GlyphChar glyphChar = transitionState.getGlyphChar();
                    if (glyphChar.templateCharIndex == this.spaceCharIndex) {
                        int glyphIndex2 = glyphIndex(glyphChar);
                        double[] dArr3 = dArr[languageIndex][this.hyphenCharIndex];
                        dArr3[glyphIndex2] = dArr3[glyphIndex2] + 1.0d;
                    }
                }
            }
        }

        private int glyphIndex(GlyphChar glyphChar) {
            return glyphChar.glyphType == GlyphChar.GlyphType.NORMAL_CHAR ? glyphChar.templateCharIndex : this.numChars + glyphChar.glyphType.ordinal();
        }

        public BasicGlyphSubstitutionModel make(double[][][] dArr, int i, int i2) {
            double[][][] dArr2 = new double[this.numLanguages][this.numChars][this.numGlyphs];
            for (int i3 = 0; i3 < this.numLanguages; i3++) {
                for (int i4 = 0; i4 < this.numChars; i4++) {
                    for (int i5 = 0; i5 < this.numChars; i5++) {
                        double sum = ArrayHelper.sum(dArr[i3][i5]);
                        for (int i6 = 0; i6 < this.numGlyphs; i6++) {
                            double d = dArr[i3][i5][i6];
                            dArr2[i3][i5][i6] = d > 1.0E-9d ? d / sum : Cropper.VERT_GROW_RATIO;
                        }
                    }
                }
            }
            return new BasicGlyphSubstitutionModel(dArr2, this.gsmPower, this.langIndexer, this.charIndexer);
        }

        public BasicGlyphSubstitutionModel makeForEval(double[][][] dArr, int i, int i2) {
            return makeForEval(dArr, i, i2, this.minCountsForEvalGsm);
        }

        public BasicGlyphSubstitutionModel makeForEval(double[][][] dArr, int i, int i2, double d) {
            if (d < 1.0d) {
                System.out.println("Estimating parameters of a new Glyph Substitution Model.  Iter: " + i + ", batch: " + i2);
                return make(dArr, i, i2);
            }
            double[][][] dArr2 = new double[this.numLanguages][this.numChars][this.numGlyphs];
            double[][][] dArr3 = new double[this.numLanguages][this.numChars][this.numGlyphs];
            for (int i3 = 0; i3 < this.numLanguages; i3++) {
                for (int i4 = 0; i4 < this.numChars; i4++) {
                    for (int i5 = 0; i5 < this.numGlyphs; i5++) {
                        double d2 = dArr[i3][i4][i5] - this.gsmSmoothingCount;
                        if (d2 < 1.0E-9d) {
                            dArr2[i3][i4][i5] = 0.0d;
                        } else if (d2 < d - 1.0E-9d) {
                            dArr2[i3][i4][i5] = 0.0d;
                        } else {
                            dArr2[i3][i4][i5] = d2;
                        }
                    }
                    double sum = ArrayHelper.sum(dArr2[i3][i4]);
                    for (int i6 = 0; i6 < this.numGlyphs; i6++) {
                        double d3 = dArr2[i3][i4][i6];
                        dArr3[i3][i4][i6] = d3 > 1.0E-9d ? d3 / sum : Cropper.VERT_GROW_RATIO;
                    }
                }
            }
            return new BasicGlyphSubstitutionModel(dArr3, this.gsmPower, this.langIndexer, this.charIndexer);
        }

        private void printGsmProbs3(int i, int i2, int i3, double[][][] dArr, double[][][] dArr2, int i4, int i5, String str) {
            Set union = CollectionHelper.setUnion(CollectionHelper.makeSet(Charset.SPACE, Charset.HYPHEN, "a", "b", "c", "d", Charset.LONG_S));
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append("language\tlmChar\tglyph\tcount\tminProb\tprob\n");
            for (int i6 = 0; i6 < i; i6++) {
                String object = this.langIndexer.getObject(i6);
                for (int i7 = 0; i7 < i2; i7++) {
                    String object2 = this.charIndexer.getObject(i7);
                    double min = ArrayHelper.min(dArr2[i6][i7]);
                    int i8 = 0;
                    while (i8 < i3) {
                        String object3 = i8 < i2 ? this.charIndexer.getObject(i8) : GlyphChar.GlyphType.values()[i8 - i2].toString();
                        double d = dArr2[i6][i7][i8];
                        double d2 = dArr[i6][i7][i8];
                        if (d2 > this.gsmSmoothingCount || (union.contains(object2) && (union.contains(object3) || i8 >= i2))) {
                            stringBuffer.append(object).append("\t");
                            stringBuffer.append(object2).append("\t");
                            stringBuffer.append(object3).append("\t");
                            stringBuffer.append(d2).append("\t");
                            stringBuffer.append(min).append("\t");
                            stringBuffer.append(d).append("\t");
                            stringBuffer.append("\n");
                        }
                        i8++;
                    }
                }
            }
            String str2 = str + ".tsv";
            System.out.println("Writing info about newly-trained GSM on iteration " + i4 + ", batch " + i5 + " out to [" + str2 + "]");
            FileHelper.writeString(str2, stringBuffer.toString());
        }

        private String gsmPrintoutFilepath(int i, int i2) {
            String str = this.outputPath + "/gsm/newGSM";
            if (i > 0) {
                str = str + "_iter-" + i;
            }
            if (i2 > 0) {
                str = str + "_batch-" + i2;
            }
            return str;
        }
    }

    public BasicGlyphSubstitutionModel(double[][][] dArr, double d, Indexer<String> indexer, Indexer<String> indexer2) {
        this.langIndexer = indexer;
        this.charIndexer = indexer2;
        this.numChars = indexer2.size();
        this.probs = dArr;
        this.gsmPower = d;
    }

    @Override // edu.berkeley.cs.nlp.ocular.gsm.GlyphSubstitutionModel
    public double glyphProb(int i, int i2, GlyphChar glyphChar) {
        GlyphChar.GlyphType glyphType = glyphChar.glyphType;
        return Math.pow(this.probs[i][i2][glyphType == GlyphChar.GlyphType.NORMAL_CHAR ? glyphChar.templateCharIndex : this.numChars + glyphType.ordinal()], this.gsmPower);
    }

    public Indexer<String> getLanguageIndexer() {
        return this.langIndexer;
    }

    public Indexer<String> getCharacterIndexer() {
        return this.charIndexer;
    }

    public Indexer<String> getLangIndexer() {
        return this.langIndexer;
    }

    public Indexer<String> getCharIndexer() {
        return this.charIndexer;
    }
}
