/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.dependency.nnparser;

import com.hankcs.hanlp.dependency.nnparser.Dependency;
import com.hankcs.hanlp.dependency.nnparser.action.Action;
import com.hankcs.hanlp.dependency.nnparser.action.ActionFactory;
import com.hankcs.hanlp.dependency.nnparser.util.std;
import java.util.ArrayList;
import java.util.List;

public class State {
    List<Integer> stack;
    int buffer;
    State previous;
    Dependency ref;
    double score;
    Action last_action;
    int top0;
    int top1;
    List<Integer> heads;
    List<Integer> deprels;
    List<Integer> nr_left_children;
    List<Integer> nr_right_children;
    List<Integer> left_most_child;
    List<Integer> right_most_child;
    List<Integer> left_2nd_most_child;
    List<Integer> right_2nd_most_child;

    public State() {
    }

    public State(Dependency ref) {
        this.ref = ref;
        this.stack = new ArrayList<Integer>(ref.size());
        this.clear();
        int L = ref.size();
        this.heads = std.create(L, -1);
        this.deprels = std.create(L, 0);
        this.nr_left_children = std.create(L, 0);
        this.nr_right_children = std.create(L, 0);
        this.left_most_child = std.create(L, -1);
        this.right_most_child = std.create(L, -1);
        this.left_2nd_most_child = std.create(L, -1);
        this.right_2nd_most_child = std.create(L, -1);
    }

    void clear() {
        this.score = 0.0;
        this.previous = null;
        this.top0 = -1;
        this.top1 = -1;
        this.buffer = 0;
        this.stack.clear();
        std.fill(this.heads, -1);
        std.fill(this.deprels, 0);
        std.fill(this.nr_left_children, 0);
        std.fill(this.nr_right_children, 0);
        std.fill(this.left_most_child, -1);
        std.fill(this.right_most_child, -1);
        std.fill(this.left_2nd_most_child, -1);
        std.fill(this.right_2nd_most_child, -1);
    }

    boolean can_shift() {
        return !this.buffer_empty();
    }

    boolean can_left_arc() {
        return this.stack_size() >= 2;
    }

    boolean can_right_arc() {
        return this.stack_size() >= 2;
    }

    void copy(State source) {
        this.ref = source.ref;
        this.score = source.score;
        this.previous = source.previous;
        this.buffer = source.buffer;
        this.top0 = source.top0;
        this.top1 = source.top1;
        this.stack = source.stack;
        this.last_action = source.last_action;
        this.heads = source.heads;
        this.deprels = source.deprels;
        this.left_most_child = source.left_most_child;
        this.right_most_child = source.right_most_child;
        this.left_2nd_most_child = source.left_2nd_most_child;
        this.right_2nd_most_child = source.right_2nd_most_child;
        this.nr_left_children = source.nr_left_children;
        this.nr_right_children = source.nr_right_children;
    }

    void refresh_stack_information() {
        int sz = this.stack.size();
        if (0 == sz) {
            this.top0 = -1;
            this.top1 = -1;
        } else if (1 == sz) {
            this.top0 = this.stack.get(sz - 1);
            this.top1 = -1;
        } else {
            this.top0 = this.stack.get(sz - 1);
            this.top1 = this.stack.get(sz - 2);
        }
    }

    boolean shift(State source) {
        if (!source.can_shift()) {
            return false;
        }
        this.copy(source);
        this.stack.add(this.buffer);
        this.refresh_stack_information();
        ++this.buffer;
        this.last_action = ActionFactory.make_shift();
        this.previous = source;
        return true;
    }

    boolean left_arc(State source, int deprel) {
        if (!source.can_left_arc()) {
            return false;
        }
        this.copy(source);
        this.stack.remove(this.stack.size() - 1);
        this.stack.set(this.stack.size() - 1, this.top0);
        this.heads.set(this.top1, this.top0);
        this.deprels.set(this.top1, deprel);
        if (-1 == this.left_most_child.get(this.top0)) {
            this.left_most_child.set(this.top0, this.top1);
        } else if (this.top1 < this.left_most_child.get(this.top0)) {
            this.left_2nd_most_child.set(this.top0, this.left_most_child.get(this.top0));
            this.left_most_child.set(this.top0, this.top1);
        } else if (this.top1 < this.left_2nd_most_child.get(this.top0)) {
            this.left_2nd_most_child.set(this.top0, this.top1);
        }
        this.nr_left_children.set(this.top0, this.nr_left_children.get(this.top0) + 1);
        this.refresh_stack_information();
        this.last_action = ActionFactory.make_left_arc(deprel);
        this.previous = source;
        return true;
    }

    boolean right_arc(State source, int deprel) {
        if (!source.can_right_arc()) {
            return false;
        }
        this.copy(source);
        std.pop_back(this.stack);
        this.heads.set(this.top0, this.top1);
        this.deprels.set(this.top0, deprel);
        if (-1 == this.right_most_child.get(this.top1)) {
            this.right_most_child.set(this.top1, this.top0);
        } else if (this.right_most_child.get(this.top1) < this.top0) {
            this.right_2nd_most_child.set(this.top1, this.right_most_child.get(this.top1));
            this.right_most_child.set(this.top1, this.top0);
        } else if (this.right_2nd_most_child.get(this.top1) < this.top0) {
            this.right_2nd_most_child.set(this.top1, this.top0);
        }
        this.nr_right_children.set(this.top1, this.nr_right_children.get(this.top1) + 1);
        this.refresh_stack_information();
        this.last_action = ActionFactory.make_right_arc(deprel);
        this.previous = source;
        return true;
    }

