From 8d0e86dec85281db3623809a72e5f6a36ac04eb0 Mon Sep 17 00:00:00 2001 From: Jameson Miller Date: Mon, 25 Jul 2022 07:26:35 -0400 Subject: [PATCH] Apply project formatting rules to ort_aten.cpp (#12294) * Apply project formatting rules to ort_aten.cpp Formatting applied by formatting the file in VS Code. This file is under active development and the inconsistent formatting was causing friction due to: 1. cpplint job on Pipeline was flagging a lot of style issues, resulting in a lot of noisy annotations. 2. local edits would result in changes that are not part of the core change. While there are other files in this part of the source tree with inconsistent formatting, this file was causing the most friction. We can come back and address the other files later, which would be a much larger change. * Apply consistent pattern for invoker.Invoke(...) --- orttraining/orttraining/eager/ort_aten.cpp | 574 ++++++++++----------- 1 file changed, 281 insertions(+), 293 deletions(-) diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 718c05517d..4a18bc61f0 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -2,17 +2,17 @@ // Licensed under the MIT License. #include "ort_aten.h" -#include -#include -#include -#include -#include +#include +#include #include +#include +#include +#include #include -#include #include +#include namespace torch_ort { namespace eager { @@ -20,48 +20,48 @@ namespace eager { // #pragma region Helpers using NodeAttributes = onnxruntime::NodeAttributes; namespace { - inline bool is_device_supported(at::DeviceType type) { - return type == at::kORT || type == at::kCPU; +inline bool is_device_supported(at::DeviceType type) { + return type == at::kORT || type == at::kCPU; +} + +inline void assert_tensor_supported(const at::Tensor& tensor) { + if (tensor.is_sparse()) { + throw std::runtime_error("ORT copy: sparse not supported"); } - inline void assert_tensor_supported(const at::Tensor& tensor) { - if (tensor.is_sparse()) { - throw std::runtime_error("ORT copy: sparse not supported"); - } - - if (tensor.is_quantized()) { - throw std::runtime_error("ORT copy: quantized not supported"); - } - - if (!is_device_supported(tensor.device().type())) { - throw std::runtime_error("ORT copy: device not supported"); - } + if (tensor.is_quantized()) { + throw std::runtime_error("ORT copy: quantized not supported"); } + + if (!is_device_supported(tensor.device().type())) { + throw std::runtime_error("ORT copy: device not supported"); + } +} } // namespace at::Tensor aten_tensor_from_ort( - OrtValue&& ot, - const at::TensorOptions& options) { + OrtValue&& ot, + const at::TensorOptions& options) { return at::Tensor(c10::make_intrusive( - std::move(ot), - options)); + std::move(ot), + options)); } const std::vector aten_tensor_from_ort( - std::vector& ortvalues, - const at::TensorOptions& options) { - const size_t num_outputs = ortvalues.size(); - std::vector atvalues = std::vector(num_outputs); - for (size_t i = 0; i < num_outputs; i++) { - atvalues[i] = at::Tensor(c10::make_intrusive( + std::vector& ortvalues, + const at::TensorOptions& options) { + const size_t num_outputs = ortvalues.size(); + std::vector atvalues = std::vector(num_outputs); + for (size_t i = 0; i < num_outputs; i++) { + atvalues[i] = at::Tensor(c10::make_intrusive( std::move(ortvalues[i]), options)); - } - return atvalues; + } + return atvalues; } onnxruntime::MLDataType ort_scalar_type_from_aten( - at::ScalarType dtype) { + at::ScalarType dtype) { switch (dtype) { case at::kFloat: return onnxruntime::DataTypeImpl::GetType(); @@ -85,15 +85,15 @@ onnxruntime::MLDataType ort_scalar_type_from_aten( } OrtValue create_ort_value( - onnxruntime::ORTInvoker& invoker, - const at::Scalar& scalar) { + onnxruntime::ORTInvoker& invoker, + const at::Scalar& scalar) { return create_ort_value(invoker, scalar, scalar.type()); } OrtValue create_ort_value( - onnxruntime::ORTInvoker& invoker, - const at::Scalar& scalar, - at::ScalarType type) { + onnxruntime::ORTInvoker& invoker, + const at::Scalar& scalar, + at::ScalarType type) { OrtValue ort_val; onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(type), onnxruntime::TensorShape({}), invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ort_val); @@ -131,8 +131,8 @@ OrtValue create_ort_value( } OrtValue create_ort_value( - onnxruntime::ORTInvoker& invoker, - const at::Tensor& tensor) { + onnxruntime::ORTInvoker& invoker, + const at::Tensor& tensor) { assert_tensor_supported(tensor); auto* impl = dynamic_cast(tensor.unsafeGetTensorImpl()); @@ -140,18 +140,18 @@ OrtValue create_ort_value( return impl->tensor(); } - OrtMemoryInfo *mem_info; + OrtMemoryInfo* mem_info; Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info)); auto element_type = ort_scalar_type_from_aten(tensor.scalar_type()); OrtValue ort_tensor; onnxruntime::Tensor::InitOrtValue( - element_type, - onnxruntime::TensorShape(tensor.sizes().vec()), - tensor.data_ptr(), - *mem_info, ort_tensor, - 0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset - tensor.strides().vec()); + element_type, + onnxruntime::TensorShape(tensor.sizes().vec()), + tensor.data_ptr(), + *mem_info, ort_tensor, + 0L, // offset = 0 - because tensor.data_ptr() includes the underlying offset + tensor.strides().vec()); return ort_tensor; } @@ -161,20 +161,20 @@ OrtValue create_ort_value(const at::Tensor& tensor) { } std::vector create_ort_value( - onnxruntime::ORTInvoker& invoker, - at::TensorList values) { - auto output = std::vector{}; - for (auto element : values) { - output.push_back(create_ort_value(element)); - } - return output; + onnxruntime::ORTInvoker& invoker, + at::TensorList values) { + auto output = std::vector{}; + for (auto element : values) { + output.push_back(create_ort_value(element)); + } + return output; } onnx::AttributeProto create_ort_attribute( - const char* name, - at::Scalar value, - const bool isTensor, - at::ScalarType type) { + const char* name, + at::Scalar value, + const bool isTensor, + at::ScalarType type) { if (isTensor) { onnx::AttributeProto attr; attr.set_name(name); @@ -184,28 +184,28 @@ onnx::AttributeProto create_ort_attribute( // Creating a 1 dim tensor of size 1, so add that dim now. constant_attribute_tensor_proto->add_dims(1); switch (type) { - case at::ScalarType::Float: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); - break; - case at::ScalarType::Double: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); - *constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to(); - break; - case at::ScalarType::Bool: - case at::ScalarType::Int: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); - *constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to(); - break; - case at::ScalarType::Long: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - *constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to(); - break; - default: - // For most at::ScalarType, it should be safe to just call value.to<> - // on it, but for now we want to explicitly know when we've encountered - // a new scalar type while bringing up ORT eager mode. - ORT_THROW("Unsupported: at::ScalarType::", value.type()); + case at::ScalarType::Float: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + break; + case at::ScalarType::Double: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + *constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to(); + break; + case at::ScalarType::Bool: + case at::ScalarType::Int: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + *constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to(); + break; + case at::ScalarType::Long: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + *constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to(); + break; + default: + // For most at::ScalarType, it should be safe to just call value.to<> + // on it, but for now we want to explicitly know when we've encountered + // a new scalar type while bringing up ORT eager mode. + ORT_THROW("Unsupported: at::ScalarType::", value.type()); } return attr; } else { @@ -214,16 +214,16 @@ onnx::AttributeProto create_ort_attribute( } onnx::AttributeProto create_ort_attribute( - const char* name, - at::Scalar value, - const bool isTensor) { - return create_ort_attribute(name, value, isTensor, value.type()); + const char* name, + at::Scalar value, + const bool isTensor) { + return create_ort_attribute(name, value, isTensor, value.type()); } onnx::AttributeProto create_ort_attribute( - const char* name, - at::Scalar value, - at::ScalarType type) { + const char* name, + at::Scalar value, + at::ScalarType type) { onnx::AttributeProto attr; attr.set_name(name); switch (type) { @@ -249,8 +249,8 @@ onnx::AttributeProto create_ort_attribute( } onnx::AttributeProto create_ort_attribute( - const char* name, - const char* value) { + const char* name, + const char* value) { onnx::AttributeProto attr; attr.set_name(name); attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING); @@ -259,8 +259,8 @@ onnx::AttributeProto create_ort_attribute( } onnx::AttributeProto create_ort_attribute( - const char* name, - const std::vector values) { + const char* name, + const std::vector values) { onnx::AttributeProto attr; attr.set_name(name); attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); @@ -331,7 +331,6 @@ static c10::optional PromoteScalarTypes( return st; } - c10::optional PromoteScalarTypesWithCategory( const std::vector& typesFromTensors, const std::vector& typesFromScalars) { @@ -370,15 +369,15 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at: std::vector output(1); NodeAttributes attrs(2); attrs["to"] = create_ort_attribute( - "to", GetONNXTensorProtoDataType(type), at::ScalarType::Long); + "to", GetONNXTensorProtoDataType(type), at::ScalarType::Long); - auto status = invoker.Invoke("Cast", { - std::move(input), - }, output, &attrs); + auto status = invoker.Invoke("Cast", + {std::move(input)}, + output, &attrs); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return output[0]; } @@ -389,9 +388,9 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at: * @param keepdim Whether to retain dim or not. Ignored if dimToReduce is null. */ inline at::DimVector calculate_reduction_shape( - const at::Tensor& self, - c10::optional dimToReduce, - bool keepdim) { + const at::Tensor& self, + c10::optional dimToReduce, + bool keepdim) { at::DimVector shape; // If we have dim value, then reduce that dimension. @@ -423,20 +422,20 @@ inline at::DimVector calculate_reduction_shape( * PyToch implementation of resize will warn about resizing * non-empty and indicate this is deprecated behavior that * can / will change. - * + * * In PyTorch repository see: aten/src/ATen/native/Resize.{h|cpp} */ void resize_output( - onnxruntime::ORTInvoker& invoker, - ORTTensorImpl* output, - at::IntArrayRef shape) { + onnxruntime::ORTInvoker& invoker, + ORTTensorImpl* output, + at::IntArrayRef shape) { if (output->sizes().equals(shape)) { return; } if (output->numel() != 0) { throw std::runtime_error( - "resizing a non-empty output tensor is not supported."); + "resizing a non-empty output tensor is not supported."); } resize_impl_ort_(invoker, output, shape); @@ -492,11 +491,11 @@ void resize_impl_ort_( // Just resize existing tensor and return OrtValue new_ort_value = reshape_invoke( - invoker, - self_ort_value, - size, - // invoke reshape kernel inplace - true); + invoker, + self_ort_value, + size, + // invoke reshape kernel inplace + true); // TODO(jamill): Investigate why reshape_invoke kernel does not update inplace self->set_tensor(new_ort_value); @@ -552,12 +551,12 @@ void resize_impl_ort_( namespace aten { at::Tensor empty_strided( - at::IntArrayRef size, - at::IntArrayRef stride, - c10::optional dtype_opt, - c10::optional layout_opt, // Ignored because there's no ONNX support. - c10::optional device_opt, // Will be ORT by the time this is dispatched. - c10::optional pin_memory_opt) { // Ignored because there's no ONNX support. + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional dtype_opt, + c10::optional layout_opt, // Ignored because there's no ONNX support. + c10::optional device_opt, // Will be ORT by the time this is dispatched. + c10::optional pin_memory_opt) { // Ignored because there's no ONNX support. ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); OrtValue ot; @@ -568,19 +567,19 @@ at::Tensor empty_strided( invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot, stride.vec()); return aten_tensor_from_ort( - std::move(ot), - at::TensorOptions() - .device(*device_opt) - .dtype(dtype)); + std::move(ot), + at::TensorOptions() + .device(*device_opt) + .dtype(dtype)); } at::Tensor empty_memory_format( - at::IntArrayRef size, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory, - c10::optional memory_format) { // Ignored because there's no ONNX support. + at::IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory, + c10::optional memory_format) { // Ignored because there's no ONNX support. ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format); // Use the strided impl with default (no strides specified). @@ -589,10 +588,10 @@ at::Tensor empty_memory_format( // aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) at::Tensor as_strided( - const at::Tensor& self, - at::IntArrayRef size, - at::IntArrayRef stride, - c10::optional storage_offset) { + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { ORT_LOG_FN(self, size, stride, storage_offset); auto& invoker = GetORTInvoker(self.device()); auto ort_input = create_ort_value(invoker, self); @@ -604,26 +603,26 @@ at::Tensor as_strided( invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault)->Info(), ot, byte_offset, stride.vec()); return aten_tensor_from_ort( - std::move(ot), - self.options()); + std::move(ot), + self.options()); } at::Tensor _reshape_alias( - const at::Tensor& self, - at::IntArrayRef size, - at::IntArrayRef stride) { + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride) { ORT_LOG_FN(self, size, stride); // TODO(unknown): support stride auto& invoker = GetORTInvoker(self.device()); auto ort_input = create_ort_value(invoker, self); return aten_tensor_from_ort( - reshape_invoke( - invoker, - ort_input, - size, - // invoke reshape kernel inplace - true), - self.options()); + reshape_invoke( + invoker, + ort_input, + size, + // invoke reshape kernel inplace + true), + self.options()); } at::Tensor view(const at::Tensor& self, at::IntArrayRef size) { @@ -631,27 +630,25 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) { auto& invoker = GetORTInvoker(self.device()); auto ort_input = create_ort_value(invoker, self); return aten_tensor_from_ort( - reshape_invoke( - invoker, - ort_input, - size, - // invoke reshape kernel inplace - true), - self.options()); + reshape_invoke( + invoker, + ort_input, + size, + // invoke reshape kernel inplace + true), + self.options()); } at::Tensor& copy_( - at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { + at::Tensor& self, + const at::Tensor& src, + bool non_blocking) { ORT_LOG_FN(self, src, non_blocking); assert_tensor_supported(self); assert_tensor_supported(src); - auto& invoker = GetORTInvoker(self.device().type() == at::kORT - ? self.device() - : src.device()); + auto& invoker = GetORTInvoker(self.device().type() == at::kORT ? self.device() : src.device()); const auto ort_src = create_ort_value(invoker, src); auto ort_self = create_ort_value(invoker, self); if (self.scalar_type() != src.scalar_type()) { @@ -659,15 +656,15 @@ at::Tensor& copy_( std::vector ort_cast_output(1); onnxruntime::NodeAttributes attrs(1); attrs["to"] = create_ort_attribute( - "to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong); + "to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong); - auto status = invoker.Invoke("Cast", { - std::move(ort_src), - }, ort_cast_output, &attrs); + auto status = invoker.Invoke("Cast", + {std::move(ort_src)}, + ort_cast_output, &attrs); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); copy(invoker, ort_cast_output[0], ort_self); } else { @@ -678,16 +675,14 @@ at::Tensor& copy_( } at::Tensor _copy_from_and_resize( - const at::Tensor& self, - const at::Tensor& dst) { + const at::Tensor& self, + const at::Tensor& dst) { ORT_LOG_FN(self, dst); assert_tensor_supported(self); assert_tensor_supported(dst); - auto& invoker = GetORTInvoker(self.device().type() == at::kORT - ? self.device() - : dst.device()); + auto& invoker = GetORTInvoker(self.device().type() == at::kORT ? self.device() : dst.device()); const auto ort_self = create_ort_value(invoker, self); auto ort_dst = create_ort_value(invoker, dst); @@ -710,15 +705,13 @@ at::Tensor& zero_(at::Tensor& self) { std::vector ort_out = {ort_in_self}; - auto status = invoker.Invoke( - "ZeroGradient", { - std::move(ort_in_self), - std::move(flag_val) - }, ort_out, nullptr, onnxruntime::kMSDomain, 1); + auto status = invoker.Invoke("ZeroGradient", + {std::move(ort_in_self), std::move(flag_val)}, + ort_out, nullptr, onnxruntime::kMSDomain, 1); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return self; } @@ -726,19 +719,19 @@ at::Tensor& zero_(at::Tensor& self) { // TODO(unknown): enhance opgen.py to support inplace binary operations. // aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) at::Tensor& add__Tensor( - at::Tensor& self, - const at::Tensor& other, - const at::Scalar& alpha) { + at::Tensor& self, + const at::Tensor& other, + const at::Scalar& alpha) { ORT_LOG_FN(self, other, alpha); auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; if ( - !IsSupportedType(alpha, st) || - !IsSupportedType(other, st) || - !IsSupportedType(self, st)) { + !IsSupportedType(alpha, st) || + !IsSupportedType(other, st) || + !IsSupportedType(self, st)) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(add__Tensor)>::call(self, other, alpha); + &at::native::cpu_fallback, + ATEN_OP(add__Tensor)>::call(self, other, alpha); } auto& invoker = GetORTInvoker(self.device()); @@ -747,39 +740,37 @@ at::Tensor& add__Tensor( std::vector ort_outputs_0_Mul(1); - auto status = invoker.Invoke("Mul", { - std::move(ort_input_alpha), - std::move(ort_input_other), - }, ort_outputs_0_Mul, nullptr); + auto status = invoker.Invoke("Mul", + {std::move(ort_input_alpha), std::move(ort_input_other)}, + ort_outputs_0_Mul, nullptr); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); auto ort_input_self = create_ort_value(invoker, self); std::vector ort_outputs_1_Add(1); ort_outputs_1_Add[0] = ort_input_self; - status = invoker.Invoke("Add", { - std::move(ort_input_self), - std::move(ort_outputs_0_Mul[0]), - }, ort_outputs_1_Add, nullptr); + status = invoker.Invoke("Add", + {std::move(ort_input_self), std::move(ort_outputs_0_Mul[0])}, + ort_outputs_1_Add, nullptr); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return self; } // aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a) at::Tensor slice_Tensor( - const at::Tensor& self, - int64_t dim, - c10::optional start, - c10::optional end, - int64_t step) { + const at::Tensor& self, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { ORT_LOG_FN(self, dim, start, end, step); int64_t ndim = self.dim(); if (ndim == 0) { @@ -823,39 +814,39 @@ at::Tensor slice_Tensor( ort_tensor->DataType(), onnxruntime::TensorShape(new_shape), ort_tensor->MutableDataRaw(), invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault)->Info(), ot, byte_offset, new_stride); return aten_tensor_from_ort( - std::move(ot), - self.options()); + std::move(ot), + self.options()); } // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) at::Tensor& argmax_out( -const at::Tensor& self, -c10::optional dim, -bool keepdim, -// *, -at::Tensor& out) { + const at::Tensor& self, + c10::optional dim, + bool keepdim, + // *, + at::Tensor& out) { ORT_LOG_FN(self, dim, keepdim, out); auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; if ( - !IsSupportedType(self, st)) { + !IsSupportedType(self, st)) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(argmax_out)>::call(self, dim, keepdim, out); + &at::native::cpu_fallback, + ATEN_OP(argmax_out)>::call(self, dim, keepdim, out); } auto& invoker = GetORTInvoker(self.device()); auto ort_input_self = - create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1})); + create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1})); int64_t l_axis = dim.has_value() ? *dim : 0; bool keepdim_effective_value = dim.has_value() ? keepdim : false; NodeAttributes attrs(2); attrs["axis"] = create_ort_attribute( - "axis", l_axis, at::ScalarType::Int); + "axis", l_axis, at::ScalarType::Int); attrs["keepdims"] = create_ort_attribute( - "keepdims", keepdim_effective_value, at::ScalarType::Bool); + "keepdims", keepdim_effective_value, at::ScalarType::Bool); std::vector ort_outputs_0_ArgMax(1); @@ -869,31 +860,31 @@ at::Tensor& out) { auto ort_input_out = create_ort_value(invoker, out); ort_outputs_0_ArgMax[0] = ort_input_out; - auto status = invoker.Invoke("ArgMax", { - std::move(ort_input_self), - }, ort_outputs_0_ArgMax, &attrs); + auto status = invoker.Invoke("ArgMax", + {std::move(ort_input_self)}, + ort_outputs_0_ArgMax, &attrs); if (!status.IsOK()) - throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + throw std::runtime_error( + "ORT return failure status:" + status.ErrorMessage()); return out; } // aten::equal(Tensor self, Tensor other) -> bool bool equal( - const at::Tensor& self, - const at::Tensor& other) { + const at::Tensor& self, + const at::Tensor& other) { ORT_LOG_FN(self, other); if ( - std::vector supportedTypes = - {at::kFloat, at::kBFloat16, at::kHalf, at::kDouble, at::kLong, at::kByte, at::kInt, at::kShort, at::kBool}; - !IsSupportedType(self, supportedTypes) || - !IsSupportedType(other, supportedTypes)) { + std::vector supportedTypes = + {at::kFloat, at::kBFloat16, at::kHalf, at::kDouble, at::kLong, at::kByte, at::kInt, at::kShort, at::kBool}; + !IsSupportedType(self, supportedTypes) || + !IsSupportedType(other, supportedTypes)) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(equal)>::call(self, other); + &at::native::cpu_fallback, + ATEN_OP(equal)>::call(self, other); } auto& invoker = GetORTInvoker(self.device()); @@ -914,19 +905,18 @@ bool equal( // being less than true, so any false will reduce to false. std::vector ort_outputs_0_Equal(1); - auto equalStatus = invoker.Invoke("Equal", { - std::move(ort_input_self), - std::move(ort_input_other), - }, ort_outputs_0_Equal, nullptr); + auto equalStatus = invoker.Invoke("Equal", + {std::move(ort_input_self), std::move(ort_input_other)}, + ort_outputs_0_Equal, nullptr); if (!equalStatus.IsOK()) throw std::runtime_error( - "ORT Equal return failure status:" + equalStatus.ErrorMessage()); + "ORT Equal return failure status:" + equalStatus.ErrorMessage()); // now reduce the resulting tensor of bool values to its minimum value (any false) NodeAttributes attrs(1); attrs["keepdims"] = create_ort_attribute( - "keepdims", 0, at::ScalarType::Int); + "keepdims", 0, at::ScalarType::Int); std::vector ort_outputs_0_ReduceMin(1); @@ -934,13 +924,13 @@ bool equal( // GetONNXTensorProtoDataType doesn't support byte, which leaves us with int OrtValue equalAsInt = CastToType(invoker, ort_outputs_0_Equal[0], at::ScalarType::Int); - auto reduceStatus = invoker.Invoke("ReduceMin", { - std::move(equalAsInt), - }, ort_outputs_0_ReduceMin, &attrs); + auto reduceStatus = invoker.Invoke("ReduceMin", + {std::move(equalAsInt)}, + ort_outputs_0_ReduceMin, &attrs); if (!reduceStatus.IsOK()) throw std::runtime_error( - "ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage()); + "ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage()); auto* ort_tensor = ort_outputs_0_ReduceMin[0].GetMutable(); // the first (and only) value of the tensor will be 0 for false else true @@ -970,18 +960,18 @@ const at::Tensor& resize_( // aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) at::Tensor& fill__Scalar( - at::Tensor& self, - const at::Scalar& value) { + at::Tensor& self, + const at::Scalar& value) { ORT_LOG_FN(self, value); if ( - std::vector supportedTypes = - {at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool}; - !IsSupportedType(self, supportedTypes)) { + std::vector supportedTypes = + {at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool}; + !IsSupportedType(self, supportedTypes)) { std::cout << "fill__Scalar - Fell back to cpu!\n"; return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(fill__Scalar)>::call(self, value); + &at::native::cpu_fallback, + ATEN_OP(fill__Scalar)>::call(self, value); } auto& invoker = GetORTInvoker(self.device()); @@ -989,37 +979,37 @@ at::Tensor& fill__Scalar( std::vector ort_outputs_0_Shape(1); - auto status = invoker.Invoke("Shape", { - std::move(ort_input_self), - }, ort_outputs_0_Shape, nullptr); + auto status = invoker.Invoke("Shape", + {std::move(ort_input_self)}, + ort_outputs_0_Shape, nullptr); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); std::vector ort_outputs_1_ConstantOfShape(1); ort_outputs_1_ConstantOfShape[0] = ort_input_self; NodeAttributes attrs(1); attrs["value"] = create_ort_attribute( - "value", value, true, self.scalar_type()); + "value", value, true, self.scalar_type()); - status = invoker.Invoke("ConstantOfShape", { - std::move(ort_outputs_0_Shape[0]), - }, ort_outputs_1_ConstantOfShape, &attrs); + status = invoker.Invoke("ConstantOfShape", + {std::move(ort_outputs_0_Shape[0])}, + ort_outputs_1_ConstantOfShape, &attrs); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return self; } // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) at::Tensor& nonzero_out( - const at::Tensor& self, - // *, - at::Tensor& out) { + const at::Tensor& self, + // *, + at::Tensor& out) { ORT_LOG_FN(self, out); auto temp = eager::aten::nonzero(self); @@ -1036,18 +1026,18 @@ at::Tensor& nonzero_out( // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) at::Tensor& _log_softmax_out( - const at::Tensor& self, - int64_t dim, - bool half_to_float, - // *, - at::Tensor& out) { + const at::Tensor& self, + int64_t dim, + bool half_to_float, + // *, + at::Tensor& out) { ORT_LOG_FN(self, dim, half_to_float, out); if ( - !IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) { + !IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out); + &at::native::cpu_fallback, + ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out); } auto& invoker = GetORTInvoker(self.device()); @@ -1074,30 +1064,30 @@ at::Tensor& _log_softmax_out( for (int64_t i = 0; i < ndim; i++) axes.push_back(i); - axes[dim] = ndim-1; - axes[ndim-1] = dim; - dim = ndim-1; + axes[dim] = ndim - 1; + axes[ndim - 1] = dim; + dim = ndim - 1; NodeAttributes attrs_0(1); attrs_0["perm"] = create_ort_attribute("perm", axes); - auto status = invoker.Invoke("Transpose", { - std::move(ort_input_0_self), - }, ort_outputs_0_Transpose, &attrs_0); + auto status = invoker.Invoke("Transpose", + {std::move(ort_input_0_self)}, + ort_outputs_0_Transpose, &attrs_0); CHECK_STATUS(status); } NodeAttributes attrs_1(1); attrs_1["axis"] = create_ort_attribute( - "axis", dim, at::ScalarType::Int); + "axis", dim, at::ScalarType::Int); std::vector ort_outputs_1_LogSoftmax(1); if (!need_transpose) { ort_outputs_1_LogSoftmax[0] = ort_input_out; } - auto status = invoker.Invoke("LogSoftmax", { - std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self), - }, ort_outputs_1_LogSoftmax, &attrs_1); + auto status = invoker.Invoke("LogSoftmax", + {std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self)}, + ort_outputs_1_LogSoftmax, &attrs_1); CHECK_STATUS(status); std::vector ort_outputs_2_Transpose(1); @@ -1108,9 +1098,9 @@ at::Tensor& _log_softmax_out( NodeAttributes attrs_2(1); attrs_2["perm"] = create_ort_attribute("perm", axes); - status = invoker.Invoke("Transpose", { - std::move(ort_outputs_1_LogSoftmax[0]), - }, ort_outputs_2_Transpose, &attrs_2); + status = invoker.Invoke("Transpose", + {std::move(ort_outputs_1_LogSoftmax[0])}, + ort_outputs_2_Transpose, &attrs_2); CHECK_STATUS(status); } @@ -1121,28 +1111,28 @@ at::Tensor& _log_softmax_out( // mm is for matrix multiplication and does not broadcast. // https://pytorch.org/docs/stable/generated/torch.mm.html at::Tensor& mm_out( - const at::Tensor& self, - const at::Tensor& mat2, - // *, - at::Tensor& out) { + const at::Tensor& self, + const at::Tensor& mat2, + // *, + at::Tensor& out) { ORT_LOG_FN(self, mat2, out); if ( - std::vector supportedTypes = - {at::kDouble, at::kLong, at::kHalf, at::kFloat, at::kBFloat16, at::kInt}; - !IsSupportedType(self, supportedTypes) || - !IsSupportedType(mat2, supportedTypes) || - // to match cpu device behavior for torch.mm, verify the following and fall back to cpu to generate error message. - // 1. self and mat2 must be 2-D (matrices) - self.dim() != 2 || mat2.dim() != 2 || - // 2. self and mat2 can be multiplied - self.sizes()[1] != mat2.sizes()[0] || - // 3. self, mat2, and out are of the same type - self.scalar_type() != out.scalar_type() || - self.scalar_type() != mat2.scalar_type()) { + std::vector supportedTypes = + {at::kDouble, at::kLong, at::kHalf, at::kFloat, at::kBFloat16, at::kInt}; + !IsSupportedType(self, supportedTypes) || + !IsSupportedType(mat2, supportedTypes) || + // to match cpu device behavior for torch.mm, verify the following and fall back to cpu to generate error message. + // 1. self and mat2 must be 2-D (matrices) + self.dim() != 2 || mat2.dim() != 2 || + // 2. self and mat2 can be multiplied + self.sizes()[1] != mat2.sizes()[0] || + // 3. self, mat2, and out are of the same type + self.scalar_type() != out.scalar_type() || + self.scalar_type() != mat2.scalar_type()) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(mm_out)>::call(self, mat2, out); + &at::native::cpu_fallback, + ATEN_OP(mm_out)>::call(self, mat2, out); } auto& invoker = GetORTInvoker(self.device()); @@ -1165,16 +1155,14 @@ at::Tensor& mm_out( std::vector ort_outputs_0_MatMul(1); ort_outputs_0_MatMul[0] = ort_input_out; - auto status = invoker.Invoke("MatMul", { - std::move(ort_input_0_self), - std::move(ort_input_0_mat2), - }, ort_outputs_0_MatMul, nullptr); + auto status = invoker.Invoke("MatMul", + {std::move(ort_input_0_self), std::move(ort_input_0_mat2)}, + ort_outputs_0_MatMul, nullptr); CHECK_STATUS(status); return out; } - } // namespace aten // #pragma endregion