fix reshape implementation in eager mode (#9741)

This commit is contained in:
Tang, Cheng 2021-11-18 19:26:49 -08:00 committed by GitHub
parent 7ea19539f8
commit fcc167dd47
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 5 deletions

View file

@ -31,7 +31,7 @@ ops = {
'aten::empty_strided': SignatureOnly(),
'aten::zero_': SignatureOnly(),
'aten::copy_': SignatureOnly(),
'aten::reshape': SignatureOnly(),
'aten::_reshape_alias': SignatureOnly(),
'aten::view': SignatureOnly(),
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),

View file

@ -67,6 +67,8 @@ onnxruntime::MLDataType ort_scalar_type_from_aten(
return onnxruntime::DataTypeImpl::GetType<int16_t>();
case at::kLong:
return onnxruntime::DataTypeImpl::GetType<int64_t>();
case at::kBool:
return onnxruntime::DataTypeImpl::GetType<bool>();
default:
ORT_THROW("Unsupport aten scalar type: ", dtype);
}
@ -222,15 +224,20 @@ at::Tensor empty_strided(
.dtype(dtype));
}
at::Tensor reshape(at::Tensor const& self, at::IntArrayRef shape) {
ORT_LOG_FN(self, shape);
at::Tensor _reshape_alias(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride){
ORT_LOG_FN(self, size, stride);
// TODO: support stride
auto& invoker = GetORTInvoker(self.device());
return aten_tensor_from_ort(
reshape_copy(
invoker,
create_ort_value(invoker, self),
shape.vec()),
at::infer_size(
size,
self.numel())),
self.options());
}

View file

@ -19,6 +19,13 @@ class OrtTensorTests(unittest.TestCase):
ort_ones = cpu_ones.to('ort')
assert ort_ones.is_ort
assert torch.allclose(cpu_ones, ort_ones.cpu())
def test_reshape(self):
cpu_ones = torch.ones(10, 10)
ort_ones = cpu_ones.to('ort')
y = ort_ones.reshape(-1)
assert len(y.size()) == 1
assert y.size()[0] == 100
if __name__ == '__main__':
unittest.main()