diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 882bd42cf6..e3586a0f46 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -107,6 +107,7 @@ hand_implemented = { "aten::_reshape_alias": SignatureOnly(), "aten::view": SignatureOnly(), "aten::_copy_from_and_resize": SignatureOnly(), + "aten::resize_": SignatureOnly(), "aten::as_strided": SignatureOnly(), # manually implement Slice using stride and offset. "aten::slice.Tensor": SignatureOnly(), diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index b2133eb8c9..b70e158197 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -346,6 +346,109 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at: //#pragma endregion +/* + * Resize backing store of a TensorImpl. + * + * See notes for implementation details and potential differences from canonical implementations due to constraints in + * ORT model. + * + * If new size is the same size as existing tensor: reshape the existing tensor + * If new size is larger: allocate new memory and copy over existing elements. New memory is uninitialized. + * If new size is smaller: allocate a smaller backing tensor, and copy over + * as many elements as will fit. + * + * Notes: + * There are some implementation details that might deviate from expectations: + * - As the Onnxruntime::tensor does not support resize operation, this functionality is supported on the TensorImpl + * by swapping out the backing tensor if the size changes. + * + * - In the ORT model the shape of the TensorImpl is defined by the backing onnxruntime::tensor, so it is not supported + * to have a TensorImpl with a different shape / size than the backing onnxruntime::tensor. This means when resizing + * to a smaller TensorImpl, other implementations might keep the same backing storage, ORT will re-allocate a new + * onnxruntime::tensor and copy over as many of the existing elements that fit. Functionally, you will end up with + * same output, but the underlying buffer will be re-allocated. + * + * A future change could be to allow ORTTensorImpl to have a different size / shape than the onnxrutime::tensor + * backing it, and then we could improve this behavior. + * + * The canonical CPU / CUDA implementations in PyTorch repository: + * CPU: aten/src/ATen/native/Resize.cpp + * CUDA: aten/src/ATen/native/cuda/Resize.cpp + */ +void resize_impl_ort_( + onnxruntime::ORTInvoker& invoker, + ORTTensorImpl* self, + at::IntArrayRef size) { + auto self_ort_value = self->tensor(); + + // If shape and size are the same, then nothing to do + if (self->sizes() == size) { + return; + } + + auto old_shape = onnxruntime::TensorShape(self->sizes()); + auto new_shape = onnxruntime::TensorShape(size); + + if (new_shape.Size() == old_shape.Size()) { + // Requested size is the same, only shape is different. + // Just resize existing tensor and return + + OrtValue new_ort_value = reshape_invoke( + invoker, + self_ort_value, + size, + // invoke reshape kernel inplace + true); + + // TODO(jamill): Investigate why reshape_invoke kernel does not update inplace + self->set_tensor(new_ort_value); + } else { + // Requested size is different - allocate a new onnxruntime::tensor and update ORTTensorImpl + // with new backing onnxruntime::tensor. + auto* self_ort_tensor = self_ort_value.GetMutable(); + + OrtValue new_ort_value; + onnxruntime::Tensor::InitOrtValue(self_ort_tensor->DataType(), new_shape, + invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), + new_ort_value); + + auto* new_ort_tensor = new_ort_value.GetMutable(); + + // Copy over existing elements from current tensor as appropriate + if (self_ort_tensor->SizeInBytes() == 0) { + // self is empty, nothing to copy over + } else if (new_ort_tensor->SizeInBytes() > self_ort_tensor->SizeInBytes()) { + // Copy elements from (smaller) old tensor to (larger) new self tensor + + // See function comments to see details on why we need to create temporary ORTValue here + // (Copying elements between tensors of different sizes is not supported) + OrtValue tmp; + onnxruntime::Tensor::InitOrtValue(new_ort_tensor->DataType(), old_shape, + new_ort_tensor->MutableDataRaw(), + new_ort_tensor->Location(), + tmp); + + copy(invoker, self_ort_value, tmp); + } else if (new_ort_tensor->SizeInBytes() < self_ort_tensor->SizeInBytes()) { + // Copy elements from (larger) initial self tensor to (smaller) updated self tensor + + // See function comments to see details on why we need to create temporary ORTValue here + // (Copying elements between tensors of different sizes is not supported) + OrtValue tmp; + onnxruntime::Tensor::InitOrtValue(self_ort_tensor->DataType(), new_shape, + self_ort_tensor->MutableDataRaw(), + self_ort_tensor->Location(), + tmp); + + copy(invoker, tmp, new_ort_value); + } + + self->set_tensor(new_ort_value); + } + + return; +} + //#pragma region Hand-Implemented ATen Ops namespace aten { @@ -749,6 +852,27 @@ bool equal( return *(ort_tensor->Data()) != 0; } +// aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) +const at::Tensor& resize_( + const at::Tensor& self, + at::IntArrayRef size, + c10::optional optional_memory_format) { + ORT_LOG_FN(self, size, optional_memory_format); + assert_tensor_supported(self); + + // If self is already desired size, then return early + if (self.sizes() == size) { + return self; + } + + auto& invoker = GetORTInvoker(self.device()); + resize_impl_ort_( + invoker, + dynamic_cast(self.unsafeGetTensorImpl()), + size); + return self; +} + } // namespace aten //#pragma endregion diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 53bda0752a..04e2c23b30 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -196,6 +196,64 @@ class OrtOpTests(unittest.TestCase): assert torch.allclose(cpu_result, ort_result.cpu()) assert cpu_result.dim() == ort_result.dim() + def test_resize(self): + device = self.get_device() + + sizes = [[1], [1, 1], [2, 2], [1, 4]] + + # Basic resize from empty Tensor + for size in sizes: + torch_size = torch.Size(size) + cpu_tensor = torch.tensor([]) + ort_tensor = torch.tensor([]).to(device) + + cpu_tensor.resize_(torch_size) + ort_tensor.resize_(torch_size) + + self.assertEqual(cpu_tensor.size(), ort_tensor.size()) + + # Validate cases where we resize from a non-empty tensor + # to a larger tensor + cpu_tensor = torch.tensor([1.0, 2.0]) + ort_tensor = cpu_tensor.to(device) + + cpu_tensor.resize_(torch.Size([3])) + ort_tensor.resize_(torch.Size([3])) + + self.assertEqual(cpu_tensor.size(), ort_tensor.size()) + self.assertTrue(torch.allclose(cpu_tensor[:2], ort_tensor.cpu()[:2])) + + # Validate case when calling resize with current shape & size + cpu_tensor = torch.tensor([1.0, 2.0]) + ort_tensor = cpu_tensor.to(device) + + cpu_tensor.resize_(torch.Size([2])) + ort_tensor.resize_(torch.Size([2])) + + self.assertEqual(cpu_tensor.size(), ort_tensor.size()) + self.assertTrue(torch.allclose(cpu_tensor, ort_tensor.cpu())) + + # Validate case when calling resize with different shape but same size + cpu_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + ort_tensor = cpu_tensor.to(device) + + cpu_tensor.resize_(torch.Size([1, 4])) + ort_tensor.resize_(torch.Size([1, 4])) + + self.assertEqual(cpu_tensor.size(), ort_tensor.size()) + self.assertTrue(torch.allclose(cpu_tensor, ort_tensor.cpu())) + + # Validate cases where we resize from a non-empty tensor + # to a smaller tensor + cpu_tensor = torch.tensor([1.0, 2.0]) + ort_tensor = cpu_tensor.to(device) + + cpu_tensor.resize_(torch.Size([1])) + ort_tensor.resize_(torch.Size([1])) + + self.assertEqual(cpu_tensor.size(), ort_tensor.size()) + self.assertTrue(torch.allclose(cpu_tensor, ort_tensor.cpu())) + if __name__ == "__main__": unittest.main()