diff --git a/.github/workflows/flink.yml b/.github/workflows/flink.yml index cf607f5f304..7c25cdabdb5 100644 --- a/.github/workflows/flink.yml +++ b/.github/workflows/flink.yml @@ -70,7 +70,7 @@ jobs: export fmt_SOURCE=BUNDLED export folly_SOURCE=BUNDLED git clone -b gluten-0530 https://github.com/bigo-sg/velox4j.git - cd velox4j && git reset --hard 115edf79d265a61c30d45dfcc6ce932ad92378ca + cd velox4j && git reset --hard 72ba7a565104e6da6eb34bb03aef2c8f39c4c0e2 git apply $GITHUB_WORKSPACE/gluten-flink/patches/fix-velox4j.patch $GITHUB_WORKSPACE/build/mvn clean install -DskipTests -Dgpg.skip -Dspotless.skip=true cd .. diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java index 53f36fcf67c..c1cd0bcfd14 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java @@ -26,12 +26,15 @@ import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; import io.github.zhztheplayer.velox4j.query.Query; import io.github.zhztheplayer.velox4j.query.SerialTask; +import io.github.zhztheplayer.velox4j.serde.Serde; import io.github.zhztheplayer.velox4j.session.Session; import io.github.zhztheplayer.velox4j.stateful.StatefulElement; import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; import io.github.zhztheplayer.velox4j.stateful.StatefulWatermark; import io.github.zhztheplayer.velox4j.type.RowType; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; @@ -53,6 +56,7 @@ public class GlutenSourceFunction extends RichParallelSourceFunction implements CheckpointedFunction { private static final Logger LOG = LoggerFactory.getLogger(GlutenSourceFunction.class); + private static final String SOURCE_SPLIT_STATE_NAME = "gluten-source-split-state"; private final StatefulPlanNode planNode; private final Map outputTypes; @@ -65,6 +69,8 @@ public class GlutenSourceFunction extends RichParallelSourceFunction private SerialTask task; private SourceTaskMetrics taskMetrics; private final Class outClass; + private transient ListState sourceSplitState; + private transient ConnectorSplit restoredSplit; public GlutenSourceFunction( StatefulPlanNode planNode, @@ -205,14 +211,25 @@ public void close() throws Exception { @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { - // TODO: implement it this.task.snapshotState(0); + if (sourceSplitState != null) { + sourceSplitState.clear(); + for (String splitState : task.snapshotSourceState()) { + sourceSplitState.add(splitState); + } + } } @Override public void initializeState(FunctionInitializationContext context) throws Exception { + sourceSplitState = + context + .getOperatorStateStore() + .getListState(new ListStateDescriptor<>(SOURCE_SPLIT_STATE_NAME, String.class)); + if (context.isRestored()) { + restoredSplit = readRestoredSplit(); + } initSession(); - // TODO: implement it this.task.initializeState(0, null); } @@ -239,8 +256,22 @@ private void initSession() { VeloxQueryConfig.getConfig(getRuntimeContext()), VeloxConnectorConfig.getConfig(getRuntimeContext())); task = session.queryOps().execute(query); - task.addSplit(id, split); + task.addSplit(id, restoredSplit != null ? restoredSplit : split); task.noMoreSplits(id); taskMetrics = new SourceTaskMetrics(getRuntimeContext().getMetricGroup()); } + + private ConnectorSplit readRestoredSplit() throws Exception { + ConnectorSplit result = null; + for (String splitState : sourceSplitState.get()) { + if (splitState == null || splitState.isEmpty()) { + continue; + } + if (result != null) { + throw new IllegalStateException("Only one restored source split is supported."); + } + result = Serde.fromJson(splitState, ConnectorSplit.class); + } + return result; + } } diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java index 8d684fab52e..346e6abaa90 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java @@ -65,10 +65,13 @@ private interface VLTypeConverter { logicalType -> new io.github.zhztheplayer.velox4j.type.VarCharType()), Map.entry( CharType.class, logicalType -> new io.github.zhztheplayer.velox4j.type.VarCharType()), - // TODO: may need precision Map.entry( TimestampType.class, - logicalType -> new io.github.zhztheplayer.velox4j.type.TimestampType()), + logicalType -> { + TimestampType timestampType = (TimestampType) logicalType; + return new io.github.zhztheplayer.velox4j.type.TimestampType( + timestampType.getPrecision(), false); + }), Map.entry( DecimalType.class, logicalType -> { @@ -110,10 +113,13 @@ private interface VLTypeConverter { // Map the flink's `TimestampLTZ` type to velox `Timestamp` type. And the timezone would // be specified by using flink's table config `LOCAL_TIME_ZONE`, which would be passed to // velox's `session_timezone` config. - // TODO: may need precision Map.entry( LocalZonedTimestampType.class, - logicalType -> new io.github.zhztheplayer.velox4j.type.TimestampType()), + logicalType -> { + LocalZonedTimestampType timestampType = (LocalZonedTimestampType) logicalType; + return new io.github.zhztheplayer.velox4j.type.TimestampType( + timestampType.getPrecision(), true); + }), Map.entry( TinyIntType.class, logicalType -> new io.github.zhztheplayer.velox4j.type.TinyIntType()), diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java index c397e6fada7..eb682f0879e 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java @@ -30,7 +30,14 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; @@ -63,7 +70,14 @@ private interface AccessorBuilder { Map.entry(StructVector.class, vector -> new StructVectorAccessor(vector)), Map.entry(ListVector.class, vector -> new ListVectorAccessor(vector)), Map.entry(DateDayVector.class, vector -> new DateDayVectorAccessor(vector)), - Map.entry(TimeStampMicroVector.class, vector -> new TimeStampMicroVectorAccessor(vector)), + Map.entry(TimeStampSecVector.class, vector -> new TimeStampVectorAccessor(vector)), + Map.entry(TimeStampSecTZVector.class, vector -> new TimeStampVectorAccessor(vector)), + Map.entry(TimeStampMilliVector.class, vector -> new TimeStampVectorAccessor(vector)), + Map.entry(TimeStampMilliTZVector.class, vector -> new TimeStampVectorAccessor(vector)), + Map.entry(TimeStampMicroVector.class, vector -> new TimeStampVectorAccessor(vector)), + Map.entry(TimeStampMicroTZVector.class, vector -> new TimeStampVectorAccessor(vector)), + Map.entry(TimeStampNanoVector.class, vector -> new TimeStampVectorAccessor(vector)), + Map.entry(TimeStampNanoTZVector.class, vector -> new TimeStampVectorAccessor(vector)), Map.entry(MapVector.class, vector -> new MapVectorAccessor(vector))); public static ArrowVectorAccessor create(FieldVector vector) { @@ -268,15 +282,38 @@ protected Object getImpl(int rowIndex) { } } -class TimeStampMicroVectorAccessor extends BaseArrowVectorAccessor { +class TimeStampVectorAccessor extends BaseArrowVectorAccessor { - public TimeStampMicroVectorAccessor(FieldVector vector) { + public TimeStampVectorAccessor(FieldVector vector) { super(vector); } @Override public Object getImpl(int rowIndex) { - long milliseconds = typedVector.get(rowIndex) / 1000; - return TimestampData.fromEpochMillis(milliseconds); + if (typedVector instanceof TimeStampSecVector) { + return TimestampData.fromEpochMillis(((TimeStampSecVector) typedVector).get(rowIndex) * 1000); + } else if (typedVector instanceof TimeStampSecTZVector) { + return TimestampData.fromEpochMillis( + ((TimeStampSecTZVector) typedVector).get(rowIndex) * 1000); + } else if (typedVector instanceof TimeStampMilliVector) { + return TimestampData.fromEpochMillis(((TimeStampMilliVector) typedVector).get(rowIndex)); + } else if (typedVector instanceof TimeStampMilliTZVector) { + return TimestampData.fromEpochMillis(((TimeStampMilliTZVector) typedVector).get(rowIndex)); + } else if (typedVector instanceof TimeStampMicroVector) { + return fromSubMillis(((TimeStampMicroVector) typedVector).get(rowIndex), 1000, 1000); + } else if (typedVector instanceof TimeStampMicroTZVector) { + return fromSubMillis(((TimeStampMicroTZVector) typedVector).get(rowIndex), 1000, 1000); + } else if (typedVector instanceof TimeStampNanoVector) { + return fromSubMillis(((TimeStampNanoVector) typedVector).get(rowIndex), 1000000, 1); + } else if (typedVector instanceof TimeStampNanoTZVector) { + return fromSubMillis(((TimeStampNanoTZVector) typedVector).get(rowIndex), 1000000, 1); + } + throw new IllegalStateException("Unexpected vector type: " + typedVector.getClass().getName()); + } + + private TimestampData fromSubMillis(long value, int unitsPerMillisecond, int nanosPerUnit) { + long milliseconds = Math.floorDiv(value, unitsPerMillisecond); + int nanoOfMillisecond = (int) Math.floorMod(value, unitsPerMillisecond) * nanosPerUnit; + return TimestampData.fromEpochMillis(milliseconds, nanoOfMillisecond); } } diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java index e6a511a4921..c144c43fde0 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java @@ -161,9 +161,14 @@ private interface ArrowTypeConverter { Map.entry(VarCharType.class, (dataType, timeZoneId) -> ArrowType.Utf8.INSTANCE), Map.entry( TimestampType.class, - (dataType, timeZoneId) -> - new ArrowType.Timestamp( - TimeUnit.MILLISECOND, timeZoneId == null ? "UTC" : timeZoneId)), + (dataType, timeZoneId) -> { + TimestampType timestampType = (TimestampType) dataType; + return new ArrowType.Timestamp( + toTimeUnit(timestampType.getPrecision()), + timestampType.isLocalZoned() + ? (timeZoneId == null ? "UTC" : timeZoneId) + : null); + }), Map.entry(DateType.class, (dataType, timeZoneId) -> new ArrowType.Date(DateUnit.DAY)), Map.entry( DecimalType.class, @@ -181,6 +186,15 @@ private static ArrowType toArrowType(Type dataType, String timeZoneId) { return converter.convert(dataType, timeZoneId); } + private static TimeUnit toTimeUnit(int precision) { + if (precision <= 3) { + return TimeUnit.MILLISECOND; + } else if (precision <= 6) { + return TimeUnit.MICROSECOND; + } + return TimeUnit.NANOSECOND; + } + private static Field toArrowField( String name, Type dataType, boolean nullable, String timeZoneId) { if (dataType instanceof ArrayType) { @@ -396,42 +410,62 @@ protected void setValue(int index, byte[] value) { } class TimestampVectorWriter extends BaseVectorWriter { - private final int precision = 3; // Millisecond precision + private final int precision; public TimestampVectorWriter(Type fieldType, BufferAllocator allocator, FieldVector vector) { super(vector); - // Verify that the vector is a timestamp vector (either TimeStampMilliVector or - // TimeStampMilliTZVector) - if (!(vector instanceof TimeStampMilliVector) && !(vector instanceof TimeStampMilliTZVector)) { + this.precision = ((TimestampType) fieldType).getPrecision(); + if (!(vector instanceof TimeStampMilliVector) + && !(vector instanceof TimeStampMilliTZVector) + && !(vector instanceof TimeStampMicroVector) + && !(vector instanceof TimeStampMicroTZVector) + && !(vector instanceof TimeStampNanoVector) + && !(vector instanceof TimeStampNanoTZVector)) { throw new IllegalArgumentException( - "Expected TimeStampMilliVector or TimeStampMilliTZVector, but got: " - + vector.getClass().getName()); + "Expected timestamp vector, but got: " + vector.getClass().getName()); } } @Override protected Long getValue(RowData rowData, int fieldIndex) { - return rowData.getTimestamp(fieldIndex, precision).getMillisecond(); + return toArrowValue(rowData.getTimestamp(fieldIndex, precision)); } @Override protected Long getValue(ArrayData arrayData, int index) { - return arrayData.getTimestamp(index, precision).getMillisecond(); + return toArrowValue(arrayData.getTimestamp(index, precision)); } @Override protected void setValue(int index, Long value) { - - // Both TimeStampMilliVector and TimeStampMilliTZVector support setSafe with long value if (this.typedVector instanceof TimeStampMilliVector) { ((TimeStampMilliVector) this.typedVector).setSafe(index, value); } else if (this.typedVector instanceof TimeStampMilliTZVector) { ((TimeStampMilliTZVector) this.typedVector).setSafe(index, value); + } else if (this.typedVector instanceof TimeStampMicroVector) { + ((TimeStampMicroVector) this.typedVector).setSafe(index, value); + } else if (this.typedVector instanceof TimeStampMicroTZVector) { + ((TimeStampMicroTZVector) this.typedVector).setSafe(index, value); + } else if (this.typedVector instanceof TimeStampNanoVector) { + ((TimeStampNanoVector) this.typedVector).setSafe(index, value); + } else if (this.typedVector instanceof TimeStampNanoTZVector) { + ((TimeStampNanoTZVector) this.typedVector).setSafe(index, value); } else { throw new IllegalStateException( "Unexpected vector type: " + this.typedVector.getClass().getName()); } } + + private Long toArrowValue(org.apache.flink.table.data.TimestampData value) { + long milliseconds = value.getMillisecond(); + int nanoOfMillisecond = value.getNanoOfMillisecond(); + if (precision <= 3) { + return milliseconds; + } else if (precision <= 6) { + return milliseconds * 1000 + nanoOfMillisecond / 1000; + } + return milliseconds * 1000000 + nanoOfMillisecond; + } } class DateDayVectorWriter extends BaseVectorWriter { diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java index ffc2c9f99f4..63deda9fcad 100644 --- a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java @@ -313,6 +313,34 @@ void testDateFormat() { "+I[2, 2025-02-28 12:12:12, 2024-02-28 20:12:12]")); } + @Test + void testTimestampMicros() { + List rows = + Arrays.asList( + Row.of(1, LocalDateTime.of(2024, 12, 31, 12, 12, 12, 123456000)), + Row.of(2, LocalDateTime.of(2025, 2, 28, 12, 12, 12, 654321000))); + createSimpleBoundedValuesTable("timestampMicrosTable", "a int, b Timestamp(6)", rows); + + String query = "select a, cast(b as string) from timestampMicrosTable"; + runAndCheck( + query, + Arrays.asList("+I[1, 2024-12-31 12:12:12.123456]", "+I[2, 2025-02-28 12:12:12.654321]")); + + rows = + Arrays.asList( + Row.of( + 1, LocalDateTime.of(2024, 12, 31, 12, 12, 12, 123456000).toInstant(ZoneOffset.UTC)), + Row.of( + 2, LocalDateTime.of(2025, 2, 28, 12, 12, 12, 654321000).toInstant(ZoneOffset.UTC))); + createSimpleBoundedValuesTable("timestampLtzMicrosTable", "a int, b Timestamp_LTZ(6)", rows); + + query = "select a, cast(b as string) from timestampLtzMicrosTable"; + tEnv().getConfig().setLocalTimeZone(ZoneId.of("Asia/Shanghai")); + runAndCheck( + query, + Arrays.asList("+I[1, 2024-12-31 20:12:12.123456]", "+I[2, 2025-02-28 20:12:12.654321]")); + } + @Test void testNotEqual() { List rows = diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/vectorized/ArrowTimestampVectorTest.java b/gluten-flink/ut/src/test/java/org/apache/gluten/vectorized/ArrowTimestampVectorTest.java new file mode 100644 index 00000000000..2d73cc3a132 --- /dev/null +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/vectorized/ArrowTimestampVectorTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.vectorized; + +import io.github.zhztheplayer.velox4j.type.TimestampType; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.TimestampData; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.junit.jupiter.api.Test; + +import java.time.LocalDateTime; + +import static org.assertj.core.api.Assertions.assertThat; + +class ArrowTimestampVectorTest { + @Test + void preservesTimestampMicros() { + assertPreservesTimestampMicros(new TimestampType(6, false)); + } + + @Test + void preservesTimestampLtzMicros() { + assertPreservesTimestampMicros(new TimestampType(6, true)); + } + + private void assertPreservesTimestampMicros(TimestampType timestampType) { + TimestampData expected = + TimestampData.fromLocalDateTime(LocalDateTime.of(2026, 6, 22, 11, 12, 13, 123456000)); + + try (BufferAllocator allocator = new RootAllocator()) { + ArrowVectorWriter writer = ArrowVectorWriter.create("ts", timestampType, allocator); + try (FieldVector vector = writer.getVector()) { + writer.write(0, GenericRowData.of(expected)); + writer.finish(); + + TimestampData actual = (TimestampData) ArrowVectorAccessor.create(vector).get(0); + + assertThat(actual.getMillisecond()).isEqualTo(expected.getMillisecond()); + assertThat(actual.getNanoOfMillisecond()).isEqualTo(expected.getNanoOfMillisecond()); + } + } + } +}