mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
adding fill scalar for torch ones direct initialization on ort device (#10898)
* adding fill scalar for torch ones direct initialization on device and adding test case for it * using ConstantOfShape to for implementing fill Scalar in atenops * adding case for handling at::Tensor attribute * handling the at::Tensor type for ConstantOfShape * handling the at::Tensor type for ConstantOfShape with attr type * handling the at::Tensor type case * converting the data to tensor in case of aten tensor mapping is needed * handling aten tensor case * handling aten tensor case and reversing the string case * changing type of scalar
This commit is contained in:
parent
2c2408814f
commit
91c940b619
5 changed files with 51 additions and 3 deletions
|
|
@ -87,6 +87,7 @@ hand_implemented = {
|
|||
'aten::max' : ReduceMax('self', keepdims=1),
|
||||
'aten::min' : ReduceMin('self', keepdims=1),
|
||||
'aten::_cat': Concat('tensors', 'dim'),
|
||||
'aten::fill_.Scalar': ConstantOfShape('self', value='value'),
|
||||
|
||||
'aten::ne.Scalar':MakeTorchFallback(),
|
||||
'aten::ne.Scalar_out': MakeTorchFallback(),
|
||||
|
|
|
|||
|
|
@ -360,6 +360,8 @@ class ORTGen:
|
|||
writer.write(f'"{attr_name}", {attr.value}')
|
||||
if attr.type.startswith('at::ScalarType::'):
|
||||
writer.write(f', {attr.type}')
|
||||
elif attr.type == AttrType.TENSOR:
|
||||
writer.write(f', true')
|
||||
elif attr.type != AttrType.STRING:
|
||||
raise FunctionGenerationError(
|
||||
cpp_func,
|
||||
|
|
|
|||
|
|
@ -155,8 +155,44 @@ std::vector<OrtValue> create_ort_value(
|
|||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value) {
|
||||
return create_ort_attribute(name, value, value.type());
|
||||
at::Scalar value,
|
||||
const bool isTensor) {
|
||||
if (isTensor){
|
||||
onnx::AttributeProto attr;
|
||||
attr.set_name(name);
|
||||
at::ScalarType type = value.type();
|
||||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
|
||||
auto* constant_attribute_tensor_proto = attr.mutable_t();
|
||||
constant_attribute_tensor_proto->mutable_dims()->Clear();
|
||||
switch (type) {
|
||||
case at::ScalarType::Float:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<float>();
|
||||
break;
|
||||
case at::ScalarType::Double:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<double>();
|
||||
break;
|
||||
case at::ScalarType::Bool:
|
||||
case at::ScalarType::Int:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int>();
|
||||
break;
|
||||
case at::ScalarType::Long:
|
||||
constant_attribute_tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
*constant_attribute_tensor_proto->mutable_float_data()->Add() = value.to<int64_t>();
|
||||
break;
|
||||
default:
|
||||
// For most at::ScalarType, it should be safe to just call value.to<>
|
||||
// on it, but for now we want to explicitly know when we've encountered
|
||||
// a new scalar type while bringing up ORT eager mode.
|
||||
ORT_THROW("Unsupported: at::ScalarType::", value.type());
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
else{
|
||||
return create_ort_attribute(name, value, value.type());
|
||||
}
|
||||
}
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
|
|
|
|||
|
|
@ -90,7 +90,8 @@ OrtValue create_ort_value(
|
|||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
at::Scalar value);
|
||||
at::Scalar value,
|
||||
const bool isTensor=false);
|
||||
|
||||
onnx::AttributeProto create_ort_attribute(
|
||||
const char* name,
|
||||
|
|
|
|||
|
|
@ -114,6 +114,14 @@ class OrtOpTests(unittest.TestCase):
|
|||
x = cpu_tensor.min()
|
||||
assert torch.allclose(x, y.cpu())
|
||||
|
||||
def test_torch_ones(self):
|
||||
device = self.get_device()
|
||||
cpu_ones = torch.ones((10,10))
|
||||
ort_ones = cpu_ones.to(device)
|
||||
ort_ones_device = torch.ones((10, 10), device = device)
|
||||
assert torch.allclose(cpu_ones, ort_ones.cpu())
|
||||
assert torch.allclose(cpu_ones, ort_ones_device.cpu())
|
||||
|
||||
def test_narrow(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
cpu_narrow = cpu_tensor.narrow(0, 5, 5)
|
||||
|
|
|
|||
Loading…
Reference in a new issue