diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 2a86a566b1f00..cbb4635a9a359 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -1811,13 +1811,22 @@ private OptionalLong restoreLatestCheckpointedStateInternal(
vertexFinishedStateChecker.validateOperatorsFinishedState();
}
+ EdgeDistributionPatternSnapshot oldEdgePatterns = null;
+ for (MasterState ms : latest.getMasterHookStates()) {
+ if (EdgeDistributionPatternSnapshot.HOOK_IDENTIFIER.equals(ms.name())) {
+ oldEdgePatterns = EdgeDistributionPatternSnapshot.fromBytes(ms.bytes());
+ break;
+ }
+ }
+
StateAssignmentOperation stateAssignmentOperation =
new StateAssignmentOperation(
latest.getCheckpointID(),
tasks,
operatorStates,
allowNonRestoredState,
- recoverOutputOnDownstreamTask);
+ recoverOutputOnDownstreamTask,
+ oldEdgePatterns);
stateAssignmentOperation.assignStates();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshot.java
new file mode 100644
index 0000000000000..701cb780133e9
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshot.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+
+import javax.annotation.Nullable;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Snapshot of per-operator output edge {@link DistributionPattern}s, persisted via {@link
+ * MasterTriggerRestoreHook} / {@link MasterState} so that the correct old distribution pattern is
+ * available when restoring from a checkpoint after a shuffle-mode change.
+ *
+ *
Key = outputOperatorID (generated, chain-tail), Value = one {@link DistributionPattern} per
+ * output partition (indexed by partitionIndex).
+ */
+public class EdgeDistributionPatternSnapshot {
+
+ public static final String HOOK_IDENTIFIER = "__flink_edge_distribution_patterns__";
+
+ private static final int FORMAT_VERSION = 1;
+ private static final byte PATTERN_POINTWISE = 0;
+ private static final byte PATTERN_ALL_TO_ALL = 1;
+
+ private final Map outputEdgePatterns;
+
+ public EdgeDistributionPatternSnapshot(
+ Map outputEdgePatterns) {
+ this.outputEdgePatterns = Collections.unmodifiableMap(new HashMap<>(outputEdgePatterns));
+ }
+
+ @Nullable
+ public DistributionPattern[] getOutputPatterns(OperatorID outputOperatorID) {
+ DistributionPattern[] patterns = outputEdgePatterns.get(outputOperatorID);
+ return patterns != null ? patterns.clone() : null;
+ }
+
+ @Nullable
+ public DistributionPattern getOutputPattern(OperatorID outputOperatorID, int partitionIndex) {
+ DistributionPattern[] patterns = outputEdgePatterns.get(outputOperatorID);
+ if (patterns == null || partitionIndex < 0 || partitionIndex >= patterns.length) {
+ return null;
+ }
+ return patterns[partitionIndex];
+ }
+
+ public byte[] toBytes() throws IOException {
+ try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos)) {
+ dos.writeInt(FORMAT_VERSION);
+ dos.writeInt(outputEdgePatterns.size());
+ for (Map.Entry entry :
+ outputEdgePatterns.entrySet()) {
+ OperatorID opId = entry.getKey();
+ dos.writeLong(opId.getLowerPart());
+ dos.writeLong(opId.getUpperPart());
+ DistributionPattern[] patterns = entry.getValue();
+ dos.writeInt(patterns.length);
+ for (DistributionPattern p : patterns) {
+ dos.writeByte(
+ p == DistributionPattern.ALL_TO_ALL
+ ? PATTERN_ALL_TO_ALL
+ : PATTERN_POINTWISE);
+ }
+ }
+ dos.flush();
+ return baos.toByteArray();
+ }
+ }
+
+ public static EdgeDistributionPatternSnapshot fromBytes(byte[] bytes) throws IOException {
+ try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
+ DataInputStream dis = new DataInputStream(bais)) {
+ int version = dis.readInt();
+ if (version != FORMAT_VERSION) {
+ throw new IOException(
+ "Unsupported EdgeDistributionPatternSnapshot version: " + version);
+ }
+ int numOperators = dis.readInt();
+ Map map = new HashMap<>(numOperators);
+ for (int i = 0; i < numOperators; i++) {
+ OperatorID opId = new OperatorID(dis.readLong(), dis.readLong());
+ int numPartitions = dis.readInt();
+ DistributionPattern[] patterns = new DistributionPattern[numPartitions];
+ for (int j = 0; j < numPartitions; j++) {
+ byte b = dis.readByte();
+ if (b == PATTERN_ALL_TO_ALL) {
+ patterns[j] = DistributionPattern.ALL_TO_ALL;
+ } else if (b == PATTERN_POINTWISE) {
+ patterns[j] = DistributionPattern.POINTWISE;
+ } else {
+ throw new IOException("Unknown DistributionPattern byte in snapshot: " + b);
+ }
+ }
+ map.put(opId, patterns);
+ }
+ return new EdgeDistributionPatternSnapshot(map);
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java
index fffc7b44305de..c7c9f4d7023c4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java
@@ -17,6 +17,8 @@
package org.apache.flink.runtime.checkpoint;
+import java.io.IOException;
+import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.Arrays;
@@ -55,6 +57,11 @@ public RescaleMappings getChannelMapping(int gateOrPartitionIndex) {
return gateOrPartitionDescriptors[gateOrPartitionIndex].getRescaleMappings();
}
+ public InflightDataGateOrPartitionRescalingDescriptor getGateOrPartitionDescriptor(
+ int gateOrPartitionIndex) {
+ return gateOrPartitionDescriptors[gateOrPartitionIndex];
+ }
+
public boolean isAmbiguous(int gateOrPartitionIndex, int oldSubtaskIndex) {
return gateOrPartitionDescriptors[gateOrPartitionIndex].ambiguousSubtaskIndexes.contains(
oldSubtaskIndex);
@@ -142,7 +149,7 @@ public RescaleMappings getRescaleMappings() {
/**
* Set when channels are merged because the connected operator has been rescaled for each
- * gate/partition.
+ * gate/partition. Used for ALL_TO_ALL → ALL_TO_ALL path.
*/
private final RescaleMappings rescaledChannelsMappings;
@@ -151,6 +158,9 @@ public RescaleMappings getRescaleMappings() {
private final MappingType mappingType;
+ /** Bundled scalar topology parameters for POINTWISE-aware rescaling. */
+ private PointwiseRescaleParams rescaleParams;
+
/** Type of mapping which should be used for this in-flight data. */
public enum MappingType {
IDENTITY,
@@ -162,10 +172,25 @@ public InflightDataGateOrPartitionRescalingDescriptor(
RescaleMappings rescaledChannelsMappings,
Set ambiguousSubtaskIndexes,
MappingType mappingType) {
+ this(
+ oldSubtaskIndexes,
+ rescaledChannelsMappings,
+ ambiguousSubtaskIndexes,
+ mappingType,
+ PointwiseRescaleParams.EMPTY);
+ }
+
+ public InflightDataGateOrPartitionRescalingDescriptor(
+ int[] oldSubtaskIndexes,
+ RescaleMappings rescaledChannelsMappings,
+ Set ambiguousSubtaskIndexes,
+ MappingType mappingType,
+ PointwiseRescaleParams rescaleParams) {
this.oldSubtaskIndexes = oldSubtaskIndexes;
this.rescaledChannelsMappings = rescaledChannelsMappings;
this.ambiguousSubtaskIndexes = ambiguousSubtaskIndexes;
this.mappingType = mappingType;
+ this.rescaleParams = rescaleParams;
}
public int[] getOldSubtaskInstances() {
@@ -180,6 +205,21 @@ public boolean isIdentity() {
return mappingType == MappingType.IDENTITY;
}
+ public PointwiseRescaleParams getPointwiseRescaleParams() {
+ return rescaleParams;
+ }
+
+ public boolean isPointwiseRescaling() {
+ return !rescaleParams.equals(PointwiseRescaleParams.EMPTY);
+ }
+
+ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
+ in.defaultReadObject();
+ if (rescaleParams == null) {
+ rescaleParams = PointwiseRescaleParams.EMPTY;
+ }
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) {
@@ -193,13 +233,18 @@ public boolean equals(Object o) {
return Arrays.equals(oldSubtaskIndexes, that.oldSubtaskIndexes)
&& Objects.equals(rescaledChannelsMappings, that.rescaledChannelsMappings)
&& Objects.equals(ambiguousSubtaskIndexes, that.ambiguousSubtaskIndexes)
- && mappingType == that.mappingType;
+ && mappingType == that.mappingType
+ && Objects.equals(rescaleParams, that.rescaleParams);
}
@Override
public int hashCode() {
int result =
- Objects.hash(rescaledChannelsMappings, ambiguousSubtaskIndexes, mappingType);
+ Objects.hash(
+ rescaledChannelsMappings,
+ ambiguousSubtaskIndexes,
+ mappingType,
+ rescaleParams);
result = 31 * result + Arrays.hashCode(oldSubtaskIndexes);
return result;
}
@@ -215,6 +260,8 @@ public String toString() {
+ ambiguousSubtaskIndexes
+ ", mappingType="
+ mappingType
+ + ", rescaleParams="
+ + rescaleParams
+ '}';
}
}
@@ -231,6 +278,12 @@ public int[] getOldSubtaskIndexes(int gateOrPartitionIndex) {
return EMPTY_INT_ARRAY;
}
+ @Override
+ public InflightDataGateOrPartitionRescalingDescriptor getGateOrPartitionDescriptor(
+ int gateOrPartitionIndex) {
+ return InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
+ }
+
@Override
public RescaleMappings getChannelMapping(int gateOrPartitionIndex) {
return RescaleMappings.SYMMETRIC_IDENTITY;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtils.java
new file mode 100644
index 0000000000000..9a155522c8664
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtils.java
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.IntStream;
+
+/**
+ * Topology computation utilities for POINTWISE edge unaligned-checkpoint rescaling. Used by both JM
+ * (subtask assignment / descriptor generation) and TM (channel routing at recovery time).
+ *
+ *
Atomic primitives
+ *
+ *
All business methods are compositions of six atomic primitives:
+ *
+ *
{@link #localIndexToGlobalSubtaskIndex} — local channel/subpartition index → global subtask
+ * index of the peer (local-to-global)
+ *
{@link #globalSubtaskIndexToLocalIndex} — inverse of the above: global peer subtask index →
+ * local channel/subpartition index (global-to-local)
+ *
{@link #oldSubtasksAssignedTo} — cross-generation: new subtask → old subtasks it inherits
+ * (ROUND_ROBIN)
+ *
{@link #newSubtaskAssignedFrom} — cross-generation inverse: old subtask → new subtask
+ * (modulo)
+ *
+ *
+ *
Core chain
+ *
+ *
One chain, two directions. JM walks it forward (set-level) to decide which old
+ * subtasks to recover from; TM walks it backward (per-buffer) to decide where to place
+ * each recovered buffer.
+ *
+ *
Forward direction (JM, set-level) — "new upstream U needs state from which old
+ * upstreams?"
+ *
+ *
Implemented by {@link #traceOutputSources}. Its inverse ({@link #resolveInputOwnership})
+ * resolves the input side: for a given new downstream, which new upstream owns each old upstream's
+ * data.
+ *
+ *
Backward direction (TM, per-buffer) — "old buffer at (oldUp, localSP) goes to which new
+ * subpartition?"
+ *
+ *
The first and third steps use modulo-based mapping on the output side ({@code D %
+ * numConsumers}), matching the execution graph's subpartition assignment in {@link
+ * org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils#computeConsumedSubpartitionRange}.
+ *
+ *
Routing principle
+ *
+ *
At TM recovery time, only the new upstream that is connected to a target new
+ * downstream writes recovered buffers to it. When {@code newUpPar > newDownPar} in POINTWISE,
+ * multiple new upstreams connect to the same downstream and receive the same old state (ambiguous
+ * assignment from {@link #traceOutputSources}). To avoid duplicates, only the primary
+ * producer ({@code producersOf(D)[0]}) writes; others discard via {@link
+ * #computeNewLocalSubpartitionIndex} returning -1.
+ *
+ *
All arithmetic mirrors {@link
+ * org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils#computeVertexInputInfoForPointwise}
+ * but operates on scalar parallelisms only (no scheduler-layer objects), so it can run on TM.
+ */
+@Internal
+public final class PointwiseChannelMappingUtils {
+
+ private PointwiseChannelMappingUtils() {}
+
+ // ============================== Business methods ==============================
+
+ /**
+ * Output-side tracing: computes old upstreams whose state a new upstream needs to recover.
+ * Chains consumersOf → oldSubtasksAssignedTo → producersOf (see class javadoc). The result is
+ * ambiguous (multiple new upstreams may claim the same old upstream); TM filters via {@link
+ * #computeNewLocalSubpartitionIndex} at recovery time.
+ */
+ public static int[] traceOutputSources(
+ int newUpstreamSubtaskIndex, PointwiseRescaleParams params) {
+ final int newUpstreamParallelism = params.getNewUpParallelism();
+ final int newDownstreamParallelism = params.getNewDownParallelism();
+ final int oldUpstreamParallelism = params.getOldUpParallelism();
+ final int oldDownstreamParallelism = params.getOldDownParallelism();
+
+ int[] newDownstreams;
+ if (params.getNewDistributionPattern() == DistributionPattern.POINTWISE) {
+ newDownstreams =
+ consumersOf(
+ newUpstreamSubtaskIndex,
+ newUpstreamParallelism,
+ newDownstreamParallelism);
+ } else {
+ newDownstreams = IntStream.range(0, newDownstreamParallelism).toArray();
+ }
+
+ Set oldDownstreams = new HashSet<>();
+ for (int newDownstream : newDownstreams) {
+ for (int oldDownstream :
+ oldSubtasksAssignedTo(
+ newDownstream, oldDownstreamParallelism, newDownstreamParallelism)) {
+ oldDownstreams.add(oldDownstream);
+ }
+ }
+
+ Set oldUpstreams = new HashSet<>();
+ for (int oldDownstream : oldDownstreams) {
+ if (params.getOldDistributionPattern() == DistributionPattern.ALL_TO_ALL) {
+ for (int i = 0; i < oldUpstreamParallelism; i++) {
+ oldUpstreams.add(i);
+ }
+ } else {
+ for (int oldUpstream :
+ producersOf(
+ oldDownstream, oldUpstreamParallelism, oldDownstreamParallelism)) {
+ oldUpstreams.add(oldUpstream);
+ }
+ }
+ }
+
+ return oldUpstreams.stream().sorted().mapToInt(Integer::intValue).toArray();
+ }
+
+ /**
+ * Input-side ownership: for a new downstream, determines which new upstream owns each old
+ * upstream's data. Inverse of {@link #traceOutputSources} scoped to connected new upstreams.
+ * First-claim wins. Returns {@code result[oldUp] = newUp}, or -1 if unclaimed.
+ */
+ public static int[] resolveInputOwnership(
+ int newDownstreamSubtaskIndex, PointwiseRescaleParams params) {
+ int[] connectedNewUpstreams;
+ if (params.getNewDistributionPattern() == DistributionPattern.POINTWISE) {
+ connectedNewUpstreams =
+ producersOf(
+ newDownstreamSubtaskIndex,
+ params.getNewUpParallelism(),
+ params.getNewDownParallelism());
+ } else {
+ connectedNewUpstreams = IntStream.range(0, params.getNewUpParallelism()).toArray();
+ }
+
+ int[] mapping = new int[params.getOldUpParallelism()];
+ Arrays.fill(mapping, -1);
+
+ for (int newUpstream : connectedNewUpstreams) {
+ for (int oldUpstream : traceOutputSources(newUpstream, params)) {
+ if (mapping[oldUpstream] < 0) {
+ mapping[oldUpstream] = newUpstream;
+ }
+ }
+ }
+
+ return mapping;
+ }
+
+ /**
+ * Topology-aware old-to-new subtask assignment. For downstream assignment, computes upstream
+ * assignment first to ensure coherence. Uses range-shift ({@code old * newPar / oldPar}) as
+ * preferred placement, then checks topology constraint via the other-side peer set.
+ *
+ * @param isInputSide true = assigning upstream (input side), false = downstream
+ * @return {@code int[oldPar]} where {@code result[oldIdx] = newIdx}
+ */
+ public static int[] computeOldToNewSubtaskAssignment(
+ PointwiseRescaleParams params, boolean isInputSide) {
+ int[] otherOldToNew = null;
+ if (!isInputSide) {
+ otherOldToNew = computeOldToNewSubtaskAssignmentInternal(params, true, null);
+ }
+ return computeOldToNewSubtaskAssignmentInternal(params, isInputSide, otherOldToNew);
+ }
+
+ /**
+ * Old-topology local input channel count: A2A = oldUpstreamParallelism, PW =
+ * producersOf(subtask).length.
+ */
+ public static int getOldLocalChannelCount(
+ int downstreamSubtaskIndex, PointwiseRescaleParams params) {
+ if (params.getOldDistributionPattern() == DistributionPattern.ALL_TO_ALL) {
+ return params.getOldUpParallelism();
+ }
+ return producersOf(
+ downstreamSubtaskIndex,
+ params.getOldUpParallelism(),
+ params.getOldDownParallelism())
+ .length;
+ }
+
+ /**
+ * New upstream subtask index → local input channel index in the new topology. Never returns -1;
+ * throws on miss because input-side recovery only routes to upstreams confirmed by {@link
+ * #resolveInputOwnership}.
+ */
+ public static int computeNewLocalInputChannelIndex(
+ int newUpstreamSubtaskIndex,
+ int newDownstreamSubtaskIndex,
+ PointwiseRescaleParams params) {
+ if (params.getNewDistributionPattern() == DistributionPattern.ALL_TO_ALL) {
+ return newUpstreamSubtaskIndex;
+ }
+ int localIndex =
+ globalSubtaskIndexToLocalIndex(
+ newDownstreamSubtaskIndex,
+ newUpstreamSubtaskIndex,
+ params.getNewUpParallelism(),
+ params.getNewDownParallelism(),
+ true,
+ params.getNewDistributionPattern());
+ if (localIndex < 0) {
+ throw new IllegalStateException(
+ "newUpstreamSubtaskIndex="
+ + newUpstreamSubtaskIndex
+ + " not found in producers of subtask "
+ + newDownstreamSubtaskIndex
+ + " with params "
+ + params);
+ }
+ return localIndex;
+ }
+
+ /**
+ * New downstream subtask index → local subpartition index in the new topology (modulo-based).
+ * Returns -1 (discard) in two cases:
+ *
+ *
+ *
Not connected: downstream is not in {@code consumersOf(upstream)}
+ *
Not primary producer: when {@code upPar > downPar}, multiple upstreams connect to the
+ * same downstream; only {@code producersOf(D)[0]} writes, others discard to avoid
+ * duplicates
+ *
+ */
+ public static int computeNewLocalSubpartitionIndex(
+ int newDownstreamSubtaskIndex,
+ int newUpstreamSubtaskIndex,
+ PointwiseRescaleParams params) {
+ if (params.getNewDistributionPattern() == DistributionPattern.ALL_TO_ALL) {
+ return newDownstreamSubtaskIndex;
+ }
+ int localIndex =
+ globalSubtaskIndexToLocalIndex(
+ newUpstreamSubtaskIndex,
+ newDownstreamSubtaskIndex,
+ params.getNewUpParallelism(),
+ params.getNewDownParallelism(),
+ false,
+ params.getNewDistributionPattern());
+ if (localIndex < 0) {
+ return -1;
+ }
+ // Primary producer dedup: only producersOf(D)[0] writes.
+ int[] producers =
+ producersOf(
+ newDownstreamSubtaskIndex,
+ params.getNewUpParallelism(),
+ params.getNewDownParallelism());
+ if (producers[0] != newUpstreamSubtaskIndex) {
+ return -1;
+ }
+ return localIndex;
+ }
+
+ // ============================== Internal ==============================
+
+ private static int[] computeOldToNewSubtaskAssignmentInternal(
+ PointwiseRescaleParams params, boolean isInputSide, int[] otherOldToNew) {
+ final int oldSelfParallelism, newSelfParallelism, oldOtherParallelism, newOtherParallelism;
+ if (isInputSide) {
+ oldSelfParallelism = params.getOldUpParallelism();
+ newSelfParallelism = params.getNewUpParallelism();
+ oldOtherParallelism = params.getOldDownParallelism();
+ newOtherParallelism = params.getNewDownParallelism();
+ } else {
+ oldSelfParallelism = params.getOldDownParallelism();
+ newSelfParallelism = params.getNewDownParallelism();
+ oldOtherParallelism = params.getOldUpParallelism();
+ newOtherParallelism = params.getNewUpParallelism();
+ }
+ int[][] newOtherPeersCache = new int[newSelfParallelism][];
+ for (int i = 0; i < newSelfParallelism; i++) {
+ newOtherPeersCache[i] =
+ isInputSide
+ ? consumersOf(i, newSelfParallelism, newOtherParallelism)
+ : producersOf(i, newOtherParallelism, newSelfParallelism);
+ }
+ int[] result = new int[oldSelfParallelism];
+ for (int oldSelf = 0; oldSelf < oldSelfParallelism; oldSelf++) {
+ int[] oldOtherPeers =
+ isInputSide
+ ? consumersOf(oldSelf, oldSelfParallelism, oldOtherParallelism)
+ : producersOf(oldSelf, oldOtherParallelism, oldSelfParallelism);
+ int chosen = -1;
+ if (oldOtherPeers.length > 0) {
+ int targetNewOther =
+ otherOldToNew != null
+ ? otherOldToNew[oldOtherPeers[0]]
+ : newSubtaskAssignedFrom(oldOtherPeers[0], newOtherParallelism);
+ int preferred = oldSelf * newSelfParallelism / oldSelfParallelism;
+ if (containsValue(newOtherPeersCache[preferred], targetNewOther)) {
+ chosen = preferred;
+ } else {
+ for (int candidate = 0; candidate < newSelfParallelism; candidate++) {
+ if (containsValue(newOtherPeersCache[candidate], targetNewOther)) {
+ chosen = candidate;
+ break;
+ }
+ }
+ }
+ }
+ if (chosen < 0) {
+ throw new IllegalStateException(
+ "No topology-compatible new subtask found for oldSelf="
+ + oldSelf
+ + " (params="
+ + params
+ + ", isInputSide="
+ + isInputSide
+ + ")");
+ }
+ result[oldSelf] = chosen;
+ }
+ return result;
+ }
+
+ // ============================== Atomic primitives ==============================
+
+ /** Topology edge: downstream subtask → upstream subtask indices that produce for it. */
+ static int[] producersOf(
+ int downstreamIndex, int upstreamParallelism, int downstreamParallelism) {
+ Preconditions.checkArgument(
+ upstreamParallelism > 0 && downstreamParallelism > 0,
+ "parallelisms must be positive, got upstreamParallelism=%s, downstreamParallelism=%s",
+ upstreamParallelism,
+ downstreamParallelism);
+ Preconditions.checkArgument(
+ downstreamIndex >= 0 && downstreamIndex < downstreamParallelism,
+ "index out of range: downstreamIndex=%s, downstreamParallelism=%s",
+ downstreamIndex,
+ downstreamParallelism);
+ if (upstreamParallelism >= downstreamParallelism) {
+ int start = downstreamIndex * upstreamParallelism / downstreamParallelism;
+ int end = (downstreamIndex + 1) * upstreamParallelism / downstreamParallelism;
+ int[] result = new int[end - start];
+ for (int i = 0; i < result.length; i++) {
+ result[i] = start + i;
+ }
+ return result;
+ } else {
+ for (int u = 0; u < upstreamParallelism; u++) {
+ int start =
+ (u * downstreamParallelism + upstreamParallelism - 1) / upstreamParallelism;
+ int end =
+ ((u + 1) * downstreamParallelism + upstreamParallelism - 1)
+ / upstreamParallelism;
+ if (downstreamIndex >= start && downstreamIndex < end) {
+ return new int[] {u};
+ }
+ }
+ throw new IllegalStateException(
+ "No producer found for downstreamIndex="
+ + downstreamIndex
+ + " (upstreamParallelism="
+ + upstreamParallelism
+ + ", downstreamParallelism="
+ + downstreamParallelism
+ + ")");
+ }
+ }
+
+ /** Topology edge: upstream subtask → downstream subtask indices it produces for. */
+ static int[] consumersOf(
+ int upstreamIndex, int upstreamParallelism, int downstreamParallelism) {
+ Preconditions.checkArgument(
+ upstreamParallelism > 0 && downstreamParallelism > 0,
+ "parallelisms must be positive, got upstreamParallelism=%s, downstreamParallelism=%s",
+ upstreamParallelism,
+ downstreamParallelism);
+ Preconditions.checkArgument(
+ upstreamIndex >= 0 && upstreamIndex < upstreamParallelism,
+ "index out of range: upstreamIndex=%s, upstreamParallelism=%s",
+ upstreamIndex,
+ upstreamParallelism);
+ if (upstreamParallelism < downstreamParallelism) {
+ int start =
+ (upstreamIndex * downstreamParallelism + upstreamParallelism - 1)
+ / upstreamParallelism;
+ int end =
+ ((upstreamIndex + 1) * downstreamParallelism + upstreamParallelism - 1)
+ / upstreamParallelism;
+ int[] result = new int[end - start];
+ for (int i = 0; i < result.length; i++) {
+ result[i] = start + i;
+ }
+ return result;
+ } else {
+ for (int d = 0; d < downstreamParallelism; d++) {
+ int start = d * upstreamParallelism / downstreamParallelism;
+ int end = (d + 1) * upstreamParallelism / downstreamParallelism;
+ if (upstreamIndex >= start && upstreamIndex < end) {
+ return new int[] {d};
+ }
+ }
+ throw new IllegalStateException(
+ "No consumer found for upstreamIndex="
+ + upstreamIndex
+ + " (upstreamParallelism="
+ + upstreamParallelism
+ + ", downstreamParallelism="
+ + downstreamParallelism
+ + ")");
+ }
+ }
+
+ /**
+ * Cross-generation: new subtask → old subtasks it inherits state from. Delegates to {@link
+ * SubtaskStateMapper#ROUND_ROBIN}.
+ */
+ public static int[] oldSubtasksAssignedTo(
+ int newSubtaskIndex, int oldParallelism, int newParallelism) {
+ return SubtaskStateMapper.ROUND_ROBIN.getOldSubtasks(
+ newSubtaskIndex, oldParallelism, newParallelism);
+ }
+
+ /** Cross-generation inverse: old subtask → new subtask that inherits it ({@code old % new}). */
+ public static int newSubtaskAssignedFrom(int oldSubtaskIndex, int newParallelism) {
+ return oldSubtaskIndex % newParallelism;
+ }
+
+ /**
+ * Local-to-global: translates a local channel/subpartition index to the global subtask index of
+ * the peer on the other side of the edge. Topology-agnostic — caller passes whichever
+ * parallelisms (old or new) are relevant.
+ *
+ *
Input side (downstream→upstream): {@code producersOf(subtask)[localIndex]} — sequential.
+ * Output side (upstream→downstream): modulo-based — finds the consumer {@code D} where
+ * {@code D % numConsumers == localIndex}, matching {@code
+ * VertexInputInfoComputationUtils.computeConsumedSubpartitionRange}.
+ *
+ *
When the distribution pattern is {@link DistributionPattern#ALL_TO_ALL}, every subtask
+ * connects to every peer, so the local index already IS the global subtask index.
+ */
+ public static int localIndexToGlobalSubtaskIndex(
+ int subtaskIndex,
+ int localIndex,
+ int upstreamParallelism,
+ int downstreamParallelism,
+ boolean isInputSide,
+ DistributionPattern distributionPattern) {
+ if (distributionPattern == DistributionPattern.ALL_TO_ALL) {
+ return localIndex;
+ }
+ if (isInputSide) {
+ int[] producers = producersOf(subtaskIndex, upstreamParallelism, downstreamParallelism);
+ return producers[localIndex];
+ } else {
+ int[] consumers = consumersOf(subtaskIndex, upstreamParallelism, downstreamParallelism);
+ for (int c : consumers) {
+ if (c % consumers.length == localIndex) {
+ return c;
+ }
+ }
+ throw new IllegalStateException(
+ "No consumer with modulo position "
+ + localIndex
+ + " for upstream "
+ + subtaskIndex
+ + " (consumers="
+ + java.util.Arrays.toString(consumers)
+ + ")");
+ }
+ }
+
+ /**
+ * Global-to-local: inverse of {@link #localIndexToGlobalSubtaskIndex}. Translates a global peer
+ * subtask index back to the local channel/subpartition index.
+ *
+ *
When the distribution pattern is {@link DistributionPattern#ALL_TO_ALL}, the local index
+ * equals the peer subtask index directly.
+ *
+ * @return the local index, or -1 if not connected
+ */
+ static int globalSubtaskIndexToLocalIndex(
+ int subtaskIndex,
+ int peerSubtaskIndex,
+ int upstreamParallelism,
+ int downstreamParallelism,
+ boolean isInputSide,
+ DistributionPattern distributionPattern) {
+ if (distributionPattern == DistributionPattern.ALL_TO_ALL) {
+ return peerSubtaskIndex;
+ }
+ if (isInputSide) {
+ int[] producers = producersOf(subtaskIndex, upstreamParallelism, downstreamParallelism);
+ for (int i = 0; i < producers.length; i++) {
+ if (producers[i] == peerSubtaskIndex) {
+ return i;
+ }
+ }
+ return -1;
+ } else {
+ int[] consumers = consumersOf(subtaskIndex, upstreamParallelism, downstreamParallelism);
+ if (!containsValue(consumers, peerSubtaskIndex)) {
+ return -1;
+ }
+ return peerSubtaskIndex % consumers.length;
+ }
+ }
+
+ // ============================== Utilities ==============================
+
+ private static boolean containsValue(int[] arr, int v) {
+ for (int x : arr) {
+ if (x == v) {
+ return true;
+ }
+ }
+ return false;
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseRescaleParams.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseRescaleParams.java
new file mode 100644
index 0000000000000..959ba62f0c6f6
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseRescaleParams.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+
+import java.io.ObjectStreamException;
+import java.io.Serializable;
+import java.util.Objects;
+
+/** Bundled scalar parameters for POINTWISE-aware rescaling on a single gate or partition. */
+public final class PointwiseRescaleParams implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ public static final PointwiseRescaleParams EMPTY =
+ new PointwiseRescaleParams(
+ DistributionPattern.ALL_TO_ALL, DistributionPattern.ALL_TO_ALL, 0, 0, 0, 0);
+
+ private final DistributionPattern oldDistributionPattern;
+ private final DistributionPattern newDistributionPattern;
+ private final int oldUpParallelism;
+ private final int oldDownParallelism;
+ private final int newUpParallelism;
+ private final int newDownParallelism;
+
+ public PointwiseRescaleParams(
+ DistributionPattern oldDistributionPattern,
+ DistributionPattern newDistributionPattern,
+ int oldUpParallelism,
+ int oldDownParallelism,
+ int newUpParallelism,
+ int newDownParallelism) {
+ this.oldDistributionPattern = oldDistributionPattern;
+ this.newDistributionPattern = newDistributionPattern;
+ this.oldUpParallelism = oldUpParallelism;
+ this.oldDownParallelism = oldDownParallelism;
+ this.newUpParallelism = newUpParallelism;
+ this.newDownParallelism = newDownParallelism;
+ }
+
+ public DistributionPattern getOldDistributionPattern() {
+ return oldDistributionPattern;
+ }
+
+ public DistributionPattern getNewDistributionPattern() {
+ return newDistributionPattern;
+ }
+
+ public int getOldUpParallelism() {
+ return oldUpParallelism;
+ }
+
+ public int getOldDownParallelism() {
+ return oldDownParallelism;
+ }
+
+ public int getNewUpParallelism() {
+ return newUpParallelism;
+ }
+
+ public int getNewDownParallelism() {
+ return newDownParallelism;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ PointwiseRescaleParams that = (PointwiseRescaleParams) o;
+ return oldUpParallelism == that.oldUpParallelism
+ && oldDownParallelism == that.oldDownParallelism
+ && newUpParallelism == that.newUpParallelism
+ && newDownParallelism == that.newDownParallelism
+ && oldDistributionPattern == that.oldDistributionPattern
+ && newDistributionPattern == that.newDistributionPattern;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ oldDistributionPattern,
+ newDistributionPattern,
+ oldUpParallelism,
+ oldDownParallelism,
+ newUpParallelism,
+ newDownParallelism);
+ }
+
+ private Object readResolve() throws ObjectStreamException {
+ if (oldUpParallelism == 0
+ && oldDownParallelism == 0
+ && newUpParallelism == 0
+ && newDownParallelism == 0) {
+ return EMPTY;
+ }
+ return this;
+ }
+
+ @Override
+ public String toString() {
+ return "PointwiseRescaleParams{"
+ + "oldPattern="
+ + oldDistributionPattern
+ + ", newPattern="
+ + newDistributionPattern
+ + ", oldUp="
+ + oldUpParallelism
+ + ", oldDown="
+ + oldDownParallelism
+ + ", newUp="
+ + newUpParallelism
+ + ", newDown="
+ + newDownParallelism
+ + '}';
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 6bc30d488c4f2..59d89cd81a8ce 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -48,6 +48,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.annotation.Nullable;
+
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
@@ -76,6 +78,7 @@ public class StateAssignmentOperation {
private final long restoreCheckpointId;
private final boolean allowNonRestoredState;
private final boolean recoverOutputOnDownstreamTask;
+ @Nullable private final EdgeDistributionPatternSnapshot oldEdgePatterns;
/** The state assignments for each ExecutionJobVertex that will be filled in multiple passes. */
private final Map vertexAssignments;
@@ -93,12 +96,29 @@ public StateAssignmentOperation(
Map operatorStates,
boolean allowNonRestoredState,
boolean recoverOutputOnDownstreamTask) {
+ this(
+ restoreCheckpointId,
+ tasks,
+ operatorStates,
+ allowNonRestoredState,
+ recoverOutputOnDownstreamTask,
+ null);
+ }
+
+ public StateAssignmentOperation(
+ long restoreCheckpointId,
+ Set tasks,
+ Map operatorStates,
+ boolean allowNonRestoredState,
+ boolean recoverOutputOnDownstreamTask,
+ @Nullable EdgeDistributionPatternSnapshot oldEdgePatterns) {
this.restoreCheckpointId = restoreCheckpointId;
this.tasks = Preconditions.checkNotNull(tasks);
this.operatorStates = Preconditions.checkNotNull(operatorStates);
this.allowNonRestoredState = allowNonRestoredState;
this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask;
+ this.oldEdgePatterns = oldEdgePatterns;
this.vertexAssignments = CollectionUtil.newHashMapWithExpectedSize(tasks.size());
}
@@ -148,7 +168,8 @@ private void buildStateAssignments() {
operatorStates,
consumerAssignment,
vertexAssignments,
- recoverOutputOnDownstreamTask);
+ recoverOutputOnDownstreamTask,
+ oldEdgePatterns);
vertexAssignments.put(executionJobVertex, stateAssignment);
for (final IntermediateResult producedDataSet : executionJobVertex.getInputs()) {
consumerAssignment.put(producedDataSet.getId(), stateAssignment);
@@ -408,7 +429,8 @@ public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment)
final List outputs =
executionJobVertex.getJobVertex().getProducedDataSets();
- if (outputState.getParallelism() == executionJobVertex.getParallelism()) {
+ if (outputState.getParallelism() == executionJobVertex.getParallelism()
+ && !assignment.hasPointwiseEdge(assignment.getDownstreamAssignments(), false)) {
assignment.resultSubpartitionStates.putAll(
toInstanceMap(assignment.outputOperatorID, outputOperatorState));
return;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
index a6db5e4837f85..372ea7c7d45a3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
@@ -26,6 +26,7 @@
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
@@ -112,6 +113,7 @@ class TaskStateAssignment {
outputRescalingDescriptors = new HashMap<>();
private final boolean recoverOutputOnDownstreamTask;
+ @Nullable private final EdgeDistributionPatternSnapshot oldEdgePatterns;
@Nullable private TaskStateAssignment[] downstreamAssignments;
@Nullable private TaskStateAssignment[] upstreamAssignments;
@@ -127,6 +129,22 @@ public TaskStateAssignment(
Map consumerAssignment,
Map vertexAssignments,
boolean recoverOutputOnDownstreamTask) {
+ this(
+ executionJobVertex,
+ oldState,
+ consumerAssignment,
+ vertexAssignments,
+ recoverOutputOnDownstreamTask,
+ null);
+ }
+
+ public TaskStateAssignment(
+ ExecutionJobVertex executionJobVertex,
+ Map oldState,
+ Map consumerAssignment,
+ Map vertexAssignments,
+ boolean recoverOutputOnDownstreamTask,
+ @Nullable EdgeDistributionPatternSnapshot oldEdgePatterns) {
this.executionJobVertex = executionJobVertex;
this.oldState = oldState;
@@ -144,6 +162,7 @@ public TaskStateAssignment(
this.consumerAssignment = checkNotNull(consumerAssignment);
this.vertexAssignments = checkNotNull(vertexAssignments);
this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask;
+ this.oldEdgePatterns = oldEdgePatterns;
final int expectedNumberOfSubtasks = newParallelism * oldState.size();
subManagedOperatorState =
@@ -370,7 +389,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
// no state on input and output, especially for any aligned checkpoint
if (subtaskGateOrPartitionMappings.isEmpty()
- && Arrays.stream(rescaledChannelsMappings).allMatch(Objects::isNull)) {
+ && Arrays.stream(rescaledChannelsMappings).allMatch(Objects::isNull)
+ && !hasPointwiseEdge(connectedAssignments, isInput)) {
return InflightDataRescalingDescriptor.NO_RESCALE;
}
@@ -411,6 +431,25 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
}
TaskStateAssignment connectedAssignment =
connectedAssignments[partition];
+
+ DistributionPattern oldPattern =
+ resolveOldDistributionPattern(
+ isInput, partition, connectedAssignment);
+ DistributionPattern newPattern =
+ resolveNewDistributionPattern(isInput, partition);
+
+ if (oldEdgePatterns != null
+ && (oldPattern == DistributionPattern.POINTWISE
+ || newPattern == DistributionPattern.POINTWISE)) {
+ return createPointwiseRescalingDescriptor(
+ instanceID,
+ partition,
+ isInput,
+ connectedAssignment,
+ oldPattern,
+ newPattern);
+ }
+
SubtasksRescaleMapping rescaleMapping =
Optional.ofNullable(rescaledChannelsMappings[partition])
.orElseGet(
@@ -508,9 +547,38 @@ public SubtasksRescaleMapping getOutputMapping(int partitionIndex) {
.get(gateIndex)
.getUpstreamSubtaskStateMapper(),
"No channel rescaler found during rescaling of channel state");
- final RescaleMappings mapping =
- mapper.getNewToOldSubtasksMapping(
- oldState.get(outputOperatorID).getParallelism(), newParallelism);
+
+ final RescaleMappings mapping;
+ if (mapper == SubtaskStateMapper.POINTWISE_UPSTREAM) {
+ PointwiseRescaleParams params =
+ buildOutputPointwiseRescaleParams(partitionIndex, downstreamAssignment);
+ int oldParallelism = oldState.get(outputOperatorID).getParallelism();
+ boolean isIdentity =
+ oldParallelism == newParallelism
+ && params.getOldUpParallelism() == params.getNewUpParallelism()
+ && params.getOldDownParallelism() == params.getNewDownParallelism()
+ && params.getOldDistributionPattern()
+ == params.getNewDistributionPattern();
+ if (isIdentity) {
+ mapping =
+ RescaleMappings.of(
+ IntStream.range(0, newParallelism).mapToObj(idx -> new int[] {idx}),
+ oldParallelism);
+ } else {
+ mapping =
+ RescaleMappings.of(
+ IntStream.range(0, newParallelism)
+ .mapToObj(
+ idx ->
+ PointwiseChannelMappingUtils
+ .traceOutputSources(idx, params)),
+ oldParallelism);
+ }
+ } else {
+ mapping =
+ mapper.getNewToOldSubtasksMapping(
+ oldState.get(outputOperatorID).getParallelism(), newParallelism);
+ }
return outputSubtaskMappings.compute(
partitionIndex,
(idx, oldMapping) ->
@@ -581,10 +649,156 @@ public boolean hasInFlightDataForResultPartition(int partitionIndex) {
return false;
}
+ boolean hasPointwiseEdge(TaskStateAssignment[] connectedAssignments, boolean isInput) {
+ if (oldEdgePatterns == null) {
+ return false;
+ }
+ for (int i = 0; i < connectedAssignments.length; i++) {
+ if (!hasInFlightData(isInput, i)) {
+ continue;
+ }
+ DistributionPattern oldPattern =
+ resolveOldDistributionPattern(isInput, i, connectedAssignments[i]);
+ DistributionPattern newPattern = resolveNewDistributionPattern(isInput, i);
+ if (oldPattern == DistributionPattern.POINTWISE
+ || newPattern == DistributionPattern.POINTWISE) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private DistributionPattern resolveOldDistributionPattern(
+ boolean isInput, int gateOrPartitionIdx, TaskStateAssignment connectedAssignment) {
+ if (oldEdgePatterns == null) {
+ return DistributionPattern.ALL_TO_ALL;
+ }
+ if (isInput) {
+ int partitionIdxInUpstream =
+ Arrays.asList(connectedAssignment.getDownstreamAssignments()).indexOf(this);
+ if (partitionIdxInUpstream < 0) {
+ return DistributionPattern.ALL_TO_ALL;
+ }
+ DistributionPattern pattern =
+ oldEdgePatterns.getOutputPattern(
+ connectedAssignment.outputOperatorID, partitionIdxInUpstream);
+ return pattern != null ? pattern : DistributionPattern.ALL_TO_ALL;
+ } else {
+ DistributionPattern pattern =
+ oldEdgePatterns.getOutputPattern(outputOperatorID, gateOrPartitionIdx);
+ return pattern != null ? pattern : DistributionPattern.ALL_TO_ALL;
+ }
+ }
+
+ private DistributionPattern resolveNewDistributionPattern(
+ boolean isInput, int gateOrPartitionIdx) {
+ if (isInput) {
+ return executionJobVertex
+ .getInputs()
+ .get(gateOrPartitionIdx)
+ .getConsumingDistributionPattern();
+ } else {
+ return executionJobVertex.getProducedDataSets()[gateOrPartitionIdx]
+ .getConsumingDistributionPattern();
+ }
+ }
+
+ private InflightDataGateOrPartitionRescalingDescriptor createPointwiseRescalingDescriptor(
+ OperatorInstanceID instanceID,
+ int partition,
+ boolean isInput,
+ TaskStateAssignment connectedAssignment,
+ DistributionPattern oldPattern,
+ DistributionPattern newPattern) {
+ final int oldUpParallelism;
+ final int oldDownParallelism;
+ final int newUpParallelism;
+ final int newDownParallelism;
+
+ if (isInput) {
+ oldUpParallelism =
+ connectedAssignment
+ .oldState
+ .get(connectedAssignment.outputOperatorID)
+ .getParallelism();
+ oldDownParallelism = oldState.get(inputOperatorID).getParallelism();
+ newUpParallelism = connectedAssignment.newParallelism;
+ newDownParallelism = newParallelism;
+ } else {
+ oldUpParallelism = oldState.get(outputOperatorID).getParallelism();
+ oldDownParallelism =
+ connectedAssignment
+ .oldState
+ .get(connectedAssignment.inputOperatorID)
+ .getParallelism();
+ newUpParallelism = newParallelism;
+ newDownParallelism = connectedAssignment.newParallelism;
+ }
+
+ PointwiseRescaleParams params =
+ new PointwiseRescaleParams(
+ oldPattern,
+ newPattern,
+ oldUpParallelism,
+ oldDownParallelism,
+ newUpParallelism,
+ newDownParallelism);
+
+ int[] oldSubtaskInstances =
+ computePointwiseOldSubtaskInstances(instanceID.getSubtaskId(), isInput, params);
+
+ boolean isIdentity =
+ oldSubtaskInstances.length == 0
+ || (oldUpParallelism == newUpParallelism
+ && oldDownParallelism == newDownParallelism
+ && oldPattern == newPattern);
+
+ return log(
+ new InflightDataGateOrPartitionRescalingDescriptor(
+ oldSubtaskInstances,
+ RescaleMappings.SYMMETRIC_IDENTITY,
+ emptySet(),
+ isIdentity ? MappingType.IDENTITY : MappingType.RESCALING,
+ params),
+ instanceID.getSubtaskId(),
+ partition);
+ }
+
+ private static int[] computePointwiseOldSubtaskInstances(
+ int newSubtaskIdx, boolean isInput, PointwiseRescaleParams params) {
+ if (isInput) {
+ return SubtaskStateMapper.ROUND_ROBIN.getOldSubtasks(
+ newSubtaskIdx, params.getOldDownParallelism(), params.getNewDownParallelism());
+ } else {
+ return PointwiseChannelMappingUtils.traceOutputSources(newSubtaskIdx, params);
+ }
+ }
+
+ private PointwiseRescaleParams buildOutputPointwiseRescaleParams(
+ int partitionIndex, TaskStateAssignment downstreamAssignment) {
+ DistributionPattern oldPattern =
+ resolveOldDistributionPattern(false, partitionIndex, downstreamAssignment);
+ DistributionPattern newPattern = resolveNewDistributionPattern(false, partitionIndex);
+ int oldUpParallelism = oldState.get(outputOperatorID).getParallelism();
+ int oldDownParallelism =
+ downstreamAssignment
+ .oldState
+ .get(downstreamAssignment.inputOperatorID)
+ .getParallelism();
+ return new PointwiseRescaleParams(
+ oldPattern,
+ newPattern,
+ oldUpParallelism,
+ oldDownParallelism,
+ newParallelism,
+ downstreamAssignment.newParallelism);
+ }
+
void distributeOutputBuffersToDownstream() {
for (Map.Entry> entry :
resultSubpartitionStates.entrySet()) {
OperatorInstanceID operatorInstanceID = entry.getKey();
+ int newUpstreamSubtaskIndex = operatorInstanceID.getSubtaskId();
List stateHandles = entry.getValue();
ResultSubpartitionDistributor distributor =
@@ -592,23 +806,49 @@ void distributeOutputBuffersToDownstream() {
getOutputRescalingDescriptor(operatorInstanceID));
for (final ResultSubpartitionStateHandle stateHandle : stateHandles) {
- distributeOutputBufferToDownstream(stateHandle, distributor);
+ ResultSubpartitionInfo info = stateHandle.getInfo();
+ int partitionIdx = info.getPartitionIdx();
+ TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIdx];
+
+ DistributionPattern oldPattern =
+ resolveOldDistributionPattern(false, partitionIdx, downstreamAssignment);
+ DistributionPattern newPattern = resolveNewDistributionPattern(false, partitionIdx);
+ if (oldEdgePatterns != null
+ && (oldPattern == DistributionPattern.POINTWISE
+ || newPattern == DistributionPattern.POINTWISE)) {
+ // PW path (aligned with createGateOrPartitionRescalingDescriptors):
+ // uses PointwiseChannelMappingUtils topology computation for routing.
+ // Dedup across new upstream subtasks is handled by
+ // computeNewLocalSubpartitionIndex returning -1 for non-primary producers.
+ distributePointwiseOutputBufferToDownstream(
+ stateHandle,
+ partitionIdx,
+ downstreamAssignment,
+ newUpstreamSubtaskIndex);
+ } else {
+ // A2A→A2A path: uses ResultSubpartitionDistributor with the descriptor's
+ // channel mapping (old subpartition → new subpartition). No dedup needed
+ // because MappingBasedRepartitioner assigns each old subtask's handles
+ // to exactly one new subtask.
+ distributeA2AOutputBufferToDownstream(
+ stateHandle, distributor, partitionIdx, downstreamAssignment);
+ }
}
}
}
- private void distributeOutputBufferToDownstream(
- ResultSubpartitionStateHandle stateHandle, ResultSubpartitionDistributor distributor) {
- // From the perspective of the downstream task, the oldUpstreamSubtaskIndex will be
- // treated as the inputChannelIdx, and the info.getSubPartitionIdx() will be treated
- // as the oldDownstreamSubtaskIndex.
- int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex();
+ private void distributeA2AOutputBufferToDownstream(
+ ResultSubpartitionStateHandle stateHandle,
+ ResultSubpartitionDistributor distributor,
+ int partitionIdx,
+ TaskStateAssignment downstreamAssignment) {
ResultSubpartitionInfo info = stateHandle.getInfo();
- int partitionIdx = info.getPartitionIdx();
+ // A2A path: oldUpstreamSubtaskIndex is the inputChannelIdx, and
+ // info.getSubPartitionIdx() is the oldDownstreamSubtaskIndex.
+ int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex();
int oldDownstreamSubtaskIndex = info.getSubPartitionIdx();
int gateIdxResultPartition = findInputGateIdxForResultPartition(partitionIdx);
- TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIdx];
List mappedSubpartitions = distributor.getMappedSubpartitions(info);
for (final ResultSubpartitionInfo mappedSubpartition : mappedSubpartitions) {
@@ -636,6 +876,84 @@ private void distributeOutputBufferToDownstream(
}
}
+ private void distributePointwiseOutputBufferToDownstream(
+ ResultSubpartitionStateHandle stateHandle,
+ int partitionIdx,
+ TaskStateAssignment downstreamAssignment,
+ int newUpstreamSubtaskIndex) {
+ int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex();
+ ResultSubpartitionInfo info = stateHandle.getInfo();
+ int subPartitionIdx = info.getSubPartitionIdx();
+
+ PointwiseRescaleParams params =
+ buildOutputPointwiseRescaleParams(partitionIdx, downstreamAssignment);
+
+ int oldUpPar = params.getOldUpParallelism();
+ int oldDownPar = params.getOldDownParallelism();
+ int newDownPar = params.getNewDownParallelism();
+
+ DistributionPattern oldPattern = params.getOldDistributionPattern();
+
+ int globalOldDownstream;
+ try {
+ globalOldDownstream =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ oldUpstreamSubtaskIndex,
+ subPartitionIdx,
+ oldUpPar,
+ oldDownPar,
+ false,
+ oldPattern);
+ } catch (IllegalStateException e) {
+ return;
+ }
+
+ int targetDownstreamSubtaskId =
+ PointwiseChannelMappingUtils.newSubtaskAssignedFrom(
+ globalOldDownstream, newDownPar);
+
+ int newLocalSP =
+ PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(
+ targetDownstreamSubtaskId, newUpstreamSubtaskIndex, params);
+ if (newLocalSP < 0) {
+ return;
+ }
+
+ int localChannelIdx =
+ PointwiseChannelMappingUtils.globalSubtaskIndexToLocalIndex(
+ globalOldDownstream,
+ oldUpstreamSubtaskIndex,
+ oldUpPar,
+ oldDownPar,
+ true,
+ oldPattern);
+ if (localChannelIdx < 0) {
+ return;
+ }
+
+ int gateIdxResultPartition = findInputGateIdxForResultPartition(partitionIdx);
+
+ OperatorInstanceID downstreamOperatorInstance =
+ new OperatorInstanceID(
+ targetDownstreamSubtaskId, downstreamAssignment.inputOperatorID);
+
+ InputChannelInfo inputChannelInfo =
+ new InputChannelInfo(gateIdxResultPartition, localChannelIdx);
+
+ InputChannelStateHandle upstreamOutputBufferHandle =
+ new InputChannelStateHandle(
+ globalOldDownstream,
+ inputChannelInfo,
+ stateHandle.getDelegate(),
+ stateHandle.getOffsets(),
+ stateHandle.getStateSize());
+
+ List upstreamOutputBufferHandles =
+ downstreamAssignment.upstreamOutputBufferStates.computeIfAbsent(
+ downstreamOperatorInstance, k -> new ArrayList<>());
+ upstreamOutputBufferHandles.add(upstreamOutputBufferHandle);
+ }
+
private int findInputGateIdxForResultPartition(int partitionIndex) {
// Check downstream input state for this partition
TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIndex];
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
index b257c3b40544e..c1c2e2f50cc6e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
@@ -21,6 +21,9 @@
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputSerializer;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.PointwiseChannelMappingUtils;
+import org.apache.flink.runtime.checkpoint.PointwiseRescaleParams;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
@@ -194,31 +197,68 @@ private static GateFilterHandler createGateHandler(
Map> gateVirtualChannels = new HashMap<>();
- for (int oldSubtaskIndex : oldSubtaskIndexes) {
- int numChannels = gate.getNumberOfInputChannels();
- int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, numChannels);
-
- for (int oldChannelIndex : oldChannelIndexes) {
- SubtaskConnectionDescriptor key =
- new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex);
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ rescalingDescriptor.getGateOrPartitionDescriptor(gateIndex);
+
+ if (gateDesc.isPointwiseRescaling()) {
+ PointwiseRescaleParams params = gateDesc.getPointwiseRescaleParams();
+ for (int oldSubtaskIndex : oldSubtaskIndexes) {
+ int numLocal =
+ PointwiseChannelMappingUtils.getOldLocalChannelCount(
+ oldSubtaskIndex, params);
+ for (int localIdx = 0; localIdx < numLocal; localIdx++) {
+ int oldChannelIndex =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ oldSubtaskIndex,
+ localIdx,
+ params.getOldUpParallelism(),
+ params.getOldDownParallelism(),
+ true,
+ params.getOldDistributionPattern());
+ SubtaskConnectionDescriptor key =
+ new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex);
+ if (gateVirtualChannels.containsKey(key)) {
+ continue;
+ }
- if (gateVirtualChannels.containsKey(key)) {
- continue;
+ // Pointwise channel are always exactly recovered
+ RecordFilter recordFilter =
+ VirtualChannelRecordFilterFactory.createPassThroughFilter();
+ RecordDeserializer> deserializer =
+ createDeserializer(filterContext.getTmpDirectories());
+ VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter);
+ gateVirtualChannels.put(key, vc);
}
+ }
+ } else {
+ for (int oldSubtaskIndex : oldSubtaskIndexes) {
+ int numChannels = gate.getNumberOfInputChannels();
+ int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, numChannels);
+
+ for (int oldChannelIndex : oldChannelIndexes) {
+ SubtaskConnectionDescriptor key =
+ new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex);
- // Only ambiguous channels need actual filtering; non-ambiguous ones pass through
- boolean isAmbiguous = rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
+ if (gateVirtualChannels.containsKey(key)) {
+ continue;
+ }
- RecordFilter recordFilter =
- isAmbiguous
- ? filterFactory.createFilter()
- : VirtualChannelRecordFilterFactory.createPassThroughFilter();
+ // Only ambiguous channels need actual filtering; non-ambiguous ones pass
+ // through
+ boolean isAmbiguous =
+ rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
- RecordDeserializer> deserializer =
- createDeserializer(filterContext.getTmpDirectories());
+ RecordFilter recordFilter =
+ isAmbiguous
+ ? filterFactory.createFilter()
+ : VirtualChannelRecordFilterFactory.createPassThroughFilter();
- VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter);
- gateVirtualChannels.put(key, vc);
+ RecordDeserializer> deserializer =
+ createDeserializer(filterContext.getTmpDirectories());
+
+ VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter);
+ gateVirtualChannels.put(key, vc);
+ }
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
index ca01ff37bd369..2a8a858bf1cab 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
@@ -21,6 +21,8 @@
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.PointwiseChannelMappingUtils;
+import org.apache.flink.runtime.checkpoint.PointwiseRescaleParams;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
@@ -78,9 +80,11 @@ class InputChannelRecoveredStateHandler
private final InputGate[] inputGates;
private final InflightDataRescalingDescriptor channelMapping;
+ private final int subtaskIndex;
private final Map rescaledChannels = new HashMap<>();
private final Map oldToNewMappings = new HashMap<>();
+ private final Map pointwiseUpstreamAssignmentCache = new HashMap<>();
/**
* Optional filtering handler for filtering recovered buffers. When non-null, filtering is
@@ -112,10 +116,12 @@ class InputChannelRecoveredStateHandler
InputGate[] inputGates,
InflightDataRescalingDescriptor channelMapping,
@Nullable ChannelStateFilteringHandler filteringHandler,
- int memorySegmentSize) {
+ int memorySegmentSize,
+ int subtaskIndex) {
this.inputGates = inputGates;
this.channelMapping = channelMapping;
this.filteringHandler = filteringHandler;
+ this.subtaskIndex = subtaskIndex;
checkArgument(
memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize);
this.memorySegmentSize = memorySegmentSize;
@@ -185,18 +191,25 @@ public void recover(
Buffer buffer = bufferWithContext.context;
try {
if (buffer.readableBytes() > 0) {
- RecoveredInputChannel channel = getMappedChannels(channelInfo);
-
- if (filteringHandler != null) {
- recoverWithFiltering(
- channel, channelInfo, oldSubtaskIndex, buffer.retainBuffer());
+ if (isPointwiseRescaling(channelInfo.getGateIdx())) {
+ recoverPointwise(channelInfo, oldSubtaskIndex, buffer);
} else {
- channel.onRecoveredStateBuffer(
- EventSerializer.toBuffer(
- new SubtaskConnectionDescriptor(
- oldSubtaskIndex, channelInfo.getInputChannelIdx()),
- false));
- channel.onRecoveredStateBuffer(buffer.retainBuffer());
+ RecoveredInputChannel channel = getMappedChannels(channelInfo);
+ if (filteringHandler != null) {
+ recoverWithFiltering(
+ channel,
+ channelInfo.getGateIdx(),
+ channelInfo.getInputChannelIdx(),
+ oldSubtaskIndex,
+ buffer.retainBuffer());
+ } else {
+ channel.onRecoveredStateBuffer(
+ EventSerializer.toBuffer(
+ new SubtaskConnectionDescriptor(
+ oldSubtaskIndex, channelInfo.getInputChannelIdx()),
+ false));
+ channel.onRecoveredStateBuffer(buffer.retainBuffer());
+ }
}
}
} finally {
@@ -204,18 +217,64 @@ public void recover(
}
}
+ private void recoverPointwise(InputChannelInfo channelInfo, int oldSubtaskIndex, Buffer buffer)
+ throws IOException, InterruptedException {
+ PointwiseRescaleParams params =
+ channelMapping
+ .getGateOrPartitionDescriptor(channelInfo.getGateIdx())
+ .getPointwiseRescaleParams();
+
+ int oldUpstreamSubtaskIndex =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ oldSubtaskIndex,
+ channelInfo.getInputChannelIdx(),
+ params.getOldUpParallelism(),
+ params.getOldDownParallelism(),
+ true,
+ params.getOldDistributionPattern());
+
+ int[] oldToNewUpstream =
+ pointwiseUpstreamAssignmentCache.computeIfAbsent(
+ channelInfo.getGateIdx(),
+ idx ->
+ PointwiseChannelMappingUtils.resolveInputOwnership(
+ subtaskIndex, params));
+ int newUpstreamSubtaskIndex = oldToNewUpstream[oldUpstreamSubtaskIndex];
+
+ int newLocalChannelIndex =
+ PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(
+ newUpstreamSubtaskIndex, subtaskIndex, params);
+
+ if (filteringHandler == null) {
+ SubtaskConnectionDescriptor descriptor =
+ new SubtaskConnectionDescriptor(oldSubtaskIndex, oldUpstreamSubtaskIndex);
+ RecoveredInputChannel channel =
+ getChannel(channelInfo.getGateIdx(), newLocalChannelIndex);
+ channel.onRecoveredStateBuffer(EventSerializer.toBuffer(descriptor, false));
+ channel.onRecoveredStateBuffer(buffer.retainBuffer());
+ } else {
+ recoverWithFiltering(
+ getChannel(channelInfo.getGateIdx(), newLocalChannelIndex),
+ channelInfo.getGateIdx(),
+ oldUpstreamSubtaskIndex,
+ oldSubtaskIndex,
+ buffer.retainBuffer());
+ }
+ }
+
private void recoverWithFiltering(
RecoveredInputChannel channel,
- InputChannelInfo channelInfo,
+ int gateIdx,
+ int inputChannelIdx,
int oldSubtaskIndex,
Buffer retainedBuffer)
throws IOException, InterruptedException {
checkState(filteringHandler != null, "filtering handler not set.");
List filteredBuffers =
filteringHandler.filterAndRewrite(
- channelInfo.getGateIdx(),
+ gateIdx,
oldSubtaskIndex,
- channelInfo.getInputChannelIdx(),
+ inputChannelIdx,
retainedBuffer,
channel::requestBufferBlocking);
@@ -260,6 +319,9 @@ private RecoveredInputChannel getMappedChannels(InputChannelInfo channelInfo) {
@Nonnull
private RecoveredInputChannel calculateMapping(InputChannelInfo info) {
+ if (isPointwiseRescaling(info.getGateIdx())) {
+ return getChannel(info.getGateIdx(), 0);
+ }
final RescaleMappings oldToNewMapping =
oldToNewMappings.computeIfAbsent(
info.getGateIdx(), idx -> channelMapping.getChannelMapping(idx).invert());
@@ -270,6 +332,10 @@ private RecoveredInputChannel calculateMapping(InputChannelInfo info) {
+ "one buffer is expected to be processed once by the same task.");
return getChannel(info.getGateIdx(), mappedIndexes[0]);
}
+
+ private boolean isPointwiseRescaling(int gateIdx) {
+ return channelMapping.getGateOrPartitionDescriptor(gateIdx).isPointwiseRescaling();
+ }
}
class ResultSubpartitionRecoveredStateHandler
@@ -277,19 +343,20 @@ class ResultSubpartitionRecoveredStateHandler
private final ResultPartitionWriter[] writers;
private final boolean notifyAndBlockOnCompletion;
+ private final InflightDataRescalingDescriptor channelMapping;
+ private final int subtaskIndex;
private final ResultSubpartitionDistributor resultSubpartitionDistributor;
ResultSubpartitionRecoveredStateHandler(
ResultPartitionWriter[] writers,
boolean notifyAndBlockOnCompletion,
- InflightDataRescalingDescriptor channelMapping) {
+ InflightDataRescalingDescriptor channelMapping,
+ int subtaskIndex) {
this.writers = writers;
+ this.channelMapping = channelMapping;
+ this.subtaskIndex = subtaskIndex;
this.resultSubpartitionDistributor =
new ResultSubpartitionDistributor(channelMapping) {
- /**
- * Override the getSubpartitionInfo to perform type checking on the
- * ResultPartitionWriter.
- */
@Override
ResultSubpartitionInfo getSubpartitionInfo(
int partitionIndex, int subPartitionIdx) {
@@ -323,25 +390,79 @@ public void recover(
if (!bufferConsumer.isDataAvailable()) {
return;
}
- final List mappedSubpartitions =
- resultSubpartitionDistributor.getMappedSubpartitions(subpartitionInfo);
- CheckpointedResultPartition checkpointedResultPartition =
- getCheckpointedResultPartition(subpartitionInfo.getPartitionIdx());
- for (final ResultSubpartitionInfo mappedSubpartition : mappedSubpartitions) {
- // channel selector is created from the downstream's point of view: the
- // subtask of downstream = subpartition index of recovered buffer
- final SubtaskConnectionDescriptor channelSelector =
- new SubtaskConnectionDescriptor(
- subpartitionInfo.getSubPartitionIdx(), oldSubtaskIndex);
- checkpointedResultPartition.addRecovered(
- mappedSubpartition.getSubPartitionIdx(),
- EventSerializer.toBufferConsumer(channelSelector, false));
- checkpointedResultPartition.addRecovered(
- mappedSubpartition.getSubPartitionIdx(), bufferConsumer.copy());
+ if (isPointwiseRescaling(subpartitionInfo.getPartitionIdx())) {
+ recoverPointwise(subpartitionInfo, oldSubtaskIndex, bufferConsumer);
+ } else {
+ recoverAllToAll(subpartitionInfo, oldSubtaskIndex, bufferConsumer);
}
}
}
+ private void recoverAllToAll(
+ ResultSubpartitionInfo subpartitionInfo,
+ int oldSubtaskIndex,
+ BufferConsumer bufferConsumer)
+ throws IOException {
+ final List mappedSubpartitions =
+ resultSubpartitionDistributor.getMappedSubpartitions(subpartitionInfo);
+ CheckpointedResultPartition checkpointedResultPartition =
+ getCheckpointedResultPartition(subpartitionInfo.getPartitionIdx());
+ for (final ResultSubpartitionInfo mappedSubpartition : mappedSubpartitions) {
+ final SubtaskConnectionDescriptor channelSelector =
+ new SubtaskConnectionDescriptor(
+ subpartitionInfo.getSubPartitionIdx(), oldSubtaskIndex);
+ checkpointedResultPartition.addRecovered(
+ mappedSubpartition.getSubPartitionIdx(),
+ EventSerializer.toBufferConsumer(channelSelector, false));
+ checkpointedResultPartition.addRecovered(
+ mappedSubpartition.getSubPartitionIdx(), bufferConsumer.copy());
+ }
+ }
+
+ private void recoverPointwise(
+ ResultSubpartitionInfo subpartitionInfo,
+ int oldSubtaskIndex,
+ BufferConsumer bufferConsumer)
+ throws IOException {
+ PointwiseRescaleParams params =
+ channelMapping
+ .getGateOrPartitionDescriptor(subpartitionInfo.getPartitionIdx())
+ .getPointwiseRescaleParams();
+
+ int oldDownstreamSubtaskIndex =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ oldSubtaskIndex,
+ subpartitionInfo.getSubPartitionIdx(),
+ params.getOldUpParallelism(),
+ params.getOldDownParallelism(),
+ false,
+ params.getOldDistributionPattern());
+
+ int newDownstreamSubtaskIndex =
+ PointwiseChannelMappingUtils.newSubtaskAssignedFrom(
+ oldDownstreamSubtaskIndex, params.getNewDownParallelism());
+
+ int newLocalSubpartitionIndex =
+ PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(
+ newDownstreamSubtaskIndex, subtaskIndex, params);
+ if (newLocalSubpartitionIndex < 0) {
+ return;
+ }
+
+ SubtaskConnectionDescriptor channelSelector =
+ new SubtaskConnectionDescriptor(oldDownstreamSubtaskIndex, oldSubtaskIndex);
+ CheckpointedResultPartition checkpointedResultPartition =
+ getCheckpointedResultPartition(subpartitionInfo.getPartitionIdx());
+ checkpointedResultPartition.addRecovered(
+ newLocalSubpartitionIndex,
+ EventSerializer.toBufferConsumer(channelSelector, false));
+ checkpointedResultPartition.addRecovered(newLocalSubpartitionIndex, bufferConsumer.copy());
+ }
+
+ private boolean isPointwiseRescaling(int partitionIdx) {
+ return channelMapping.getGateOrPartitionDescriptor(partitionIdx).isPointwiseRescaling();
+ }
+
private CheckpointedResultPartition getCheckpointedResultPartition(int partitionIndex) {
ResultPartitionWriter writer = writers[partitionIndex];
if (!(writer instanceof CheckpointedResultPartition)) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
index c52572e52faec..d366d7eaf4bc0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
@@ -52,9 +52,15 @@ public class SequentialChannelStateReaderImpl implements SequentialChannelStateR
private final TaskStateSnapshot taskStateSnapshot;
private final ChannelStateSerializer serializer;
private final ChannelStateChunkReader chunkReader;
+ private final int subtaskIndex;
public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) {
+ this(taskStateSnapshot, 0);
+ }
+
+ public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot, int subtaskIndex) {
this.taskStateSnapshot = taskStateSnapshot;
+ this.subtaskIndex = subtaskIndex;
serializer = new ChannelStateSerializerImpl();
chunkReader = new ChannelStateChunkReader(serializer);
}
@@ -75,7 +81,8 @@ public void readInputData(InputGate[] inputGates, RecordFilterContext filterCont
inputGates,
taskStateSnapshot.getInputRescalingDescriptor(),
filteringHandler,
- filterContext.getMemorySegmentSize())) {
+ filterContext.getMemorySegmentSize(),
+ subtaskIndex)) {
read(
stateHandler,
groupByDelegate(
@@ -102,7 +109,8 @@ public void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlo
new ResultSubpartitionRecoveredStateHandler(
writers,
notifyAndBlockOnCompletion,
- taskStateSnapshot.getOutputRescalingDescriptor())) {
+ taskStateSnapshot.getOutputRescalingDescriptor(),
+ subtaskIndex)) {
read(
stateHandler,
groupByDelegate(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java
index c9136cebc549b..7c8d3886d9c72 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java
@@ -173,6 +173,21 @@ public int[] getOldSubtasks(
}
},
+ POINTWISE_UPSTREAM {
+ @Override
+ public int[] getOldSubtasks(
+ int newSubtaskIndex, int oldNumberOfSubtasks, int newNumberOfSubtasks) {
+ throw new UnsupportedOperationException(
+ "POINTWISE_UPSTREAM requires PointwiseRescaleParams; "
+ + "use PointwiseChannelMappingUtils.traceOutputSources instead.");
+ }
+
+ @Override
+ public boolean isAmbiguous() {
+ return true;
+ }
+ },
+
UNSUPPORTED {
@Override
public int[] getOldSubtasks(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java
index 61765ad6bc225..22d0d464e9de9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java
@@ -113,7 +113,8 @@ public TaskStateManagerImpl(
new SequentialChannelStateReaderImpl(
jobManagerTaskRestore == null
? new TaskStateSnapshot()
- : jobManagerTaskRestore.getTaskStateSnapshot()));
+ : jobManagerTaskRestore.getTaskStateSnapshot(),
+ executionAttemptID.getSubtaskIndex()));
}
public TaskStateManagerImpl(
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/EdgeDistributionPatternHook.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/EdgeDistributionPatternHook.java
new file mode 100644
index 0000000000000..7186389f75f5d
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/EdgeDistributionPatternHook.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.graph;
+
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.runtime.checkpoint.EdgeDistributionPatternSnapshot;
+import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executor;
+
+/**
+ * A {@link MasterTriggerRestoreHook} that persists the current job's edge {@link
+ * org.apache.flink.runtime.jobgraph.DistributionPattern}s into checkpoint metadata, so they are
+ * available when restoring after a shuffle-mode change.
+ */
+class EdgeDistributionPatternHook implements MasterTriggerRestoreHook {
+
+ private final byte[] snapshotBytes;
+
+ EdgeDistributionPatternHook(byte[] snapshotBytes) {
+ this.snapshotBytes = snapshotBytes;
+ }
+
+ @Override
+ public String getIdentifier() {
+ return EdgeDistributionPatternSnapshot.HOOK_IDENTIFIER;
+ }
+
+ @Nullable
+ @Override
+ public CompletableFuture triggerCheckpoint(
+ long checkpointId, long timestamp, Executor executor) {
+ return CompletableFuture.completedFuture(snapshotBytes);
+ }
+
+ @Override
+ public void restoreCheckpoint(long checkpointId, @Nullable byte[] checkpointData) {
+ // no-op: data is extracted directly from MasterState before assignStates()
+ }
+
+ @Nullable
+ @Override
+ public SimpleVersionedSerializer createCheckpointDataSerializer() {
+ return ByteArraySerializer.INSTANCE;
+ }
+
+ private static final class ByteArraySerializer implements SimpleVersionedSerializer {
+ static final ByteArraySerializer INSTANCE = new ByteArraySerializer();
+
+ @Override
+ public int getVersion() {
+ return 1;
+ }
+
+ @Override
+ public byte[] serialize(byte[] obj) throws IOException {
+ return obj;
+ }
+
+ @Override
+ public byte[] deserialize(int version, byte[] serialized) throws IOException {
+ return serialized;
+ }
+ }
+
+ static class Factory implements MasterTriggerRestoreHook.Factory {
+ private static final long serialVersionUID = 1L;
+
+ private final byte[] snapshotBytes;
+
+ Factory(EdgeDistributionPatternSnapshot snapshot) throws IOException {
+ this.snapshotBytes = snapshot.toBytes();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public MasterTriggerRestoreHook create() {
+ return (MasterTriggerRestoreHook) new EdgeDistributionPatternHook(snapshotBytes);
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
index e65ec1246ea18..ca034504cc06f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
@@ -269,9 +269,14 @@ public StreamGraph generate() {
LineageGraph lineageGraph = LineageGraphUtils.convertToLineageGraph(transformations);
streamGraph.setLineageGraph(lineageGraph);
+ boolean forceUnaligned =
+ checkpointConfig.isForceUnalignedCheckpoints() && !streamGraph.isIterative();
for (StreamNode node : streamGraph.getStreamNodes()) {
if (node.getInEdges().stream()
- .anyMatch(e -> !e.getPartitioner().isSupportsUnalignedCheckpoint())) {
+ .anyMatch(
+ e ->
+ !e.getPartitioner()
+ .isSupportsUnalignedCheckpoint(forceUnaligned))) {
for (StreamEdge edge : node.getInEdges()) {
edge.setSupportsUnalignedCheckpoints(false);
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 66776c03fb5fd..493908e2493bb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -34,6 +34,9 @@
import org.apache.flink.configuration.IllegalConfigurationException;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.checkpoint.EdgeDistributionPatternSnapshot;
+import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook;
+import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.InputOutputFormatVertex;
@@ -49,6 +52,7 @@
import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil;
import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
+import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable;
import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalPipelinedRegion;
import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology;
@@ -253,9 +257,70 @@ private JobGraph createJobGraph() {
// Wait for the serialization of operator coordinators and stream config.
serializeOperatorCoordinatorsAndStreamConfig(serializationExecutor, jobVertexBuildContext);
+ configureEdgeDistributionPatternHook();
+
return jobGraph;
}
+ @SuppressWarnings("unchecked")
+ private void configureEdgeDistributionPatternHook() {
+ if (!streamGraph.getCheckpointConfig().isUnalignedCheckpointsEnabled()
+ || !streamGraph.getCheckpointConfig().isForceUnalignedCheckpoints()) {
+ return;
+ }
+ JobCheckpointingSettings settings = jobGraph.getCheckpointingSettings();
+ if (settings == null) {
+ return;
+ }
+ try {
+ Map outputPatterns = new HashMap<>();
+ for (JobVertex jv : jobGraph.getVerticesSortedTopologicallyFromSources()) {
+ OperatorID outputOpId = jv.getOperatorIDs().get(0).getGeneratedOperatorID();
+ List outputs = jv.getProducedDataSets();
+ DistributionPattern[] patterns = new DistributionPattern[outputs.size()];
+ for (int i = 0; i < outputs.size(); i++) {
+ IntermediateDataSet ds = outputs.get(i);
+ patterns[i] =
+ ds.getConsumers().isEmpty()
+ ? DistributionPattern.ALL_TO_ALL
+ : ds.getDistributionPattern();
+ }
+ outputPatterns.put(outputOpId, patterns);
+ }
+
+ MasterTriggerRestoreHook.Factory[] existingHooks;
+ SerializedValue serializedHooks =
+ settings.getMasterHooks();
+ if (serializedHooks != null) {
+ existingHooks =
+ serializedHooks.deserializeValue(
+ Thread.currentThread().getContextClassLoader());
+ } else {
+ existingHooks = new MasterTriggerRestoreHook.Factory[0];
+ }
+
+ MasterTriggerRestoreHook.Factory[] newHooks =
+ new MasterTriggerRestoreHook.Factory[existingHooks.length + 1];
+ System.arraycopy(existingHooks, 0, newHooks, 0, existingHooks.length);
+ newHooks[existingHooks.length] =
+ new EdgeDistributionPatternHook.Factory(
+ new EdgeDistributionPatternSnapshot(outputPatterns));
+
+ JobCheckpointingSettings newSettings =
+ new JobCheckpointingSettings(
+ settings.getCheckpointCoordinatorConfiguration(),
+ settings.getDefaultStateBackend(),
+ settings.isChangelogStateBackendEnabled(),
+ settings.getDefaultCheckpointStorage(),
+ new SerializedValue<>(newHooks),
+ settings.isStateBackendUseManagedMemory());
+ jobGraph.setSnapshotSettings(newSettings);
+ } catch (Exception e) {
+ throw new FlinkRuntimeException(
+ "Failed to configure edge distribution pattern hook", e);
+ }
+ }
+
public static void serializeOperatorCoordinatorsAndStreamConfig(
Executor serializationExecutor, JobVertexBuildContext jobVertexBuildContext) {
try {
@@ -1634,8 +1699,29 @@ public static IntermediateDataSet connect(
// set strategy name so that web interface can show it.
jobEdge.setShipStrategyName(partitioner.toString());
- jobEdge.setDownstreamSubtaskStateMapper(partitioner.getDownstreamSubtaskStateMapper());
- jobEdge.setUpstreamSubtaskStateMapper(partitioner.getUpstreamSubtaskStateMapper());
+ SubtaskStateMapper downstreamMapper = partitioner.getDownstreamSubtaskStateMapper();
+ SubtaskStateMapper upstreamMapper = partitioner.getUpstreamSubtaskStateMapper();
+
+ if (jobVertexBuildContext
+ .getStreamGraph()
+ .getCheckpointConfig()
+ .isForceUnalignedCheckpoints()
+ && jobVertexBuildContext
+ .getStreamGraph()
+ .getCheckpointConfig()
+ .isUnalignedCheckpointsEnabled()
+ && !partitioner.isSupportsUnalignedCheckpoint(false)
+ && partitioner.isSupportsUnalignedCheckpoint(true)) {
+ if (downstreamMapper == SubtaskStateMapper.UNSUPPORTED) {
+ downstreamMapper = SubtaskStateMapper.ROUND_ROBIN;
+ }
+ if (upstreamMapper == SubtaskStateMapper.UNSUPPORTED) {
+ upstreamMapper = SubtaskStateMapper.POINTWISE_UPSTREAM;
+ }
+ }
+
+ jobEdge.setDownstreamSubtaskStateMapper(downstreamMapper);
+ jobEdge.setUpstreamSubtaskStateMapper(upstreamMapper);
if (LOG.isDebugEnabled()) {
LOG.debug(
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java
index 6f1f3bda8b0e4..0d262776efbb6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java
@@ -19,6 +19,9 @@
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.PointwiseChannelMappingUtils;
+import org.apache.flink.runtime.checkpoint.PointwiseRescaleParams;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
@@ -152,6 +155,20 @@ static DemultiplexingRecordDeserializer create(
if (oldSubtaskIndexes.length == 0) {
return UNMAPPED;
}
+
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ rescalingDescriptor.getGateOrPartitionDescriptor(channelInfo.getGateIdx());
+
+ if (gateDesc.isPointwiseRescaling()) {
+ return createPointwise(
+ channelInfo,
+ rescalingDescriptor,
+ gateDesc,
+ deserializerFactory,
+ recordFilterFactory,
+ oldSubtaskIndexes);
+ }
+
final int[] oldChannelIndexes =
rescalingDescriptor
.getChannelMapping(channelInfo.getGateIdx())
@@ -179,6 +196,56 @@ static DemultiplexingRecordDeserializer create(
return new DemultiplexingRecordDeserializer(virtualChannels);
}
+ private static DemultiplexingRecordDeserializer createPointwise(
+ InputChannelInfo channelInfo,
+ InflightDataRescalingDescriptor rescalingDescriptor,
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc,
+ Function>>
+ deserializerFactory,
+ Function> recordFilterFactory,
+ int[] oldSubtaskIndexes) {
+ PointwiseRescaleParams params = gateDesc.getPointwiseRescaleParams();
+
+ int totalVirtualChannels = 0;
+ for (int subtask : oldSubtaskIndexes) {
+ totalVirtualChannels +=
+ PointwiseChannelMappingUtils.getOldLocalChannelCount(subtask, params);
+ }
+ if (totalVirtualChannels == 0) {
+ return UNMAPPED;
+ }
+
+ Map> virtualChannels =
+ Maps.newHashMapWithExpectedSize(totalVirtualChannels);
+ for (int subtask : oldSubtaskIndexes) {
+ int numLocalChannels =
+ PointwiseChannelMappingUtils.getOldLocalChannelCount(subtask, params);
+ for (int localChannelIndex = 0;
+ localChannelIndex < numLocalChannels;
+ localChannelIndex++) {
+ int oldUpstreamSubtaskIndex =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ subtask,
+ localChannelIndex,
+ params.getOldUpParallelism(),
+ params.getOldDownParallelism(),
+ true,
+ params.getOldDistributionPattern());
+ SubtaskConnectionDescriptor descriptor =
+ new SubtaskConnectionDescriptor(subtask, oldUpstreamSubtaskIndex);
+ virtualChannels.put(
+ descriptor,
+ new VirtualChannel<>(
+ deserializerFactory.apply(totalVirtualChannels),
+ rescalingDescriptor.isAmbiguous(channelInfo.getGateIdx(), subtask)
+ ? recordFilterFactory.apply(channelInfo)
+ : RecordFilter.acceptAll()));
+ }
+ }
+
+ return new DemultiplexingRecordDeserializer(virtualChannels);
+ }
+
@Override
public String toString() {
return "DemultiplexingRecordDeserializer{" + "channels=" + channels.keySet() + '}';
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java
index 7a07c8bfa9e00..80db01c3d6b81 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java
@@ -87,7 +87,11 @@ public SubtaskStateMapper getUpstreamSubtaskStateMapper() {
public abstract boolean isPointwise();
public boolean isSupportsUnalignedCheckpoint() {
- return supportsUnalignedCheckpoint && !isPointwise() && !isBroadcast();
+ return isSupportsUnalignedCheckpoint(false);
+ }
+
+ public boolean isSupportsUnalignedCheckpoint(boolean forceUnaligned) {
+ return supportsUnalignedCheckpoint && (forceUnaligned || !isPointwise()) && !isBroadcast();
}
public void disableUnalignedCheckpoints() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshotTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshotTest.java
new file mode 100644
index 0000000000000..4074ce8b9963f
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshotTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link EdgeDistributionPatternSnapshot}. */
+class EdgeDistributionPatternSnapshotTest {
+
+ @Test
+ void roundTripSerialization() throws IOException {
+ Map patterns = new HashMap<>();
+ OperatorID op1 = new OperatorID();
+ OperatorID op2 = new OperatorID();
+ patterns.put(op1, new DistributionPattern[] {DistributionPattern.POINTWISE});
+ patterns.put(
+ op2,
+ new DistributionPattern[] {
+ DistributionPattern.ALL_TO_ALL, DistributionPattern.POINTWISE
+ });
+
+ EdgeDistributionPatternSnapshot original = new EdgeDistributionPatternSnapshot(patterns);
+ byte[] bytes = original.toBytes();
+ EdgeDistributionPatternSnapshot restored = EdgeDistributionPatternSnapshot.fromBytes(bytes);
+
+ assertThat(restored.getOutputPatterns(op1)).containsExactly(DistributionPattern.POINTWISE);
+ assertThat(restored.getOutputPatterns(op2))
+ .containsExactly(DistributionPattern.ALL_TO_ALL, DistributionPattern.POINTWISE);
+ }
+
+ @Test
+ void emptySnapshot() throws IOException {
+ EdgeDistributionPatternSnapshot original =
+ new EdgeDistributionPatternSnapshot(new HashMap<>());
+ byte[] bytes = original.toBytes();
+ EdgeDistributionPatternSnapshot restored = EdgeDistributionPatternSnapshot.fromBytes(bytes);
+
+ assertThat(restored.getOutputPatterns(new OperatorID())).isNull();
+ }
+
+ @Test
+ void getOutputPatternByIndex() throws IOException {
+ Map patterns = new HashMap<>();
+ OperatorID op = new OperatorID();
+ patterns.put(
+ op,
+ new DistributionPattern[] {
+ DistributionPattern.POINTWISE,
+ DistributionPattern.ALL_TO_ALL,
+ DistributionPattern.POINTWISE
+ });
+
+ EdgeDistributionPatternSnapshot snapshot = new EdgeDistributionPatternSnapshot(patterns);
+
+ assertThat(snapshot.getOutputPattern(op, 0)).isEqualTo(DistributionPattern.POINTWISE);
+ assertThat(snapshot.getOutputPattern(op, 1)).isEqualTo(DistributionPattern.ALL_TO_ALL);
+ assertThat(snapshot.getOutputPattern(op, 2)).isEqualTo(DistributionPattern.POINTWISE);
+ assertThat(snapshot.getOutputPattern(op, 3)).isNull();
+ assertThat(snapshot.getOutputPattern(new OperatorID(), 0)).isNull();
+ }
+
+ @Test
+ void unknownPatternByteRejected() throws IOException {
+ java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream();
+ java.io.DataOutputStream dos = new java.io.DataOutputStream(baos);
+ dos.writeInt(1); // version
+ dos.writeInt(1); // numOperators
+ dos.writeLong(0L); // opId lower
+ dos.writeLong(0L); // opId upper
+ dos.writeInt(1); // numPartitions
+ dos.writeByte(7); // unknown pattern byte
+ dos.flush();
+
+ assertThatThrownBy(() -> EdgeDistributionPatternSnapshot.fromBytes(baos.toByteArray()))
+ .isInstanceOf(IOException.class)
+ .hasMessageContaining("Unknown DistributionPattern byte");
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtilsTest.java
new file mode 100644
index 0000000000000..11da7bd5224a8
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtilsTest.java
@@ -0,0 +1,788 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link PointwiseChannelMappingUtils}. */
+class PointwiseChannelMappingUtilsTest {
+
+ private static final DistributionPattern PW = DistributionPattern.POINTWISE;
+ private static final DistributionPattern A2A = DistributionPattern.ALL_TO_ALL;
+
+ private static PointwiseRescaleParams params(int oldUp, int oldDown, int newUp, int newDown) {
+ return new PointwiseRescaleParams(PW, PW, oldUp, oldDown, newUp, newDown);
+ }
+
+ private static PointwiseRescaleParams params(
+ DistributionPattern oldPat,
+ DistributionPattern newPat,
+ int oldUp,
+ int oldDown,
+ int newUp,
+ int newDown) {
+ return new PointwiseRescaleParams(oldPat, newPat, oldUp, oldDown, newUp, newDown);
+ }
+
+ // ======================== producersOf / consumersOf ========================
+
+ @Test
+ void producersOfAndConsumersOfAreInverses() {
+ for (int upPar = 1; upPar <= 8; upPar++) {
+ for (int downPar = 1; downPar <= 8; downPar++) {
+ for (int d = 0; d < downPar; d++) {
+ int[] producers = PointwiseChannelMappingUtils.producersOf(d, upPar, downPar);
+ for (int p : producers) {
+ int[] consumers =
+ PointwiseChannelMappingUtils.consumersOf(p, upPar, downPar);
+ assertThat(consumers)
+ .as(
+ "up=%d down=%d: d=%d should be in consumersOf(%d)",
+ upPar, downPar, d, p)
+ .contains(d);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ void producersOfCoversAllUpstream() {
+ for (int upPar = 1; upPar <= 8; upPar++) {
+ for (int downPar = 1; downPar <= 8; downPar++) {
+ Set allProducers = new HashSet<>();
+ for (int d = 0; d < downPar; d++) {
+ int[] producers = PointwiseChannelMappingUtils.producersOf(d, upPar, downPar);
+ for (int p : producers) {
+ allProducers.add(p);
+ }
+ }
+ assertThat(allProducers)
+ .as("up=%d down=%d: all upstream covered", upPar, downPar)
+ .hasSize(upPar);
+ }
+ }
+ }
+
+ @Test
+ void consumersOfCoversAllDownstream() {
+ for (int upPar = 1; upPar <= 8; upPar++) {
+ for (int downPar = 1; downPar <= 8; downPar++) {
+ Set allConsumers = new HashSet<>();
+ for (int u = 0; u < upPar; u++) {
+ int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar);
+ for (int c : consumers) {
+ allConsumers.add(c);
+ }
+ }
+ assertThat(allConsumers)
+ .as("up=%d down=%d: all downstream covered", upPar, downPar)
+ .hasSize(downPar);
+ }
+ }
+ }
+
+ @Test
+ void producersOfConcreteValues() {
+ // 2 upstream -> 4 downstream (upscale): each downstream has one producer
+ assertThat(PointwiseChannelMappingUtils.producersOf(0, 2, 4)).containsExactly(0);
+ assertThat(PointwiseChannelMappingUtils.producersOf(1, 2, 4)).containsExactly(0);
+ assertThat(PointwiseChannelMappingUtils.producersOf(2, 2, 4)).containsExactly(1);
+ assertThat(PointwiseChannelMappingUtils.producersOf(3, 2, 4)).containsExactly(1);
+
+ // 4 upstream -> 2 downstream (downscale): each downstream merges 2 producers
+ assertThat(PointwiseChannelMappingUtils.producersOf(0, 4, 2)).containsExactly(0, 1);
+ assertThat(PointwiseChannelMappingUtils.producersOf(1, 4, 2)).containsExactly(2, 3);
+
+ // 3 upstream -> 3 downstream (equal): identity
+ assertThat(PointwiseChannelMappingUtils.producersOf(0, 3, 3)).containsExactly(0);
+ assertThat(PointwiseChannelMappingUtils.producersOf(1, 3, 3)).containsExactly(1);
+ assertThat(PointwiseChannelMappingUtils.producersOf(2, 3, 3)).containsExactly(2);
+ }
+
+ @Test
+ void consumersOfConcreteValues() {
+ // 2 upstream -> 4 downstream (upscale): each upstream serves 2 consumers
+ assertThat(PointwiseChannelMappingUtils.consumersOf(0, 2, 4)).containsExactly(0, 1);
+ assertThat(PointwiseChannelMappingUtils.consumersOf(1, 2, 4)).containsExactly(2, 3);
+
+ // 4 upstream -> 2 downstream (downscale): each upstream feeds one consumer
+ assertThat(PointwiseChannelMappingUtils.consumersOf(0, 4, 2)).containsExactly(0);
+ assertThat(PointwiseChannelMappingUtils.consumersOf(1, 4, 2)).containsExactly(0);
+ assertThat(PointwiseChannelMappingUtils.consumersOf(2, 4, 2)).containsExactly(1);
+ assertThat(PointwiseChannelMappingUtils.consumersOf(3, 4, 2)).containsExactly(1);
+ }
+
+ // ======================== localIndexToGlobalSubtaskIndex ========================
+
+ @Test
+ void localIndexToGlobalSubtaskIndexPointwiseInputSide() {
+ // 4 upstream -> 2 downstream: d=0 has producers {0,1}, d=1 has producers {2,3}
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 0, 0, 4, 2, true, PW))
+ .isEqualTo(0);
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 0, 1, 4, 2, true, PW))
+ .isEqualTo(1);
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 1, 0, 4, 2, true, PW))
+ .isEqualTo(2);
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 1, 1, 4, 2, true, PW))
+ .isEqualTo(3);
+ }
+
+ @Test
+ void localIndexToGlobalSubtaskIndexPointwiseOutputSide() {
+ // 2 upstream -> 4 downstream: u=0 has consumers {0,1}, u=1 has consumers {2,3}
+ // consumers start at 0 and 2 (both % 2 == 0), so modulo == sequential here.
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 0, 0, 2, 4, false, PW))
+ .isEqualTo(0);
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 0, 1, 2, 4, false, PW))
+ .isEqualTo(1);
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 1, 0, 2, 4, false, PW))
+ .isEqualTo(2);
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 1, 1, 2, 4, false, PW))
+ .isEqualTo(3);
+
+ // 4 upstream -> 6 downstream: u=2 has consumers {3,4}
+ // consumers[0]=3, 3%2=1 ≠ 0, so modulo ≠ sequential:
+ // sp 0 → consumer 4 (4%2=0), sp 1 → consumer 3 (3%2=1)
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 2, 0, 4, 6, false, PW))
+ .isEqualTo(4);
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 2, 1, 4, 6, false, PW))
+ .isEqualTo(3);
+ }
+
+ @Test
+ void localIndexToGlobalSubtaskIndexBoundsCheck() {
+ assertThatThrownBy(
+ () ->
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 0, 5, 4, 2, true, PW))
+ .isInstanceOf(ArrayIndexOutOfBoundsException.class);
+ }
+
+ // ======================== getOldLocalChannelCount ========================
+
+ @Test
+ void getOldLocalChannelCountAllToAll() {
+ PointwiseRescaleParams p = new PointwiseRescaleParams(A2A, A2A, 5, 3, 5, 3);
+ assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(0, p)).isEqualTo(5);
+ assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(2, p)).isEqualTo(5);
+ }
+
+ @Test
+ void getOldLocalChannelCountPointwise() {
+ // 4 upstream -> 2 downstream: each downstream has 2 producers
+ PointwiseRescaleParams p = params(4, 2, 4, 2);
+ assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(0, p)).isEqualTo(2);
+ assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(1, p)).isEqualTo(2);
+ }
+
+ // ======================== computeNewLocalInputChannelIndex / SubpartitionIndex ===
+
+ @Test
+ void computeNewLocalInputChannelIndexAllToAll() {
+ PointwiseRescaleParams p = new PointwiseRescaleParams(PW, A2A, 4, 2, 4, 2);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(3, 0, p))
+ .isEqualTo(3);
+ }
+
+ @Test
+ void computeNewLocalInputChannelIndexPointwise() {
+ // 4 upstream -> 2 downstream: d=0 connects to {0,1}, d=1 connects to {2,3}
+ PointwiseRescaleParams p = params(4, 2, 4, 2);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(0, 0, p))
+ .isEqualTo(0);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(1, 0, p))
+ .isEqualTo(1);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(2, 1, p))
+ .isEqualTo(0);
+ }
+
+ @Test
+ void computeNewLocalSubpartitionIndexAllToAll() {
+ PointwiseRescaleParams p = new PointwiseRescaleParams(PW, A2A, 2, 4, 2, 4);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(3, 0, p))
+ .isEqualTo(3);
+ }
+
+ @Test
+ void computeNewLocalSubpartitionIndexPointwise() {
+ // 2 upstream -> 4 downstream: u=0 connects to {0,1}, u=1 connects to {2,3}
+ PointwiseRescaleParams p = params(2, 4, 2, 4);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(0, 0, p))
+ .isEqualTo(0);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(1, 0, p))
+ .isEqualTo(1);
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(2, 1, p))
+ .isEqualTo(0);
+ }
+
+ // ======================== computeOldToNewSubtaskAssignment ========================
+
+ @Test
+ void assignmentCoversAllOldSubtasks() {
+ for (int oldUp = 1; oldUp <= 6; oldUp++) {
+ for (int newUp = 1; newUp <= 6; newUp++) {
+ for (int oldDown = 1; oldDown <= 6; oldDown++) {
+ for (int newDown = 1; newDown <= 6; newDown++) {
+ int[] mapping =
+ PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(
+ params(oldUp, oldDown, newUp, newDown), true);
+ assertThat(mapping).hasSize(oldUp);
+ for (int i = 0; i < oldUp; i++) {
+ assertThat(mapping[i])
+ .as(
+ "up %d->%d, down %d->%d: old[%d] in range",
+ oldUp, newUp, oldDown, newDown, i)
+ .isBetween(0, newUp - 1);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ void assignmentIsIdentityWhenUnchanged() {
+ for (int p = 1; p <= 8; p++) {
+ int[] mapping =
+ PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(
+ params(p, p, p, p), true);
+ for (int i = 0; i < p; i++) {
+ assertThat(mapping[i]).as("p=%d: identity at %d", p, i).isEqualTo(i);
+ }
+ }
+ }
+
+ @Test
+ void assignmentDownstreamIsIdentityWhenUnchanged() {
+ for (int p = 1; p <= 8; p++) {
+ int[] mapping =
+ PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(
+ params(p, p, p, p), false);
+ for (int i = 0; i < p; i++) {
+ assertThat(mapping[i]).as("p=%d: identity at %d", p, i).isEqualTo(i);
+ }
+ }
+ }
+
+ @Test
+ void assignmentUpstreamRespectTopology() {
+ // 4 upstream -> 2 downstream, rescale to 2 upstream -> 2 downstream.
+ // Old u=0 has consumers {0} in old topology (4 up, 2 down).
+ // New topology (2 up, 2 down): new u=0 -> consumers {0}, new u=1 -> consumers {1}.
+ // Old u=0's target counterpart = old consumer 0 mod 2 = 0, so must go to new u=0.
+ int[] mapping =
+ PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(
+ params(4, 2, 2, 2), true);
+ // old u=0,1 have consumer 0 in old topology -> should map to new u=0
+ // old u=2,3 have consumer 1 in old topology -> should map to new u=1
+ assertThat(mapping[0]).isEqualTo(0);
+ assertThat(mapping[1]).isEqualTo(0);
+ assertThat(mapping[2]).isEqualTo(1);
+ assertThat(mapping[3]).isEqualTo(1);
+ }
+
+ @Test
+ void assignmentDownstreamCoherentWithUpstream() {
+ // Verify downstream assignment is coherent: for any assigned pair (oldD -> newD),
+ // the upstream assignment of oldD's producers should produce subtasks that are
+ // connected to newD in the new topology.
+ int oldUp = 4, newUp = 2, oldDown = 4, newDown = 2;
+ PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown);
+ int[] fUp = PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(p, true);
+ int[] fDown = PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(p, false);
+ for (int oldD = 0; oldD < oldDown; oldD++) {
+ int newD = fDown[oldD];
+ int[] oldProducers = PointwiseChannelMappingUtils.producersOf(oldD, oldUp, oldDown);
+ int[] newProducersOfNewD =
+ PointwiseChannelMappingUtils.producersOf(newD, newUp, newDown);
+ for (int oldP : oldProducers) {
+ int newU = fUp[oldP];
+ assertThat(newProducersOfNewD)
+ .as(
+ "oldD=%d->newD=%d, oldP=%d->newU=%d should connect",
+ oldD, newD, oldP, newU)
+ .contains(newU);
+ }
+ }
+ }
+
+ @Test
+ void assignmentScaleUp() {
+ // 2 upstream -> 4 upstream, downstream unchanged at 4
+ int[] mapping =
+ PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(
+ params(2, 4, 4, 4), true);
+ assertThat(mapping).hasSize(2);
+ for (int v : mapping) {
+ assertThat(v).isBetween(0, 3);
+ }
+ }
+
+ @Test
+ void assignmentScaleDown() {
+ // 6 upstream -> 2 upstream, downstream 6 -> 2
+ int[] mapping =
+ PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(
+ params(6, 6, 2, 2), true);
+ assertThat(mapping).hasSize(6);
+ for (int v : mapping) {
+ assertThat(v).isBetween(0, 1);
+ }
+ }
+
+ // =============== oldSubtasksAssignedTo / newSubtaskAssignedFrom ===============
+
+ @Test
+ void oldSubtasksAssignedToAndNewSubtaskAssignedFromAreConsistent() {
+ for (int oldPar = 1; oldPar <= 8; oldPar++) {
+ for (int newPar = 1; newPar <= 8; newPar++) {
+ for (int newIdx = 0; newIdx < newPar; newIdx++) {
+ int[] oldSubtasks =
+ PointwiseChannelMappingUtils.oldSubtasksAssignedTo(
+ newIdx, oldPar, newPar);
+ for (int oldIdx : oldSubtasks) {
+ assertThat(
+ PointwiseChannelMappingUtils.newSubtaskAssignedFrom(
+ oldIdx, newPar))
+ .as(
+ "old=%d,new=%d: newSubtaskAssignedFrom(%d) == %d",
+ oldPar, newPar, oldIdx, newIdx)
+ .isEqualTo(newIdx);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ void oldSubtasksAssignedToCoversAllOld() {
+ for (int oldPar = 1; oldPar <= 8; oldPar++) {
+ for (int newPar = 1; newPar <= 8; newPar++) {
+ Set covered = new HashSet<>();
+ for (int newIdx = 0; newIdx < newPar; newIdx++) {
+ for (int oldIdx :
+ PointwiseChannelMappingUtils.oldSubtasksAssignedTo(
+ newIdx, oldPar, newPar)) {
+ covered.add(oldIdx);
+ }
+ }
+ assertThat(covered)
+ .as("old=%d new=%d: all old subtasks assigned", oldPar, newPar)
+ .hasSize(oldPar);
+ }
+ }
+ }
+
+ @Test
+ void newSubtaskAssignedFromIsModulo() {
+ assertThat(PointwiseChannelMappingUtils.newSubtaskAssignedFrom(5, 3)).isEqualTo(2);
+ assertThat(PointwiseChannelMappingUtils.newSubtaskAssignedFrom(0, 4)).isEqualTo(0);
+ assertThat(PointwiseChannelMappingUtils.newSubtaskAssignedFrom(7, 3)).isEqualTo(1);
+ }
+
+ // ======================== traceOutputSources ========================
+
+ @Test
+ void traceOutputSourcesPwToPw() {
+ // 3 up -> 3 down (PW), rescale to 4 up -> 3 down (PW)
+ // newUp=0: consumers in new topo = consumersOf(0,4,3) = {0}
+ // newDown=0: oldSubtasksAssignedTo(0,3,3) = {0}
+ // oldDown=0: producersOf(0,3,3) = {0}
+ // result = {0}
+ PointwiseRescaleParams p = params(3, 3, 4, 3);
+ assertThat(PointwiseChannelMappingUtils.traceOutputSources(0, p)).containsExactly(0);
+
+ // newUp=3: consumers in new topo = consumersOf(3,4,3) = {2}
+ // newDown=2: oldSubtasksAssignedTo(2,3,3) = {2}
+ // oldDown=2: producersOf(2,3,3) = {2}
+ // result = {2}
+ assertThat(PointwiseChannelMappingUtils.traceOutputSources(3, p)).containsExactly(2);
+ }
+
+ @Test
+ void traceOutputSourcesA2aToPw() {
+ // Old A2A (3 up, 3 down) -> New PW (4 up, 3 down)
+ // Any old downstream connects to all 3 old upstreams (A2A), so result = {0,1,2}
+ PointwiseRescaleParams p = params(A2A, PW, 3, 3, 4, 3);
+ for (int newUp = 0; newUp < 4; newUp++) {
+ assertThat(PointwiseChannelMappingUtils.traceOutputSources(newUp, p))
+ .as("newUp=%d", newUp)
+ .containsExactly(0, 1, 2);
+ }
+ }
+
+ @Test
+ void traceOutputSourcesPwToA2a() {
+ // Old PW (3 up, 3 down) -> New A2A (4 up, 3 down)
+ // New A2A: every new upstream serves ALL new downstreams {0,1,2}
+ // oldSubtasksAssignedTo covers all 3 old downstreams → all old producers reached
+ PointwiseRescaleParams p = params(PW, A2A, 3, 3, 4, 3);
+ for (int newUp = 0; newUp < 4; newUp++) {
+ assertThat(PointwiseChannelMappingUtils.traceOutputSources(newUp, p))
+ .as("newUp=%d", newUp)
+ .containsExactly(0, 1, 2);
+ }
+ }
+
+ @Test
+ void traceOutputSourcesCoversAllOldUpstreams() {
+ for (int oldUp = 1; oldUp <= 6; oldUp++) {
+ for (int oldDown = 1; oldDown <= 6; oldDown++) {
+ for (int newUp = 1; newUp <= 6; newUp++) {
+ for (int newDown = 1; newDown <= 6; newDown++) {
+ PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown);
+ Set allClaimed = new HashSet<>();
+ for (int u = 0; u < newUp; u++) {
+ for (int oldU : PointwiseChannelMappingUtils.traceOutputSources(u, p)) {
+ allClaimed.add(oldU);
+ }
+ }
+ assertThat(allClaimed)
+ .as(
+ "PW %d:%d->%d:%d: all old upstreams covered",
+ oldUp, oldDown, newUp, newDown)
+ .hasSize(oldUp);
+ }
+ }
+ }
+ }
+ }
+
+ // ======================== resolveInputOwnership ========================
+
+ @Test
+ void resolveInputOwnershipPwToPw() {
+ // 3 up -> 3 down (PW), rescale to 4 up -> 3 down (PW)
+ // newDown=0: connected new upstreams = producersOf(0,4,3) = {0}
+ // newUp=0: traceOutputSources = {0} → mapping[0] = 0
+ PointwiseRescaleParams p = params(3, 3, 4, 3);
+ int[] ownership = PointwiseChannelMappingUtils.resolveInputOwnership(0, p);
+ assertThat(ownership).hasSize(3);
+ assertThat(ownership[0]).isEqualTo(0);
+ }
+
+ @Test
+ void resolveInputOwnershipNoUnclaimedForConnectedOldUpstreams() {
+ for (int oldUp = 1; oldUp <= 6; oldUp++) {
+ for (int oldDown = 1; oldDown <= 6; oldDown++) {
+ for (int newUp = 1; newUp <= 6; newUp++) {
+ for (int newDown = 1; newDown <= 6; newDown++) {
+ PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown);
+ for (int newD = 0; newD < newDown; newD++) {
+ int[] ownership =
+ PointwiseChannelMappingUtils.resolveInputOwnership(newD, p);
+ // Every old upstream that was connected to any old downstream
+ // that maps to this new downstream must be claimed.
+ int[] oldDownstreams =
+ PointwiseChannelMappingUtils.oldSubtasksAssignedTo(
+ newD, oldDown, newDown);
+ for (int oldD : oldDownstreams) {
+ int[] oldProducers =
+ PointwiseChannelMappingUtils.producersOf(
+ oldD, oldUp, oldDown);
+ for (int oldU : oldProducers) {
+ assertThat(ownership[oldU])
+ .as(
+ "%d:%d->%d:%d newD=%d oldD=%d oldU=%d must be claimed",
+ oldUp, oldDown, newUp, newDown, newD, oldD,
+ oldU)
+ .isGreaterThanOrEqualTo(0);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ void resolveInputOwnershipOwnerIsConnected() {
+ for (int oldUp = 1; oldUp <= 6; oldUp++) {
+ for (int oldDown = 1; oldDown <= 6; oldDown++) {
+ for (int newUp = 1; newUp <= 6; newUp++) {
+ for (int newDown = 1; newDown <= 6; newDown++) {
+ PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown);
+ for (int newD = 0; newD < newDown; newD++) {
+ int[] ownership =
+ PointwiseChannelMappingUtils.resolveInputOwnership(newD, p);
+ int[] connectedNewUp =
+ PointwiseChannelMappingUtils.producersOf(newD, newUp, newDown);
+ for (int oldU = 0; oldU < oldUp; oldU++) {
+ if (ownership[oldU] >= 0) {
+ assertThat(connectedNewUp)
+ .as(
+ "%d:%d->%d:%d newD=%d: owner of oldU=%d should be connected",
+ oldUp, oldDown, newUp, newDown, newD, oldU)
+ .contains(ownership[oldU]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // ======================== Primary producer dedup ========================
+
+ @Test
+ void primaryProducerDedupExactlyOneWriterPerDownstream() {
+ // When newUpPar > newDownPar, multiple upstreams connect to the same downstream.
+ // computeNewLocalSubpartitionIndex must return non-negative for exactly one.
+ int[][] cases = {{4, 2}, {10, 3}, {11, 3}, {6, 2}, {7, 1}};
+ for (int[] c : cases) {
+ int newUp = c[0], newDown = c[1];
+ PointwiseRescaleParams p = params(newUp, newDown, newUp, newDown);
+ for (int d = 0; d < newDown; d++) {
+ int writerCount = 0;
+ int writerIndex = -1;
+ for (int u = 0; u < newUp; u++) {
+ int sub =
+ PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(d, u, p);
+ if (sub >= 0) {
+ writerCount++;
+ writerIndex = u;
+ }
+ }
+ assertThat(writerCount)
+ .as("%d:%d downstream %d: exactly one writer", newUp, newDown, d)
+ .isEqualTo(1);
+ // The writer must be producersOf(d)[0] (primary producer)
+ int[] producers = PointwiseChannelMappingUtils.producersOf(d, newUp, newDown);
+ assertThat(writerIndex)
+ .as("%d:%d downstream %d: writer is primary producer", newUp, newDown, d)
+ .isEqualTo(producers[0]);
+ }
+ }
+ }
+
+ // ======================== Modulo-based subpartition/channel mapping ========================
+
+ /**
+ * The execution graph assigns subpartitions via {@code consumerSubtaskIndex % numConsumers}
+ * (see {@code VertexInputInfoComputationUtils.computeConsumedSubpartitionRange}). When
+ * consumers don't start at a multiple of numConsumers, this differs from sequential indexing
+ * ({@code D - consumers[0]}). Both {@code localIndexToGlobalSubtaskIndex} (output side) and
+ * {@code computeNewLocalSubpartitionIndex} must use modulo, not sequential offset.
+ *
+ *
Concrete example: 4 upstream → 6 downstream POINTWISE. Upstream 2 has consumers [3,4].
+ * Sequential would map sp0→D3, sp1→D4. But the execution graph assigns D3→sp(3%2=1),
+ * D4→sp(4%2=0). So sp0→D4, sp1→D3. Using sequential causes buffers to be routed to the wrong
+ * downstream, leading to "Cannot select SubtaskConnectionDescriptor" errors at the DEMUX layer.
+ */
+ @Test
+ void moduloMappingDiffersFromSequentialWhenOffsetNotAligned() {
+ // 4 up -> 6 down: upstream 2 has consumers [3,4]
+ int[] consumers = PointwiseChannelMappingUtils.consumersOf(2, 4, 6);
+ assertThat(consumers).containsExactly(3, 4);
+
+ // localIndexToGlobalSubtaskIndex output side: sp → consumer via modulo
+ // sp 0 → consumer where c%2==0 → 4 (NOT sequential consumers[0]=3)
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 2, 0, 4, 6, false, PW))
+ .isEqualTo(4);
+ // sp 1 → consumer where c%2==1 → 3 (NOT sequential consumers[1]=4)
+ assertThat(
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ 2, 1, 4, 6, false, PW))
+ .isEqualTo(3);
+
+ // computeNewLocalSubpartitionIndex: consumer → sp via modulo
+ PointwiseRescaleParams p = params(4, 6, 4, 6);
+ // D=3 → sp 3%2=1 (NOT sequential 3-3=0)
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(3, 2, p))
+ .isEqualTo(1);
+ // D=4 → sp 4%2=0 (NOT sequential 4-3=1)
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(4, 2, p))
+ .isEqualTo(0);
+ }
+
+ @Test
+ void subpartitionIndexModuloForNonZeroOffset() {
+ // Subpartition assignment uses D % numConsumers (matching the execution graph),
+ // NOT D - consumers[0] (sequential offset). This distinction matters when
+ // consumers[0] % consumers.length != 0.
+
+ // 3 up -> 10 down (PW): upstream 1 has consumers [4,5,6]
+ // D=4 → sp 4%3=1, D=5 → sp 5%3=2, D=6 → sp 6%3=0
+ PointwiseRescaleParams p = params(3, 10, 3, 10);
+ int[] consumers = PointwiseChannelMappingUtils.consumersOf(1, 3, 10);
+ assertThat(consumers).startsWith(4); // sanity check: non-zero offset
+ for (int c : consumers) {
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(c, 1, p))
+ .as("3:10 upstream 1, D=%d → sp %d", c, c % consumers.length)
+ .isEqualTo(c % consumers.length);
+ }
+
+ // 2 up -> 5 down (PW): upstream 1 has consumers [3,4]
+ // D=3 → sp 3%2=1, D=4 → sp 4%2=0
+ PointwiseRescaleParams p2 = params(2, 5, 2, 5);
+ int[] consumers2 = PointwiseChannelMappingUtils.consumersOf(1, 2, 5);
+ assertThat(consumers2).startsWith(3);
+ for (int c : consumers2) {
+ assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(c, 1, p2))
+ .as("2:5 upstream 1, D=%d → sp %d", c, c % consumers2.length)
+ .isEqualTo(c % consumers2.length);
+ }
+ }
+
+ @Test
+ void subpartitionIndexIsModuloForAllParallelisms() {
+ // Exhaustively verify D % numConsumers for all upPar x downPar up to 8.
+ for (int upPar = 1; upPar <= 8; upPar++) {
+ for (int downPar = 1; downPar <= 8; downPar++) {
+ PointwiseRescaleParams p = params(upPar, downPar, upPar, downPar);
+ for (int u = 0; u < upPar; u++) {
+ int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar);
+ for (int c : consumers) {
+ int sp =
+ PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(
+ c, u, p);
+ int[] producers =
+ PointwiseChannelMappingUtils.producersOf(c, upPar, downPar);
+ if (producers[0] == u) {
+ assertThat(sp)
+ .as(
+ "up=%d down=%d u=%d D=%d: sp should be D %% numConsumers",
+ upPar, downPar, u, c)
+ .isEqualTo(c % consumers.length);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ void outputSideLocalIndexModuloForAllParallelisms() {
+ // Exhaustively verify modulo mapping on the output side of localIndexToGlobalSubtaskIndex.
+ for (int upPar = 1; upPar <= 8; upPar++) {
+ for (int downPar = 1; downPar <= 8; downPar++) {
+ for (int u = 0; u < upPar; u++) {
+ int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar);
+ for (int sp = 0; sp < consumers.length; sp++) {
+ int globalD =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ u, sp, upPar, downPar, false, PW);
+ assertThat(globalD % consumers.length)
+ .as(
+ "up=%d down=%d u=%d sp=%d: D %% numConsumers should equal sp",
+ upPar, downPar, u, sp)
+ .isEqualTo(sp);
+ }
+ }
+ }
+ }
+ }
+
+ // =============== Round-trip: localIndexToGlobalSubtaskIndex ↔ subpartition index
+ // ===============
+
+ @Test
+ void localIndexToGlobalSubtaskIndexAndSubpartitionIndexAreInverses() {
+ for (int upPar = 1; upPar <= 6; upPar++) {
+ for (int downPar = 1; downPar <= 6; downPar++) {
+ PointwiseRescaleParams p = params(upPar, downPar, upPar, downPar);
+ for (int u = 0; u < upPar; u++) {
+ int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar);
+ // Only the primary producer can round-trip (others get -1)
+ for (int localIndex = 0; localIndex < consumers.length; localIndex++) {
+ int globalD =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ u, localIndex, upPar, downPar, false, PW);
+ // Output side uses modulo mapping (D % numConsumers == localIndex),
+ // not sequential consumers[localIndex].
+ assertThat(consumers).contains(globalD);
+ assertThat(globalD % consumers.length)
+ .as("up=%d down=%d u=%d local=%d", upPar, downPar, u, localIndex)
+ .isEqualTo(localIndex);
+ int[] producers =
+ PointwiseChannelMappingUtils.producersOf(globalD, upPar, downPar);
+ if (producers[0] == u) {
+ int backLocal =
+ PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(
+ globalD, u, p);
+ assertThat(backLocal)
+ .as(
+ "up=%d down=%d u=%d: round-trip localIndex=%d",
+ upPar, downPar, u, localIndex)
+ .isEqualTo(localIndex);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ void localIndexToGlobalSubtaskIndexAndInputChannelIndexAreInverses() {
+ for (int upPar = 1; upPar <= 6; upPar++) {
+ for (int downPar = 1; downPar <= 6; downPar++) {
+ PointwiseRescaleParams p = params(upPar, downPar, upPar, downPar);
+ for (int d = 0; d < downPar; d++) {
+ int[] producers = PointwiseChannelMappingUtils.producersOf(d, upPar, downPar);
+ for (int localIndex = 0; localIndex < producers.length; localIndex++) {
+ int globalU =
+ PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex(
+ d, localIndex, upPar, downPar, true, PW);
+ assertThat(globalU).isEqualTo(producers[localIndex]);
+ int backLocal =
+ PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(
+ globalU, d, p);
+ assertThat(backLocal)
+ .as(
+ "up=%d down=%d d=%d: round-trip localIndex=%d",
+ upPar, downPar, d, localIndex)
+ .isEqualTo(localIndex);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseRescalingDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseRescalingDescriptorTest.java
new file mode 100644
index 0000000000000..5bc8fad43f355
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseRescalingDescriptorTest.java
@@ -0,0 +1,632 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.JobException;
+import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
+import org.apache.flink.runtime.client.JobExecutionException;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.JobEdge;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.testtasks.NoOpInvokable;
+import org.apache.flink.runtime.util.JobVertexConnectionUtils;
+import org.apache.flink.testutils.TestingUtils;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewInputChannelStateHandle;
+import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewResultSubpartitionStateHandle;
+import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration tests for POINTWISE edge rescaling descriptor generation in {@link
+ * StateAssignmentOperation} and {@link TaskStateAssignment}.
+ */
+class PointwiseRescalingDescriptorTest {
+
+ @RegisterExtension
+ private static final TestExecutorExtension EXECUTOR_EXTENSION =
+ TestingUtils.defaultExecutorExtension();
+
+ private static final int MAX_P = 256;
+
+ // ===== Test 1: POINTWISE same parallelism => IDENTITY =====
+ @Test
+ void testPointwiseSameParallelismIsIdentity() throws Exception {
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldPar = 3;
+ int newPar = 3;
+
+ JobVertex source = createJobVertex(sourceId, newPar);
+ JobVertex sink = createJobVertex(sinkId, newPar);
+ connectPointwise(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+ Map states =
+ buildStatesWithChannelState(sourceId, sinkId, oldPar);
+
+ EdgeDistributionPatternSnapshot oldEdgePatterns =
+ buildEdgePatterns(sourceId, DistributionPattern.POINTWISE);
+
+ new StateAssignmentOperation(
+ 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns)
+ .assignStates();
+
+ // Same parallelism + same pattern => IDENTITY (NO_RESCALE)
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sinkState =
+ getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ assertThat(sinkState.getInputRescalingDescriptor())
+ .isEqualTo(InflightDataRescalingDescriptor.NO_RESCALE);
+
+ OperatorSubtaskState sourceState =
+ getAssignedState(vertices.get(sourceId), sourceId, subtask);
+ assertThat(sourceState.getOutputRescalingDescriptor())
+ .isEqualTo(InflightDataRescalingDescriptor.NO_RESCALE);
+ }
+ }
+
+ // ===== Test 2: POINTWISE scale up - subtasks with old state get descriptors =====
+ @Test
+ void testPointwiseScaleUp() throws Exception {
+ // old: source(2) --PW--> sink(2), new: source(4) --PW--> sink(4)
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldPar = 2;
+ int newPar = 4;
+
+ JobVertex source = createJobVertex(sourceId, newPar);
+ JobVertex sink = createJobVertex(sinkId, newPar);
+ connectPointwise(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+ Map states =
+ buildStatesWithChannelState(sourceId, sinkId, oldPar);
+
+ EdgeDistributionPatternSnapshot oldEdgePatterns =
+ buildEdgePatterns(sourceId, DistributionPattern.POINTWISE);
+
+ new StateAssignmentOperation(
+ 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns)
+ .assignStates();
+
+ // Subtasks that absorb old state should get POINTWISE descriptors (newUpPar > 0)
+ int subtasksWithDescriptor = 0;
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sinkState =
+ getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor();
+
+ if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) {
+ subtasksWithDescriptor++;
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ inputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism())
+ .isEqualTo(newPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewDownParallelism())
+ .isEqualTo(newPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism())
+ .isEqualTo(oldPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDownParallelism())
+ .isEqualTo(oldPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern())
+ .isEqualTo(DistributionPattern.POINTWISE);
+ assertThat(gateDesc.isIdentity()).isFalse();
+
+ int[] oldSubtaskIndexes = inputDesc.getOldSubtaskIndexes(0);
+ assertThat(oldSubtaskIndexes.length).isGreaterThan(0);
+ }
+ }
+ // At least the subtasks that got old state must have descriptors
+ assertThat(subtasksWithDescriptor).isGreaterThan(0);
+
+ // Verify output side similarly
+ int sourceSubtasksWithDescriptor = 0;
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sourceState =
+ getAssignedState(vertices.get(sourceId), sourceId, subtask);
+ InflightDataRescalingDescriptor outputDesc = sourceState.getOutputRescalingDescriptor();
+
+ if (!outputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) {
+ sourceSubtasksWithDescriptor++;
+ InflightDataGateOrPartitionRescalingDescriptor partDesc =
+ outputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(partDesc.getPointwiseRescaleParams().getNewUpParallelism())
+ .isEqualTo(newPar);
+ assertThat(partDesc.getPointwiseRescaleParams().getNewDownParallelism())
+ .isEqualTo(newPar);
+ assertThat(partDesc.getPointwiseRescaleParams().getOldUpParallelism())
+ .isEqualTo(oldPar);
+ assertThat(partDesc.getPointwiseRescaleParams().getOldDownParallelism())
+ .isEqualTo(oldPar);
+ assertThat(partDesc.getPointwiseRescaleParams().getOldDistributionPattern())
+ .isEqualTo(DistributionPattern.POINTWISE);
+ }
+ }
+ assertThat(sourceSubtasksWithDescriptor).isGreaterThan(0);
+ }
+
+ // ===== Test 3: POINTWISE scale down - all old subtasks covered =====
+ @Test
+ void testPointwiseScaleDown() throws Exception {
+ // old: source(4) --PW--> sink(4), new: source(2) --PW--> sink(2)
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldPar = 4;
+ int newPar = 2;
+
+ JobVertex source = createJobVertex(sourceId, newPar);
+ JobVertex sink = createJobVertex(sinkId, newPar);
+ connectPointwise(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+ Map states =
+ buildStatesWithChannelState(sourceId, sinkId, oldPar);
+
+ EdgeDistributionPatternSnapshot oldEdgePatterns =
+ buildEdgePatterns(sourceId, DistributionPattern.POINTWISE);
+
+ new StateAssignmentOperation(
+ 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns)
+ .assignStates();
+
+ // Scale down: every new subtask absorbs state from multiple old subtasks
+ boolean[] covered = new boolean[oldPar];
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sinkState =
+ getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor();
+ assertThat(inputDesc)
+ .as("subtask %d should have rescaling descriptor", subtask)
+ .isNotEqualTo(InflightDataRescalingDescriptor.NO_RESCALE);
+
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ inputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism())
+ .isEqualTo(oldPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDownParallelism())
+ .isEqualTo(oldPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism())
+ .isEqualTo(newPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewDownParallelism())
+ .isEqualTo(newPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern())
+ .isEqualTo(DistributionPattern.POINTWISE);
+
+ int[] oldSubtaskIndexes = inputDesc.getOldSubtaskIndexes(0);
+ assertThat(oldSubtaskIndexes.length).isGreaterThan(0);
+ for (int oldIdx : oldSubtaskIndexes) {
+ covered[oldIdx] = true;
+ }
+ }
+ // Every old subtask must be assigned to some new subtask
+ for (int i = 0; i < oldPar; i++) {
+ assertThat(covered[i]).as("Old subtask %d should be covered", i).isTrue();
+ }
+ }
+
+ // ===== Test 4: A2A -> POINTWISE migration =====
+ @Test
+ void testAllToAllToPointwiseMigration() throws Exception {
+ // old: source(3) --A2A--> sink(3), new: source(3) --PW--> sink(3)
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldPar = 3;
+ int newPar = 3;
+
+ JobVertex source = createJobVertex(sourceId, newPar);
+ JobVertex sink = createJobVertex(sinkId, newPar);
+ connectPointwise(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+ Map states =
+ buildStatesWithChannelState(sourceId, sinkId, oldPar);
+
+ // Old pattern was A2A
+ EdgeDistributionPatternSnapshot oldEdgePatterns =
+ buildEdgePatterns(sourceId, DistributionPattern.ALL_TO_ALL);
+
+ new StateAssignmentOperation(
+ 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns)
+ .assignStates();
+
+ // A2A->PW same-par: new sink k reads from old sink k only (ROUND_ROBIN 1:1).
+ // TM-side recoverPointwise then routes each old upstream's buffer to the correct new
+ // upstream channel via resolveInputOwnership, so reading the extra old sinks would
+ // cause each inflight record to be delivered to all new subtasks (3x duplicates).
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sinkState =
+ getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor();
+ assertThat(inputDesc)
+ .as("subtask %d", subtask)
+ .isNotEqualTo(InflightDataRescalingDescriptor.NO_RESCALE);
+
+ int[] oldSubtaskIndexes = inputDesc.getOldSubtaskIndexes(0);
+ // ROUND_ROBIN 1:1 for same parallelism: new sink k -> old sink k
+ assertThat(oldSubtaskIndexes).isEqualTo(new int[] {subtask});
+
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ inputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern())
+ .isEqualTo(DistributionPattern.ALL_TO_ALL);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewDistributionPattern())
+ .isEqualTo(DistributionPattern.POINTWISE);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism())
+ .isEqualTo(newPar);
+ }
+ }
+
+ // ===== Test 5: POINTWISE -> A2A migration =====
+ @Test
+ void testPointwiseToAllToAllMigration() throws Exception {
+ // old: source(3) --PW--> sink(3), new: source(3) --A2A--> sink(3)
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldPar = 3;
+ int newPar = 3;
+
+ JobVertex source = createJobVertex(sourceId, newPar);
+ JobVertex sink = createJobVertex(sinkId, newPar);
+ connectAllToAll(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+ Map states =
+ buildStatesWithChannelState(sourceId, sinkId, oldPar);
+
+ // Old pattern was POINTWISE
+ EdgeDistributionPatternSnapshot oldEdgePatterns =
+ buildEdgePatterns(sourceId, DistributionPattern.POINTWISE);
+
+ new StateAssignmentOperation(
+ 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns)
+ .assignStates();
+
+ // PW->A2A: enters the pointwise path (oldPattern != A2A)
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sinkState =
+ getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor();
+ assertThat(inputDesc)
+ .as("subtask %d", subtask)
+ .isNotEqualTo(InflightDataRescalingDescriptor.NO_RESCALE);
+
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ inputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern())
+ .isEqualTo(DistributionPattern.POINTWISE);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewDistributionPattern())
+ .isEqualTo(DistributionPattern.ALL_TO_ALL);
+ // PW->A2A: newUpParallelism > 0 confirms pointwise path used
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism())
+ .isEqualTo(newPar);
+ }
+ }
+
+ // ===== Test 6: Pure A2A rescaling (regression - original path unchanged) =====
+ @Test
+ void testPureAllToAllRescalingRegression() throws Exception {
+ // old: source(2) --A2A--> sink(2), new: source(3) --A2A--> sink(3)
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldPar = 2;
+ int newPar = 3;
+
+ JobVertex source = createJobVertex(sourceId, newPar);
+ JobVertex sink = createJobVertex(sinkId, newPar);
+ connectAllToAll(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+ Map states =
+ buildStatesWithChannelState(sourceId, sinkId, oldPar);
+
+ // Old pattern was A2A (same as new)
+ EdgeDistributionPatternSnapshot oldEdgePatterns =
+ buildEdgePatterns(sourceId, DistributionPattern.ALL_TO_ALL);
+
+ new StateAssignmentOperation(
+ 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns)
+ .assignStates();
+
+ // A2A->A2A: should use the original all-to-all path
+ // The descriptor should have newUpParallelism == 0 (no pointwise fields)
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sinkState =
+ getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor();
+
+ if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) {
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ inputDesc.getGateOrPartitionDescriptor(0);
+ // A2A->A2A uses 4-arg constructor, scalars all 0
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()).isEqualTo(0);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism()).isEqualTo(0);
+ }
+ }
+ }
+
+ // ===== Test 7: POINTWISE cross-parallelism =====
+ @Test
+ void testPointwiseCrossParallelism() throws Exception {
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldSourcePar = 3;
+ int oldSinkPar = 2;
+ int newSourcePar = 4;
+ int newSinkPar = 3;
+
+ JobVertex source = createJobVertex(sourceId, newSourcePar);
+ JobVertex sink = createJobVertex(sinkId, newSinkPar);
+ connectPointwise(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+
+ Map states = new HashMap<>();
+ Random random = new Random(42);
+
+ OperatorState sourceState = new OperatorState(null, null, sourceId, oldSourcePar, MAX_P);
+ for (int i = 0; i < oldSourcePar; i++) {
+ sourceState.putState(
+ i,
+ OperatorSubtaskState.builder()
+ .setResultSubpartitionState(
+ new StateObjectCollection<>(
+ Collections.singletonList(
+ createNewResultSubpartitionStateHandle(
+ 10, 0, random))))
+ .build());
+ }
+ states.put(sourceId, sourceState);
+
+ OperatorState sinkState = new OperatorState(null, null, sinkId, oldSinkPar, MAX_P);
+ for (int i = 0; i < oldSinkPar; i++) {
+ sinkState.putState(
+ i,
+ OperatorSubtaskState.builder()
+ .setInputChannelState(
+ new StateObjectCollection<>(
+ Collections.singletonList(
+ createNewInputChannelStateHandle(
+ 10, 0, random))))
+ .build());
+ }
+ states.put(sinkId, sinkState);
+
+ EdgeDistributionPatternSnapshot oldEdgePatterns =
+ buildEdgePatterns(sourceId, DistributionPattern.POINTWISE);
+
+ new StateAssignmentOperation(
+ 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns)
+ .assignStates();
+
+ // Verify input descriptors carry correct cross-parallelism values
+ int subtasksWithDesc = 0;
+ for (int subtask = 0; subtask < newSinkPar; subtask++) {
+ OperatorSubtaskState assigned = getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ InflightDataRescalingDescriptor inputDesc = assigned.getInputRescalingDescriptor();
+
+ if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) {
+ subtasksWithDesc++;
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ inputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism())
+ .isEqualTo(oldSourcePar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDownParallelism())
+ .isEqualTo(oldSinkPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism())
+ .isEqualTo(newSourcePar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewDownParallelism())
+ .isEqualTo(newSinkPar);
+ assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern())
+ .isEqualTo(DistributionPattern.POINTWISE);
+ }
+ }
+ assertThat(subtasksWithDesc).isGreaterThan(0);
+
+ // Verify output descriptors
+ int sourceSubtasksWithDesc = 0;
+ for (int subtask = 0; subtask < newSourcePar; subtask++) {
+ OperatorSubtaskState assigned =
+ getAssignedState(vertices.get(sourceId), sourceId, subtask);
+ InflightDataRescalingDescriptor outputDesc = assigned.getOutputRescalingDescriptor();
+
+ if (!outputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) {
+ sourceSubtasksWithDesc++;
+ InflightDataGateOrPartitionRescalingDescriptor partDesc =
+ outputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(partDesc.getPointwiseRescaleParams().getOldUpParallelism())
+ .isEqualTo(oldSourcePar);
+ assertThat(partDesc.getPointwiseRescaleParams().getOldDownParallelism())
+ .isEqualTo(oldSinkPar);
+ assertThat(partDesc.getPointwiseRescaleParams().getNewUpParallelism())
+ .isEqualTo(newSourcePar);
+ assertThat(partDesc.getPointwiseRescaleParams().getNewDownParallelism())
+ .isEqualTo(newSinkPar);
+ }
+ }
+ assertThat(sourceSubtasksWithDesc).isGreaterThan(0);
+ }
+
+ // ===== Test 8: No EdgeDistributionPatternSnapshot -> treated as A2A (backward compat) =====
+ @Test
+ void testNullEdgePatternsDefaultsToAllToAll() throws Exception {
+ // old: source(2) --?--> sink(2), new: source(3) --A2A--> sink(3)
+ OperatorID sourceId = new OperatorID();
+ OperatorID sinkId = new OperatorID();
+ int oldPar = 2;
+ int newPar = 3;
+
+ JobVertex source = createJobVertex(sourceId, newPar);
+ JobVertex sink = createJobVertex(sinkId, newPar);
+ connectAllToAll(source, sink);
+
+ Map vertices = toExecutionVertices(source, sink);
+ Map states =
+ buildStatesWithChannelState(sourceId, sinkId, oldPar);
+
+ // No edge pattern snapshot (null) - backward compatible
+ new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, true, null)
+ .assignStates();
+
+ // Should produce standard A2A descriptors (newUpParallelism == 0)
+ for (int subtask = 0; subtask < newPar; subtask++) {
+ OperatorSubtaskState sinkState =
+ getAssignedState(vertices.get(sinkId), sinkId, subtask);
+ InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor();
+
+ if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) {
+ InflightDataGateOrPartitionRescalingDescriptor gateDesc =
+ inputDesc.getGateOrPartitionDescriptor(0);
+ assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()).isEqualTo(0);
+ }
+ }
+ }
+
+ // ===== Helper methods =====
+
+ private JobVertex createJobVertex(OperatorID operatorID, int parallelism) {
+ JobVertex jobVertex =
+ new JobVertex(
+ operatorID.toHexString(),
+ new JobVertexID(),
+ Collections.singletonList(
+ OperatorIDPair.of(operatorID, operatorID, null, null)));
+ jobVertex.setInvokableClass(NoOpInvokable.class);
+ jobVertex.setParallelism(parallelism);
+ return jobVertex;
+ }
+
+ private void connectPointwise(JobVertex upstream, JobVertex downstream) {
+ final JobEdge jobEdge =
+ JobVertexConnectionUtils.connectNewDataSetAsInput(
+ downstream,
+ upstream,
+ DistributionPattern.POINTWISE,
+ ResultPartitionType.PIPELINED);
+ jobEdge.setDownstreamSubtaskStateMapper(ROUND_ROBIN);
+ jobEdge.setUpstreamSubtaskStateMapper(ROUND_ROBIN);
+ }
+
+ private void connectAllToAll(JobVertex upstream, JobVertex downstream) {
+ final JobEdge jobEdge =
+ JobVertexConnectionUtils.connectNewDataSetAsInput(
+ downstream,
+ upstream,
+ DistributionPattern.ALL_TO_ALL,
+ ResultPartitionType.PIPELINED);
+ jobEdge.setDownstreamSubtaskStateMapper(ROUND_ROBIN);
+ jobEdge.setUpstreamSubtaskStateMapper(ROUND_ROBIN);
+ }
+
+ private Map toExecutionVertices(JobVertex... jobVertices)
+ throws JobException, JobExecutionException {
+ JobGraph jobGraph = JobGraphTestUtils.streamingJobGraph(jobVertices);
+ ExecutionGraph eg =
+ TestingDefaultExecutionGraphBuilder.newBuilder()
+ .setJobGraph(jobGraph)
+ .build(EXECUTOR_EXTENSION.getExecutor());
+ return Arrays.stream(jobVertices)
+ .collect(
+ Collectors.toMap(
+ jobVertex ->
+ jobVertex.getOperatorIDs().get(0).getGeneratedOperatorID(),
+ jobVertex -> {
+ try {
+ return eg.getJobVertex(jobVertex.getID());
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }));
+ }
+
+ private Map buildStatesWithChannelState(
+ OperatorID sourceId, OperatorID sinkId, int oldParallelism) {
+ Map states = new HashMap<>();
+ Random random = new Random(42);
+
+ OperatorState sourceState = new OperatorState(null, null, sourceId, oldParallelism, MAX_P);
+ for (int i = 0; i < oldParallelism; i++) {
+ sourceState.putState(
+ i,
+ OperatorSubtaskState.builder()
+ .setResultSubpartitionState(
+ new StateObjectCollection<>(
+ Collections.singletonList(
+ createNewResultSubpartitionStateHandle(
+ 10, 0, random))))
+ .build());
+ }
+ states.put(sourceId, sourceState);
+
+ OperatorState sinkState = new OperatorState(null, null, sinkId, oldParallelism, MAX_P);
+ for (int i = 0; i < oldParallelism; i++) {
+ sinkState.putState(
+ i,
+ OperatorSubtaskState.builder()
+ .setInputChannelState(
+ new StateObjectCollection<>(
+ Collections.singletonList(
+ createNewInputChannelStateHandle(
+ 10, 0, random))))
+ .build());
+ }
+ states.put(sinkId, sinkState);
+
+ return states;
+ }
+
+ private EdgeDistributionPatternSnapshot buildEdgePatterns(
+ OperatorID sourceId, DistributionPattern pattern) {
+ Map patterns = new HashMap<>();
+ patterns.put(sourceId, new DistributionPattern[] {pattern});
+ return new EdgeDistributionPatternSnapshot(patterns);
+ }
+
+ private OperatorSubtaskState getAssignedState(
+ ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) {
+ return executionJobVertex
+ .getTaskVertices()[subtaskIdx]
+ .getCurrentExecutionAttempt()
+ .getTaskRestore()
+ .getTaskStateSnapshot()
+ .getSubtaskStateByOperatorID(operatorId);
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
index 9c4aab0bc7a5d..5b8badef9d967 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
@@ -83,7 +83,8 @@ private InputChannelRecoveredStateHandler buildInputChannelStateHandler(
.MappingType.IDENTITY)
}),
null,
- MemoryManager.DEFAULT_PAGE_SIZE);
+ MemoryManager.DEFAULT_PAGE_SIZE,
+ 0);
}
private InputChannelRecoveredStateHandler buildMultiChannelHandler() {
@@ -111,7 +112,8 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() {
.MappingType.RESCALING)
}),
null,
- MemoryManager.DEFAULT_PAGE_SIZE);
+ MemoryManager.DEFAULT_PAGE_SIZE,
+ 0);
}
/** Builds a handler in filtering mode (non-null filtering handler, no-op stub). */
@@ -136,7 +138,8 @@ private InputChannelRecoveredStateHandler buildFilteringInputChannelStateHandler
.MappingType.IDENTITY)
}),
stubFilteringHandler,
- MemoryManager.DEFAULT_PAGE_SIZE);
+ MemoryManager.DEFAULT_PAGE_SIZE,
+ 0);
}
@Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java
index 91d4800e6736a..cbb2ca60b34ee 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java
@@ -70,7 +70,8 @@ private ResultSubpartitionRecoveredStateHandler buildResultStateHandler(
InflightDataRescalingDescriptor
.InflightDataGateOrPartitionRescalingDescriptor
.MappingType.IDENTITY)
- }));
+ }),
+ 0);
}
@Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForceUnalignedSupportTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForceUnalignedSupportTest.java
new file mode 100644
index 0000000000000..b78b8a8ed31b8
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForceUnalignedSupportTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.partitioner;
+
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Tests for {@link StreamPartitioner#isSupportsUnalignedCheckpoint(boolean)} — specifically the
+ * {@code forceUnaligned} flag introduced to allow unaligned checkpoints on forward (pointwise)
+ * edges while keeping broadcast edges aligned.
+ */
+class ForceUnalignedSupportTest {
+
+ @Test
+ void testForwardPartitionerRespectsForceUnaligned() {
+ ForwardPartitioner partitioner = new ForwardPartitioner<>();
+
+ // Without force: pointwise edges block unaligned checkpoints.
+ assertThat(partitioner.isSupportsUnalignedCheckpoint(false)).isFalse();
+ // With force: pointwise edges are allowed to use unaligned checkpoints.
+ assertThat(partitioner.isSupportsUnalignedCheckpoint(true)).isTrue();
+ }
+
+ @Test
+ void testRescalePartitionerRespectsForceUnaligned() {
+ RescalePartitioner partitioner = new RescalePartitioner<>();
+
+ assertThat(partitioner.isSupportsUnalignedCheckpoint(false)).isFalse();
+ assertThat(partitioner.isSupportsUnalignedCheckpoint(true)).isTrue();
+ }
+
+ @Test
+ void testBroadcastPartitionerIgnoresForceUnaligned() {
+ BroadcastPartitioner partitioner = new BroadcastPartitioner<>();
+
+ // Broadcast edges must never support unaligned checkpoints, regardless of force.
+ assertThat(partitioner.isSupportsUnalignedCheckpoint(false)).isFalse();
+ assertThat(partitioner.isSupportsUnalignedCheckpoint(true)).isFalse();
+ }
+
+ @Test
+ void testRebalancePartitionerForceDoesNotFlipBehavior() {
+ RebalancePartitioner partitioner = new RebalancePartitioner<>();
+
+ boolean withoutForce = partitioner.isSupportsUnalignedCheckpoint(false);
+ boolean withForce = partitioner.isSupportsUnalignedCheckpoint(true);
+
+ assertThat(withoutForce).isTrue();
+ assertThat(withForce).isEqualTo(withoutForce);
+ }
+
+ @Test
+ void testKeyGroupPartitionerForceDoesNotFlipBehavior() {
+ KeyGroupStreamPartitioner partitioner =
+ new KeyGroupStreamPartitioner<>((KeySelector) value -> value, 128);
+
+ boolean withoutForce = partitioner.isSupportsUnalignedCheckpoint(false);
+ boolean withForce = partitioner.isSupportsUnalignedCheckpoint(true);
+
+ assertThat(withoutForce).isTrue();
+ assertThat(withForce).isEqualTo(withoutForce);
+ }
+}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointShuffleChangeITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointShuffleChangeITCase.java
new file mode 100644
index 0000000000000..10a1ec1fffbfd
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointShuffleChangeITCase.java
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import org.apache.flink.api.common.JobExecutionResult;
+import org.apache.flink.api.common.JobStatus;
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.connector.source.util.ratelimit.RateLimiterStrategy;
+import org.apache.flink.configuration.CheckpointingOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.configuration.TaskManagerOptions;
+import org.apache.flink.connector.datagen.source.DataGeneratorSource;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.legacy.RichSinkFunction;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestInfo;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.concurrent.atomic.AtomicLong;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Integration test for changing distribution patterns (POINTWISE / ALL_TO_ALL) and parallelism
+ * between runs when restoring from an unaligned checkpoint.
+ *
+ *
Each parameterized case is a 9-tuple: (desc, oldUpPar, oldDownPar, oldPattern, newUpPar,
+ * newDownPar, newPattern, recoverOutputOnDownstream, checkpointDuringRecovery). Run 1 emits a
+ * bounded monotonic sequence {@code [0, TOTAL_COUNT)} XORed with a header, takes an in-flight
+ * unaligned checkpoint, and is cancelled. Run 2 restores with the new topology and runs to
+ * completion. The verifying sink asserts: (a) no header corruption (mis-routed buffers), (b) no
+ * per-subtask duplicates after restore, and (c) every value in {@code [0, TOTAL_COUNT)} arrives at
+ * the sink.
+ *
+ *
The test is parameterized over three recovery configurations:
+ *
+ *