mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Remove monkey patch for PyTorch Nightly + ORTTrainer (#6659)
This commit is contained in:
parent
ff465483b1
commit
7ee5baa60d
2 changed files with 3 additions and 32 deletions
|
|
@ -10,7 +10,7 @@ import numpy as np
|
|||
from inspect import signature
|
||||
|
||||
from torch.utils.dlpack import from_dlpack
|
||||
from torch._six import container_abcs
|
||||
from collections import abc
|
||||
|
||||
# Needed to re-implement PyTorch's cpu,cuda,to methods
|
||||
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
|
||||
|
|
@ -134,9 +134,9 @@ def _parse_outputs_for_onnx_export(module, inputs):
|
|||
output_dynamic_axes = {}
|
||||
if isinstance(sample_outputs, torch.Tensor):
|
||||
output_names, output_dynamic_axes = _create_output_dim_names(sample_outputs, 0, False)
|
||||
elif isinstance(sample_outputs, container_abcs.Mapping):
|
||||
elif isinstance(sample_outputs, abc.Mapping):
|
||||
raise NotImplementedError('Dictionaries are not supported as output yet')
|
||||
elif isinstance(sample_outputs, container_abcs.Sequence):
|
||||
elif isinstance(sample_outputs, abc.Sequence):
|
||||
for idx, out in enumerate(sample_outputs):
|
||||
tmp_output_names, tmp_output_dynamic_axes = _create_output_dim_names(out, idx, True)
|
||||
output_names += tmp_output_names
|
||||
|
|
|
|||
|
|
@ -14,29 +14,6 @@ from .model_desc_validation import _ORTTrainerModelDesc
|
|||
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
def monkey_patch_pytorch():
|
||||
warnings.warn('ORTTrainer: Remove this monkey patch when https://github.com/pytorch/pytorch/pull/51396 is merged')
|
||||
|
||||
def ort_prim_ConstantChunk(g, self, chunks, dim):
|
||||
input_shape = g.op("Shape", self)
|
||||
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
|
||||
input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
|
||||
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
|
||||
chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
|
||||
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))
|
||||
input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
|
||||
chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
|
||||
res = []
|
||||
for i in range(chunks):
|
||||
index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
|
||||
end = g.op("Mul", chunk_dim, index)
|
||||
res.append(g.op("Slice", self, start, end, axis))
|
||||
start = end
|
||||
return res
|
||||
|
||||
import torch.onnx.symbolic_opset11
|
||||
torch.onnx.symbolic_opset11.prim_ConstantChunk = ort_prim_ConstantChunk
|
||||
|
||||
|
||||
class TrainStepInfo(object):
|
||||
r"""Private class used to store runtime information from current train step.
|
||||
|
|
@ -146,12 +123,6 @@ class ORTTrainer(object):
|
|||
"""
|
||||
|
||||
def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
|
||||
|
||||
# DO NOT MERGE THIS ON MASTER
|
||||
# TODO: Remove after https://github.com/pytorch/pytorch/pull/51396 is merged
|
||||
monkey_patch_pytorch()
|
||||
|
||||
# Basic validation
|
||||
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
|
||||
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
|
||||
assert isinstance(optim_config, optim._OptimizerConfig),\
|
||||
|
|
|
|||
Loading…
Reference in a new issue