diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index d90c9a8eb8270b..81b0830be66c81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -156,15 +156,15 @@ default R visitBitmapUnionInt(BitmapUnionInt bitmapUnionInt, C context) { } default R visitBoolAnd(BoolAnd boolAnd, C context) { - return visitAggregateFunction(boolAnd, context); + return visitNullableAggregateFunction(boolAnd, context); } default R visitBoolOr(BoolOr boolOr, C context) { - return visitAggregateFunction(boolOr, context); + return visitNullableAggregateFunction(boolOr, context); } default R visitBoolXor(BoolXor boolXor, C context) { - return visitAggregateFunction(boolXor, context); + return visitNullableAggregateFunction(boolXor, context); } default R visitCollectList(CollectList collectList, C context) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitorTest.java new file mode 100644 index 00000000000000..d77f1007a808c9 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitorTest.java @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.visitor; + +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +class AggregateFunctionVisitorTest { + private static final String AGG_PACKAGE = "org.apache.doris.nereids.trees.expressions.functions.agg."; + private static final String COMBINATOR_PACKAGE = "org.apache.doris.nereids.trees.expressions.functions.combinator."; + + @Test + void testNullableAggregateFunctionsVisitNullableDefault() throws Exception { + AggregateFunctionVisitor visitor = new AggregateFunctionVisitor() { + @Override + public String visitAggregateFunction(AggregateFunction function, Void context) { + return "aggregate"; + } + + @Override + public String visitNullableAggregateFunction(NullableAggregateFunction nullableAggregateFunction, + Void context) { + return "nullable"; + } + }; + + for (Class functionClass : nullableAggregateFunctionClasses()) { + List visitorMethods = Arrays.stream(AggregateFunctionVisitor.class.getMethods()) + .filter(method -> method.getParameterCount() == 2) + .filter(method -> method.getParameterTypes()[0].equals(functionClass)) + .collect(Collectors.toList()); + Assertions.assertEquals(1, visitorMethods.size(), functionClass.getName()); + Assertions.assertEquals("nullable", visitorMethods.get(0).invoke(visitor, null, null), + functionClass.getName()); + } + } + + private static List> nullableAggregateFunctionClasses() throws ClassNotFoundException { + return Arrays.asList( + aggregateClass("AIAgg"), + aggregateClass("AnyValue"), + aggregateClass("Avg"), + aggregateClass("AvgWeighted"), + aggregateClass("BoolAnd"), + aggregateClass("BoolOr"), + aggregateClass("BoolXor"), + aggregateClass("Corr"), + aggregateClass("CorrWelford"), + aggregateClass("Covar"), + aggregateClass("CovarSamp"), + aggregateClass("ExponentialMovingAverage"), + aggregateClass("GroupBitAnd"), + aggregateClass("GroupBitOr"), + aggregateClass("GroupBitXor"), + aggregateClass("GroupBitmapXor"), + aggregateClass("GroupConcat"), + aggregateClass("Max"), + aggregateClass("MaxBy"), + aggregateClass("Median"), + aggregateClass("Min"), + aggregateClass("MinBy"), + aggregateClass("MultiDistinctGroupConcat"), + aggregateClass("MultiDistinctSum"), + aggregateClass("Percentile"), + aggregateClass("PercentileApprox"), + aggregateClass("PercentileApproxWeighted"), + aggregateClass("PercentileReservoir"), + aggregateClass("Retention"), + aggregateClass("Sem"), + aggregateClass("SequenceMatch"), + aggregateClass("Stddev"), + aggregateClass("StddevSamp"), + aggregateClass("Sum"), + aggregateClass("TopN"), + aggregateClass("TopNArray"), + aggregateClass("TopNWeighted"), + aggregateClass("Variance"), + aggregateClass("VarianceSamp"), + aggregateClass("WindowFunnel"), + aggregateClass("WindowFunnelV2"), + combinatorClass("ForEachCombinator")); + } + + private static Class aggregateClass(String name) throws ClassNotFoundException { + return Class.forName(AGG_PACKAGE + name); + } + + private static Class combinatorClass(String name) throws ClassNotFoundException { + return Class.forName(COMBINATOR_PACKAGE + name); + } +}