Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1811,13 +1811,22 @@ private OptionalLong restoreLatestCheckpointedStateInternal(
vertexFinishedStateChecker.validateOperatorsFinishedState();
}

EdgeDistributionPatternSnapshot oldEdgePatterns = null;
for (MasterState ms : latest.getMasterHookStates()) {
if (EdgeDistributionPatternSnapshot.HOOK_IDENTIFIER.equals(ms.name())) {
oldEdgePatterns = EdgeDistributionPatternSnapshot.fromBytes(ms.bytes());
break;
}
}

StateAssignmentOperation stateAssignmentOperation =
new StateAssignmentOperation(
latest.getCheckpointID(),
tasks,
operatorStates,
allowNonRestoredState,
recoverOutputOnDownstreamTask);
recoverOutputOnDownstreamTask,
oldEdgePatterns);

stateAssignmentOperation.assignStates();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.checkpoint;

import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.OperatorID;

import javax.annotation.Nullable;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Snapshot of per-operator output edge {@link DistributionPattern}s, persisted via {@link
* MasterTriggerRestoreHook} / {@link MasterState} so that the correct old distribution pattern is
* available when restoring from a checkpoint after a shuffle-mode change.
*
* <p>Key = outputOperatorID (generated, chain-tail), Value = one {@link DistributionPattern} per
* output partition (indexed by partitionIndex).
*/
public class EdgeDistributionPatternSnapshot {

public static final String HOOK_IDENTIFIER = "__flink_edge_distribution_patterns__";

private static final int FORMAT_VERSION = 1;
private static final byte PATTERN_POINTWISE = 0;
private static final byte PATTERN_ALL_TO_ALL = 1;

private final Map<OperatorID, DistributionPattern[]> outputEdgePatterns;

public EdgeDistributionPatternSnapshot(
Map<OperatorID, DistributionPattern[]> outputEdgePatterns) {
this.outputEdgePatterns = Collections.unmodifiableMap(new HashMap<>(outputEdgePatterns));
}

@Nullable
public DistributionPattern[] getOutputPatterns(OperatorID outputOperatorID) {
DistributionPattern[] patterns = outputEdgePatterns.get(outputOperatorID);
return patterns != null ? patterns.clone() : null;
}

@Nullable
public DistributionPattern getOutputPattern(OperatorID outputOperatorID, int partitionIndex) {
DistributionPattern[] patterns = outputEdgePatterns.get(outputOperatorID);
if (patterns == null || partitionIndex < 0 || partitionIndex >= patterns.length) {
return null;
}
return patterns[partitionIndex];
}

public byte[] toBytes() throws IOException {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos)) {
dos.writeInt(FORMAT_VERSION);
dos.writeInt(outputEdgePatterns.size());
for (Map.Entry<OperatorID, DistributionPattern[]> entry :
outputEdgePatterns.entrySet()) {
OperatorID opId = entry.getKey();
dos.writeLong(opId.getLowerPart());
dos.writeLong(opId.getUpperPart());
DistributionPattern[] patterns = entry.getValue();
dos.writeInt(patterns.length);
for (DistributionPattern p : patterns) {
dos.writeByte(
p == DistributionPattern.ALL_TO_ALL
? PATTERN_ALL_TO_ALL
: PATTERN_POINTWISE);
}
}
dos.flush();
return baos.toByteArray();
}
}

