/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.parse;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.antlr.runtime.TokenRewriteStream;
import org.apache.calcite.sql.SqlKind;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.QueryState;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.metadata.HiveUtils;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.parse.ASTErrorUtils;
import org.apache.hadoop.hive.ql.parse.ASTNode;
import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer;
import org.apache.hadoop.hive.ql.parse.RewriteSemanticAnalyzer;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.UnparseTranslator;
import org.apache.hadoop.hive.ql.parse.rewrite.MergeStatement;
import org.apache.hadoop.hive.ql.parse.rewrite.RewriterFactory;

public class MergeSemanticAnalyzer
extends RewriteSemanticAnalyzer<MergeStatement> {
    private static final String MERGE_INSERT_VALUES_PROGRAM = "MERGE_INSERT_VALUES_PROGRAM";
    private int numWhenMatchedUpdateClauses;
    private int numWhenMatchedDeleteClauses;
    private IdentifierQuoter quotedIdentifierHelper;

    MergeSemanticAnalyzer(QueryState queryState, RewriterFactory<MergeStatement> rewriterFactory) throws SemanticException {
        super(queryState, rewriterFactory);
        queryState.setSqlKind(SqlKind.MERGE);
    }

    @Override
    protected ASTNode getTargetTableNode(ASTNode tree) {
        return (ASTNode)tree.getChild(0);
    }

    @Override
    public void analyze(ASTNode tree, Table targetTable, ASTNode targetNameNode) throws SemanticException {
        this.quotedIdentifierHelper = new IdentifierQuoter(this.ctx.getTokenRewriteStream());
        if (tree.getToken().getType() != 1091) {
            throw new RuntimeException("Asked to parse token " + tree.getName() + " in MergeSemanticAnalyzer");
        }
        ASTNode source = (ASTNode)tree.getChild(1);
        String targetAlias = this.getSimpleTableName(targetNameNode);
        String sourceName = this.getSimpleTableName(source);
        ASTNode onClause = (ASTNode)tree.getChild(2);
        String onClauseAsText = this.getMatchedText(onClause);
        MergeStatement.MergeStatementBuilder mergeStatementBuilder = MergeStatement.withTarget(targetTable, this.getFullTableNameForSQL(targetNameNode), targetAlias).sourceName(sourceName).sourceAlias(this.getSourceAlias(source, sourceName)).onClauseAsText(onClauseAsText);
        int whenClauseBegins = 3;
        boolean hasHint = false;
        ASTNode qHint = (ASTNode)tree.getChild(3);
        if (qHint.getType() == 427) {
            hasHint = true;
            ++whenClauseBegins;
        }
        List<ASTNode> whenClauses = this.findWhenClauses(tree, whenClauseBegins);
        if (hasHint) {
            mergeStatementBuilder.hintStr(String.format(" /*+ %s */ ", qHint.getText()));
        }
        String extraPredicate = null;
        int numInsertClauses = 0;
        this.numWhenMatchedUpdateClauses = 0;
        this.numWhenMatchedDeleteClauses = 0;
        for (ASTNode whenClause : whenClauses) {
            switch (this.getWhenClauseOperation(whenClause).getType()) {
                case 1055: {
                    ++numInsertClauses;
                    OnClauseAnalyzer oca = new OnClauseAnalyzer(onClause, targetTable, targetAlias, this.conf, onClauseAsText);
                    oca.analyze();
                    mergeStatementBuilder.addWhenClause(this.handleInsert(whenClause, oca.getPredicate(), targetTable)).onClausePredicate(oca.getPredicate());
                    break;
                }
                case 1304: {
                    ++this.numWhenMatchedUpdateClauses;
                    MergeStatement.UpdateClause updateClause = this.handleUpdate(whenClause, targetTable, extraPredicate);
                    mergeStatementBuilder.addWhenClause(updateClause);
                    if (this.numWhenMatchedUpdateClauses + this.numWhenMatchedDeleteClauses != 1) break;
                    extraPredicate = updateClause.getExtraPredicate();
                    break;
                }
                case 987: {
                    ++this.numWhenMatchedDeleteClauses;
                    MergeStatement.DeleteClause deleteClause = this.handleDelete(whenClause, extraPredicate);
                    mergeStatementBuilder.addWhenClause(deleteClause);
                    if (this.numWhenMatchedUpdateClauses + this.numWhenMatchedDeleteClauses != 1) break;
                    extraPredicate = deleteClause.getExtraPredicate();
                    break;
                }
                default: {
                    throw new IllegalStateException("Unexpected WHEN clause type: " + whenClause.getType() + MergeSemanticAnalyzer.addParseInfo(whenClause));
                }
            }
            if (this.numWhenMatchedDeleteClauses > 1) {
                throw new SemanticException(ErrorMsg.MERGE_TOO_MANY_DELETE, new String[]{this.ctx.getCmd()});
            }
            if (this.numWhenMatchedUpdateClauses > 1) {
                throw new SemanticException(ErrorMsg.MERGE_TOO_MANY_UPDATE, new String[]{this.ctx.getCmd()});
            }
            assert (numInsertClauses < 2) : "too many Insert clauses";
        }
        if (this.numWhenMatchedDeleteClauses + this.numWhenMatchedUpdateClauses == 2 && extraPredicate == null) {
            throw new SemanticException(ErrorMsg.MERGE_PREDIACTE_REQUIRED, new String[]{this.ctx.getCmd()});
        }
        String subQueryAlias = this.isAliased(targetNameNode) ? targetAlias : targetTable.getTTable().getTableName();
        this.rewriteAndAnalyze(mergeStatementBuilder.build(), subQueryAlias);
        this.updateOutputs(targetTable);
    }

    private String getSourceAlias(ASTNode source, String sourceName) throws SemanticException {
        String sourceAlias;
        if (source.getType() == 1236) {
            sourceAlias = this.getMatchedText(source);
        } else {
            sourceAlias = this.getFullTableNameForSQL(source);
            if (this.isAliased(source)) {
                sourceAlias = String.format("%s %s", sourceAlias, sourceName);
            }
        }
        return sourceAlias;
    }

    private MergeStatement.UpdateClause handleUpdate(ASTNode whenMatchedUpdateClause, Table targetTable, String deleteExtraPredicate) throws SemanticException {
        assert (whenMatchedUpdateClause.getType() == 1090);
        assert (this.getWhenClauseOperation(whenMatchedUpdateClause).getType() == 1304);
        HashMap<String, String> newValuesMap = new HashMap<String, String>(targetTable.getCols().size() + targetTable.getPartCols().size());
        ASTNode setClause = (ASTNode)this.getWhenClauseOperation(whenMatchedUpdateClause).getChild(0);
        Map<String, ASTNode> setColsExprs = this.collectSetColumnsAndExpressions(setClause, null, targetTable);
        List<FieldSchema> nonPartCols = targetTable.getCols();
        Map<String, String> colNameToDefaultConstraint = this.getColNameToDefaultValueMap(targetTable);
        for (FieldSchema fs : nonPartCols) {
            String name = fs.getName();
            if (!setColsExprs.containsKey(name)) continue;
            ASTNode setColExpr = setColsExprs.get(name);
            if (setColExpr.getType() == 1270 && setColExpr.getChildCount() == 1 && setColExpr.getChild(0).getType() == 986) {
                UnparseTranslator defaultValueTranslator = new UnparseTranslator((Configuration)this.conf);
                defaultValueTranslator.enable();
                defaultValueTranslator.addDefaultValueTranslation(setColsExprs.get(name), colNameToDefaultConstraint.get(name));
                defaultValueTranslator.applyTranslations(this.ctx.getTokenRewriteStream());
            }
            String rhsExp = this.getMatchedText(setColsExprs.get(name));
            switch (rhsExp.charAt(rhsExp.length() - 1)) {
                case '\n': 
                case ',': {
                    rhsExp = rhsExp.substring(0, rhsExp.length() - 1);
                    break;
                }
            }
            newValuesMap.put(name, rhsExp);
        }
        String extraPredicate = this.getWhenClausePredicate(whenMatchedUpdateClause);
        this.setUpAccessControlInfoForUpdate(targetTable, setColsExprs);
        return new MergeStatement.UpdateClause(extraPredicate, deleteExtraPredicate, newValuesMap);
    }

    protected MergeStatement.DeleteClause handleDelete(ASTNode whenMatchedDeleteClause, String updateExtraPredicate) {
        assert (whenMatchedDeleteClause.getType() == 1090);
        String extraPredicate = this.getWhenClausePredicate(whenMatchedDeleteClause);
        return new MergeStatement.DeleteClause(extraPredicate, updateExtraPredicate);
    }

    private static String addParseInfo(ASTNode n) {
        return " at " + ASTErrorUtils.renderPosition((ASTNode)n);
    }

    private List<ASTNode> findWhenClauses(ASTNode tree, int start) throws SemanticException {
        assert (tree.getType() == 1091);
        ArrayList<ASTNode> whenClauses = new ArrayList<ASTNode>();
        for (int idx = start; idx < tree.getChildCount(); ++idx) {
            ASTNode whenClause = (ASTNode)tree.getChild(idx);
            assert (whenClause.getType() == 1090 || whenClause.getType() == 1097) : "Unexpected node type found: " + whenClause.getType() + MergeSemanticAnalyzer.addParseInfo(whenClause);
            whenClauses.add(whenClause);
        }
        if (whenClauses.size() <= 0) {
            throw new SemanticException("Must have at least 1 WHEN clause in MERGE statement");
        }
        return whenClauses;
    }

    protected ASTNode getWhenClauseOperation(ASTNode whenClause) {
        if (whenClause.getType() != 1090 && whenClause.getType() != 1097) {
            throw MergeSemanticAnalyzer.raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", whenClause);
        }
        return (ASTNode)whenClause.getChild(0);
    }

    private String getWhenClausePredicate(ASTNode whenClause) {
        if (whenClause.getType() != 1090 && whenClause.getType() != 1097) {
            throw MergeSemanticAnalyzer.raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", whenClause);
        }
        if (whenClause.getChildCount() == 2) {
            return this.getMatchedText((ASTNode)whenClause.getChild(1));
        }
        return null;
    }

    private MergeStatement.InsertClause handleInsert(ASTNode whenNotMatchedClause, String onClausePredicate, Table targetTable) throws SemanticException {
        ArrayList<String> columnNames;
        ASTNode whenClauseOperation = this.getWhenClauseOperation(whenNotMatchedClause);
        assert (whenNotMatchedClause.getType() == 1097);
        assert (whenClauseOperation.getType() == 1055);
        List children = whenClauseOperation.getChildren();
        ASTNode valuesNode = (ASTNode)children.stream().filter(n -> ((ASTNode)n).getType() == 1035).findFirst().get();
        ASTNode columnListNode = children.stream().filter(n -> ((ASTNode)n).getType() == 1247).findFirst().orElse(null);
        if (columnListNode != null) {
            if (columnListNode.getChildCount() != valuesNode.getChildCount() - 1) {
                throw new SemanticException(String.format("Column schema must have the same length as values (%d vs %d)", columnListNode.getChildCount(), valuesNode.getChildCount() - 1));
            }
            columnNames = new ArrayList<String>(valuesNode.getChildCount());
            for (int i = 0; i < columnListNode.getChildCount(); ++i) {
                ASTNode columnNameNode = (ASTNode)columnListNode.getChild(i);
                String columnName = this.ctx.getTokenRewriteStream().toString(columnNameNode.getTokenStartIndex(), columnNameNode.getTokenStopIndex()).trim();
                columnNames.add(columnName);
            }
        } else {
            columnNames = null;
        }
        ArrayList<String> values = new ArrayList<String>(valuesNode.getChildCount());
        UnparseTranslator unparseTranslator = HiveUtils.collectUnescapeIdentifierTranslations(valuesNode);
        unparseTranslator.applyTranslations(this.ctx.getTokenRewriteStream(), MERGE_INSERT_VALUES_PROGRAM);
        List<String> targetSchema = this.processTableColumnNames(columnListNode, targetTable.getFullyQualifiedName());
        List<String> defaultConstraints = this.getDefaultConstraints(targetTable, targetSchema);
        for (int i = 1; i < valuesNode.getChildCount(); ++i) {
            ASTNode valueNode = (ASTNode)valuesNode.getChild(i);
            String value = valueNode.getType() == 1270 && valueNode.getChild(0).getType() == 986 ? (String)ObjectUtils.defaultIfNull((Object)defaultConstraints.get(i - 1), (Object)"NULL") : this.ctx.getTokenRewriteStream().toString(MERGE_INSERT_VALUES_PROGRAM, valueNode.getTokenStartIndex(), valueNode.getTokenStopIndex()).trim();
            values.add(value);
        }
        String extraPredicate = this.getWhenClausePredicate(whenNotMatchedClause);
        return new MergeStatement.InsertClause(columnNames, values, onClausePredicate, extraPredicate);
    }

    @Override
    protected boolean allowOutputMultipleTimes() {
        return this.conf.getBoolVar(HiveConf.ConfVars.SPLIT_UPDATE);
    }

    @Override
    protected boolean enableColumnStatsCollecting() {
        return this.numWhenMatchedUpdateClauses == 0 && this.numWhenMatchedDeleteClauses == 0;
    }

    protected String getMatchedText(ASTNode n) {
        if (n == null) {
            return null;
        }
        this.quotedIdentifierHelper.visit(n);
        return this.ctx.getTokenRewriteStream().toString(n.getTokenStartIndex(), n.getTokenStopIndex() + 1).trim();
    }

    protected boolean isAliased(ASTNode n) {
        switch (n.getType()) {
            case 1274: {
                return MergeSemanticAnalyzer.findTabRefIdxs(n)[0] != 0;
            }
            case 1273: {
                return false;
            }
            case 1236: {
                assert (n.getChildCount() > 1) : "Expected Derived Table to be aliased";
                return true;
            }
        }
        throw MergeSemanticAnalyzer.raiseWrongType("TOK_TABREF|TOK_TABNAME", n);
    }

    protected String getSimpleTableName(ASTNode n) throws SemanticException {
        return HiveUtils.unparseIdentifier(MergeSemanticAnalyzer.getSimpleTableNameBase(n), (Configuration)this.conf);
    }

    private static final class IdentifierQuoter {
        private final TokenRewriteStream trs;
        private final IdentityHashMap<ASTNode, ASTNode> visitedNodes = new IdentityHashMap();

        IdentifierQuoter(TokenRewriteStream trs) {
            this.trs = trs;
            if (trs == null) {
                throw new IllegalArgumentException("Must have a TokenRewriteStream");
            }
        }

        private void visit(ASTNode n) {
            if (n.getType() == 24) {
                if (this.visitedNodes.containsKey(n)) {
                    return;
                }
                this.visitedNodes.put(n, n);
                this.trs.insertBefore(n.getToken(), (Object)"`");
                this.trs.insertAfter(n.getToken(), (Object)"`");
            }
            if (n.getChildCount() <= 0) {
                return;
            }
            for (Node c : n.getChildren()) {
                this.visit((ASTNode)c);
            }
        }
    }

    private static final class OnClauseAnalyzer {
        private final ASTNode onClause;
        private final Map<String, List<String>> table2column = new HashMap<String, List<String>>();
        private final List<String> unresolvedColumns = new ArrayList<String>();
        private final List<FieldSchema> allTargetTableColumns = new ArrayList<FieldSchema>();
        private final Set<String> tableNamesFound = new HashSet<String>();
        private final String targetTableNameInSourceQuery;
        private final HiveConf conf;
        private final String onClauseAsString;

        OnClauseAnalyzer(ASTNode onClause, Table targetTable, String targetTableNameInSourceQuery, HiveConf conf, String onClauseAsString) {
            this.onClause = onClause;
            this.allTargetTableColumns.addAll(targetTable.getCols());
            this.allTargetTableColumns.addAll(targetTable.getPartCols());
            this.targetTableNameInSourceQuery = BaseSemanticAnalyzer.unescapeIdentifier(targetTableNameInSourceQuery);
            this.conf = conf;
            this.onClauseAsString = onClauseAsString;
        }

        private void visit(ASTNode n) {
            if (n.getType() == 1270) {
                ASTNode parent = (ASTNode)n.getParent();
                if (parent != null && parent.getType() == 16) {
                    if (parent.getParent() != null && parent.getParent().getType() == 16) {
                        throw new IllegalArgumentException("Found unexpected db.table.col reference in " + this.onClauseAsString);
                    }
                    this.addColumn2Table(n.getChild(0).getText(), parent.getChild(1).getText());
                } else {
                    this.unresolvedColumns.add(n.getChild(0).getText());
                }
            }
            if (n.getChildCount() == 0) {
                return;
            }
            for (Node child : n.getChildren()) {
                this.visit((ASTNode)child);
            }
        }

        private void analyze() {
            this.visit(this.onClause);
            if (this.tableNamesFound.size() > 2) {
                throw new IllegalArgumentException("Found > 2 table refs in ON clause.  Found " + this.tableNamesFound + " in " + this.onClauseAsString);
            }
            this.handleUnresolvedColumns();
            if (this.tableNamesFound.size() > 2) {
                throw new IllegalArgumentException("Found > 2 table refs in ON clause (incl unresolved).  Found " + this.tableNamesFound + " in " + this.onClauseAsString);
            }
        }

        private void handleUnresolvedColumns() {
            if (this.unresolvedColumns.isEmpty()) {
                return;
            }
            block0: for (String c : this.unresolvedColumns) {
                for (FieldSchema fs : this.allTargetTableColumns) {
                    if (!c.equalsIgnoreCase(fs.getName())) continue;
                    this.addColumn2Table(this.targetTableNameInSourceQuery.toLowerCase(), c);
                    continue block0;
                }
            }
        }

        private void addColumn2Table(String tableName, String columnName) {
            tableName = tableName.toLowerCase();
            this.tableNamesFound.add(tableName);
            List<String> cols = this.table2column.get(tableName);
            if (cols == null) {
                cols = new ArrayList<String>();
                this.table2column.put(tableName, cols);
            }
            cols.add(columnName);
        }

        private String getPredicate() {
            List<String> targetCols = this.table2column.get(this.targetTableNameInSourceQuery.toLowerCase());
            if (targetCols == null) {
                throw new IllegalArgumentException(ErrorMsg.INVALID_TABLE_IN_ON_CLAUSE_OF_MERGE.format(new String[]{this.targetTableNameInSourceQuery, this.onClauseAsString}));
            }
            StringBuilder sb = new StringBuilder();
            for (String col : targetCols) {
                if (sb.length() > 0) {
                    sb.append(" AND ");
                }
                sb.append(HiveUtils.unparseIdentifier(this.targetTableNameInSourceQuery, (Configuration)this.conf)).append(".").append(HiveUtils.unparseIdentifier(col, (Configuration)this.conf)).append(" IS NULL");
            }
            return sb.toString();
        }
    }
}

