diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java index 05645fda9ac..09e2b2bb624 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java @@ -120,7 +120,7 @@ void initSession() { outputBridge = VectorOutputBridge.Factory.create(outClass); } sessionResource = new GlutenSessionResource(); - GlutenSessionResources.getInstance().addSessionResource(id, sessionResource); + GlutenTaskSessionContext.addSessionResource(id, sessionResource); inputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue(); // add a mock input as velox not allow the source is empty. if (inputType == null) { @@ -283,6 +283,9 @@ public void close() throws Exception { inputQueue.close(); } }, + () -> { + GlutenTaskSessionContext.unregisterSessionResource(id); + }, () -> { if (sessionResource != null) { sessionResource.close(); diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResources.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResource.java similarity index 76% rename from gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResources.java rename to gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResource.java index cda410f719e..ea38229e950 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResources.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSessionResource.java @@ -26,9 +26,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import java.util.HashMap; -import java.util.Map; - // Manage the session and resource for Velox. class GlutenSessionResource { private Session session; @@ -77,26 +74,3 @@ public void setKeyedStateBackend(KeyedStateBackend keyedStateBackend) { this.keyedStateBackend = keyedStateBackend; } } - -public class GlutenSessionResources { - private static final GlutenSessionResources instance = new GlutenSessionResources(); - private Map sessionResources = new HashMap<>(); - - private GlutenSessionResources() {} - - public static GlutenSessionResources getInstance() { - return instance; - } - - public GlutenSessionResource getSessionResource(String id) { - return sessionResources.get(id); - } - - public void addSessionResource(String id, GlutenSessionResource sessionResource) { - sessionResources.put(id, sessionResource); - } - - public Session getSession(String id) { - return sessionResources.get(id).getSession(); - } -} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java index 76cced93f15..b2caf31425b 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java @@ -199,6 +199,7 @@ public void close() throws Exception { task = null; } if (sessionResource != null) { + GlutenTaskSessionContext.unregisterSessionResource(id); sessionResource.close(); sessionResource = null; } @@ -240,7 +241,7 @@ private void initSession() { } sessionResource = new GlutenSessionResource(); - GlutenSessionResources.getInstance().addSessionResource(id, sessionResource); + GlutenTaskSessionContext.addSessionResource(id, sessionResource); Session session = sessionResource.getSession(); query = new Query( diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTaskSessionContext.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTaskSessionContext.java new file mode 100644 index 00000000000..035567dd94e --- /dev/null +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTaskSessionContext.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.table.runtime.operators; + +import io.github.zhztheplayer.velox4j.session.Session; + +import java.util.HashMap; +import java.util.Map; + +/** Task-thread-local Gluten runtime context used by operators and serializers. */ +public final class GlutenTaskSessionContext { + private static final ThreadLocal> SESSION_RESOURCES = + ThreadLocal.withInitial(HashMap::new); + + private GlutenTaskSessionContext() {} + + public static GlutenSessionResource getSessionResource(String id) { + return SESSION_RESOURCES.get().get(id); + } + + public static void addSessionResource(String id, GlutenSessionResource sessionResource) { + SESSION_RESOURCES.get().put(id, sessionResource); + } + + public static void registerSessionResource(String operatorId, GlutenSessionResource resource) { + addSessionResource(operatorId, resource); + } + + public static void unregisterSessionResource(String operatorId) { + Map resources = SESSION_RESOURCES.get(); + resources.remove(operatorId); + if (resources.isEmpty()) { + SESSION_RESOURCES.remove(); + } + } + + public static Session getSession(String operatorId) { + Map resources = SESSION_RESOURCES.get(); + GlutenSessionResource resource = resources.get(operatorId); + if (resource == null) { + throw new IllegalStateException( + "No Gluten session registered on the current task thread for operator " + + operatorId + + ". Registered operators: " + + resources.keySet()); + } + Session session = resource.getSession(); + if (session == null) { + throw new IllegalStateException( + "Gluten session is already closed for operator " + operatorId); + } + return session; + } +} diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTwoInputOperator.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTwoInputOperator.java index 1251bb084e5..7845e22a5b3 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTwoInputOperator.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenTwoInputOperator.java @@ -258,6 +258,9 @@ public void close() throws Exception { task.close(); } }, + () -> { + GlutenTaskSessionContext.unregisterSessionResource(getId()); + }, () -> { if (sessionResource != null) { sessionResource.close(); @@ -324,7 +327,7 @@ private void initSession() { } sessionResource = new GlutenSessionResource(); - GlutenSessionResources.getInstance().addSessionResource(getId(), sessionResource); + GlutenTaskSessionContext.addSessionResource(getId(), sessionResource); leftInputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue(); rightInputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue(); diff --git a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenStatefulRecordSerializer.java b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenStatefulRecordSerializer.java index 66fdd7bc515..00e8abaab6d 100644 --- a/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenStatefulRecordSerializer.java +++ b/gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/typeutils/GlutenStatefulRecordSerializer.java @@ -17,7 +17,7 @@ package org.apache.gluten.table.runtime.typeutils; import org.apache.gluten.streaming.api.operators.GlutenOperator; -import org.apache.gluten.table.runtime.operators.GlutenSessionResources; +import org.apache.gluten.table.runtime.operators.GlutenTaskSessionContext; import io.github.zhztheplayer.velox4j.data.RowVector; import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; @@ -37,16 +37,20 @@ public class GlutenStatefulRecordSerializer extends TypeSerializer { private static final long serialVersionUID = 1L; private final RowType rowType; - private final GlutenOperator operator; + private final String operatorId; public GlutenStatefulRecordSerializer(RowType rowType, GlutenOperator operator) { + this(rowType, operator.getId()); + } + + private GlutenStatefulRecordSerializer(RowType rowType, String operatorId) { this.rowType = rowType; - this.operator = operator; + this.operatorId = operatorId; } @Override public TypeSerializer duplicate() { - return new GlutenStatefulRecordSerializer(rowType, operator); + return new GlutenStatefulRecordSerializer(rowType, operatorId); } @Override @@ -67,12 +71,11 @@ public StatefulRecord deserialize(DataInputView source) throws IOException { byte[] str = new byte[len]; source.readFully(str); RowVector rowVector = - GlutenSessionResources.getInstance() - .getSession(operator.getId()) + GlutenTaskSessionContext.getSession(operatorId) .baseVectorOps() .deserializeOne(new String(str)) .asRowVector(); - StatefulRecord record = new StatefulRecord(operator.getId(), rowVector.id(), 0, false, -1); + StatefulRecord record = new StatefulRecord(operatorId, rowVector.id(), 0, false, -1); record.setRowVector(rowVector); return record; } @@ -130,7 +133,7 @@ public int getLength() { @Override public TypeSerializerSnapshot snapshotConfiguration() { - return new RowVectorSerializerSnapshot(rowType, operator); + return new RowVectorSerializerSnapshot(rowType, operatorId); } /** {@link TypeSerializerSnapshot} for Gluten RowVector.. */ @@ -139,16 +142,16 @@ public static final class RowVectorSerializerSnapshot private static final int CURRENT_VERSION = 1; private RowType rowType; - private GlutenOperator operator; + private String operatorId; @SuppressWarnings("unused") public RowVectorSerializerSnapshot() { // this constructor is used when restoring from a checkpoint/savepoint. } - RowVectorSerializerSnapshot(RowType rowType, GlutenOperator operator) { + RowVectorSerializerSnapshot(RowType rowType, String operatorId) { this.rowType = rowType; - this.operator = operator; + this.operatorId = operatorId; } @Override @@ -165,7 +168,7 @@ public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCode @Override public GlutenStatefulRecordSerializer restoreSerializer() { - return new GlutenStatefulRecordSerializer(rowType, operator); + return new GlutenStatefulRecordSerializer(rowType, operatorId); } @Override diff --git a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/operators/GlutenStatefulRecordSerializerTest.java b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/operators/GlutenStatefulRecordSerializerTest.java new file mode 100644 index 00000000000..4c38890446f --- /dev/null +++ b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/operators/GlutenStatefulRecordSerializerTest.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.table.runtime.operators; + +import org.apache.gluten.streaming.api.operators.GlutenOperator; +import org.apache.gluten.table.runtime.stream.common.Velox4jEnvironment; +import org.apache.gluten.table.runtime.typeutils.GlutenStatefulRecordSerializer; +import org.apache.gluten.util.LogicalTypeConverter; +import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; + +import io.github.zhztheplayer.velox4j.data.RowVector; +import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; + +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static org.assertj.core.api.Assertions.assertThat; + +public class GlutenStatefulRecordSerializerTest { + private static final String OPERATOR_ID = "serializer-test-operator"; + + private GlutenSessionResource sourceResource; + private GlutenSessionResource targetResource; + + @BeforeAll + public static void initializeVelox() { + Velox4jEnvironment.initializeOnce(); + } + + @AfterEach + public void tearDown() { + GlutenTaskSessionContext.unregisterSessionResource(OPERATOR_ID); + if (sourceResource != null) { + sourceResource.close(); + sourceResource = null; + } + if (targetResource != null) { + targetResource.close(); + targetResource = null; + } + } + + @Test + public void testDeserializeUsesTaskLocalSession() throws Exception { + io.github.zhztheplayer.velox4j.type.RowType veloxRowType = createVeloxRowType(); + GlutenStatefulRecordSerializer serializer = + new GlutenStatefulRecordSerializer(veloxRowType, new TestGlutenOperator(veloxRowType)); + + sourceResource = new GlutenSessionResource(); + RowVector sourceVector = + FlinkRowToVLVectorConvertor.fromRowData( + GenericRowData.of(7, StringData.fromString("Alice")), + sourceResource.getAllocator(), + sourceResource.getSession(), + veloxRowType); + StatefulRecord sourceRecord = new StatefulRecord(OPERATOR_ID, sourceVector.id(), 0, false, 3); + sourceRecord.setRowVector(sourceVector); + + DataOutputSerializer output = new DataOutputSerializer(128); + serializer.serialize(sourceRecord, output); + + targetResource = new GlutenSessionResource(); + GlutenTaskSessionContext.addSessionResource(OPERATOR_ID, targetResource); + + StatefulRecord restoredRecord = + serializer.deserialize(new DataInputDeserializer(output.getCopyOfBuffer())); + List restoredRows = + FlinkRowToVLVectorConvertor.toRowData( + restoredRecord.getRowVector(), targetResource.getAllocator(), veloxRowType); + + assertThat(restoredRecord.getNodeId()).isEqualTo(OPERATOR_ID); + assertThat(restoredRows).hasSize(1); + RowData restoredRow = restoredRows.get(0); + assertThat(restoredRow.getInt(0)).isEqualTo(7); + assertThat(restoredRow.getString(1)).isEqualTo(StringData.fromString("Alice")); + + restoredRecord.close(); + sourceRecord.close(); + } + + @Test + public void testTaskLocalSessionContextIsThreadIsolated() throws Exception { + ExecutorService executor = Executors.newFixedThreadPool(2); + CyclicBarrier barrier = new CyclicBarrier(2); + try { + Callable assertion = + () -> { + assertTaskLocalSessionIsIsolated(barrier); + return null; + }; + Future first = executor.submit(assertion); + Future second = executor.submit(assertion); + + first.get(); + second.get(); + } finally { + executor.shutdownNow(); + } + } + + private static void assertTaskLocalSessionIsIsolated(CyclicBarrier barrier) throws Exception { + GlutenSessionResource resource = new GlutenSessionResource(); + try { + GlutenTaskSessionContext.addSessionResource(OPERATOR_ID, resource); + barrier.await(); + + assertThat(GlutenTaskSessionContext.getSession(OPERATOR_ID)).isSameAs(resource.getSession()); + } finally { + GlutenTaskSessionContext.unregisterSessionResource(OPERATOR_ID); + resource.close(); + } + } + + private static io.github.zhztheplayer.velox4j.type.RowType createVeloxRowType() { + RowType flinkRowType = + RowType.of( + new LogicalType[] {new IntType(), new VarCharType(VarCharType.MAX_LENGTH)}, + new String[] {"id", "name"}); + return (io.github.zhztheplayer.velox4j.type.RowType) + LogicalTypeConverter.toVLType(flinkRowType); + } + + private static class TestGlutenOperator implements GlutenOperator { + private final io.github.zhztheplayer.velox4j.type.RowType rowType; + + private TestGlutenOperator(io.github.zhztheplayer.velox4j.type.RowType rowType) { + this.rowType = rowType; + } + + @Override + public StatefulPlanNode getPlanNode() { + return null; + } + + @Override + public io.github.zhztheplayer.velox4j.type.RowType getInputType() { + return rowType; + } + + @Override + public Map getOutputTypes() { + return Map.of(OPERATOR_ID, rowType); + } + + @Override + public String getId() { + return OPERATOR_ID; + } + } +} diff --git a/gluten-flink/ut/src/test/resources/nexmark/q12.sql b/gluten-flink/ut/src/test/resources/nexmark/q12.sql new file mode 100644 index 00000000000..f2cda4f463b --- /dev/null +++ b/gluten-flink/ut/src/test/resources/nexmark/q12.sql @@ -0,0 +1,20 @@ +CREATE TABLE nexmark_q12 ( + bidder BIGINT, + bid_count BIGINT, + starttime TIMESTAMP(3), + endtime TIMESTAMP(3) +) WITH ( + 'connector' = 'blackhole' +); + +CREATE VIEW B AS SELECT *, PROCTIME() as p_time FROM bid; + +INSERT INTO nexmark_q12 +SELECT + bidder, + count(*) as bid_count, + window_start AS starttime, + window_end AS endtime +FROM TABLE( + TUMBLE(TABLE B, DESCRIPTOR(p_time), INTERVAL '10' SECOND)) +GROUP BY bidder, window_start, window_end;