package edu.berkeley.cs.nlp.ocular.main;

import edu.berkeley.cs.nlp.ocular.data.textreader.BasicTextReader;
import edu.berkeley.cs.nlp.ocular.data.textreader.BlacklistCharacterSetTextReader;
import edu.berkeley.cs.nlp.ocular.data.textreader.CharIndexer;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.data.textreader.ConvertLongSTextReader;
import edu.berkeley.cs.nlp.ocular.data.textreader.RemoveAllDiacriticsTextReader;
import edu.berkeley.cs.nlp.ocular.data.textreader.ReplaceSomeTextReader;
import edu.berkeley.cs.nlp.ocular.data.textreader.TextReader;
import edu.berkeley.cs.nlp.ocular.data.textreader.WhitelistCharacterSetTextReader;
import edu.berkeley.cs.nlp.ocular.lm.BasicCodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.CodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.CorpusCounter;
import edu.berkeley.cs.nlp.ocular.lm.LanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.NgramLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel;
import edu.berkeley.cs.nlp.ocular.util.FileUtil;
import edu.berkeley.cs.nlp.ocular.util.StringHelper;
import edu.berkeley.cs.nlp.ocular.util.Tuple2;
import edu.berkeley.cs.nlp.ocular.util.Tuple3;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.jocl.CL;
import tberg.murphy.fig.Option;
import tberg.murphy.fileio.f;
import tberg.murphy.indexer.HashMapIndexer;
import tberg.murphy.indexer.Indexer;

/* loaded from: input_file:main/ocular_2.12-0.3-SNAPSHOT.jar:edu/berkeley/cs/nlp/ocular/main/InitializeLanguageModel.class */
public class InitializeLanguageModel extends OcularRunnable {

    @Option(gloss = "Output LM file path.")
    public static String outputLmPath = null;

    @Option(gloss = "Path to the text files (or directory hierarchies) for training the LM.  For each entry, the entire directory will be recursively searched for any files that do not start with `.`.  For a multilingual (code-switching) model, give multiple comma-separated files with language names: \"english->texts/english/,spanish->texts/spanish/,french->texts/french/\".  Be sure to wrap the whole string with \"quotes\".)")
    public static String inputTextPath = null;

    @Option(gloss = "Number of times the character must be seen in order to be included.")
    public static int minCharCount = 10;

    @Option(gloss = "Prior probability of each language; ignore for uniform priors. Give multiple comma-separated language/prior pairs: \"english->0.7,spanish->0.2,french->0.1\". Be sure to wrap the whole string with \"quotes\". (Only relevant if multiple languages used.)  Default: Uniform priors.")
    public static String languagePriors = null;

    @Option(gloss = "Prior probability of sticking with the same language when moving between words in a code-switch model transition model. (Only relevant if multiple languages used.)")
    public static double pKeepSameLanguage = 0.999999d;

    @Option(gloss = "Paths to Alternate Spelling Replacement files. If just a simple path is given, the replacements will be applied to all languages.  For language-specific replacements, give multiple comma-separated language/path pairs: \"english->rules/en.txt,spanish->rules/sp.txt,french->rules/fr.txt\". Be sure to wrap the whole string with \"quotes\". Any languages for which no replacements are need can be safely ignored.")
    public static String alternateSpellingReplacementPaths = null;

    @Option(gloss = "Automatically insert \"long s\" characters into the language model training data?")
    public static boolean insertLongS = false;

    @Option(gloss = "Remove diacritics?")
    public static boolean removeDiacritics = false;

    @Option(gloss = "Treat backslashes in text as escape characters?")
    public static boolean escapes = false;

    @Option(gloss = "A set of valid characters. If a character with a diacritic is found but not in this set, the diacritic will be dropped. Other excluded characters will simply be dropped. Ignore to allow all characters.")
    public static Set<String> explicitCharacterSet = null;

    @Option(gloss = "LM character n-gram length. If just one language is used, or if all languages should use the same value, just give an integer.  If languages can have different values, give them as comma-separated language/integer pairs: \"english->6,spanish->4,french->4\"; be sure to wrap the whole string with \"quotes\".")
    public static String charNgramLength = "6";

