From fcc167dd47371b0cfafdc4dcde6eb13f652d8f67 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Thu, 18 Nov 2021 19:26:49 -0800 Subject: [PATCH] fix reshape implementation in eager mode (#9741) --- .../orttraining/eager/opgen/opgen/atenops.py | 2 +- orttraining/orttraining/eager/ort_aten.cpp | 15 +++++++++++---- orttraining/orttraining/eager/test/ort_tensor.py | 7 +++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index c07c7d281d..34a4b80845 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -31,7 +31,7 @@ ops = { 'aten::empty_strided': SignatureOnly(), 'aten::zero_': SignatureOnly(), 'aten::copy_': SignatureOnly(), - 'aten::reshape': SignatureOnly(), + 'aten::_reshape_alias': SignatureOnly(), 'aten::view': SignatureOnly(), 'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'), diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 79f33c41a2..90a4d96a10 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -67,6 +67,8 @@ onnxruntime::MLDataType ort_scalar_type_from_aten( return onnxruntime::DataTypeImpl::GetType(); case at::kLong: return onnxruntime::DataTypeImpl::GetType(); + case at::kBool: + return onnxruntime::DataTypeImpl::GetType(); default: ORT_THROW("Unsupport aten scalar type: ", dtype); } @@ -222,15 +224,20 @@ at::Tensor empty_strided( .dtype(dtype)); } -at::Tensor reshape(at::Tensor const& self, at::IntArrayRef shape) { - ORT_LOG_FN(self, shape); - +at::Tensor _reshape_alias( + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride){ + ORT_LOG_FN(self, size, stride); + // TODO: support stride auto& invoker = GetORTInvoker(self.device()); return aten_tensor_from_ort( reshape_copy( invoker, create_ort_value(invoker, self), - shape.vec()), + at::infer_size( + size, + self.numel())), self.options()); } diff --git a/orttraining/orttraining/eager/test/ort_tensor.py b/orttraining/orttraining/eager/test/ort_tensor.py index c0a1b8eb5e..772c26287a 100644 --- a/orttraining/orttraining/eager/test/ort_tensor.py +++ b/orttraining/orttraining/eager/test/ort_tensor.py @@ -19,6 +19,13 @@ class OrtTensorTests(unittest.TestCase): ort_ones = cpu_ones.to('ort') assert ort_ones.is_ort assert torch.allclose(cpu_ones, ort_ones.cpu()) + + def test_reshape(self): + cpu_ones = torch.ones(10, 10) + ort_ones = cpu_ones.to('ort') + y = ort_ones.reshape(-1) + assert len(y.size()) == 1 + assert y.size()[0] == 100 if __name__ == '__main__': unittest.main() \ No newline at end of file