mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-20 21:40:57 +00:00
fix reshape implementation in eager mode (#9741)
This commit is contained in:
parent
7ea19539f8
commit
fcc167dd47
3 changed files with 19 additions and 5 deletions
|
|
@ -31,7 +31,7 @@ ops = {
|
|||
'aten::empty_strided': SignatureOnly(),
|
||||
'aten::zero_': SignatureOnly(),
|
||||
'aten::copy_': SignatureOnly(),
|
||||
'aten::reshape': SignatureOnly(),
|
||||
'aten::_reshape_alias': SignatureOnly(),
|
||||
'aten::view': SignatureOnly(),
|
||||
|
||||
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ onnxruntime::MLDataType ort_scalar_type_from_aten(
|
|||
return onnxruntime::DataTypeImpl::GetType<int16_t>();
|
||||
case at::kLong:
|
||||
return onnxruntime::DataTypeImpl::GetType<int64_t>();
|
||||
case at::kBool:
|
||||
return onnxruntime::DataTypeImpl::GetType<bool>();
|
||||
default:
|
||||
ORT_THROW("Unsupport aten scalar type: ", dtype);
|
||||
}
|
||||
|
|
@ -222,15 +224,20 @@ at::Tensor empty_strided(
|
|||
.dtype(dtype));
|
||||
}
|
||||
|
||||
at::Tensor reshape(at::Tensor const& self, at::IntArrayRef shape) {
|
||||
ORT_LOG_FN(self, shape);
|
||||
|
||||
at::Tensor _reshape_alias(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
at::IntArrayRef stride){
|
||||
ORT_LOG_FN(self, size, stride);
|
||||
// TODO: support stride
|
||||
auto& invoker = GetORTInvoker(self.device());
|
||||
return aten_tensor_from_ort(
|
||||
reshape_copy(
|
||||
invoker,
|
||||
create_ort_value(invoker, self),
|
||||
shape.vec()),
|
||||
at::infer_size(
|
||||
size,
|
||||
self.numel())),
|
||||
self.options());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,13 @@ class OrtTensorTests(unittest.TestCase):
|
|||
ort_ones = cpu_ones.to('ort')
|
||||
assert ort_ones.is_ort
|
||||
assert torch.allclose(cpu_ones, ort_ones.cpu())
|
||||
|
||||
def test_reshape(self):
|
||||
cpu_ones = torch.ones(10, 10)
|
||||
ort_ones = cpu_ones.to('ort')
|
||||
y = ort_ones.reshape(-1)
|
||||
assert len(y.size()) == 1
|
||||
assert y.size()[0] == 100
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue