diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 2a86a566b1f00..cbb4635a9a359 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -1811,13 +1811,22 @@ private OptionalLong restoreLatestCheckpointedStateInternal( vertexFinishedStateChecker.validateOperatorsFinishedState(); } + EdgeDistributionPatternSnapshot oldEdgePatterns = null; + for (MasterState ms : latest.getMasterHookStates()) { + if (EdgeDistributionPatternSnapshot.HOOK_IDENTIFIER.equals(ms.name())) { + oldEdgePatterns = EdgeDistributionPatternSnapshot.fromBytes(ms.bytes()); + break; + } + } + StateAssignmentOperation stateAssignmentOperation = new StateAssignmentOperation( latest.getCheckpointID(), tasks, operatorStates, allowNonRestoredState, - recoverOutputOnDownstreamTask); + recoverOutputOnDownstreamTask, + oldEdgePatterns); stateAssignmentOperation.assignStates(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshot.java new file mode 100644 index 0000000000000..701cb780133e9 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshot.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.OperatorID; + +import javax.annotation.Nullable; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Snapshot of per-operator output edge {@link DistributionPattern}s, persisted via {@link + * MasterTriggerRestoreHook} / {@link MasterState} so that the correct old distribution pattern is + * available when restoring from a checkpoint after a shuffle-mode change. + * + *

Key = outputOperatorID (generated, chain-tail), Value = one {@link DistributionPattern} per + * output partition (indexed by partitionIndex). + */ +public class EdgeDistributionPatternSnapshot { + + public static final String HOOK_IDENTIFIER = "__flink_edge_distribution_patterns__"; + + private static final int FORMAT_VERSION = 1; + private static final byte PATTERN_POINTWISE = 0; + private static final byte PATTERN_ALL_TO_ALL = 1; + + private final Map outputEdgePatterns; + + public EdgeDistributionPatternSnapshot( + Map outputEdgePatterns) { + this.outputEdgePatterns = Collections.unmodifiableMap(new HashMap<>(outputEdgePatterns)); + } + + @Nullable + public DistributionPattern[] getOutputPatterns(OperatorID outputOperatorID) { + DistributionPattern[] patterns = outputEdgePatterns.get(outputOperatorID); + return patterns != null ? patterns.clone() : null; + } + + @Nullable + public DistributionPattern getOutputPattern(OperatorID outputOperatorID, int partitionIndex) { + DistributionPattern[] patterns = outputEdgePatterns.get(outputOperatorID); + if (patterns == null || partitionIndex < 0 || partitionIndex >= patterns.length) { + return null; + } + return patterns[partitionIndex]; + } + + public byte[] toBytes() throws IOException { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos)) { + dos.writeInt(FORMAT_VERSION); + dos.writeInt(outputEdgePatterns.size()); + for (Map.Entry entry : + outputEdgePatterns.entrySet()) { + OperatorID opId = entry.getKey(); + dos.writeLong(opId.getLowerPart()); + dos.writeLong(opId.getUpperPart()); + DistributionPattern[] patterns = entry.getValue(); + dos.writeInt(patterns.length); + for (DistributionPattern p : patterns) { + dos.writeByte( + p == DistributionPattern.ALL_TO_ALL + ? PATTERN_ALL_TO_ALL + : PATTERN_POINTWISE); + } + } + dos.flush(); + return baos.toByteArray(); + } + } + + public static EdgeDistributionPatternSnapshot fromBytes(byte[] bytes) throws IOException { + try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes); + DataInputStream dis = new DataInputStream(bais)) { + int version = dis.readInt(); + if (version != FORMAT_VERSION) { + throw new IOException( + "Unsupported EdgeDistributionPatternSnapshot version: " + version); + } + int numOperators = dis.readInt(); + Map map = new HashMap<>(numOperators); + for (int i = 0; i < numOperators; i++) { + OperatorID opId = new OperatorID(dis.readLong(), dis.readLong()); + int numPartitions = dis.readInt(); + DistributionPattern[] patterns = new DistributionPattern[numPartitions]; + for (int j = 0; j < numPartitions; j++) { + byte b = dis.readByte(); + if (b == PATTERN_ALL_TO_ALL) { + patterns[j] = DistributionPattern.ALL_TO_ALL; + } else if (b == PATTERN_POINTWISE) { + patterns[j] = DistributionPattern.POINTWISE; + } else { + throw new IOException("Unknown DistributionPattern byte in snapshot: " + b); + } + } + map.put(opId, patterns); + } + return new EdgeDistributionPatternSnapshot(map); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java index fffc7b44305de..c7c9f4d7023c4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java @@ -17,6 +17,8 @@ package org.apache.flink.runtime.checkpoint; +import java.io.IOException; +import java.io.ObjectInputStream; import java.io.ObjectStreamException; import java.io.Serializable; import java.util.Arrays; @@ -55,6 +57,11 @@ public RescaleMappings getChannelMapping(int gateOrPartitionIndex) { return gateOrPartitionDescriptors[gateOrPartitionIndex].getRescaleMappings(); } + public InflightDataGateOrPartitionRescalingDescriptor getGateOrPartitionDescriptor( + int gateOrPartitionIndex) { + return gateOrPartitionDescriptors[gateOrPartitionIndex]; + } + public boolean isAmbiguous(int gateOrPartitionIndex, int oldSubtaskIndex) { return gateOrPartitionDescriptors[gateOrPartitionIndex].ambiguousSubtaskIndexes.contains( oldSubtaskIndex); @@ -142,7 +149,7 @@ public RescaleMappings getRescaleMappings() { /** * Set when channels are merged because the connected operator has been rescaled for each - * gate/partition. + * gate/partition. Used for ALL_TO_ALL → ALL_TO_ALL path. */ private final RescaleMappings rescaledChannelsMappings; @@ -151,6 +158,9 @@ public RescaleMappings getRescaleMappings() { private final MappingType mappingType; + /** Bundled scalar topology parameters for POINTWISE-aware rescaling. */ + private PointwiseRescaleParams rescaleParams; + /** Type of mapping which should be used for this in-flight data. */ public enum MappingType { IDENTITY, @@ -162,10 +172,25 @@ public InflightDataGateOrPartitionRescalingDescriptor( RescaleMappings rescaledChannelsMappings, Set ambiguousSubtaskIndexes, MappingType mappingType) { + this( + oldSubtaskIndexes, + rescaledChannelsMappings, + ambiguousSubtaskIndexes, + mappingType, + PointwiseRescaleParams.EMPTY); + } + + public InflightDataGateOrPartitionRescalingDescriptor( + int[] oldSubtaskIndexes, + RescaleMappings rescaledChannelsMappings, + Set ambiguousSubtaskIndexes, + MappingType mappingType, + PointwiseRescaleParams rescaleParams) { this.oldSubtaskIndexes = oldSubtaskIndexes; this.rescaledChannelsMappings = rescaledChannelsMappings; this.ambiguousSubtaskIndexes = ambiguousSubtaskIndexes; this.mappingType = mappingType; + this.rescaleParams = rescaleParams; } public int[] getOldSubtaskInstances() { @@ -180,6 +205,21 @@ public boolean isIdentity() { return mappingType == MappingType.IDENTITY; } + public PointwiseRescaleParams getPointwiseRescaleParams() { + return rescaleParams; + } + + public boolean isPointwiseRescaling() { + return !rescaleParams.equals(PointwiseRescaleParams.EMPTY); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + if (rescaleParams == null) { + rescaleParams = PointwiseRescaleParams.EMPTY; + } + } + @Override public boolean equals(Object o) { if (this == o) { @@ -193,13 +233,18 @@ public boolean equals(Object o) { return Arrays.equals(oldSubtaskIndexes, that.oldSubtaskIndexes) && Objects.equals(rescaledChannelsMappings, that.rescaledChannelsMappings) && Objects.equals(ambiguousSubtaskIndexes, that.ambiguousSubtaskIndexes) - && mappingType == that.mappingType; + && mappingType == that.mappingType + && Objects.equals(rescaleParams, that.rescaleParams); } @Override public int hashCode() { int result = - Objects.hash(rescaledChannelsMappings, ambiguousSubtaskIndexes, mappingType); + Objects.hash( + rescaledChannelsMappings, + ambiguousSubtaskIndexes, + mappingType, + rescaleParams); result = 31 * result + Arrays.hashCode(oldSubtaskIndexes); return result; } @@ -215,6 +260,8 @@ public String toString() { + ambiguousSubtaskIndexes + ", mappingType=" + mappingType + + ", rescaleParams=" + + rescaleParams + '}'; } } @@ -231,6 +278,12 @@ public int[] getOldSubtaskIndexes(int gateOrPartitionIndex) { return EMPTY_INT_ARRAY; } + @Override + public InflightDataGateOrPartitionRescalingDescriptor getGateOrPartitionDescriptor( + int gateOrPartitionIndex) { + return InflightDataGateOrPartitionRescalingDescriptor.NO_STATE; + } + @Override public RescaleMappings getChannelMapping(int gateOrPartitionIndex) { return RescaleMappings.SYMMETRIC_IDENTITY; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtils.java new file mode 100644 index 0000000000000..9a155522c8664 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtils.java @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.IntStream; + +/** + * Topology computation utilities for POINTWISE edge unaligned-checkpoint rescaling. Used by both JM + * (subtask assignment / descriptor generation) and TM (channel routing at recovery time). + * + *

Atomic primitives

+ * + *

All business methods are compositions of six atomic primitives: + * + *

+ * + *

Core chain

+ * + *

One chain, two directions. JM walks it forward (set-level) to decide which old + * subtasks to recover from; TM walks it backward (per-buffer) to decide where to place + * each recovered buffer. + * + *

Forward direction (JM, set-level) — "new upstream U needs state from which old + * upstreams?" + * + *

+ *   newUp ──consumersOf──→ newDown ──oldSubtasksAssignedTo──→ oldDown ──producersOf──→ oldUp
+ * 
+ * + *

Implemented by {@link #traceOutputSources}. Its inverse ({@link #resolveInputOwnership}) + * resolves the input side: for a given new downstream, which new upstream owns each old upstream's + * data. + * + *

Backward direction (TM, per-buffer) — "old buffer at (oldUp, localSP) goes to which new + * subpartition?" + * + *

+ *   (oldUp, localSP) ──{@link #localIndexToGlobalSubtaskIndex}──→ oldDown
+ *                      ──{@link #newSubtaskAssignedFrom}──→ newDown
+ *                      ──{@link #globalSubtaskIndexToLocalIndex}──→ newLocalSP
+ * 
+ * + *

The first and third steps use modulo-based mapping on the output side ({@code D % + * numConsumers}), matching the execution graph's subpartition assignment in {@link + * org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils#computeConsumedSubpartitionRange}. + * + *

Routing principle

+ * + *

At TM recovery time, only the new upstream that is connected to a target new + * downstream writes recovered buffers to it. When {@code newUpPar > newDownPar} in POINTWISE, + * multiple new upstreams connect to the same downstream and receive the same old state (ambiguous + * assignment from {@link #traceOutputSources}). To avoid duplicates, only the primary + * producer ({@code producersOf(D)[0]}) writes; others discard via {@link + * #computeNewLocalSubpartitionIndex} returning -1. + * + *

All arithmetic mirrors {@link + * org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils#computeVertexInputInfoForPointwise} + * but operates on scalar parallelisms only (no scheduler-layer objects), so it can run on TM. + */ +@Internal +public final class PointwiseChannelMappingUtils { + + private PointwiseChannelMappingUtils() {} + + // ============================== Business methods ============================== + + /** + * Output-side tracing: computes old upstreams whose state a new upstream needs to recover. + * Chains consumersOf → oldSubtasksAssignedTo → producersOf (see class javadoc). The result is + * ambiguous (multiple new upstreams may claim the same old upstream); TM filters via {@link + * #computeNewLocalSubpartitionIndex} at recovery time. + */ + public static int[] traceOutputSources( + int newUpstreamSubtaskIndex, PointwiseRescaleParams params) { + final int newUpstreamParallelism = params.getNewUpParallelism(); + final int newDownstreamParallelism = params.getNewDownParallelism(); + final int oldUpstreamParallelism = params.getOldUpParallelism(); + final int oldDownstreamParallelism = params.getOldDownParallelism(); + + int[] newDownstreams; + if (params.getNewDistributionPattern() == DistributionPattern.POINTWISE) { + newDownstreams = + consumersOf( + newUpstreamSubtaskIndex, + newUpstreamParallelism, + newDownstreamParallelism); + } else { + newDownstreams = IntStream.range(0, newDownstreamParallelism).toArray(); + } + + Set oldDownstreams = new HashSet<>(); + for (int newDownstream : newDownstreams) { + for (int oldDownstream : + oldSubtasksAssignedTo( + newDownstream, oldDownstreamParallelism, newDownstreamParallelism)) { + oldDownstreams.add(oldDownstream); + } + } + + Set oldUpstreams = new HashSet<>(); + for (int oldDownstream : oldDownstreams) { + if (params.getOldDistributionPattern() == DistributionPattern.ALL_TO_ALL) { + for (int i = 0; i < oldUpstreamParallelism; i++) { + oldUpstreams.add(i); + } + } else { + for (int oldUpstream : + producersOf( + oldDownstream, oldUpstreamParallelism, oldDownstreamParallelism)) { + oldUpstreams.add(oldUpstream); + } + } + } + + return oldUpstreams.stream().sorted().mapToInt(Integer::intValue).toArray(); + } + + /** + * Input-side ownership: for a new downstream, determines which new upstream owns each old + * upstream's data. Inverse of {@link #traceOutputSources} scoped to connected new upstreams. + * First-claim wins. Returns {@code result[oldUp] = newUp}, or -1 if unclaimed. + */ + public static int[] resolveInputOwnership( + int newDownstreamSubtaskIndex, PointwiseRescaleParams params) { + int[] connectedNewUpstreams; + if (params.getNewDistributionPattern() == DistributionPattern.POINTWISE) { + connectedNewUpstreams = + producersOf( + newDownstreamSubtaskIndex, + params.getNewUpParallelism(), + params.getNewDownParallelism()); + } else { + connectedNewUpstreams = IntStream.range(0, params.getNewUpParallelism()).toArray(); + } + + int[] mapping = new int[params.getOldUpParallelism()]; + Arrays.fill(mapping, -1); + + for (int newUpstream : connectedNewUpstreams) { + for (int oldUpstream : traceOutputSources(newUpstream, params)) { + if (mapping[oldUpstream] < 0) { + mapping[oldUpstream] = newUpstream; + } + } + } + + return mapping; + } + + /** + * Topology-aware old-to-new subtask assignment. For downstream assignment, computes upstream + * assignment first to ensure coherence. Uses range-shift ({@code old * newPar / oldPar}) as + * preferred placement, then checks topology constraint via the other-side peer set. + * + * @param isInputSide true = assigning upstream (input side), false = downstream + * @return {@code int[oldPar]} where {@code result[oldIdx] = newIdx} + */ + public static int[] computeOldToNewSubtaskAssignment( + PointwiseRescaleParams params, boolean isInputSide) { + int[] otherOldToNew = null; + if (!isInputSide) { + otherOldToNew = computeOldToNewSubtaskAssignmentInternal(params, true, null); + } + return computeOldToNewSubtaskAssignmentInternal(params, isInputSide, otherOldToNew); + } + + /** + * Old-topology local input channel count: A2A = oldUpstreamParallelism, PW = + * producersOf(subtask).length. + */ + public static int getOldLocalChannelCount( + int downstreamSubtaskIndex, PointwiseRescaleParams params) { + if (params.getOldDistributionPattern() == DistributionPattern.ALL_TO_ALL) { + return params.getOldUpParallelism(); + } + return producersOf( + downstreamSubtaskIndex, + params.getOldUpParallelism(), + params.getOldDownParallelism()) + .length; + } + + /** + * New upstream subtask index → local input channel index in the new topology. Never returns -1; + * throws on miss because input-side recovery only routes to upstreams confirmed by {@link + * #resolveInputOwnership}. + */ + public static int computeNewLocalInputChannelIndex( + int newUpstreamSubtaskIndex, + int newDownstreamSubtaskIndex, + PointwiseRescaleParams params) { + if (params.getNewDistributionPattern() == DistributionPattern.ALL_TO_ALL) { + return newUpstreamSubtaskIndex; + } + int localIndex = + globalSubtaskIndexToLocalIndex( + newDownstreamSubtaskIndex, + newUpstreamSubtaskIndex, + params.getNewUpParallelism(), + params.getNewDownParallelism(), + true, + params.getNewDistributionPattern()); + if (localIndex < 0) { + throw new IllegalStateException( + "newUpstreamSubtaskIndex=" + + newUpstreamSubtaskIndex + + " not found in producers of subtask " + + newDownstreamSubtaskIndex + + " with params " + + params); + } + return localIndex; + } + + /** + * New downstream subtask index → local subpartition index in the new topology (modulo-based). + * Returns -1 (discard) in two cases: + * + *

    + *
  1. Not connected: downstream is not in {@code consumersOf(upstream)} + *
  2. Not primary producer: when {@code upPar > downPar}, multiple upstreams connect to the + * same downstream; only {@code producersOf(D)[0]} writes, others discard to avoid + * duplicates + *
+ */ + public static int computeNewLocalSubpartitionIndex( + int newDownstreamSubtaskIndex, + int newUpstreamSubtaskIndex, + PointwiseRescaleParams params) { + if (params.getNewDistributionPattern() == DistributionPattern.ALL_TO_ALL) { + return newDownstreamSubtaskIndex; + } + int localIndex = + globalSubtaskIndexToLocalIndex( + newUpstreamSubtaskIndex, + newDownstreamSubtaskIndex, + params.getNewUpParallelism(), + params.getNewDownParallelism(), + false, + params.getNewDistributionPattern()); + if (localIndex < 0) { + return -1; + } + // Primary producer dedup: only producersOf(D)[0] writes. + int[] producers = + producersOf( + newDownstreamSubtaskIndex, + params.getNewUpParallelism(), + params.getNewDownParallelism()); + if (producers[0] != newUpstreamSubtaskIndex) { + return -1; + } + return localIndex; + } + + // ============================== Internal ============================== + + private static int[] computeOldToNewSubtaskAssignmentInternal( + PointwiseRescaleParams params, boolean isInputSide, int[] otherOldToNew) { + final int oldSelfParallelism, newSelfParallelism, oldOtherParallelism, newOtherParallelism; + if (isInputSide) { + oldSelfParallelism = params.getOldUpParallelism(); + newSelfParallelism = params.getNewUpParallelism(); + oldOtherParallelism = params.getOldDownParallelism(); + newOtherParallelism = params.getNewDownParallelism(); + } else { + oldSelfParallelism = params.getOldDownParallelism(); + newSelfParallelism = params.getNewDownParallelism(); + oldOtherParallelism = params.getOldUpParallelism(); + newOtherParallelism = params.getNewUpParallelism(); + } + int[][] newOtherPeersCache = new int[newSelfParallelism][]; + for (int i = 0; i < newSelfParallelism; i++) { + newOtherPeersCache[i] = + isInputSide + ? consumersOf(i, newSelfParallelism, newOtherParallelism) + : producersOf(i, newOtherParallelism, newSelfParallelism); + } + int[] result = new int[oldSelfParallelism]; + for (int oldSelf = 0; oldSelf < oldSelfParallelism; oldSelf++) { + int[] oldOtherPeers = + isInputSide + ? consumersOf(oldSelf, oldSelfParallelism, oldOtherParallelism) + : producersOf(oldSelf, oldOtherParallelism, oldSelfParallelism); + int chosen = -1; + if (oldOtherPeers.length > 0) { + int targetNewOther = + otherOldToNew != null + ? otherOldToNew[oldOtherPeers[0]] + : newSubtaskAssignedFrom(oldOtherPeers[0], newOtherParallelism); + int preferred = oldSelf * newSelfParallelism / oldSelfParallelism; + if (containsValue(newOtherPeersCache[preferred], targetNewOther)) { + chosen = preferred; + } else { + for (int candidate = 0; candidate < newSelfParallelism; candidate++) { + if (containsValue(newOtherPeersCache[candidate], targetNewOther)) { + chosen = candidate; + break; + } + } + } + } + if (chosen < 0) { + throw new IllegalStateException( + "No topology-compatible new subtask found for oldSelf=" + + oldSelf + + " (params=" + + params + + ", isInputSide=" + + isInputSide + + ")"); + } + result[oldSelf] = chosen; + } + return result; + } + + // ============================== Atomic primitives ============================== + + /** Topology edge: downstream subtask → upstream subtask indices that produce for it. */ + static int[] producersOf( + int downstreamIndex, int upstreamParallelism, int downstreamParallelism) { + Preconditions.checkArgument( + upstreamParallelism > 0 && downstreamParallelism > 0, + "parallelisms must be positive, got upstreamParallelism=%s, downstreamParallelism=%s", + upstreamParallelism, + downstreamParallelism); + Preconditions.checkArgument( + downstreamIndex >= 0 && downstreamIndex < downstreamParallelism, + "index out of range: downstreamIndex=%s, downstreamParallelism=%s", + downstreamIndex, + downstreamParallelism); + if (upstreamParallelism >= downstreamParallelism) { + int start = downstreamIndex * upstreamParallelism / downstreamParallelism; + int end = (downstreamIndex + 1) * upstreamParallelism / downstreamParallelism; + int[] result = new int[end - start]; + for (int i = 0; i < result.length; i++) { + result[i] = start + i; + } + return result; + } else { + for (int u = 0; u < upstreamParallelism; u++) { + int start = + (u * downstreamParallelism + upstreamParallelism - 1) / upstreamParallelism; + int end = + ((u + 1) * downstreamParallelism + upstreamParallelism - 1) + / upstreamParallelism; + if (downstreamIndex >= start && downstreamIndex < end) { + return new int[] {u}; + } + } + throw new IllegalStateException( + "No producer found for downstreamIndex=" + + downstreamIndex + + " (upstreamParallelism=" + + upstreamParallelism + + ", downstreamParallelism=" + + downstreamParallelism + + ")"); + } + } + + /** Topology edge: upstream subtask → downstream subtask indices it produces for. */ + static int[] consumersOf( + int upstreamIndex, int upstreamParallelism, int downstreamParallelism) { + Preconditions.checkArgument( + upstreamParallelism > 0 && downstreamParallelism > 0, + "parallelisms must be positive, got upstreamParallelism=%s, downstreamParallelism=%s", + upstreamParallelism, + downstreamParallelism); + Preconditions.checkArgument( + upstreamIndex >= 0 && upstreamIndex < upstreamParallelism, + "index out of range: upstreamIndex=%s, upstreamParallelism=%s", + upstreamIndex, + upstreamParallelism); + if (upstreamParallelism < downstreamParallelism) { + int start = + (upstreamIndex * downstreamParallelism + upstreamParallelism - 1) + / upstreamParallelism; + int end = + ((upstreamIndex + 1) * downstreamParallelism + upstreamParallelism - 1) + / upstreamParallelism; + int[] result = new int[end - start]; + for (int i = 0; i < result.length; i++) { + result[i] = start + i; + } + return result; + } else { + for (int d = 0; d < downstreamParallelism; d++) { + int start = d * upstreamParallelism / downstreamParallelism; + int end = (d + 1) * upstreamParallelism / downstreamParallelism; + if (upstreamIndex >= start && upstreamIndex < end) { + return new int[] {d}; + } + } + throw new IllegalStateException( + "No consumer found for upstreamIndex=" + + upstreamIndex + + " (upstreamParallelism=" + + upstreamParallelism + + ", downstreamParallelism=" + + downstreamParallelism + + ")"); + } + } + + /** + * Cross-generation: new subtask → old subtasks it inherits state from. Delegates to {@link + * SubtaskStateMapper#ROUND_ROBIN}. + */ + public static int[] oldSubtasksAssignedTo( + int newSubtaskIndex, int oldParallelism, int newParallelism) { + return SubtaskStateMapper.ROUND_ROBIN.getOldSubtasks( + newSubtaskIndex, oldParallelism, newParallelism); + } + + /** Cross-generation inverse: old subtask → new subtask that inherits it ({@code old % new}). */ + public static int newSubtaskAssignedFrom(int oldSubtaskIndex, int newParallelism) { + return oldSubtaskIndex % newParallelism; + } + + /** + * Local-to-global: translates a local channel/subpartition index to the global subtask index of + * the peer on the other side of the edge. Topology-agnostic — caller passes whichever + * parallelisms (old or new) are relevant. + * + *

Input side (downstream→upstream): {@code producersOf(subtask)[localIndex]} — sequential. + * Output side (upstream→downstream): modulo-based — finds the consumer {@code D} where + * {@code D % numConsumers == localIndex}, matching {@code + * VertexInputInfoComputationUtils.computeConsumedSubpartitionRange}. + * + *

When the distribution pattern is {@link DistributionPattern#ALL_TO_ALL}, every subtask + * connects to every peer, so the local index already IS the global subtask index. + */ + public static int localIndexToGlobalSubtaskIndex( + int subtaskIndex, + int localIndex, + int upstreamParallelism, + int downstreamParallelism, + boolean isInputSide, + DistributionPattern distributionPattern) { + if (distributionPattern == DistributionPattern.ALL_TO_ALL) { + return localIndex; + } + if (isInputSide) { + int[] producers = producersOf(subtaskIndex, upstreamParallelism, downstreamParallelism); + return producers[localIndex]; + } else { + int[] consumers = consumersOf(subtaskIndex, upstreamParallelism, downstreamParallelism); + for (int c : consumers) { + if (c % consumers.length == localIndex) { + return c; + } + } + throw new IllegalStateException( + "No consumer with modulo position " + + localIndex + + " for upstream " + + subtaskIndex + + " (consumers=" + + java.util.Arrays.toString(consumers) + + ")"); + } + } + + /** + * Global-to-local: inverse of {@link #localIndexToGlobalSubtaskIndex}. Translates a global peer + * subtask index back to the local channel/subpartition index. + * + *

When the distribution pattern is {@link DistributionPattern#ALL_TO_ALL}, the local index + * equals the peer subtask index directly. + * + * @return the local index, or -1 if not connected + */ + static int globalSubtaskIndexToLocalIndex( + int subtaskIndex, + int peerSubtaskIndex, + int upstreamParallelism, + int downstreamParallelism, + boolean isInputSide, + DistributionPattern distributionPattern) { + if (distributionPattern == DistributionPattern.ALL_TO_ALL) { + return peerSubtaskIndex; + } + if (isInputSide) { + int[] producers = producersOf(subtaskIndex, upstreamParallelism, downstreamParallelism); + for (int i = 0; i < producers.length; i++) { + if (producers[i] == peerSubtaskIndex) { + return i; + } + } + return -1; + } else { + int[] consumers = consumersOf(subtaskIndex, upstreamParallelism, downstreamParallelism); + if (!containsValue(consumers, peerSubtaskIndex)) { + return -1; + } + return peerSubtaskIndex % consumers.length; + } + } + + // ============================== Utilities ============================== + + private static boolean containsValue(int[] arr, int v) { + for (int x : arr) { + if (x == v) { + return true; + } + } + return false; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseRescaleParams.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseRescaleParams.java new file mode 100644 index 0000000000000..959ba62f0c6f6 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PointwiseRescaleParams.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.jobgraph.DistributionPattern; + +import java.io.ObjectStreamException; +import java.io.Serializable; +import java.util.Objects; + +/** Bundled scalar parameters for POINTWISE-aware rescaling on a single gate or partition. */ +public final class PointwiseRescaleParams implements Serializable { + + private static final long serialVersionUID = 1L; + + public static final PointwiseRescaleParams EMPTY = + new PointwiseRescaleParams( + DistributionPattern.ALL_TO_ALL, DistributionPattern.ALL_TO_ALL, 0, 0, 0, 0); + + private final DistributionPattern oldDistributionPattern; + private final DistributionPattern newDistributionPattern; + private final int oldUpParallelism; + private final int oldDownParallelism; + private final int newUpParallelism; + private final int newDownParallelism; + + public PointwiseRescaleParams( + DistributionPattern oldDistributionPattern, + DistributionPattern newDistributionPattern, + int oldUpParallelism, + int oldDownParallelism, + int newUpParallelism, + int newDownParallelism) { + this.oldDistributionPattern = oldDistributionPattern; + this.newDistributionPattern = newDistributionPattern; + this.oldUpParallelism = oldUpParallelism; + this.oldDownParallelism = oldDownParallelism; + this.newUpParallelism = newUpParallelism; + this.newDownParallelism = newDownParallelism; + } + + public DistributionPattern getOldDistributionPattern() { + return oldDistributionPattern; + } + + public DistributionPattern getNewDistributionPattern() { + return newDistributionPattern; + } + + public int getOldUpParallelism() { + return oldUpParallelism; + } + + public int getOldDownParallelism() { + return oldDownParallelism; + } + + public int getNewUpParallelism() { + return newUpParallelism; + } + + public int getNewDownParallelism() { + return newDownParallelism; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PointwiseRescaleParams that = (PointwiseRescaleParams) o; + return oldUpParallelism == that.oldUpParallelism + && oldDownParallelism == that.oldDownParallelism + && newUpParallelism == that.newUpParallelism + && newDownParallelism == that.newDownParallelism + && oldDistributionPattern == that.oldDistributionPattern + && newDistributionPattern == that.newDistributionPattern; + } + + @Override + public int hashCode() { + return Objects.hash( + oldDistributionPattern, + newDistributionPattern, + oldUpParallelism, + oldDownParallelism, + newUpParallelism, + newDownParallelism); + } + + private Object readResolve() throws ObjectStreamException { + if (oldUpParallelism == 0 + && oldDownParallelism == 0 + && newUpParallelism == 0 + && newDownParallelism == 0) { + return EMPTY; + } + return this; + } + + @Override + public String toString() { + return "PointwiseRescaleParams{" + + "oldPattern=" + + oldDistributionPattern + + ", newPattern=" + + newDistributionPattern + + ", oldUp=" + + oldUpParallelism + + ", oldDown=" + + oldDownParallelism + + ", newUp=" + + newUpParallelism + + ", newDown=" + + newDownParallelism + + '}'; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 6bc30d488c4f2..59d89cd81a8ce 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -48,6 +48,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -76,6 +78,7 @@ public class StateAssignmentOperation { private final long restoreCheckpointId; private final boolean allowNonRestoredState; private final boolean recoverOutputOnDownstreamTask; + @Nullable private final EdgeDistributionPatternSnapshot oldEdgePatterns; /** The state assignments for each ExecutionJobVertex that will be filled in multiple passes. */ private final Map vertexAssignments; @@ -93,12 +96,29 @@ public StateAssignmentOperation( Map operatorStates, boolean allowNonRestoredState, boolean recoverOutputOnDownstreamTask) { + this( + restoreCheckpointId, + tasks, + operatorStates, + allowNonRestoredState, + recoverOutputOnDownstreamTask, + null); + } + + public StateAssignmentOperation( + long restoreCheckpointId, + Set tasks, + Map operatorStates, + boolean allowNonRestoredState, + boolean recoverOutputOnDownstreamTask, + @Nullable EdgeDistributionPatternSnapshot oldEdgePatterns) { this.restoreCheckpointId = restoreCheckpointId; this.tasks = Preconditions.checkNotNull(tasks); this.operatorStates = Preconditions.checkNotNull(operatorStates); this.allowNonRestoredState = allowNonRestoredState; this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask; + this.oldEdgePatterns = oldEdgePatterns; this.vertexAssignments = CollectionUtil.newHashMapWithExpectedSize(tasks.size()); } @@ -148,7 +168,8 @@ private void buildStateAssignments() { operatorStates, consumerAssignment, vertexAssignments, - recoverOutputOnDownstreamTask); + recoverOutputOnDownstreamTask, + oldEdgePatterns); vertexAssignments.put(executionJobVertex, stateAssignment); for (final IntermediateResult producedDataSet : executionJobVertex.getInputs()) { consumerAssignment.put(producedDataSet.getId(), stateAssignment); @@ -408,7 +429,8 @@ public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) final List outputs = executionJobVertex.getJobVertex().getProducedDataSets(); - if (outputState.getParallelism() == executionJobVertex.getParallelism()) { + if (outputState.getParallelism() == executionJobVertex.getParallelism() + && !assignment.hasPointwiseEdge(assignment.getDownstreamAssignments(), false)) { assignment.resultSubpartitionStates.putAll( toInstanceMap(assignment.outputOperatorID, outputOperatorState)); return; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java index a6db5e4837f85..372ea7c7d45a3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper; +import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.OperatorInstanceID; @@ -112,6 +113,7 @@ class TaskStateAssignment { outputRescalingDescriptors = new HashMap<>(); private final boolean recoverOutputOnDownstreamTask; + @Nullable private final EdgeDistributionPatternSnapshot oldEdgePatterns; @Nullable private TaskStateAssignment[] downstreamAssignments; @Nullable private TaskStateAssignment[] upstreamAssignments; @@ -127,6 +129,22 @@ public TaskStateAssignment( Map consumerAssignment, Map vertexAssignments, boolean recoverOutputOnDownstreamTask) { + this( + executionJobVertex, + oldState, + consumerAssignment, + vertexAssignments, + recoverOutputOnDownstreamTask, + null); + } + + public TaskStateAssignment( + ExecutionJobVertex executionJobVertex, + Map oldState, + Map consumerAssignment, + Map vertexAssignments, + boolean recoverOutputOnDownstreamTask, + @Nullable EdgeDistributionPatternSnapshot oldEdgePatterns) { this.executionJobVertex = executionJobVertex; this.oldState = oldState; @@ -144,6 +162,7 @@ public TaskStateAssignment( this.consumerAssignment = checkNotNull(consumerAssignment); this.vertexAssignments = checkNotNull(vertexAssignments); this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask; + this.oldEdgePatterns = oldEdgePatterns; final int expectedNumberOfSubtasks = newParallelism * oldState.size(); subManagedOperatorState = @@ -370,7 +389,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor( // no state on input and output, especially for any aligned checkpoint if (subtaskGateOrPartitionMappings.isEmpty() - && Arrays.stream(rescaledChannelsMappings).allMatch(Objects::isNull)) { + && Arrays.stream(rescaledChannelsMappings).allMatch(Objects::isNull) + && !hasPointwiseEdge(connectedAssignments, isInput)) { return InflightDataRescalingDescriptor.NO_RESCALE; } @@ -411,6 +431,25 @@ private InflightDataRescalingDescriptor createRescalingDescriptor( } TaskStateAssignment connectedAssignment = connectedAssignments[partition]; + + DistributionPattern oldPattern = + resolveOldDistributionPattern( + isInput, partition, connectedAssignment); + DistributionPattern newPattern = + resolveNewDistributionPattern(isInput, partition); + + if (oldEdgePatterns != null + && (oldPattern == DistributionPattern.POINTWISE + || newPattern == DistributionPattern.POINTWISE)) { + return createPointwiseRescalingDescriptor( + instanceID, + partition, + isInput, + connectedAssignment, + oldPattern, + newPattern); + } + SubtasksRescaleMapping rescaleMapping = Optional.ofNullable(rescaledChannelsMappings[partition]) .orElseGet( @@ -508,9 +547,38 @@ public SubtasksRescaleMapping getOutputMapping(int partitionIndex) { .get(gateIndex) .getUpstreamSubtaskStateMapper(), "No channel rescaler found during rescaling of channel state"); - final RescaleMappings mapping = - mapper.getNewToOldSubtasksMapping( - oldState.get(outputOperatorID).getParallelism(), newParallelism); + + final RescaleMappings mapping; + if (mapper == SubtaskStateMapper.POINTWISE_UPSTREAM) { + PointwiseRescaleParams params = + buildOutputPointwiseRescaleParams(partitionIndex, downstreamAssignment); + int oldParallelism = oldState.get(outputOperatorID).getParallelism(); + boolean isIdentity = + oldParallelism == newParallelism + && params.getOldUpParallelism() == params.getNewUpParallelism() + && params.getOldDownParallelism() == params.getNewDownParallelism() + && params.getOldDistributionPattern() + == params.getNewDistributionPattern(); + if (isIdentity) { + mapping = + RescaleMappings.of( + IntStream.range(0, newParallelism).mapToObj(idx -> new int[] {idx}), + oldParallelism); + } else { + mapping = + RescaleMappings.of( + IntStream.range(0, newParallelism) + .mapToObj( + idx -> + PointwiseChannelMappingUtils + .traceOutputSources(idx, params)), + oldParallelism); + } + } else { + mapping = + mapper.getNewToOldSubtasksMapping( + oldState.get(outputOperatorID).getParallelism(), newParallelism); + } return outputSubtaskMappings.compute( partitionIndex, (idx, oldMapping) -> @@ -581,10 +649,156 @@ public boolean hasInFlightDataForResultPartition(int partitionIndex) { return false; } + boolean hasPointwiseEdge(TaskStateAssignment[] connectedAssignments, boolean isInput) { + if (oldEdgePatterns == null) { + return false; + } + for (int i = 0; i < connectedAssignments.length; i++) { + if (!hasInFlightData(isInput, i)) { + continue; + } + DistributionPattern oldPattern = + resolveOldDistributionPattern(isInput, i, connectedAssignments[i]); + DistributionPattern newPattern = resolveNewDistributionPattern(isInput, i); + if (oldPattern == DistributionPattern.POINTWISE + || newPattern == DistributionPattern.POINTWISE) { + return true; + } + } + return false; + } + + private DistributionPattern resolveOldDistributionPattern( + boolean isInput, int gateOrPartitionIdx, TaskStateAssignment connectedAssignment) { + if (oldEdgePatterns == null) { + return DistributionPattern.ALL_TO_ALL; + } + if (isInput) { + int partitionIdxInUpstream = + Arrays.asList(connectedAssignment.getDownstreamAssignments()).indexOf(this); + if (partitionIdxInUpstream < 0) { + return DistributionPattern.ALL_TO_ALL; + } + DistributionPattern pattern = + oldEdgePatterns.getOutputPattern( + connectedAssignment.outputOperatorID, partitionIdxInUpstream); + return pattern != null ? pattern : DistributionPattern.ALL_TO_ALL; + } else { + DistributionPattern pattern = + oldEdgePatterns.getOutputPattern(outputOperatorID, gateOrPartitionIdx); + return pattern != null ? pattern : DistributionPattern.ALL_TO_ALL; + } + } + + private DistributionPattern resolveNewDistributionPattern( + boolean isInput, int gateOrPartitionIdx) { + if (isInput) { + return executionJobVertex + .getInputs() + .get(gateOrPartitionIdx) + .getConsumingDistributionPattern(); + } else { + return executionJobVertex.getProducedDataSets()[gateOrPartitionIdx] + .getConsumingDistributionPattern(); + } + } + + private InflightDataGateOrPartitionRescalingDescriptor createPointwiseRescalingDescriptor( + OperatorInstanceID instanceID, + int partition, + boolean isInput, + TaskStateAssignment connectedAssignment, + DistributionPattern oldPattern, + DistributionPattern newPattern) { + final int oldUpParallelism; + final int oldDownParallelism; + final int newUpParallelism; + final int newDownParallelism; + + if (isInput) { + oldUpParallelism = + connectedAssignment + .oldState + .get(connectedAssignment.outputOperatorID) + .getParallelism(); + oldDownParallelism = oldState.get(inputOperatorID).getParallelism(); + newUpParallelism = connectedAssignment.newParallelism; + newDownParallelism = newParallelism; + } else { + oldUpParallelism = oldState.get(outputOperatorID).getParallelism(); + oldDownParallelism = + connectedAssignment + .oldState + .get(connectedAssignment.inputOperatorID) + .getParallelism(); + newUpParallelism = newParallelism; + newDownParallelism = connectedAssignment.newParallelism; + } + + PointwiseRescaleParams params = + new PointwiseRescaleParams( + oldPattern, + newPattern, + oldUpParallelism, + oldDownParallelism, + newUpParallelism, + newDownParallelism); + + int[] oldSubtaskInstances = + computePointwiseOldSubtaskInstances(instanceID.getSubtaskId(), isInput, params); + + boolean isIdentity = + oldSubtaskInstances.length == 0 + || (oldUpParallelism == newUpParallelism + && oldDownParallelism == newDownParallelism + && oldPattern == newPattern); + + return log( + new InflightDataGateOrPartitionRescalingDescriptor( + oldSubtaskInstances, + RescaleMappings.SYMMETRIC_IDENTITY, + emptySet(), + isIdentity ? MappingType.IDENTITY : MappingType.RESCALING, + params), + instanceID.getSubtaskId(), + partition); + } + + private static int[] computePointwiseOldSubtaskInstances( + int newSubtaskIdx, boolean isInput, PointwiseRescaleParams params) { + if (isInput) { + return SubtaskStateMapper.ROUND_ROBIN.getOldSubtasks( + newSubtaskIdx, params.getOldDownParallelism(), params.getNewDownParallelism()); + } else { + return PointwiseChannelMappingUtils.traceOutputSources(newSubtaskIdx, params); + } + } + + private PointwiseRescaleParams buildOutputPointwiseRescaleParams( + int partitionIndex, TaskStateAssignment downstreamAssignment) { + DistributionPattern oldPattern = + resolveOldDistributionPattern(false, partitionIndex, downstreamAssignment); + DistributionPattern newPattern = resolveNewDistributionPattern(false, partitionIndex); + int oldUpParallelism = oldState.get(outputOperatorID).getParallelism(); + int oldDownParallelism = + downstreamAssignment + .oldState + .get(downstreamAssignment.inputOperatorID) + .getParallelism(); + return new PointwiseRescaleParams( + oldPattern, + newPattern, + oldUpParallelism, + oldDownParallelism, + newParallelism, + downstreamAssignment.newParallelism); + } + void distributeOutputBuffersToDownstream() { for (Map.Entry> entry : resultSubpartitionStates.entrySet()) { OperatorInstanceID operatorInstanceID = entry.getKey(); + int newUpstreamSubtaskIndex = operatorInstanceID.getSubtaskId(); List stateHandles = entry.getValue(); ResultSubpartitionDistributor distributor = @@ -592,23 +806,49 @@ void distributeOutputBuffersToDownstream() { getOutputRescalingDescriptor(operatorInstanceID)); for (final ResultSubpartitionStateHandle stateHandle : stateHandles) { - distributeOutputBufferToDownstream(stateHandle, distributor); + ResultSubpartitionInfo info = stateHandle.getInfo(); + int partitionIdx = info.getPartitionIdx(); + TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIdx]; + + DistributionPattern oldPattern = + resolveOldDistributionPattern(false, partitionIdx, downstreamAssignment); + DistributionPattern newPattern = resolveNewDistributionPattern(false, partitionIdx); + if (oldEdgePatterns != null + && (oldPattern == DistributionPattern.POINTWISE + || newPattern == DistributionPattern.POINTWISE)) { + // PW path (aligned with createGateOrPartitionRescalingDescriptors): + // uses PointwiseChannelMappingUtils topology computation for routing. + // Dedup across new upstream subtasks is handled by + // computeNewLocalSubpartitionIndex returning -1 for non-primary producers. + distributePointwiseOutputBufferToDownstream( + stateHandle, + partitionIdx, + downstreamAssignment, + newUpstreamSubtaskIndex); + } else { + // A2A→A2A path: uses ResultSubpartitionDistributor with the descriptor's + // channel mapping (old subpartition → new subpartition). No dedup needed + // because MappingBasedRepartitioner assigns each old subtask's handles + // to exactly one new subtask. + distributeA2AOutputBufferToDownstream( + stateHandle, distributor, partitionIdx, downstreamAssignment); + } } } } - private void distributeOutputBufferToDownstream( - ResultSubpartitionStateHandle stateHandle, ResultSubpartitionDistributor distributor) { - // From the perspective of the downstream task, the oldUpstreamSubtaskIndex will be - // treated as the inputChannelIdx, and the info.getSubPartitionIdx() will be treated - // as the oldDownstreamSubtaskIndex. - int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex(); + private void distributeA2AOutputBufferToDownstream( + ResultSubpartitionStateHandle stateHandle, + ResultSubpartitionDistributor distributor, + int partitionIdx, + TaskStateAssignment downstreamAssignment) { ResultSubpartitionInfo info = stateHandle.getInfo(); - int partitionIdx = info.getPartitionIdx(); + // A2A path: oldUpstreamSubtaskIndex is the inputChannelIdx, and + // info.getSubPartitionIdx() is the oldDownstreamSubtaskIndex. + int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex(); int oldDownstreamSubtaskIndex = info.getSubPartitionIdx(); int gateIdxResultPartition = findInputGateIdxForResultPartition(partitionIdx); - TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIdx]; List mappedSubpartitions = distributor.getMappedSubpartitions(info); for (final ResultSubpartitionInfo mappedSubpartition : mappedSubpartitions) { @@ -636,6 +876,84 @@ private void distributeOutputBufferToDownstream( } } + private void distributePointwiseOutputBufferToDownstream( + ResultSubpartitionStateHandle stateHandle, + int partitionIdx, + TaskStateAssignment downstreamAssignment, + int newUpstreamSubtaskIndex) { + int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex(); + ResultSubpartitionInfo info = stateHandle.getInfo(); + int subPartitionIdx = info.getSubPartitionIdx(); + + PointwiseRescaleParams params = + buildOutputPointwiseRescaleParams(partitionIdx, downstreamAssignment); + + int oldUpPar = params.getOldUpParallelism(); + int oldDownPar = params.getOldDownParallelism(); + int newDownPar = params.getNewDownParallelism(); + + DistributionPattern oldPattern = params.getOldDistributionPattern(); + + int globalOldDownstream; + try { + globalOldDownstream = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + oldUpstreamSubtaskIndex, + subPartitionIdx, + oldUpPar, + oldDownPar, + false, + oldPattern); + } catch (IllegalStateException e) { + return; + } + + int targetDownstreamSubtaskId = + PointwiseChannelMappingUtils.newSubtaskAssignedFrom( + globalOldDownstream, newDownPar); + + int newLocalSP = + PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex( + targetDownstreamSubtaskId, newUpstreamSubtaskIndex, params); + if (newLocalSP < 0) { + return; + } + + int localChannelIdx = + PointwiseChannelMappingUtils.globalSubtaskIndexToLocalIndex( + globalOldDownstream, + oldUpstreamSubtaskIndex, + oldUpPar, + oldDownPar, + true, + oldPattern); + if (localChannelIdx < 0) { + return; + } + + int gateIdxResultPartition = findInputGateIdxForResultPartition(partitionIdx); + + OperatorInstanceID downstreamOperatorInstance = + new OperatorInstanceID( + targetDownstreamSubtaskId, downstreamAssignment.inputOperatorID); + + InputChannelInfo inputChannelInfo = + new InputChannelInfo(gateIdxResultPartition, localChannelIdx); + + InputChannelStateHandle upstreamOutputBufferHandle = + new InputChannelStateHandle( + globalOldDownstream, + inputChannelInfo, + stateHandle.getDelegate(), + stateHandle.getOffsets(), + stateHandle.getStateSize()); + + List upstreamOutputBufferHandles = + downstreamAssignment.upstreamOutputBufferStates.computeIfAbsent( + downstreamOperatorInstance, k -> new ArrayList<>()); + upstreamOutputBufferHandles.add(upstreamOutputBufferHandle); + } + private int findInputGateIdxForResultPartition(int partitionIndex) { // Check downstream input state for this partition TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIndex]; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java index b257c3b40544e..c1c2e2f50cc6e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java @@ -21,6 +21,9 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataOutputSerializer; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.PointwiseChannelMappingUtils; +import org.apache.flink.runtime.checkpoint.PointwiseRescaleParams; import org.apache.flink.runtime.checkpoint.RescaleMappings; import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; @@ -194,31 +197,68 @@ private static GateFilterHandler createGateHandler( Map> gateVirtualChannels = new HashMap<>(); - for (int oldSubtaskIndex : oldSubtaskIndexes) { - int numChannels = gate.getNumberOfInputChannels(); - int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, numChannels); - - for (int oldChannelIndex : oldChannelIndexes) { - SubtaskConnectionDescriptor key = - new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + rescalingDescriptor.getGateOrPartitionDescriptor(gateIndex); + + if (gateDesc.isPointwiseRescaling()) { + PointwiseRescaleParams params = gateDesc.getPointwiseRescaleParams(); + for (int oldSubtaskIndex : oldSubtaskIndexes) { + int numLocal = + PointwiseChannelMappingUtils.getOldLocalChannelCount( + oldSubtaskIndex, params); + for (int localIdx = 0; localIdx < numLocal; localIdx++) { + int oldChannelIndex = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + oldSubtaskIndex, + localIdx, + params.getOldUpParallelism(), + params.getOldDownParallelism(), + true, + params.getOldDistributionPattern()); + SubtaskConnectionDescriptor key = + new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); + if (gateVirtualChannels.containsKey(key)) { + continue; + } - if (gateVirtualChannels.containsKey(key)) { - continue; + // Pointwise channel are always exactly recovered + RecordFilter recordFilter = + VirtualChannelRecordFilterFactory.createPassThroughFilter(); + RecordDeserializer> deserializer = + createDeserializer(filterContext.getTmpDirectories()); + VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter); + gateVirtualChannels.put(key, vc); } + } + } else { + for (int oldSubtaskIndex : oldSubtaskIndexes) { + int numChannels = gate.getNumberOfInputChannels(); + int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, numChannels); + + for (int oldChannelIndex : oldChannelIndexes) { + SubtaskConnectionDescriptor key = + new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); - // Only ambiguous channels need actual filtering; non-ambiguous ones pass through - boolean isAmbiguous = rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex); + if (gateVirtualChannels.containsKey(key)) { + continue; + } - RecordFilter recordFilter = - isAmbiguous - ? filterFactory.createFilter() - : VirtualChannelRecordFilterFactory.createPassThroughFilter(); + // Only ambiguous channels need actual filtering; non-ambiguous ones pass + // through + boolean isAmbiguous = + rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex); - RecordDeserializer> deserializer = - createDeserializer(filterContext.getTmpDirectories()); + RecordFilter recordFilter = + isAmbiguous + ? filterFactory.createFilter() + : VirtualChannelRecordFilterFactory.createPassThroughFilter(); - VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter); - gateVirtualChannels.put(key, vc); + RecordDeserializer> deserializer = + createDeserializer(filterContext.getTmpDirectories()); + + VirtualChannel vc = new VirtualChannel<>(deserializer, recordFilter); + gateVirtualChannels.put(key, vc); + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java index ca01ff37bd369..2a8a858bf1cab 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java @@ -21,6 +21,8 @@ import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.PointwiseChannelMappingUtils; +import org.apache.flink.runtime.checkpoint.PointwiseRescaleParams; import org.apache.flink.runtime.checkpoint.RescaleMappings; import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; @@ -78,9 +80,11 @@ class InputChannelRecoveredStateHandler private final InputGate[] inputGates; private final InflightDataRescalingDescriptor channelMapping; + private final int subtaskIndex; private final Map rescaledChannels = new HashMap<>(); private final Map oldToNewMappings = new HashMap<>(); + private final Map pointwiseUpstreamAssignmentCache = new HashMap<>(); /** * Optional filtering handler for filtering recovered buffers. When non-null, filtering is @@ -112,10 +116,12 @@ class InputChannelRecoveredStateHandler InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping, @Nullable ChannelStateFilteringHandler filteringHandler, - int memorySegmentSize) { + int memorySegmentSize, + int subtaskIndex) { this.inputGates = inputGates; this.channelMapping = channelMapping; this.filteringHandler = filteringHandler; + this.subtaskIndex = subtaskIndex; checkArgument( memorySegmentSize > 0, "memorySegmentSize must be positive: %s", memorySegmentSize); this.memorySegmentSize = memorySegmentSize; @@ -185,18 +191,25 @@ public void recover( Buffer buffer = bufferWithContext.context; try { if (buffer.readableBytes() > 0) { - RecoveredInputChannel channel = getMappedChannels(channelInfo); - - if (filteringHandler != null) { - recoverWithFiltering( - channel, channelInfo, oldSubtaskIndex, buffer.retainBuffer()); + if (isPointwiseRescaling(channelInfo.getGateIdx())) { + recoverPointwise(channelInfo, oldSubtaskIndex, buffer); } else { - channel.onRecoveredStateBuffer( - EventSerializer.toBuffer( - new SubtaskConnectionDescriptor( - oldSubtaskIndex, channelInfo.getInputChannelIdx()), - false)); - channel.onRecoveredStateBuffer(buffer.retainBuffer()); + RecoveredInputChannel channel = getMappedChannels(channelInfo); + if (filteringHandler != null) { + recoverWithFiltering( + channel, + channelInfo.getGateIdx(), + channelInfo.getInputChannelIdx(), + oldSubtaskIndex, + buffer.retainBuffer()); + } else { + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer( + new SubtaskConnectionDescriptor( + oldSubtaskIndex, channelInfo.getInputChannelIdx()), + false)); + channel.onRecoveredStateBuffer(buffer.retainBuffer()); + } } } } finally { @@ -204,18 +217,64 @@ public void recover( } } + private void recoverPointwise(InputChannelInfo channelInfo, int oldSubtaskIndex, Buffer buffer) + throws IOException, InterruptedException { + PointwiseRescaleParams params = + channelMapping + .getGateOrPartitionDescriptor(channelInfo.getGateIdx()) + .getPointwiseRescaleParams(); + + int oldUpstreamSubtaskIndex = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + oldSubtaskIndex, + channelInfo.getInputChannelIdx(), + params.getOldUpParallelism(), + params.getOldDownParallelism(), + true, + params.getOldDistributionPattern()); + + int[] oldToNewUpstream = + pointwiseUpstreamAssignmentCache.computeIfAbsent( + channelInfo.getGateIdx(), + idx -> + PointwiseChannelMappingUtils.resolveInputOwnership( + subtaskIndex, params)); + int newUpstreamSubtaskIndex = oldToNewUpstream[oldUpstreamSubtaskIndex]; + + int newLocalChannelIndex = + PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex( + newUpstreamSubtaskIndex, subtaskIndex, params); + + if (filteringHandler == null) { + SubtaskConnectionDescriptor descriptor = + new SubtaskConnectionDescriptor(oldSubtaskIndex, oldUpstreamSubtaskIndex); + RecoveredInputChannel channel = + getChannel(channelInfo.getGateIdx(), newLocalChannelIndex); + channel.onRecoveredStateBuffer(EventSerializer.toBuffer(descriptor, false)); + channel.onRecoveredStateBuffer(buffer.retainBuffer()); + } else { + recoverWithFiltering( + getChannel(channelInfo.getGateIdx(), newLocalChannelIndex), + channelInfo.getGateIdx(), + oldUpstreamSubtaskIndex, + oldSubtaskIndex, + buffer.retainBuffer()); + } + } + private void recoverWithFiltering( RecoveredInputChannel channel, - InputChannelInfo channelInfo, + int gateIdx, + int inputChannelIdx, int oldSubtaskIndex, Buffer retainedBuffer) throws IOException, InterruptedException { checkState(filteringHandler != null, "filtering handler not set."); List filteredBuffers = filteringHandler.filterAndRewrite( - channelInfo.getGateIdx(), + gateIdx, oldSubtaskIndex, - channelInfo.getInputChannelIdx(), + inputChannelIdx, retainedBuffer, channel::requestBufferBlocking); @@ -260,6 +319,9 @@ private RecoveredInputChannel getMappedChannels(InputChannelInfo channelInfo) { @Nonnull private RecoveredInputChannel calculateMapping(InputChannelInfo info) { + if (isPointwiseRescaling(info.getGateIdx())) { + return getChannel(info.getGateIdx(), 0); + } final RescaleMappings oldToNewMapping = oldToNewMappings.computeIfAbsent( info.getGateIdx(), idx -> channelMapping.getChannelMapping(idx).invert()); @@ -270,6 +332,10 @@ private RecoveredInputChannel calculateMapping(InputChannelInfo info) { + "one buffer is expected to be processed once by the same task."); return getChannel(info.getGateIdx(), mappedIndexes[0]); } + + private boolean isPointwiseRescaling(int gateIdx) { + return channelMapping.getGateOrPartitionDescriptor(gateIdx).isPointwiseRescaling(); + } } class ResultSubpartitionRecoveredStateHandler @@ -277,19 +343,20 @@ class ResultSubpartitionRecoveredStateHandler private final ResultPartitionWriter[] writers; private final boolean notifyAndBlockOnCompletion; + private final InflightDataRescalingDescriptor channelMapping; + private final int subtaskIndex; private final ResultSubpartitionDistributor resultSubpartitionDistributor; ResultSubpartitionRecoveredStateHandler( ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion, - InflightDataRescalingDescriptor channelMapping) { + InflightDataRescalingDescriptor channelMapping, + int subtaskIndex) { this.writers = writers; + this.channelMapping = channelMapping; + this.subtaskIndex = subtaskIndex; this.resultSubpartitionDistributor = new ResultSubpartitionDistributor(channelMapping) { - /** - * Override the getSubpartitionInfo to perform type checking on the - * ResultPartitionWriter. - */ @Override ResultSubpartitionInfo getSubpartitionInfo( int partitionIndex, int subPartitionIdx) { @@ -323,25 +390,79 @@ public void recover( if (!bufferConsumer.isDataAvailable()) { return; } - final List mappedSubpartitions = - resultSubpartitionDistributor.getMappedSubpartitions(subpartitionInfo); - CheckpointedResultPartition checkpointedResultPartition = - getCheckpointedResultPartition(subpartitionInfo.getPartitionIdx()); - for (final ResultSubpartitionInfo mappedSubpartition : mappedSubpartitions) { - // channel selector is created from the downstream's point of view: the - // subtask of downstream = subpartition index of recovered buffer - final SubtaskConnectionDescriptor channelSelector = - new SubtaskConnectionDescriptor( - subpartitionInfo.getSubPartitionIdx(), oldSubtaskIndex); - checkpointedResultPartition.addRecovered( - mappedSubpartition.getSubPartitionIdx(), - EventSerializer.toBufferConsumer(channelSelector, false)); - checkpointedResultPartition.addRecovered( - mappedSubpartition.getSubPartitionIdx(), bufferConsumer.copy()); + if (isPointwiseRescaling(subpartitionInfo.getPartitionIdx())) { + recoverPointwise(subpartitionInfo, oldSubtaskIndex, bufferConsumer); + } else { + recoverAllToAll(subpartitionInfo, oldSubtaskIndex, bufferConsumer); } } } + private void recoverAllToAll( + ResultSubpartitionInfo subpartitionInfo, + int oldSubtaskIndex, + BufferConsumer bufferConsumer) + throws IOException { + final List mappedSubpartitions = + resultSubpartitionDistributor.getMappedSubpartitions(subpartitionInfo); + CheckpointedResultPartition checkpointedResultPartition = + getCheckpointedResultPartition(subpartitionInfo.getPartitionIdx()); + for (final ResultSubpartitionInfo mappedSubpartition : mappedSubpartitions) { + final SubtaskConnectionDescriptor channelSelector = + new SubtaskConnectionDescriptor( + subpartitionInfo.getSubPartitionIdx(), oldSubtaskIndex); + checkpointedResultPartition.addRecovered( + mappedSubpartition.getSubPartitionIdx(), + EventSerializer.toBufferConsumer(channelSelector, false)); + checkpointedResultPartition.addRecovered( + mappedSubpartition.getSubPartitionIdx(), bufferConsumer.copy()); + } + } + + private void recoverPointwise( + ResultSubpartitionInfo subpartitionInfo, + int oldSubtaskIndex, + BufferConsumer bufferConsumer) + throws IOException { + PointwiseRescaleParams params = + channelMapping + .getGateOrPartitionDescriptor(subpartitionInfo.getPartitionIdx()) + .getPointwiseRescaleParams(); + + int oldDownstreamSubtaskIndex = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + oldSubtaskIndex, + subpartitionInfo.getSubPartitionIdx(), + params.getOldUpParallelism(), + params.getOldDownParallelism(), + false, + params.getOldDistributionPattern()); + + int newDownstreamSubtaskIndex = + PointwiseChannelMappingUtils.newSubtaskAssignedFrom( + oldDownstreamSubtaskIndex, params.getNewDownParallelism()); + + int newLocalSubpartitionIndex = + PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex( + newDownstreamSubtaskIndex, subtaskIndex, params); + if (newLocalSubpartitionIndex < 0) { + return; + } + + SubtaskConnectionDescriptor channelSelector = + new SubtaskConnectionDescriptor(oldDownstreamSubtaskIndex, oldSubtaskIndex); + CheckpointedResultPartition checkpointedResultPartition = + getCheckpointedResultPartition(subpartitionInfo.getPartitionIdx()); + checkpointedResultPartition.addRecovered( + newLocalSubpartitionIndex, + EventSerializer.toBufferConsumer(channelSelector, false)); + checkpointedResultPartition.addRecovered(newLocalSubpartitionIndex, bufferConsumer.copy()); + } + + private boolean isPointwiseRescaling(int partitionIdx) { + return channelMapping.getGateOrPartitionDescriptor(partitionIdx).isPointwiseRescaling(); + } + private CheckpointedResultPartition getCheckpointedResultPartition(int partitionIndex) { ResultPartitionWriter writer = writers[partitionIndex]; if (!(writer instanceof CheckpointedResultPartition)) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index c52572e52faec..d366d7eaf4bc0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -52,9 +52,15 @@ public class SequentialChannelStateReaderImpl implements SequentialChannelStateR private final TaskStateSnapshot taskStateSnapshot; private final ChannelStateSerializer serializer; private final ChannelStateChunkReader chunkReader; + private final int subtaskIndex; public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot) { + this(taskStateSnapshot, 0); + } + + public SequentialChannelStateReaderImpl(TaskStateSnapshot taskStateSnapshot, int subtaskIndex) { this.taskStateSnapshot = taskStateSnapshot; + this.subtaskIndex = subtaskIndex; serializer = new ChannelStateSerializerImpl(); chunkReader = new ChannelStateChunkReader(serializer); } @@ -75,7 +81,8 @@ public void readInputData(InputGate[] inputGates, RecordFilterContext filterCont inputGates, taskStateSnapshot.getInputRescalingDescriptor(), filteringHandler, - filterContext.getMemorySegmentSize())) { + filterContext.getMemorySegmentSize(), + subtaskIndex)) { read( stateHandler, groupByDelegate( @@ -102,7 +109,8 @@ public void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlo new ResultSubpartitionRecoveredStateHandler( writers, notifyAndBlockOnCompletion, - taskStateSnapshot.getOutputRescalingDescriptor())) { + taskStateSnapshot.getOutputRescalingDescriptor(), + subtaskIndex)) { read( stateHandler, groupByDelegate( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java index c9136cebc549b..7c8d3886d9c72 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/SubtaskStateMapper.java @@ -173,6 +173,21 @@ public int[] getOldSubtasks( } }, + POINTWISE_UPSTREAM { + @Override + public int[] getOldSubtasks( + int newSubtaskIndex, int oldNumberOfSubtasks, int newNumberOfSubtasks) { + throw new UnsupportedOperationException( + "POINTWISE_UPSTREAM requires PointwiseRescaleParams; " + + "use PointwiseChannelMappingUtils.traceOutputSources instead."); + } + + @Override + public boolean isAmbiguous() { + return true; + } + }, + UNSUPPORTED { @Override public int[] getOldSubtasks( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java index 61765ad6bc225..22d0d464e9de9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateManagerImpl.java @@ -113,7 +113,8 @@ public TaskStateManagerImpl( new SequentialChannelStateReaderImpl( jobManagerTaskRestore == null ? new TaskStateSnapshot() - : jobManagerTaskRestore.getTaskStateSnapshot())); + : jobManagerTaskRestore.getTaskStateSnapshot(), + executionAttemptID.getSubtaskIndex())); } public TaskStateManagerImpl( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/EdgeDistributionPatternHook.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/EdgeDistributionPatternHook.java new file mode 100644 index 0000000000000..7186389f75f5d --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/EdgeDistributionPatternHook.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.checkpoint.EdgeDistributionPatternSnapshot; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +/** + * A {@link MasterTriggerRestoreHook} that persists the current job's edge {@link + * org.apache.flink.runtime.jobgraph.DistributionPattern}s into checkpoint metadata, so they are + * available when restoring after a shuffle-mode change. + */ +class EdgeDistributionPatternHook implements MasterTriggerRestoreHook { + + private final byte[] snapshotBytes; + + EdgeDistributionPatternHook(byte[] snapshotBytes) { + this.snapshotBytes = snapshotBytes; + } + + @Override + public String getIdentifier() { + return EdgeDistributionPatternSnapshot.HOOK_IDENTIFIER; + } + + @Nullable + @Override + public CompletableFuture triggerCheckpoint( + long checkpointId, long timestamp, Executor executor) { + return CompletableFuture.completedFuture(snapshotBytes); + } + + @Override + public void restoreCheckpoint(long checkpointId, @Nullable byte[] checkpointData) { + // no-op: data is extracted directly from MasterState before assignStates() + } + + @Nullable + @Override + public SimpleVersionedSerializer createCheckpointDataSerializer() { + return ByteArraySerializer.INSTANCE; + } + + private static final class ByteArraySerializer implements SimpleVersionedSerializer { + static final ByteArraySerializer INSTANCE = new ByteArraySerializer(); + + @Override + public int getVersion() { + return 1; + } + + @Override + public byte[] serialize(byte[] obj) throws IOException { + return obj; + } + + @Override + public byte[] deserialize(int version, byte[] serialized) throws IOException { + return serialized; + } + } + + static class Factory implements MasterTriggerRestoreHook.Factory { + private static final long serialVersionUID = 1L; + + private final byte[] snapshotBytes; + + Factory(EdgeDistributionPatternSnapshot snapshot) throws IOException { + this.snapshotBytes = snapshot.toBytes(); + } + + @SuppressWarnings("unchecked") + @Override + public MasterTriggerRestoreHook create() { + return (MasterTriggerRestoreHook) new EdgeDistributionPatternHook(snapshotBytes); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java index e65ec1246ea18..ca034504cc06f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java @@ -269,9 +269,14 @@ public StreamGraph generate() { LineageGraph lineageGraph = LineageGraphUtils.convertToLineageGraph(transformations); streamGraph.setLineageGraph(lineageGraph); + boolean forceUnaligned = + checkpointConfig.isForceUnalignedCheckpoints() && !streamGraph.isIterative(); for (StreamNode node : streamGraph.getStreamNodes()) { if (node.getInEdges().stream() - .anyMatch(e -> !e.getPartitioner().isSupportsUnalignedCheckpoint())) { + .anyMatch( + e -> + !e.getPartitioner() + .isSupportsUnalignedCheckpoint(forceUnaligned))) { for (StreamEdge edge : node.getInEdges()) { edge.setSupportsUnalignedCheckpoints(false); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index 66776c03fb5fd..493908e2493bb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -34,6 +34,9 @@ import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.runtime.OperatorIDPair; +import org.apache.flink.runtime.checkpoint.EdgeDistributionPatternSnapshot; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; +import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.InputOutputFormatVertex; @@ -49,6 +52,7 @@ import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; import org.apache.flink.runtime.jobgraph.forwardgroup.JobVertexForwardGroup; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration; +import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable; import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalPipelinedRegion; import org.apache.flink.runtime.jobgraph.topology.DefaultLogicalTopology; @@ -253,9 +257,70 @@ private JobGraph createJobGraph() { // Wait for the serialization of operator coordinators and stream config. serializeOperatorCoordinatorsAndStreamConfig(serializationExecutor, jobVertexBuildContext); + configureEdgeDistributionPatternHook(); + return jobGraph; } + @SuppressWarnings("unchecked") + private void configureEdgeDistributionPatternHook() { + if (!streamGraph.getCheckpointConfig().isUnalignedCheckpointsEnabled() + || !streamGraph.getCheckpointConfig().isForceUnalignedCheckpoints()) { + return; + } + JobCheckpointingSettings settings = jobGraph.getCheckpointingSettings(); + if (settings == null) { + return; + } + try { + Map outputPatterns = new HashMap<>(); + for (JobVertex jv : jobGraph.getVerticesSortedTopologicallyFromSources()) { + OperatorID outputOpId = jv.getOperatorIDs().get(0).getGeneratedOperatorID(); + List outputs = jv.getProducedDataSets(); + DistributionPattern[] patterns = new DistributionPattern[outputs.size()]; + for (int i = 0; i < outputs.size(); i++) { + IntermediateDataSet ds = outputs.get(i); + patterns[i] = + ds.getConsumers().isEmpty() + ? DistributionPattern.ALL_TO_ALL + : ds.getDistributionPattern(); + } + outputPatterns.put(outputOpId, patterns); + } + + MasterTriggerRestoreHook.Factory[] existingHooks; + SerializedValue serializedHooks = + settings.getMasterHooks(); + if (serializedHooks != null) { + existingHooks = + serializedHooks.deserializeValue( + Thread.currentThread().getContextClassLoader()); + } else { + existingHooks = new MasterTriggerRestoreHook.Factory[0]; + } + + MasterTriggerRestoreHook.Factory[] newHooks = + new MasterTriggerRestoreHook.Factory[existingHooks.length + 1]; + System.arraycopy(existingHooks, 0, newHooks, 0, existingHooks.length); + newHooks[existingHooks.length] = + new EdgeDistributionPatternHook.Factory( + new EdgeDistributionPatternSnapshot(outputPatterns)); + + JobCheckpointingSettings newSettings = + new JobCheckpointingSettings( + settings.getCheckpointCoordinatorConfiguration(), + settings.getDefaultStateBackend(), + settings.isChangelogStateBackendEnabled(), + settings.getDefaultCheckpointStorage(), + new SerializedValue<>(newHooks), + settings.isStateBackendUseManagedMemory()); + jobGraph.setSnapshotSettings(newSettings); + } catch (Exception e) { + throw new FlinkRuntimeException( + "Failed to configure edge distribution pattern hook", e); + } + } + public static void serializeOperatorCoordinatorsAndStreamConfig( Executor serializationExecutor, JobVertexBuildContext jobVertexBuildContext) { try { @@ -1634,8 +1699,29 @@ public static IntermediateDataSet connect( // set strategy name so that web interface can show it. jobEdge.setShipStrategyName(partitioner.toString()); - jobEdge.setDownstreamSubtaskStateMapper(partitioner.getDownstreamSubtaskStateMapper()); - jobEdge.setUpstreamSubtaskStateMapper(partitioner.getUpstreamSubtaskStateMapper()); + SubtaskStateMapper downstreamMapper = partitioner.getDownstreamSubtaskStateMapper(); + SubtaskStateMapper upstreamMapper = partitioner.getUpstreamSubtaskStateMapper(); + + if (jobVertexBuildContext + .getStreamGraph() + .getCheckpointConfig() + .isForceUnalignedCheckpoints() + && jobVertexBuildContext + .getStreamGraph() + .getCheckpointConfig() + .isUnalignedCheckpointsEnabled() + && !partitioner.isSupportsUnalignedCheckpoint(false) + && partitioner.isSupportsUnalignedCheckpoint(true)) { + if (downstreamMapper == SubtaskStateMapper.UNSUPPORTED) { + downstreamMapper = SubtaskStateMapper.ROUND_ROBIN; + } + if (upstreamMapper == SubtaskStateMapper.UNSUPPORTED) { + upstreamMapper = SubtaskStateMapper.POINTWISE_UPSTREAM; + } + } + + jobEdge.setDownstreamSubtaskStateMapper(downstreamMapper); + jobEdge.setUpstreamSubtaskStateMapper(upstreamMapper); if (LOG.isDebugEnabled()) { LOG.debug( diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java index 6f1f3bda8b0e4..0d262776efbb6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/DemultiplexingRecordDeserializer.java @@ -19,6 +19,9 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.PointwiseChannelMappingUtils; +import org.apache.flink.runtime.checkpoint.PointwiseRescaleParams; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; @@ -152,6 +155,20 @@ static DemultiplexingRecordDeserializer create( if (oldSubtaskIndexes.length == 0) { return UNMAPPED; } + + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + rescalingDescriptor.getGateOrPartitionDescriptor(channelInfo.getGateIdx()); + + if (gateDesc.isPointwiseRescaling()) { + return createPointwise( + channelInfo, + rescalingDescriptor, + gateDesc, + deserializerFactory, + recordFilterFactory, + oldSubtaskIndexes); + } + final int[] oldChannelIndexes = rescalingDescriptor .getChannelMapping(channelInfo.getGateIdx()) @@ -179,6 +196,56 @@ static DemultiplexingRecordDeserializer create( return new DemultiplexingRecordDeserializer(virtualChannels); } + private static DemultiplexingRecordDeserializer createPointwise( + InputChannelInfo channelInfo, + InflightDataRescalingDescriptor rescalingDescriptor, + InflightDataGateOrPartitionRescalingDescriptor gateDesc, + Function>> + deserializerFactory, + Function> recordFilterFactory, + int[] oldSubtaskIndexes) { + PointwiseRescaleParams params = gateDesc.getPointwiseRescaleParams(); + + int totalVirtualChannels = 0; + for (int subtask : oldSubtaskIndexes) { + totalVirtualChannels += + PointwiseChannelMappingUtils.getOldLocalChannelCount(subtask, params); + } + if (totalVirtualChannels == 0) { + return UNMAPPED; + } + + Map> virtualChannels = + Maps.newHashMapWithExpectedSize(totalVirtualChannels); + for (int subtask : oldSubtaskIndexes) { + int numLocalChannels = + PointwiseChannelMappingUtils.getOldLocalChannelCount(subtask, params); + for (int localChannelIndex = 0; + localChannelIndex < numLocalChannels; + localChannelIndex++) { + int oldUpstreamSubtaskIndex = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + subtask, + localChannelIndex, + params.getOldUpParallelism(), + params.getOldDownParallelism(), + true, + params.getOldDistributionPattern()); + SubtaskConnectionDescriptor descriptor = + new SubtaskConnectionDescriptor(subtask, oldUpstreamSubtaskIndex); + virtualChannels.put( + descriptor, + new VirtualChannel<>( + deserializerFactory.apply(totalVirtualChannels), + rescalingDescriptor.isAmbiguous(channelInfo.getGateIdx(), subtask) + ? recordFilterFactory.apply(channelInfo) + : RecordFilter.acceptAll())); + } + } + + return new DemultiplexingRecordDeserializer(virtualChannels); + } + @Override public String toString() { return "DemultiplexingRecordDeserializer{" + "channels=" + channels.keySet() + '}'; diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java index 7a07c8bfa9e00..80db01c3d6b81 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java @@ -87,7 +87,11 @@ public SubtaskStateMapper getUpstreamSubtaskStateMapper() { public abstract boolean isPointwise(); public boolean isSupportsUnalignedCheckpoint() { - return supportsUnalignedCheckpoint && !isPointwise() && !isBroadcast(); + return isSupportsUnalignedCheckpoint(false); + } + + public boolean isSupportsUnalignedCheckpoint(boolean forceUnaligned) { + return supportsUnalignedCheckpoint && (forceUnaligned || !isPointwise()) && !isBroadcast(); } public void disableUnalignedCheckpoints() { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshotTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshotTest.java new file mode 100644 index 0000000000000..4074ce8b9963f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/EdgeDistributionPatternSnapshotTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.OperatorID; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link EdgeDistributionPatternSnapshot}. */ +class EdgeDistributionPatternSnapshotTest { + + @Test + void roundTripSerialization() throws IOException { + Map patterns = new HashMap<>(); + OperatorID op1 = new OperatorID(); + OperatorID op2 = new OperatorID(); + patterns.put(op1, new DistributionPattern[] {DistributionPattern.POINTWISE}); + patterns.put( + op2, + new DistributionPattern[] { + DistributionPattern.ALL_TO_ALL, DistributionPattern.POINTWISE + }); + + EdgeDistributionPatternSnapshot original = new EdgeDistributionPatternSnapshot(patterns); + byte[] bytes = original.toBytes(); + EdgeDistributionPatternSnapshot restored = EdgeDistributionPatternSnapshot.fromBytes(bytes); + + assertThat(restored.getOutputPatterns(op1)).containsExactly(DistributionPattern.POINTWISE); + assertThat(restored.getOutputPatterns(op2)) + .containsExactly(DistributionPattern.ALL_TO_ALL, DistributionPattern.POINTWISE); + } + + @Test + void emptySnapshot() throws IOException { + EdgeDistributionPatternSnapshot original = + new EdgeDistributionPatternSnapshot(new HashMap<>()); + byte[] bytes = original.toBytes(); + EdgeDistributionPatternSnapshot restored = EdgeDistributionPatternSnapshot.fromBytes(bytes); + + assertThat(restored.getOutputPatterns(new OperatorID())).isNull(); + } + + @Test + void getOutputPatternByIndex() throws IOException { + Map patterns = new HashMap<>(); + OperatorID op = new OperatorID(); + patterns.put( + op, + new DistributionPattern[] { + DistributionPattern.POINTWISE, + DistributionPattern.ALL_TO_ALL, + DistributionPattern.POINTWISE + }); + + EdgeDistributionPatternSnapshot snapshot = new EdgeDistributionPatternSnapshot(patterns); + + assertThat(snapshot.getOutputPattern(op, 0)).isEqualTo(DistributionPattern.POINTWISE); + assertThat(snapshot.getOutputPattern(op, 1)).isEqualTo(DistributionPattern.ALL_TO_ALL); + assertThat(snapshot.getOutputPattern(op, 2)).isEqualTo(DistributionPattern.POINTWISE); + assertThat(snapshot.getOutputPattern(op, 3)).isNull(); + assertThat(snapshot.getOutputPattern(new OperatorID(), 0)).isNull(); + } + + @Test + void unknownPatternByteRejected() throws IOException { + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); + java.io.DataOutputStream dos = new java.io.DataOutputStream(baos); + dos.writeInt(1); // version + dos.writeInt(1); // numOperators + dos.writeLong(0L); // opId lower + dos.writeLong(0L); // opId upper + dos.writeInt(1); // numPartitions + dos.writeByte(7); // unknown pattern byte + dos.flush(); + + assertThatThrownBy(() -> EdgeDistributionPatternSnapshot.fromBytes(baos.toByteArray())) + .isInstanceOf(IOException.class) + .hasMessageContaining("Unknown DistributionPattern byte"); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtilsTest.java new file mode 100644 index 0000000000000..11da7bd5224a8 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseChannelMappingUtilsTest.java @@ -0,0 +1,788 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.jobgraph.DistributionPattern; + +import org.junit.jupiter.api.Test; + +import java.util.HashSet; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link PointwiseChannelMappingUtils}. */ +class PointwiseChannelMappingUtilsTest { + + private static final DistributionPattern PW = DistributionPattern.POINTWISE; + private static final DistributionPattern A2A = DistributionPattern.ALL_TO_ALL; + + private static PointwiseRescaleParams params(int oldUp, int oldDown, int newUp, int newDown) { + return new PointwiseRescaleParams(PW, PW, oldUp, oldDown, newUp, newDown); + } + + private static PointwiseRescaleParams params( + DistributionPattern oldPat, + DistributionPattern newPat, + int oldUp, + int oldDown, + int newUp, + int newDown) { + return new PointwiseRescaleParams(oldPat, newPat, oldUp, oldDown, newUp, newDown); + } + + // ======================== producersOf / consumersOf ======================== + + @Test + void producersOfAndConsumersOfAreInverses() { + for (int upPar = 1; upPar <= 8; upPar++) { + for (int downPar = 1; downPar <= 8; downPar++) { + for (int d = 0; d < downPar; d++) { + int[] producers = PointwiseChannelMappingUtils.producersOf(d, upPar, downPar); + for (int p : producers) { + int[] consumers = + PointwiseChannelMappingUtils.consumersOf(p, upPar, downPar); + assertThat(consumers) + .as( + "up=%d down=%d: d=%d should be in consumersOf(%d)", + upPar, downPar, d, p) + .contains(d); + } + } + } + } + } + + @Test + void producersOfCoversAllUpstream() { + for (int upPar = 1; upPar <= 8; upPar++) { + for (int downPar = 1; downPar <= 8; downPar++) { + Set allProducers = new HashSet<>(); + for (int d = 0; d < downPar; d++) { + int[] producers = PointwiseChannelMappingUtils.producersOf(d, upPar, downPar); + for (int p : producers) { + allProducers.add(p); + } + } + assertThat(allProducers) + .as("up=%d down=%d: all upstream covered", upPar, downPar) + .hasSize(upPar); + } + } + } + + @Test + void consumersOfCoversAllDownstream() { + for (int upPar = 1; upPar <= 8; upPar++) { + for (int downPar = 1; downPar <= 8; downPar++) { + Set allConsumers = new HashSet<>(); + for (int u = 0; u < upPar; u++) { + int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar); + for (int c : consumers) { + allConsumers.add(c); + } + } + assertThat(allConsumers) + .as("up=%d down=%d: all downstream covered", upPar, downPar) + .hasSize(downPar); + } + } + } + + @Test + void producersOfConcreteValues() { + // 2 upstream -> 4 downstream (upscale): each downstream has one producer + assertThat(PointwiseChannelMappingUtils.producersOf(0, 2, 4)).containsExactly(0); + assertThat(PointwiseChannelMappingUtils.producersOf(1, 2, 4)).containsExactly(0); + assertThat(PointwiseChannelMappingUtils.producersOf(2, 2, 4)).containsExactly(1); + assertThat(PointwiseChannelMappingUtils.producersOf(3, 2, 4)).containsExactly(1); + + // 4 upstream -> 2 downstream (downscale): each downstream merges 2 producers + assertThat(PointwiseChannelMappingUtils.producersOf(0, 4, 2)).containsExactly(0, 1); + assertThat(PointwiseChannelMappingUtils.producersOf(1, 4, 2)).containsExactly(2, 3); + + // 3 upstream -> 3 downstream (equal): identity + assertThat(PointwiseChannelMappingUtils.producersOf(0, 3, 3)).containsExactly(0); + assertThat(PointwiseChannelMappingUtils.producersOf(1, 3, 3)).containsExactly(1); + assertThat(PointwiseChannelMappingUtils.producersOf(2, 3, 3)).containsExactly(2); + } + + @Test + void consumersOfConcreteValues() { + // 2 upstream -> 4 downstream (upscale): each upstream serves 2 consumers + assertThat(PointwiseChannelMappingUtils.consumersOf(0, 2, 4)).containsExactly(0, 1); + assertThat(PointwiseChannelMappingUtils.consumersOf(1, 2, 4)).containsExactly(2, 3); + + // 4 upstream -> 2 downstream (downscale): each upstream feeds one consumer + assertThat(PointwiseChannelMappingUtils.consumersOf(0, 4, 2)).containsExactly(0); + assertThat(PointwiseChannelMappingUtils.consumersOf(1, 4, 2)).containsExactly(0); + assertThat(PointwiseChannelMappingUtils.consumersOf(2, 4, 2)).containsExactly(1); + assertThat(PointwiseChannelMappingUtils.consumersOf(3, 4, 2)).containsExactly(1); + } + + // ======================== localIndexToGlobalSubtaskIndex ======================== + + @Test + void localIndexToGlobalSubtaskIndexPointwiseInputSide() { + // 4 upstream -> 2 downstream: d=0 has producers {0,1}, d=1 has producers {2,3} + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 0, 0, 4, 2, true, PW)) + .isEqualTo(0); + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 0, 1, 4, 2, true, PW)) + .isEqualTo(1); + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 1, 0, 4, 2, true, PW)) + .isEqualTo(2); + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 1, 1, 4, 2, true, PW)) + .isEqualTo(3); + } + + @Test + void localIndexToGlobalSubtaskIndexPointwiseOutputSide() { + // 2 upstream -> 4 downstream: u=0 has consumers {0,1}, u=1 has consumers {2,3} + // consumers start at 0 and 2 (both % 2 == 0), so modulo == sequential here. + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 0, 0, 2, 4, false, PW)) + .isEqualTo(0); + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 0, 1, 2, 4, false, PW)) + .isEqualTo(1); + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 1, 0, 2, 4, false, PW)) + .isEqualTo(2); + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 1, 1, 2, 4, false, PW)) + .isEqualTo(3); + + // 4 upstream -> 6 downstream: u=2 has consumers {3,4} + // consumers[0]=3, 3%2=1 ≠ 0, so modulo ≠ sequential: + // sp 0 → consumer 4 (4%2=0), sp 1 → consumer 3 (3%2=1) + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 2, 0, 4, 6, false, PW)) + .isEqualTo(4); + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 2, 1, 4, 6, false, PW)) + .isEqualTo(3); + } + + @Test + void localIndexToGlobalSubtaskIndexBoundsCheck() { + assertThatThrownBy( + () -> + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 0, 5, 4, 2, true, PW)) + .isInstanceOf(ArrayIndexOutOfBoundsException.class); + } + + // ======================== getOldLocalChannelCount ======================== + + @Test + void getOldLocalChannelCountAllToAll() { + PointwiseRescaleParams p = new PointwiseRescaleParams(A2A, A2A, 5, 3, 5, 3); + assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(0, p)).isEqualTo(5); + assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(2, p)).isEqualTo(5); + } + + @Test + void getOldLocalChannelCountPointwise() { + // 4 upstream -> 2 downstream: each downstream has 2 producers + PointwiseRescaleParams p = params(4, 2, 4, 2); + assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(0, p)).isEqualTo(2); + assertThat(PointwiseChannelMappingUtils.getOldLocalChannelCount(1, p)).isEqualTo(2); + } + + // ======================== computeNewLocalInputChannelIndex / SubpartitionIndex === + + @Test + void computeNewLocalInputChannelIndexAllToAll() { + PointwiseRescaleParams p = new PointwiseRescaleParams(PW, A2A, 4, 2, 4, 2); + assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(3, 0, p)) + .isEqualTo(3); + } + + @Test + void computeNewLocalInputChannelIndexPointwise() { + // 4 upstream -> 2 downstream: d=0 connects to {0,1}, d=1 connects to {2,3} + PointwiseRescaleParams p = params(4, 2, 4, 2); + assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(0, 0, p)) + .isEqualTo(0); + assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(1, 0, p)) + .isEqualTo(1); + assertThat(PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex(2, 1, p)) + .isEqualTo(0); + } + + @Test + void computeNewLocalSubpartitionIndexAllToAll() { + PointwiseRescaleParams p = new PointwiseRescaleParams(PW, A2A, 2, 4, 2, 4); + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(3, 0, p)) + .isEqualTo(3); + } + + @Test + void computeNewLocalSubpartitionIndexPointwise() { + // 2 upstream -> 4 downstream: u=0 connects to {0,1}, u=1 connects to {2,3} + PointwiseRescaleParams p = params(2, 4, 2, 4); + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(0, 0, p)) + .isEqualTo(0); + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(1, 0, p)) + .isEqualTo(1); + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(2, 1, p)) + .isEqualTo(0); + } + + // ======================== computeOldToNewSubtaskAssignment ======================== + + @Test + void assignmentCoversAllOldSubtasks() { + for (int oldUp = 1; oldUp <= 6; oldUp++) { + for (int newUp = 1; newUp <= 6; newUp++) { + for (int oldDown = 1; oldDown <= 6; oldDown++) { + for (int newDown = 1; newDown <= 6; newDown++) { + int[] mapping = + PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment( + params(oldUp, oldDown, newUp, newDown), true); + assertThat(mapping).hasSize(oldUp); + for (int i = 0; i < oldUp; i++) { + assertThat(mapping[i]) + .as( + "up %d->%d, down %d->%d: old[%d] in range", + oldUp, newUp, oldDown, newDown, i) + .isBetween(0, newUp - 1); + } + } + } + } + } + } + + @Test + void assignmentIsIdentityWhenUnchanged() { + for (int p = 1; p <= 8; p++) { + int[] mapping = + PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment( + params(p, p, p, p), true); + for (int i = 0; i < p; i++) { + assertThat(mapping[i]).as("p=%d: identity at %d", p, i).isEqualTo(i); + } + } + } + + @Test + void assignmentDownstreamIsIdentityWhenUnchanged() { + for (int p = 1; p <= 8; p++) { + int[] mapping = + PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment( + params(p, p, p, p), false); + for (int i = 0; i < p; i++) { + assertThat(mapping[i]).as("p=%d: identity at %d", p, i).isEqualTo(i); + } + } + } + + @Test + void assignmentUpstreamRespectTopology() { + // 4 upstream -> 2 downstream, rescale to 2 upstream -> 2 downstream. + // Old u=0 has consumers {0} in old topology (4 up, 2 down). + // New topology (2 up, 2 down): new u=0 -> consumers {0}, new u=1 -> consumers {1}. + // Old u=0's target counterpart = old consumer 0 mod 2 = 0, so must go to new u=0. + int[] mapping = + PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment( + params(4, 2, 2, 2), true); + // old u=0,1 have consumer 0 in old topology -> should map to new u=0 + // old u=2,3 have consumer 1 in old topology -> should map to new u=1 + assertThat(mapping[0]).isEqualTo(0); + assertThat(mapping[1]).isEqualTo(0); + assertThat(mapping[2]).isEqualTo(1); + assertThat(mapping[3]).isEqualTo(1); + } + + @Test + void assignmentDownstreamCoherentWithUpstream() { + // Verify downstream assignment is coherent: for any assigned pair (oldD -> newD), + // the upstream assignment of oldD's producers should produce subtasks that are + // connected to newD in the new topology. + int oldUp = 4, newUp = 2, oldDown = 4, newDown = 2; + PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown); + int[] fUp = PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(p, true); + int[] fDown = PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment(p, false); + for (int oldD = 0; oldD < oldDown; oldD++) { + int newD = fDown[oldD]; + int[] oldProducers = PointwiseChannelMappingUtils.producersOf(oldD, oldUp, oldDown); + int[] newProducersOfNewD = + PointwiseChannelMappingUtils.producersOf(newD, newUp, newDown); + for (int oldP : oldProducers) { + int newU = fUp[oldP]; + assertThat(newProducersOfNewD) + .as( + "oldD=%d->newD=%d, oldP=%d->newU=%d should connect", + oldD, newD, oldP, newU) + .contains(newU); + } + } + } + + @Test + void assignmentScaleUp() { + // 2 upstream -> 4 upstream, downstream unchanged at 4 + int[] mapping = + PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment( + params(2, 4, 4, 4), true); + assertThat(mapping).hasSize(2); + for (int v : mapping) { + assertThat(v).isBetween(0, 3); + } + } + + @Test + void assignmentScaleDown() { + // 6 upstream -> 2 upstream, downstream 6 -> 2 + int[] mapping = + PointwiseChannelMappingUtils.computeOldToNewSubtaskAssignment( + params(6, 6, 2, 2), true); + assertThat(mapping).hasSize(6); + for (int v : mapping) { + assertThat(v).isBetween(0, 1); + } + } + + // =============== oldSubtasksAssignedTo / newSubtaskAssignedFrom =============== + + @Test + void oldSubtasksAssignedToAndNewSubtaskAssignedFromAreConsistent() { + for (int oldPar = 1; oldPar <= 8; oldPar++) { + for (int newPar = 1; newPar <= 8; newPar++) { + for (int newIdx = 0; newIdx < newPar; newIdx++) { + int[] oldSubtasks = + PointwiseChannelMappingUtils.oldSubtasksAssignedTo( + newIdx, oldPar, newPar); + for (int oldIdx : oldSubtasks) { + assertThat( + PointwiseChannelMappingUtils.newSubtaskAssignedFrom( + oldIdx, newPar)) + .as( + "old=%d,new=%d: newSubtaskAssignedFrom(%d) == %d", + oldPar, newPar, oldIdx, newIdx) + .isEqualTo(newIdx); + } + } + } + } + } + + @Test + void oldSubtasksAssignedToCoversAllOld() { + for (int oldPar = 1; oldPar <= 8; oldPar++) { + for (int newPar = 1; newPar <= 8; newPar++) { + Set covered = new HashSet<>(); + for (int newIdx = 0; newIdx < newPar; newIdx++) { + for (int oldIdx : + PointwiseChannelMappingUtils.oldSubtasksAssignedTo( + newIdx, oldPar, newPar)) { + covered.add(oldIdx); + } + } + assertThat(covered) + .as("old=%d new=%d: all old subtasks assigned", oldPar, newPar) + .hasSize(oldPar); + } + } + } + + @Test + void newSubtaskAssignedFromIsModulo() { + assertThat(PointwiseChannelMappingUtils.newSubtaskAssignedFrom(5, 3)).isEqualTo(2); + assertThat(PointwiseChannelMappingUtils.newSubtaskAssignedFrom(0, 4)).isEqualTo(0); + assertThat(PointwiseChannelMappingUtils.newSubtaskAssignedFrom(7, 3)).isEqualTo(1); + } + + // ======================== traceOutputSources ======================== + + @Test + void traceOutputSourcesPwToPw() { + // 3 up -> 3 down (PW), rescale to 4 up -> 3 down (PW) + // newUp=0: consumers in new topo = consumersOf(0,4,3) = {0} + // newDown=0: oldSubtasksAssignedTo(0,3,3) = {0} + // oldDown=0: producersOf(0,3,3) = {0} + // result = {0} + PointwiseRescaleParams p = params(3, 3, 4, 3); + assertThat(PointwiseChannelMappingUtils.traceOutputSources(0, p)).containsExactly(0); + + // newUp=3: consumers in new topo = consumersOf(3,4,3) = {2} + // newDown=2: oldSubtasksAssignedTo(2,3,3) = {2} + // oldDown=2: producersOf(2,3,3) = {2} + // result = {2} + assertThat(PointwiseChannelMappingUtils.traceOutputSources(3, p)).containsExactly(2); + } + + @Test + void traceOutputSourcesA2aToPw() { + // Old A2A (3 up, 3 down) -> New PW (4 up, 3 down) + // Any old downstream connects to all 3 old upstreams (A2A), so result = {0,1,2} + PointwiseRescaleParams p = params(A2A, PW, 3, 3, 4, 3); + for (int newUp = 0; newUp < 4; newUp++) { + assertThat(PointwiseChannelMappingUtils.traceOutputSources(newUp, p)) + .as("newUp=%d", newUp) + .containsExactly(0, 1, 2); + } + } + + @Test + void traceOutputSourcesPwToA2a() { + // Old PW (3 up, 3 down) -> New A2A (4 up, 3 down) + // New A2A: every new upstream serves ALL new downstreams {0,1,2} + // oldSubtasksAssignedTo covers all 3 old downstreams → all old producers reached + PointwiseRescaleParams p = params(PW, A2A, 3, 3, 4, 3); + for (int newUp = 0; newUp < 4; newUp++) { + assertThat(PointwiseChannelMappingUtils.traceOutputSources(newUp, p)) + .as("newUp=%d", newUp) + .containsExactly(0, 1, 2); + } + } + + @Test + void traceOutputSourcesCoversAllOldUpstreams() { + for (int oldUp = 1; oldUp <= 6; oldUp++) { + for (int oldDown = 1; oldDown <= 6; oldDown++) { + for (int newUp = 1; newUp <= 6; newUp++) { + for (int newDown = 1; newDown <= 6; newDown++) { + PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown); + Set allClaimed = new HashSet<>(); + for (int u = 0; u < newUp; u++) { + for (int oldU : PointwiseChannelMappingUtils.traceOutputSources(u, p)) { + allClaimed.add(oldU); + } + } + assertThat(allClaimed) + .as( + "PW %d:%d->%d:%d: all old upstreams covered", + oldUp, oldDown, newUp, newDown) + .hasSize(oldUp); + } + } + } + } + } + + // ======================== resolveInputOwnership ======================== + + @Test + void resolveInputOwnershipPwToPw() { + // 3 up -> 3 down (PW), rescale to 4 up -> 3 down (PW) + // newDown=0: connected new upstreams = producersOf(0,4,3) = {0} + // newUp=0: traceOutputSources = {0} → mapping[0] = 0 + PointwiseRescaleParams p = params(3, 3, 4, 3); + int[] ownership = PointwiseChannelMappingUtils.resolveInputOwnership(0, p); + assertThat(ownership).hasSize(3); + assertThat(ownership[0]).isEqualTo(0); + } + + @Test + void resolveInputOwnershipNoUnclaimedForConnectedOldUpstreams() { + for (int oldUp = 1; oldUp <= 6; oldUp++) { + for (int oldDown = 1; oldDown <= 6; oldDown++) { + for (int newUp = 1; newUp <= 6; newUp++) { + for (int newDown = 1; newDown <= 6; newDown++) { + PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown); + for (int newD = 0; newD < newDown; newD++) { + int[] ownership = + PointwiseChannelMappingUtils.resolveInputOwnership(newD, p); + // Every old upstream that was connected to any old downstream + // that maps to this new downstream must be claimed. + int[] oldDownstreams = + PointwiseChannelMappingUtils.oldSubtasksAssignedTo( + newD, oldDown, newDown); + for (int oldD : oldDownstreams) { + int[] oldProducers = + PointwiseChannelMappingUtils.producersOf( + oldD, oldUp, oldDown); + for (int oldU : oldProducers) { + assertThat(ownership[oldU]) + .as( + "%d:%d->%d:%d newD=%d oldD=%d oldU=%d must be claimed", + oldUp, oldDown, newUp, newDown, newD, oldD, + oldU) + .isGreaterThanOrEqualTo(0); + } + } + } + } + } + } + } + } + + @Test + void resolveInputOwnershipOwnerIsConnected() { + for (int oldUp = 1; oldUp <= 6; oldUp++) { + for (int oldDown = 1; oldDown <= 6; oldDown++) { + for (int newUp = 1; newUp <= 6; newUp++) { + for (int newDown = 1; newDown <= 6; newDown++) { + PointwiseRescaleParams p = params(oldUp, oldDown, newUp, newDown); + for (int newD = 0; newD < newDown; newD++) { + int[] ownership = + PointwiseChannelMappingUtils.resolveInputOwnership(newD, p); + int[] connectedNewUp = + PointwiseChannelMappingUtils.producersOf(newD, newUp, newDown); + for (int oldU = 0; oldU < oldUp; oldU++) { + if (ownership[oldU] >= 0) { + assertThat(connectedNewUp) + .as( + "%d:%d->%d:%d newD=%d: owner of oldU=%d should be connected", + oldUp, oldDown, newUp, newDown, newD, oldU) + .contains(ownership[oldU]); + } + } + } + } + } + } + } + } + + // ======================== Primary producer dedup ======================== + + @Test + void primaryProducerDedupExactlyOneWriterPerDownstream() { + // When newUpPar > newDownPar, multiple upstreams connect to the same downstream. + // computeNewLocalSubpartitionIndex must return non-negative for exactly one. + int[][] cases = {{4, 2}, {10, 3}, {11, 3}, {6, 2}, {7, 1}}; + for (int[] c : cases) { + int newUp = c[0], newDown = c[1]; + PointwiseRescaleParams p = params(newUp, newDown, newUp, newDown); + for (int d = 0; d < newDown; d++) { + int writerCount = 0; + int writerIndex = -1; + for (int u = 0; u < newUp; u++) { + int sub = + PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(d, u, p); + if (sub >= 0) { + writerCount++; + writerIndex = u; + } + } + assertThat(writerCount) + .as("%d:%d downstream %d: exactly one writer", newUp, newDown, d) + .isEqualTo(1); + // The writer must be producersOf(d)[0] (primary producer) + int[] producers = PointwiseChannelMappingUtils.producersOf(d, newUp, newDown); + assertThat(writerIndex) + .as("%d:%d downstream %d: writer is primary producer", newUp, newDown, d) + .isEqualTo(producers[0]); + } + } + } + + // ======================== Modulo-based subpartition/channel mapping ======================== + + /** + * The execution graph assigns subpartitions via {@code consumerSubtaskIndex % numConsumers} + * (see {@code VertexInputInfoComputationUtils.computeConsumedSubpartitionRange}). When + * consumers don't start at a multiple of numConsumers, this differs from sequential indexing + * ({@code D - consumers[0]}). Both {@code localIndexToGlobalSubtaskIndex} (output side) and + * {@code computeNewLocalSubpartitionIndex} must use modulo, not sequential offset. + * + *

Concrete example: 4 upstream → 6 downstream POINTWISE. Upstream 2 has consumers [3,4]. + * Sequential would map sp0→D3, sp1→D4. But the execution graph assigns D3→sp(3%2=1), + * D4→sp(4%2=0). So sp0→D4, sp1→D3. Using sequential causes buffers to be routed to the wrong + * downstream, leading to "Cannot select SubtaskConnectionDescriptor" errors at the DEMUX layer. + */ + @Test + void moduloMappingDiffersFromSequentialWhenOffsetNotAligned() { + // 4 up -> 6 down: upstream 2 has consumers [3,4] + int[] consumers = PointwiseChannelMappingUtils.consumersOf(2, 4, 6); + assertThat(consumers).containsExactly(3, 4); + + // localIndexToGlobalSubtaskIndex output side: sp → consumer via modulo + // sp 0 → consumer where c%2==0 → 4 (NOT sequential consumers[0]=3) + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 2, 0, 4, 6, false, PW)) + .isEqualTo(4); + // sp 1 → consumer where c%2==1 → 3 (NOT sequential consumers[1]=4) + assertThat( + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + 2, 1, 4, 6, false, PW)) + .isEqualTo(3); + + // computeNewLocalSubpartitionIndex: consumer → sp via modulo + PointwiseRescaleParams p = params(4, 6, 4, 6); + // D=3 → sp 3%2=1 (NOT sequential 3-3=0) + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(3, 2, p)) + .isEqualTo(1); + // D=4 → sp 4%2=0 (NOT sequential 4-3=1) + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(4, 2, p)) + .isEqualTo(0); + } + + @Test + void subpartitionIndexModuloForNonZeroOffset() { + // Subpartition assignment uses D % numConsumers (matching the execution graph), + // NOT D - consumers[0] (sequential offset). This distinction matters when + // consumers[0] % consumers.length != 0. + + // 3 up -> 10 down (PW): upstream 1 has consumers [4,5,6] + // D=4 → sp 4%3=1, D=5 → sp 5%3=2, D=6 → sp 6%3=0 + PointwiseRescaleParams p = params(3, 10, 3, 10); + int[] consumers = PointwiseChannelMappingUtils.consumersOf(1, 3, 10); + assertThat(consumers).startsWith(4); // sanity check: non-zero offset + for (int c : consumers) { + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(c, 1, p)) + .as("3:10 upstream 1, D=%d → sp %d", c, c % consumers.length) + .isEqualTo(c % consumers.length); + } + + // 2 up -> 5 down (PW): upstream 1 has consumers [3,4] + // D=3 → sp 3%2=1, D=4 → sp 4%2=0 + PointwiseRescaleParams p2 = params(2, 5, 2, 5); + int[] consumers2 = PointwiseChannelMappingUtils.consumersOf(1, 2, 5); + assertThat(consumers2).startsWith(3); + for (int c : consumers2) { + assertThat(PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex(c, 1, p2)) + .as("2:5 upstream 1, D=%d → sp %d", c, c % consumers2.length) + .isEqualTo(c % consumers2.length); + } + } + + @Test + void subpartitionIndexIsModuloForAllParallelisms() { + // Exhaustively verify D % numConsumers for all upPar x downPar up to 8. + for (int upPar = 1; upPar <= 8; upPar++) { + for (int downPar = 1; downPar <= 8; downPar++) { + PointwiseRescaleParams p = params(upPar, downPar, upPar, downPar); + for (int u = 0; u < upPar; u++) { + int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar); + for (int c : consumers) { + int sp = + PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex( + c, u, p); + int[] producers = + PointwiseChannelMappingUtils.producersOf(c, upPar, downPar); + if (producers[0] == u) { + assertThat(sp) + .as( + "up=%d down=%d u=%d D=%d: sp should be D %% numConsumers", + upPar, downPar, u, c) + .isEqualTo(c % consumers.length); + } + } + } + } + } + } + + @Test + void outputSideLocalIndexModuloForAllParallelisms() { + // Exhaustively verify modulo mapping on the output side of localIndexToGlobalSubtaskIndex. + for (int upPar = 1; upPar <= 8; upPar++) { + for (int downPar = 1; downPar <= 8; downPar++) { + for (int u = 0; u < upPar; u++) { + int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar); + for (int sp = 0; sp < consumers.length; sp++) { + int globalD = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + u, sp, upPar, downPar, false, PW); + assertThat(globalD % consumers.length) + .as( + "up=%d down=%d u=%d sp=%d: D %% numConsumers should equal sp", + upPar, downPar, u, sp) + .isEqualTo(sp); + } + } + } + } + } + + // =============== Round-trip: localIndexToGlobalSubtaskIndex ↔ subpartition index + // =============== + + @Test + void localIndexToGlobalSubtaskIndexAndSubpartitionIndexAreInverses() { + for (int upPar = 1; upPar <= 6; upPar++) { + for (int downPar = 1; downPar <= 6; downPar++) { + PointwiseRescaleParams p = params(upPar, downPar, upPar, downPar); + for (int u = 0; u < upPar; u++) { + int[] consumers = PointwiseChannelMappingUtils.consumersOf(u, upPar, downPar); + // Only the primary producer can round-trip (others get -1) + for (int localIndex = 0; localIndex < consumers.length; localIndex++) { + int globalD = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + u, localIndex, upPar, downPar, false, PW); + // Output side uses modulo mapping (D % numConsumers == localIndex), + // not sequential consumers[localIndex]. + assertThat(consumers).contains(globalD); + assertThat(globalD % consumers.length) + .as("up=%d down=%d u=%d local=%d", upPar, downPar, u, localIndex) + .isEqualTo(localIndex); + int[] producers = + PointwiseChannelMappingUtils.producersOf(globalD, upPar, downPar); + if (producers[0] == u) { + int backLocal = + PointwiseChannelMappingUtils.computeNewLocalSubpartitionIndex( + globalD, u, p); + assertThat(backLocal) + .as( + "up=%d down=%d u=%d: round-trip localIndex=%d", + upPar, downPar, u, localIndex) + .isEqualTo(localIndex); + } + } + } + } + } + } + + @Test + void localIndexToGlobalSubtaskIndexAndInputChannelIndexAreInverses() { + for (int upPar = 1; upPar <= 6; upPar++) { + for (int downPar = 1; downPar <= 6; downPar++) { + PointwiseRescaleParams p = params(upPar, downPar, upPar, downPar); + for (int d = 0; d < downPar; d++) { + int[] producers = PointwiseChannelMappingUtils.producersOf(d, upPar, downPar); + for (int localIndex = 0; localIndex < producers.length; localIndex++) { + int globalU = + PointwiseChannelMappingUtils.localIndexToGlobalSubtaskIndex( + d, localIndex, upPar, downPar, true, PW); + assertThat(globalU).isEqualTo(producers[localIndex]); + int backLocal = + PointwiseChannelMappingUtils.computeNewLocalInputChannelIndex( + globalU, d, p); + assertThat(backLocal) + .as( + "up=%d down=%d d=%d: round-trip localIndex=%d", + upPar, downPar, d, localIndex) + .isEqualTo(localIndex); + } + } + } + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseRescalingDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseRescalingDescriptorTest.java new file mode 100644 index 0000000000000..5bc8fad43f355 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PointwiseRescalingDescriptorTest.java @@ -0,0 +1,632 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.runtime.JobException; +import org.apache.flink.runtime.OperatorIDPair; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor; +import org.apache.flink.runtime.client.JobExecutionException; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.JobEdge; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobGraphTestUtils; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.testtasks.NoOpInvokable; +import org.apache.flink.runtime.util.JobVertexConnectionUtils; +import org.apache.flink.testutils.TestingUtils; +import org.apache.flink.testutils.executor.TestExecutorExtension; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Collectors; + +import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewInputChannelStateHandle; +import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewResultSubpartitionStateHandle; +import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for POINTWISE edge rescaling descriptor generation in {@link + * StateAssignmentOperation} and {@link TaskStateAssignment}. + */ +class PointwiseRescalingDescriptorTest { + + @RegisterExtension + private static final TestExecutorExtension EXECUTOR_EXTENSION = + TestingUtils.defaultExecutorExtension(); + + private static final int MAX_P = 256; + + // ===== Test 1: POINTWISE same parallelism => IDENTITY ===== + @Test + void testPointwiseSameParallelismIsIdentity() throws Exception { + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldPar = 3; + int newPar = 3; + + JobVertex source = createJobVertex(sourceId, newPar); + JobVertex sink = createJobVertex(sinkId, newPar); + connectPointwise(source, sink); + + Map vertices = toExecutionVertices(source, sink); + Map states = + buildStatesWithChannelState(sourceId, sinkId, oldPar); + + EdgeDistributionPatternSnapshot oldEdgePatterns = + buildEdgePatterns(sourceId, DistributionPattern.POINTWISE); + + new StateAssignmentOperation( + 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns) + .assignStates(); + + // Same parallelism + same pattern => IDENTITY (NO_RESCALE) + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sinkState = + getAssignedState(vertices.get(sinkId), sinkId, subtask); + assertThat(sinkState.getInputRescalingDescriptor()) + .isEqualTo(InflightDataRescalingDescriptor.NO_RESCALE); + + OperatorSubtaskState sourceState = + getAssignedState(vertices.get(sourceId), sourceId, subtask); + assertThat(sourceState.getOutputRescalingDescriptor()) + .isEqualTo(InflightDataRescalingDescriptor.NO_RESCALE); + } + } + + // ===== Test 2: POINTWISE scale up - subtasks with old state get descriptors ===== + @Test + void testPointwiseScaleUp() throws Exception { + // old: source(2) --PW--> sink(2), new: source(4) --PW--> sink(4) + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldPar = 2; + int newPar = 4; + + JobVertex source = createJobVertex(sourceId, newPar); + JobVertex sink = createJobVertex(sinkId, newPar); + connectPointwise(source, sink); + + Map vertices = toExecutionVertices(source, sink); + Map states = + buildStatesWithChannelState(sourceId, sinkId, oldPar); + + EdgeDistributionPatternSnapshot oldEdgePatterns = + buildEdgePatterns(sourceId, DistributionPattern.POINTWISE); + + new StateAssignmentOperation( + 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns) + .assignStates(); + + // Subtasks that absorb old state should get POINTWISE descriptors (newUpPar > 0) + int subtasksWithDescriptor = 0; + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sinkState = + getAssignedState(vertices.get(sinkId), sinkId, subtask); + InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor(); + + if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) { + subtasksWithDescriptor++; + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + inputDesc.getGateOrPartitionDescriptor(0); + assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()) + .isEqualTo(newPar); + assertThat(gateDesc.getPointwiseRescaleParams().getNewDownParallelism()) + .isEqualTo(newPar); + assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism()) + .isEqualTo(oldPar); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDownParallelism()) + .isEqualTo(oldPar); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern()) + .isEqualTo(DistributionPattern.POINTWISE); + assertThat(gateDesc.isIdentity()).isFalse(); + + int[] oldSubtaskIndexes = inputDesc.getOldSubtaskIndexes(0); + assertThat(oldSubtaskIndexes.length).isGreaterThan(0); + } + } + // At least the subtasks that got old state must have descriptors + assertThat(subtasksWithDescriptor).isGreaterThan(0); + + // Verify output side similarly + int sourceSubtasksWithDescriptor = 0; + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sourceState = + getAssignedState(vertices.get(sourceId), sourceId, subtask); + InflightDataRescalingDescriptor outputDesc = sourceState.getOutputRescalingDescriptor(); + + if (!outputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) { + sourceSubtasksWithDescriptor++; + InflightDataGateOrPartitionRescalingDescriptor partDesc = + outputDesc.getGateOrPartitionDescriptor(0); + assertThat(partDesc.getPointwiseRescaleParams().getNewUpParallelism()) + .isEqualTo(newPar); + assertThat(partDesc.getPointwiseRescaleParams().getNewDownParallelism()) + .isEqualTo(newPar); + assertThat(partDesc.getPointwiseRescaleParams().getOldUpParallelism()) + .isEqualTo(oldPar); + assertThat(partDesc.getPointwiseRescaleParams().getOldDownParallelism()) + .isEqualTo(oldPar); + assertThat(partDesc.getPointwiseRescaleParams().getOldDistributionPattern()) + .isEqualTo(DistributionPattern.POINTWISE); + } + } + assertThat(sourceSubtasksWithDescriptor).isGreaterThan(0); + } + + // ===== Test 3: POINTWISE scale down - all old subtasks covered ===== + @Test + void testPointwiseScaleDown() throws Exception { + // old: source(4) --PW--> sink(4), new: source(2) --PW--> sink(2) + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldPar = 4; + int newPar = 2; + + JobVertex source = createJobVertex(sourceId, newPar); + JobVertex sink = createJobVertex(sinkId, newPar); + connectPointwise(source, sink); + + Map vertices = toExecutionVertices(source, sink); + Map states = + buildStatesWithChannelState(sourceId, sinkId, oldPar); + + EdgeDistributionPatternSnapshot oldEdgePatterns = + buildEdgePatterns(sourceId, DistributionPattern.POINTWISE); + + new StateAssignmentOperation( + 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns) + .assignStates(); + + // Scale down: every new subtask absorbs state from multiple old subtasks + boolean[] covered = new boolean[oldPar]; + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sinkState = + getAssignedState(vertices.get(sinkId), sinkId, subtask); + InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor(); + assertThat(inputDesc) + .as("subtask %d should have rescaling descriptor", subtask) + .isNotEqualTo(InflightDataRescalingDescriptor.NO_RESCALE); + + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + inputDesc.getGateOrPartitionDescriptor(0); + assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism()) + .isEqualTo(oldPar); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDownParallelism()) + .isEqualTo(oldPar); + assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()) + .isEqualTo(newPar); + assertThat(gateDesc.getPointwiseRescaleParams().getNewDownParallelism()) + .isEqualTo(newPar); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern()) + .isEqualTo(DistributionPattern.POINTWISE); + + int[] oldSubtaskIndexes = inputDesc.getOldSubtaskIndexes(0); + assertThat(oldSubtaskIndexes.length).isGreaterThan(0); + for (int oldIdx : oldSubtaskIndexes) { + covered[oldIdx] = true; + } + } + // Every old subtask must be assigned to some new subtask + for (int i = 0; i < oldPar; i++) { + assertThat(covered[i]).as("Old subtask %d should be covered", i).isTrue(); + } + } + + // ===== Test 4: A2A -> POINTWISE migration ===== + @Test + void testAllToAllToPointwiseMigration() throws Exception { + // old: source(3) --A2A--> sink(3), new: source(3) --PW--> sink(3) + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldPar = 3; + int newPar = 3; + + JobVertex source = createJobVertex(sourceId, newPar); + JobVertex sink = createJobVertex(sinkId, newPar); + connectPointwise(source, sink); + + Map vertices = toExecutionVertices(source, sink); + Map states = + buildStatesWithChannelState(sourceId, sinkId, oldPar); + + // Old pattern was A2A + EdgeDistributionPatternSnapshot oldEdgePatterns = + buildEdgePatterns(sourceId, DistributionPattern.ALL_TO_ALL); + + new StateAssignmentOperation( + 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns) + .assignStates(); + + // A2A->PW same-par: new sink k reads from old sink k only (ROUND_ROBIN 1:1). + // TM-side recoverPointwise then routes each old upstream's buffer to the correct new + // upstream channel via resolveInputOwnership, so reading the extra old sinks would + // cause each inflight record to be delivered to all new subtasks (3x duplicates). + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sinkState = + getAssignedState(vertices.get(sinkId), sinkId, subtask); + InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor(); + assertThat(inputDesc) + .as("subtask %d", subtask) + .isNotEqualTo(InflightDataRescalingDescriptor.NO_RESCALE); + + int[] oldSubtaskIndexes = inputDesc.getOldSubtaskIndexes(0); + // ROUND_ROBIN 1:1 for same parallelism: new sink k -> old sink k + assertThat(oldSubtaskIndexes).isEqualTo(new int[] {subtask}); + + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + inputDesc.getGateOrPartitionDescriptor(0); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern()) + .isEqualTo(DistributionPattern.ALL_TO_ALL); + assertThat(gateDesc.getPointwiseRescaleParams().getNewDistributionPattern()) + .isEqualTo(DistributionPattern.POINTWISE); + assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()) + .isEqualTo(newPar); + } + } + + // ===== Test 5: POINTWISE -> A2A migration ===== + @Test + void testPointwiseToAllToAllMigration() throws Exception { + // old: source(3) --PW--> sink(3), new: source(3) --A2A--> sink(3) + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldPar = 3; + int newPar = 3; + + JobVertex source = createJobVertex(sourceId, newPar); + JobVertex sink = createJobVertex(sinkId, newPar); + connectAllToAll(source, sink); + + Map vertices = toExecutionVertices(source, sink); + Map states = + buildStatesWithChannelState(sourceId, sinkId, oldPar); + + // Old pattern was POINTWISE + EdgeDistributionPatternSnapshot oldEdgePatterns = + buildEdgePatterns(sourceId, DistributionPattern.POINTWISE); + + new StateAssignmentOperation( + 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns) + .assignStates(); + + // PW->A2A: enters the pointwise path (oldPattern != A2A) + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sinkState = + getAssignedState(vertices.get(sinkId), sinkId, subtask); + InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor(); + assertThat(inputDesc) + .as("subtask %d", subtask) + .isNotEqualTo(InflightDataRescalingDescriptor.NO_RESCALE); + + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + inputDesc.getGateOrPartitionDescriptor(0); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern()) + .isEqualTo(DistributionPattern.POINTWISE); + assertThat(gateDesc.getPointwiseRescaleParams().getNewDistributionPattern()) + .isEqualTo(DistributionPattern.ALL_TO_ALL); + // PW->A2A: newUpParallelism > 0 confirms pointwise path used + assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()) + .isEqualTo(newPar); + } + } + + // ===== Test 6: Pure A2A rescaling (regression - original path unchanged) ===== + @Test + void testPureAllToAllRescalingRegression() throws Exception { + // old: source(2) --A2A--> sink(2), new: source(3) --A2A--> sink(3) + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldPar = 2; + int newPar = 3; + + JobVertex source = createJobVertex(sourceId, newPar); + JobVertex sink = createJobVertex(sinkId, newPar); + connectAllToAll(source, sink); + + Map vertices = toExecutionVertices(source, sink); + Map states = + buildStatesWithChannelState(sourceId, sinkId, oldPar); + + // Old pattern was A2A (same as new) + EdgeDistributionPatternSnapshot oldEdgePatterns = + buildEdgePatterns(sourceId, DistributionPattern.ALL_TO_ALL); + + new StateAssignmentOperation( + 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns) + .assignStates(); + + // A2A->A2A: should use the original all-to-all path + // The descriptor should have newUpParallelism == 0 (no pointwise fields) + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sinkState = + getAssignedState(vertices.get(sinkId), sinkId, subtask); + InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor(); + + if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) { + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + inputDesc.getGateOrPartitionDescriptor(0); + // A2A->A2A uses 4-arg constructor, scalars all 0 + assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()).isEqualTo(0); + assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism()).isEqualTo(0); + } + } + } + + // ===== Test 7: POINTWISE cross-parallelism ===== + @Test + void testPointwiseCrossParallelism() throws Exception { + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldSourcePar = 3; + int oldSinkPar = 2; + int newSourcePar = 4; + int newSinkPar = 3; + + JobVertex source = createJobVertex(sourceId, newSourcePar); + JobVertex sink = createJobVertex(sinkId, newSinkPar); + connectPointwise(source, sink); + + Map vertices = toExecutionVertices(source, sink); + + Map states = new HashMap<>(); + Random random = new Random(42); + + OperatorState sourceState = new OperatorState(null, null, sourceId, oldSourcePar, MAX_P); + for (int i = 0; i < oldSourcePar; i++) { + sourceState.putState( + i, + OperatorSubtaskState.builder() + .setResultSubpartitionState( + new StateObjectCollection<>( + Collections.singletonList( + createNewResultSubpartitionStateHandle( + 10, 0, random)))) + .build()); + } + states.put(sourceId, sourceState); + + OperatorState sinkState = new OperatorState(null, null, sinkId, oldSinkPar, MAX_P); + for (int i = 0; i < oldSinkPar; i++) { + sinkState.putState( + i, + OperatorSubtaskState.builder() + .setInputChannelState( + new StateObjectCollection<>( + Collections.singletonList( + createNewInputChannelStateHandle( + 10, 0, random)))) + .build()); + } + states.put(sinkId, sinkState); + + EdgeDistributionPatternSnapshot oldEdgePatterns = + buildEdgePatterns(sourceId, DistributionPattern.POINTWISE); + + new StateAssignmentOperation( + 0, new HashSet<>(vertices.values()), states, false, true, oldEdgePatterns) + .assignStates(); + + // Verify input descriptors carry correct cross-parallelism values + int subtasksWithDesc = 0; + for (int subtask = 0; subtask < newSinkPar; subtask++) { + OperatorSubtaskState assigned = getAssignedState(vertices.get(sinkId), sinkId, subtask); + InflightDataRescalingDescriptor inputDesc = assigned.getInputRescalingDescriptor(); + + if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) { + subtasksWithDesc++; + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + inputDesc.getGateOrPartitionDescriptor(0); + assertThat(gateDesc.getPointwiseRescaleParams().getOldUpParallelism()) + .isEqualTo(oldSourcePar); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDownParallelism()) + .isEqualTo(oldSinkPar); + assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()) + .isEqualTo(newSourcePar); + assertThat(gateDesc.getPointwiseRescaleParams().getNewDownParallelism()) + .isEqualTo(newSinkPar); + assertThat(gateDesc.getPointwiseRescaleParams().getOldDistributionPattern()) + .isEqualTo(DistributionPattern.POINTWISE); + } + } + assertThat(subtasksWithDesc).isGreaterThan(0); + + // Verify output descriptors + int sourceSubtasksWithDesc = 0; + for (int subtask = 0; subtask < newSourcePar; subtask++) { + OperatorSubtaskState assigned = + getAssignedState(vertices.get(sourceId), sourceId, subtask); + InflightDataRescalingDescriptor outputDesc = assigned.getOutputRescalingDescriptor(); + + if (!outputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) { + sourceSubtasksWithDesc++; + InflightDataGateOrPartitionRescalingDescriptor partDesc = + outputDesc.getGateOrPartitionDescriptor(0); + assertThat(partDesc.getPointwiseRescaleParams().getOldUpParallelism()) + .isEqualTo(oldSourcePar); + assertThat(partDesc.getPointwiseRescaleParams().getOldDownParallelism()) + .isEqualTo(oldSinkPar); + assertThat(partDesc.getPointwiseRescaleParams().getNewUpParallelism()) + .isEqualTo(newSourcePar); + assertThat(partDesc.getPointwiseRescaleParams().getNewDownParallelism()) + .isEqualTo(newSinkPar); + } + } + assertThat(sourceSubtasksWithDesc).isGreaterThan(0); + } + + // ===== Test 8: No EdgeDistributionPatternSnapshot -> treated as A2A (backward compat) ===== + @Test + void testNullEdgePatternsDefaultsToAllToAll() throws Exception { + // old: source(2) --?--> sink(2), new: source(3) --A2A--> sink(3) + OperatorID sourceId = new OperatorID(); + OperatorID sinkId = new OperatorID(); + int oldPar = 2; + int newPar = 3; + + JobVertex source = createJobVertex(sourceId, newPar); + JobVertex sink = createJobVertex(sinkId, newPar); + connectAllToAll(source, sink); + + Map vertices = toExecutionVertices(source, sink); + Map states = + buildStatesWithChannelState(sourceId, sinkId, oldPar); + + // No edge pattern snapshot (null) - backward compatible + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, true, null) + .assignStates(); + + // Should produce standard A2A descriptors (newUpParallelism == 0) + for (int subtask = 0; subtask < newPar; subtask++) { + OperatorSubtaskState sinkState = + getAssignedState(vertices.get(sinkId), sinkId, subtask); + InflightDataRescalingDescriptor inputDesc = sinkState.getInputRescalingDescriptor(); + + if (!inputDesc.equals(InflightDataRescalingDescriptor.NO_RESCALE)) { + InflightDataGateOrPartitionRescalingDescriptor gateDesc = + inputDesc.getGateOrPartitionDescriptor(0); + assertThat(gateDesc.getPointwiseRescaleParams().getNewUpParallelism()).isEqualTo(0); + } + } + } + + // ===== Helper methods ===== + + private JobVertex createJobVertex(OperatorID operatorID, int parallelism) { + JobVertex jobVertex = + new JobVertex( + operatorID.toHexString(), + new JobVertexID(), + Collections.singletonList( + OperatorIDPair.of(operatorID, operatorID, null, null))); + jobVertex.setInvokableClass(NoOpInvokable.class); + jobVertex.setParallelism(parallelism); + return jobVertex; + } + + private void connectPointwise(JobVertex upstream, JobVertex downstream) { + final JobEdge jobEdge = + JobVertexConnectionUtils.connectNewDataSetAsInput( + downstream, + upstream, + DistributionPattern.POINTWISE, + ResultPartitionType.PIPELINED); + jobEdge.setDownstreamSubtaskStateMapper(ROUND_ROBIN); + jobEdge.setUpstreamSubtaskStateMapper(ROUND_ROBIN); + } + + private void connectAllToAll(JobVertex upstream, JobVertex downstream) { + final JobEdge jobEdge = + JobVertexConnectionUtils.connectNewDataSetAsInput( + downstream, + upstream, + DistributionPattern.ALL_TO_ALL, + ResultPartitionType.PIPELINED); + jobEdge.setDownstreamSubtaskStateMapper(ROUND_ROBIN); + jobEdge.setUpstreamSubtaskStateMapper(ROUND_ROBIN); + } + + private Map toExecutionVertices(JobVertex... jobVertices) + throws JobException, JobExecutionException { + JobGraph jobGraph = JobGraphTestUtils.streamingJobGraph(jobVertices); + ExecutionGraph eg = + TestingDefaultExecutionGraphBuilder.newBuilder() + .setJobGraph(jobGraph) + .build(EXECUTOR_EXTENSION.getExecutor()); + return Arrays.stream(jobVertices) + .collect( + Collectors.toMap( + jobVertex -> + jobVertex.getOperatorIDs().get(0).getGeneratedOperatorID(), + jobVertex -> { + try { + return eg.getJobVertex(jobVertex.getID()); + } catch (Exception e) { + throw new RuntimeException(e); + } + })); + } + + private Map buildStatesWithChannelState( + OperatorID sourceId, OperatorID sinkId, int oldParallelism) { + Map states = new HashMap<>(); + Random random = new Random(42); + + OperatorState sourceState = new OperatorState(null, null, sourceId, oldParallelism, MAX_P); + for (int i = 0; i < oldParallelism; i++) { + sourceState.putState( + i, + OperatorSubtaskState.builder() + .setResultSubpartitionState( + new StateObjectCollection<>( + Collections.singletonList( + createNewResultSubpartitionStateHandle( + 10, 0, random)))) + .build()); + } + states.put(sourceId, sourceState); + + OperatorState sinkState = new OperatorState(null, null, sinkId, oldParallelism, MAX_P); + for (int i = 0; i < oldParallelism; i++) { + sinkState.putState( + i, + OperatorSubtaskState.builder() + .setInputChannelState( + new StateObjectCollection<>( + Collections.singletonList( + createNewInputChannelStateHandle( + 10, 0, random)))) + .build()); + } + states.put(sinkId, sinkState); + + return states; + } + + private EdgeDistributionPatternSnapshot buildEdgePatterns( + OperatorID sourceId, DistributionPattern pattern) { + Map patterns = new HashMap<>(); + patterns.put(sourceId, new DistributionPattern[] {pattern}); + return new EdgeDistributionPatternSnapshot(patterns); + } + + private OperatorSubtaskState getAssignedState( + ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) { + return executionJobVertex + .getTaskVertices()[subtaskIdx] + .getCurrentExecutionAttempt() + .getTaskRestore() + .getTaskStateSnapshot() + .getSubtaskStateByOperatorID(operatorId); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java index 9c4aab0bc7a5d..5b8badef9d967 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java @@ -83,7 +83,8 @@ private InputChannelRecoveredStateHandler buildInputChannelStateHandler( .MappingType.IDENTITY) }), null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + 0); } private InputChannelRecoveredStateHandler buildMultiChannelHandler() { @@ -111,7 +112,8 @@ private InputChannelRecoveredStateHandler buildMultiChannelHandler() { .MappingType.RESCALING) }), null, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + 0); } /** Builds a handler in filtering mode (non-null filtering handler, no-op stub). */ @@ -136,7 +138,8 @@ private InputChannelRecoveredStateHandler buildFilteringInputChannelStateHandler .MappingType.IDENTITY) }), stubFilteringHandler, - MemoryManager.DEFAULT_PAGE_SIZE); + MemoryManager.DEFAULT_PAGE_SIZE, + 0); } @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java index 91d4800e6736a..cbb2ca60b34ee 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ResultSubpartitionRecoveredStateHandlerTest.java @@ -70,7 +70,8 @@ private ResultSubpartitionRecoveredStateHandler buildResultStateHandler( InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType.IDENTITY) - })); + }), + 0); } @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForceUnalignedSupportTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForceUnalignedSupportTest.java new file mode 100644 index 0000000000000..b78b8a8ed31b8 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForceUnalignedSupportTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.partitioner; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link StreamPartitioner#isSupportsUnalignedCheckpoint(boolean)} — specifically the + * {@code forceUnaligned} flag introduced to allow unaligned checkpoints on forward (pointwise) + * edges while keeping broadcast edges aligned. + */ +class ForceUnalignedSupportTest { + + @Test + void testForwardPartitionerRespectsForceUnaligned() { + ForwardPartitioner partitioner = new ForwardPartitioner<>(); + + // Without force: pointwise edges block unaligned checkpoints. + assertThat(partitioner.isSupportsUnalignedCheckpoint(false)).isFalse(); + // With force: pointwise edges are allowed to use unaligned checkpoints. + assertThat(partitioner.isSupportsUnalignedCheckpoint(true)).isTrue(); + } + + @Test + void testRescalePartitionerRespectsForceUnaligned() { + RescalePartitioner partitioner = new RescalePartitioner<>(); + + assertThat(partitioner.isSupportsUnalignedCheckpoint(false)).isFalse(); + assertThat(partitioner.isSupportsUnalignedCheckpoint(true)).isTrue(); + } + + @Test + void testBroadcastPartitionerIgnoresForceUnaligned() { + BroadcastPartitioner partitioner = new BroadcastPartitioner<>(); + + // Broadcast edges must never support unaligned checkpoints, regardless of force. + assertThat(partitioner.isSupportsUnalignedCheckpoint(false)).isFalse(); + assertThat(partitioner.isSupportsUnalignedCheckpoint(true)).isFalse(); + } + + @Test + void testRebalancePartitionerForceDoesNotFlipBehavior() { + RebalancePartitioner partitioner = new RebalancePartitioner<>(); + + boolean withoutForce = partitioner.isSupportsUnalignedCheckpoint(false); + boolean withForce = partitioner.isSupportsUnalignedCheckpoint(true); + + assertThat(withoutForce).isTrue(); + assertThat(withForce).isEqualTo(withoutForce); + } + + @Test + void testKeyGroupPartitionerForceDoesNotFlipBehavior() { + KeyGroupStreamPartitioner partitioner = + new KeyGroupStreamPartitioner<>((KeySelector) value -> value, 128); + + boolean withoutForce = partitioner.isSupportsUnalignedCheckpoint(false); + boolean withForce = partitioner.isSupportsUnalignedCheckpoint(true); + + assertThat(withoutForce).isTrue(); + assertThat(withForce).isEqualTo(withoutForce); + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointShuffleChangeITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointShuffleChangeITCase.java new file mode 100644 index 0000000000000..10a1ec1fffbfd --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointShuffleChangeITCase.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.checkpointing; + +import org.apache.flink.api.common.JobExecutionResult; +import org.apache.flink.api.common.JobStatus; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.connector.source.util.ratelimit.RateLimiterStrategy; +import org.apache.flink.configuration.CheckpointingOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.MemorySize; +import org.apache.flink.configuration.TaskManagerOptions; +import org.apache.flink.connector.datagen.source.DataGeneratorSource; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.legacy.RichSinkFunction; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameter; +import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension; +import org.apache.flink.testutils.junit.extensions.parameterized.Parameters; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.atomic.AtomicLong; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration test for changing distribution patterns (POINTWISE / ALL_TO_ALL) and parallelism + * between runs when restoring from an unaligned checkpoint. + * + *

Each parameterized case is a 9-tuple: (desc, oldUpPar, oldDownPar, oldPattern, newUpPar, + * newDownPar, newPattern, recoverOutputOnDownstream, checkpointDuringRecovery). Run 1 emits a + * bounded monotonic sequence {@code [0, TOTAL_COUNT)} XORed with a header, takes an in-flight + * unaligned checkpoint, and is cancelled. Run 2 restores with the new topology and runs to + * completion. The verifying sink asserts: (a) no header corruption (mis-routed buffers), (b) no + * per-subtask duplicates after restore, and (c) every value in {@code [0, TOTAL_COUNT)} arrives at + * the sink. + * + *

The test is parameterized over three recovery configurations: + * + *

+ */ +@ExtendWith(ParameterizedTestExtension.class) +class UnalignedCheckpointShuffleChangeITCase extends UnalignedCheckpointTestBase { + + private static final int TOTAL_COUNT = 20_000; + + private static final ConcurrentSkipListSet SEEN_GLOBAL = new ConcurrentSkipListSet<>(); + private static final AtomicLong POST_RESTORE_DUPLICATES = new AtomicLong(); + private static final AtomicLong CORRUPTIONS = new AtomicLong(); + + enum EdgePattern { + POINTWISE, + ALL_TO_ALL + } + + @Parameter(0) + public String desc; + + @Parameter(1) + public int oldUpPar; + + @Parameter(2) + public int oldDownPar; + + @Parameter(3) + public EdgePattern oldPattern; + + @Parameter(4) + public int newUpPar; + + @Parameter(5) + public int newDownPar; + + @Parameter(6) + public EdgePattern newPattern; + + @Parameter(7) + public boolean recoverOutputOnDownstream; + + @Parameter(8) + public boolean checkpointDuringRecovery; + + @Parameters(name = "{0}") + public static Collection parameters() { + EdgePattern a2a = EdgePattern.ALL_TO_ALL; + EdgePattern pw = EdgePattern.POINTWISE; + + Object[][] baseCases = + new Object[][] { + // ---- A2A -> A2A ---- + {"A2A->A2A same", 3, 3, a2a, 3, 3, a2a}, + {"A2A->A2A up-up", 3, 3, a2a, 4, 3, a2a}, + {"A2A->A2A up-up wide", 3, 3, a2a, 11, 3, a2a}, + {"A2A->A2A up-down", 4, 3, a2a, 3, 3, a2a}, + {"A2A->A2A up-down wide", 11, 3, a2a, 3, 3, a2a}, + {"A2A->A2A down-up", 3, 3, a2a, 3, 4, a2a}, + {"A2A->A2A down-up wide", 3, 3, a2a, 3, 11, a2a}, + {"A2A->A2A down-down", 3, 4, a2a, 3, 3, a2a}, + {"A2A->A2A down-down wide", 3, 11, a2a, 3, 3, a2a}, + // ---- A2A -> PW ---- + {"A2A->PW same", 3, 3, a2a, 3, 3, pw}, + {"A2A->PW up-up", 3, 3, a2a, 4, 3, pw}, + {"A2A->PW up-up wide", 3, 3, a2a, 11, 3, pw}, + {"A2A->PW up-down", 4, 3, a2a, 3, 3, pw}, + {"A2A->PW up-down wide", 11, 3, a2a, 3, 3, pw}, + {"A2A->PW down-up", 3, 3, a2a, 3, 4, pw}, + {"A2A->PW down-up wide", 3, 3, a2a, 3, 11, pw}, + {"A2A->PW down-down", 3, 4, a2a, 3, 3, pw}, + {"A2A->PW down-down wide", 3, 11, a2a, 3, 3, pw}, + // ---- PW -> A2A ---- + {"PW->A2A same", 3, 3, pw, 3, 3, a2a}, + {"PW->A2A up-up", 3, 3, pw, 4, 3, a2a}, + {"PW->A2A up-up wide", 3, 3, pw, 11, 3, a2a}, + {"PW->A2A up-down", 4, 3, pw, 3, 3, a2a}, + {"PW->A2A up-down wide", 11, 3, pw, 3, 3, a2a}, + {"PW->A2A down-up", 3, 3, pw, 3, 4, a2a}, + {"PW->A2A down-up wide", 3, 3, pw, 3, 11, a2a}, + {"PW->A2A down-down", 3, 4, pw, 3, 3, a2a}, + {"PW->A2A down-down wide", 3, 11, pw, 3, 3, a2a}, + // ---- PW -> PW ---- + {"PW->PW same", 3, 3, pw, 3, 3, pw}, + {"PW->PW up-up", 3, 3, pw, 4, 3, pw}, + {"PW->PW up-up wide", 3, 3, pw, 11, 3, pw}, + {"PW->PW up-down", 4, 3, pw, 3, 3, pw}, + {"PW->PW up-down wide", 11, 3, pw, 3, 3, pw}, + {"PW->PW down-up", 3, 3, pw, 3, 4, pw}, + {"PW->PW down-up wide", 3, 3, pw, 3, 11, pw}, + {"PW->PW down-down", 3, 4, pw, 3, 3, pw}, + {"PW->PW down-down wide", 3, 11, pw, 3, 3, pw}, + {"PW->PW both-up", 3, 3, pw, 4, 4, pw}, + {"PW->PW both-down", 4, 4, pw, 3, 3, pw}, + }; + + // Three recovery configurations: + // 1. Neither enabled (false, false) + // 2. Only recoverOutputOnDownstream (true, false) + // 3. Both enabled (true, true) + boolean[][] recoveryConfigs = + new boolean[][] { + {false, false}, + {true, false}, + {true, true}, + }; + String[] configLabels = new String[] {"base", "recoverOut", "recoverOut+cpDuringRecov"}; + + List result = new ArrayList<>(); + for (Object[] base : baseCases) { + for (int c = 0; c < recoveryConfigs.length; c++) { + Object[] row = new Object[9]; + row[0] = base[0] + " [" + configLabels[c] + "]"; + System.arraycopy(base, 1, row, 1, 6); + row[7] = recoveryConfigs[c][0]; + row[8] = recoveryConfigs[c][1]; + result.add(row); + } + } + return result; + } + + @BeforeEach + void setup() { + SEEN_GLOBAL.clear(); + POST_RESTORE_DUPLICATES.set(0); + CORRUPTIONS.set(0); + } + + @TestTemplate + void testShuffleChangeWithUnalignedCheckpoint(TestInfo testInfo) throws Exception { + // Phase 1: run with old topology, generate checkpoint, cancel + UnalignedSettings phase1Settings = + createSettings(oldPattern, oldUpPar, oldDownPar) + .setCheckpointGenerationMode( + CheckpointGenerationMode.WAIT_FOR_CHECKPOINT_AND_CANCEL); + String checkpointPath = super.execute(phase1Settings, testInfo); + assertThat(checkpointPath) + .as("Phase 1 must generate a checkpoint for restore test to be valid.") + .isNotNull(); + + // Phase 2: restore with new topology and run to completion + UnalignedSettings phase2Settings = + createSettings(newPattern, newUpPar, newDownPar) + .setRestoreCheckpoint(checkpointPath) + .setExpectedFinalJobStatus(JobStatus.FINISHED); + super.execute(phase2Settings, testInfo); + + assertThat(CORRUPTIONS.get()) + .as("records with bad header (mis-routed or wrongly-decoded buffers)") + .isZero(); + assertThat(POST_RESTORE_DUPLICATES.get()) + .as("post-restore duplicates (replayed buffer delivered twice to a sink subtask)") + .isZero(); + assertThat(SEEN_GLOBAL) + .as("union of all sink output must cover [0, %d)", TOTAL_COUNT) + .hasSize(TOTAL_COUNT); + assertThat(SEEN_GLOBAL.first()).isZero(); + assertThat(SEEN_GLOBAL.last()).isEqualTo((long) (TOTAL_COUNT - 1)); + } + + @Override + protected void checkCounters(JobExecutionResult result) {} + + private UnalignedSettings createSettings(EdgePattern pattern, int upPar, int downPar) { + ShuffleChangeSettings settings = + new ShuffleChangeSettings( + new ShuffleChangeDagCreator(pattern, upPar, downPar), + recoverOutputOnDownstream, + checkpointDuringRecovery); + settings.setParallelism(Math.max(upPar, downPar)); + settings.setChannelTypes(ChannelType.LOCAL); + return settings; + } + + // ------------------------------------------------------------------------- + // Settings + // ------------------------------------------------------------------------- + + private static class ShuffleChangeSettings extends UnalignedSettings { + private final boolean recoverOutputOnDownstream; + private final boolean checkpointDuringRecovery; + + ShuffleChangeSettings( + DagCreator dagCreator, + boolean recoverOutputOnDownstream, + boolean checkpointDuringRecovery) { + super(dagCreator); + this.recoverOutputOnDownstream = recoverOutputOnDownstream; + this.checkpointDuringRecovery = checkpointDuringRecovery; + } + + @Override + public Configuration getConfiguration(File checkpointDir) { + Configuration conf = super.getConfiguration(checkpointDir); + conf.set(TaskManagerOptions.MEMORY_SEGMENT_SIZE, MemorySize.parse("1 kb")); + conf.set(CheckpointingOptions.MAX_RETAINED_CHECKPOINTS, 50); + conf.set( + CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, + recoverOutputOnDownstream); + conf.set( + CheckpointingOptions.CHECKPOINTING_DURING_RECOVERY_ENABLED, + checkpointDuringRecovery); + return conf; + } + } + + // ------------------------------------------------------------------------- + // DAG creation + // ------------------------------------------------------------------------- + + private static class ShuffleChangeDagCreator implements DagCreator { + private final EdgePattern pattern; + private final int upPar; + private final int downPar; + + ShuffleChangeDagCreator(EdgePattern pattern, int upPar, int downPar) { + this.pattern = pattern; + this.upPar = upPar; + this.downPar = downPar; + } + + @Override + public void create( + StreamExecutionEnvironment env, + int minCheckpoints, + boolean slotSharing, + int expectedFailuresUntilSourceFinishes, + long sourceSleepMs) { + env.disableOperatorChaining(); + + DataGeneratorSource source = + new DataGeneratorSource<>( + UnalignedCheckpointTestBase::withHeader, + TOTAL_COUNT, + RateLimiterStrategy.perSecond(5000), + Types.LONG); + + DataStream sourceStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "source") + .uid("source") + .setParallelism(upPar); + + DataStream shuffled = + (pattern == EdgePattern.POINTWISE) + ? sourceStream.rescale() + : sourceStream.rebalance(); + + shuffled.map( + x -> { + Thread.sleep(1); + return x; + }) + .returns(Types.LONG) + .name("map") + .uid("map") + .setParallelism(downPar) + .addSink(new VerifyingSink()) + .uid("sink") + .name("sink") + .setParallelism(downPar); + } + } + + // ------------------------------------------------------------------------- + // Verification sink + // ------------------------------------------------------------------------- + + private static class VerifyingSink extends RichSinkFunction + implements CheckpointedFunction { + private transient ListState seenState; + private transient HashSet seen; + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + seenState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("seen", Types.LONG)); + seen = new HashSet<>(); + for (Long v : seenState.get()) { + seen.add(v); + } + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception { + seenState.update(new ArrayList<>(seen)); + } + + @Override + public void invoke(Long value, Context context) { + try { + long base = withoutHeader(value); + if (!seen.add(base)) { + POST_RESTORE_DUPLICATES.incrementAndGet(); + } + SEEN_GLOBAL.add(base); + } catch (IllegalArgumentException e) { + CORRUPTIONS.incrementAndGet(); + } + } + } +}