diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 1d80f2a48b..a7360e4df3 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -87,6 +87,7 @@ hand_implemented = { 'aten::max' : ReduceMax('self', keepdims=1), 'aten::min' : ReduceMin('self', keepdims=1), 'aten::_cat': Concat('tensors', 'dim'), + 'aten::fill_.Scalar': ConstantOfShape('self', value='value'), 'aten::ne.Scalar':MakeTorchFallback(), 'aten::ne.Scalar_out': MakeTorchFallback(), diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 64073e5689..6bfc754344 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -360,6 +360,8 @@ class ORTGen: writer.write(f'"{attr_name}", {attr.value}') if attr.type.startswith('at::ScalarType::'): writer.write(f', {attr.type}') + elif attr.type == AttrType.TENSOR: + writer.write(f', true') elif attr.type != AttrType.STRING: raise FunctionGenerationError( cpp_func, diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 6834ed7ea4..cd58a0004e 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -155,8 +155,44 @@ std::vector create_ort_value( onnx::AttributeProto create_ort_attribute( const char* name, - at::Scalar value) { - return create_ort_attribute(name, value, value.type()); + at::Scalar value, + const bool isTensor) { + if (isTensor){ + onnx::AttributeProto attr; + attr.set_name(name); + at::ScalarType type = value.type(); + attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR); + auto* constant_attribute_tensor_proto = attr.mutable_t(); + constant_attribute_tensor_proto->mutable_dims()->Clear(); + switch (type) { + case at::ScalarType::Float: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + break; + case at::ScalarType::Double: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + break; + case at::ScalarType::Bool: + case at::ScalarType::Int: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + break; + case at::ScalarType::Long: + constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + *constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to(); + break; + default: + // For most at::ScalarType, it should be safe to just call value.to<> + // on it, but for now we want to explicitly know when we've encountered + // a new scalar type while bringing up ORT eager mode. + ORT_THROW("Unsupported: at::ScalarType::", value.type()); + } + return attr; + } + else{ + return create_ort_attribute(name, value, value.type()); + } } onnx::AttributeProto create_ort_attribute( diff --git a/orttraining/orttraining/eager/ort_aten.h b/orttraining/orttraining/eager/ort_aten.h index bd2036b9aa..6e86198f31 100644 --- a/orttraining/orttraining/eager/ort_aten.h +++ b/orttraining/orttraining/eager/ort_aten.h @@ -90,7 +90,8 @@ OrtValue create_ort_value( onnx::AttributeProto create_ort_attribute( const char* name, - at::Scalar value); + at::Scalar value, + const bool isTensor=false); onnx::AttributeProto create_ort_attribute( const char* name, diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index bc52af5d50..9d1aaa81a5 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -114,6 +114,14 @@ class OrtOpTests(unittest.TestCase): x = cpu_tensor.min() assert torch.allclose(x, y.cpu()) + def test_torch_ones(self): + device = self.get_device() + cpu_ones = torch.ones((10,10)) + ort_ones = cpu_ones.to(device) + ort_ones_device = torch.ones((10, 10), device = device) + assert torch.allclose(cpu_ones, ort_ones.cpu()) + assert torch.allclose(cpu_ones, ort_ones_device.cpu()) + def test_narrow(self): cpu_tensor = torch.rand(10, 10) cpu_narrow = cpu_tensor.narrow(0, 5, 5)