package org.apache.hadoop.hive.ql.optimizer.calcite.translator;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/translator/PlanModifierForASTConv.class */
public class PlanModifierForASTConv {
    private static final Log LOG = LogFactory.getLog(PlanModifierForASTConv.class);

    public static RelNode convertOpTree(RelNode relNode, List<FieldSchema> list) throws CalciteSemanticException {
        RelNode relNode2 = relNode;
        if (LOG.isDebugEnabled()) {
            LOG.debug("Original plan for PlanModifier\n " + RelOptUtil.toString(relNode2));
        }
        if (!(relNode2 instanceof Project) && !(relNode2 instanceof Sort)) {
            relNode2 = introduceDerivedTable(relNode2);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Plan after top-level introduceDerivedTable\n " + RelOptUtil.toString(relNode2));
            }
        }
        convertOpTree(relNode2, (RelNode) null);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Plan after nested convertOpTree\n " + RelOptUtil.toString(relNode2));
        }
        PlanModifierUtil.fixTopOBSchema(relNode2, HiveCalciteUtil.getTopLevelSelect(relNode2), list, true);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Plan after fixTopOBSchema\n " + RelOptUtil.toString(relNode2));
        }
        RelNode renameTopLevelSelectInResultSchema = renameTopLevelSelectInResultSchema(relNode2, HiveCalciteUtil.getTopLevelSelect(relNode2), list);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Final plan after modifier\n " + RelOptUtil.toString(renameTopLevelSelectInResultSchema));
        }
        return renameTopLevelSelectInResultSchema;
    }

    private static String getTblAlias(RelNode relNode) {
        if (null == relNode) {
            return null;
        }
        if (relNode instanceof HiveTableScan) {
            return ((HiveTableScan) relNode).getTableAlias();
        }
        if (!(relNode instanceof Project) && relNode.getInputs().size() == 1) {
            return getTblAlias(relNode.getInput(0));
        }
        return null;
    }

    private static void convertOpTree(RelNode relNode, RelNode relNode2) {
        if (relNode instanceof HepRelVertex) {
            throw new RuntimeException("Found HepRelVertex");
        }
        if (relNode instanceof Join) {
            if (!validJoinParent(relNode, relNode2)) {
                introduceDerivedTable(relNode, relNode2);
            }
            String tblAlias = getTblAlias(((Join) relNode).getLeft());
            if (null != tblAlias && tblAlias.equalsIgnoreCase(getTblAlias(((Join) relNode).getRight()))) {
                introduceDerivedTable(((Join) relNode).getLeft(), relNode);
            }
        } else {
            if (relNode instanceof MultiJoin) {
                throw new RuntimeException("Found MultiJoin");
            }
            if (relNode instanceof RelSubset) {
                throw new RuntimeException("Found RelSubset");
            }
            if (relNode instanceof SetOp) {
                if (!validSetopParent(relNode, relNode2)) {
                    introduceDerivedTable(relNode, relNode2);
                }
                SetOp setOp = (SetOp) relNode;
                for (RelNode relNode3 : setOp.getInputs()) {
                    if (!validSetopChild(relNode3)) {
                        introduceDerivedTable(relNode3, setOp);
                    }
                }
            } else if (relNode instanceof SingleRel) {
                if (relNode instanceof Filter) {
                    if (!validFilterParent(relNode, relNode2)) {
                        introduceDerivedTable(relNode, relNode2);
                    }
                } else if (relNode instanceof HiveSort) {
                    if (!validSortParent(relNode, relNode2)) {
                        introduceDerivedTable(relNode, relNode2);
                    }
                    if (!validSortChild((HiveSort) relNode)) {
                        introduceDerivedTable(((HiveSort) relNode).getInput(), relNode);
                    }
                } else if (relNode instanceof HiveAggregate) {
                    RelNode relNode4 = relNode2;
                    if (!validGBParent(relNode, relNode2)) {
                        relNode4 = introduceDerivedTable(relNode, relNode2);
                    }
                    if (isEmptyGrpAggr(relNode)) {
                        replaceEmptyGroupAggr(relNode, relNode4);
                    }
                }
            }
        }
        List<RelNode> inputs = relNode.getInputs();
        if (inputs != null) {
            Iterator<RelNode> it = inputs.iterator();
            while (it.hasNext()) {
                convertOpTree(it.next(), relNode);
            }
        }
    }

    private static RelNode renameTopLevelSelectInResultSchema(RelNode relNode, Pair<RelNode, RelNode> pair, List<FieldSchema> list) throws CalciteSemanticException {
        RelNode key = pair.getKey();
        HiveProject hiveProject = (HiveProject) pair.getValue();
        List<RexNode> childExps = hiveProject.getChildExps();
        if (list.size() != childExps.size()) {
            LOG.error(PlanModifierUtil.generateInvalidSchemaMessage(hiveProject, list, 0));
            throw new CalciteSemanticException("Result Schema didn't match Optimized Op Tree Schema");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < childExps.size(); i++) {
            String name = list.get(i).getName();
            if (name.startsWith("_")) {
                name = name.substring(1);
            }
            arrayList.add(name);
        }
        HiveProject create = HiveProject.create(hiveProject.getInput(), hiveProject.getChildExps(), arrayList);
        if (relNode == hiveProject) {
            return create;
        }
        key.replaceInput(0, create);
        return relNode;
    }

    private static RelNode introduceDerivedTable(RelNode relNode) {
        return HiveProject.create(relNode.getCluster(), relNode, HiveCalciteUtil.getProjsFromBelowAsInputRef(relNode), relNode.getRowType(), relNode.getCollationList());
    }

    private static RelNode introduceDerivedTable(RelNode relNode, RelNode relNode2) {
        int i = 0;
        int i2 = -1;
        Iterator<RelNode> it = relNode2.getInputs().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (it.next() == relNode) {
                i2 = i;
                break;
            }
            i++;
        }
        if (i2 == -1) {
            throw new RuntimeException("Couldn't find child node in parent's inputs");
        }
        RelNode introduceDerivedTable = introduceDerivedTable(relNode);
        relNode2.replaceInput(i2, introduceDerivedTable);
        return introduceDerivedTable;
    }

    private static boolean validJoinParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if (relNode2 instanceof Join) {
            if (((Join) relNode2).getRight() == relNode) {
                z = false;
            }
        } else if (relNode2 instanceof SetOp) {
            z = false;
        }
        return z;
    }

    private static boolean validFilterParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if ((relNode2 instanceof Filter) || (relNode2 instanceof Join) || (relNode2 instanceof SetOp)) {
            z = false;
        }
        return z;
    }

    private static boolean validGBParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if ((relNode2 instanceof Join) || (relNode2 instanceof SetOp) || (relNode2 instanceof Aggregate) || ((relNode2 instanceof Filter) && ((Aggregate) relNode).getGroupSet().isEmpty())) {
            z = false;
        }
        return z;
    }

    private static boolean validSortParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if (relNode2 != null && !(relNode2 instanceof Project) && !(relNode2 instanceof Sort) && !HiveCalciteUtil.orderRelNode(relNode2)) {
            z = false;
        }
        return z;
    }

    private static boolean validSortChild(HiveSort hiveSort) {
        boolean z = true;
        RelNode input = hiveSort.getInput();
        if ((!HiveCalciteUtil.limitRelNode(hiveSort) || !HiveCalciteUtil.orderRelNode(input)) && !(input instanceof Project)) {
            z = false;
        }
        return z;
    }

    private static boolean validSetopParent(RelNode relNode, RelNode relNode2) {
        boolean z = true;
        if (relNode2 != null && !(relNode2 instanceof Project)) {
            z = false;
        }
        return z;
    }

    private static boolean validSetopChild(RelNode relNode) {
        boolean z = true;
        if (!(relNode instanceof Project)) {
            z = false;
        }
        return z;
    }

    private static boolean isEmptyGrpAggr(RelNode relNode) {
        Aggregate aggregate = (Aggregate) relNode;
        return aggregate.getGroupSet().isEmpty() && aggregate.getAggCallList().isEmpty();
    }

    private static void replaceEmptyGroupAggr(RelNode relNode, RelNode relNode2) {
        for (RexNode rexNode : relNode2.getChildExps()) {
            if (!((Boolean) rexNode.accept(new HiveCalciteUtil.ConstantFinder())).booleanValue()) {
                throw new RuntimeException("We expect " + relNode2.toString() + " to contain only constants. However, " + rexNode.toString() + " is " + rexNode.getKind());
            }
        }
        HiveAggregate hiveAggregate = (HiveAggregate) relNode;
        RelDataTypeFactory typeFactory = hiveAggregate.getCluster().getTypeFactory();
        RelDataType convert = TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
        relNode2.replaceInput(0, introduceDerivedTable(hiveAggregate.copy(hiveAggregate.getTraitSet(), hiveAggregate.getInput(), hiveAggregate.indicator, hiveAggregate.getGroupSet(), hiveAggregate.getGroupSets(), ImmutableList.of(new AggregateCall(SqlFunctionConverter.getCalciteAggFn("count", ImmutableList.of(TypeConverter.convert(TypeInfoFactory.intTypeInfo, typeFactory)), convert), false, ImmutableList.of(0), convert, null)))));
    }
}
