adding fill scalar for torch ones direct initialization on ort device (#10898)

* adding fill scalar for torch ones direct initialization on device and adding test case for it

* using ConstantOfShape to for implementing fill Scalar in atenops

* adding case for handling at::Tensor attribute

* handling the at::Tensor type for ConstantOfShape

* handling the at::Tensor type for ConstantOfShape with attr type

* handling the at::Tensor type case

* converting the data to tensor in case of aten tensor mapping is needed

* handling aten tensor case

* handling aten tensor case and reversing the string case

* changing type of scalar
This commit is contained in:
Abhishek Jindal 2022-04-05 11:17:25 -07:00 committed by GitHub
parent 2c2408814f
commit 91c940b619
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 3 deletions

View file

@ -87,6 +87,7 @@ hand_implemented = {
'aten::max' : ReduceMax('self', keepdims=1),
'aten::min' : ReduceMin('self', keepdims=1),
'aten::_cat': Concat('tensors', 'dim'),
'aten::fill_.Scalar': ConstantOfShape('self', value='value'),
'aten::ne.Scalar':MakeTorchFallback(),
'aten::ne.Scalar_out': MakeTorchFallback(),

View file

@ -360,6 +360,8 @@ class ORTGen:
writer.write(f'"{attr_name}", {attr.value}')
if attr.type.startswith('at::ScalarType::'):
writer.write(f', {attr.type}')
elif attr.type == AttrType.TENSOR:
writer.write(f', true')
elif attr.type != AttrType.STRING:
raise FunctionGenerationError(
cpp_func,

View file

@ -155,8 +155,44 @@ std::vector<OrtValue> create_ort_value(
onnx::AttributeProto create_ort_attribute(
const char* name,
at::Scalar value) {
return create_ort_attribute(name, value, value.type());
at::Scalar value,
const bool isTensor) {
if (isTensor){
onnx::AttributeProto attr;
attr.set_name(name);
at::ScalarType type = value.type();
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
auto* constant_attribute_tensor_proto = attr.mutable_t();
constant_attribute_tensor_proto->mutable_dims()->Clear();
switch (type) {
case at::ScalarType::Float:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<float>();
break;
case at::ScalarType::Double:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<double>();
break;
case at::ScalarType::Bool:
case at::ScalarType::Int:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int>();
break;
case at::ScalarType::Long:
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int64_t>();
break;
default:
// For most at::ScalarType, it should be safe to just call value.to<>
// on it, but for now we want to explicitly know when we've encountered
// a new scalar type while bringing up ORT eager mode.
ORT_THROW("Unsupported: at::ScalarType::", value.type());
}
return attr;
}
else{
return create_ort_attribute(name, value, value.type());
}
}
onnx::AttributeProto create_ort_attribute(

View file

@ -90,7 +90,8 @@ OrtValue create_ort_value(
onnx::AttributeProto create_ort_attribute(
const char* name,
at::Scalar value);
at::Scalar value,
const bool isTensor=false);
onnx::AttributeProto create_ort_attribute(
const char* name,

View file

@ -114,6 +114,14 @@ class OrtOpTests(unittest.TestCase):
x = cpu_tensor.min()
assert torch.allclose(x, y.cpu())
def test_torch_ones(self):
device = self.get_device()
cpu_ones = torch.ones((10,10))
ort_ones = cpu_ones.to(device)
ort_ones_device = torch.ones((10, 10), device = device)
assert torch.allclose(cpu_ones, ort_ones.cpu())
assert torch.allclose(cpu_ones, ort_ones_device.cpu())
def test_narrow(self):
cpu_tensor = torch.rand(10, 10)
cpu_narrow = cpu_tensor.narrow(0, 5, 5)