From 777a80fbc15eaf1b79e06b86dcd57cdf1d156f50 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 13 Dec 2021 13:23:46 -0800 Subject: [PATCH] Abjindal/eager onnx operators fix (#9968) * adding view operator changes * adding the slice operator definition * moving to opgen script for slice op and removing redundant steps in view op and reshape_copy * adding for at definition * adding for at::infer_size definition * changing template style for reshape_copy to ensure int64_t type --- .../orttraining/eager/opgen/opgen/atenops.py | 2 +- orttraining/orttraining/eager/ort_aten.cpp | 18 ++++++++----- orttraining/orttraining/eager/ort_aten.h | 27 +++++++++++++++++++ orttraining/orttraining/eager/ort_ops.cpp | 20 -------------- orttraining/orttraining/eager/ort_ops.h | 18 ++++++++++++- 5 files changed, 56 insertions(+), 29 deletions(-) diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 5da6ac607f..f8cda84df4 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -34,7 +34,6 @@ ops = { 'aten::_reshape_alias': SignatureOnly(), 'aten::view': SignatureOnly(), 'aten::_copy_from_and_resize' : SignatureOnly(), - 'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'), 'aten::t': Transpose('self'), 'aten::mm': MatMul('self', 'mat2'), @@ -52,6 +51,7 @@ ops = { 'aten::gelu_backward' : GeluGrad('grad', 'self'), 'aten::max' : ReduceMax('self', keepdims=1), 'aten::min' : ReduceMin('self', keepdims=1), + 'aten::slice.Tensor' : Slice('self', 'start', 'end', 'dim', 'step'), 'aten::ne.Scalar':MakeTorchFallback(), 'aten::ne.Scalar_out': MakeTorchFallback(), diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 00969e1339..a698fccb69 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -5,6 +5,7 @@ #include "ort_tensor.h" #include #include +#include namespace torch_ort { namespace eager { @@ -173,6 +174,14 @@ bool IsSupportedType(at::IntArrayRef arrary, const std::vector& std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end(); } +bool IsSupportedType(int64_t val, const std::vector& valid_types){ + return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end(); +} + +bool IsSupportedType(c10::optional val, const std::vector& valid_types){ + return IsSupportedType(val.value(), valid_types); +} + //#pragma endregion //#pragma region Hand-Implemented ATen Ops @@ -250,23 +259,18 @@ at::Tensor _reshape_alias( reshape_copy( invoker, create_ort_value(invoker, self), - at::infer_size( - size, - self.numel())), + size), self.options()); } at::Tensor view(const at::Tensor& self, at::IntArrayRef size) { ORT_LOG_FN(self, size); - auto& invoker = GetORTInvoker(self.device()); return aten_tensor_from_ort( reshape_copy( invoker, create_ort_value(invoker, self), - at::infer_size( - size, - self.numel())), + size), self.options()); } diff --git a/orttraining/orttraining/eager/ort_aten.h b/orttraining/orttraining/eager/ort_aten.h index b91c9ab381..a63f914b30 100644 --- a/orttraining/orttraining/eager/ort_aten.h +++ b/orttraining/orttraining/eager/ort_aten.h @@ -34,6 +34,29 @@ OrtValue create_ort_value( OrtValue create_ort_value(const at::Tensor& tensor); +// Create 1-dimensional ORT tensor from a given value +template +OrtValue create_ort_value( + onnxruntime::ORTInvoker& invoker, + const T val) { + OrtValue ort_val; + CreateMLValue( + invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), + onnxruntime::DataTypeImpl::GetType(), + {1,}, + &ort_val); + auto* ort_tensor = ort_val.GetMutable(); + CopyVectorToTensor(invoker, {val}, *ort_tensor); + return ort_val; +} + +template +OrtValue create_ort_value( + onnxruntime::ORTInvoker& invoker, + c10::optional val) { + return create_ort_value(invoker, val.value()); +} + template OrtValue create_ort_value( onnxruntime::ORTInvoker& invoker, @@ -79,5 +102,9 @@ bool IsSupportedType(at::Tensor tensor, const std::vector& valid bool IsSupportedType(at::IntArrayRef arrary, const std::vector& valid_types); +bool IsSupportedType(int64_t val, const std::vector& valid_types); + +bool IsSupportedType(c10::optional val, const std::vector& valid_types); + } // namespace eager } // namespace torch_ort \ No newline at end of file diff --git a/orttraining/orttraining/eager/ort_ops.cpp b/orttraining/orttraining/eager/ort_ops.cpp index 5d83d728c9..2cc98b356b 100644 --- a/orttraining/orttraining/eager/ort_ops.cpp +++ b/orttraining/orttraining/eager/ort_ops.cpp @@ -8,26 +8,6 @@ namespace torch_ort { namespace eager { -OrtValue reshape_copy( - onnxruntime::ORTInvoker& invoker, - const OrtValue& input, - std::vector shape) { - - // TODO: actual reshape on buffer - const onnxruntime::Tensor& input_tensor = input.Get(); - auto new_shape = at::infer_size(shape, input_tensor.Shape().Size()); - OrtValue shape_tensor; - //todo: avoid the copy on this small shape vector; - auto element_type = onnxruntime::DataTypeImpl::GetType(); - CreateMLValue(invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), - element_type, {(int64_t)new_shape.size(),}, &shape_tensor); - auto* ort_shape_tensor = shape_tensor.GetMutable(); - CopyVectorToTensor(invoker, new_shape, *ort_shape_tensor); - std::vector result(1); - ORT_THROW_IF_ERROR(invoker.Invoke("Reshape", {input, shape_tensor}, result, nullptr)); - return result[0]; -} - void copy(onnxruntime::ORTInvoker& invoker, const OrtValue& src, OrtValue& dst){ auto& ort_ep = invoker.GetCurrentExecutionProvider(); diff --git a/orttraining/orttraining/eager/ort_ops.h b/orttraining/orttraining/eager/ort_ops.h index cb78d011b0..13f7c30be4 100644 --- a/orttraining/orttraining/eager/ort_ops.h +++ b/orttraining/orttraining/eager/ort_ops.h @@ -3,16 +3,32 @@ #pragma once +#include "ort_util.h" #include #include namespace torch_ort { namespace eager { +template class V> OrtValue reshape_copy( onnxruntime::ORTInvoker& invoker, const OrtValue& input, - std::vector shape); + V shape) { + // TODO: actual reshape on buffer + const onnxruntime::Tensor& input_tensor = input.Get(); + auto new_shape = at::infer_size(shape, input_tensor.Shape().Size()); + OrtValue shape_tensor; + //todo: avoid the copy on this small shape vector; + auto element_type = onnxruntime::DataTypeImpl::GetType(); + CreateMLValue(invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), + element_type, {(int64_t)new_shape.size(),}, &shape_tensor); + auto* ort_shape_tensor = shape_tensor.GetMutable(); + CopyVectorToTensor(invoker, new_shape, *ort_shape_tensor); + std::vector result(1); + ORT_THROW_IF_ERROR(invoker.Invoke("Reshape", {input, shape_tensor}, result, nullptr)); + return result[0]; +} OrtValue add(onnxruntime::ORTInvoker& invoker, const OrtValue& A,