/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.ExpGain;
import cc.mallet.types.FeatureInducer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.GradientGain;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelVector;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Random;
import java.util.logging.Logger;

public class CRFTrainerByLabelLikelihood
extends TransducerTrainer
implements TransducerTrainer.ByOptimization {
    private static Logger logger = MalletLogger.getLogger(CRFTrainerByLabelLikelihood.class.getName());
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;
    CRF crf;
    CRFOptimizableByLabelLikelihood ocrf;
    Optimizer opt;
    int iterationCount = 0;
    boolean converged;
    boolean usingHyperbolicPrior = false;
    double gaussianPriorVariance = 1.0;
    double hyperbolicPriorSlope = 0.2;
    double hyperbolicPriorSharpness = 10.0;
    boolean useSparseWeights = true;
    boolean useNoWeights = false;
    private transient boolean useSomeUnsupportedTrick = true;
    private int cachedValueWeightsStamp = -1;
    private int cachedGradientWeightsStamp = -1;
    private int cachedWeightsStructureStamp = -1;
    public boolean printGradient = false;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;

    public CRFTrainerByLabelLikelihood(CRF crf) {
        this.crf = crf;
    }

    @Override
    public Transducer getTransducer() {
        return this.crf;
    }

    public CRF getCRF() {
        return this.crf;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.opt;
    }

    public boolean isConverged() {
        return this.converged;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override
    public int getIteration() {
        return this.iterationCount;
    }

    public void setAddNoFactors(boolean flag) {
        this.useNoWeights = flag;
    }

    public CRFOptimizableByLabelLikelihood getOptimizableCRF(InstanceList trainingSet) {
        if (this.cachedWeightsStructureStamp != this.crf.weightsStructureChangeStamp) {
            if (!this.useNoWeights) {
                if (this.useSparseWeights) {
                    this.crf.setWeightsDimensionAsIn(trainingSet, this.useSomeUnsupportedTrick);
                } else {
                    this.crf.setWeightsDimensionDensely();
                }
            }
            this.ocrf = null;
            this.cachedWeightsStructureStamp = this.crf.weightsStructureChangeStamp;
        }
        if (this.ocrf == null || this.ocrf.trainingSet != trainingSet) {
            this.ocrf = new CRFOptimizableByLabelLikelihood(this.crf, trainingSet);
            this.ocrf.setGaussianPriorVariance(this.gaussianPriorVariance);
            this.ocrf.setHyperbolicPriorSharpness(this.hyperbolicPriorSharpness);
            this.ocrf.setHyperbolicPriorSlope(this.hyperbolicPriorSlope);
            this.ocrf.setUseHyperbolicPrior(this.usingHyperbolicPrior);
            this.opt = null;
        }
        return this.ocrf;
    }

    public Optimizer getOptimizer(InstanceList trainingSet) {
        this.getOptimizableCRF(trainingSet);
        if (this.opt == null || this.ocrf != this.opt.getOptimizable()) {
            this.opt = new LimitedMemoryBFGS(this.ocrf);
        }
        return this.opt;
    }

    public boolean trainIncremental(InstanceList training) {
        return this.train(training, Integer.MAX_VALUE);
    }

    @Override
    public boolean train(InstanceList trainingSet, int numIterations) {
        if (numIterations <= 0) {
            return false;
        }
        assert (trainingSet.size() > 0);
        this.getOptimizableCRF(trainingSet);
        this.getOptimizer(trainingSet);
        boolean converged = false;
        logger.info("CRF about to train with " + numIterations + " iterations");
        for (int i = 0; i < numIterations; ++i) {
            try {
                converged = this.opt.optimize(1);
                ++this.iterationCount;
                logger.info("CRF finished one iteration of maximizer, i=" + i);
                this.runEvaluators();
            }
            catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                converged = true;
            }
            catch (Exception e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                converged = true;
            }
            if (!converged) continue;
            logger.info("CRF training has converged, i=" + i);
            break;
        }
        return converged;
    }

    public boolean train(InstanceList training, int numIterationsPerProportion, double[] trainingProportions) {
        int trainingIteration = 0;
        assert (trainingProportions.length > 0);
        boolean converged = false;
        for (int i = 0; i < trainingProportions.length; ++i) {
            assert (trainingProportions[i] <= 1.0);
            logger.info("Training on " + 100.0 * trainingProportions[i] + "% of the data this round.");
            converged = trainingProportions[i] == 1.0 ? this.train(training, numIterationsPerProportion) : this.train(training.split(new Random(1L), new double[]{trainingProportions[i], 1.0 - trainingProportions[i]})[0], numIterationsPerProportion);
            trainingIteration += numIterationsPerProportion;
        }
        return converged;
    }

    public boolean trainWithFeatureInduction(InstanceList trainingData, InstanceList validationData, InstanceList testingData, TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions) {
        return this.trainWithFeatureInduction(trainingData, validationData, testingData, eval, numIterations, numIterationsBetweenFeatureInductions, numFeatureInductions, numFeaturesPerFeatureInduction, trueLabelProbThreshold, clusteredFeatureInduction, trainingProportions, "exp");
    }

    public boolean trainWithFeatureInduction(InstanceList trainingData, InstanceList validationData, InstanceList testingData, TransducerEvaluator eval, int numIterations, int numIterationsBetweenFeatureInductions, int numFeatureInductions, int numFeaturesPerFeatureInduction, double trueLabelProbThreshold, boolean clusteredFeatureInduction, double[] trainingProportions, String gainName) {
        int trainingIteration = 0;
        int numLabels = this.crf.outputAlphabet.size();
        this.crf.globalFeatureSelection = trainingData.getFeatureSelection();
        if (this.crf.globalFeatureSelection == null) {
            this.crf.globalFeatureSelection = new FeatureSelection(trainingData.getDataAlphabet());
            trainingData.setFeatureSelection(this.crf.globalFeatureSelection);
        }
        if (validationData != null) {
            validationData.setFeatureSelection(this.crf.globalFeatureSelection);
        }
        if (testingData != null) {
            testingData.setFeatureSelection(this.crf.globalFeatureSelection);
        }
        for (int featureInductionIteration = 0; featureInductionIteration < numFeatureInductions; ++featureInductionIteration) {
            int i;
            logger.info("Feature induction iteration " + featureInductionIteration);
            InstanceList theTrainingData = trainingData;
            if (trainingProportions != null && featureInductionIteration < trainingProportions.length) {
                logger.info("Training on " + trainingProportions[featureInductionIteration] + "% of the data this round.");
                InstanceList[] sampledTrainingData = trainingData.split(new Random(1L), new double[]{trainingProportions[featureInductionIteration], 1.0 - trainingProportions[featureInductionIteration]});
                theTrainingData = sampledTrainingData[0];
                theTrainingData.setFeatureSelection(this.crf.globalFeatureSelection);
                logger.info("  which is " + theTrainingData.size() + " instances");
            }
            boolean converged = false;
            if (featureInductionIteration != 0) {
                converged = this.train(theTrainingData, numIterationsBetweenFeatureInductions);
            }
            trainingIteration += numIterationsBetweenFeatureInductions;
            logger.info("Starting feature induction with " + this.crf.inputAlphabet.size() + " features.");
            InstanceList errorInstances = new InstanceList(trainingData.getDataAlphabet(), trainingData.getTargetAlphabet());
            errorInstances.setFeatureSelection(this.crf.globalFeatureSelection);
            ArrayList<LabelVector> errorLabelVectors = new ArrayList<LabelVector>();
            InstanceList[][] clusteredErrorInstances = new InstanceList[numLabels][numLabels];
            ArrayList[][] clusteredErrorLabelVectors = new ArrayList[numLabels][numLabels];
            for (i = 0; i < numLabels; ++i) {
                for (int j = 0; j < numLabels; ++j) {
                    clusteredErrorInstances[i][j] = new InstanceList(trainingData.getDataAlphabet(), trainingData.getTargetAlphabet());
                    clusteredErrorInstances[i][j].setFeatureSelection(this.crf.globalFeatureSelection);
                    clusteredErrorLabelVectors[i][j] = new ArrayList();
                }
            }
            for (i = 0; i < theTrainingData.size(); ++i) {
                logger.info("instance=" + i);
                Instance instance = (Instance)theTrainingData.get(i);
                Sequence input = (Sequence)instance.getData();
                Sequence trueOutput = (Sequence)instance.getTarget();
                assert (input.size() == trueOutput.size());
                SumLattice lattice = this.crf.sumLatticeFactory.newSumLattice((Transducer)this.crf, input, (Sequence)null, (Transducer.Incrementor)null, (LabelAlphabet)theTrainingData.getTargetAlphabet());
                int prevLabelIndex = 0;
                for (int j = 0; j < trueOutput.size(); ++j) {
                    Label label = ((LabelSequence)trueOutput).getLabelAtPosition(j);
                    assert (label != null);
                    LabelVector latticeLabeling = lattice.getLabelingAtPosition(j);
                    double trueLabelProb = latticeLabeling.value(label.getIndex());
                    int labelIndex = latticeLabeling.getBestIndex();
                    if (trueLabelProb < trueLabelProbThreshold) {
                        logger.info("Adding error: instance=" + i + " position=" + j + " prtrue=" + trueLabelProb + (label == latticeLabeling.getBestLabel() ? "  " : " *") + " truelabel=" + label + " predlabel=" + latticeLabeling.getBestLabel() + " fv=" + ((FeatureVector)input.get(j)).toString(true));
                        errorInstances.add(input.get(j), label, null, null);
                        errorLabelVectors.add(latticeLabeling);
                        clusteredErrorInstances[prevLabelIndex][labelIndex].add(input.get(j), label, null, null);
                        clusteredErrorLabelVectors[prevLabelIndex][labelIndex].add(latticeLabeling);
                    }
                    prevLabelIndex = labelIndex;
                }
            }
            logger.info("Error instance list size = " + errorInstances.size());
            if (clusteredFeatureInduction) {
                int i2;
                FeatureInducer[][] klfi = new FeatureInducer[numLabels][numLabels];
                for (i2 = 0; i2 < numLabels; ++i2) {
                    for (int j = 0; j < numLabels; ++j) {
                        logger.info("Doing feature induction for " + this.crf.outputAlphabet.lookupObject(i2) + " -> " + this.crf.outputAlphabet.lookupObject(j) + " with " + clusteredErrorInstances[i2][j].size() + " instances");
                        if (clusteredErrorInstances[i2][j].size() < 20) {
                            logger.info("..skipping because only " + clusteredErrorInstances[i2][j].size() + " instances.");
                            continue;
                        }
                        int s = clusteredErrorLabelVectors[i2][j].size();
                        LabelVector[] lvs = new LabelVector[s];
                        for (int k = 0; k < s; ++k) {
                            lvs[k] = (LabelVector)clusteredErrorLabelVectors[i2][j].get(k);
                        }
                        RankedFeatureVector.Factory gainFactory = null;
                        if (gainName.equals("exp")) {
                            gainFactory = new ExpGain.Factory(lvs, this.gaussianPriorVariance);
                        } else if (gainName.equals("grad")) {
                            gainFactory = new GradientGain.Factory(lvs);
                        } else if (gainName.equals("info")) {
                            gainFactory = new InfoGain.Factory();
                        }
                        klfi[i2][j] = new FeatureInducer(gainFactory, clusteredErrorInstances[i2][j], numFeaturesPerFeatureInduction, 2 * numFeaturesPerFeatureInduction, 2 * numFeaturesPerFeatureInduction);
                        this.crf.featureInducers.add(klfi[i2][j]);
                    }
                }
                for (i2 = 0; i2 < numLabels; ++i2) {
                    for (int j = 0; j < numLabels; ++j) {
                        logger.info("Adding new induced features for " + this.crf.outputAlphabet.lookupObject(i2) + " -> " + this.crf.outputAlphabet.lookupObject(j));
                        if (klfi[i2][j] == null) {
                            logger.info("...skipping because no features induced.");
                            continue;
                        }
                        klfi[i2][j].induceFeaturesFor(trainingData, false, false);
                        if (testingData == null) continue;
                        klfi[i2][j].induceFeaturesFor(testingData, false, false);
                    }
                }
                klfi = null;
                continue;
            }
            int s = errorLabelVectors.size();
            LabelVector[] lvs = new LabelVector[s];
            for (int i3 = 0; i3 < s; ++i3) {
                lvs[i3] = (LabelVector)errorLabelVectors.get(i3);
            }
            RankedFeatureVector.Factory gainFactory = null;
            if (gainName.equals("exp")) {
                gainFactory = new ExpGain.Factory(lvs, this.gaussianPriorVariance);
            } else if (gainName.equals("grad")) {
                gainFactory = new GradientGain.Factory(lvs);
            } else if (gainName.equals("info")) {
                gainFactory = new InfoGain.Factory();
            }
            FeatureInducer klfi = new FeatureInducer(gainFactory, errorInstances, numFeaturesPerFeatureInduction, 2 * numFeaturesPerFeatureInduction, 2 * numFeaturesPerFeatureInduction);
            this.crf.featureInducers.add(klfi);
            klfi.induceFeaturesFor(trainingData, false, false);
            if (testingData != null) {
                klfi.induceFeaturesFor(testingData, false, false);
            }
            logger.info("CRF4 FeatureSelection now includes " + this.crf.globalFeatureSelection.cardinality() + " features");
            klfi = null;
        }
        return this.train(trainingData, numIterations - trainingIteration);
    }

    public void setUseHyperbolicPrior(boolean f) {
        this.usingHyperbolicPrior = f;
    }

    public void setHyperbolicPriorSlope(double p) {
        this.hyperbolicPriorSlope = p;
    }

    public void setHyperbolicPriorSharpness(double p) {
        this.hyperbolicPriorSharpness = p;
    }

    public double getUseHyperbolicPriorSlope() {
        return this.hyperbolicPriorSlope;
    }

    public double getUseHyperbolicPriorSharpness() {
        return this.hyperbolicPriorSharpness;
    }

    public void setGaussianPriorVariance(double p) {
        this.gaussianPriorVariance = p;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    public void setUseSparseWeights(boolean b) {
        this.useSparseWeights = b;
    }

    public boolean getUseSparseWeights() {
        return this.useSparseWeights;
    }

    public void setUseSomeUnsupportedTrick(boolean b) {
        this.useSomeUnsupportedTrick = b;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeBoolean(this.usingHyperbolicPrior);
        out.writeDouble(this.gaussianPriorVariance);
        out.writeDouble(this.hyperbolicPriorSlope);
        out.writeDouble(this.hyperbolicPriorSharpness);
        out.writeInt(this.cachedGradientWeightsStamp);
        out.writeInt(this.cachedValueWeightsStamp);
        out.writeInt(this.cachedWeightsStructureStamp);
        out.writeBoolean(this.printGradient);
        out.writeBoolean(this.useSparseWeights);
        throw new IllegalStateException("Implementation not yet complete.");
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.usingHyperbolicPrior = in.readBoolean();
        this.gaussianPriorVariance = in.readDouble();
        this.hyperbolicPriorSlope = in.readDouble();
        this.hyperbolicPriorSharpness = in.readDouble();
        this.printGradient = in.readBoolean();
        this.useSparseWeights = in.readBoolean();
        throw new IllegalStateException("Implementation not yet complete.");
    }
}

