Abjindal/eager onnx operators fix (#9968)

* adding view operator changes

* adding the slice operator definition

* moving to opgen script for slice op and removing redundant steps in view op and reshape_copy

* adding for at definition

* adding for at::infer_size definition

* changing template style for reshape_copy to ensure int64_t type
This commit is contained in:
Abhishek Jindal 2021-12-13 13:23:46 -08:00 committed by GitHub
parent d0b08af37a
commit 777a80fbc1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 29 deletions

View file

@ -34,7 +34,6 @@ ops = {
'aten::_reshape_alias': SignatureOnly(),
'aten::view': SignatureOnly(),
'aten::_copy_from_and_resize' : SignatureOnly(),
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
'aten::t': Transpose('self'),
'aten::mm': MatMul('self', 'mat2'),
@ -52,6 +51,7 @@ ops = {
'aten::gelu_backward' : GeluGrad('grad', 'self'),
'aten::max' : ReduceMax('self', keepdims=1),
'aten::min' : ReduceMin('self', keepdims=1),
'aten::slice.Tensor' : Slice('self', 'start', 'end', 'dim', 'step'),
'aten::ne.Scalar':MakeTorchFallback(),
'aten::ne.Scalar_out': MakeTorchFallback(),

View file

@ -5,6 +5,7 @@
#include "ort_tensor.h"
#include <c10/core/TensorImpl.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/InferSize.h>
namespace torch_ort {
namespace eager {
@ -173,6 +174,14 @@ bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>&
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
}
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types){
return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
}
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types){
return IsSupportedType(val.value(), valid_types);
}
//#pragma endregion
//#pragma region Hand-Implemented ATen Ops
@ -250,23 +259,18 @@ at::Tensor _reshape_alias(
reshape_copy(
invoker,
create_ort_value(invoker, self),
at::infer_size(
size,
self.numel())),
size),
self.options());
}
at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
ORT_LOG_FN(self, size);
auto& invoker = GetORTInvoker(self.device());
return aten_tensor_from_ort(
reshape_copy(
invoker,
create_ort_value(invoker, self),
at::infer_size(
size,
self.numel())),
size),
self.options());
}

View file

@ -34,6 +34,29 @@ OrtValue create_ort_value(
OrtValue create_ort_value(const at::Tensor& tensor);
// Create 1-dimensional ORT tensor from a given value
template <typename T>
OrtValue create_ort_value(
onnxruntime::ORTInvoker& invoker,
const T val) {
OrtValue ort_val;
CreateMLValue(
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault),
onnxruntime::DataTypeImpl::GetType<T>(),
{1,},
&ort_val);
auto* ort_tensor = ort_val.GetMutable<onnxruntime::Tensor>();
CopyVectorToTensor<int64_t>(invoker, {val}, *ort_tensor);
return ort_val;
}
template <typename T>
OrtValue create_ort_value(
onnxruntime::ORTInvoker& invoker,
c10::optional<T> val) {
return create_ort_value(invoker, val.value());
}
template<typename T>
OrtValue create_ort_value(
onnxruntime::ORTInvoker& invoker,
@ -79,5 +102,9 @@ bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types);
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types);
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types);
} // namespace eager
} // namespace torch_ort

View file

@ -8,26 +8,6 @@
namespace torch_ort {
namespace eager {
OrtValue reshape_copy(
onnxruntime::ORTInvoker& invoker,
const OrtValue& input,
std::vector<int64_t> shape) {
// TODO: actual reshape on buffer
const onnxruntime::Tensor& input_tensor = input.Get<onnxruntime::Tensor>();
auto new_shape = at::infer_size(shape, input_tensor.Shape().Size());
OrtValue shape_tensor;
//todo: avoid the copy on this small shape vector;
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>();
CreateMLValue(invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault),
element_type, {(int64_t)new_shape.size(),}, &shape_tensor);
auto* ort_shape_tensor = shape_tensor.GetMutable<onnxruntime::Tensor>();
CopyVectorToTensor<int64_t>(invoker, new_shape, *ort_shape_tensor);
std::vector<OrtValue> result(1);
ORT_THROW_IF_ERROR(invoker.Invoke("Reshape", {input, shape_tensor}, result, nullptr));
return result[0];
}
void copy(onnxruntime::ORTInvoker& invoker,
const OrtValue& src, OrtValue& dst){
auto& ort_ep = invoker.GetCurrentExecutionProvider();

View file

@ -3,16 +3,32 @@
#pragma once
#include "ort_util.h"
#include <core/framework/ort_value.h>
#include <core/eager/ort_kernel_invoker.h>
namespace torch_ort {
namespace eager {
template <template<class> class V>
OrtValue reshape_copy(
onnxruntime::ORTInvoker& invoker,
const OrtValue& input,
std::vector<int64_t> shape);
V<int64_t> shape) {
// TODO: actual reshape on buffer
const onnxruntime::Tensor& input_tensor = input.Get<onnxruntime::Tensor>();
auto new_shape = at::infer_size(shape, input_tensor.Shape().Size());
OrtValue shape_tensor;
//todo: avoid the copy on this small shape vector;
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>();
CreateMLValue(invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault),
element_type, {(int64_t)new_shape.size(),}, &shape_tensor);
auto* ort_shape_tensor = shape_tensor.GetMutable<onnxruntime::Tensor>();
CopyVectorToTensor<int64_t>(invoker, new_shape, *ort_shape_tensor);
std::vector<OrtValue> result(1);
ORT_THROW_IF_ERROR(invoker.Invoke("Reshape", {input, shape_tensor}, result, nullptr));
return result[0];
}
OrtValue add(onnxruntime::ORTInvoker& invoker,
const OrtValue& A,