    int cost(List<Integer> gold_heads, List<Integer> gold_deprels) {
        int[][][] T;
        ArrayList tree = new ArrayList(gold_heads.size());
        for (int i = 0; i < gold_heads.size(); ++i) {
            int h = gold_heads.get(i);
            if (h < 0) continue;
            ((List)tree.get(h)).add(i);
        }
        List<Integer> sigma_l = this.stack;
        ArrayList<Integer> sigma_r = new ArrayList<Integer>();
        sigma_r.add(this.stack.get(this.stack.size() - 1));
        boolean[] sigma_l_mask = new boolean[gold_heads.size()];
        boolean[] sigma_r_mask = new boolean[gold_heads.size()];
        for (int s = 0; s < sigma_l.size(); ++s) {
            sigma_l_mask[sigma_l.get((int)s).intValue()] = true;
        }
        block2: for (int i = this.buffer; i < this.ref.size(); ++i) {
            if (gold_heads.get(i) < this.buffer) {
                sigma_r.add(i);
                sigma_r_mask[i] = true;
                continue;
            }
            List node = (List)tree.get(i);
            for (int d = 0; d < node.size(); ++d) {
                if (!sigma_l_mask[(Integer)node.get(d)] && !sigma_r_mask[(Integer)node.get(d)]) continue;
                sigma_r.add(i);
                sigma_r_mask[i] = true;
                continue block2;
            }
        }
        int len_l = sigma_l.size();
        int len_r = sigma_r.size();
        int[][][] nArray = T = new int[len_l][len_r][len_l + len_r - 1];
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int[][] one;
            for (int[] two : one = nArray[i]) {
                for (int i2 = 0; i2 < two.length; ++i2) {
                    two[i2] = 1024;
                }
            }
        }
        T[0][0][len_l - 1] = 0;
        for (int d = 0; d < len_l + len_r - 1; ++d) {
            for (int j = Math.max(0, d - len_l + 1); j < Math.min(d + 1, len_r); ++j) {
                int h_rank;
                int h;
                int rank;
                int i = d - j;
                if (i < len_l - 1) {
                    int i_1 = sigma_l.get(len_l - i - 2);
                    int i_1_rank = len_l - i - 2;
                    for (rank = len_l - i - 1; rank < len_l; ++rank) {
                        h = sigma_l.get(rank);
                        h_rank = rank;
                        T[i + 1][j][h_rank] = Math.min(T[i + 1][j][h_rank], T[i][j][h_rank] + (gold_heads.get(i_1) == h ? 0 : 2));
                        T[i + 1][j][i_1_rank] = Math.min(T[i + 1][j][i_1_rank], T[i][j][h_rank] + (gold_heads.get(h) == i_1 ? 0 : 2));
                    }
                    for (rank = 1; rank < j + 1; ++rank) {
                        h = (Integer)sigma_r.get(rank);
                        h_rank = len_l + rank - 1;
                        T[i + 1][j][h_rank] = Math.min(T[i + 1][j][h_rank], T[i][j][h_rank] + (gold_heads.get(i_1) == h ? 0 : 2));
                        T[i + 1][j][i_1_rank] = Math.min(T[i + 1][j][i_1_rank], T[i][j][h_rank] + (gold_heads.get(h) == i_1 ? 0 : 2));
                    }
                }
                if (j >= len_r - 1) continue;
                int j_1 = (Integer)sigma_r.get(j + 1);
                int j_1_rank = len_l + j;
                for (rank = len_l - i - 1; rank < len_l; ++rank) {
                    h = sigma_l.get(rank);
                    h_rank = rank;
                    T[i][j + 1][h_rank] = Math.min(T[i][j + 1][h_rank], T[i][j][h_rank] + (gold_heads.get(j_1) == h ? 0 : 2));
                    T[i][j + 1][j_1_rank] = Math.min(T[i][j + 1][j_1_rank], T[i][j][h_rank] + (gold_heads.get(h) == j_1 ? 0 : 2));
                }
                for (rank = 1; rank < j + 1; ++rank) {
                    h = (Integer)sigma_r.get(rank);
                    h_rank = len_l + rank - 1;
                    T[i][j + 1][h_rank] = Math.min(T[i][j + 1][h_rank], T[i][j][h_rank] + (gold_heads.get(j_1) == h ? 0 : 2));
                    T[i][j + 1][j_1_rank] = Math.min(T[i][j + 1][j_1_rank], T[i][j][h_rank] + (gold_heads.get(h) == j_1 ? 0 : 2));
                }
            }
        }
        int penalty = 0;
        for (int i = 0; i < this.buffer; ++i) {
            if (this.heads.get(i) == -1) continue;
            if (this.heads.get(i) != gold_heads.get(i)) {
                penalty += 2;
                continue;
            }
            if (this.deprels.get(i) == gold_deprels.get(i)) continue;
            ++penalty;
        }
        return T[len_l - 1][len_r - 1][0] + penalty;
    }

    boolean buffer_empty() {
        return this.buffer == this.ref.size();
    }

    int stack_size() {
        return this.stack.size();
    }
}

