Update create_ort_attribute to set the tensor dimension and value correctly. Implement eager fill_ (#12018)

* Update create_ort_attribute to set the tensor dimension and value correctly.

* Eager mode support for fill_ and mm.out (mm uses mm.out).
This commit is contained in:
Wil Brady 2022-07-11 11:18:04 -04:00 committed by GitHub
parent 1c39d22f4e
commit 418cfdc766
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 7 deletions

View file

@ -142,7 +142,7 @@ hand_implemented = {
"aten::addmm": Gemm("mat1", "mat2", "self", alpha="alpha", beta="beta"),
"aten::add_.Tensor": SignatureOnly(),
"aten::t": Transpose("self"),
"aten::mm": MatMul("self", "mat2"),
"aten::mm.out": MatMul("self", "mat2"),
"aten::zeros_like": ConstantOfShape(
Shape("self")
), # the default constant is 0, so don't need to speicify attribute
@ -156,7 +156,7 @@ hand_implemented = {
"aten::max": ReduceMax("self", keepdims=0),
"aten::min": ReduceMin("self", keepdims=0),
"aten::_cat": Concat("tensors", "dim"),
"aten::fill_.Scalar": ConstantOfShape("self", value="value"),
"aten::fill_.Scalar": SignatureOnly(),
"aten::ne.Scalar_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
"aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
"aten::eq.Tensor_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"),

View file

@ -155,14 +155,16 @@ std::vector<OrtValue> create_ort_value(
onnx::AttributeProto create_ort_attribute(
const char* name,
at::Scalar value,
const bool isTensor) {
const bool isTensor,
at::ScalarType type) {
if (isTensor){
onnx::AttributeProto attr;
attr.set_name(name);
at::ScalarType type = value.type();
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
auto* constant_attribute_tensor_proto = attr.mutable_t();
constant_attribute_tensor_proto->mutable_dims()->Clear();
// Creating a 1 dim tensor of size 1, so add that dim now.
constant_attribute_tensor_proto->add_dims(1);
switch (type) {
case at::ScalarType::Float:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
@ -170,16 +172,16 @@ onnx::AttributeProto create_ort_attribute(
break;
case at::ScalarType::Double:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<double>();
*constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to<double>();
break;
case at::ScalarType::Bool:
case at::ScalarType::Int:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int>();
*constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to<int>();
break;
case at::ScalarType::Long:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int64_t>();
*constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to<int64_t>();
break;
default:
// For most at::ScalarType, it should be safe to just call value.to<>
@ -194,6 +196,13 @@ onnx::AttributeProto create_ort_attribute(
}
}
onnx::AttributeProto create_ort_attribute(
const char* name,
at::Scalar value,
const bool isTensor) {
return create_ort_attribute(name, value, isTensor, value.type());
}
onnx::AttributeProto create_ort_attribute(
const char* name,
at::Scalar value,
@ -901,6 +910,52 @@ const at::Tensor& resize_(
return self;
}
// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
at::Tensor& fill__Scalar(
at::Tensor& self,
const at::Scalar& value) {
ORT_LOG_FN(self, value);
if (
!IsSupportedType(self, {at::kHalf,at::kFloat,at::kInt,at::kDouble,at::kByte,at::kShort,at::kLong,at::kBFloat16,at::kBool})) {
std::cout << "fill__Scalar - Fell back to cpu!\n";
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(fill__Scalar)>::call(self, value);
}
auto& invoker = GetORTInvoker(self.device());
auto ort_input_self = create_ort_value(invoker, self);
std::vector<OrtValue> ort_outputs_0_Shape(1);
auto status = invoker.Invoke("Shape", {
std::move(ort_input_self),
}, ort_outputs_0_Shape, nullptr);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
std::vector<OrtValue> ort_outputs_1_ConstantOfShape(1);
ort_outputs_1_ConstantOfShape[0] = ort_input_self;
NodeAttributes attrs(1);
attrs["value"] = create_ort_attribute(
"value", value, true, self.scalar_type());
status = invoker.Invoke("ConstantOfShape", {
std::move(ort_outputs_0_Shape[0]),
}, ort_outputs_1_ConstantOfShape, &attrs);
if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
return self;
}
} // namespace aten
//#pragma endregion

View file

@ -344,6 +344,17 @@ class OrtOpTests(unittest.TestCase):
assert torch.equal(cpu_float_float_result, ort_float_float_result.to("cpu"))
assert torch.equal(cpu_float_float_not_result, ort_float_float_not_result.to("cpu"))
def test_fill(self):
device = self.get_device()
for type in {torch.int, torch.float}:
cpu_tensor = torch.zeros(2, 2, dtype=type)
ort_tensor = cpu_tensor.to(device)
for value in {True, 1.1, -1, 0}:
cpu_tensor.fill_(value)
ort_tensor.fill_(value)
assert cpu_tensor.dtype == ort_tensor.dtype
assert torch.equal(cpu_tensor, ort_tensor.to("cpu"))
if __name__ == "__main__":
unittest.main()