diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 0c83e0e223..4babc3fa38 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -142,7 +142,7 @@ hand_implemented = { "aten::addmm": Gemm("mat1", "mat2", "self", alpha="alpha", beta="beta"), "aten::add_.Tensor": SignatureOnly(), "aten::t": Transpose("self"), - "aten::mm": MatMul("self", "mat2"), + "aten::mm.out": MatMul("self", "mat2"), "aten::zeros_like": ConstantOfShape( Shape("self") ), # the default constant is 0, so don't need to speicify attribute @@ -156,7 +156,7 @@ hand_implemented = { "aten::max": ReduceMax("self", keepdims=0), "aten::min": ReduceMin("self", keepdims=0), "aten::_cat": Concat("tensors", "dim"), - "aten::fill_.Scalar": ConstantOfShape("self", value="value"), + "aten::fill_.Scalar": SignatureOnly(), "aten::ne.Scalar_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"), "aten::ne.Tensor_out": Cast(Not(Equal("self", "other")), to="GetONNXTensorProtoDataType(out.scalar_type())"), "aten::eq.Tensor_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"), diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 5de3da912d..5215762619 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -155,14 +155,16 @@ std::vector create_ort_value( onnx::AttributeProto create_ort_attribute( const char* name, at::Scalar value, - const bool isTensor) { + const bool isTensor, + at::ScalarType type) { if (isTensor){ onnx::AttributeProto attr; attr.set_name(name); - at::ScalarType type = value.type(); attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR); auto* constant_attribute_tensor_proto = attr.mutable_t(); constant_attribute_tensor_proto->mutable_dims()->Clear(); + // Creating a 1 dim tensor of size 1, so add that dim now. + constant_attribute_tensor_proto->add_dims(1); switch (type) { case at::ScalarType::Float: constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); @@ -170,16 +172,16 @@ onnx::AttributeProto create_ort_attribute( break; case at::ScalarType::Double: constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); - *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + *constant_attribute_tensor_proto->mutable_double_data()->Add() = value.to(); break; case at::ScalarType::Bool: case at::ScalarType::Int: constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); - *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + *constant_attribute_tensor_proto->mutable_int32_data()->Add() = value.to(); break; case at::ScalarType::Long: constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + *constant_attribute_tensor_proto->mutable_int64_data()->Add() = value.to(); break; default: // For most at::ScalarType, it should be safe to just call value.to<> @@ -194,6 +196,13 @@ onnx::AttributeProto create_ort_attribute( } } +onnx::AttributeProto create_ort_attribute( + const char* name, + at::Scalar value, + const bool isTensor) { + return create_ort_attribute(name, value, isTensor, value.type()); +} + onnx::AttributeProto create_ort_attribute( const char* name, at::Scalar value, @@ -901,6 +910,52 @@ const at::Tensor& resize_( return self; } +// aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) +at::Tensor& fill__Scalar( + at::Tensor& self, + const at::Scalar& value) { + ORT_LOG_FN(self, value); + + if ( + !IsSupportedType(self, {at::kHalf,at::kFloat,at::kInt,at::kDouble,at::kByte,at::kShort,at::kLong,at::kBFloat16,at::kBool})) { + std::cout << "fill__Scalar - Fell back to cpu!\n"; + return at::native::call_fallback_fn< + &at::native::cpu_fallback, + ATEN_OP(fill__Scalar)>::call(self, value); + } + auto& invoker = GetORTInvoker(self.device()); + + auto ort_input_self = create_ort_value(invoker, self); + + std::vector ort_outputs_0_Shape(1); + + auto status = invoker.Invoke("Shape", { + std::move(ort_input_self), + }, ort_outputs_0_Shape, nullptr); + + if (!status.IsOK()) + throw std::runtime_error( + "ORT return failure status:" + status.ErrorMessage()); + + std::vector ort_outputs_1_ConstantOfShape(1); + ort_outputs_1_ConstantOfShape[0] = ort_input_self; + + NodeAttributes attrs(1); + attrs["value"] = create_ort_attribute( + "value", value, true, self.scalar_type()); + + status = invoker.Invoke("ConstantOfShape", { + std::move(ort_outputs_0_Shape[0]), + }, ort_outputs_1_ConstantOfShape, &attrs); + + if (!status.IsOK()) + throw std::runtime_error( + "ORT return failure status:" + status.ErrorMessage()); + + return self; +} + + } // namespace aten //#pragma endregion diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 4224e472d0..f73abf5d32 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -344,6 +344,17 @@ class OrtOpTests(unittest.TestCase): assert torch.equal(cpu_float_float_result, ort_float_float_result.to("cpu")) assert torch.equal(cpu_float_float_not_result, ort_float_float_not_result.to("cpu")) + def test_fill(self): + device = self.get_device() + for type in {torch.int, torch.float}: + cpu_tensor = torch.zeros(2, 2, dtype=type) + ort_tensor = cpu_tensor.to(device) + for value in {True, 1.1, -1, 0}: + cpu_tensor.fill_(value) + ort_tensor.fill_(value) + assert cpu_tensor.dtype == ort_tensor.dtype + assert torch.equal(cpu_tensor, ort_tensor.to("cpu")) + if __name__ == "__main__": unittest.main()