mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Abjindal/eager onnx operators fix (#9968)
* adding view operator changes * adding the slice operator definition * moving to opgen script for slice op and removing redundant steps in view op and reshape_copy * adding for at definition * adding for at::infer_size definition * changing template style for reshape_copy to ensure int64_t type
This commit is contained in:
parent
d0b08af37a
commit
777a80fbc1
5 changed files with 56 additions and 29 deletions
|
|
@ -34,7 +34,6 @@ ops = {
|
|||
'aten::_reshape_alias': SignatureOnly(),
|
||||
'aten::view': SignatureOnly(),
|
||||
'aten::_copy_from_and_resize' : SignatureOnly(),
|
||||
|
||||
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
|
||||
'aten::t': Transpose('self'),
|
||||
'aten::mm': MatMul('self', 'mat2'),
|
||||
|
|
@ -52,6 +51,7 @@ ops = {
|
|||
'aten::gelu_backward' : GeluGrad('grad', 'self'),
|
||||
'aten::max' : ReduceMax('self', keepdims=1),
|
||||
'aten::min' : ReduceMin('self', keepdims=1),
|
||||
'aten::slice.Tensor' : Slice('self', 'start', 'end', 'dim', 'step'),
|
||||
|
||||
'aten::ne.Scalar':MakeTorchFallback(),
|
||||
'aten::ne.Scalar_out': MakeTorchFallback(),
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "ort_tensor.h"
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
#include <ATen/InferSize.h>
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
|
@ -173,6 +174,14 @@ bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>&
|
|||
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types){
|
||||
return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types){
|
||||
return IsSupportedType(val.value(), valid_types);
|
||||
}
|
||||
|
||||
//#pragma endregion
|
||||
|
||||
//#pragma region Hand-Implemented ATen Ops
|
||||
|
|
@ -250,23 +259,18 @@ at::Tensor _reshape_alias(
|
|||
reshape_copy(
|
||||
invoker,
|
||||
create_ort_value(invoker, self),
|
||||
at::infer_size(
|
||||
size,
|
||||
self.numel())),
|
||||
size),
|
||||
self.options());
|
||||
}
|
||||
|
||||
at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
|
||||
ORT_LOG_FN(self, size);
|
||||
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
return aten_tensor_from_ort(
|
||||
reshape_copy(
|
||||
invoker,
|
||||
create_ort_value(invoker, self),
|
||||
at::infer_size(
|
||||
size,
|
||||
self.numel())),
|
||||
size),
|
||||
self.options());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,29 @@ OrtValue create_ort_value(
|
|||
|
||||
OrtValue create_ort_value(const at::Tensor& tensor);
|
||||
|
||||
// Create 1-dimensional ORT tensor from a given value
|
||||
template <typename T>
|
||||
OrtValue create_ort_value(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const T val) {
|
||||
OrtValue ort_val;
|
||||
CreateMLValue(
|
||||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault),
|
||||
onnxruntime::DataTypeImpl::GetType<T>(),
|
||||
{1,},
|
||||
&ort_val);
|
||||
auto* ort_tensor = ort_val.GetMutable<onnxruntime::Tensor>();
|
||||
CopyVectorToTensor<int64_t>(invoker, {val}, *ort_tensor);
|
||||
return ort_val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OrtValue create_ort_value(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
c10::optional<T> val) {
|
||||
return create_ort_value(invoker, val.value());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
OrtValue create_ort_value(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
|
|
@ -79,5 +102,9 @@ bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid
|
|||
|
||||
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types);
|
||||
|
||||
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types);
|
||||
|
||||
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
|
|
@ -8,26 +8,6 @@
|
|||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
||||
OrtValue reshape_copy(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const OrtValue& input,
|
||||
std::vector<int64_t> shape) {
|
||||
|
||||
// TODO: actual reshape on buffer
|
||||
const onnxruntime::Tensor& input_tensor = input.Get<onnxruntime::Tensor>();
|
||||
auto new_shape = at::infer_size(shape, input_tensor.Shape().Size());
|
||||
OrtValue shape_tensor;
|
||||
//todo: avoid the copy on this small shape vector;
|
||||
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>();
|
||||
CreateMLValue(invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault),
|
||||
element_type, {(int64_t)new_shape.size(),}, &shape_tensor);
|
||||
auto* ort_shape_tensor = shape_tensor.GetMutable<onnxruntime::Tensor>();
|
||||
CopyVectorToTensor<int64_t>(invoker, new_shape, *ort_shape_tensor);
|
||||
std::vector<OrtValue> result(1);
|
||||
ORT_THROW_IF_ERROR(invoker.Invoke("Reshape", {input, shape_tensor}, result, nullptr));
|
||||
return result[0];
|
||||
}
|
||||
|
||||
void copy(onnxruntime::ORTInvoker& invoker,
|
||||
const OrtValue& src, OrtValue& dst){
|
||||
auto& ort_ep = invoker.GetCurrentExecutionProvider();
|
||||
|
|
|
|||
|
|
@ -3,16 +3,32 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "ort_util.h"
|
||||
#include <core/framework/ort_value.h>
|
||||
#include <core/eager/ort_kernel_invoker.h>
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
||||
template <template<class> class V>
|
||||
OrtValue reshape_copy(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const OrtValue& input,
|
||||
std::vector<int64_t> shape);
|
||||
V<int64_t> shape) {
|
||||
// TODO: actual reshape on buffer
|
||||
const onnxruntime::Tensor& input_tensor = input.Get<onnxruntime::Tensor>();
|
||||
auto new_shape = at::infer_size(shape, input_tensor.Shape().Size());
|
||||
OrtValue shape_tensor;
|
||||
//todo: avoid the copy on this small shape vector;
|
||||
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>();
|
||||
CreateMLValue(invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault),
|
||||
element_type, {(int64_t)new_shape.size(),}, &shape_tensor);
|
||||
auto* ort_shape_tensor = shape_tensor.GetMutable<onnxruntime::Tensor>();
|
||||
CopyVectorToTensor<int64_t>(invoker, new_shape, *ort_shape_tensor);
|
||||
std::vector<OrtValue> result(1);
|
||||
ORT_THROW_IF_ERROR(invoker.Invoke("Reshape", {input, shape_tensor}, result, nullptr));
|
||||
return result[0];
|
||||
}
|
||||
|
||||
OrtValue add(onnxruntime::ORTInvoker& invoker,
|
||||
const OrtValue& A,
|
||||
|
|
|
|||
Loading…
Reference in a new issue