diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index ca4f8d6c5a4b04..447de055bc5a78 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -249,7 +249,7 @@ private TypeCheckResult checkInputDataTypesWithExpectTypes( DataType expected = expectedTypes.get(i); if (!checkInputDataTypesWithExpectType(input.getDataType(), expected)) { errorMessages.add(String.format("argument %d requires %s type, however '%s' is of %s type", - i + 1, expected.simpleString(), input.toSql(), input.getDataType().simpleString())); + i + 1, expected, input.toSql(), input.getDataType())); } } if (!errorMessages.isEmpty()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java index 5410de371a793c..3d186812d42b6b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java @@ -18,24 +18,41 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.RewriteWhenAnalyze; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.coercion.AnyDataType; + +import com.google.common.collect.ImmutableList; import java.util.List; /** * ScalarFunction 'array_first'. */ -public class ArrayFirst extends ElementAt - implements HighOrderFunction { +public class ArrayFirst extends ScalarFunction + implements HighOrderFunction, PropagateNullLiteral, PropagateNullable, RewriteWhenAnalyze { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX), + ArrayType.of(BooleanType.INSTANCE)) + ); /** * constructor with arguments. - * array_first(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), 1) */ public ArrayFirst(Expression arg) { - super(new ArrayFilter(arg), new BigIntLiteral(1)); + super("array_first", arg instanceof Lambda ? arg.child(1).child(0) : arg, new ArrayMap(arg)); + if (!(arg instanceof Lambda)) { + throw new AnalysisException( + String.format("The 1st arg of %s must be lambda but is %s", getName(), arg)); + } } /** constructor for withChildren and reuse signature */ @@ -44,7 +61,7 @@ private ArrayFirst(ScalarFunctionParams functionParams) { } @Override - public ElementAt withChildren(List children) { + public ArrayFirst withChildren(List children) { return new ArrayFirst(getFunctionParams(children)); } @@ -57,4 +74,10 @@ public List getImplSignature() { public R accept(ExpressionVisitor visitor, C context) { return visitor.visitArrayFirst(this, context); } + + // array_first(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), 1) + @Override + public Expression rewriteWhenAnalyze() { + return new ElementAt(new ArrayFilter(getArgument(0), getArgument(1)), new BigIntLiteral(1)); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java index b9f5650156f083..45ab415b6ff1fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java @@ -18,24 +18,41 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.functions.RewriteWhenAnalyze; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.coercion.AnyDataType; + +import com.google.common.collect.ImmutableList; import java.util.List; /** * ScalarFunction 'array_last'. */ -public class ArrayLast extends ElementAt - implements HighOrderFunction { +public class ArrayLast extends ScalarFunction + implements HighOrderFunction, PropagateNullLiteral, PropagateNullable, RewriteWhenAnalyze { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX), + ArrayType.of(BooleanType.INSTANCE)) + ); /** * constructor with arguments. - * array_last(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), -1) */ public ArrayLast(Expression arg) { - super(new ArrayFilter(arg), new BigIntLiteral(-1)); + super("array_last", arg instanceof Lambda ? arg.child(1).child(0) : arg, new ArrayMap(arg)); + if (!(arg instanceof Lambda)) { + throw new AnalysisException( + String.format("The 1st arg of %s must be lambda but is %s", getName(), arg)); + } } /** constructor for withChildren and reuse signature */ @@ -49,7 +66,7 @@ public List getImplSignature() { } @Override - public ElementAt withChildren(List children) { + public ArrayLast withChildren(List children) { return new ArrayLast(getFunctionParams(children)); } @@ -57,4 +74,10 @@ public ElementAt withChildren(List children) { public R accept(ExpressionVisitor visitor, C context) { return visitor.visitArrayLast(this, context); } + + // array_last(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), -1) + @Override + public Expression rewriteWhenAnalyze() { + return new ElementAt(new ArrayFilter(getArgument(0), getArgument(1)), new BigIntLiteral(-1)); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 2edfc133152a4f..f22f384b1c1b0a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -699,7 +699,7 @@ default R visitArrayFilter(ArrayFilter arrayFilter, C context) { } default R visitArrayFirst(ArrayFirst arrayFirst, C context) { - return visitElementAt(arrayFirst, context); + return visitScalarFunction(arrayFirst, context); } default R visitArrayFirstIndex(ArrayFirstIndex arrayFirstIndex, C context) { @@ -715,7 +715,7 @@ default R visitArrayJoin(ArrayJoin arrayJoin, C context) { } default R visitArrayLast(ArrayLast arrayLast, C context) { - return visitElementAt(arrayLast, context); + return visitScalarFunction(arrayLast, context); } default R visitArrayLastIndex(ArrayLastIndex arrayLastIndex, C context) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirstLastTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirstLastTest.java new file mode 100644 index 00000000000000..1d7b285fdd61ff --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirstLastTest.java @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.BooleanType; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ArrayFirstLastTest extends ExpressionRewriteTestHelper { + + @Test + public void testArrayFirstLambdaResultCanCastToBoolean() { + assertLambdaResultCastToBoolean("array_first(x -> x, [0, 1])", 1); + } + + @Test + public void testArrayLastLambdaResultCanCastToBoolean() { + assertLambdaResultCastToBoolean("array_last(x -> x, [0, 1])", -1); + } + + private void assertLambdaResultCastToBoolean(String sql, long expectedIndex) { + Expression analyzed = typeCoercion(PARSER.parseExpression(sql)); + Assertions.assertTrue(analyzed instanceof ElementAt); + + ElementAt elementAt = (ElementAt) analyzed; + Assertions.assertTrue(elementAt.left() instanceof ArrayFilter); + Assertions.assertEquals(expectedIndex, ((BigIntLiteral) elementAt.right()).getValue()); + + ArrayFilter arrayFilter = (ArrayFilter) elementAt.left(); + Expression filterResult = arrayFilter.child(1); + Assertions.assertTrue(filterResult instanceof Cast); + Assertions.assertEquals(ArrayType.of(BooleanType.INSTANCE), filterResult.getDataType()); + Assertions.assertTrue(((Cast) filterResult).child() instanceof ArrayMap); + } +} diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out index bc4f80a9576f03..24a5dc6f88e7f2 100644 --- a/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out @@ -14,6 +14,9 @@ b -- !select_05 -- 10.2 +-- !select_lambda_result_cast -- +1 + -- !select_06 -- 0 [2] ["123", "124", "125"] 1 [1, 2, 3, 4, 5] ["234", "124", "125"] diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out index b00916ba420d23..45c56736f0f8eb 100644 --- a/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out @@ -14,6 +14,9 @@ c -- !select_05 -- 5.3 +-- !select_lambda_result_cast -- +2 + -- !select_06 -- 0 [2] ["123", "124", "125"] 1 [1, 2, 3, 4, 5] ["234", "124", "125"] diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy index 2b4fc07860225c..ee11b7ab1085cf 100644 --- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy @@ -45,9 +45,10 @@ suite("test_array_first") { qt_select_03 " select array_first(x -> x>=5,[1,2,3,4,5]);" qt_select_04 " select array_first(x -> x > 'abc', ['a','b','c']);" qt_select_05 " select array_first(x -> x > 5.2 , [10.2, 5.3, 4]);" + qt_select_lambda_result_cast " select array_first(x -> x, [0, 1, 2]);" qt_select_06 "select * from ${tableName} order by id;" qt_select_07 " select array_first(x->x>3,c_array1), array_first(x-> x>'124',c_array2) from test_array_first order by id;" sql "DROP TABLE IF EXISTS ${tableName}" -} \ No newline at end of file +} diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy index fe5fdec9ffceb2..82df24c8eab748 100644 --- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy @@ -45,9 +45,10 @@ suite("test_array_last") { qt_select_03 " select array_last(x -> x>=5,[1,2,3,4,5]);" qt_select_04 " select array_last(x -> x > 'abc', ['a','b','c']);" qt_select_05 " select array_last(x -> x > 5.2 , [10.2, 5.3, 4]);" + qt_select_lambda_result_cast " select array_last(x -> x, [0, 1, 2]);" qt_select_06 "select * from ${tableName} order by id;" qt_select_07 " select array_last(x->x>3,c_array1), array_last(x-> x>'124',c_array2) from test_array_last order by id;" sql "DROP TABLE IF EXISTS ${tableName}" -} \ No newline at end of file +}