Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,12 @@ class ShapePropagator : public PropertyPropBase {
"aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"},
[](Node* node) -> type_vec_t {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
auto dtype = type->scalarType();
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if (maybe_dtype_option && maybe_dtype_option->isInt()) {
dtype = maybe_dtype_option->toScalarType();
}

auto device = getDeviceFromValue(node->namedInput(attr::device));
if (type->dim()) {
auto scalarType =
Expand All @@ -1446,7 +1452,7 @@ class ShapePropagator : public PropertyPropBase {
scalarType = type->scalarType();
}
return {TensorType::create(
scalarType,
dtype,
device,
type->dim(),
/*requires_grad=*/c10::nullopt)
Expand Down
12 changes: 12 additions & 0 deletions pytorch_blade/tests/torchscript/since_1_10.graph
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,15 @@ graph(%p1 : Float(*, *, *, device=cpu)):
// CHECK: Float(*, *, *, device=cuda) = aten::to
%cuda_zeros : Tensor = aten::to(%new_zeros, %cuda, %none, %false, %false)
return (%cuda_zeros)

// aten::to.prim_Device with dtype
// CHECK-LABEL: graph
graph(%p1 : Bool(device=cuda:0)):
%1 : Device = prim::Constant[value="cuda:1"]()
%2 : int = prim::Constant[value=5]()
%3 : bool = prim::Constant[value=0]()
// CHECK: Half(device=cuda:1) = aten::to(%p1, %1, %2, %3, %3)
%5 : Tensor = aten::to(%p1, %1, %2, %3, %3)
return (%2)


18 changes: 18 additions & 0 deletions tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ LogicalResult PadOpConvert::matchAndRewrite(mhlo::PadOp op,
} // namespace

namespace {

struct SliceOpConvert : public OpRewritePattern<mhlo::SliceOp> {
explicit SliceOpConvert(MLIRContext* context) : OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mhlo::SliceOp op,
Expand All @@ -129,6 +130,22 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op,
auto operand = op.getOperand();
rewriter.replaceOpWithNewOp<mhlo::RealDynamicSliceOp>(
op, op.getType(), operand, startIndices, limitIndices, strides);

return success();
}

struct ArithConstOpConvert : public OpRewritePattern<arith::ConstantOp> {
explicit ArithConstOpConvert(MLIRContext* context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(arith::ConstantOp op,
PatternRewriter& rewriter) const override;
};

LogicalResult ArithConstOpConvert::matchAndRewrite(
arith::ConstantOp op, PatternRewriter& rewriter) const {
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType or resultType.getRank() < 1) return failure();
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, op.getValue());
return success();
}
} // namespace
Expand All @@ -139,6 +156,7 @@ struct MhloDecompositionRewriterPass
func::FuncOp func = getOperation();
MLIRContext* ctx = func.getContext();
RewritePatternSet patterns(ctx);
patterns.insert<ArithConstOpConvert>(ctx);
patterns.insert<BatchNormInferenceOpConvert>(ctx);
patterns.insert<PadOpConvert>(ctx);
patterns.insert<SliceOpConvert>(ctx);
Expand Down