Skip to content

Commit

Permalink
Add operand type invariant to torch.overwrite.tensor.contents (#606)
Browse files Browse the repository at this point in the history
This commit adds the invariant to the op `torch.overwrite.tensor.contents` that
both of its operands have the same shape and size. In order to
maintain the invariant, special handling of this op is added to the
`RefineTypes` pass.
  • Loading branch information
ramiro050 authored Feb 22, 2022
1 parent 5dbace2 commit ba29d4f
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 22 deletions.
14 changes: 9 additions & 5 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -884,8 +884,10 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
let verifier = "return ::verify(*this);";
}

def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
AllowsTypeRefinement
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
"value", "overwritten",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
]> {
let summary = "Ovewrite the contents of tensor with values from another.";
let description = [{
Expand All @@ -895,10 +897,12 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
Immediately after this op has completed, indexing `overwritten` will result
in identical values as indexing into `value`. Of course, later ops
might mutate `overwritten`, so this relationship need not hold for the
entire program.
entire program. This op only updates the tensor data (not metadata).
In other words, it cannot change the (dynamic) shape of the overwritten tensor.

This op has undefined behavior if the two tensors have different
shapes or dtypes.
This op does not have the AllowsTypeRefinement trait because the types of the
two operands are coupled. Only places that know how to simultaneously update
both types should be changing the type of this op.
}];
let arguments = (ins
Torch_ValueTensorType:$value,
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
if (user->getBlock() != copy->getBlock())
return failure();
// We can only analyze these ops or view-like ops.
if (isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
if (isa<CopyToValueTensorOp, OverwriteTensorContentsOp>(user))
foundNonViewLikeOpUser = true;
else if (!isViewLikeOp(user))
return failure();
Expand All @@ -71,9 +71,10 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
for (Operation *user : users) {
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
currentlyHeldValueTensor = overwriteTensor.value();
rewriter.eraseOp(overwriteTensor);
} else if (auto overwriteTensorContents =
dyn_cast<OverwriteTensorContentsOp>(user)) {
currentlyHeldValueTensor = overwriteTensorContents.value();
rewriter.eraseOp(overwriteTensorContents);
} else if (isViewLikeOp(user)) {
// This case currently only handles view-like ops that have one tensor
// input and one tensor output.
Expand Down
23 changes: 21 additions & 2 deletions lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

// Create an overwrite in a manner that preserves the
// `OverwriteTensorContentsOp` invariant that both arguments
// must have the same shape and dtype.
static void createOverwriteTensorContents(PatternRewriter &rewriter,
Location loc, Value overwriterTensor,
Value overwrittenTensor) {
Type overwriterTensorType = overwriterTensor.getType();
Type overwrittenTensorType = overwrittenTensor.getType()
.dyn_cast<NonValueTensorType>()
.getWithValueSemantics();
if (overwriterTensorType != overwrittenTensorType) {
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
loc, overwrittenTensorType, overwriterTensor);
}
rewriter.create<OverwriteTensorContentsOp>(loc, overwriterTensor,
overwrittenTensor);
}

namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on
// immutable tensors.
Expand Down Expand Up @@ -143,7 +161,7 @@ class ReduceNonValueSemanticOps : public RewritePattern {

auto tensor =
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(loc, tensor, op->getOperand(0));
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
Expand Down Expand Up @@ -180,7 +198,8 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern {
Operation *newOp = rewriter.createOperation(state);
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));

return success();
Expand Down
37 changes: 37 additions & 0 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2105,7 +2105,44 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
if (isSafeToRefineOperandInPlace(use, refinedType)) {
use->set(newTypedValue);
continue;
} else if (auto overwriteTensorContents =
dyn_cast<OverwriteTensorContentsOp>(
use->getOwner())) {
// `OverwriteTensorContentsOp` has special handling here because
// it requires that both of its operands always have the same
// shape and dtype.
//
// WARNING: In order to simplify the implementation, the type
// used for both operands is the type of the overwritten tensor.
// A better way of doing this would be to join the two operand
// types to create the most specific type possible and use that
// for both arguments, allowing static sizes to always propagate.
const unsigned overwriterOperandIndex = 0;
const unsigned overwrittenOperandIndex = 1;
unsigned operandNumber = use->getOperandNumber();
if (operandNumber != overwrittenOperandIndex)
continue;

Location loc = overwriteTensorContents.getLoc();
Value overwriterTensor = overwriteTensorContents.value();
Type overwriterTensorType = overwriterTensor.getType();
Type overwrittenTensorType = newTypedValue.getType()
.dyn_cast<NonValueTensorType>()
.getWithValueSemantics();
if (overwriterTensorType == overwrittenTensorType)
continue;

{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(overwriteTensorContents);
Value castedOverwriterTensor = b.create<TensorStaticInfoCastOp>(
loc, overwrittenTensorType, overwriterTensor);
overwriteTensorContents.setOperand(overwriterOperandIndex,
castedOverwriterTensor);
}
continue;
}

// If needed, create a value of the original type to appease users
// that cannot accept the new type.
if (!oldTypedValue) {
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/Torch/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,13 @@ builtin.func @torch.prim.ListConstruct() {
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<!torch.tensor>
return
}

// -----

builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>
%1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32>
return %1 : !torch.vtensor<[1],f32>
}
14 changes: 7 additions & 7 deletions test/Dialect/Torch/maximize-value-semantics.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}
Expand All @@ -34,12 +34,12 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor

// Overwrite with %arg1
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor

// Overwrite with %arg2
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor

return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
Expand All @@ -52,18 +52,18 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%value_result = torch.copy.to_vtensor %result : !torch.vtensor
return %value_result : !torch.vtensor
}

// CHECK-LABEL: func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor
// CHECK: torch.overwrite.tensor.contents
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
"some.op"(%0) : (!torch.tensor) -> ()
%result = torch.copy.to_vtensor %0 : !torch.vtensor
return %result : !torch.vtensor
Expand All @@ -76,7 +76,7 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
torch.prim.If %cond -> () {
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.prim.If.yield
} else {
torch.prim.If.yield
Expand Down
8 changes: 4 additions & 4 deletions test/Dialect/Torch/reduce-op-variants.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
// being applied in sequence.
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
%c1 = torch.constant.int 1
Expand Down Expand Up @@ -138,7 +138,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
Expand All @@ -153,7 +153,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
%generator = torch.constant.none
Expand All @@ -169,7 +169,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
%value = torch.constant.int 1
Expand Down
32 changes: 32 additions & 0 deletions test/Dialect/Torch/refine-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1157,3 +1157,35 @@ func @torch.aten.BinaryBroadcasting(%arg0: !torch.vtensor<[5,4,3,3,1],f32>, %arg
%0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[5,4,3,3,1],f32>, !torch.vtensor<[?,3,1,2],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
}

// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32>, !torch.tensor<[2],f32>
func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor
%static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32>
return %result : !torch.vtensor<[2],f32>
}

// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32> to !torch.vtensor<[?],f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32>, !torch.tensor<[?],f32>
func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor
%dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32>
return %result : !torch.vtensor<[?],f32>
}

0 comments on commit ba29d4f

Please sign in to comment.