/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.transform;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.StatFunctions;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Random;
import java.util.logging.Logger;

public class DownsampleLabelWords {
    private static Logger logger = MalletLogger.getLogger(DownsampleLabelWords.class.getName());
    static CommandOption.File inputFile = new CommandOption.File(DownsampleLabelWords.class, "input", "FILE", true, null, "Read the instance list from this file. This should be a Mallet instance list preserving feature sequence and a class label.", null);
    static CommandOption.File outputFile = new CommandOption.File(DownsampleLabelWords.class, "output", "FILE", true, null, "Write pruned instance list to this file.", null);
    static CommandOption.File reportFile = new CommandOption.File(DownsampleLabelWords.class, "report-file", "FILE", true, new File("removed_words.tsv"), "Write a tab-delimited report on words that were removed to this file", null);
    static CommandOption.Integer verboseInstances = new CommandOption.Integer(DownsampleLabelWords.class, "show", "INTEGER", false, 0, "Display the first [this number] instances, showing any deletions. This option is intended to help you feel confident that you know what this process is doing.", null);
    static CommandOption.Double samplingThreshold = new CommandOption.Double(DownsampleLabelWords.class, "threshold", "NUMBER", true, 0.05, "Threshold value for deciding whether a word is over-represented. Lower values will remove fewer tokens, higher values will remove more. The default should be a good choice for most applications.", null);
    static CommandOption.Integer randomSeed = new CommandOption.Integer(DownsampleLabelWords.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting subset of input words. Use this option if you need to repeat a process exactly.", null);

    public static void main(String[] args) throws FileNotFoundException, IOException {
        CommandOption.setSummary(DownsampleLabelWords.class, "A tool for removing words that are strongly associated with a particular document label.");
        CommandOption.process(DownsampleLabelWords.class, args);
        if (args.length == 0) {
            CommandOption.getList(DownsampleLabelWords.class).printUsage(false);
            System.exit(-1);
        }
        Random random = randomSeed.wasInvoked() ? new Random(DownsampleLabelWords.randomSeed.value) : new Random();
        InstanceList instances = InstanceList.load(DownsampleLabelWords.inputFile.value);
        Alphabet alphabet = instances.getDataAlphabet();
        LabelAlphabet labelAlphabet = (LabelAlphabet)instances.getTargetAlphabet();
        int numWords = alphabet.size();
        int numLabels = labelAlphabet.size();
        System.out.format("%d words, %d labels\n", numWords, numLabels);
        int[][] wordLabelCounts = new int[numWords][numLabels];
        int[] labelCounts = new int[numLabels];
        double[][] labelWordProbs = new double[numLabels][numWords];
        for (Instance instance : instances) {
            Label labelObject = (Label)instance.getTarget();
            int label = labelObject.getIndex();
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            int n = label;
            labelCounts[n] = labelCounts[n] + tokens.size();
            for (int position = 0; position < tokens.size(); ++position) {
                int type = tokens.getIndexAtPosition(position);
                int[] nArray = wordLabelCounts[type];
                int n2 = label;
                nArray[n2] = nArray[n2] + 1;
            }
        }
        for (int word = 0; word < numWords; ++word) {
            double[] proportions = new double[numLabels];
            double sum = 0.0;
            double sumSquares = 0.0;
            double nonZeros = 0.0;
            for (int label = 0; label < numLabels; ++label) {
                if (wordLabelCounts[word][label] <= 0) continue;
                proportions[label] = (double)wordLabelCounts[word][label] / (double)labelCounts[label];
                nonZeros += 1.0;
                sum += proportions[label];
                sumSquares += proportions[label] * proportions[label];
            }
            double mean = sum / (double)numLabels;
            double variance = sumSquares - 2.0 * sum * mean + (double)numLabels * mean * mean;
            variance += ((double)numLabels - nonZeros) * mean * mean;
            double shape = mean * mean / (variance /= (double)numLabels);
            double scale = variance / mean;
            double threshold = StatFunctions.gammaInverseCDF(1.0 - DownsampleLabelWords.samplingThreshold.value, shape, scale);
            for (int label = 0; label < numLabels; ++label) {
                labelWordProbs[label][word] = proportions[label] > threshold ? threshold / proportions[label] : 1.0;
            }
        }
        InstanceList downsampledInstances = new InstanceList(instances.getPipe());
        int instanceCounter = 0;
        StringBuilder instanceDisplay = null;
        int inputTokens = 0;
        int outputTokens = 0;
        int[][] wordLabelRemovalCounts = new int[numWords][numLabels];
        for (Instance instance : instances) {
            Label labelObject = (Label)instance.getTarget();
            int label = labelObject.getIndex();
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            if (instanceCounter < DownsampleLabelWords.verboseInstances.value) {
                instanceDisplay = new StringBuilder();
            }
            int[] sampledWords = new int[tokens.size()];
            int actualLength = 0;
            for (int position = 0; position < tokens.size(); ++position) {
                int type = tokens.getIndexAtPosition(position);
                double prob = labelWordProbs[label][type];
                ++inputTokens;
                if (random.nextDouble() < prob) {
                    sampledWords[actualLength] = type;
                    ++actualLength;
                    ++outputTokens;
                    if (instanceDisplay == null) continue;
                    instanceDisplay.append(alphabet.lookupObject(type) + " ");
                    continue;
                }
                int[] nArray = wordLabelRemovalCounts[type];
                int n = label;
                nArray[n] = nArray[n] + 1;
                if (instanceDisplay == null) continue;
                instanceDisplay.append("[" + alphabet.lookupObject(type) + "] ");
            }
            if (instanceDisplay != null) {
                logger.info(instanceDisplay.toString());
                instanceDisplay = null;
            }
            FeatureSequence downsampledFS = new FeatureSequence(alphabet, sampledWords, actualLength);
            downsampledInstances.add(new Instance(downsampledFS, instance.getTarget(), instance.getName(), instance.getSource()));
            ++instanceCounter;
        }
        if (DownsampleLabelWords.reportFile.value != null) {
            PrintWriter reportWriter = new PrintWriter(DownsampleLabelWords.reportFile.value);
            reportWriter.println("Word\tLabel\tCount");
            for (int word = 0; word < numWords; ++word) {
                for (int label = 0; label < numLabels; ++label) {
                    if (wordLabelRemovalCounts[word][label] <= 0) continue;
                    reportWriter.format("%s\t%s\t%d\n", alphabet.lookupObject(word), labelAlphabet.lookupObject(label), wordLabelRemovalCounts[word][label]);
                }
            }
        }
        logger.info("reduced " + inputTokens + " to " + outputTokens + " tokens");
        downsampledInstances.save(DownsampleLabelWords.outputFile.value);
    }
}

