diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 718c05517d..4a18bc61f0 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -2,17 +2,17 @@ // Licensed under the MIT License. #include "ort_aten.h" -#include -#include -#include -#include -#include +#include +#include #include +#include +#include +#include #include -#include #include +#include namespace torch_ort { namespace eager { @@ -20,48 +20,48 @@ namespace eager { // #pragma region Helpers using NodeAttributes = onnxruntime::NodeAttributes; namespace { - inline bool is_device_supported(at::DeviceType type) { - return type == at::kORT || type == at::kCPU; +inline bool is_device_supported(at::DeviceType type) { + return type == at::kORT || type == at::kCPU; +} + +inline void assert_tensor_supported(const at::Tensor& tensor) { + if (tensor.is_sparse()) { + throw std::runtime_error("ORT copy: sparse not supported"); } - inline void assert_tensor_supported(const at::Tensor& tensor) { - if (tensor.is_sparse()) { - throw std::runtime_error("ORT copy: sparse not supported"); - } - - if (tensor.is_quantized()) { - throw std::runtime_error("ORT copy: quantized not supported"); - } - - if (!is_device_supported(tensor.device().type())) { - throw std::runtime_error("ORT copy: device not supported"); - } + if (tensor.is_quantized()) { + throw std::runtime_error("ORT copy: quantized not supported"); } + + if (!is_device_supported(tensor.device().type())) { + throw std::runtime_error("ORT copy: device not supported"); + } +} } // namespace at::Tensor aten_tensor_from_ort( - OrtValue&& ot, - const at::TensorOptions& options) { + OrtValue&& ot, + const at::TensorOptions& options) { return at::Tensor(c10::make_intrusive( - std::move(ot), - options)); + std::move(ot), + options)); } const std::vector aten_tensor_from_ort( - std::vector& ortvalues, - const at::TensorOptions& options) { - const size_t num_outputs = ortvalues.size(); - std::vector atvalues = std::vector(num_outputs); - for (size_t i = 0; i < num_outputs; i++) { - atvalues[i] = at::Tensor(c10::make_intrusive( + std::vector& ortvalues, + const at::TensorOptions& options) { + const size_t num_outputs = ortvalues.size(); + std::vector atvalues = std::vector(num_outputs); + for (size_t i = 0; i < num_outputs; i++) { + atvalues[i] = at::Tensor(c10::make_intrusive( std::move(ortvalues[i]), options)); - } - return atvalues; + } + return atvalues; } onnxruntime::MLDataType ort_scalar_type_from_aten( - at::ScalarType dtype) { + at::ScalarType dtype) { switch (dtype) { case at::kFloat: return onnxruntime::DataTypeImpl::GetType(); @@ -85,15 +85,15 @@ onnxruntime::MLDataType ort_scalar_type_from_aten( } OrtValue create_ort_value( - onnxruntime::ORTInvoker& invoker, - const at::Scalar& scalar) { + onnxruntime::ORTInvoker& invoker, + const at::Scalar& scalar) { return create_ort_value(invoker, scalar, scalar.type()); } OrtValue create_ort_value( - onnxruntime::ORTInvoker& invoker, - const at::Scalar& scalar, - at::ScalarType type) { + onnxruntime::ORTInvoker& invoker, + const at::Scalar& scalar, + at::ScalarType type) { OrtValue ort_val; onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(type), onnxruntime::TensorShape({}), invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ort_val); @@ -131,8 +131,8 @@ OrtValue create_ort_value( } OrtValue create_ort_value( - onnxruntime::ORTInvoker& invoker, - const at::Tensor& tensor) { + onnxruntime::ORTInvoker& invoker, + const at::Tensor& tensor) { assert_tensor_supported(tensor); auto* impl = dynamic_cast(tensor.unsafeGetTensorImpl()); @@ -140,18 +140,18 @@ OrtValue create_ort_value( return impl->tensor(); } - OrtMemoryInfo *mem_info; + OrtMemoryInfo* mem_info; Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info)); auto element_type = ort_scalar_type_from_aten(tensor.scalar_type()); OrtValue ort_tensor; onnxruntime::Tensor::InitOrtValue( - element_type, - onnxruntime::TensorShape(tensor.sizes().vec()), - tensor.data_ptr(), - *mem_info, ort_tensor, - 0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset - tensor.strides().vec()); + element_type, + onnxruntime::TensorShape(tensor.sizes().vec()), + tensor.data_ptr(), + *mem_info, ort_tensor, + 0L, // offset = 0 - because tensor.data_ptr() includes the underlying offset + tensor.strides().vec()); return ort_tensor; } @@ -161,20 +161,20 @@ OrtValue create_ort_value(const at::Tensor& tensor) { } std::vector create_ort_value( - onnxruntime::ORTInvoker& invoker, - at::TensorList values) { - auto output = std::vector{}; - for (auto element : values) { - output.push_back(create_ort_value(element)); - } - return output; + onnxruntime::ORTInvoker& invoker, + at::TensorList values) { + auto output = std::vector{}; + for (auto element : values) { + output.push_back(create_ort_value(element)); + } + return output; } onnx::AttributeProto create_ort_attribute( - const char* name, - at::Scalar value, - const bool isTensor, - at::ScalarType type) { + const char* name, + at::Scalar value, + const bool isTensor, + at::ScalarType type) { if (isTensor) { onnx::AttributeProto attr; attr.set_name(name); @@ -184,28 +184,28 @@ onnx::AttributeProto create_ort_attribute( // Creating a 1 dim tensor of size 1, so add that dim now. constant_attribute_tensor_proto->add_dims(1); switch (type) { - case at::ScalarType::Float: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); - break; - case at::ScalarType::Double: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); - *constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to(); - break; - case at::ScalarType::Bool: - case at::ScalarType::Int: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); - *constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to(); - break; - case at::ScalarType::Long: - constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - *constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to(); - break; - default: - // For most at::ScalarType, it should be safe to just call value.to<> - // on it, but for now we want to explicitly know when we've encountered - // a new scalar type while bringing up ORT eager mode. - ORT_THROW("Unsupported: at::ScalarType::", value.type()); + case at::ScalarType::Float: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + break; + case at::ScalarType::Double: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + *constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to(); + break; + case at::ScalarType::Bool: + case at::ScalarType::Int: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + *constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to(); + break; + case at::ScalarType::Long: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + *constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to(); + break; + default: + // For most at::ScalarType, it should be safe to just call value.to<> + // on it, but for now we want to explicitly know when we've encountered + // a new scalar type while bringing up ORT eager mode. + ORT_THROW("Unsupported: at::ScalarType::", value.type()); } return attr; } else { @@ -214,16 +214,16 @@ onnx::AttributeProto create_ort_attribute( } onnx::AttributeProto create_ort_attribute( - const char* name, - at::Scalar value, - const bool isTensor) { - return create_ort_attribute(name, value, isTensor, value.type()); + const char* name, + at::Scalar value, + const bool isTensor) { + return create_ort_attribute(name, value, isTensor, value.type()); } onnx::AttributeProto create_ort_attribute( - const char* name, - at::Scalar value, - at::ScalarType type) { + const char* name, + at::Scalar value, + at::ScalarType type) { onnx::AttributeProto attr; attr.set_name(name); switch (type) { @@ -249,8 +249,8 @@ onnx::AttributeProto create_ort_attribute( } onnx::AttributeProto create_ort_attribute( - const char* name, - const char* value) { + const char* name, + const char* value) { onnx::AttributeProto attr; attr.set_name(name); attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING); @@ -259,8 +259,8 @@ onnx::AttributeProto create_ort_attribute( } onnx::AttributeProto create_ort_attribute( - const char* name, - const std::vector values) { + const char* name, + const std::vector values) { onnx::AttributeProto attr; attr.set_name(name); attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); @@ -331,7 +331,6 @@ static c10::optional PromoteScalarTypes( return st; } - c10::optional PromoteScalarTypesWithCategory( const std::vector& typesFromTensors, const std::vector& typesFromScalars) { @@ -370,15 +369,15 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at: std::vector output(1); NodeAttributes attrs(2); attrs["to"] = create_ort_attribute( - "to", GetONNXTensorProtoDataType(type), at::ScalarType::Long); + "to", GetONNXTensorProtoDataType(type), at::ScalarType::Long); - auto status = invoker.Invoke("Cast", { - std::move(input), - }, output, &attrs); + auto status = invoker.Invoke("Cast", + {std::move(input)}, + output, &attrs); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return output[0]; } @@ -389,9 +388,9 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at: * @param keepdim Whether to retain dim or not. Ignored if dimToReduce is null. */ inline at::DimVector calculate_reduction_shape( - const at::Tensor& self, - c10::optional dimToReduce, - bool keepdim) { + const at::Tensor& self, + c10::optional dimToReduce, + bool keepdim) { at::DimVector shape; // If we have dim value, then reduce that dimension. @@ -423,20 +422,20 @@ inline at::DimVector calculate_reduction_shape( * PyToch implementation of resize will warn about resizing * non-empty and indicate this is deprecated behavior that * can / will change. - * + * * In PyTorch repository see: aten/src/ATen/native/Resize.{h|cpp} */ void resize_output( - onnxruntime::ORTInvoker& invoker, - ORTTensorImpl* output, - at::IntArrayRef shape) { + onnxruntime::ORTInvoker& invoker, + ORTTensorImpl* output, + at::IntArrayRef shape) { if (output->sizes().equals(shape)) { return; } if (output->numel() != 0) { throw std::runtime_error( - "resizing a non-empty output tensor is not supported."); + "resizing a non-empty output tensor is not supported."); } resize_impl_ort_(invoker, output, shape); @@ -492,11 +491,11 @@ void resize_impl_ort_( // Just resize existing tensor and return OrtValue new_ort_value = reshape_invoke( - invoker, - self_ort_value, - size, - // invoke reshape kernel inplace - true); + invoker, + self_ort_value, + size, + // invoke reshape kernel inplace + true); // TODO(jamill): Investigate why reshape_invoke kernel does not update inplace self->set_tensor(new_ort_value); @@ -552,12 +551,12 @@ void resize_impl_ort_( namespace aten { at::Tensor empty_strided( - at::IntArrayRef size, - at::IntArrayRef stride, - c10::optional dtype_opt, - c10::optional layout_opt, // Ignored because there's no ONNX support. - c10::optional device_opt, // Will be ORT by the time this is dispatched. - c10::optional pin_memory_opt) { // Ignored because there's no ONNX support. + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional dtype_opt, + c10::optional layout_opt, // Ignored because there's no ONNX support. + c10::optional device_opt, // Will be ORT by the time this is dispatched. + c10::optional pin_memory_opt) { // Ignored because there's no ONNX support. ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); OrtValue ot; @@ -568,19 +567,19 @@ at::Tensor empty_strided( invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot, stride.vec()); return aten_tensor_from_ort( - std::move(ot), - at::TensorOptions() - .device(*device_opt) - .dtype(dtype)); + std::move(ot), + at::TensorOptions() + .device(*device_opt) + .dtype(dtype)); } at::Tensor empty_memory_format( - at::IntArrayRef size, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory, - c10::optional memory_format) { // Ignored because there's no ONNX support. + at::IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory, + c10::optional memory_format) { // Ignored because there's no ONNX support. ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format); // Use the strided impl with default (no strides specified). @@ -589,10 +588,10 @@ at::Tensor empty_memory_format( // aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) at::Tensor as_strided( - const at::Tensor& self, - at::IntArrayRef size, - at::IntArrayRef stride, - c10::optional storage_offset) { + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride, + c10::optional storage_offset) { ORT_LOG_FN(self, size, stride, storage_offset); auto& invoker = GetORTInvoker(self.device()); auto ort_input = create_ort_value(invoker, self); @@ -604,26 +603,26 @@ at::Tensor as_strided( invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault)->Info(), ot, byte_offset, stride.vec()); return aten_tensor_from_ort( - std::move(ot), - self.options()); + std::move(ot), + self.options()); } at::Tensor _reshape_alias( - const at::Tensor& self, - at::IntArrayRef size, - at::IntArrayRef stride) { + const at::Tensor& self, + at::IntArrayRef size, + at::IntArrayRef stride) { ORT_LOG_FN(self, size, stride); // TODO(unknown): support stride auto& invoker = GetORTInvoker(self.device()); auto ort_input = create_ort_value(invoker, self); return aten_tensor_from_ort( - reshape_invoke( - invoker, - ort_input, - size, - // invoke reshape kernel inplace - true), - self.options()); + reshape_invoke( + invoker, + ort_input, + size, + // invoke reshape kernel inplace + true), + self.options()); } at::Tensor view(const at::Tensor& self, at::IntArrayRef size) { @@ -631,27 +630,25 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) { auto& invoker = GetORTInvoker(self.device()); auto ort_input = create_ort_value(invoker, self); return aten_tensor_from_ort( - reshape_invoke( - invoker, - ort_input, - size, - // invoke reshape kernel inplace - true), - self.options()); + reshape_invoke( + invoker, + ort_input, + size, + // invoke reshape kernel inplace + true), + self.options()); } at::Tensor& copy_( - at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { + at::Tensor& self, + const at::Tensor& src, + bool non_blocking) { ORT_LOG_FN(self, src, non_blocking); assert_tensor_supported(self); assert_tensor_supported(src); - auto& invoker = GetORTInvoker(self.device().type() == at::kORT - ? self.device() - : src.device()); + auto& invoker = GetORTInvoker(self.device().type() == at::kORT ? self.device() : src.device()); const auto ort_src = create_ort_value(invoker, src); auto ort_self = create_ort_value(invoker, self); if (self.scalar_type() != src.scalar_type()) { @@ -659,15 +656,15 @@ at::Tensor& copy_( std::vector ort_cast_output(1); onnxruntime::NodeAttributes attrs(1); attrs["to"] = create_ort_attribute( - "to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong); + "to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong); - auto status = invoker.Invoke("Cast", { - std::move(ort_src), - }, ort_cast_output, &attrs); + auto status = invoker.Invoke("Cast", + {std::move(ort_src)}, + ort_cast_output, &attrs); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); copy(invoker, ort_cast_output[0], ort_self); } else { @@ -678,16 +675,14 @@ at::Tensor& copy_( } at::Tensor _copy_from_and_resize( - const at::Tensor& self, - const at::Tensor& dst) { + const at::Tensor& self, + const at::Tensor& dst) { ORT_LOG_FN(self, dst); assert_tensor_supported(self); assert_tensor_supported(dst); - auto& invoker = GetORTInvoker(self.device().type() == at::kORT - ? self.device() - : dst.device()); + auto& invoker = GetORTInvoker(self.device().type() == at::kORT ? self.device() : dst.device()); const auto ort_self = create_ort_value(invoker, self); auto ort_dst = create_ort_value(invoker, dst); @@ -710,15 +705,13 @@ at::Tensor& zero_(at::Tensor& self) { std::vector ort_out = {ort_in_self}; - auto status = invoker.Invoke( - "ZeroGradient", { - std::move(ort_in_self), - std::move(flag_val) - }, ort_out, nullptr, onnxruntime::kMSDomain, 1); + auto status = invoker.Invoke("ZeroGradient", + {std::move(ort_in_self), std::move(flag_val)}, + ort_out, nullptr, onnxruntime::kMSDomain, 1); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return self; } @@ -726,19 +719,19 @@ at::Tensor& zero_(at::Tensor& self) { // TODO(unknown): enhance opgen.py to support inplace binary operations. // aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) at::Tensor& add__Tensor( - at::Tensor& self, - const at::Tensor& other, - const at::Scalar& alpha) { + at::Tensor& self, + const at::Tensor& other, + const at::Scalar& alpha) { ORT_LOG_FN(self, other, alpha); auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; if ( - !IsSupportedType(alpha, st) || - !IsSupportedType(other, st) || - !IsSupportedType(self, st)) { + !IsSupportedType(alpha, st) || + !IsSupportedType(other, st) || + !IsSupportedType(self, st)) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(add__Tensor)>::call(self, other, alpha); + &at::native::cpu_fallback, + ATEN_OP(add__Tensor)>::call(self, other, alpha); } auto& invoker = GetORTInvoker(self.device()); @@ -747,39 +740,37 @@ at::Tensor& add__Tensor( std::vector ort_outputs_0_Mul(1); - auto status = invoker.Invoke("Mul", { - std::move(ort_input_alpha), - std::move(ort_input_other), - }, ort_outputs_0_Mul, nullptr); + auto status = invoker.Invoke("Mul", + {std::move(ort_input_alpha), std::move(ort_input_other)}, + ort_outputs_0_Mul, nullptr); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); auto ort_input_self = create_ort_value(invoker, self); std::vector ort_outputs_1_Add(1); ort_outputs_1_Add[0] = ort_input_self; - status = invoker.Invoke("Add", { - std::move(ort_input_self), - std::move(ort_outputs_0_Mul[0]), - }, ort_outputs_1_Add, nullptr); + status = invoker.Invoke("Add", + {std::move(ort_input_self), std::move(ort_outputs_0_Mul[0])}, + ort_outputs_1_Add, nullptr); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return self; } // aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a) at::Tensor slice_Tensor( - const at::Tensor& self, - int64_t dim, - c10::optional start, - c10::optional end, - int64_t step) { + const at::Tensor& self, + int64_t dim, + c10::optional start, + c10::optional end, + int64_t step) { ORT_LOG_FN(self, dim, start, end, step); int64_t ndim = self.dim(); if (ndim == 0) { @@ -823,39 +814,39 @@ at::Tensor slice_Tensor( ort_tensor->DataType(), onnxruntime::TensorShape(new_shape), ort_tensor->MutableDataRaw(), invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault)->Info(), ot, byte_offset, new_stride); return aten_tensor_from_ort( - std::move(ot), - self.options()); + std::move(ot), + self.options()); } // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) at::Tensor& argmax_out( -const at::Tensor& self, -c10::optional dim, -bool keepdim, -// *, -at::Tensor& out) { + const at::Tensor& self, + c10::optional dim, + bool keepdim, + // *, + at::Tensor& out) { ORT_LOG_FN(self, dim, keepdim, out); auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; if ( - !IsSupportedType(self, st)) { + !IsSupportedType(self, st)) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(argmax_out)>::call(self, dim, keepdim, out); + &at::native::cpu_fallback, + ATEN_OP(argmax_out)>::call(self, dim, keepdim, out); } auto& invoker = GetORTInvoker(self.device()); auto ort_input_self = - create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1})); + create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1})); int64_t l_axis = dim.has_value() ? *dim : 0; bool keepdim_effective_value = dim.has_value() ? keepdim : false; NodeAttributes attrs(2); attrs["axis"] = create_ort_attribute( - "axis", l_axis, at::ScalarType::Int); + "axis", l_axis, at::ScalarType::Int); attrs["keepdims"] = create_ort_attribute( - "keepdims", keepdim_effective_value, at::ScalarType::Bool); + "keepdims", keepdim_effective_value, at::ScalarType::Bool); std::vector ort_outputs_0_ArgMax(1); @@ -869,31 +860,31 @@ at::Tensor& out) { auto ort_input_out = create_ort_value(invoker, out); ort_outputs_0_ArgMax[0] = ort_input_out; - auto status = invoker.Invoke("ArgMax", { - std::move(ort_input_self), - }, ort_outputs_0_ArgMax, &attrs); + auto status = invoker.Invoke("ArgMax", + {std::move(ort_input_self)}, + ort_outputs_0_ArgMax, &attrs); if (!status.IsOK()) - throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + throw std::runtime_error( + "ORT return failure status:" + status.ErrorMessage()); return out; } // aten::equal(Tensor self, Tensor other) -> bool bool equal( - const at::Tensor& self, - const at::Tensor& other) { + const at::Tensor& self, + const at::Tensor& other) { ORT_LOG_FN(self, other); if ( - std::vector supportedTypes = - {at::kFloat, at::kBFloat16, at::kHalf, at::kDouble, at::kLong, at::kByte, at::kInt, at::kShort, at::kBool}; - !IsSupportedType(self, supportedTypes) || - !IsSupportedType(other, supportedTypes)) { + std::vector supportedTypes = + {at::kFloat, at::kBFloat16, at::kHalf, at::kDouble, at::kLong, at::kByte, at::kInt, at::kShort, at::kBool}; + !IsSupportedType(self, supportedTypes) || + !IsSupportedType(other, supportedTypes)) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(equal)>::call(self, other); + &at::native::cpu_fallback, + ATEN_OP(equal)>::call(self, other); } auto& invoker = GetORTInvoker(self.device()); @@ -914,19 +905,18 @@ bool equal( // being less than true, so any false will reduce to false. std::vector ort_outputs_0_Equal(1); - auto equalStatus = invoker.Invoke("Equal", { - std::move(ort_input_self), - std::move(ort_input_other), - }, ort_outputs_0_Equal, nullptr); + auto equalStatus = invoker.Invoke("Equal", + {std::move(ort_input_self), std::move(ort_input_other)}, + ort_outputs_0_Equal, nullptr); if (!equalStatus.IsOK()) throw std::runtime_error( - "ORT Equal return failure status:" + equalStatus.ErrorMessage()); + "ORT Equal return failure status:" + equalStatus.ErrorMessage()); // now reduce the resulting tensor of bool values to its minimum value (any false) NodeAttributes attrs(1); attrs["keepdims"] = create_ort_attribute( - "keepdims", 0, at::ScalarType::Int); + "keepdims", 0, at::ScalarType::Int); std::vector ort_outputs_0_ReduceMin(1); @@ -934,13 +924,13 @@ bool equal( // GetONNXTensorProtoDataType doesn't support byte, which leaves us with int OrtValue equalAsInt = CastToType(invoker, ort_outputs_0_Equal[0], at::ScalarType::Int); - auto reduceStatus = invoker.Invoke("ReduceMin", { - std::move(equalAsInt), - }, ort_outputs_0_ReduceMin, &attrs); + auto reduceStatus = invoker.Invoke("ReduceMin", + {std::move(equalAsInt)}, + ort_outputs_0_ReduceMin, &attrs); if (!reduceStatus.IsOK()) throw std::runtime_error( - "ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage()); + "ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage()); auto* ort_tensor = ort_outputs_0_ReduceMin[0].GetMutable(); // the first (and only) value of the tensor will be 0 for false else true @@ -970,18 +960,18 @@ const at::Tensor& resize_( // aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) at::Tensor& fill__Scalar( - at::Tensor& self, - const at::Scalar& value) { + at::Tensor& self, + const at::Scalar& value) { ORT_LOG_FN(self, value); if ( - std::vector supportedTypes = - {at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool}; - !IsSupportedType(self, supportedTypes)) { + std::vector supportedTypes = + {at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool}; + !IsSupportedType(self, supportedTypes)) { std::cout << "fill__Scalar - Fell back to cpu!\n"; return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(fill__Scalar)>::call(self, value); + &at::native::cpu_fallback, + ATEN_OP(fill__Scalar)>::call(self, value); } auto& invoker = GetORTInvoker(self.device()); @@ -989,37 +979,37 @@ at::Tensor& fill__Scalar( std::vector ort_outputs_0_Shape(1); - auto status = invoker.Invoke("Shape", { - std::move(ort_input_self), - }, ort_outputs_0_Shape, nullptr); + auto status = invoker.Invoke("Shape", + {std::move(ort_input_self)}, + ort_outputs_0_Shape, nullptr); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); std::vector ort_outputs_1_ConstantOfShape(1); ort_outputs_1_ConstantOfShape[0] = ort_input_self; NodeAttributes attrs(1); attrs["value"] = create_ort_attribute( - "value", value, true, self.scalar_type()); + "value", value, true, self.scalar_type()); - status = invoker.Invoke("ConstantOfShape", { - std::move(ort_outputs_0_Shape[0]), - }, ort_outputs_1_ConstantOfShape, &attrs); + status = invoker.Invoke("ConstantOfShape", + {std::move(ort_outputs_0_Shape[0])}, + ort_outputs_1_ConstantOfShape, &attrs); if (!status.IsOK()) throw std::runtime_error( - "ORT return failure status:" + status.ErrorMessage()); + "ORT return failure status:" + status.ErrorMessage()); return self; } // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) at::Tensor& nonzero_out( - const at::Tensor& self, - // *, - at::Tensor& out) { + const at::Tensor& self, + // *, + at::Tensor& out) { ORT_LOG_FN(self, out); auto temp = eager::aten::nonzero(self); @@ -1036,18 +1026,18 @@ at::Tensor& nonzero_out( // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) at::Tensor& _log_softmax_out( - const at::Tensor& self, - int64_t dim, - bool half_to_float, - // *, - at::Tensor& out) { + const at::Tensor& self, + int64_t dim, + bool half_to_float, + // *, + at::Tensor& out) { ORT_LOG_FN(self, dim, half_to_float, out); if ( - !IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) { + !IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out); + &at::native::cpu_fallback, + ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out); } auto& invoker = GetORTInvoker(self.device()); @@ -1074,30 +1064,30 @@ at::Tensor& _log_softmax_out( for (int64_t i = 0; i < ndim; i++) axes.push_back(i); - axes[dim] = ndim-1; - axes[ndim-1] = dim; - dim = ndim-1; + axes[dim] = ndim - 1; + axes[ndim - 1] = dim; + dim = ndim - 1; NodeAttributes attrs_0(1); attrs_0["perm"] = create_ort_attribute("perm", axes); - auto status = invoker.Invoke("Transpose", { - std::move(ort_input_0_self), - }, ort_outputs_0_Transpose, &attrs_0); + auto status = invoker.Invoke("Transpose", + {std::move(ort_input_0_self)}, + ort_outputs_0_Transpose, &attrs_0); CHECK_STATUS(status); } NodeAttributes attrs_1(1); attrs_1["axis"] = create_ort_attribute( - "axis", dim, at::ScalarType::Int); + "axis", dim, at::ScalarType::Int); std::vector ort_outputs_1_LogSoftmax(1); if (!need_transpose) { ort_outputs_1_LogSoftmax[0] = ort_input_out; } - auto status = invoker.Invoke("LogSoftmax", { - std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self), - }, ort_outputs_1_LogSoftmax, &attrs_1); + auto status = invoker.Invoke("LogSoftmax", + {std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self)}, + ort_outputs_1_LogSoftmax, &attrs_1); CHECK_STATUS(status); std::vector ort_outputs_2_Transpose(1); @@ -1108,9 +1098,9 @@ at::Tensor& _log_softmax_out( NodeAttributes attrs_2(1); attrs_2["perm"] = create_ort_attribute("perm", axes); - status = invoker.Invoke("Transpose", { - std::move(ort_outputs_1_LogSoftmax[0]), - }, ort_outputs_2_Transpose, &attrs_2); + status = invoker.Invoke("Transpose", + {std::move(ort_outputs_1_LogSoftmax[0])}, + ort_outputs_2_Transpose, &attrs_2); CHECK_STATUS(status); } @@ -1121,28 +1111,28 @@ at::Tensor& _log_softmax_out( // mm is for matrix multiplication and does not broadcast. // https://pytorch.org/docs/stable/generated/torch.mm.html at::Tensor& mm_out( - const at::Tensor& self, - const at::Tensor& mat2, - // *, - at::Tensor& out) { + const at::Tensor& self, + const at::Tensor& mat2, + // *, + at::Tensor& out) { ORT_LOG_FN(self, mat2, out); if ( - std::vector supportedTypes = - {at::kDouble, at::kLong, at::kHalf, at::kFloat, at::kBFloat16, at::kInt}; - !IsSupportedType(self, supportedTypes) || - !IsSupportedType(mat2, supportedTypes) || - // to match cpu device behavior for torch.mm, verify the following and fall back to cpu to generate error message. - // 1. self and mat2 must be 2-D (matrices) - self.dim() != 2 || mat2.dim() != 2 || - // 2. self and mat2 can be multiplied - self.sizes()[1] != mat2.sizes()[0] || - // 3. self, mat2, and out are of the same type - self.scalar_type() != out.scalar_type() || - self.scalar_type() != mat2.scalar_type()) { + std::vector supportedTypes = + {at::kDouble, at::kLong, at::kHalf, at::kFloat, at::kBFloat16, at::kInt}; + !IsSupportedType(self, supportedTypes) || + !IsSupportedType(mat2, supportedTypes) || + // to match cpu device behavior for torch.mm, verify the following and fall back to cpu to generate error message. + // 1. self and mat2 must be 2-D (matrices) + self.dim() != 2 || mat2.dim() != 2 || + // 2. self and mat2 can be multiplied + self.sizes()[1] != mat2.sizes()[0] || + // 3. self, mat2, and out are of the same type + self.scalar_type() != out.scalar_type() || + self.scalar_type() != mat2.scalar_type()) { return at::native::call_fallback_fn< - &at::native::cpu_fallback, - ATEN_OP(mm_out)>::call(self, mat2, out); + &at::native::cpu_fallback, + ATEN_OP(mm_out)>::call(self, mat2, out); } auto& invoker = GetORTInvoker(self.device()); @@ -1165,16 +1155,14 @@ at::Tensor& mm_out( std::vector ort_outputs_0_MatMul(1); ort_outputs_0_MatMul[0] = ort_input_out; - auto status = invoker.Invoke("MatMul", { - std::move(ort_input_0_self), - std::move(ort_input_0_mat2), - }, ort_outputs_0_MatMul, nullptr); + auto status = invoker.Invoke("MatMul", + {std::move(ort_input_0_self), std::move(ort_input_0_mat2)}, + ort_outputs_0_MatMul, nullptr); CHECK_STATUS(status); return out; } - } // namespace aten // #pragma endregion