package edu.berkeley.cs.nlp.ocular.lm;

import edu.berkeley.cs.nlp.ocular.data.textreader.CharIndexer;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.data.textreader.TextReader;
import edu.berkeley.cs.nlp.ocular.util.ArrayHelper;
import edu.berkeley.cs.nlp.ocular.util.CollectionHelper;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import tberg.murphy.indexer.Indexer;

/* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/lm/NgramLanguageModel.class */
public class NgramLanguageModel implements SingleLanguageModel {
    private static final long serialVersionUID = 873286328149782L;
    private Indexer<String> charIndexer;
    private CountDbBig[] countDbs;
    private int maxOrder;
    private LMType type;
    private double lmPower;
    private Set<LongArrWrapper> allContextsSet;
    private List<int[]> allContexts;
    private Set<Integer> activeCharacters;

    /* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/lm/NgramLanguageModel$LMType.class */
    public enum LMType {
        MLE,
        ABS_DISC,
        KNESER_NEY
    }

    @Override // edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel
    public Set<Integer> getActiveCharacters() {
        return this.activeCharacters;
    }

    public NgramLanguageModel(Indexer<String> indexer, CountDbBig[] countDbBigArr, Set<Integer> set, LMType lMType, double d) {
        this.charIndexer = indexer;
        this.countDbs = countDbBigArr;
        this.maxOrder = countDbBigArr.length;
        if (this.maxOrder <= 0) {
            throw new RuntimeException("maxOrder must be greater than zero.");
        }
        this.type = lMType;
        this.lmPower = d;
        this.allContextsSet = new HashSet();
        this.allContexts = new ArrayList();
        for (int i = 0; i < this.maxOrder - 1; i++) {
            for (long[] jArr : countDbBigArr[i].getKeys()) {
                if (jArr != null && countDbBigArr[i].getCount(jArr, CountType.HISTORY_TYPE_INDEX) > 0) {
                    this.allContextsSet.add(new LongArrWrapper(jArr));
                    this.allContexts.add(LongNgram.convertToIntArr(jArr));
                }
            }
        }
        if (set == null) {
            throw new RuntimeException("activeCharacters is null!");
        }
        this.activeCharacters = set;
    }

    public static NgramLanguageModel buildFromText(String str, int i, int i2, LMType lMType, double d, TextReader textReader) {
        return buildFromText((List<String>) CollectionHelper.makeList(str), i, i2, lMType, d, textReader);
    }

    public static NgramLanguageModel buildFromText(List<String> list, int i, int i2, LMType lMType, double d, TextReader textReader) {
        CorpusCounter corpusCounter = new CorpusCounter(i2);
        Set<Integer> activeCharacters = corpusCounter.getActiveCharacters();
        CharIndexer charIndexer = new CharIndexer();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            corpusCounter.countRecursive(it.next(), i, charIndexer, textReader);
        }
        activeCharacters.add(Integer.valueOf(charIndexer.getIndex((CharIndexer) Charset.SPACE)));
        charIndexer.lock();
        corpusCounter.printStats(-1);
        return new NgramLanguageModel(charIndexer, corpusCounter.getCounts(), activeCharacters, lMType, d);
    }

    public void checkNormalizes(int[] iArr) {
        double d = 0.0d;
        for (int i = 0; i < this.charIndexer.size(); i++) {
            d += getCharNgramProb(iArr, i);
        }
        System.out.println("Total prob for context " + LongNgram.toString(iArr, this.charIndexer) + ": " + d);
    }

    @Override // edu.berkeley.cs.nlp.ocular.lm.LanguageModel
    public Indexer<String> getCharacterIndexer() {
        return this.charIndexer;
    }

    @Override // edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel
    public int getMaxOrder() {
        return this.maxOrder;
    }

    public double getLmPower() {
        return this.lmPower;
    }

    @Override // edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel
    public int[] shrinkContext(int[] iArr) {
        int[] iArr2 = iArr;
        if (iArr2.length > this.maxOrder - 1) {
            iArr2 = ArrayHelper.takeRight(iArr2, this.maxOrder - 1);
        }
        while (!containsContext(iArr2) && iArr2.length > 0) {
            iArr2 = ArrayHelper.takeRight(iArr2, iArr2.length - 1);
        }
        return iArr2;
    }

    @Override // edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel
    public boolean containsContext(int[] iArr) {
        if (iArr.length == 0) {
            return true;
        }
        return this.allContextsSet.contains(new LongArrWrapper(LongNgram.convertToLong(iArr)));
    }

    @Override // edu.berkeley.cs.nlp.ocular.lm.LanguageModel
    public double getCharNgramProb(int[] iArr, int i) {
        return getCharNgramProbRaw(iArr, i);
    }

    private double getCharNgramProbRaw(int[] iArr, int i) {
        double kneserNey;
        int[] iArr2 = new int[iArr.length + 1];
        System.arraycopy(iArr, 0, iArr2, 0, iArr.length);
        iArr2[iArr2.length - 1] = i;
        NgramWrapper ngramWrapper = NgramWrapper.getNew(iArr2, 0, iArr2.length);
        switch (this.type) {
            case MLE:
                kneserNey = new NgramCounts(ngramWrapper, this.countDbs).getTokenMle();
                break;
            case ABS_DISC:
                kneserNey = new NgramCounts(ngramWrapper, this.countDbs).getAbsoluteDiscounting();
                break;
            case KNESER_NEY:
                kneserNey = new NgramCounts(ngramWrapper, this.countDbs).getKneserNey();
                break;
            default:
                throw new RuntimeException("Bad type: " + this.type);
        }
        return Math.pow(kneserNey, this.lmPower);
    }
}