    @Option(gloss = "Exponent on LM scores.")
    public static double lmPower = 4.0d;

    @Option(gloss = "Number of characters to use for training the LM.  Use 0 to indicate that the full training data should be used.  Default: Use all documents in full.")
    public static long lmCharCount = 0;

    public static void main(String[] strArr) {
        System.out.println("InitializeLanguageModel");
        OcularRunnable initializeLanguageModel = new InitializeLanguageModel();
        initializeLanguageModel.doMain(initializeLanguageModel, strArr);
    }

    @Override // edu.berkeley.cs.nlp.ocular.main.OcularRunnable
    protected void validateOptions() {
        if (outputLmPath == null) {
            throw new IllegalArgumentException("-outputLmPath not set");
        }
        if (inputTextPath == null) {
            throw new IllegalArgumentException("-inputTextPath not set");
        }
    }

    @Override // edu.berkeley.cs.nlp.ocular.main.OcularRunnable
    public void run(List<String> list) {
        Tuple2<Indexer<String>, List<Tuple3<Tuple2<String, TextReader>, Double, Integer>>> makePathsReadersAndPriors = makePathsReadersAndPriors();
        Indexer<String> indexer = makePathsReadersAndPriors._1;
        List<Tuple3<Tuple2<String, TextReader>, Double, Integer>> list2 = makePathsReadersAndPriors._2;
        CharIndexer charIndexer = new CharIndexer();
        List<Tuple2<SingleLanguageModel, Double>> makeMultipleSubLMs = makeMultipleSubLMs(list2, charIndexer, indexer);
        charIndexer.lock();
        System.out.println("pKeepSameLanguage = " + pKeepSameLanguage);
        double d = 0.0d;
        Iterator<Tuple2<SingleLanguageModel, Double>> it = makeMultipleSubLMs.iterator();
        while (it.hasNext()) {
            d += it.next()._2.doubleValue();
        }
        StringBuilder sb = new StringBuilder("Language priors: ");
        for (int i = 0; i < indexer.size(); i++) {
            sb.append(indexer.getObject(i)).append(" -> ").append(makeMultipleSubLMs.get(i)._2.doubleValue() / d).append(", ");
        }
        StringBuilder sb2 = new StringBuilder("Char ngram lengths: ");
        for (int i2 = 0; i2 < indexer.size(); i2++) {
            sb2.append(indexer.getObject(i2)).append(" -> ").append(list2.get(i2)._3).append(", ");
        }
        System.out.println(sb2.substring(0, sb2.length() - 2));
        ArrayList arrayList = new ArrayList();
        Iterator<String> it2 = charIndexer.getObjects().iterator();
        while (it2.hasNext()) {
            arrayList.add(it2.next());
        }
        Collections.sort(arrayList);
        System.out.println("ALL POSSIBLE CHARACTERS: " + arrayList);
        BasicCodeSwitchLanguageModel basicCodeSwitchLanguageModel = new BasicCodeSwitchLanguageModel(makeMultipleSubLMs, charIndexer, indexer, pKeepSameLanguage);
        System.out.println("writing LM to " + outputLmPath);
        writeLM(basicCodeSwitchLanguageModel, outputLmPath);
    }

