Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/flink.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
export fmt_SOURCE=BUNDLED
export folly_SOURCE=BUNDLED
git clone -b gluten-0530 https://github.com/bigo-sg/velox4j.git
cd velox4j && git reset --hard 115edf79d265a61c30d45dfcc6ce932ad92378ca
cd velox4j && git reset --hard 72ba7a565104e6da6eb34bb03aef2c8f39c4c0e2
git apply $GITHUB_WORKSPACE/gluten-flink/patches/fix-velox4j.patch
$GITHUB_WORKSPACE/build/mvn clean install -DskipTests -Dgpg.skip -Dspotless.skip=true
cd ..
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@
import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode;
import io.github.zhztheplayer.velox4j.query.Query;
import io.github.zhztheplayer.velox4j.query.SerialTask;
import io.github.zhztheplayer.velox4j.serde.Serde;
import io.github.zhztheplayer.velox4j.session.Session;
import io.github.zhztheplayer.velox4j.stateful.StatefulElement;
import io.github.zhztheplayer.velox4j.stateful.StatefulRecord;
import io.github.zhztheplayer.velox4j.stateful.StatefulWatermark;
import io.github.zhztheplayer.velox4j.type.RowType;

import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
Expand All @@ -53,6 +56,7 @@
public class GlutenSourceFunction<OUT> extends RichParallelSourceFunction<OUT>
implements CheckpointedFunction {
private static final Logger LOG = LoggerFactory.getLogger(GlutenSourceFunction.class);
private static final String SOURCE_SPLIT_STATE_NAME = "gluten-source-split-state";

private final StatefulPlanNode planNode;
private final Map<String, RowType> outputTypes;
Expand All @@ -65,6 +69,8 @@ public class GlutenSourceFunction<OUT> extends RichParallelSourceFunction<OUT>
private SerialTask task;
private SourceTaskMetrics taskMetrics;
private final Class<OUT> outClass;
private transient ListState<String> sourceSplitState;
private transient ConnectorSplit restoredSplit;

public GlutenSourceFunction(
StatefulPlanNode planNode,
Expand Down Expand Up @@ -205,14 +211,25 @@ public void close() throws Exception {

@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
// TODO: implement it
this.task.snapshotState(0);
if (sourceSplitState != null) {
sourceSplitState.clear();
for (String splitState : task.snapshotSourceState()) {
sourceSplitState.add(splitState);
}
}
}

@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
sourceSplitState =
context
.getOperatorStateStore()
.getListState(new ListStateDescriptor<>(SOURCE_SPLIT_STATE_NAME, String.class));
if (context.isRestored()) {
restoredSplit = readRestoredSplit();
}
initSession();
// TODO: implement it
this.task.initializeState(0, null);
}

Expand All @@ -239,8 +256,22 @@ private void initSession() {
VeloxQueryConfig.getConfig(getRuntimeContext()),
VeloxConnectorConfig.getConfig(getRuntimeContext()));
task = session.queryOps().execute(query);
task.addSplit(id, split);
task.addSplit(id, restoredSplit != null ? restoredSplit : split);
task.noMoreSplits(id);
taskMetrics = new SourceTaskMetrics(getRuntimeContext().getMetricGroup());
}

