onnxruntime/tools/python/register_custom_ops_pytorch_exporter.py

138 lines
5.7 KiB
Python
Raw Normal View History

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Register pytorch symbolic for export using ONNX Runtime contrib ops
from torch.onnx import register_custom_op_symbolic
cherry picked commits for rel-1.8.1 (#8076) * Cache initializers and avoid device check ot end of forward (#7905) * ATenOp Enhancement (#7725) * config parser, default argument values * ut * win build * maxpool2d * fix win build * fix build * unfold atenop * Update CMakeLists.txt for openvino EP (#7980) * Add SoftmaxCrossEntropyLossInternal to Support Dynamic ignore_index Input (#7899) * add SoftmaxCrossEntropyLossInternal * bugfix and ut * fix ut * fix ut * support torch1.8.1 * function body for nll_loss_internal * Override ORTModule named_modules to support extra arg (#7954) * add missing provider_options.h in packages (#7995) * consolidate copy binary script for gpu/trt tarball package * add provider_options.h * add provider_options.h * Add cuda provides files (#8002) * Save module output for backward if needed (#8010) * Save module output for backward if needed * Make logic in InsertCastTransformer around forcing a node to fp32 more precise. (#8018) * Address #7981 Reworked the logic around forcing a node to run on fp32 even if it was supported on fp16. The github issue had multiple factors. In ORT 1.8 we remove Identity nodes that produce graph outputs as they're not needed. That resulted in a Loop node no longer having output nodes (it produces graph outputs instead), which meant the check in IsSingleInputNodeFloat16Node returned true as there was no longer a downstream Identity node processing fp16 data. We shouldn't only force a node to fp32 in very specific circumstances, and the changes hopefully check for those more precisely. * Fix Memory Leak from DlpackToOrtValue (#8029) * Update DirectML EP changes from DmlDev as of 2021-06-07 (#7987) * Merged PR 6093117: Fix test_DynamicQuantizedLinear_max_adjusted_expanded by allowing Identity operator to run on non-float inputs Motivation: As part of the OnnxConformance Backend tests, DynamicQuantizedLinear_max_adjusted_expanded is failing. Root Cause: - The test model has `Identity` operator as one of the node. The input of this node is of non-float data type. - In DML, `Identity` operator is registered as operator which requires floating input. - As per `DirectMLSchema.h`, support for non-float input has been added for `Identity` operator in DML but the same has not been reflected in the `OperatorRegistration.cpp`. Changes: - Removed all traces of the requiresFloatFormatsForGraph flag from it's definition and usage. This flag was only used for Identity and it's related operator. - Added null check for the graphOutput nodeArg in GraphDescBuilder.cpp to stop the crash of the test. Related work items: #33076298 * Merged PR 6103324: Remove usage of non-generic error code (FWP_E_NULL_POINTER) Motivation: Addressing Dwayne comment on the previous PR. [Ref: [6093117](https://dev.azure.com/microsoft/WindowsAI/_git/onnxruntime/pullrequest/6093117?discussionId=44292162&path=%2Fonnxruntime%2Fcore%2Fproviders%2Fdml%2FDmlExecutionProvider%2Fsrc%2FGraphPartitioner.cpp)] Changes: Inside the DML EP, we should not use some other platform specific error codes. Instead we should a appropriate generic error code. Related work items: #33076298 Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com> * [js/react_native] Use a mobile ORT instead of a full ORT (#8042) * Change full ort to mobile ort * Update Android example to load mobile ort * Change the format of test models to ort * update ios to use mobile ort * revise README * use onnxruntime-mobile-c CocoaPods in a npm package * fix PATH addition in windows should set PATH, not add to the tail the copy of PATH * Reduce Kernel Optimization (#8067) * reduce optimization * bug fix * add a check * add ut * refactor * add ut cases for keepdims=true Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Vincent Wang <wangwchpku@outlook.com> Co-authored-by: Changming Sun <chasun@microsoft.com> Co-authored-by: George Wu <jywu@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Sherlock <baihan.huang@gmail.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: sumitsays <sumitagarwal330@gmail.com> Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com> Co-authored-by: Sunghoon <35605090+hanbitmyths@users.noreply.github.com> Co-authored-by: iperov <lepersorium@gmail.com>
2021-06-18 14:44:55 +00:00
import torch.onnx.symbolic_helper as sym_help
Cherry picks for release - 1.8.1 Round 2 (#8137) * fix boost download url (#7843) * Topo sort the model before saving (#7913) * checkin toposort * review comments * revert and add TODO * Add shape inference to custom symbolic functions (#7937) **Description**: As title. **Motivation and Context** - PyTorch ONNX exporter heavily depends on ONNX shape inference to export accurate and efficient model. Custom symbolic function exports the op as contrib ops, thus exporter is unable to perform standard onnx shape inference. Models with dynamic shape inputs are affected. * Fix missing files on linux (#8066) * [Mobile package] Update required operator config with additional ops for wav2vec2. (#8079) Add some additional ops to the mobile package that are needed for the wav2vec2 model. * Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088) * Fix input schema extrator for ORTModule (#8098) * Fix 32bit Android java API crash (#8122) * Fix 32bit Android java API crash * fix code formating * [Mobile package] Update required operator config with additional ops for newer version of Wav2Vec 2. (#8123) This is an update to https://github.com/microsoft/onnxruntime/pull/8079 The sample application motivating the original update changed to use an updated version of the model. Now, fewer ops are required. This change removes the previously added ops which are no longer needed. * Add int64 as a required type to ConstantOfShape as it's used by the pytorch converter for Pad. (#8128) It's also used pointlessly for torch.tensor.repeat (although that usage should always be able to be constant folded). * Update logic in props.xml to account for shared provider library changes (#8138) * Ortmodule override torch.manual_seed() (#8131) * Ortmodule override torch.manual_seed() * Fix Python Cuda loading issues (#7939) * Fix mac shared_provider warning (#8153) Co-authored-by: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Bowen Bao <bowbao@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Sherlock <baihan.huang@gmail.com>
2021-06-26 18:26:29 +00:00
from torch.onnx.symbolic_helper import parse_args, _get_tensor_dim_size, _get_tensor_sizes
_onnx_opset_version = 1
def register_custom_op(is_ortmodule=False):
"""
This function registers symbolic functions for
custom ops that are implemented as part of ONNX Runtime
"""
# Symbolic definition
def inverse(g, self):
Cherry picks for release - 1.8.1 Round 2 (#8137) * fix boost download url (#7843) * Topo sort the model before saving (#7913) * checkin toposort * review comments * revert and add TODO * Add shape inference to custom symbolic functions (#7937) **Description**: As title. **Motivation and Context** - PyTorch ONNX exporter heavily depends on ONNX shape inference to export accurate and efficient model. Custom symbolic function exports the op as contrib ops, thus exporter is unable to perform standard onnx shape inference. Models with dynamic shape inputs are affected. * Fix missing files on linux (#8066) * [Mobile package] Update required operator config with additional ops for wav2vec2. (#8079) Add some additional ops to the mobile package that are needed for the wav2vec2 model. * Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088) * Fix input schema extrator for ORTModule (#8098) * Fix 32bit Android java API crash (#8122) * Fix 32bit Android java API crash * fix code formating * [Mobile package] Update required operator config with additional ops for newer version of Wav2Vec 2. (#8123) This is an update to https://github.com/microsoft/onnxruntime/pull/8079 The sample application motivating the original update changed to use an updated version of the model. Now, fewer ops are required. This change removes the previously added ops which are no longer needed. * Add int64 as a required type to ConstantOfShape as it's used by the pytorch converter for Pad. (#8128) It's also used pointlessly for torch.tensor.repeat (although that usage should always be able to be constant folded). * Update logic in props.xml to account for shared provider library changes (#8138) * Ortmodule override torch.manual_seed() (#8131) * Ortmodule override torch.manual_seed() * Fix Python Cuda loading issues (#7939) * Fix mac shared_provider warning (#8153) Co-authored-by: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Bowen Bao <bowbao@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Sherlock <baihan.huang@gmail.com>
2021-06-26 18:26:29 +00:00
return g.op("com.microsoft::Inverse", self).setType(self.type())
def gelu(g, self):
Cherry picks for release - 1.8.1 Round 2 (#8137) * fix boost download url (#7843) * Topo sort the model before saving (#7913) * checkin toposort * review comments * revert and add TODO * Add shape inference to custom symbolic functions (#7937) **Description**: As title. **Motivation and Context** - PyTorch ONNX exporter heavily depends on ONNX shape inference to export accurate and efficient model. Custom symbolic function exports the op as contrib ops, thus exporter is unable to perform standard onnx shape inference. Models with dynamic shape inputs are affected. * Fix missing files on linux (#8066) * [Mobile package] Update required operator config with additional ops for wav2vec2. (#8079) Add some additional ops to the mobile package that are needed for the wav2vec2 model. * Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088) * Fix input schema extrator for ORTModule (#8098) * Fix 32bit Android java API crash (#8122) * Fix 32bit Android java API crash * fix code formating * [Mobile package] Update required operator config with additional ops for newer version of Wav2Vec 2. (#8123) This is an update to https://github.com/microsoft/onnxruntime/pull/8079 The sample application motivating the original update changed to use an updated version of the model. Now, fewer ops are required. This change removes the previously added ops which are no longer needed. * Add int64 as a required type to ConstantOfShape as it's used by the pytorch converter for Pad. (#8128) It's also used pointlessly for torch.tensor.repeat (although that usage should always be able to be constant folded). * Update logic in props.xml to account for shared provider library changes (#8138) * Ortmodule override torch.manual_seed() (#8131) * Ortmodule override torch.manual_seed() * Fix Python Cuda loading issues (#7939) * Fix mac shared_provider warning (#8153) Co-authored-by: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Bowen Bao <bowbao@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Sherlock <baihan.huang@gmail.com>
2021-06-26 18:26:29 +00:00
return g.op("com.microsoft::Gelu", self).setType(self.type())
def triu(g, self, diagonal):
Cherry picks for release - 1.8.1 Round 2 (#8137) * fix boost download url (#7843) * Topo sort the model before saving (#7913) * checkin toposort * review comments * revert and add TODO * Add shape inference to custom symbolic functions (#7937) **Description**: As title. **Motivation and Context** - PyTorch ONNX exporter heavily depends on ONNX shape inference to export accurate and efficient model. Custom symbolic function exports the op as contrib ops, thus exporter is unable to perform standard onnx shape inference. Models with dynamic shape inputs are affected. * Fix missing files on linux (#8066) * [Mobile package] Update required operator config with additional ops for wav2vec2. (#8079) Add some additional ops to the mobile package that are needed for the wav2vec2 model. * Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088) * Fix input schema extrator for ORTModule (#8098) * Fix 32bit Android java API crash (#8122) * Fix 32bit Android java API crash * fix code formating * [Mobile package] Update required operator config with additional ops for newer version of Wav2Vec 2. (#8123) This is an update to https://github.com/microsoft/onnxruntime/pull/8079 The sample application motivating the original update changed to use an updated version of the model. Now, fewer ops are required. This change removes the previously added ops which are no longer needed. * Add int64 as a required type to ConstantOfShape as it's used by the pytorch converter for Pad. (#8128) It's also used pointlessly for torch.tensor.repeat (although that usage should always be able to be constant folded). * Update logic in props.xml to account for shared provider library changes (#8138) * Ortmodule override torch.manual_seed() (#8131) * Ortmodule override torch.manual_seed() * Fix Python Cuda loading issues (#7939) * Fix mac shared_provider warning (#8153) Co-authored-by: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Bowen Bao <bowbao@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Sherlock <baihan.huang@gmail.com>
2021-06-26 18:26:29 +00:00
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type())
def tril(g, self, diagonal):
Cherry picks for release - 1.8.1 Round 2 (#8137) * fix boost download url (#7843) * Topo sort the model before saving (#7913) * checkin toposort * review comments * revert and add TODO * Add shape inference to custom symbolic functions (#7937) **Description**: As title. **Motivation and Context** - PyTorch ONNX exporter heavily depends on ONNX shape inference to export accurate and efficient model. Custom symbolic function exports the op as contrib ops, thus exporter is unable to perform standard onnx shape inference. Models with dynamic shape inputs are affected. * Fix missing files on linux (#8066) * [Mobile package] Update required operator config with additional ops for wav2vec2. (#8079) Add some additional ops to the mobile package that are needed for the wav2vec2 model. * Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088) * Fix input schema extrator for ORTModule (#8098) * Fix 32bit Android java API crash (#8122) * Fix 32bit Android java API crash * fix code formating * [Mobile package] Update required operator config with additional ops for newer version of Wav2Vec 2. (#8123) This is an update to https://github.com/microsoft/onnxruntime/pull/8079 The sample application motivating the original update changed to use an updated version of the model. Now, fewer ops are required. This change removes the previously added ops which are no longer needed. * Add int64 as a required type to ConstantOfShape as it's used by the pytorch converter for Pad. (#8128) It's also used pointlessly for torch.tensor.repeat (although that usage should always be able to be constant folded). * Update logic in props.xml to account for shared provider library changes (#8138) * Ortmodule override torch.manual_seed() (#8131) * Ortmodule override torch.manual_seed() * Fix Python Cuda loading issues (#7939) * Fix mac shared_provider warning (#8153) Co-authored-by: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Bowen Bao <bowbao@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Sherlock <baihan.huang@gmail.com>
2021-06-26 18:26:29 +00:00
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type())
# Op Registration
register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version)
register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version)
register_custom_op_symbolic('::triu', triu, _onnx_opset_version)
register_custom_op_symbolic('::tril', tril, _onnx_opset_version)
if is_ortmodule:
@parse_args('v', 'v', 'i', 'b', 'b')
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
custom_attributes_json = (
'{'
f'"padding_idx":{str(padding_idx)},'
f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
f'"sparse":{str(sparse).lower()}'
'}'
)
Cherry picks for release - 1.8.1 Round 2 (#8137) * fix boost download url (#7843) * Topo sort the model before saving (#7913) * checkin toposort * review comments * revert and add TODO * Add shape inference to custom symbolic functions (#7937) **Description**: As title. **Motivation and Context** - PyTorch ONNX exporter heavily depends on ONNX shape inference to export accurate and efficient model. Custom symbolic function exports the op as contrib ops, thus exporter is unable to perform standard onnx shape inference. Models with dynamic shape inputs are affected. * Fix missing files on linux (#8066) * [Mobile package] Update required operator config with additional ops for wav2vec2. (#8079) Add some additional ops to the mobile package that are needed for the wav2vec2 model. * Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088) * Fix input schema extrator for ORTModule (#8098) * Fix 32bit Android java API crash (#8122) * Fix 32bit Android java API crash * fix code formating * [Mobile package] Update required operator config with additional ops for newer version of Wav2Vec 2. (#8123) This is an update to https://github.com/microsoft/onnxruntime/pull/8079 The sample application motivating the original update changed to use an updated version of the model. Now, fewer ops are required. This change removes the previously added ops which are no longer needed. * Add int64 as a required type to ConstantOfShape as it's used by the pytorch converter for Pad. (#8128) It's also used pointlessly for torch.tensor.repeat (although that usage should always be able to be constant folded). * Update logic in props.xml to account for shared provider library changes (#8138) * Ortmodule override torch.manual_seed() (#8131) * Ortmodule override torch.manual_seed() * Fix Python Cuda loading issues (#7939) * Fix mac shared_provider warning (#8153) Co-authored-by: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Bowen Bao <bowbao@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Sherlock <baihan.huang@gmail.com>
2021-06-26 18:26:29 +00:00
output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
custom_attributes_json_s=custom_attributes_json)
indices_shape = _get_tensor_sizes(indices)
if indices_shape is not None and hasattr(weight.type(), 'with_sizes'):
output_type = weight.type().with_sizes(indices_shape + [_get_tensor_dim_size(weight, 1)])
output.setType(output_type)
return output
register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version)
cherry picked commits for rel-1.8.1 (#8076) * Cache initializers and avoid device check ot end of forward (#7905) * ATenOp Enhancement (#7725) * config parser, default argument values * ut * win build * maxpool2d * fix win build * fix build * unfold atenop * Update CMakeLists.txt for openvino EP (#7980) * Add SoftmaxCrossEntropyLossInternal to Support Dynamic ignore_index Input (#7899) * add SoftmaxCrossEntropyLossInternal * bugfix and ut * fix ut * fix ut * support torch1.8.1 * function body for nll_loss_internal * Override ORTModule named_modules to support extra arg (#7954) * add missing provider_options.h in packages (#7995) * consolidate copy binary script for gpu/trt tarball package * add provider_options.h * add provider_options.h * Add cuda provides files (#8002) * Save module output for backward if needed (#8010) * Save module output for backward if needed * Make logic in InsertCastTransformer around forcing a node to fp32 more precise. (#8018) * Address #7981 Reworked the logic around forcing a node to run on fp32 even if it was supported on fp16. The github issue had multiple factors. In ORT 1.8 we remove Identity nodes that produce graph outputs as they're not needed. That resulted in a Loop node no longer having output nodes (it produces graph outputs instead), which meant the check in IsSingleInputNodeFloat16Node returned true as there was no longer a downstream Identity node processing fp16 data. We shouldn't only force a node to fp32 in very specific circumstances, and the changes hopefully check for those more precisely. * Fix Memory Leak from DlpackToOrtValue (#8029) * Update DirectML EP changes from DmlDev as of 2021-06-07 (#7987) * Merged PR 6093117: Fix test_DynamicQuantizedLinear_max_adjusted_expanded by allowing Identity operator to run on non-float inputs Motivation: As part of the OnnxConformance Backend tests, DynamicQuantizedLinear_max_adjusted_expanded is failing. Root Cause: - The test model has `Identity` operator as one of the node. The input of this node is of non-float data type. - In DML, `Identity` operator is registered as operator which requires floating input. - As per `DirectMLSchema.h`, support for non-float input has been added for `Identity` operator in DML but the same has not been reflected in the `OperatorRegistration.cpp`. Changes: - Removed all traces of the requiresFloatFormatsForGraph flag from it's definition and usage. This flag was only used for Identity and it's related operator. - Added null check for the graphOutput nodeArg in GraphDescBuilder.cpp to stop the crash of the test. Related work items: #33076298 * Merged PR 6103324: Remove usage of non-generic error code (FWP_E_NULL_POINTER) Motivation: Addressing Dwayne comment on the previous PR. [Ref: [6093117](https://dev.azure.com/microsoft/WindowsAI/_git/onnxruntime/pullrequest/6093117?discussionId=44292162&path=%2Fonnxruntime%2Fcore%2Fproviders%2Fdml%2FDmlExecutionProvider%2Fsrc%2FGraphPartitioner.cpp)] Changes: Inside the DML EP, we should not use some other platform specific error codes. Instead we should a appropriate generic error code. Related work items: #33076298 Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com> * [js/react_native] Use a mobile ORT instead of a full ORT (#8042) * Change full ort to mobile ort * Update Android example to load mobile ort * Change the format of test models to ort * update ios to use mobile ort * revise README * use onnxruntime-mobile-c CocoaPods in a npm package * fix PATH addition in windows should set PATH, not add to the tail the copy of PATH * Reduce Kernel Optimization (#8067) * reduce optimization * bug fix * add a check * add ut * refactor * add ut cases for keepdims=true Co-authored-by: baijumeswani <bmeswani@microsoft.com> Co-authored-by: Vincent Wang <wangwchpku@outlook.com> Co-authored-by: Changming Sun <chasun@microsoft.com> Co-authored-by: George Wu <jywu@microsoft.com> Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Co-authored-by: Sherlock <baihan.huang@gmail.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: sumitsays <sumitagarwal330@gmail.com> Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com> Co-authored-by: Sunghoon <35605090+hanbitmyths@users.noreply.github.com> Co-authored-by: iperov <lepersorium@gmail.com>
2021-06-18 14:44:55 +00:00
@parse_args('v', 'v', 'v', 'i', 'v')
def cross_entropy_loss(g, self, target, weight, reduction, ignore_index):
# reduction: 0->none, 1->mean, 2->sum
reduction = sym_help._maybe_get_const(reduction, 'i')
reduction_vals = ['none', 'mean', 'sum']
reduction = reduction_vals[reduction]
output, log_prob = g.op("com.microsoft::SoftmaxCrossEntropyLossInternal",
self, target, weight, ignore_index,
reduction_s=reduction, outputs=2)
output.setType(self.type())
log_prob.setType(self.type())
return output
register_custom_op_symbolic('::cross_entropy_loss', cross_entropy_loss, _onnx_opset_version)
@parse_args('v', 'v', 'v', 'i', 'v')
def nll_loss(g, self, target, weight, reduction, ignore_index):
# reduction: 0->none, 1->mean, 2->sum
reduction = sym_help._maybe_get_const(reduction, 'i')
reduction_vals = ['none', 'mean', 'sum']
reduction = reduction_vals[reduction]
output = g.op("com.microsoft::NegativeLogLikelihoodLossInternal",
self, target, weight, ignore_index, reduction_s=reduction)
output.setType(self.type())
return output
register_custom_op_symbolic('::nll_loss', nll_loss, _onnx_opset_version)
@parse_args('v', 'is', 'is', 'is', 'is', 'b')
def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode):
custom_attributes_json = (
'{'
f'"kernel_size":{str(kernel_size)},'
f'"stride":{str(stride)},'
f'"padding":{str(padding)},'
f'"dilation":{str(dilation)},'
f'"ceil_mode":{str(ceil_mode).lower()}'
'}'
)
return g.op("com.microsoft::ATenOp", self, name_s='aten::max_pool2d_with_indices',
custom_attributes_json_s=custom_attributes_json, outputs=2)[0]
register_custom_op_symbolic('::max_pool2d', max_pool2d, _onnx_opset_version)
@parse_args('v', 'i', 'i', 'i')
def unfold(g, input, dimension, size, step):
custom_attributes_json = (
'{'
f'"dimension":{str(dimension)},'
f'"size":{str(size)},'
f'"step":{str(step)}'
'}'
)
return g.op("com.microsoft::ATenOp", input, name_s='aten::unfold',
custom_attributes_json_s=custom_attributes_json)
register_custom_op_symbolic('::unfold', unfold, _onnx_opset_version)
def unregister_custom_op():
"""
This function unregisters symbolic functions for
custom ops that are implemented as part of ONNX Runtime
"""
import torch.onnx.symbolic_registry as sym_registry
# TODO: replace this once PyTorch supports unregister natively.
def unregister(name, opset_version):
ns, kind = name.split("::")
from torch.onnx.symbolic_helper import _onnx_stable_opsets
for version in _onnx_stable_opsets:
if version >= opset_version and sym_registry.is_registered_op(kind, ns, version):
del sym_registry._registry[(ns, version)][kind]
unregister('::inverse', _onnx_opset_version)
unregister('::gelu', _onnx_opset_version)
unregister('::triu', _onnx_opset_version)
unregister('::tril', _onnx_opset_version)