Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void initSession() {
outputBridge = VectorOutputBridge.Factory.create(outClass);
}
sessionResource = new GlutenSessionResource();
GlutenSessionResources.getInstance().addSessionResource(id, sessionResource);
GlutenTaskSessionContext.addSessionResource(id, sessionResource);
inputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue();
// add a mock input as velox not allow the source is empty.
if (inputType == null) {
Expand Down Expand Up @@ -283,6 +283,9 @@ public void close() throws Exception {
inputQueue.close();
}
},
() -> {
GlutenTaskSessionContext.unregisterSessionResource(id);
},
() -> {
if (sessionResource != null) {
sessionResource.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;

import java.util.HashMap;
import java.util.Map;

// Manage the session and resource for Velox.
class GlutenSessionResource {
private Session session;
Expand Down Expand Up @@ -77,26 +74,3 @@ public void setKeyedStateBackend(KeyedStateBackend<?> keyedStateBackend) {
this.keyedStateBackend = keyedStateBackend;
}
}

public class GlutenSessionResources {
private static final GlutenSessionResources instance = new GlutenSessionResources();
private Map<String, GlutenSessionResource> sessionResources = new HashMap<>();

private GlutenSessionResources() {}

public static GlutenSessionResources getInstance() {
return instance;
}

public GlutenSessionResource getSessionResource(String id) {
return sessionResources.get(id);
}

public void addSessionResource(String id, GlutenSessionResource sessionResource) {
sessionResources.put(id, sessionResource);
}

public Session getSession(String id) {
return sessionResources.get(id).getSession();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ public void close() throws Exception {
task = null;
}
if (sessionResource != null) {
GlutenTaskSessionContext.unregisterSessionResource(id);
sessionResource.close();
sessionResource = null;
}
Expand Down Expand Up @@ -240,7 +241,7 @@ private void initSession() {
}

sessionResource = new GlutenSessionResource();
GlutenSessionResources.getInstance().addSessionResource(id, sessionResource);
GlutenTaskSessionContext.addSessionResource(id, sessionResource);
Session session = sessionResource.getSession();
query =
new Query(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.table.runtime.operators;

import io.github.zhztheplayer.velox4j.session.Session;

import java.util.HashMap;
import java.util.Map;

/** Task-thread-local Gluten runtime context used by operators and serializers. */
public final class GlutenTaskSessionContext {
private static final ThreadLocal<Map<String, GlutenSessionResource>> SESSION_RESOURCES =
ThreadLocal.withInitial(HashMap::new);

private GlutenTaskSessionContext() {}

public static GlutenSessionResource getSessionResource(String id) {
return SESSION_RESOURCES.get().get(id);
}

public static void addSessionResource(String id, GlutenSessionResource sessionResource) {
SESSION_RESOURCES.get().put(id, sessionResource);
}

public static void registerSessionResource(String operatorId, GlutenSessionResource resource) {
addSessionResource(operatorId, resource);
}

public static void unregisterSessionResource(String operatorId) {
Map<String, GlutenSessionResource> resources = SESSION_RESOURCES.get();
resources.remove(operatorId);
if (resources.isEmpty()) {
SESSION_RESOURCES.remove();
}
}

public static Session getSession(String operatorId) {
Map<String, GlutenSessionResource> resources = SESSION_RESOURCES.get();
GlutenSessionResource resource = resources.get(operatorId);
if (resource == null) {
throw new IllegalStateException(
"No Gluten session registered on the current task thread for operator "
+ operatorId
+ ". Registered operators: "
+ resources.keySet());
}
Session session = resource.getSession();
if (session == null) {
throw new IllegalStateException(
"Gluten session is already closed for operator " + operatorId);
}
return session;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ public void close() throws Exception {
task.close();
}
},
() -> {
GlutenTaskSessionContext.unregisterSessionResource(getId());
},
() -> {
if (sessionResource != null) {
sessionResource.close();
Expand Down Expand Up @@ -324,7 +327,7 @@ private void initSession() {
}

sessionResource = new GlutenSessionResource();
GlutenSessionResources.getInstance().addSessionResource(getId(), sessionResource);
GlutenTaskSessionContext.addSessionResource(getId(), sessionResource);
leftInputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue();
rightInputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.table.runtime.typeutils;

import org.apache.gluten.streaming.api.operators.GlutenOperator;
import org.apache.gluten.table.runtime.operators.GlutenSessionResources;
import org.apache.gluten.table.runtime.operators.GlutenTaskSessionContext;

import io.github.zhztheplayer.velox4j.data.RowVector;
import io.github.zhztheplayer.velox4j.stateful.StatefulRecord;
Expand All @@ -37,16 +37,20 @@
public class GlutenStatefulRecordSerializer extends TypeSerializer<StatefulRecord> {
private static final long serialVersionUID = 1L;
private final RowType rowType;
private final GlutenOperator operator;
private final String operatorId;

public GlutenStatefulRecordSerializer(RowType rowType, GlutenOperator operator) {
this(rowType, operator.getId());
}

private GlutenStatefulRecordSerializer(RowType rowType, String operatorId) {
this.rowType = rowType;
this.operator = operator;
this.operatorId = operatorId;
}

@Override
public TypeSerializer<StatefulRecord> duplicate() {
return new GlutenStatefulRecordSerializer(rowType, operator);
return new GlutenStatefulRecordSerializer(rowType, operatorId);
}

@Override
Expand All @@ -67,12 +71,11 @@ public StatefulRecord deserialize(DataInputView source) throws IOException {
byte[] str = new byte[len];
source.readFully(str);
RowVector rowVector =
GlutenSessionResources.getInstance()
.getSession(operator.getId())
GlutenTaskSessionContext.getSession(operatorId)
.baseVectorOps()
.deserializeOne(new String(str))
.asRowVector();
StatefulRecord record = new StatefulRecord(operator.getId(), rowVector.id(), 0, false, -1);
StatefulRecord record = new StatefulRecord(operatorId, rowVector.id(), 0, false, -1);
record.setRowVector(rowVector);
return record;
}
Expand Down Expand Up @@ -130,7 +133,7 @@ public int getLength() {

@Override
public TypeSerializerSnapshot<StatefulRecord> snapshotConfiguration() {
return new RowVectorSerializerSnapshot(rowType, operator);
return new RowVectorSerializerSnapshot(rowType, operatorId);
}

/** {@link TypeSerializerSnapshot} for Gluten RowVector.. */
Expand All @@ -139,16 +142,16 @@ public static final class RowVectorSerializerSnapshot
private static final int CURRENT_VERSION = 1;

private RowType rowType;
private GlutenOperator operator;
private String operatorId;

@SuppressWarnings("unused")
public RowVectorSerializerSnapshot() {
// this constructor is used when restoring from a checkpoint/savepoint.
}

RowVectorSerializerSnapshot(RowType rowType, GlutenOperator operator) {
RowVectorSerializerSnapshot(RowType rowType, String operatorId) {
this.rowType = rowType;
this.operator = operator;
this.operatorId = operatorId;
}

@Override
Expand All @@ -165,7 +168,7 @@ public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCode

@Override
public GlutenStatefulRecordSerializer restoreSerializer() {
return new GlutenStatefulRecordSerializer(rowType, operator);
return new GlutenStatefulRecordSerializer(rowType, operatorId);
}

@Override
Expand Down
Loading
Loading