Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ private TypeCheckResult checkInputDataTypesWithExpectTypes(
DataType expected = expectedTypes.get(i);
if (!checkInputDataTypesWithExpectType(input.getDataType(), expected)) {
errorMessages.add(String.format("argument %d requires %s type, however '%s' is of %s type",
i + 1, expected.simpleString(), input.toSql(), input.getDataType().simpleString()));
i + 1, expected, input.toSql(), input.getDataType()));
}
}
if (!errorMessages.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,41 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.RewriteWhenAnalyze;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.coercion.AnyDataType;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'array_first'.
*/
public class ArrayFirst extends ElementAt
implements HighOrderFunction {
public class ArrayFirst extends ScalarFunction
implements HighOrderFunction, PropagateNullLiteral, PropagateNullable, RewriteWhenAnalyze {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX),
ArrayType.of(BooleanType.INSTANCE))
);

/**
* constructor with arguments.
* array_first(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), 1)
*/
public ArrayFirst(Expression arg) {
super(new ArrayFilter(arg), new BigIntLiteral(1));
super("array_first", arg instanceof Lambda ? arg.child(1).child(0) : arg, new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
}

/** constructor for withChildren and reuse signature */
Expand All @@ -44,7 +61,7 @@ private ArrayFirst(ScalarFunctionParams functionParams) {
}

@Override
public ElementAt withChildren(List<Expression> children) {
public ArrayFirst withChildren(List<Expression> children) {
return new ArrayFirst(getFunctionParams(children));
}

Expand All @@ -57,4 +74,10 @@ public List<FunctionSignature> getImplSignature() {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayFirst(this, context);
}

// array_first(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), 1)
@Override
public Expression rewriteWhenAnalyze() {
return new ElementAt(new ArrayFilter(getArgument(0), getArgument(1)), new BigIntLiteral(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,41 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.RewriteWhenAnalyze;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.coercion.AnyDataType;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'array_last'.
*/
public class ArrayLast extends ElementAt
implements HighOrderFunction {
public class ArrayLast extends ScalarFunction
implements HighOrderFunction, PropagateNullLiteral, PropagateNullable, RewriteWhenAnalyze {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX),
ArrayType.of(BooleanType.INSTANCE))
);

/**
* constructor with arguments.
* array_last(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), -1)
*/
public ArrayLast(Expression arg) {
super(new ArrayFilter(arg), new BigIntLiteral(-1));
super("array_last", arg instanceof Lambda ? arg.child(1).child(0) : arg, new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
}

/** constructor for withChildren and reuse signature */
Expand All @@ -49,12 +66,18 @@ public List<FunctionSignature> getImplSignature() {
}

@Override
public ElementAt withChildren(List<Expression> children) {
public ArrayLast withChildren(List<Expression> children) {
return new ArrayLast(getFunctionParams(children));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayLast(this, context);
}

// array_last(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), -1)
@Override
public Expression rewriteWhenAnalyze() {
return new ElementAt(new ArrayFilter(getArgument(0), getArgument(1)), new BigIntLiteral(-1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ default R visitArrayFilter(ArrayFilter arrayFilter, C context) {
}

default R visitArrayFirst(ArrayFirst arrayFirst, C context) {
return visitElementAt(arrayFirst, context);
return visitScalarFunction(arrayFirst, context);
}

default R visitArrayFirstIndex(ArrayFirstIndex arrayFirstIndex, C context) {
Expand All @@ -715,7 +715,7 @@ default R visitArrayJoin(ArrayJoin arrayJoin, C context) {
}

default R visitArrayLast(ArrayLast arrayLast, C context) {
return visitElementAt(arrayLast, context);
return visitScalarFunction(arrayLast, context);
}

default R visitArrayLastIndex(ArrayLastIndex arrayLastIndex, C context) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BooleanType;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class ArrayFirstLastTest extends ExpressionRewriteTestHelper {

@Test
public void testArrayFirstLambdaResultCanCastToBoolean() {
assertLambdaResultCastToBoolean("array_first(x -> x, [0, 1])", 1);
}

@Test
public void testArrayLastLambdaResultCanCastToBoolean() {
assertLambdaResultCastToBoolean("array_last(x -> x, [0, 1])", -1);
}

private void assertLambdaResultCastToBoolean(String sql, long expectedIndex) {
Expression analyzed = typeCoercion(PARSER.parseExpression(sql));
Assertions.assertTrue(analyzed instanceof ElementAt);

ElementAt elementAt = (ElementAt) analyzed;
Assertions.assertTrue(elementAt.left() instanceof ArrayFilter);
Assertions.assertEquals(expectedIndex, ((BigIntLiteral) elementAt.right()).getValue());

ArrayFilter arrayFilter = (ArrayFilter) elementAt.left();
Expression filterResult = arrayFilter.child(1);
Assertions.assertTrue(filterResult instanceof Cast);
Assertions.assertEquals(ArrayType.of(BooleanType.INSTANCE), filterResult.getDataType());
Assertions.assertTrue(((Cast) filterResult).child() instanceof ArrayMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ b
-- !select_05 --
10.2

-- !select_lambda_result_cast --
1

-- !select_06 --
0 [2] ["123", "124", "125"]
1 [1, 2, 3, 4, 5] ["234", "124", "125"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ c
-- !select_05 --
5.3

-- !select_lambda_result_cast --
2

-- !select_06 --
0 [2] ["123", "124", "125"]
1 [1, 2, 3, 4, 5] ["234", "124", "125"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ suite("test_array_first") {
qt_select_03 " select array_first(x -> x>=5,[1,2,3,4,5]);"
qt_select_04 " select array_first(x -> x > 'abc', ['a','b','c']);"
qt_select_05 " select array_first(x -> x > 5.2 , [10.2, 5.3, 4]);"
qt_select_lambda_result_cast " select array_first(x -> x, [0, 1, 2]);"

qt_select_06 "select * from ${tableName} order by id;"

qt_select_07 " select array_first(x->x>3,c_array1), array_first(x-> x>'124',c_array2) from test_array_first order by id;"
sql "DROP TABLE IF EXISTS ${tableName}"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ suite("test_array_last") {
qt_select_03 " select array_last(x -> x>=5,[1,2,3,4,5]);"
qt_select_04 " select array_last(x -> x > 'abc', ['a','b','c']);"
qt_select_05 " select array_last(x -> x > 5.2 , [10.2, 5.3, 4]);"
qt_select_lambda_result_cast " select array_last(x -> x, [0, 1, 2]);"

qt_select_06 "select * from ${tableName} order by id;"

qt_select_07 " select array_last(x->x>3,c_array1), array_last(x-> x>'124',c_array2) from test_array_last order by id;"
sql "DROP TABLE IF EXISTS ${tableName}"
}
}
Loading