/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.wayang.tensorflow.model;

import org.apache.wayang.basic.model.DLModel;
import org.apache.wayang.basic.model.op.*;
import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss;
import org.apache.wayang.basic.model.op.nn.Linear;
import org.apache.wayang.basic.model.op.nn.Sigmoid;
import org.apache.wayang.basic.model.optimizer.GradientDescent;
import org.apache.wayang.basic.model.optimizer.Optimizer;
import org.junit.jupiter.api.Test;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.IntNdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;

class TensorflowModelTest {
    @Test
    void test() {
        FloatNdArray x = NdArrays.ofFloats(Shape.of(6, 4))
                .set(NdArrays.vectorOf(5.1f, 3.5f, 1.4f, 0.2f), 0)
                .set(NdArrays.vectorOf(4.9f, 3.0f, 1.4f, 0.2f), 1)
                .set(NdArrays.vectorOf(6.9f, 3.1f, 4.9f, 1.5f), 2)
                .set(NdArrays.vectorOf(5.5f, 2.3f, 4.0f, 1.3f), 3)
                .set(NdArrays.vectorOf(5.8f, 2.7f, 5.1f, 1.9f), 4)
                .set(NdArrays.vectorOf(6.7f, 3.3f, 5.7f, 2.5f), 5)
                ;
        IntNdArray y = NdArrays.vectorOf(0, 0, 1, 1, 2, 2);

        Input features = new Input(null, Input.Type.FEATURES, Op.DType.FLOAT32);
        Input labels = new Input(null, Input.Type.LABEL, Op.DType.INT32);

        DLModel model = new DLModel.Builder()
                .layer(features)
                .layer(new Linear(4, 64, true))
                .layer(new Sigmoid())
                .layer(new Linear(64, 3, true))
                .build();

        Op criterion = new CrossEntropyLoss(3);
        criterion.with(model.getOut(), labels);

        Op acc = new Mean(0);
        acc.with(new Cast(Op.DType.FLOAT32).with(new Eq().with(
                new ArgMax(1).with(model.getOut()),
                labels
        )));

        Optimizer optimizer = new GradientDescent(0.02f);

        try (TensorflowModel tfModel = new TensorflowModel(model, criterion, optimizer, acc)) {
            System.out.println(tfModel.getOut().getName());
            tfModel.train(x, y, 100, 6);
            TFloat32 predicted = tfModel.predict(x);
            Ops tf = Ops.create();
            org.tensorflow.op.math.ArgMax<TInt32> argMax = tf.math.argMax(tf.constantOf(predicted), tf.constant(1), TInt32.class);
            final TInt32 tensor = argMax.asTensor();
            System.out.print("[ ");
            for (int i = 0; i < tensor.shape().size(0); i++) {
                System.out.print(tensor.getInt(i) + " ");
            }
            System.out.println("]");
        }
        System.out.println();
    }
}
