onnxruntime/orttraining
Jameson Miller 975bb56e8c
Eager mode - argmax_out: set output tensor (#12233)
This change updates the implementation or te argmax_out operator to 1)
set the output tensor correctly and 2) remove the unnecessary use of a
temporary tensor to store intermediate result of onnx ArgMax operation.

Previously, the argmax_out operator did not correctly update the out
tensor - it replaced the OrtValue instead of the memory backing the
OrtValue . To properly update the output tensor, we need to calculate
the expected shape of the out tensor.

We add the helper function calculate_reduction_shape to calculate the
shape of the reduced tensor from the input tensor, dimension to reduce,
and option to keep the reduced dimension or not. This is based on the
utility functions in aten/src/ATen/native/ReduceOpsUtils.h in the
PyTorch repository, but is tailored to be a bit more specific to our
current needs.

Notes:

We considered just directly leveraging PyTorch's utility functions (e.g.
get_reduction_shape) to calculate the shape of the reduced tensor from
aten/src/ATen/native/ReduceOpsUtils.h in the PyTorch repository, but
including this header file resulted in warnings around unused functions
that we need to handle. As we only need a limited functionality at the
moment, we instead implemented our own utility function to calculate the
reduction shape for our specific current needs. If we need a utility
function to more generally calculate the reduction shape, we could
consider switching to leveraging the utility methods in PyTorch.
2022-07-19 14:37:03 -04:00
..
orttraining Eager mode - argmax_out: set output tensor (#12233) 2022-07-19 14:37:03 -04:00
pytorch_frontend_examples Set black's target version (#11370) 2022-04-27 14:52:19 -07:00
tools [UPDATE] update AMD CI pipeline to Rocm5.2 with torch1.11 (#12162) 2022-07-14 16:38:16 +08:00