mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Update create_ort_attribute to set the tensor dimension and value correctly. Implement eager fill_ (#12018)
* Update create_ort_attribute to set the tensor dimension and value correctly. * Eager mode support for fill_ and mm.out (mm uses mm.out).
This commit is contained in:
parent
1c39d22f4e
commit
418cfdc766
3 changed files with 73 additions and 7 deletions
|
|
@ -142,7 +142,7 @@ hand_implemented = {
|
|||
"aten::addmm": Gemm("mat1", "mat2", "self", alpha="alpha", beta="beta"),
|
||||
"aten::add_.Tensor": SignatureOnly(),
|
||||
"aten::t": Transpose("self"),
|
||||
"aten::mm": MatMul("self", "mat2"),
|
||||
"aten::mm.out": MatMul("self", "mat2"),
|
||||
"aten::zeros_like": ConstantOfShape(
|
||||
Shape("self")
|
||||
), # the default constant is 0, so don't need to speicify attribute
|
||||
|
|
@ -156,7 +156,7 @@ hand_implemented = {
|
|||
"aten::max": ReduceMax("self", keepdims=0),
|
||||
"aten::min": ReduceMin("self", keepdims=0),
|
||||
"aten::_cat": Concat("tensors", "dim"),
|
||||
"aten::fill_.Scalar": ConstantOfShape("self", value="value"),
|
||||
"aten::fill_.Scalar": SignatureOnly(),
|
||||
"aten::ne.Scalar_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
"aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
"aten::eq.Tensor_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"),
|
||||
|
|
|
|||
|
|
@ -155,14 +155,16 @@ std::vector<OrtValue> create_ort_value(
|
|||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
const bool isTensor) {
|
||||
const bool isTensor,
|
||||
at::ScalarType type) {
|
||||
if (isTensor){
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
at::ScalarType type = value.type();
|
||||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
|
||||
auto* constant_attribute_tensor_proto = attr.mutable_t();
|
||||
constant_attribute_tensor_proto->mutable_dims()->Clear();
|
||||
// Creating a 1 dim tensor of size 1, so add that dim now.
|
||||
constant_attribute_tensor_proto->add_dims(1);
|
||||
switch (type) {
|
||||
case at::ScalarType::Float:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
|
|
@ -170,16 +172,16 @@ onnx::AttributeProto create_ort_attribute(
|
|||
break;
|
||||
case at::ScalarType::Double:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<double>();
|
||||
*constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to<double>();
|
||||
break;
|
||||
case at::ScalarType::Bool:
|
||||
case at::ScalarType::Int:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int>();
|
||||
*constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to<int>();
|
||||
break;
|
||||
case at::ScalarType::Long:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int64_t>();
|
||||
*constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to<int64_t>();
|
||||
break;
|
||||
default:
|
||||
// For most at::ScalarType, it should be safe to just call value.to<>
|
||||
|
|
@ -194,6 +196,13 @@ onnx::AttributeProto create_ort_attribute(
|
|||
}
|
||||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
const bool isTensor) {
|
||||
return create_ort_attribute(name, value, isTensor, value.type());
|
||||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value,
|
||||
|
|
@ -901,6 +910,52 @@ const at::Tensor& resize_(
|
|||
return self;
|
||||
}
|
||||
|
||||
// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
|
||||
at::Tensor& fill__Scalar(
|
||||
at::Tensor& self,
|
||||
const at::Scalar& value) {
|
||||
ORT_LOG_FN(self, value);
|
||||
|
||||
if (
|
||||
!IsSupportedType(self, {at::kHalf,at::kFloat,at::kInt,at::kDouble,at::kByte,at::kShort,at::kLong,at::kBFloat16,at::kBool})) {
|
||||
std::cout << "fill__Scalar - Fell back to cpu!\n";
|
||||
return at::native::call_fallback_fn<
|
||||
&at::native::cpu_fallback,
|
||||
ATEN_OP(fill__Scalar)>::call(self, value);
|
||||
}
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
|
||||
auto ort_input_self = create_ort_value(invoker, self);
|
||||
|
||||
std::vector<OrtValue> ort_outputs_0_Shape(1);
|
||||
|
||||
auto status = invoker.Invoke("Shape", {
|
||||
std::move(ort_input_self),
|
||||
}, ort_outputs_0_Shape, nullptr);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
std::vector<OrtValue> ort_outputs_1_ConstantOfShape(1);
|
||||
ort_outputs_1_ConstantOfShape[0] = ort_input_self;
|
||||
|
||||
NodeAttributes attrs(1);
|
||||
attrs["value"] = create_ort_attribute(
|
||||
"value", value, true, self.scalar_type());
|
||||
|
||||
status = invoker.Invoke("ConstantOfShape", {
|
||||
std::move(ort_outputs_0_Shape[0]),
|
||||
}, ort_outputs_1_ConstantOfShape, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
} // namespace aten
|
||||
|
||||
//#pragma endregion
|
||||
|
|
|
|||
|
|
@ -344,6 +344,17 @@ class OrtOpTests(unittest.TestCase):
|
|||
assert torch.equal(cpu_float_float_result, ort_float_float_result.to("cpu"))
|
||||
assert torch.equal(cpu_float_float_not_result, ort_float_float_not_result.to("cpu"))
|
||||
|
||||
def test_fill(self):
|
||||
device = self.get_device()
|
||||
for type in {torch.int, torch.float}:
|
||||
cpu_tensor = torch.zeros(2, 2, dtype=type)
|
||||
ort_tensor = cpu_tensor.to(device)
|
||||
for value in {True, 1.1, -1, 0}:
|
||||
cpu_tensor.fill_(value)
|
||||
ort_tensor.fill_(value)
|
||||
assert cpu_tensor.dtype == ort_tensor.dtype
|
||||
assert torch.equal(cpu_tensor, ort_tensor.to("cpu"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue