mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Apply project formatting rules to ort_aten.cpp (#12294)
* Apply project formatting rules to ort_aten.cpp
Formatting applied by formatting the file in VS Code.
This file is under active development and the inconsistent formatting
was causing friction due to:
1. cpplint job on Pipeline was flagging a lot of style issues,
resulting in a lot of noisy annotations.
2. local edits would result in changes that are not part of the core change.
While there are other files in this part of the source tree with
inconsistent formatting, this file was causing the most friction. We can
come back and address the other files later, which would be a much
larger change.
* Apply consistent pattern for invoker.Invoke(...)
This commit is contained in:
parent
0fa3aeb65c
commit
8d0e86dec8
1 changed files with 281 additions and 293 deletions
|
|
@ -2,17 +2,17 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "ort_aten.h"
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
#include <ATen/InferSize.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
|
@ -20,48 +20,48 @@ namespace eager {
|
|||
// #pragma region Helpers
|
||||
using NodeAttributes = onnxruntime::NodeAttributes;
|
||||
namespace {
|
||||
inline bool is_device_supported(at::DeviceType type) {
|
||||
return type == at::kORT || type == at::kCPU;
|
||||
inline bool is_device_supported(at::DeviceType type) {
|
||||
return type == at::kORT || type == at::kCPU;
|
||||
}
|
||||
|
||||
inline void assert_tensor_supported(const at::Tensor& tensor) {
|
||||
if (tensor.is_sparse()) {
|
||||
throw std::runtime_error("ORT copy: sparse not supported");
|
||||
}
|
||||
|
||||
inline void assert_tensor_supported(const at::Tensor& tensor) {
|
||||
if (tensor.is_sparse()) {
|
||||
throw std::runtime_error("ORT copy: sparse not supported");
|
||||
}
|
||||
|
||||
if (tensor.is_quantized()) {
|
||||
throw std::runtime_error("ORT copy: quantized not supported");
|
||||
}
|
||||
|
||||
if (!is_device_supported(tensor.device().type())) {
|
||||
throw std::runtime_error("ORT copy: device not supported");
|
||||
}
|
||||
if (tensor.is_quantized()) {
|
||||
throw std::runtime_error("ORT copy: quantized not supported");
|
||||
}
|
||||
|
||||
if (!is_device_supported(tensor.device().type())) {
|
||||
throw std::runtime_error("ORT copy: device not supported");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
at::Tensor aten_tensor_from_ort(
|
||||
OrtValue&& ot,
|
||||
const at::TensorOptions& options) {
|
||||
OrtValue&& ot,
|
||||
const at::TensorOptions& options) {
|
||||
return at::Tensor(c10::make_intrusive<ORTTensorImpl>(
|
||||
std::move(ot),
|
||||
options));
|
||||
std::move(ot),
|
||||
options));
|
||||
}
|
||||
|
||||
const std::vector<at::Tensor> aten_tensor_from_ort(
|
||||
std::vector<OrtValue>& ortvalues,
|
||||
const at::TensorOptions& options) {
|
||||
const size_t num_outputs = ortvalues.size();
|
||||
std::vector<at::Tensor> atvalues = std::vector<at::Tensor>(num_outputs);
|
||||
for (size_t i = 0; i < num_outputs; i++) {
|
||||
atvalues[i] = at::Tensor(c10::make_intrusive<ORTTensorImpl>(
|
||||
std::vector<OrtValue>& ortvalues,
|
||||
const at::TensorOptions& options) {
|
||||
const size_t num_outputs = ortvalues.size();
|
||||
std::vector<at::Tensor> atvalues = std::vector<at::Tensor>(num_outputs);
|
||||
for (size_t i = 0; i < num_outputs; i++) {
|
||||
atvalues[i] = at::Tensor(c10::make_intrusive<ORTTensorImpl>(
|
||||
std::move(ortvalues[i]),
|
||||
options));
|
||||
}
|
||||
return atvalues;
|
||||
}
|
||||
return atvalues;
|
||||
}
|
||||
|
||||
onnxruntime::MLDataType ort_scalar_type_from_aten(
|
||||
at::ScalarType dtype) {
|
||||
at::ScalarType dtype) {
|
||||
switch (dtype) {
|
||||
case at::kFloat:
|
||||
return onnxruntime::DataTypeImpl::GetType<float>();
|
||||
|
|
@ -85,15 +85,15 @@ onnxruntime::MLDataType ort_scalar_type_from_aten(
|
|||
}
|
||||
|
||||
OrtValue create_ort_value(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const at::Scalar& scalar) {
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const at::Scalar& scalar) {
|
||||
return create_ort_value(invoker, scalar, scalar.type());
|
||||
}
|
||||
|
||||
OrtValue create_ort_value(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const at::Scalar& scalar,
|
||||
at::ScalarType type) {
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const at::Scalar& scalar,
|
||||
at::ScalarType type) {
|
||||
OrtValue ort_val;
|
||||
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(type), onnxruntime::TensorShape({}),
|
||||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ort_val);
|
||||
|
|
@ -131,8 +131,8 @@ OrtValue create_ort_value(
|
|||
}
|
||||
|
||||
OrtValue create_ort_value(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const at::Tensor& tensor) {
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
const at::Tensor& tensor) {
|
||||
assert_tensor_supported(tensor);
|
||||
|
||||
auto* impl = dynamic_cast<ORTTensorImpl*>(tensor.unsafeGetTensorImpl());
|
||||
|
|
@ -140,18 +140,18 @@ OrtValue create_ort_value(
|
|||
return impl->tensor();
|
||||
}
|
||||
|
||||
OrtMemoryInfo *mem_info;
|
||||
OrtMemoryInfo* mem_info;
|
||||
Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info));
|
||||
auto element_type = ort_scalar_type_from_aten(tensor.scalar_type());
|
||||
|
||||
OrtValue ort_tensor;
|
||||
onnxruntime::Tensor::InitOrtValue(
|
||||
element_type,
|
||||
onnxruntime::TensorShape(tensor.sizes().vec()),
|
||||
tensor.data_ptr(),
|
||||
*mem_info, ort_tensor,
|
||||
0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset
|
||||
tensor.strides().vec());
|
||||
element_type,
|
||||
onnxruntime::TensorShape(tensor.sizes().vec()),
|
||||
tensor.data_ptr(),
|
||||
*mem_info, ort_tensor,
|
||||
0L, // offset = 0 - because tensor.data_ptr() includes the underlying offset
|
||||
tensor.strides().vec());
|
||||
return ort_tensor;
|
||||
}
|
||||
|
||||
|
|
@ -161,20 +161,20 @@ OrtValue create_ort_value(const at::Tensor& tensor) {
|
|||
}
|
||||
|
||||
std::vector<OrtValue> create_ort_value(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
at::TensorList values) {
|
||||
auto output = std::vector<OrtValue>{};
|
||||
for (auto element : values) {
|
||||
output.push_back(create_ort_value(element));
|
||||
}
|
||||
return output;
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
at::TensorList values) {
|
||||
auto output = std::vector<OrtValue>{};
|
||||
for (auto element : values) {
|
||||
output.push_back(create_ort_value(element));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
const bool isTensor,
|
||||
at::ScalarType type) {
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
const bool isTensor,
|
||||
at::ScalarType type) {
|
||||
if (isTensor) {
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
|
|
@ -184,28 +184,28 @@ onnx::AttributeProto create_ort_attribute(
|
|||
// Creating a 1 dim tensor of size 1, so add that dim now.
|
||||
constant_attribute_tensor_proto->add_dims(1);
|
||||
switch (type) {
|
||||
case at::ScalarType::Float:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<float>();
|
||||
break;
|
||||
case at::ScalarType::Double:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
|
||||
*constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to<double>();
|
||||
break;
|
||||
case at::ScalarType::Bool:
|
||||
case at::ScalarType::Int:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
*constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to<int>();
|
||||
break;
|
||||
case at::ScalarType::Long:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
*constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to<int64_t>();
|
||||
break;
|
||||
default:
|
||||
// For most at::ScalarType, it should be safe to just call value.to<>
|
||||
// on it, but for now we want to explicitly know when we've encountered
|
||||
// a new scalar type while bringing up ORT eager mode.
|
||||
ORT_THROW("Unsupported: at::ScalarType::", value.type());
|
||||
case at::ScalarType::Float:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<float>();
|
||||
break;
|
||||
case at::ScalarType::Double:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
|
||||
*constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to<double>();
|
||||
break;
|
||||
case at::ScalarType::Bool:
|
||||
case at::ScalarType::Int:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
*constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to<int>();
|
||||
break;
|
||||
case at::ScalarType::Long:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
*constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to<int64_t>();
|
||||
break;
|
||||
default:
|
||||
// For most at::ScalarType, it should be safe to just call value.to<>
|
||||
// on it, but for now we want to explicitly know when we've encountered
|
||||
// a new scalar type while bringing up ORT eager mode.
|
||||
ORT_THROW("Unsupported: at::ScalarType::", value.type());
|
||||
}
|
||||
return attr;
|
||||
} else {
|
||||
|
|
@ -214,16 +214,16 @@ onnx::AttributeProto create_ort_attribute(
|
|||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
const bool isTensor) {
|
||||
return create_ort_attribute(name, value, isTensor, value.type());
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
const bool isTensor) {
|
||||
return create_ort_attribute(name, value, isTensor, value.type());
|
||||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
at::ScalarType type) {
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
at::ScalarType type) {
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
switch (type) {
|
||||
|
|
@ -249,8 +249,8 @@ onnx::AttributeProto create_ort_attribute(
|
|||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
const char* value) {
|
||||
const char* name,
|
||||
const char* value) {
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING);
|
||||
|
|
@ -259,8 +259,8 @@ onnx::AttributeProto create_ort_attribute(
|
|||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
const std::vector<int64_t> values) {
|
||||
const char* name,
|
||||
const std::vector<int64_t> values) {
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
|
||||
|
|
@ -331,7 +331,6 @@ static c10::optional<at::ScalarType> PromoteScalarTypes(
|
|||
return st;
|
||||
}
|
||||
|
||||
|
||||
c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
|
||||
const std::vector<at::ScalarType>& typesFromTensors,
|
||||
const std::vector<at::ScalarType>& typesFromScalars) {
|
||||
|
|
@ -370,15 +369,15 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at:
|
|||
std::vector<OrtValue> output(1);
|
||||
NodeAttributes attrs(2);
|
||||
attrs["to"] = create_ort_attribute(
|
||||
"to", GetONNXTensorProtoDataType(type), at::ScalarType::Long);
|
||||
"to", GetONNXTensorProtoDataType(type), at::ScalarType::Long);
|
||||
|
||||
auto status = invoker.Invoke("Cast", {
|
||||
std::move(input),
|
||||
}, output, &attrs);
|
||||
auto status = invoker.Invoke("Cast",
|
||||
{std::move(input)},
|
||||
output, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
return output[0];
|
||||
}
|
||||
|
||||
|
|
@ -389,9 +388,9 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at:
|
|||
* @param keepdim Whether to retain dim or not. Ignored if dimToReduce is null.
|
||||
*/
|
||||
inline at::DimVector calculate_reduction_shape(
|
||||
const at::Tensor& self,
|
||||
c10::optional<int64_t> dimToReduce,
|
||||
bool keepdim) {
|
||||
const at::Tensor& self,
|
||||
c10::optional<int64_t> dimToReduce,
|
||||
bool keepdim) {
|
||||
at::DimVector shape;
|
||||
|
||||
// If we have dim value, then reduce that dimension.
|
||||
|
|
@ -423,20 +422,20 @@ inline at::DimVector calculate_reduction_shape(
|
|||
* PyToch implementation of resize will warn about resizing
|
||||
* non-empty and indicate this is deprecated behavior that
|
||||
* can / will change.
|
||||
*
|
||||
*
|
||||
* In PyTorch repository see: aten/src/ATen/native/Resize.{h|cpp}
|
||||
*/
|
||||
void resize_output(
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
ORTTensorImpl* output,
|
||||
at::IntArrayRef shape) {
|
||||
onnxruntime::ORTInvoker& invoker,
|
||||
ORTTensorImpl* output,
|
||||
at::IntArrayRef shape) {
|
||||
if (output->sizes().equals(shape)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (output->numel() != 0) {
|
||||
throw std::runtime_error(
|
||||
"resizing a non-empty output tensor is not supported.");
|
||||
"resizing a non-empty output tensor is not supported.");
|
||||
}
|
||||
|
||||
resize_impl_ort_(invoker, output, shape);
|
||||
|
|
@ -492,11 +491,11 @@ void resize_impl_ort_(
|
|||
// Just resize existing tensor and return
|
||||
|
||||
OrtValue new_ort_value = reshape_invoke(
|
||||
invoker,
|
||||
self_ort_value,
|
||||
size,
|
||||
// invoke reshape kernel inplace
|
||||
true);
|
||||
invoker,
|
||||
self_ort_value,
|
||||
size,
|
||||
// invoke reshape kernel inplace
|
||||
true);
|
||||
|
||||
// TODO(jamill): Investigate why reshape_invoke kernel does not update inplace
|
||||
self->set_tensor(new_ort_value);
|
||||
|
|
@ -552,12 +551,12 @@ void resize_impl_ort_(
|
|||
namespace aten {
|
||||
|
||||
at::Tensor empty_strided(
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype_opt,
|
||||
c10::optional<at::Layout> layout_opt, // Ignored because there's no ONNX support.
|
||||
c10::optional<at::Device> device_opt, // Will be ORT by the time this is dispatched.
|
||||
c10::optional<bool> pin_memory_opt) { // Ignored because there's no ONNX support.
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<at::ScalarType> dtype_opt,
|
||||
c10::optional<at::Layout> layout_opt, // Ignored because there's no ONNX support.
|
||||
c10::optional<at::Device> device_opt, // Will be ORT by the time this is dispatched.
|
||||
c10::optional<bool> pin_memory_opt) { // Ignored because there's no ONNX support.
|
||||
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
|
||||
OrtValue ot;
|
||||
|
|
@ -568,19 +567,19 @@ at::Tensor empty_strided(
|
|||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot,
|
||||
stride.vec());
|
||||
return aten_tensor_from_ort(
|
||||
std::move(ot),
|
||||
at::TensorOptions()
|
||||
.device(*device_opt)
|
||||
.dtype(dtype));
|
||||
std::move(ot),
|
||||
at::TensorOptions()
|
||||
.device(*device_opt)
|
||||
.dtype(dtype));
|
||||
}
|
||||
|
||||
at::Tensor empty_memory_format(
|
||||
at::IntArrayRef size,
|
||||
c10::optional<at::ScalarType> dtype_opt,
|
||||
c10::optional<at::Layout> layout_opt,
|
||||
c10::optional<at::Device> device_opt,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) { // Ignored because there's no ONNX support.
|
||||
at::IntArrayRef size,
|
||||
c10::optional<at::ScalarType> dtype_opt,
|
||||
c10::optional<at::Layout> layout_opt,
|
||||
c10::optional<at::Device> device_opt,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<at::MemoryFormat> memory_format) { // Ignored because there's no ONNX support.
|
||||
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format);
|
||||
|
||||
// Use the strided impl with default (no strides specified).
|
||||
|
|
@ -589,10 +588,10 @@ at::Tensor empty_memory_format(
|
|||
|
||||
// aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
|
||||
at::Tensor as_strided(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
ORT_LOG_FN(self, size, stride, storage_offset);
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
auto ort_input = create_ort_value(invoker, self);
|
||||
|
|
@ -604,26 +603,26 @@ at::Tensor as_strided(
|
|||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault)->Info(),
|
||||
ot, byte_offset, stride.vec());
|
||||
return aten_tensor_from_ort(
|
||||
std::move(ot),
|
||||
self.options());
|
||||
std::move(ot),
|
||||
self.options());
|
||||
}
|
||||
|
||||
at::Tensor _reshape_alias(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride) {
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride) {
|
||||
ORT_LOG_FN(self, size, stride);
|
||||
// TODO(unknown): support stride
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
auto ort_input = create_ort_value(invoker, self);
|
||||
return aten_tensor_from_ort(
|
||||
reshape_invoke(
|
||||
invoker,
|
||||
ort_input,
|
||||
size,
|
||||
// invoke reshape kernel inplace
|
||||
true),
|
||||
self.options());
|
||||
reshape_invoke(
|
||||
invoker,
|
||||
ort_input,
|
||||
size,
|
||||
// invoke reshape kernel inplace
|
||||
true),
|
||||
self.options());
|
||||
}
|
||||
|
||||
at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
|
||||
|
|
@ -631,27 +630,25 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
|
|||
auto& invoker = GetORTInvoker(self.device());
|
||||
auto ort_input = create_ort_value(invoker, self);
|
||||
return aten_tensor_from_ort(
|
||||
reshape_invoke(
|
||||
invoker,
|
||||
ort_input,
|
||||
size,
|
||||
// invoke reshape kernel inplace
|
||||
true),
|
||||
self.options());
|
||||
reshape_invoke(
|
||||
invoker,
|
||||
ort_input,
|
||||
size,
|
||||
// invoke reshape kernel inplace
|
||||
true),
|
||||
self.options());
|
||||
}
|
||||
|
||||
at::Tensor& copy_(
|
||||
at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
bool non_blocking) {
|
||||
at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
bool non_blocking) {
|
||||
ORT_LOG_FN(self, src, non_blocking);
|
||||
|
||||
assert_tensor_supported(self);
|
||||
assert_tensor_supported(src);
|
||||
|
||||
auto& invoker = GetORTInvoker(self.device().type() == at::kORT
|
||||
? self.device()
|
||||
: src.device());
|
||||
auto& invoker = GetORTInvoker(self.device().type() == at::kORT ? self.device() : src.device());
|
||||
const auto ort_src = create_ort_value(invoker, src);
|
||||
auto ort_self = create_ort_value(invoker, self);
|
||||
if (self.scalar_type() != src.scalar_type()) {
|
||||
|
|
@ -659,15 +656,15 @@ at::Tensor& copy_(
|
|||
std::vector<OrtValue> ort_cast_output(1);
|
||||
onnxruntime::NodeAttributes attrs(1);
|
||||
attrs["to"] = create_ort_attribute(
|
||||
"to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong);
|
||||
"to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong);
|
||||
|
||||
auto status = invoker.Invoke("Cast", {
|
||||
std::move(ort_src),
|
||||
}, ort_cast_output, &attrs);
|
||||
auto status = invoker.Invoke("Cast",
|
||||
{std::move(ort_src)},
|
||||
ort_cast_output, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
copy(invoker, ort_cast_output[0], ort_self);
|
||||
} else {
|
||||
|
|
@ -678,16 +675,14 @@ at::Tensor& copy_(
|
|||
}
|
||||
|
||||
at::Tensor _copy_from_and_resize(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& dst) {
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& dst) {
|
||||
ORT_LOG_FN(self, dst);
|
||||
|
||||
assert_tensor_supported(self);
|
||||
assert_tensor_supported(dst);
|
||||
|
||||
auto& invoker = GetORTInvoker(self.device().type() == at::kORT
|
||||
? self.device()
|
||||
: dst.device());
|
||||
auto& invoker = GetORTInvoker(self.device().type() == at::kORT ? self.device() : dst.device());
|
||||
const auto ort_self = create_ort_value(invoker, self);
|
||||
auto ort_dst = create_ort_value(invoker, dst);
|
||||
|
||||
|
|
@ -710,15 +705,13 @@ at::Tensor& zero_(at::Tensor& self) {
|
|||
|
||||
std::vector<OrtValue> ort_out = {ort_in_self};
|
||||
|
||||
auto status = invoker.Invoke(
|
||||
"ZeroGradient", {
|
||||
std::move(ort_in_self),
|
||||
std::move(flag_val)
|
||||
}, ort_out, nullptr, onnxruntime::kMSDomain, 1);
|
||||
auto status = invoker.Invoke("ZeroGradient",
|
||||
{std::move(ort_in_self), std::move(flag_val)},
|
||||
ort_out, nullptr, onnxruntime::kMSDomain, 1);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
return self;
|
||||
}
|
||||
|
|
@ -726,19 +719,19 @@ at::Tensor& zero_(at::Tensor& self) {
|
|||
// TODO(unknown): enhance opgen.py to support inplace binary operations.
|
||||
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
||||
at::Tensor& add__Tensor(
|
||||
at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
const at::Scalar& alpha) {
|
||||
at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
const at::Scalar& alpha) {
|
||||
ORT_LOG_FN(self, other, alpha);
|
||||
|
||||
auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
|
||||
if (
|
||||
!IsSupportedType(alpha, st) ||
|
||||
!IsSupportedType(other, st) ||
|
||||
!IsSupportedType(self, st)) {
|
||||
!IsSupportedType(alpha, st) ||
|
||||
!IsSupportedType(other, st) ||
|
||||
!IsSupportedType(self, st)) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(add__Tensor)>::call(self, other, alpha);
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(add__Tensor)>::call(self, other, alpha);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
|
|
@ -747,39 +740,37 @@ at::Tensor& add__Tensor(
|
|||
|
||||
std::vector<OrtValue> ort_outputs_0_Mul(1);
|
||||
|
||||
auto status = invoker.Invoke("Mul", {
|
||||
std::move(ort_input_alpha),
|
||||
std::move(ort_input_other),
|
||||
}, ort_outputs_0_Mul, nullptr);
|
||||
auto status = invoker.Invoke("Mul",
|
||||
{std::move(ort_input_alpha), std::move(ort_input_other)},
|
||||
ort_outputs_0_Mul, nullptr);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
auto ort_input_self = create_ort_value(invoker, self);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_1_Add(1);
|
||||
ort_outputs_1_Add[0] = ort_input_self;
|
||||
|
||||
status = invoker.Invoke("Add", {
|
||||
std::move(ort_input_self),
|
||||
std::move(ort_outputs_0_Mul[0]),
|
||||
}, ort_outputs_1_Add, nullptr);
|
||||
status = invoker.Invoke("Add",
|
||||
{std::move(ort_input_self), std::move(ort_outputs_0_Mul[0])},
|
||||
ort_outputs_1_Add, nullptr);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
// aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
|
||||
at::Tensor slice_Tensor(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end,
|
||||
int64_t step) {
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end,
|
||||
int64_t step) {
|
||||
ORT_LOG_FN(self, dim, start, end, step);
|
||||
int64_t ndim = self.dim();
|
||||
if (ndim == 0) {
|
||||
|
|
@ -823,39 +814,39 @@ at::Tensor slice_Tensor(
|
|||
ort_tensor->DataType(), onnxruntime::TensorShape(new_shape), ort_tensor->MutableDataRaw(),
|
||||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault)->Info(), ot, byte_offset, new_stride);
|
||||
return aten_tensor_from_ort(
|
||||
std::move(ot),
|
||||
self.options());
|
||||
std::move(ot),
|
||||
self.options());
|
||||
}
|
||||
|
||||
// aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
at::Tensor& argmax_out(
|
||||
const at::Tensor& self,
|
||||
c10::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
const at::Tensor& self,
|
||||
c10::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, dim, keepdim, out);
|
||||
|
||||
auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
|
||||
if (
|
||||
!IsSupportedType(self, st)) {
|
||||
!IsSupportedType(self, st)) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out);
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
auto ort_input_self =
|
||||
create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1}));
|
||||
create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1}));
|
||||
|
||||
int64_t l_axis = dim.has_value() ? *dim : 0;
|
||||
bool keepdim_effective_value = dim.has_value() ? keepdim : false;
|
||||
|
||||
NodeAttributes attrs(2);
|
||||
attrs["axis"] = create_ort_attribute(
|
||||
"axis", l_axis, at::ScalarType::Int);
|
||||
"axis", l_axis, at::ScalarType::Int);
|
||||
attrs["keepdims"] = create_ort_attribute(
|
||||
"keepdims", keepdim_effective_value, at::ScalarType::Bool);
|
||||
"keepdims", keepdim_effective_value, at::ScalarType::Bool);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_0_ArgMax(1);
|
||||
|
||||
|
|
@ -869,31 +860,31 @@ at::Tensor& out) {
|
|||
auto ort_input_out = create_ort_value(invoker, out);
|
||||
ort_outputs_0_ArgMax[0] = ort_input_out;
|
||||
|
||||
auto status = invoker.Invoke("ArgMax", {
|
||||
std::move(ort_input_self),
|
||||
}, ort_outputs_0_ArgMax, &attrs);
|
||||
auto status = invoker.Invoke("ArgMax",
|
||||
{std::move(ort_input_self)},
|
||||
ort_outputs_0_ArgMax, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// aten::equal(Tensor self, Tensor other) -> bool
|
||||
bool equal(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& other) {
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& other) {
|
||||
ORT_LOG_FN(self, other);
|
||||
|
||||
if (
|
||||
std::vector<at::ScalarType> supportedTypes =
|
||||
{at::kFloat, at::kBFloat16, at::kHalf, at::kDouble, at::kLong, at::kByte, at::kInt, at::kShort, at::kBool};
|
||||
!IsSupportedType(self, supportedTypes) ||
|
||||
!IsSupportedType(other, supportedTypes)) {
|
||||
std::vector<at::ScalarType> supportedTypes =
|
||||
{at::kFloat, at::kBFloat16, at::kHalf, at::kDouble, at::kLong, at::kByte, at::kInt, at::kShort, at::kBool};
|
||||
!IsSupportedType(self, supportedTypes) ||
|
||||
!IsSupportedType(other, supportedTypes)) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(equal)>::call(self, other);
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(equal)>::call(self, other);
|
||||
}
|
||||
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
|
@ -914,19 +905,18 @@ bool equal(
|
|||
// being less than true, so any false will reduce to false.
|
||||
std::vector<OrtValue> ort_outputs_0_Equal(1);
|
||||
|
||||
auto equalStatus = invoker.Invoke("Equal", {
|
||||
std::move(ort_input_self),
|
||||
std::move(ort_input_other),
|
||||
}, ort_outputs_0_Equal, nullptr);
|
||||
auto equalStatus = invoker.Invoke("Equal",
|
||||
{std::move(ort_input_self), std::move(ort_input_other)},
|
||||
ort_outputs_0_Equal, nullptr);
|
||||
|
||||
if (!equalStatus.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT Equal return failure status:" + equalStatus.ErrorMessage());
|
||||
"ORT Equal return failure status:" + equalStatus.ErrorMessage());
|
||||
|
||||
// now reduce the resulting tensor of bool values to its minimum value (any false)
|
||||
NodeAttributes attrs(1);
|
||||
attrs["keepdims"] = create_ort_attribute(
|
||||
"keepdims", 0, at::ScalarType::Int);
|
||||
"keepdims", 0, at::ScalarType::Int);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_0_ReduceMin(1);
|
||||
|
||||
|
|
@ -934,13 +924,13 @@ bool equal(
|
|||
// GetONNXTensorProtoDataType doesn't support byte, which leaves us with int
|
||||
OrtValue equalAsInt = CastToType(invoker, ort_outputs_0_Equal[0], at::ScalarType::Int);
|
||||
|
||||
auto reduceStatus = invoker.Invoke("ReduceMin", {
|
||||
std::move(equalAsInt),
|
||||
}, ort_outputs_0_ReduceMin, &attrs);
|
||||
auto reduceStatus = invoker.Invoke("ReduceMin",
|
||||
{std::move(equalAsInt)},
|
||||
ort_outputs_0_ReduceMin, &attrs);
|
||||
|
||||
if (!reduceStatus.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage());
|
||||
"ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage());
|
||||
|
||||
auto* ort_tensor = ort_outputs_0_ReduceMin[0].GetMutable<onnxruntime::Tensor>();
|
||||
// the first (and only) value of the tensor will be 0 for false else true
|
||||
|
|
@ -970,18 +960,18 @@ const at::Tensor& resize_(
|
|||
|
||||
// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
|
||||
at::Tensor& fill__Scalar(
|
||||
at::Tensor& self,
|
||||
const at::Scalar& value) {
|
||||
at::Tensor& self,
|
||||
const at::Scalar& value) {
|
||||
ORT_LOG_FN(self, value);
|
||||
|
||||
if (
|
||||
std::vector<at::ScalarType> supportedTypes =
|
||||
{at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool};
|
||||
!IsSupportedType(self, supportedTypes)) {
|
||||
std::vector<at::ScalarType> supportedTypes =
|
||||
{at::kHalf, at::kFloat, at::kInt, at::kDouble, at::kByte, at::kShort, at::kLong, at::kBFloat16, at::kBool};
|
||||
!IsSupportedType(self, supportedTypes)) {
|
||||
std::cout << "fill__Scalar - Fell back to cpu!\n";
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(fill__Scalar)>::call(self, value);
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(fill__Scalar)>::call(self, value);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
|
|
@ -989,37 +979,37 @@ at::Tensor& fill__Scalar(
|
|||
|
||||
std::vector<OrtValue> ort_outputs_0_Shape(1);
|
||||
|
||||
auto status = invoker.Invoke("Shape", {
|
||||
std::move(ort_input_self),
|
||||
}, ort_outputs_0_Shape, nullptr);
|
||||
auto status = invoker.Invoke("Shape",
|
||||
{std::move(ort_input_self)},
|
||||
ort_outputs_0_Shape, nullptr);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
std::vector<OrtValue> ort_outputs_1_ConstantOfShape(1);
|
||||
ort_outputs_1_ConstantOfShape[0] = ort_input_self;
|
||||
|
||||
NodeAttributes attrs(1);
|
||||
attrs["value"] = create_ort_attribute(
|
||||
"value", value, true, self.scalar_type());
|
||||
"value", value, true, self.scalar_type());
|
||||
|
||||
status = invoker.Invoke("ConstantOfShape", {
|
||||
std::move(ort_outputs_0_Shape[0]),
|
||||
}, ort_outputs_1_ConstantOfShape, &attrs);
|
||||
status = invoker.Invoke("ConstantOfShape",
|
||||
{std::move(ort_outputs_0_Shape[0])},
|
||||
ort_outputs_1_ConstantOfShape, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
// aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
at::Tensor& nonzero_out(
|
||||
const at::Tensor& self,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
const at::Tensor& self,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, out);
|
||||
|
||||
auto temp = eager::aten::nonzero(self);
|
||||
|
|
@ -1036,18 +1026,18 @@ at::Tensor& nonzero_out(
|
|||
|
||||
// aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
|
||||
at::Tensor& _log_softmax_out(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
bool half_to_float,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
bool half_to_float,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, dim, half_to_float, out);
|
||||
|
||||
if (
|
||||
!IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) {
|
||||
!IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out);
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
|
|
@ -1074,30 +1064,30 @@ at::Tensor& _log_softmax_out(
|
|||
for (int64_t i = 0; i < ndim; i++)
|
||||
axes.push_back(i);
|
||||
|
||||
axes[dim] = ndim-1;
|
||||
axes[ndim-1] = dim;
|
||||
dim = ndim-1;
|
||||
axes[dim] = ndim - 1;
|
||||
axes[ndim - 1] = dim;
|
||||
dim = ndim - 1;
|
||||
|
||||
NodeAttributes attrs_0(1);
|
||||
attrs_0["perm"] = create_ort_attribute("perm", axes);
|
||||
auto status = invoker.Invoke("Transpose", {
|
||||
std::move(ort_input_0_self),
|
||||
}, ort_outputs_0_Transpose, &attrs_0);
|
||||
auto status = invoker.Invoke("Transpose",
|
||||
{std::move(ort_input_0_self)},
|
||||
ort_outputs_0_Transpose, &attrs_0);
|
||||
CHECK_STATUS(status);
|
||||
}
|
||||
|
||||
NodeAttributes attrs_1(1);
|
||||
attrs_1["axis"] = create_ort_attribute(
|
||||
"axis", dim, at::ScalarType::Int);
|
||||
"axis", dim, at::ScalarType::Int);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_1_LogSoftmax(1);
|
||||
if (!need_transpose) {
|
||||
ort_outputs_1_LogSoftmax[0] = ort_input_out;
|
||||
}
|
||||
|
||||
auto status = invoker.Invoke("LogSoftmax", {
|
||||
std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self),
|
||||
}, ort_outputs_1_LogSoftmax, &attrs_1);
|
||||
auto status = invoker.Invoke("LogSoftmax",
|
||||
{std::move(need_transpose ? ort_outputs_0_Transpose[0] : ort_input_0_self)},
|
||||
ort_outputs_1_LogSoftmax, &attrs_1);
|
||||
CHECK_STATUS(status);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_2_Transpose(1);
|
||||
|
|
@ -1108,9 +1098,9 @@ at::Tensor& _log_softmax_out(
|
|||
NodeAttributes attrs_2(1);
|
||||
attrs_2["perm"] = create_ort_attribute("perm", axes);
|
||||
|
||||
status = invoker.Invoke("Transpose", {
|
||||
std::move(ort_outputs_1_LogSoftmax[0]),
|
||||
}, ort_outputs_2_Transpose, &attrs_2);
|
||||
status = invoker.Invoke("Transpose",
|
||||
{std::move(ort_outputs_1_LogSoftmax[0])},
|
||||
ort_outputs_2_Transpose, &attrs_2);
|
||||
CHECK_STATUS(status);
|
||||
}
|
||||
|
||||
|
|
@ -1121,28 +1111,28 @@ at::Tensor& _log_softmax_out(
|
|||
// mm is for matrix multiplication and does not broadcast.
|
||||
// https://pytorch.org/docs/stable/generated/torch.mm.html
|
||||
at::Tensor& mm_out(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& mat2,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& mat2,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, mat2, out);
|
||||
|
||||
if (
|
||||
std::vector<at::ScalarType> supportedTypes =
|
||||
{at::kDouble, at::kLong, at::kHalf, at::kFloat, at::kBFloat16, at::kInt};
|
||||
!IsSupportedType(self, supportedTypes) ||
|
||||
!IsSupportedType(mat2, supportedTypes) ||
|
||||
// to match cpu device behavior for torch.mm, verify the following and fall back to cpu to generate error message.
|
||||
// 1. self and mat2 must be 2-D (matrices)
|
||||
self.dim() != 2 || mat2.dim() != 2 ||
|
||||
// 2. self and mat2 can be multiplied
|
||||
self.sizes()[1] != mat2.sizes()[0] ||
|
||||
// 3. self, mat2, and out are of the same type
|
||||
self.scalar_type() != out.scalar_type() ||
|
||||
self.scalar_type() != mat2.scalar_type()) {
|
||||
std::vector<at::ScalarType> supportedTypes =
|
||||
{at::kDouble, at::kLong, at::kHalf, at::kFloat, at::kBFloat16, at::kInt};
|
||||
!IsSupportedType(self, supportedTypes) ||
|
||||
!IsSupportedType(mat2, supportedTypes) ||
|
||||
// to match cpu device behavior for torch.mm, verify the following and fall back to cpu to generate error message.
|
||||
// 1. self and mat2 must be 2-D (matrices)
|
||||
self.dim() != 2 || mat2.dim() != 2 ||
|
||||
// 2. self and mat2 can be multiplied
|
||||
self.sizes()[1] != mat2.sizes()[0] ||
|
||||
// 3. self, mat2, and out are of the same type
|
||||
self.scalar_type() != out.scalar_type() ||
|
||||
self.scalar_type() != mat2.scalar_type()) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(mm_out)>::call(self, mat2, out);
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(mm_out)>::call(self, mat2, out);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
|
|
@ -1165,16 +1155,14 @@ at::Tensor& mm_out(
|
|||
std::vector<OrtValue> ort_outputs_0_MatMul(1);
|
||||
ort_outputs_0_MatMul[0] = ort_input_out;
|
||||
|
||||
auto status = invoker.Invoke("MatMul", {
|
||||
std::move(ort_input_0_self),
|
||||
std::move(ort_input_0_mat2),
|
||||
}, ort_outputs_0_MatMul, nullptr);
|
||||
auto status = invoker.Invoke("MatMul",
|
||||
{std::move(ort_input_0_self), std::move(ort_input_0_mat2)},
|
||||
ort_outputs_0_MatMul, nullptr);
|
||||
CHECK_STATUS(status);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
} // namespace aten
|
||||
|
||||
// #pragma endregion
|
||||
|
|
|
|||
Loading…
Reference in a new issue