private ConnectorSplit readRestoredSplit() throws Exception {
ConnectorSplit result = null;
for (String splitState : sourceSplitState.get()) {
if (splitState == null || splitState.isEmpty()) {
continue;
}
if (result != null) {
throw new IllegalStateException("Only one restored source split is supported.");
}
result = Serde.fromJson(splitState, ConnectorSplit.class);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,13 @@ private interface VLTypeConverter {
logicalType -> new io.github.zhztheplayer.velox4j.type.VarCharType()),
Map.entry(
CharType.class, logicalType -> new io.github.zhztheplayer.velox4j.type.VarCharType()),
// TODO: may need precision
Map.entry(
TimestampType.class,
logicalType -> new io.github.zhztheplayer.velox4j.type.TimestampType()),
logicalType -> {
TimestampType timestampType = (TimestampType) logicalType;
return new io.github.zhztheplayer.velox4j.type.TimestampType(
timestampType.getPrecision(), false);
}),
Map.entry(
DecimalType.class,
logicalType -> {
Expand Down Expand Up @@ -110,10 +113,13 @@ private interface VLTypeConverter {
// Map the flink's `TimestampLTZ` type to velox `Timestamp` type. And the timezone would
// be specified by using flink's table config `LOCAL_TIME_ZONE`, which would be passed to
// velox's `session_timezone` config.
// TODO: may need precision
Map.entry(
LocalZonedTimestampType.class,
logicalType -> new io.github.zhztheplayer.velox4j.type.TimestampType()),
logicalType -> {
LocalZonedTimestampType timestampType = (LocalZonedTimestampType) logicalType;
return new io.github.zhztheplayer.velox4j.type.TimestampType(
timestampType.getPrecision(), true);
}),
Map.entry(
TinyIntType.class,
logicalType -> new io.github.zhztheplayer.velox4j.type.TinyIntType()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.TimeStampMicroTZVector;
import org.apache.arrow.vector.TimeStampMicroVector;
import org.apache.arrow.vector.TimeStampMilliTZVector;
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.TimeStampNanoTZVector;
import org.apache.arrow.vector.TimeStampNanoVector;
import org.apache.arrow.vector.TimeStampSecTZVector;
import org.apache.arrow.vector.TimeStampSecVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
Expand Down Expand Up @@ -63,7 +70,14 @@ private interface AccessorBuilder {
Map.entry(StructVector.class, vector -> new StructVectorAccessor(vector)),
Map.entry(ListVector.class, vector -> new ListVectorAccessor(vector)),
Map.entry(DateDayVector.class, vector -> new DateDayVectorAccessor(vector)),
Map.entry(TimeStampMicroVector.class, vector -> new TimeStampMicroVectorAccessor(vector)),
Map.entry(TimeStampSecVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(TimeStampSecTZVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(TimeStampMilliVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(TimeStampMilliTZVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(TimeStampMicroVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(TimeStampMicroTZVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(TimeStampNanoVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(TimeStampNanoTZVector.class, vector -> new TimeStampVectorAccessor(vector)),
Map.entry(MapVector.class, vector -> new MapVectorAccessor(vector)));

public static ArrowVectorAccessor create(FieldVector vector) {
Expand Down Expand Up @@ -268,15 +282,38 @@ protected Object getImpl(int rowIndex) {
}
}

class TimeStampMicroVectorAccessor extends BaseArrowVectorAccessor<TimeStampMicroVector> {
class TimeStampVectorAccessor extends BaseArrowVectorAccessor<FieldVector> {

public TimeStampMicroVectorAccessor(FieldVector vector) {
public TimeStampVectorAccessor(FieldVector vector) {
super(vector);
}

@Override
public Object getImpl(int rowIndex) {
long milliseconds = typedVector.get(rowIndex) / 1000;
return TimestampData.fromEpochMillis(milliseconds);
if (typedVector instanceof TimeStampSecVector) {
return TimestampData.fromEpochMillis(((TimeStampSecVector) typedVector).get(rowIndex) * 1000);
} else if (typedVector instanceof TimeStampSecTZVector) {
return TimestampData.fromEpochMillis(
((TimeStampSecTZVector) typedVector).get(rowIndex) * 1000);
} else if (typedVector instanceof TimeStampMilliVector) {
return TimestampData.fromEpochMillis(((TimeStampMilliVector) typedVector).get(rowIndex));
} else if (typedVector instanceof TimeStampMilliTZVector) {
return TimestampData.fromEpochMillis(((TimeStampMilliTZVector) typedVector).get(rowIndex));
} else if (typedVector instanceof TimeStampMicroVector) {
return fromSubMillis(((TimeStampMicroVector) typedVector).get(rowIndex), 1000, 1000);
} else if (typedVector instanceof TimeStampMicroTZVector) {
return fromSubMillis(((TimeStampMicroTZVector) typedVector).get(rowIndex), 1000, 1000);
} else if (typedVector instanceof TimeStampNanoVector) {
return fromSubMillis(((TimeStampNanoVector) typedVector).get(rowIndex), 1000000, 1);
} else if (typedVector instanceof TimeStampNanoTZVector) {
return fromSubMillis(((TimeStampNanoTZVector) typedVector).get(rowIndex), 1000000, 1);
}
throw new IllegalStateException("Unexpected vector type: " + typedVector.getClass().getName());
}

private TimestampData fromSubMillis(long value, int unitsPerMillisecond, int nanosPerUnit) {
long milliseconds = Math.floorDiv(value, unitsPerMillisecond);
int nanoOfMillisecond = (int) Math.floorMod(value, unitsPerMillisecond) * nanosPerUnit;
return TimestampData.fromEpochMillis(milliseconds, nanoOfMillisecond);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,14 @@ private interface ArrowTypeConverter {
Map.entry(VarCharType.class, (dataType, timeZoneId) -> ArrowType.Utf8.INSTANCE),
Map.entry(
TimestampType.class,
(dataType, timeZoneId) ->
new ArrowType.Timestamp(
TimeUnit.MILLISECOND, timeZoneId == null ? "UTC" : timeZoneId)),
(dataType, timeZoneId) -> {
TimestampType timestampType = (TimestampType) dataType;
return new ArrowType.Timestamp(
toTimeUnit(timestampType.getPrecision()),
timestampType.isLocalZoned()
? (timeZoneId == null ? "UTC" : timeZoneId)
: null);
}),
Map.entry(DateType.class, (dataType, timeZoneId) -> new ArrowType.Date(DateUnit.DAY)),
Map.entry(
DecimalType.class,
Expand All @@ -181,6 +186,15 @@ private static ArrowType toArrowType(Type dataType, String timeZoneId) {
return converter.convert(dataType, timeZoneId);
}

private static TimeUnit toTimeUnit(int precision) {
if (precision <= 3) {
return TimeUnit.MILLISECOND;
} else if (precision <= 6) {
return TimeUnit.MICROSECOND;
}
return TimeUnit.NANOSECOND;
}

private static Field toArrowField(
String name, Type dataType, boolean nullable, String timeZoneId) {
if (dataType instanceof ArrayType) {
Expand Down Expand Up @@ -396,42 +410,62 @@ protected void setValue(int index, byte[] value) {
}

class TimestampVectorWriter extends BaseVectorWriter<FieldVector, Long> {
private final int precision = 3; // Millisecond precision
private final int precision;

public TimestampVectorWriter(Type fieldType, BufferAllocator allocator, FieldVector vector) {
super(vector);
// Verify that the vector is a timestamp vector (either TimeStampMilliVector or
// TimeStampMilliTZVector)
if (!(vector instanceof TimeStampMilliVector) && !(vector instanceof TimeStampMilliTZVector)) {
this.precision = ((TimestampType) fieldType).getPrecision();
if (!(vector instanceof TimeStampMilliVector)
&& !(vector instanceof TimeStampMilliTZVector)
&& !(vector instanceof TimeStampMicroVector)
&& !(vector instanceof TimeStampMicroTZVector)
&& !(vector instanceof TimeStampNanoVector)
&& !(vector instanceof TimeStampNanoTZVector)) {
throw new IllegalArgumentException(
"Expected TimeStampMilliVector or TimeStampMilliTZVector, but got: "
+ vector.getClass().getName());
"Expected timestamp vector, but got: " + vector.getClass().getName());
}
}

@Override
protected Long getValue(RowData rowData, int fieldIndex) {
return rowData.getTimestamp(fieldIndex, precision).getMillisecond();
return toArrowValue(rowData.getTimestamp(fieldIndex, precision));
}

@Override
protected Long getValue(ArrayData arrayData, int index) {
return arrayData.getTimestamp(index, precision).getMillisecond();
return toArrowValue(arrayData.getTimestamp(index, precision));
}

@Override
protected void setValue(int index, Long value) {

// Both TimeStampMilliVector and TimeStampMilliTZVector support setSafe with long value
if (this.typedVector instanceof TimeStampMilliVector) {
((TimeStampMilliVector) this.typedVector).setSafe(index, value);
} else if (this.typedVector instanceof TimeStampMilliTZVector) {
((TimeStampMilliTZVector) this.typedVector).setSafe(index, value);
} else if (this.typedVector instanceof TimeStampMicroVector) {
((TimeStampMicroVector) this.typedVector).setSafe(index, value);
} else if (this.typedVector instanceof TimeStampMicroTZVector) {
((TimeStampMicroTZVector) this.typedVector).setSafe(index, value);
} else if (this.typedVector instanceof TimeStampNanoVector) {
((TimeStampNanoVector) this.typedVector).setSafe(index, value);
} else if (this.typedVector instanceof TimeStampNanoTZVector) {
((TimeStampNanoTZVector) this.typedVector).setSafe(index, value);
} else {
throw new IllegalStateException(
"Unexpected vector type: " + this.typedVector.getClass().getName());
}
}

private Long toArrowValue(org.apache.flink.table.data.TimestampData value) {
long milliseconds = value.getMillisecond();
int nanoOfMillisecond = value.getNanoOfMillisecond();
if (precision <= 3) {
return milliseconds;
} else if (precision <= 6) {
return milliseconds * 1000 + nanoOfMillisecond / 1000;
}
return milliseconds * 1000000 + nanoOfMillisecond;
}
}

class DateDayVectorWriter extends BaseVectorWriter<DateDayVector, Integer> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,34 @@ void testDateFormat() {
"+I[2, 2025-02-28 12:12:12, 2024-02-28 20:12:12]"));
}

@Test
void testTimestampMicros() {
List<Row> rows =
Arrays.asList(
Row.of(1, LocalDateTime.of(2024, 12, 31, 12, 12, 12, 123456000)),
Row.of(2, LocalDateTime.of(2025, 2, 28, 12, 12, 12, 654321000)));
createSimpleBoundedValuesTable("timestampMicrosTable", "a int, b Timestamp(6)", rows);

String query = "select a, cast(b as string) from timestampMicrosTable";
runAndCheck(
query,
Arrays.asList("+I[1, 2024-12-31 12:12:12.123456]", "+I[2, 2025-02-28 12:12:12.654321]"));

rows =
Arrays.asList(
Row.of(
1, LocalDateTime.of(2024, 12, 31, 12, 12, 12, 123456000).toInstant(ZoneOffset.UTC)),
Row.of(
2, LocalDateTime.of(2025, 2, 28, 12, 12, 12, 654321000).toInstant(ZoneOffset.UTC)));
createSimpleBoundedValuesTable("timestampLtzMicrosTable", "a int, b Timestamp_LTZ(6)", rows);

query = "select a, cast(b as string) from timestampLtzMicrosTable";
tEnv().getConfig().setLocalTimeZone(ZoneId.of("Asia/Shanghai"));
runAndCheck(
query,
Arrays.asList("+I[1, 2024-12-31 20:12:12.123456]", "+I[2, 2025-02-28 20:12:12.654321]"));
}

@Test
void testNotEqual() {
List<Row> rows =
Expand Down
Loading
Loading