mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
cpplint & Eager mode: refactor and add comments to empty_* functions, general lint cleanup in ort_aten (#12238)
* empty* comments and code reuse * lint * more cpplint * add cpplint settings * test empty
This commit is contained in:
parent
72c689a502
commit
424120d0fa
3 changed files with 85 additions and 74 deletions
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
|
|
@ -34,5 +34,10 @@
|
|||
"python.linting.pydocstyleArgs": [
|
||||
"--convention=google"
|
||||
],
|
||||
"python.linting.banditEnabled": true
|
||||
"python.linting.banditEnabled": true,
|
||||
"cpplint.lineLength": 120,
|
||||
"cpplint.filters": [
|
||||
"-build/include_subdir",
|
||||
"-runtime/references"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,11 +10,14 @@
|
|||
#include <c10/util/irange.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
||||
//#pragma region Helpers
|
||||
// #pragma region Helpers
|
||||
using NodeAttributes = onnxruntime::NodeAttributes;
|
||||
namespace {
|
||||
inline bool is_device_supported(at::DeviceType type) {
|
||||
|
|
@ -34,7 +37,7 @@ namespace {
|
|||
throw std::runtime_error("ORT copy: device not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
at::Tensor aten_tensor_from_ort(
|
||||
OrtValue&& ot,
|
||||
|
|
@ -59,7 +62,7 @@ const std::vector<at::Tensor> aten_tensor_from_ort(
|
|||
|
||||
onnxruntime::MLDataType ort_scalar_type_from_aten(
|
||||
at::ScalarType dtype) {
|
||||
switch (dtype){
|
||||
switch (dtype) {
|
||||
case at::kFloat:
|
||||
return onnxruntime::DataTypeImpl::GetType<float>();
|
||||
case at::kDouble:
|
||||
|
|
@ -107,7 +110,7 @@ OrtValue create_ort_value(
|
|||
break;
|
||||
}
|
||||
default:
|
||||
// TODO: support more types
|
||||
// TODO(unknown): support more types
|
||||
// For most at::ScalarType, it should be safe to just call value.to<>
|
||||
// on it, but for now we want to explicitly know when we've encountered
|
||||
// a new scalar type while bringing up ORT eager mode.
|
||||
|
|
@ -131,13 +134,17 @@ OrtValue create_ort_value(
|
|||
auto element_type = ort_scalar_type_from_aten(tensor.scalar_type());
|
||||
|
||||
OrtValue ort_tensor;
|
||||
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape(tensor.sizes().vec()), tensor.data_ptr(),
|
||||
*mem_info, ort_tensor, 0L /* offset = 0 - because tensor.data_ptr() includes the underyling offset */,
|
||||
tensor.strides().vec());
|
||||
onnxruntime::Tensor::InitOrtValue(
|
||||
element_type,
|
||||
onnxruntime::TensorShape(tensor.sizes().vec()),
|
||||
tensor.data_ptr(),
|
||||
*mem_info, ort_tensor,
|
||||
0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset
|
||||
tensor.strides().vec());
|
||||
return ort_tensor;
|
||||
}
|
||||
|
||||
OrtValue create_ort_value(const at::Tensor& tensor){
|
||||
OrtValue create_ort_value(const at::Tensor& tensor) {
|
||||
auto& invoker = GetORTInvoker(tensor.device());
|
||||
return create_ort_value(invoker, tensor);
|
||||
}
|
||||
|
|
@ -146,7 +153,7 @@ std::vector<OrtValue> create_ort_value(
|
|||
onnxruntime::ORTInvoker& invoker,
|
||||
at::TensorList values) {
|
||||
auto output = std::vector<OrtValue>{};
|
||||
for (auto element: values){
|
||||
for (auto element : values) {
|
||||
output.push_back(create_ort_value(element));
|
||||
}
|
||||
return output;
|
||||
|
|
@ -157,7 +164,7 @@ onnx::AttributeProto create_ort_attribute(
|
|||
at::Scalar value,
|
||||
const bool isTensor,
|
||||
at::ScalarType type) {
|
||||
if (isTensor){
|
||||
if (isTensor) {
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
|
||||
|
|
@ -190,8 +197,7 @@ onnx::AttributeProto create_ort_attribute(
|
|||
ORT_THROW("Unsupported: at::ScalarType::", value.type());
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
else{
|
||||
} else {
|
||||
return create_ort_attribute(name, value, value.type());
|
||||
}
|
||||
}
|
||||
|
|
@ -254,33 +260,33 @@ onnx::AttributeProto create_ort_attribute(
|
|||
return attr;
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){
|
||||
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types) {
|
||||
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types){
|
||||
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types) {
|
||||
return std::find(valid_types.begin(), valid_types.end(), tensor.scalar_type()) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types){
|
||||
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types) {
|
||||
return std::find(valid_types.begin(), valid_types.end(), at::kInt) != valid_types.end() ||
|
||||
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types){
|
||||
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types) {
|
||||
return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types){
|
||||
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types) {
|
||||
return IsSupportedType(val.value(), valid_types);
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types){
|
||||
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types) {
|
||||
return IsSupportedType(tensors[0], valid_types);
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
|
||||
switch (dtype){
|
||||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype) {
|
||||
switch (dtype) {
|
||||
case at::kFloat:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
|
||||
case at::kDouble:
|
||||
|
|
@ -349,7 +355,7 @@ c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
|
|||
return typeFromTensor;
|
||||
}
|
||||
|
||||
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){
|
||||
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type) {
|
||||
std::vector<OrtValue> output(1);
|
||||
NodeAttributes attrs(2);
|
||||
attrs["to"] = create_ort_attribute(
|
||||
|
|
@ -425,7 +431,7 @@ void resize_output(
|
|||
resize_impl_ort_(invoker, output, shape);
|
||||
}
|
||||
|
||||
//#pragma endregion
|
||||
// #pragma endregion
|
||||
|
||||
/*
|
||||
* Resize backing store of a TensorImpl.
|
||||
|
|
@ -530,52 +536,44 @@ void resize_impl_ort_(
|
|||
return;
|
||||
}
|
||||
|
||||
//#pragma region Hand-Implemented ATen Ops
|
||||
// #pragma region Hand-Implemented ATen Ops
|
||||
|
||||
namespace aten {
|
||||
|
||||
at::Tensor empty_memory_format(
|
||||
at::Tensor empty_strided(
|
||||
at::IntArrayRef size,
|
||||
// *,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype_opt,
|
||||
c10::optional<at::Layout> layout_opt,
|
||||
c10::optional<at::Device> device_opt,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) {
|
||||
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format);
|
||||
|
||||
assert(dtype_opt.has_value());
|
||||
assert(device_opt.has_value());
|
||||
|
||||
// TODO: validate options and memory format
|
||||
// TODO: figure out how to get the correct element type.
|
||||
OrtValue ot;
|
||||
auto& invoker = GetORTInvoker(*device_opt);
|
||||
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(*dtype_opt), onnxruntime::TensorShape(size.vec()),
|
||||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot);
|
||||
return aten_tensor_from_ort(
|
||||
std::move(ot),
|
||||
at::TensorOptions()
|
||||
.device(*device_opt)
|
||||
.dtype(*dtype_opt));
|
||||
}
|
||||
|
||||
at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt,
|
||||
c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt) {
|
||||
c10::optional<at::Layout> layout_opt, // Ignored because there's no ONNX support.
|
||||
c10::optional<at::Device> device_opt, // Will be ORT by the time this is dispatched.
|
||||
c10::optional<bool> pin_memory_opt) { // Ignored because there's no ONNX support.
|
||||
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
|
||||
// TODO: how to handle type conversion
|
||||
OrtValue ot;
|
||||
assert(device_opt.has_value());
|
||||
// TODO: how to support layout
|
||||
// assert(!layout_opt.has_value());
|
||||
at::ScalarType dtype = c10::dtype_or_default(dtype_opt);
|
||||
auto& invoker = GetORTInvoker(*device_opt);
|
||||
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()),
|
||||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot,
|
||||
stride.vec());
|
||||
return aten_tensor_from_ort(std::move(ot), at::TensorOptions().device(*device_opt).dtype(dtype));
|
||||
return aten_tensor_from_ort(
|
||||
std::move(ot),
|
||||
at::TensorOptions()
|
||||
.device(*device_opt)
|
||||
.dtype(dtype));
|
||||
}
|
||||
|
||||
at::Tensor empty_memory_format(
|
||||
at::IntArrayRef size,
|
||||
c10::optional<at::ScalarType> dtype_opt,
|
||||
c10::optional<at::Layout> layout_opt,
|
||||
c10::optional<at::Device> device_opt,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) { // Ignored because there's no ONNX support.
|
||||
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format);
|
||||
|
||||
// Use the strided impl with default (no strides specified).
|
||||
return empty_strided(size, at::IntArrayRef({}), dtype_opt, layout_opt, device_opt, pin_memory);
|
||||
}
|
||||
|
||||
// aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
|
||||
|
|
@ -602,9 +600,9 @@ at::Tensor as_strided(
|
|||
at::Tensor _reshape_alias(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride){
|
||||
at::IntArrayRef stride) {
|
||||
ORT_LOG_FN(self, size, stride);
|
||||
// TODO: support stride
|
||||
// TODO(unknown): support stride
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
auto ort_input = create_ort_value(invoker, self);
|
||||
return aten_tensor_from_ort(
|
||||
|
|
@ -645,7 +643,7 @@ at::Tensor& copy_(
|
|||
: src.device());
|
||||
const auto ort_src = create_ort_value(invoker, src);
|
||||
auto ort_self = create_ort_value(invoker, self);
|
||||
if (self.scalar_type() != src.scalar_type()){
|
||||
if (self.scalar_type() != src.scalar_type()) {
|
||||
// invoke cast first
|
||||
std::vector<OrtValue> ort_cast_output(1);
|
||||
onnxruntime::NodeAttributes attrs(1);
|
||||
|
|
@ -661,8 +659,7 @@ at::Tensor& copy_(
|
|||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
copy(invoker, ort_cast_output[0], ort_self);
|
||||
}
|
||||
else{
|
||||
} else {
|
||||
copy(invoker, ort_src, ort_self);
|
||||
}
|
||||
|
||||
|
|
@ -671,7 +668,7 @@ at::Tensor& copy_(
|
|||
|
||||
at::Tensor _copy_from_and_resize(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& dst){
|
||||
const at::Tensor& dst) {
|
||||
ORT_LOG_FN(self, dst);
|
||||
|
||||
assert_tensor_supported(self);
|
||||
|
|
@ -688,11 +685,11 @@ at::Tensor _copy_from_and_resize(
|
|||
return self;
|
||||
}
|
||||
|
||||
at::Tensor& zero_(at::Tensor& self){
|
||||
at::Tensor& zero_(at::Tensor& self) {
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
auto ort_in_self = create_ort_value(invoker, self);
|
||||
OrtValue flag_val;
|
||||
//construct a constant tensor
|
||||
// construct a constant tensor
|
||||
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>();
|
||||
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape({}),
|
||||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), flag_val);
|
||||
|
|
@ -715,7 +712,7 @@ at::Tensor& zero_(at::Tensor& self){
|
|||
return self;
|
||||
}
|
||||
|
||||
// TODO: enhance opgen.py to support inplace binary operations.
|
||||
// TODO(unknown): enhance opgen.py to support inplace binary operations.
|
||||
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
||||
at::Tensor& add__Tensor(
|
||||
at::Tensor& self,
|
||||
|
|
@ -723,10 +720,11 @@ at::Tensor& add__Tensor(
|
|||
const at::Scalar& alpha) {
|
||||
ORT_LOG_FN(self, other, alpha);
|
||||
|
||||
auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
|
||||
if (
|
||||
!IsSupportedType(alpha, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
|
||||
!IsSupportedType(other, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
|
||||
!IsSupportedType(self, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16})) {
|
||||
!IsSupportedType(alpha, st) ||
|
||||
!IsSupportedType(other, st) ||
|
||||
!IsSupportedType(self, st)) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(add__Tensor)>::call(self, other, alpha);
|
||||
|
|
@ -827,8 +825,9 @@ bool keepdim,
|
|||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, dim, keepdim, out);
|
||||
|
||||
auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
|
||||
if (
|
||||
!IsSupportedType(self, {at::kLong, at::kShort, at::kHalf, at::kBFloat16, at::kFloat, at::kByte, at::kInt, at::kDouble})) {
|
||||
!IsSupportedType(self, st)) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out);
|
||||
|
|
@ -1034,7 +1033,7 @@ at::Tensor& _log_softmax_out(
|
|||
ORT_LOG_FN(self, dim, half_to_float, out);
|
||||
|
||||
if (
|
||||
!IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) {
|
||||
!IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out);
|
||||
|
|
@ -1096,7 +1095,7 @@ at::Tensor& _log_softmax_out(
|
|||
ort_outputs_2_Transpose[0] = ort_input_out;
|
||||
|
||||
NodeAttributes attrs_2(1);
|
||||
attrs_2["perm"] = create_ort_attribute("perm", axes);;
|
||||
attrs_2["perm"] = create_ort_attribute("perm", axes);
|
||||
|
||||
status = invoker.Invoke("Transpose", {
|
||||
std::move(ort_outputs_1_LogSoftmax[0]),
|
||||
|
|
@ -1165,9 +1164,9 @@ at::Tensor& mm_out(
|
|||
}
|
||||
|
||||
|
||||
} // namespace aten
|
||||
} // namespace aten
|
||||
|
||||
//#pragma endregion
|
||||
// #pragma endregion
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
|
|
|
|||
|
|
@ -231,6 +231,13 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_tensor_copied = ort_tensor.cpu()
|
||||
assert cpu_tensor_copied.stride() == (0, 0, 0)
|
||||
|
||||
def test_empty(self):
|
||||
device = self.get_device()
|
||||
cpu_tensor = torch.empty(size=(3, 4))
|
||||
ort_tensor = torch.empty(size=(3, 4), device=device)
|
||||
assert ort_tensor.is_ort
|
||||
assert ort_tensor.size() == cpu_tensor.size()
|
||||
|
||||
def test_softmax(self):
|
||||
device = self.get_device()
|
||||
cpu_tensor = torch.rand(3, 5)
|
||||
|
|
|
|||
Loading…
Reference in a new issue