/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.postag;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.concurrent.atomic.AtomicInteger;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.ml.BeamSearch;
import opennlp.tools.ml.EventModelSequenceTrainer;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.SequenceTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.model.MaxentModel;
import opennlp.tools.ml.model.SequenceClassificationModel;
import opennlp.tools.ngram.NGramModel;
import opennlp.tools.postag.MutableTagDictionary;
import opennlp.tools.postag.POSContextGenerator;
import opennlp.tools.postag.POSDictionary;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSSampleEventStream;
import opennlp.tools.postag.POSSampleSequenceStream;
import opennlp.tools.postag.POSTagger;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.TagDictionary;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.Sequence;
import opennlp.tools.util.SequenceValidator;
import opennlp.tools.util.StringList;
import opennlp.tools.util.StringUtil;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.featuregen.StringPattern;
import opennlp.tools.util.model.ModelType;

public class POSTaggerME
implements POSTagger {
    public static final int DEFAULT_BEAM_SIZE = 3;
    private POSModel modelPackage;
    protected POSContextGenerator contextGen;
    protected TagDictionary tagDictionary;
    protected Dictionary ngramDictionary;
    protected boolean useClosedClassTagsFilter = false;
    protected int size;
    private Sequence bestSequence;
    private SequenceClassificationModel<String> model;
    private SequenceValidator<String> sequenceValidator;

    @Deprecated
    public POSTaggerME(POSModel model, int beamSize, int cacheSize) {
        POSTaggerFactory factory = model.getFactory();
        this.modelPackage = model;
        this.contextGen = factory.getPOSContextGenerator(beamSize);
        this.tagDictionary = factory.getTagDictionary();
        this.size = beamSize;
        this.sequenceValidator = factory.getSequenceValidator();
        this.model = model.getPosSequenceModel() != null ? model.getPosSequenceModel() : new BeamSearch<String>(beamSize, model.getPosModel(), cacheSize);
    }

    public POSTaggerME(POSModel model) {
        POSTaggerFactory factory = model.getFactory();
        int beamSize = 3;
        String beamSizeString = model.getManifestProperty("BeamSize");
        if (beamSizeString != null) {
            beamSize = Integer.parseInt(beamSizeString);
        }
        this.modelPackage = model;
        this.contextGen = factory.getPOSContextGenerator(beamSize);
        this.tagDictionary = factory.getTagDictionary();
        this.size = beamSize;
        this.sequenceValidator = factory.getSequenceValidator();
        this.model = model.getPosSequenceModel() != null ? model.getPosSequenceModel() : new BeamSearch<String>(beamSize, model.getPosModel(), 0);
    }

    @Deprecated
    public int getNumTags() {
        return this.model.getOutcomes().length;
    }

    public String[] getAllPosTags() {
        return this.model.getOutcomes();
    }

    @Override
    @Deprecated
    public List<String> tag(List<String> sentence) {
        this.bestSequence = this.model.bestSequence((String[])sentence.toArray(new String[sentence.size()]), null, this.contextGen, this.sequenceValidator);
        return this.bestSequence.getOutcomes();
    }

    @Override
    public String[] tag(String[] sentence) {
        return this.tag(sentence, null);
    }

    @Override
    public String[] tag(String[] sentence, Object[] additionaContext) {
        this.bestSequence = this.model.bestSequence((String[])sentence, additionaContext, this.contextGen, this.sequenceValidator);
        List<String> t = this.bestSequence.getOutcomes();
        return t.toArray(new String[t.size()]);
    }

    public String[][] tag(int numTaggings, String[] sentence) {
        Sequence[] bestSequences = this.model.bestSequences(numTaggings, (String[])sentence, null, this.contextGen, this.sequenceValidator);
        String[][] tags = new String[bestSequences.length][];
        for (int si = 0; si < tags.length; ++si) {
            List<String> t = bestSequences[si].getOutcomes();
            tags[si] = t.toArray(new String[t.size()]);
        }
        return tags;
    }

    @Override
    @Deprecated
    public Sequence[] topKSequences(List<String> sentence) {
        return this.model.bestSequences(this.size, (String[])sentence.toArray(new String[sentence.size()]), null, this.contextGen, this.sequenceValidator);
    }

    @Override
    public Sequence[] topKSequences(String[] sentence) {
        return this.topKSequences(sentence, null);
    }

    @Override
    public Sequence[] topKSequences(String[] sentence, Object[] additionaContext) {
        return this.model.bestSequences(this.size, (String[])sentence, additionaContext, this.contextGen, this.sequenceValidator);
    }

    public void probs(double[] probs) {
        this.bestSequence.getProbs(probs);
    }

    public double[] probs() {
        return this.bestSequence.getProbs();
    }

    @Override
    @Deprecated
    public String tag(String sentence) {
        ArrayList<String> toks = new ArrayList<String>();
        StringTokenizer st = new StringTokenizer(sentence);
        while (st.hasMoreTokens()) {
            toks.add(st.nextToken());
        }
        List<String> tags = this.tag(toks);
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < tags.size(); ++i) {
            sb.append((String)toks.get(i) + "/" + tags.get(i) + " ");
        }
        return sb.toString().trim();
    }

    public String[] getOrderedTags(List<String> words, List<String> tags, int index) {
        return this.getOrderedTags(words, tags, index, null);
    }

    public String[] getOrderedTags(List<String> words, List<String> tags, int index, double[] tprobs) {
        if (this.modelPackage.getPosModel() != null) {
            MaxentModel posModel = this.modelPackage.getPosModel();
            double[] probs = posModel.eval(this.contextGen.getContext(index, words.toArray(new String[words.size()]), tags.toArray(new String[tags.size()]), (Object[])null));
            String[] orderedTags = new String[probs.length];
            for (int i = 0; i < probs.length; ++i) {
                int max = 0;
                for (int ti = 1; ti < probs.length; ++ti) {
                    if (!(probs[ti] > probs[max])) continue;
                    max = ti;
                }
                orderedTags[i] = posModel.getOutcome(max);
                if (tprobs != null) {
                    tprobs[i] = probs[max];
                }
                probs[max] = 0.0;
            }
            return orderedTags;
        }
        throw new UnsupportedOperationException("This method can only be called if the classifcation model is an event model!");
    }

    public static POSModel train(String languageCode, ObjectStream<POSSample> samples, TrainingParameters trainParams, POSTaggerFactory posFactory) throws IOException {
        String beamSizeString = trainParams.getSettings().get("BeamSize");
        int beamSize = 3;
        if (beamSizeString != null) {
            beamSize = Integer.parseInt(beamSizeString);
        }
        POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator();
        HashMap<String, String> manifestInfoEntries = new HashMap<String, String>();
        TrainerFactory.TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings());
        MaxentModel posModel = null;
        SequenceClassificationModel<String> seqPosModel = null;
        if (TrainerFactory.TrainerType.EVENT_MODEL_TRAINER.equals((Object)trainerType)) {
            POSSampleEventStream es = new POSSampleEventStream(samples, contextGenerator);
            EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams.getSettings(), manifestInfoEntries);
            posModel = trainer.train(es);
        } else if (TrainerFactory.TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals((Object)trainerType)) {
            POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
            EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer(trainParams.getSettings(), manifestInfoEntries);
            posModel = trainer.train(ss);
        } else if (TrainerFactory.TrainerType.SEQUENCE_TRAINER.equals((Object)trainerType)) {
            SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(trainParams.getSettings(), manifestInfoEntries);
            POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator);
            seqPosModel = trainer.train(ss);
        } else {
            throw new IllegalArgumentException("Trainer type is not supported: " + (Object)((Object)trainerType));
        }
        if (posModel != null) {
            return new POSModel(languageCode, posModel, beamSize, manifestInfoEntries, posFactory);
        }
        return new POSModel(languageCode, seqPosModel, manifestInfoEntries, posFactory);
    }

    public static POSModel train(String languageCode, ObjectStream<POSSample> samples, TrainingParameters trainParams, POSDictionary tagDictionary, Dictionary ngramDictionary) throws IOException {
        return POSTaggerME.train(languageCode, samples, trainParams, new POSTaggerFactory(ngramDictionary, tagDictionary));
    }

    @Deprecated
    public static POSModel train(String languageCode, ObjectStream<POSSample> samples, ModelType modelType, POSDictionary tagDictionary, Dictionary ngramDictionary, int cutoff, int iterations) throws IOException {
        TrainingParameters params = new TrainingParameters();
        params.put("Algorithm", modelType.toString());
        params.put("Iterations", Integer.toString(iterations));
        params.put("Cutoff", Integer.toString(cutoff));
        return POSTaggerME.train(languageCode, samples, params, tagDictionary, ngramDictionary);
    }

    public static Dictionary buildNGramDictionary(ObjectStream<POSSample> samples, int cutoff) throws IOException {
        POSSample sample;
        NGramModel ngramModel = new NGramModel();
        while ((sample = samples.read()) != null) {
            String[] words = sample.getSentence();
            if (words.length <= 0) continue;
            ngramModel.add(new StringList(words), 1, 1);
        }
        ngramModel.cutoff(cutoff, Integer.MAX_VALUE);
        return ngramModel.toDictionary(true);
    }

    public static void populatePOSDictionary(ObjectStream<POSSample> samples, MutableTagDictionary dict, int cutoff) throws IOException {
        POSSample sample;
        System.out.println("Expanding POS Dictionary ...");
        long start = System.nanoTime();
        HashMap newEntries = new HashMap();
        while ((sample = samples.read()) != null) {
            String[] words = sample.getSentence();
            String[] tags = sample.getTags();
            for (int i = 0; i < words.length; ++i) {
                String[] dictTags;
                if (StringPattern.recognize(words[i]).containsDigit()) continue;
                String word = dict.isCaseSensitive() ? words[i] : StringUtil.toLowerCase(words[i]);
                if (!newEntries.containsKey(word)) {
                    newEntries.put(word, new HashMap());
                }
                if ((dictTags = dict.getTags(word)) != null) {
                    for (String tag : dictTags) {
                        Map value = (Map)newEntries.get(word);
                        if (value.containsKey(tag)) continue;
                        value.put(tag, new AtomicInteger(cutoff));
                    }
                }
                if (!((Map)newEntries.get(word)).containsKey(tags[i])) {
                    ((Map)newEntries.get(word)).put(tags[i], new AtomicInteger(1));
                    continue;
                }
                ((AtomicInteger)((Map)newEntries.get(word)).get(tags[i])).incrementAndGet();
            }
        }
        for (Map.Entry wordEntry : newEntries.entrySet()) {
            ArrayList tagsForWord = new ArrayList();
            for (Map.Entry entry : ((Map)wordEntry.getValue()).entrySet()) {
                if (((AtomicInteger)entry.getValue()).get() < cutoff) continue;
                tagsForWord.add(entry.getKey());
            }
            if (tagsForWord.size() <= 0) continue;
            dict.put((String)wordEntry.getKey(), tagsForWord.toArray(new String[tagsForWord.size()]));
        }
        System.out.println("... finished expanding POS Dictionary. [" + (System.nanoTime() - start) / 1000000L + "ms]");
    }
}

