Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public final class OffsetAwareOutputStream implements Closeable {

private long position;

OffsetAwareOutputStream(OutputStream currentOut, long position) {
public OffsetAwareOutputStream(OutputStream currentOut, long position) {
this.currentOut = checkNotNull(currentOut);
this.position = position;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.channel.FetchedChannelStateReader.SpillSegment;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger;
import org.apache.flink.runtime.jobgraph.JobVertexID;
Expand All @@ -41,6 +42,7 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
Expand Down Expand Up @@ -161,6 +163,48 @@ void writeInput(
}
}

void writeInputFromSpill(
JobVertexID jobVertexID, int subtaskIndex, FetchedChannelStateReader reader) {
if (isDone()) {
try {
reader.close();
} catch (Exception ignored) {
}
return;
}
ChannelStatePendingResult pendingResult =
getChannelStatePendingResult(jobVertexID, subtaskIndex);
runWithChecks(
() -> {
checkState(!pendingResult.isAllInputsReceived());
try {
String action = "ChannelStateCheckpointWriter#writeInputFromSpill";
Optional<SpillSegment> next;
while ((next = reader.nextSegment()).isPresent()) {
SpillSegment seg = next.get();
long offset = checkpointStream.getPos();
try (AutoCloseable ignored =
NetworkActionsLogger.measureIO(action, seg.channelInfo())) {
serializer.writeData(dataStream, seg.bodyStream(), seg.length());
}
long size = checkpointStream.getPos() - offset;
pendingResult
.getInputChannelOffsets()
.computeIfAbsent(
seg.channelInfo(), unused -> new StateContentMetaInfo())
.withDataAdded(offset, size);
NetworkActionsLogger.tracePersist(
action,
seg.length() + " bytes",
seg.channelInfo(),
checkpointId);
}
} finally {
reader.close();
}
});
}

void writeOutput(
JobVertexID jobVertexID, int subtaskIndex, ResultSubpartitionInfo info, Buffer buffer) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,18 @@ public static ChannelStateFilteringHandler createFromContext(
}

/**
* Filters a recovered buffer from the specified virtual channel, returning new buffers
* containing only the records that belong to the current subtask.
*
* <p>One source buffer may produce 0 to N result buffers: 0 if all records are filtered out,
* and potentially more than 1 when a spanning record completes in this buffer. The deserializer
* caches partial record data from previous buffers, so the output may contain data that was not
* in the current source buffer, causing the total output size to exceed one buffer capacity.
* This can happen with any spanning record regardless of its size.
*
* @return filtered buffers, possibly empty if all records were filtered out.
* Filters {@code sourceBuffer} through the virtual channel identified by {@code gateIndex} /
* {@code oldChannelIndex}, appending each surviving record (length-prefixed) into {@code
* outputSerializer}. One call may emit 0..N records depending on the filter result and whether
* records spanning previous buffers complete here. The caller owns the segment boundary.
*/
public List<Buffer> filterAndRewrite(
public void filterAndRewrite(
int gateIndex,
int oldSubtaskIndex,
int oldChannelIndex,
Buffer sourceBuffer,
BufferSupplier bufferSupplier)
throws IOException, InterruptedException {
DataOutputSerializer outputSerializer)
throws IOException {

if (gateIndex < 0 || gateIndex >= gateHandlers.length) {
throw new IllegalStateException(
Expand All @@ -135,8 +129,8 @@ public List<Buffer> filterAndRewrite(
+ gateIndex
+ ". This gate is not a network input and should not have recovered buffers.");
}
return gateHandler.filterAndRewrite(
oldSubtaskIndex, oldChannelIndex, sourceBuffer, bufferSupplier);
gateHandler.filterAndRewrite(
oldSubtaskIndex, oldChannelIndex, sourceBuffer, outputSerializer);
}

/** Returns {@code true} if any virtual channel has a partial (spanning) record pending. */
Expand Down Expand Up @@ -215,7 +209,8 @@ private static <T> GateFilterHandler<T> createGateHandler(
: VirtualChannelRecordFilterFactory.createPassThroughFilter();

RecordDeserializer<DeserializationDelegate<StreamElement>> deserializer =
createDeserializer(filterContext.getTmpDirectories());
new SpillingAdaptiveSpanningRecordDeserializer<>(
filterContext.getTmpDirectories());

VirtualChannel<T> vc = new VirtualChannel<>(deserializer, recordFilter);
gateVirtualChannels.put(key, vc);
Expand Down Expand Up @@ -246,26 +241,10 @@ private static int[] getOldChannelIndexes(RescaleMappings channelMapping, int nu
return oldIndexes.stream().mapToInt(Integer::intValue).toArray();
}

private static RecordDeserializer<DeserializationDelegate<StreamElement>> createDeserializer(
String[] tmpDirectories) {
if (tmpDirectories != null && tmpDirectories.length > 0) {
return new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories);
} else {
String[] defaultDirs = new String[] {System.getProperty("java.io.tmpdir")};
return new SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs);
}
}

// -------------------------------------------------------------------------------------------
// Inner classes
// -------------------------------------------------------------------------------------------

/** Provides buffers for re-serializing filtered records. Implementations may block. */
@FunctionalInterface
public interface BufferSupplier {
Buffer requestBufferBlocking() throws IOException, InterruptedException;
}

/**
* Handles record filtering for a single input gate. Each gate has its own serializer and set of
* virtual channels, allowing different gates to handle different record types independently.
Expand All @@ -275,32 +254,28 @@ static class GateFilterHandler<T> {
private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> virtualChannels;
private final StreamElementSerializer<T> serializer;
private final DeserializationDelegate<StreamElement> deserializationDelegate;
private final DataOutputSerializer outputSerializer;
private final byte[] lengthBuffer = new byte[4];

GateFilterHandler(
Map<SubtaskConnectionDescriptor, VirtualChannel<T>> virtualChannels,
StreamElementSerializer<T> serializer) {
this.virtualChannels = checkNotNull(virtualChannels);
this.serializer = checkNotNull(serializer);
this.deserializationDelegate = new NonReusingDeserializationDelegate<>(serializer);
this.outputSerializer = new DataOutputSerializer(128);
}

/**
* Deserializes records from {@code sourceBuffer}, applies the virtual channel's record
* filter, and immediately re-serializes each surviving record into output buffers.
* filter, and re-serializes each surviving record into {@code outputSerializer}. No
* intermediate network buffer is used; the caller owns the segment boundary.
*/
List<Buffer> filterAndRewrite(
void filterAndRewrite(
int oldSubtaskIndex,
int oldChannelIndex,
Buffer sourceBuffer,
BufferSupplier bufferSupplier)
throws IOException, InterruptedException {
DataOutputSerializer outputSerializer)
throws IOException {

boolean sourceBufferOwnershipTransferred = false;
List<Buffer> resultBuffers = new ArrayList<>();
Buffer currentBuffer = null;
try {
SubtaskConnectionDescriptor key =
new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex);
Expand All @@ -319,132 +294,33 @@ List<Buffer> filterAndRewrite(
while (true) {
DeserializationResult result = vc.getNextRecord(deserializationDelegate);
if (result.isFullRecord()) {
if (currentBuffer == null) {
currentBuffer = bufferSupplier.requestBufferBlocking();
}
currentBuffer =
serializeElement(
deserializationDelegate.getInstance(),
currentBuffer,
resultBuffers,
bufferSupplier);
serializeElement(deserializationDelegate.getInstance(), outputSerializer);
}
if (result.isBufferConsumed()) {
break;
}
}

if (currentBuffer != null) {
if (currentBuffer.readableBytes() > 0) {
resultBuffers.add(currentBuffer);
} else {
currentBuffer.recycleBuffer();
}
currentBuffer = null;
}

return resultBuffers;
} catch (Throwable t) {
if (!sourceBufferOwnershipTransferred) {
sourceBuffer.recycleBuffer();
}
// Avoid double-recycle: currentBuffer may already be the last element in
// resultBuffers if serializeElement added it before the exception.
if (currentBuffer != null
&& (resultBuffers.isEmpty()
|| resultBuffers.get(resultBuffers.size() - 1) != currentBuffer)) {
currentBuffer.recycleBuffer();
}
for (Buffer buf : resultBuffers) {
buf.recycleBuffer();
}
resultBuffers.clear();
throw t;
}
}

/**
* Serializes a single stream element into the current buffer using the length-prefixed
* format (4-byte big-endian length + record bytes) expected by Flink's record
* deserializers. Spills into new buffers from {@code bufferSupplier} when needed.
*
* @return the buffer to continue writing into (may differ from the input buffer).
* Appends one stream element as a length-prefixed record. Reserves the 4B prefix,
* serializes the element, then backfills the length, because {@code outputSerializer}
* already holds the segment header and earlier records, so the prefix cannot be written
* from a fixed offset.
*/
private Buffer serializeElement(
StreamElement element,
Buffer currentBuffer,
List<Buffer> resultBuffers,
BufferSupplier bufferSupplier)
throws IOException, InterruptedException {
outputSerializer.clear();
private void serializeElement(StreamElement element, DataOutputSerializer outputSerializer)
throws IOException {
int startPos = outputSerializer.length();
outputSerializer.writeInt(0); // length placeholder
serializer.serialize(element, outputSerializer);
int recordLength = outputSerializer.length();

writeLengthToBuffer(recordLength);
currentBuffer =
writeDataToBuffer(
lengthBuffer, 0, 4, currentBuffer, resultBuffers, bufferSupplier);

byte[] serializedData = outputSerializer.getSharedBuffer();
currentBuffer =
writeDataToBuffer(
serializedData,
0,
recordLength,
currentBuffer,
resultBuffers,
bufferSupplier);
return currentBuffer;
}

private void writeLengthToBuffer(int length) {
lengthBuffer[0] = (byte) (length >> 24);
lengthBuffer[1] = (byte) (length >> 16);
lengthBuffer[2] = (byte) (length >> 8);
lengthBuffer[3] = (byte) length;
}

/**
* Writes data to the current buffer, spilling into new buffers from {@code bufferSupplier}
* when the current one is full.
*
* @return the buffer to continue writing into (may differ from the input buffer).
*/
private Buffer writeDataToBuffer(
byte[] data,
int dataOffset,
int dataLength,
Buffer currentBuffer,
List<Buffer> resultBuffers,
BufferSupplier bufferSupplier)
throws IOException, InterruptedException {
int offset = dataOffset;
int remaining = dataLength;

while (remaining > 0) {
int writableBytes = currentBuffer.getMaxCapacity() - currentBuffer.getSize();

if (writableBytes == 0) {
// Buffer is full, transfer ownership to resultBuffers
resultBuffers.add(currentBuffer);
currentBuffer = bufferSupplier.requestBufferBlocking();
writableBytes = currentBuffer.getMaxCapacity();
}

int bytesToWrite = Math.min(remaining, writableBytes);
currentBuffer
.getMemorySegment()
.put(
currentBuffer.getMemorySegmentOffset() + currentBuffer.getSize(),
data,
offset,
bytesToWrite);
currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite);

offset += bytesToWrite;
remaining -= bytesToWrite;
}
return currentBuffer;
int recordLength = outputSerializer.length() - startPos - Integer.BYTES;
outputSerializer.writeIntUnsafe(recordLength, startPos);
}

boolean hasPartialData() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ interface ChannelStateSerializer {

void writeData(DataOutputStream stream, Buffer... flinkBuffers) throws IOException;

void writeData(DataOutputStream stream, InputStream input, int length) throws IOException;

void readHeader(InputStream stream) throws IOException;

int readLength(InputStream stream) throws IOException;
Expand Down Expand Up @@ -165,6 +167,18 @@ public void writeData(DataOutputStream stream, Buffer... flinkBuffers) throws IO
}
}

@Override
public void writeData(DataOutputStream stream, InputStream input, int length)
throws IOException {
Preconditions.checkArgument(length >= 0, "negative state size");
stream.writeInt(length);
long copied = input.transferTo(stream);
if (copied != length) {
throw new java.io.EOFException(
"Unexpected EOF: expected " + length + " bytes of segment body, got " + copied);
}
}

private int getSize(Buffer[] buffers) {
int len = 0;
for (Buffer buffer : buffers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,20 @@ static ChannelStateWriteRequest abort(
return new CheckpointAbortRequest(jobVertexID, subtaskIndex, checkpointId, cause);
}

static ChannelStateWriteRequest replayInputDataFromSpill(
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
FetchedChannelStateReader reader) {
return new CheckpointInProgressRequest(
"writeInputFromSpill",
jobVertexID,
subtaskIndex,
checkpointId,
writer -> writer.writeInputFromSpill(jobVertexID, subtaskIndex, reader),
throwable -> reader.close());
}

static ChannelStateWriteRequest registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
return new SubtaskRegisterRequest(jobVertexID, subtaskIndex);
}
Expand Down
Loading