/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.r;

import org.apache.spark.SparkException;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.UnaryTransformer;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.r.LDAWrapper;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReadable;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import scala.Array$;
import scala.Predef$;
import scala.collection.ArrayOps$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class LDAWrapper$
implements MLReadable<LDAWrapper> {
    public static final LDAWrapper$ MODULE$ = new LDAWrapper$();
    private static final String TOKENIZER_COL;
    private static final String STOPWORDS_REMOVER_COL;
    private static final String COUNT_VECTOR_COL;

    static {
        MLReadable.$init$(MODULE$);
        TOKENIZER_COL = String.valueOf(Identifiable$.MODULE$.randomUID("rawTokens"));
        STOPWORDS_REMOVER_COL = String.valueOf(Identifiable$.MODULE$.randomUID("tokens"));
        COUNT_VECTOR_COL = String.valueOf(Identifiable$.MODULE$.randomUID("features"));
    }

    public String TOKENIZER_COL() {
        return TOKENIZER_COL;
    }

    public String STOPWORDS_REMOVER_COL() {
        return STOPWORDS_REMOVER_COL;
    }

    public String COUNT_VECTOR_COL() {
        return COUNT_VECTOR_COL;
    }

    private PipelineStage[] getPreStages(String features, String[] customizedStopWords, int maxVocabSize) {
        RegexTokenizer tokenizer = (RegexTokenizer)((UnaryTransformer)new RegexTokenizer().setInputCol(features)).setOutputCol(this.TOKENIZER_COL());
        StopWordsRemover stopWordsRemover = new StopWordsRemover().setInputCol(this.TOKENIZER_COL()).setOutputCol(this.STOPWORDS_REMOVER_COL());
        stopWordsRemover.setStopWords((String[])ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.refArrayOps((Object[])stopWordsRemover.getStopWords()), (Object)customizedStopWords, ClassTag$.MODULE$.apply(String.class)));
        CountVectorizer countVectorizer = new CountVectorizer().setVocabSize(maxVocabSize).setInputCol(this.STOPWORDS_REMOVER_COL()).setOutputCol(this.COUNT_VECTOR_COL());
        return (PipelineStage[])((Object[])new PipelineStage[]{tokenizer, stopWordsRemover, countVectorizer});
    }

    public LDAWrapper fit(Dataset<Row> data, String features, int k, int maxIter, String optimizer, double subsamplingRate, double topicConcentration, double[] docConcentration, String[] customizedStopWords, int maxVocabSize) {
        String[] stringArray;
        PipelineStage[] pipelineStageArray;
        LDA lda = new LDA().setK(k).setMaxIter(maxIter).setSubsamplingRate(subsamplingRate).setOptimizer(optimizer);
        StructField featureSchema = data.schema().apply(features);
        DataType dataType = featureSchema.dataType();
        if (dataType instanceof StringType) {
            pipelineStageArray = (PipelineStage[])ArrayOps$.MODULE$.$plus$plus$extension(Predef$.MODULE$.refArrayOps((Object[])this.getPreStages(features, customizedStopWords, maxVocabSize)), (Object)new LDA[]{lda.setFeaturesCol(this.COUNT_VECTOR_COL())}, ClassTag$.MODULE$.apply(PipelineStage.class));
        } else if (dataType instanceof VectorUDT) {
            pipelineStageArray = (PipelineStage[])((Object[])new LDA[]{lda.setFeaturesCol(features)});
        } else {
            throw new SparkException(new StringBuilder(0).append(new StringBuilder(36).append("Unsupported input features type of ").append(featureSchema.dataType().typeName()).append(",").toString()).append(" only String type and Vector type are supported now.").toString());
        }
        PipelineStage[] stages = pipelineStageArray;
        Object object = topicConcentration != (double)-1 ? lda.setTopicConcentration(topicConcentration) : BoxedUnit.UNIT;
        LDA lDA = docConcentration.length == 1 ? (BoxesRunTime.unboxToDouble((Object)ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.doubleArrayOps(docConcentration))) != (double)-1 ? lda.setDocConcentration(BoxesRunTime.unboxToDouble((Object)ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.doubleArrayOps(docConcentration)))) : BoxedUnit.UNIT) : lda.setDocConcentration(docConcentration);
        Pipeline pipeline = new Pipeline().setStages(stages);
        Model model = pipeline.fit((Dataset)data);
        DataType dataType2 = featureSchema.dataType();
        if (dataType2 instanceof StringType) {
            CountVectorizerModel countVectorModel = (CountVectorizerModel)((PipelineModel)model).stages()[2];
            stringArray = countVectorModel.vocabulary();
        } else {
            stringArray = (String[])Array$.MODULE$.empty(ClassTag$.MODULE$.apply(String.class));
        }
        String[] vocabulary = stringArray;
        LDAModel ldaModel = (LDAModel)ArrayOps$.MODULE$.last$extension(Predef$.MODULE$.refArrayOps((Object[])((PipelineModel)model).stages()));
        PipelineModel preprocessor = new PipelineModel(String.valueOf(Identifiable$.MODULE$.randomUID(pipeline.uid())), (Transformer[])ArrayOps$.MODULE$.dropRight$extension(Predef$.MODULE$.refArrayOps((Object[])((PipelineModel)model).stages()), 1));
        Dataset<Row> preprocessedData = preprocessor.transform(data);
        return new LDAWrapper((PipelineModel)model, ldaModel.logLikelihood(preprocessedData), ldaModel.logPerplexity(preprocessedData), vocabulary);
    }

    @Override
    public MLReader<LDAWrapper> read() {
        return new LDAWrapper.LDAWrapperReader();
    }

    @Override
    public LDAWrapper load(String path) {
        return (LDAWrapper)MLReadable.load$(this, path);
    }

    private LDAWrapper$() {
    }
}