public static EdgeDistributionPatternSnapshot fromBytes(byte[] bytes) throws IOException {
try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
DataInputStream dis = new DataInputStream(bais)) {
int version = dis.readInt();
if (version != FORMAT_VERSION) {
throw new IOException(
"Unsupported EdgeDistributionPatternSnapshot version: " + version);
}
int numOperators = dis.readInt();
Map<OperatorID, DistributionPattern[]> map = new HashMap<>(numOperators);
for (int i = 0; i < numOperators; i++) {
OperatorID opId = new OperatorID(dis.readLong(), dis.readLong());
int numPartitions = dis.readInt();
DistributionPattern[] patterns = new DistributionPattern[numPartitions];
for (int j = 0; j < numPartitions; j++) {
byte b = dis.readByte();
if (b == PATTERN_ALL_TO_ALL) {
patterns[j] = DistributionPattern.ALL_TO_ALL;
} else if (b == PATTERN_POINTWISE) {
patterns[j] = DistributionPattern.POINTWISE;
} else {
throw new IOException("Unknown DistributionPattern byte in snapshot: " + b);
}
}
map.put(opId, patterns);
}
return new EdgeDistributionPatternSnapshot(map);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.flink.runtime.checkpoint;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.Arrays;
Expand Down Expand Up @@ -55,6 +57,11 @@ public RescaleMappings getChannelMapping(int gateOrPartitionIndex) {
return gateOrPartitionDescriptors[gateOrPartitionIndex].getRescaleMappings();
}

public InflightDataGateOrPartitionRescalingDescriptor getGateOrPartitionDescriptor(
int gateOrPartitionIndex) {
return gateOrPartitionDescriptors[gateOrPartitionIndex];
}

public boolean isAmbiguous(int gateOrPartitionIndex, int oldSubtaskIndex) {
return gateOrPartitionDescriptors[gateOrPartitionIndex].ambiguousSubtaskIndexes.contains(
oldSubtaskIndex);
Expand Down Expand Up @@ -142,7 +149,7 @@ public RescaleMappings getRescaleMappings() {

/**
* Set when channels are merged because the connected operator has been rescaled for each
* gate/partition.
* gate/partition. Used for ALL_TO_ALL → ALL_TO_ALL path.
*/
private final RescaleMappings rescaledChannelsMappings;

Expand All @@ -151,6 +158,9 @@ public RescaleMappings getRescaleMappings() {

private final MappingType mappingType;

/** Bundled scalar topology parameters for POINTWISE-aware rescaling. */
private PointwiseRescaleParams rescaleParams;

/** Type of mapping which should be used for this in-flight data. */
public enum MappingType {
IDENTITY,
Expand All @@ -162,10 +172,25 @@ public InflightDataGateOrPartitionRescalingDescriptor(
RescaleMappings rescaledChannelsMappings,
Set<Integer> ambiguousSubtaskIndexes,
MappingType mappingType) {
this(
oldSubtaskIndexes,
rescaledChannelsMappings,
ambiguousSubtaskIndexes,
mappingType,
PointwiseRescaleParams.EMPTY);
}

public InflightDataGateOrPartitionRescalingDescriptor(
int[] oldSubtaskIndexes,
RescaleMappings rescaledChannelsMappings,
Set<Integer> ambiguousSubtaskIndexes,
MappingType mappingType,
PointwiseRescaleParams rescaleParams) {
this.oldSubtaskIndexes = oldSubtaskIndexes;
this.rescaledChannelsMappings = rescaledChannelsMappings;
this.ambiguousSubtaskIndexes = ambiguousSubtaskIndexes;
this.mappingType = mappingType;
this.rescaleParams = rescaleParams;
}

public int[] getOldSubtaskInstances() {
Expand All @@ -180,6 +205,21 @@ public boolean isIdentity() {
return mappingType == MappingType.IDENTITY;
}

public PointwiseRescaleParams getPointwiseRescaleParams() {
return rescaleParams;
}

public boolean isPointwiseRescaling() {
return !rescaleParams.equals(PointwiseRescaleParams.EMPTY);
}

private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
if (rescaleParams == null) {
rescaleParams = PointwiseRescaleParams.EMPTY;
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -193,13 +233,18 @@ public boolean equals(Object o) {
return Arrays.equals(oldSubtaskIndexes, that.oldSubtaskIndexes)
&& Objects.equals(rescaledChannelsMappings, that.rescaledChannelsMappings)
&& Objects.equals(ambiguousSubtaskIndexes, that.ambiguousSubtaskIndexes)
&& mappingType == that.mappingType;
&& mappingType == that.mappingType
&& Objects.equals(rescaleParams, that.rescaleParams);
}

@Override
public int hashCode() {
int result =
Objects.hash(rescaledChannelsMappings, ambiguousSubtaskIndexes, mappingType);
Objects.hash(
rescaledChannelsMappings,
ambiguousSubtaskIndexes,
mappingType,
rescaleParams);
result = 31 * result + Arrays.hashCode(oldSubtaskIndexes);
return result;
}
Expand All @@ -215,6 +260,8 @@ public String toString() {
+ ambiguousSubtaskIndexes
+ ", mappingType="
+ mappingType
+ ", rescaleParams="
+ rescaleParams
+ '}';
}
}
Expand All @@ -231,6 +278,12 @@ public int[] getOldSubtaskIndexes(int gateOrPartitionIndex) {
return EMPTY_INT_ARRAY;
}

@Override
public InflightDataGateOrPartitionRescalingDescriptor getGateOrPartitionDescriptor(
int gateOrPartitionIndex) {
return InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
}

@Override
public RescaleMappings getChannelMapping(int gateOrPartitionIndex) {
return RescaleMappings.SYMMETRIC_IDENTITY;
Expand Down
Loading