/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeSet;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.rules.ImmutableAggregateExpandWithinDistinctRule;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.IntPair;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;
import shaded.com.google.common.collect.ArrayListMultimap;
import shaded.com.google.common.collect.ImmutableList;

@Value.Enclosing
public class AggregateExpandWithinDistinctRule
extends RelRule<Config> {
    protected AggregateExpandWithinDistinctRule(Config config) {
        super(config);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static boolean hasWithinDistinct(Aggregate aggregate) {
        if (!aggregate.getAggCallList().stream().anyMatch(c -> c.distinctKeys != null)) return false;
        if (!aggregate.getAggCallList().stream().noneMatch(CoreRules.AGGREGATE_REDUCE_FUNCTIONS::canReduce)) return false;
        if (aggregate.getGroupType() != Aggregate.Group.SIMPLE) return false;
        return true;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        List aggCallList = aggregate.getAggCallList().stream().map(c -> AggregateExpandWithinDistinctRule.unDistinct(c, aggregate.getInput()::fieldIsNullable)).collect(ImmutableList.toImmutableList());
        ArrayListMultimap<ImmutableBitSet, Object> argLists = ArrayListMultimap.create();
        ImmutableBitSet notDistinct = ImmutableBitSet.of(aggregate.getInput().getRowType().getFieldCount());
        for (Object aggCall : aggCallList) {
            ImmutableBitSet distinctKeys = ((AggregateCall)aggCall).distinctKeys;
            if (distinctKeys == null) {
                distinctKeys = notDistinct;
            } else if (distinctKeys.intersects(aggregate.getGroupSet())) {
                distinctKeys = distinctKeys.rebuild().removeAll(aggregate.getGroupSet()).build();
            }
            argLists.put(distinctKeys, aggCall);
        }
        TreeSet<ImmutableBitSet> groupSetTreeSet = new TreeSet<ImmutableBitSet>(ImmutableBitSet.ORDERING);
        for (ImmutableBitSet key : argLists.keySet()) {
            groupSetTreeSet.add(key == notDistinct ? aggregate.getGroupSet() : ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
        }
        ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
        boolean hasMultipleGroupSets = groupSets.size() > 1;
        final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
        LinkedHashSet<Integer> fullGroupOrderedSet = new LinkedHashSet<Integer>();
        fullGroupOrderedSet.addAll(aggregate.getGroupSet().asSet());
        fullGroupOrderedSet.addAll(fullGroupSet.asSet());
        ImmutableIntList fullGroupList = ImmutableIntList.copyOf(fullGroupOrderedSet);
        final RelBuilder b = call.builder();
        b.push(aggregate.getInput());
        final ArrayList aggCalls = new ArrayList();
        class Registrar {
            final int g;
            final Map<IntPair, Integer> args;
            final Map<Integer, Integer> aggs;
            final Map<Integer, Integer> counts;

            Registrar() {
                this.g = fullGroupSet.cardinality();
                this.args = new HashMap<IntPair, Integer>();
                this.aggs = new HashMap<Integer, Integer>();
                this.counts = new HashMap<Integer, Integer>();
            }

            List<Integer> fields(List<Integer> fields2, int filterArg) {
                return Util.transform(fields2, f -> this.field((int)f, filterArg));
            }

            int field(int field, int filterArg) {
                return Objects.requireNonNull(this.args.get(IntPair.of(field, filterArg)));
            }

            int register(int field, int filterArg) {
                return this.args.computeIfAbsent(IntPair.of(field, filterArg), j -> {
                    int ordinal = this.g + aggCalls.size();
                    RelBuilder.AggCall groupedField = b.aggregateCall(SqlStdOperatorTable.MIN, b.field(field));
                    aggCalls.add(filterArg < 0 ? groupedField : groupedField.filter(b.field(filterArg)));
                    if (((Config)AggregateExpandWithinDistinctRule.this.config).throwIfNotUnique()) {
                        groupedField = b.aggregateCall(SqlStdOperatorTable.MAX, b.field(field));
                        aggCalls.add(filterArg < 0 ? groupedField : groupedField.filter(b.field(filterArg)));
                    }
                    return ordinal;
                });
            }

            int registerAgg(int i, RelBuilder.AggCall aggregateCall) {
                int ordinal = this.g + aggCalls.size();
                this.aggs.put(i, ordinal);
                aggCalls.add(aggregateCall);
                return ordinal;
            }

            int getAgg(int i) {
                return Objects.requireNonNull(this.aggs.get(i));
            }

            int registerCount(int filterArg) {
                assert (filterArg >= 0);
                return this.counts.computeIfAbsent(filterArg, i -> {
                    int ordinal = this.g + aggCalls.size();
                    aggCalls.add(b.aggregateCall(SqlStdOperatorTable.COUNT, new RexNode[0]).filter(b.field(filterArg)));
                    return ordinal;
                });
            }

            int getCount(int filterArg) {
                return Objects.requireNonNull(this.counts.get(filterArg));
            }
        }
        Registrar registrar = new Registrar();
        Ord.forEach(aggCallList, (c, i) -> {
            if (c.distinctKeys == null) {
                RelBuilder.AggCall aggCall = b.aggregateCall(c.getParserPosition(), c.getAggregation(), b.fields(c.getArgList()));
                if (c.hasFilter()) {
                    aggCall = aggCall.filter(b.field(c.filterArg));
                }
                if (c.hasCollation()) {
                    aggCall = aggCall.sort(b.fields(c.getCollation()));
                }
                registrar.registerAgg(i, aggCall);
            } else {
                for (int inputIdx : c.getArgList()) {
                    registrar.register(inputIdx, c.filterArg);
                }
                if (AggregateExpandWithinDistinctRule.mustBeCounted(c)) {
                    registrar.registerCount(c.filterArg);
                }
            }
        });
        int grouping = hasMultipleGroupSets ? registrar.registerAgg(-1, b.aggregateCall(SqlStdOperatorTable.GROUPING, b.fields(fullGroupList))) : -1;
        b.aggregate(b.groupKey(fullGroupSet, (Iterable<? extends ImmutableBitSet>)groupSets), aggCalls);
        aggCalls.clear();
        Ord.forEach(aggCallList, (c, i) -> {
            RelBuilder.AggCall aggCall;
            ArrayList<RexNode> filters = new ArrayList<RexNode>();
            RexNode groupFilter = null;
            if (hasMultipleGroupSets) {
                groupFilter = b.equals(b.field(grouping), b.literal(AggregateExpandDistinctAggregatesRule.groupValue(fullGroupList, AggregateExpandWithinDistinctRule.union(aggregate.getGroupSet(), c.distinctKeys))));
                filters.add(groupFilter);
            }
            if (c.distinctKeys == null) {
                aggCall = b.aggregateCall(c.getParserPosition(), SqlStdOperatorTable.MIN, b.field(registrar.getAgg(i)));
            } else {
                aggCall = b.aggregateCall(c.getParserPosition(), c.getAggregation(), b.fields(registrar.fields(c.getArgList(), c.filterArg)));
                if (AggregateExpandWithinDistinctRule.mustBeCounted(c)) {
                    filters.add(b.greaterThan(b.field(registrar.getCount(c.filterArg)), b.literal(0)));
                }
                if (((Config)this.config).throwIfNotUnique()) {
                    for (int j : c.getArgList()) {
                        RexNode isUniqueCondition = b.isNotDistinctFrom(b.field(registrar.field(j, c.filterArg)), b.field(registrar.field(j, c.filterArg) + 1));
                        if (groupFilter != null) {
                            isUniqueCondition = b.or(b.not(groupFilter), isUniqueCondition);
                        }
                        String message = "more than one distinct value in agg UNIQUE_VALUE";
                        filters.add(b.call((SqlOperator)SqlInternalOperators.THROW_UNLESS, isUniqueCondition, b.literal(message)));
                    }
                }
            }
            if (!filters.isEmpty()) {
                aggCall = aggCall.filter(b.and(filters));
            }
            aggCalls.add(aggCall);
        });
        b.aggregate(b.groupKey(AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSet()), (Iterable<? extends ImmutableBitSet>)AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSets())), aggCalls);
        b.convert(aggregate.getRowType(), false);
        call.transformTo(b.build());
    }

    private static boolean mustBeCounted(AggregateCall aggCall) {
        return aggCall.hasFilter();
    }

    private static AggregateCall unDistinct(AggregateCall aggregateCall, IntPredicate isNullable) {
        if (aggregateCall.isDistinct()) {
            List<Integer> newArgList = aggregateCall.getArgList().stream().filter(i -> aggregateCall.getAggregation().getKind() != SqlKind.COUNT || aggregateCall.hasFilter() || isNullable.test((int)i)).collect(Collectors.toList());
            return aggregateCall.withDistinct(false).withDistinctKeys(ImmutableBitSet.of(aggregateCall.getArgList())).withArgList(newArgList);
        }
        return aggregateCall;
    }

    private static ImmutableBitSet union(ImmutableBitSet s0, @Nullable ImmutableBitSet s1) {
        return s1 == null ? s0 : s0.union(s1);
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableAggregateExpandWithinDistinctRule.Config.of().withOperandSupplier(b -> b.operand(LogicalAggregate.class).predicate(x$0 -> AggregateExpandWithinDistinctRule.hasWithinDistinct(x$0)).anyInputs());

        @Override
        default public AggregateExpandWithinDistinctRule toRule() {
            return new AggregateExpandWithinDistinctRule(this);
        }

        @Value.Default
        default public boolean throwIfNotUnique() {
            return true;
        }

        public Config withThrowIfNotUnique(boolean var1);
    }
}