    public Tuple2<Indexer<String>, List<Tuple3<Tuple2<String, TextReader>, Double, Integer>>> makePathsReadersAndPriors() {
        String str = inputTextPath;
        if (!inputTextPath.contains("->")) {
            str = "NoLanguageNameGiven->" + inputTextPath;
        }
        HashMap hashMap = new HashMap();
        for (String str2 : str.split(",")) {
            String[] split = str2.split("->");
            if (split.length != 2) {
                throw new IllegalArgumentException("malformed lmPath argument: comma-separated part must be of the form \"LANGUAGE->PATH\", was: " + str2);
            }
            hashMap.put(split[0].trim(), split[1].trim());
        }
        HashMap hashMap2 = new HashMap();
        if (languagePriors == null || languagePriors.isEmpty()) {
            Iterator it = hashMap.keySet().iterator();
            while (it.hasNext()) {
                hashMap2.put((String) it.next(), Double.valueOf(1.0d));
            }
        } else {
            for (String str3 : languagePriors.split(",")) {
                String[] split2 = str3.split("->");
                if (split2.length != 2) {
                    throw new IllegalArgumentException("malformed languagePriors argument: comma-separated part must be of the form \"LANGUAGE->PRIOR\", was: " + str3);
                }
                hashMap2.put(split2[0].trim(), Double.valueOf(Double.parseDouble(split2[1].trim())));
            }
            if (!hashMap.keySet().equals(hashMap2.keySet())) {
                throw new RuntimeException("-inputTextPath and -languagePriors do not have the same set of languages: " + hashMap.keySet() + " vs " + hashMap2.keySet());
            }
        }
        HashMap hashMap3 = new HashMap();
        if (Pattern.matches("^\\d+$", charNgramLength)) {
            Iterator it2 = hashMap.keySet().iterator();
            while (it2.hasNext()) {
                hashMap3.put((String) it2.next(), Integer.valueOf(Integer.parseInt(charNgramLength)));
            }
        } else {
            for (String str4 : charNgramLength.split(",")) {
                String[] split3 = str4.split("->");
                if (split3.length != 2) {
                    throw new IllegalArgumentException("malformed charNgramLength argument: comma-separated part must be of the form \"LANGUAGE->LENGTH\", was: " + str4);
                }
                hashMap3.put(split3[0].trim(), Integer.valueOf(Integer.parseInt(split3[1].trim())));
            }
            if (!hashMap.keySet().equals(hashMap3.keySet())) {
                throw new RuntimeException("-inputTextPath and -languagePriors do not have the same set of languages: " + hashMap.keySet() + " vs " + hashMap3.keySet());
            }
        }
        HashMap hashMap4 = new HashMap();
        if (alternateSpellingReplacementPaths != null && !alternateSpellingReplacementPaths.isEmpty()) {
            if (alternateSpellingReplacementPaths.contains("->")) {
                for (String str5 : alternateSpellingReplacementPaths.split(",")) {
                    String[] split4 = str5.split("->");
                    if (split4.length != 2) {
                        throw new IllegalArgumentException("malformed alternateSpellingReplacementPaths argument: comma-separated part must be of the form \"LANGUAGE->PATH\", was: " + str5);
                    }
                    String trim = split4[0].trim();
                    String trim2 = split4[1].trim();
                    if (!hashMap.keySet().contains(trim)) {
                        throw new RuntimeException("Language '" + trim + "' appears in the alternateSpellingReplacementPaths argument but not in inputTextPath (" + hashMap.keySet() + ")");
                    }
                    hashMap4.put(trim, trim2);
                }
            } else {
                String str6 = alternateSpellingReplacementPaths;
                Iterator it3 = hashMap.keySet().iterator();
                while (it3.hasNext()) {
                    hashMap4.put((String) it3.next(), str6);
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        HashMapIndexer hashMapIndexer = new HashMapIndexer();
        for (String str7 : hashMap.keySet()) {
            String str8 = (String) hashMap.get(str7);
            Double d = (Double) hashMap2.get(str7);
            Integer num = (Integer) hashMap3.get(str7);
            System.out.println("For language '" + str7 + "', using text in " + str8 + ", prior=" + d + (hashMap4.keySet().contains(str7) ? ", alternate spelling replacement rules in " + ((String) hashMap4.get(str7)) : ""));
            TextReader blacklistCharacterSetTextReader = new BlacklistCharacterSetTextReader(Charset.BANNED_CHARS, new BasicTextReader(escapes));
            if (explicitCharacterSet != null && !explicitCharacterSet.isEmpty()) {
                blacklistCharacterSetTextReader = new WhitelistCharacterSetTextReader(explicitCharacterSet, blacklistCharacterSetTextReader);
            }
            if (removeDiacritics) {
                blacklistCharacterSetTextReader = new RemoveAllDiacriticsTextReader(blacklistCharacterSetTextReader);
            }
            if (insertLongS) {
                blacklistCharacterSetTextReader = new ConvertLongSTextReader(blacklistCharacterSetTextReader);
            }
            if (hashMap4.keySet().contains(str7)) {
                blacklistCharacterSetTextReader = handleReplacementRulesOption(blacklistCharacterSetTextReader, (String) hashMap4.get(str7));
            }
            hashMapIndexer.getIndex(str7);
            arrayList.add(Tuple3.Tuple3(Tuple2.Tuple2(str8, blacklistCharacterSetTextReader), d, num));
        }
        return Tuple2.Tuple2(hashMapIndexer, arrayList);
    }

    private TextReader handleReplacementRulesOption(TextReader textReader, String str) {
        if (!new File(str).exists()) {
            throw new RuntimeException("replacementsFile [" + str + "] does not exist");
        }
        List<Tuple2<Tuple2<List<String>, List<String>>, Integer>> loadRulesFromFile = ReplaceSomeTextReader.loadRulesFromFile(str);
        Iterator<Tuple2<Tuple2<List<String>, List<String>>, Integer>> it = loadRulesFromFile.iterator();
        while (it.hasNext()) {
            System.out.println("    " + it.next());
        }
        return new ReplaceSomeTextReader(loadRulesFromFile, textReader);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<Tuple2<SingleLanguageModel, Double>> makeMultipleSubLMs(List<Tuple3<Tuple2<String, TextReader>, Double, Integer>> list, Indexer<String> indexer, Indexer<String> indexer2) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < indexer2.size(); i++) {
            Tuple3<Tuple2<String, TextReader>, Double, Integer> tuple3 = list.get(i);
            String object = indexer2.getObject(i);
            String str = tuple3._1._1;
            TextReader textReader = tuple3._1._2;
            System.out.println(object + " text reader: " + textReader);
            CorpusCounter corpusCounter = new CorpusCounter(tuple3._3.intValue());
            int i2 = 0;
            for (List<String> list2 : readFileChars(str, textReader, lmCharCount > 0 ? lmCharCount : CL.CL_LONG_MAX)) {
                corpusCounter.countChars(list2, indexer, 0);
                i2 += list2.size();
            }
            System.out.println("  using " + i2 + " characters for " + object + " read from " + str);
            Set<Integer> activeCharacters = corpusCounter.getActiveCharacters();
            ArrayList<Tuple2> arrayList2 = new ArrayList();
            for (Map.Entry<Integer, Integer> entry : corpusCounter.getUnigramCounts().entrySet()) {
                arrayList2.add(Tuple2.Tuple2(entry.getValue(), entry.getKey()));
            }
            Collections.sort(arrayList2, new Tuple2.DefaultLexicographicTuple2Comparator());
            Collections.reverse(arrayList2);
            for (Tuple2 tuple2 : arrayList2) {
                StringBuilder sb = new StringBuilder();
                if (((Integer) tuple2._1).intValue() < minCharCount) {
                    activeCharacters.remove(tuple2._2);
                    sb.append("[skipped due to count < " + minCharCount + "]");
                }
                System.out.println("    " + tuple2._1 + "  " + indexer.getObject(((Integer) tuple2._2).intValue()) + "   " + StringHelper.toUnicode(indexer.getObject(((Integer) tuple2._2).intValue())) + "   " + ((Object) sb));
            }
            activeCharacters.add(Integer.valueOf(indexer.getIndex(Charset.SPACE)));
            System.out.println("Including 'universal punctuation' chars: " + Charset.UNIV_PUNC);
            Iterator<String> it = Charset.UNIV_PUNC.iterator();
            while (it.hasNext()) {
                activeCharacters.add(Integer.valueOf(indexer.getIndex(it.next())));
            }
            ArrayList arrayList3 = new ArrayList();
            Iterator<Integer> it2 = activeCharacters.iterator();
            while (it2.hasNext()) {
                arrayList3.add(indexer.getObject(it2.next().intValue()));
            }
            Collections.sort(arrayList3);
            System.out.println(object + ": " + arrayList3);
            arrayList.add(Tuple2.Tuple2(new NgramLanguageModel(indexer, corpusCounter.getCounts(), corpusCounter.getActiveCharacters(), NgramLanguageModel.LMType.KNESER_NEY, lmPower), tuple3._2));
        }
        indexer.getIndex(Charset.LONG_S);
        for (Map.Entry<String, String> entry2 : Charset.LIGATURES.entrySet()) {
            List<String> readNormalizeCharacters = Charset.readNormalizeCharacters(entry2.getKey());
            if (readNormalizeCharacters.size() > 1) {
                throw new RuntimeException("Ligature [" + entry2.getKey() + "] has more than one character: " + readNormalizeCharacters);
            }
            indexer.getIndex(readNormalizeCharacters.get(0));
            Iterator<String> it3 = Charset.readNormalizeCharacters(entry2.getValue()).iterator();
            while (it3.hasNext()) {
                indexer.getIndex(it3.next());
            }
        }
        for (String str2 : indexer.getObjects()) {
            String removeAnyDiacriticFromChar = Charset.removeAnyDiacriticFromChar(str2);
            if (Charset.CHARS_THAT_CAN_BE_DECORATED_WITH_AN_ELISION_TILDE.contains(str2)) {
                indexer.getIndex(Charset.addTilde(str2));
            }
            if (Charset.CHARS_THAT_CAN_BE_DECORATED_WITH_AN_ELISION_TILDE.contains(removeAnyDiacriticFromChar)) {
                indexer.getIndex(Charset.addTilde(removeAnyDiacriticFromChar));
            }
            indexer.getIndex(removeAnyDiacriticFromChar);
        }
        indexer.lock();
        return arrayList;
    }

    private List<List<String>> readFileChars(String str, TextReader textReader, long j) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        loop0: for (File file : FileUtil.recursiveFiles(str)) {
            ArrayList arrayList2 = new ArrayList();
            for (String str2 : f.readLines(file.getPath())) {
                if (!str2.isEmpty()) {
                    for (String str3 : textReader.readCharacters(str2 + Charset.SPACE)) {
                        Charset.normalizeChar(str3);
                        arrayList2.add(str3);
                        i++;
                    }
                    if (i >= j) {
                        break loop0;
                    }
                }
            }
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    public static LanguageModel readLM(String str) {
        ObjectInputStream objectInputStream = null;
        try {
            try {
                File file = new File(str);
                if (!file.exists()) {
                    throw new RuntimeException("Serialized LanguageModel file " + str + " not found");
                }
                ObjectInputStream objectInputStream2 = new ObjectInputStream(new GZIPInputStream(new FileInputStream(file)));
                LanguageModel languageModel = (LanguageModel) objectInputStream2.readObject();
                if (objectInputStream2 != null) {
                    try {
                        objectInputStream2.close();
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                }
                return languageModel;
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        } catch (Throwable th) {
            if (0 != 0) {
                try {
                    objectInputStream.close();
                } catch (IOException e3) {
                    throw new RuntimeException(e3);
                }
            }
            throw th;
        }
    }

    public static CodeSwitchLanguageModel readCodeSwitchLM(String str) {
        return (CodeSwitchLanguageModel) readLM(str);
    }

    public static void writeLM(CodeSwitchLanguageModel codeSwitchLanguageModel, String str) {
        ObjectOutputStream objectOutputStream = null;
        try {
            try {
                new File(str).getAbsoluteFile().getParentFile().mkdirs();
                objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(str)));
                objectOutputStream.writeObject(codeSwitchLanguageModel);
                if (objectOutputStream != null) {
                    try {
                        objectOutputStream.close();
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                }
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        } catch (Throwable th) {
            if (objectOutputStream != null) {
                try {
                    objectOutputStream.close();
                } catch (IOException e3) {
                    throw new RuntimeException(e3);
                }
            }
            throw th;
        }
    }
}
