mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
This change updates the implementation or te argmax_out operator to 1) set the output tensor correctly and 2) remove the unnecessary use of a temporary tensor to store intermediate result of onnx ArgMax operation. Previously, the argmax_out operator did not correctly update the out tensor - it replaced the OrtValue instead of the memory backing the OrtValue . To properly update the output tensor, we need to calculate the expected shape of the out tensor. We add the helper function calculate_reduction_shape to calculate the shape of the reduced tensor from the input tensor, dimension to reduce, and option to keep the reduced dimension or not. This is based on the utility functions in aten/src/ATen/native/ReduceOpsUtils.h in the PyTorch repository, but is tailored to be a bit more specific to our current needs. Notes: We considered just directly leveraging PyTorch's utility functions (e.g. get_reduction_shape) to calculate the shape of the reduced tensor from aten/src/ATen/native/ReduceOpsUtils.h in the PyTorch repository, but including this header file resulted in warnings around unused functions that we need to handle. As we only need a limited functionality at the moment, we instead implemented our own utility function to calculate the reduction shape for our specific current needs. If we need a utility function to more generally calculate the reduction shape, we could consider switching to leveraging the utility methods in PyTorch. |
||
|---|---|---|
| .. | ||
| orttraining | ||
| pytorch_frontend_examples | ||
| tools | ||