mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Eager mode: Argmax and fixup max and min. (#11861)
* Eager mode ArgMax support. * Fix basic max and min functionality with minor generator update. Note this does not address all max and min api scope. * Add addmm test.
This commit is contained in:
parent
2c4e4b6afc
commit
fa7f80c847
4 changed files with 119 additions and 48 deletions
|
|
@ -124,8 +124,8 @@ hand_implemented = {
|
|||
"aten::softshrink": Shrink("self", bias="lambd", lambd="lambd"), # yes, bias is set to 'lambd'
|
||||
"aten::hardshrink": Shrink("self", bias=0, lambd="lambd"),
|
||||
"aten::gelu": Gelu("self"),
|
||||
"aten::max": ReduceMax("self", keepdims=1),
|
||||
"aten::min": ReduceMin("self", keepdims=1),
|
||||
"aten::max": ReduceMax("self", keepdims=0),
|
||||
"aten::min": ReduceMin("self", keepdims=0),
|
||||
"aten::_cat": Concat("tensors", "dim"),
|
||||
"aten::fill_.Scalar": ConstantOfShape("self", value="value"),
|
||||
"aten::ne.Scalar": MakeTorchFallback(),
|
||||
|
|
@ -137,8 +137,10 @@ hand_implemented = {
|
|||
"aten::masked_select": MakeTorchFallback(),
|
||||
"aten::_local_scalar_dense": MakeTorchFallback(),
|
||||
"aten::gt.Scalar_out": MakeTorchFallback(),
|
||||
"aten::lt.Scalar_out": MakeTorchFallback(),
|
||||
"aten::equal": MakeTorchFallback(),
|
||||
"aten::_softmax": Softmax("self", axis="dim"),
|
||||
"aten::argmax.out": SignatureOnly(),
|
||||
}
|
||||
|
||||
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional, Dict, List, Union
|
||||
|
||||
import sys
|
||||
import json
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import opgen.lexer as lexer
|
||||
import opgen.parser as parser
|
||||
import opgen.ast as ast
|
||||
import opgen.writer as writer
|
||||
from opgen import ast, lexer, parser, writer
|
||||
|
||||
|
||||
class Outputs:
|
||||
|
|
@ -356,7 +352,7 @@ class ORTGen:
|
|||
writer.writeline("}")
|
||||
|
||||
# Torch kwargs -> ORT attributes
|
||||
attrs = {k: v for k, v in onnx_op.attributes.items() if v and v.value}
|
||||
attrs = {k: v for k, v in onnx_op.attributes.items() if v and v.value is not None}
|
||||
if len(attrs) > 0:
|
||||
attrs_arg = "attrs"
|
||||
writer.writeline()
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ OrtValue create_ort_value(
|
|||
Ort::BFloat16_t *valOrtBFloat16 = reinterpret_cast<Ort::BFloat16_t *>(&valBFloat16);
|
||||
CopyVectorToTensor<Ort::BFloat16_t>(invoker, valOrtBFloat16, 1, *ort_tensor);
|
||||
break;
|
||||
}
|
||||
}
|
||||
default:
|
||||
// TODO: support more types
|
||||
// For most at::ScalarType, it should be safe to just call value.to<>
|
||||
|
|
@ -163,7 +163,7 @@ onnx::AttributeProto create_ort_attribute(
|
|||
at::ScalarType type = value.type();
|
||||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
|
||||
auto* constant_attribute_tensor_proto = attr.mutable_t();
|
||||
constant_attribute_tensor_proto->mutable_dims()->Clear();
|
||||
constant_attribute_tensor_proto->mutable_dims()->Clear();
|
||||
switch (type) {
|
||||
case at::ScalarType::Float:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
|
@ -341,7 +341,7 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at:
|
|||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
return output[0];
|
||||
return output[0];
|
||||
}
|
||||
|
||||
//#pragma endregion
|
||||
|
|
@ -396,9 +396,9 @@ at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::opti
|
|||
|
||||
// aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
|
||||
at::Tensor as_strided(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride,
|
||||
c10::optional<int64_t> storage_offset) {
|
||||
ORT_LOG_FN(self, size, stride, storage_offset);
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
|
@ -416,8 +416,8 @@ at::Tensor as_strided(
|
|||
}
|
||||
|
||||
at::Tensor _reshape_alias(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride){
|
||||
ORT_LOG_FN(self, size, stride);
|
||||
// TODO: support stride
|
||||
|
|
@ -471,22 +471,22 @@ at::Tensor& copy_(
|
|||
auto status = invoker.Invoke("Cast", {
|
||||
std::move(ort_src),
|
||||
}, ort_cast_output, &attrs);
|
||||
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
|
||||
copy(invoker, ort_cast_output[0], ort_self);
|
||||
}
|
||||
else{
|
||||
copy(invoker, ort_src, ort_self);
|
||||
}
|
||||
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor _copy_from_and_resize(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& dst){
|
||||
ORT_LOG_FN(self, dst);
|
||||
|
||||
|
|
@ -517,7 +517,7 @@ at::Tensor& zero_(at::Tensor& self){
|
|||
CopyVectorToTensor<int64_t>(invoker, &one, 1, *ort_flag_tensor);
|
||||
|
||||
std::vector<OrtValue> ort_out = {ort_in_self};
|
||||
|
||||
|
||||
auto status = invoker.Invoke(
|
||||
"ZeroGradient", {
|
||||
std::move(ort_in_self),
|
||||
|
|
@ -534,58 +534,58 @@ at::Tensor& zero_(at::Tensor& self){
|
|||
// TODO: enhance opgen.py to support inplace binary operations.
|
||||
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
||||
at::Tensor& add__Tensor(
|
||||
at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
at::Tensor& self,
|
||||
const at::Tensor& other,
|
||||
const at::Scalar& alpha) {
|
||||
ORT_LOG_FN(self, other, alpha);
|
||||
|
||||
|
||||
if (
|
||||
!IsSupportedType(alpha, {at::kDouble,at::kLong,at::kHalf,at::kShort,at::kInt,at::kByte,at::kFloat,at::kBFloat16}) ||
|
||||
!IsSupportedType(other, {at::kDouble,at::kLong,at::kHalf,at::kShort,at::kInt,at::kByte,at::kFloat,at::kBFloat16}) ||
|
||||
!IsSupportedType(self, {at::kDouble,at::kLong,at::kHalf,at::kShort,at::kInt,at::kByte,at::kFloat,at::kBFloat16})) {
|
||||
!IsSupportedType(alpha, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
|
||||
!IsSupportedType(other, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
|
||||
!IsSupportedType(self, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16})) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(add__Tensor)>::call(self, other, alpha);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
|
||||
auto ort_input_alpha = create_ort_value(invoker, alpha, other.scalar_type());
|
||||
auto ort_input_other = create_ort_value(invoker, other);
|
||||
|
||||
|
||||
std::vector<OrtValue> ort_outputs_0_Mul(1);
|
||||
|
||||
|
||||
auto status = invoker.Invoke("Mul", {
|
||||
std::move(ort_input_alpha),
|
||||
std::move(ort_input_other),
|
||||
}, ort_outputs_0_Mul, nullptr);
|
||||
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
|
||||
auto ort_input_self = create_ort_value(invoker, self);
|
||||
|
||||
|
||||
std::vector<OrtValue> ort_outputs_1_Add(1);
|
||||
ort_outputs_1_Add[0] = ort_input_self;
|
||||
|
||||
|
||||
status = invoker.Invoke("Add", {
|
||||
std::move(ort_input_self),
|
||||
std::move(ort_outputs_0_Mul[0]),
|
||||
}, ort_outputs_1_Add, nullptr);
|
||||
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
// aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
|
||||
at::Tensor slice_Tensor(
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end,
|
||||
const at::Tensor& self,
|
||||
int64_t dim,
|
||||
c10::optional<int64_t> start,
|
||||
c10::optional<int64_t> end,
|
||||
int64_t step) {
|
||||
ORT_LOG_FN(self, dim, start, end, step);
|
||||
int64_t ndim = self.dim();
|
||||
|
|
@ -634,6 +634,55 @@ at::Tensor slice_Tensor(
|
|||
self.options());
|
||||
}
|
||||
|
||||
// aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
at::Tensor& argmax_out(
|
||||
const at::Tensor& self,
|
||||
c10::optional<int64_t> dim,
|
||||
bool keepdim,
|
||||
// *,
|
||||
at::Tensor& out) {
|
||||
ORT_LOG_FN(self, dim, keepdim, out);
|
||||
|
||||
if (
|
||||
!IsSupportedType(self, {at::kLong, at::kShort, at::kHalf, at::kBFloat16, at::kFloat, at::kByte, at::kInt, at::kDouble})) {
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
auto ort_input_self =
|
||||
create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1}));
|
||||
|
||||
// Remove this hand signature once the generator can support this one line below.
|
||||
int64_t l_axis = dim.has_value() ? *dim : 0;
|
||||
|
||||
NodeAttributes attrs(2);
|
||||
attrs["axis"] = create_ort_attribute(
|
||||
"axis", l_axis, at::ScalarType::Int);
|
||||
attrs["keepdims"] = create_ort_attribute(
|
||||
"keepdims", keepdim, at::ScalarType::Int);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_0_ArgMax(1);
|
||||
|
||||
auto status = invoker.Invoke("ArgMax", {
|
||||
std::move(ort_input_self),
|
||||
}, ort_outputs_0_ArgMax, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
at::TensorOptions tensor_options = out.options();
|
||||
|
||||
// generator also needs to do this to handle the out param!
|
||||
out = aten_tensor_from_ort(
|
||||
std::move(ort_outputs_0_ArgMax[0]),
|
||||
tensor_options);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
} // namespace aten
|
||||
|
||||
//#pragma endregion
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -9,6 +11,8 @@ import torch
|
|||
|
||||
|
||||
class OrtOpTests(unittest.TestCase):
|
||||
"""test cases for supported eager ops"""
|
||||
|
||||
def get_device(self):
|
||||
return torch_ort.device()
|
||||
|
||||
|
|
@ -101,16 +105,18 @@ class OrtOpTests(unittest.TestCase):
|
|||
def test_max(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
ort_tensor = cpu_tensor.to("ort")
|
||||
y = ort_tensor.max()
|
||||
x = cpu_tensor.max()
|
||||
assert torch.allclose(x, y.cpu())
|
||||
ort_min = ort_tensor.max()
|
||||
cpu_min = cpu_tensor.max()
|
||||
assert torch.allclose(cpu_min, ort_min.cpu())
|
||||
assert cpu_min.dim() == ort_min.dim()
|
||||
|
||||
def test_min(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
ort_tensor = cpu_tensor.to("ort")
|
||||
y = ort_tensor.min()
|
||||
x = cpu_tensor.min()
|
||||
assert torch.allclose(x, y.cpu())
|
||||
ort_min = ort_tensor.min()
|
||||
cpu_min = cpu_tensor.min()
|
||||
assert torch.allclose(cpu_min, ort_min.cpu())
|
||||
assert cpu_min.dim() == ort_min.dim()
|
||||
|
||||
def test_equal(self):
|
||||
device = self.get_device()
|
||||
|
|
@ -152,6 +158,24 @@ class OrtOpTests(unittest.TestCase):
|
|||
ort_result = torch.softmax(ort_tensor, dim=1)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
|
||||
def test_addmm(self):
|
||||
device = self.get_device()
|
||||
size = 4
|
||||
ort_tensor = torch.ones([size, size]).to(device)
|
||||
input_bias = torch.ones([size]).to(device)
|
||||
output = torch.addmm(input_bias, ort_tensor, ort_tensor)
|
||||
expected = torch.ones([size, size]) * 5
|
||||
assert torch.equal(output.to("cpu"), expected)
|
||||
|
||||
def test_argmax(self):
|
||||
device = self.get_device()
|
||||
cpu_tensor = torch.rand(3, 5)
|
||||
ort_tensor = cpu_tensor.to(device)
|
||||
cpu_result = torch.argmax(cpu_tensor, dim=1)
|
||||
ort_result = torch.argmax(ort_tensor, dim=1)
|
||||
assert torch.allclose(cpu_result, ort_result.cpu())
|
||||
assert cpu_result.dim() == ort_result.dim()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue