Eager mode: implement resize_ operation (#12004)

Add support for PyTorch `resize_` operation. The PyTorch API method is documented
here:

https://pytorch.org/docs/stable/generated/torch.Tensor.resize_.html

Implementation notes:

There are some implementation details that might deviate from
expectations:

  - As the Onnxruntime::tensor does not support resize operation, this
    functionality is supported on the TensorImpl by swapping out the
    backing tensor if the size changes.

  - In the ORT model the shape of the TensorImpl is defined by the
    backing onnxruntime::tensor, so it is not supported to have a
    TensorImpl with a different shape / size than the backing
    onnxruntime::tensor. This means when resizing to a smaller TensorImpl,
    other implementations might keep the same backing storage, ORT will
    re-allocate a new onnxruntime::tensor and copy over as many of the
    existing elements that fit. Functionally, you will end up with same
    output, but the underlying buffer will be re-allocated.

    A future change could be to allow ORTTensorImpl to have a different
    size / shape than the onnxrutime::tensor backing it, and then we
    could improve this behavior.

 The canonical CPU / CUDA implementations in PyTorch repository:
     CPU: aten/src/ATen/native/Resize.cpp
     CUDA: aten/src/ATen/native/cuda/Resize.cpp
This commit is contained in:
Jameson Miller 2022-06-30 22:14:37 -04:00 committed by GitHub
parent b858c2f725
commit 3e6b8d159a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 183 additions and 0 deletions

View file

@ -107,6 +107,7 @@ hand_implemented = {
"aten::_reshape_alias": SignatureOnly(),
"aten::view": SignatureOnly(),
"aten::_copy_from_and_resize": SignatureOnly(),
"aten::resize_": SignatureOnly(),
"aten::as_strided": SignatureOnly(),
# manually implement Slice using stride and offset.
"aten::slice.Tensor": SignatureOnly(),

View file

@ -346,6 +346,109 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at:
//#pragma endregion
/*
* Resize backing store of a TensorImpl.
*
* See notes for implementation details and potential differences from canonical implementations due to constraints in
* ORT model.
*
* If new size is the same size as existing tensor: reshape the existing tensor
* If new size is larger: allocate new memory and copy over existing elements. New memory is uninitialized.
* If new size is smaller: allocate a smaller backing tensor, and copy over
* as many elements as will fit.
*
* Notes:
* There are some implementation details that might deviate from expectations:
* - As the Onnxruntime::tensor does not support resize operation, this functionality is supported on the TensorImpl
* by swapping out the backing tensor if the size changes.
*
* - In the ORT model the shape of the TensorImpl is defined by the backing onnxruntime::tensor, so it is not supported
* to have a TensorImpl with a different shape / size than the backing onnxruntime::tensor. This means when resizing
* to a smaller TensorImpl, other implementations might keep the same backing storage, ORT will re-allocate a new
* onnxruntime::tensor and copy over as many of the existing elements that fit. Functionally, you will end up with
* same output, but the underlying buffer will be re-allocated.
*
* A future change could be to allow ORTTensorImpl to have a different size / shape than the onnxrutime::tensor
* backing it, and then we could improve this behavior.
*
* The canonical CPU / CUDA implementations in PyTorch repository:
* CPU: aten/src/ATen/native/Resize.cpp
* CUDA: aten/src/ATen/native/cuda/Resize.cpp
*/
void resize_impl_ort_(
onnxruntime::ORTInvoker& invoker,
ORTTensorImpl* self,
at::IntArrayRef size) {
auto self_ort_value = self->tensor();
// If shape and size are the same, then nothing to do
if (self->sizes() == size) {
return;
}
auto old_shape = onnxruntime::TensorShape(self->sizes());
auto new_shape = onnxruntime::TensorShape(size);
if (new_shape.Size() == old_shape.Size()) {
// Requested size is the same, only shape is different.
// Just resize existing tensor and return
OrtValue new_ort_value = reshape_invoke(
invoker,
self_ort_value,
size,
// invoke reshape kernel inplace
true);
// TODO(jamill): Investigate why reshape_invoke kernel does not update inplace
self->set_tensor(new_ort_value);
} else {
// Requested size is different - allocate a new onnxruntime::tensor and update ORTTensorImpl
// with new backing onnxruntime::tensor.
auto* self_ort_tensor = self_ort_value.GetMutable<onnxruntime::Tensor>();
OrtValue new_ort_value;
onnxruntime::Tensor::InitOrtValue(self_ort_tensor->DataType(), new_shape,
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault),
new_ort_value);
auto* new_ort_tensor = new_ort_value.GetMutable<onnxruntime::Tensor>();
// Copy over existing elements from current tensor as appropriate
if (self_ort_tensor->SizeInBytes() == 0) {
// self is empty, nothing to copy over
} else if (new_ort_tensor->SizeInBytes() > self_ort_tensor->SizeInBytes()) {
// Copy elements from (smaller) old tensor to (larger) new self tensor
// See function comments to see details on why we need to create temporary ORTValue here
// (Copying elements between tensors of different sizes is not supported)
OrtValue tmp;
onnxruntime::Tensor::InitOrtValue(new_ort_tensor->DataType(), old_shape,
new_ort_tensor->MutableDataRaw(),
new_ort_tensor->Location(),
tmp);
copy(invoker, self_ort_value, tmp);
} else if (new_ort_tensor->SizeInBytes() < self_ort_tensor->SizeInBytes()) {
// Copy elements from (larger) initial self tensor to (smaller) updated self tensor
// See function comments to see details on why we need to create temporary ORTValue here
// (Copying elements between tensors of different sizes is not supported)
OrtValue tmp;
onnxruntime::Tensor::InitOrtValue(self_ort_tensor->DataType(), new_shape,
self_ort_tensor->MutableDataRaw(),
self_ort_tensor->Location(),
tmp);
copy(invoker, tmp, new_ort_value);
}
self->set_tensor(new_ort_value);
}
return;
}
//#pragma region Hand-Implemented ATen Ops
namespace aten {
@ -749,6 +852,27 @@ bool equal(
return *(ort_tensor->Data<int>()) != 0;
}
// aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
const at::Tensor& resize_(
const at::Tensor& self,
at::IntArrayRef size,
c10::optional<at::MemoryFormat> optional_memory_format) {
ORT_LOG_FN(self, size, optional_memory_format);
assert_tensor_supported(self);
// If self is already desired size, then return early
if (self.sizes() == size) {
return self;
}
auto& invoker = GetORTInvoker(self.device());
resize_impl_ort_(
invoker,
dynamic_cast<ORTTensorImpl*>(self.unsafeGetTensorImpl()),
size);
return self;
}
} // namespace aten
//#pragma endregion

View file

@ -196,6 +196,64 @@ class OrtOpTests(unittest.TestCase):
assert torch.allclose(cpu_result, ort_result.cpu())
assert cpu_result.dim() == ort_result.dim()
def test_resize(self):
device = self.get_device()
sizes = [[1], [1, 1], [2, 2], [1, 4]]
# Basic resize from empty Tensor
for size in sizes:
torch_size = torch.Size(size)
cpu_tensor = torch.tensor([])
ort_tensor = torch.tensor([]).to(device)
cpu_tensor.resize_(torch_size)
ort_tensor.resize_(torch_size)
self.assertEqual(cpu_tensor.size(), ort_tensor.size())
# Validate cases where we resize from a non-empty tensor
# to a larger tensor
cpu_tensor = torch.tensor([1.0, 2.0])
ort_tensor = cpu_tensor.to(device)
cpu_tensor.resize_(torch.Size([3]))
ort_tensor.resize_(torch.Size([3]))
self.assertEqual(cpu_tensor.size(), ort_tensor.size())
self.assertTrue(torch.allclose(cpu_tensor[:2], ort_tensor.cpu()[:2]))
# Validate case when calling resize with current shape & size
cpu_tensor = torch.tensor([1.0, 2.0])
ort_tensor = cpu_tensor.to(device)
cpu_tensor.resize_(torch.Size([2]))
ort_tensor.resize_(torch.Size([2]))
self.assertEqual(cpu_tensor.size(), ort_tensor.size())
self.assertTrue(torch.allclose(cpu_tensor, ort_tensor.cpu()))
# Validate case when calling resize with different shape but same size
cpu_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
ort_tensor = cpu_tensor.to(device)
cpu_tensor.resize_(torch.Size([1, 4]))
ort_tensor.resize_(torch.Size([1, 4]))
self.assertEqual(cpu_tensor.size(), ort_tensor.size())
self.assertTrue(torch.allclose(cpu_tensor, ort_tensor.cpu()))
# Validate cases where we resize from a non-empty tensor
# to a smaller tensor
cpu_tensor = torch.tensor([1.0, 2.0])
ort_tensor = cpu_tensor.to(device)
cpu_tensor.resize_(torch.Size([1]))
ort_tensor.resize_(torch.Size([1]))
self.assertEqual(cpu_tensor.size(), ort_tensor.size())
self.assertTrue(torch.allclose(cpu_tensor, ort_tensor.cpu()))
if __name__ == "__main__":
unittest.main()