Eager mode: Argmax and fixup max and min. (#11861)

* Eager mode ArgMax support.

* Fix basic max and min functionality with minor generator update. Note this does not address all max and min api scope.

* Add addmm test.
This commit is contained in:
Wil Brady 2022-06-23 15:55:34 -04:00 committed by GitHub
parent 2c4e4b6afc
commit fa7f80c847
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 119 additions and 48 deletions

View file

@ -124,8 +124,8 @@ hand_implemented = {
"aten::softshrink": Shrink("self", bias="lambd", lambd="lambd"), # yes, bias is set to 'lambd'
"aten::hardshrink": Shrink("self", bias=0, lambd="lambd"),
"aten::gelu": Gelu("self"),
"aten::max": ReduceMax("self", keepdims=1),
"aten::min": ReduceMin("self", keepdims=1),
"aten::max": ReduceMax("self", keepdims=0),
"aten::min": ReduceMin("self", keepdims=0),
"aten::_cat": Concat("tensors", "dim"),
"aten::fill_.Scalar": ConstantOfShape("self", value="value"),
"aten::ne.Scalar": MakeTorchFallback(),
@ -137,8 +137,10 @@ hand_implemented = {
"aten::masked_select": MakeTorchFallback(),
"aten::_local_scalar_dense": MakeTorchFallback(),
"aten::gt.Scalar_out": MakeTorchFallback(),
"aten::lt.Scalar_out": MakeTorchFallback(),
"aten::equal": MakeTorchFallback(),
"aten::_softmax": Softmax("self", axis="dim"),
"aten::argmax.out": SignatureOnly(),
}
# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439

View file

@ -1,15 +1,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from typing import Optional, Dict, List, Union
import sys
import json
import sys
from typing import Dict, List, Optional, Union
import opgen.lexer as lexer
import opgen.parser as parser
import opgen.ast as ast
import opgen.writer as writer
from opgen import ast, lexer, parser, writer
class Outputs:
@ -356,7 +352,7 @@ class ORTGen:
writer.writeline("}")
# Torch kwargs -> ORT attributes
attrs = {k: v for k, v in onnx_op.attributes.items() if v and v.value}
attrs = {k: v for k, v in onnx_op.attributes.items() if v and v.value is not None}
if len(attrs) > 0:
attrs_arg = "attrs"
writer.writeline()

View file

@ -106,7 +106,7 @@ OrtValue create_ort_value(
Ort::BFloat16_t *valOrtBFloat16 = reinterpret_cast<Ort::BFloat16_t *>(&valBFloat16);
CopyVectorToTensor<Ort::BFloat16_t>(invoker, valOrtBFloat16, 1, *ort_tensor);
break;
}
}
default:
// TODO: support more types
// For most at::ScalarType, it should be safe to just call value.to<>
@ -163,7 +163,7 @@ onnx::AttributeProto create_ort_attribute(
at::ScalarType type = value.type();
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
auto* constant_attribute_tensor_proto = attr.mutable_t();
constant_attribute_tensor_proto->mutable_dims()->Clear();
constant_attribute_tensor_proto->mutable_dims()->Clear();
switch (type) {
case at::ScalarType::Float:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
@ -341,7 +341,7 @@ OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at:
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
return output[0];
return output[0];
}
//#pragma endregion
@ -396,9 +396,9 @@ at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::opti
// aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
at::Tensor as_strided(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
ORT_LOG_FN(self, size, stride, storage_offset);
auto& invoker = GetORTInvoker(self.device());
@ -416,8 +416,8 @@ at::Tensor as_strided(
}
at::Tensor _reshape_alias(
const at::Tensor& self,
at::IntArrayRef size,
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride){
ORT_LOG_FN(self, size, stride);
// TODO: support stride
@ -471,22 +471,22 @@ at::Tensor& copy_(
auto status = invoker.Invoke("Cast", {
std::move(ort_src),
}, ort_cast_output, &attrs);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
copy(invoker, ort_cast_output[0], ort_self);
}
else{
copy(invoker, ort_src, ort_self);
}
return self;
}
at::Tensor _copy_from_and_resize(
const at::Tensor& self,
const at::Tensor& self,
const at::Tensor& dst){
ORT_LOG_FN(self, dst);
@ -517,7 +517,7 @@ at::Tensor& zero_(at::Tensor& self){
CopyVectorToTensor<int64_t>(invoker, &one, 1, *ort_flag_tensor);
std::vector<OrtValue> ort_out = {ort_in_self};
auto status = invoker.Invoke(
"ZeroGradient", {
std::move(ort_in_self),
@ -534,58 +534,58 @@ at::Tensor& zero_(at::Tensor& self){
// TODO: enhance opgen.py to support inplace binary operations.
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
at::Tensor& add__Tensor(
at::Tensor& self,
const at::Tensor& other,
at::Tensor& self,
const at::Tensor& other,
const at::Scalar& alpha) {
ORT_LOG_FN(self, other, alpha);
if (
!IsSupportedType(alpha, {at::kDouble,at::kLong,at::kHalf,at::kShort,at::kInt,at::kByte,at::kFloat,at::kBFloat16}) ||
!IsSupportedType(other, {at::kDouble,at::kLong,at::kHalf,at::kShort,at::kInt,at::kByte,at::kFloat,at::kBFloat16}) ||
!IsSupportedType(self, {at::kDouble,at::kLong,at::kHalf,at::kShort,at::kInt,at::kByte,at::kFloat,at::kBFloat16})) {
!IsSupportedType(alpha, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
!IsSupportedType(other, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
!IsSupportedType(self, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16})) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(add__Tensor)>::call(self, other, alpha);
}
auto& invoker = GetORTInvoker(self.device());
auto ort_input_alpha = create_ort_value(invoker, alpha, other.scalar_type());
auto ort_input_other = create_ort_value(invoker, other);
std::vector<OrtValue> ort_outputs_0_Mul(1);
auto status = invoker.Invoke("Mul", {
std::move(ort_input_alpha),
std::move(ort_input_other),
}, ort_outputs_0_Mul, nullptr);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
auto ort_input_self = create_ort_value(invoker, self);
std::vector<OrtValue> ort_outputs_1_Add(1);
ort_outputs_1_Add[0] = ort_input_self;
status = invoker.Invoke("Add", {
std::move(ort_input_self),
std::move(ort_outputs_0_Mul[0]),
}, ort_outputs_1_Add, nullptr);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
return self;
}
// aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
at::Tensor slice_Tensor(
const at::Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
const at::Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
ORT_LOG_FN(self, dim, start, end, step);
int64_t ndim = self.dim();
@ -634,6 +634,55 @@ at::Tensor slice_Tensor(
self.options());
}
// aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
at::Tensor& argmax_out(
const at::Tensor& self,
c10::optional<int64_t> dim,
bool keepdim,
// *,
at::Tensor& out) {
ORT_LOG_FN(self, dim, keepdim, out);
if (
!IsSupportedType(self, {at::kLong, at::kShort, at::kHalf, at::kBFloat16, at::kFloat, at::kByte, at::kInt, at::kDouble})) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out);
}
auto& invoker = GetORTInvoker(self.device());
auto ort_input_self =
create_ort_value(invoker, dim.has_value() ? self : self.reshape({-1}));
// Remove this hand signature once the generator can support this one line below.
int64_t l_axis = dim.has_value() ? *dim : 0;
NodeAttributes attrs(2);
attrs["axis"] = create_ort_attribute(
"axis", l_axis, at::ScalarType::Int);
attrs["keepdims"] = create_ort_attribute(
"keepdims", keepdim, at::ScalarType::Int);
std::vector<OrtValue> ort_outputs_0_ArgMax(1);
auto status = invoker.Invoke("ArgMax", {
std::move(ort_input_self),
}, ort_outputs_0_ArgMax, &attrs);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
at::TensorOptions tensor_options = out.options();
// generator also needs to do this to handle the out param!
out = aten_tensor_from_ort(
std::move(ort_outputs_0_ArgMax[0]),
tensor_options);
return out;
}
} // namespace aten
//#pragma endregion

View file

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# pylint: disable=missing-docstring
import unittest
import numpy as np
@ -9,6 +11,8 @@ import torch
class OrtOpTests(unittest.TestCase):
"""test cases for supported eager ops"""
def get_device(self):
return torch_ort.device()
@ -101,16 +105,18 @@ class OrtOpTests(unittest.TestCase):
def test_max(self):
cpu_tensor = torch.rand(10, 10)
ort_tensor = cpu_tensor.to("ort")
y = ort_tensor.max()
x = cpu_tensor.max()
assert torch.allclose(x, y.cpu())
ort_min = ort_tensor.max()
cpu_min = cpu_tensor.max()
assert torch.allclose(cpu_min, ort_min.cpu())
assert cpu_min.dim() == ort_min.dim()
def test_min(self):
cpu_tensor = torch.rand(10, 10)
ort_tensor = cpu_tensor.to("ort")
y = ort_tensor.min()
x = cpu_tensor.min()
assert torch.allclose(x, y.cpu())
ort_min = ort_tensor.min()
cpu_min = cpu_tensor.min()
assert torch.allclose(cpu_min, ort_min.cpu())
assert cpu_min.dim() == ort_min.dim()
def test_equal(self):
device = self.get_device()
@ -152,6 +158,24 @@ class OrtOpTests(unittest.TestCase):
ort_result = torch.softmax(ort_tensor, dim=1)
assert torch.allclose(cpu_result, ort_result.cpu())
def test_addmm(self):
device = self.get_device()
size = 4
ort_tensor = torch.ones([size, size]).to(device)
input_bias = torch.ones([size]).to(device)
output = torch.addmm(input_bias, ort_tensor, ort_tensor)
expected = torch.ones([size, size]) * 5
assert torch.equal(output.to("cpu"), expected)
def test_argmax(self):
device = self.get_device()
cpu_tensor = torch.rand(3, 5)
ort_tensor = cpu_tensor.to(device)
cpu_result = torch.argmax(cpu_tensor, dim=1)
ort_result = torch.argmax(ort_tensor, dim=1)
assert torch.allclose(cpu_result, ort_result.cpu())
assert cpu_result.dim() == ort_result.dim()
if __name__ == "__main__":
unittest.main()