Remove monkey patch for PyTorch Nightly + ORTTrainer (#6659)

This commit is contained in:
Thiago Crepaldi 2021-02-16 17:24:50 -08:00 committed by GitHub
parent ff465483b1
commit 7ee5baa60d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 32 deletions

View file

@ -10,7 +10,7 @@ import numpy as np
from inspect import signature
from torch.utils.dlpack import from_dlpack
from torch._six import container_abcs
from collections import abc
# Needed to re-implement PyTorch's cpu,cuda,to methods
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
@ -134,9 +134,9 @@ def _parse_outputs_for_onnx_export(module, inputs):
output_dynamic_axes = {}
if isinstance(sample_outputs, torch.Tensor):
output_names, output_dynamic_axes = _create_output_dim_names(sample_outputs, 0, False)
elif isinstance(sample_outputs, container_abcs.Mapping):
elif isinstance(sample_outputs, abc.Mapping):
raise NotImplementedError('Dictionaries are not supported as output yet')
elif isinstance(sample_outputs, container_abcs.Sequence):
elif isinstance(sample_outputs, abc.Sequence):
for idx, out in enumerate(sample_outputs):
tmp_output_names, tmp_output_dynamic_axes = _create_output_dim_names(out, idx, True)
output_names += tmp_output_names

View file

@ -14,29 +14,6 @@ from .model_desc_validation import _ORTTrainerModelDesc
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
def monkey_patch_pytorch():
warnings.warn('ORTTrainer: Remove this monkey patch when https://github.com/pytorch/pytorch/pull/51396 is merged')
def ort_prim_ConstantChunk(g, self, chunks, dim):
input_shape = g.op("Shape", self)
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))
input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
res = []
for i in range(chunks):
index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
end = g.op("Mul", chunk_dim, index)
res.append(g.op("Slice", self, start, end, axis))
start = end
return res
import torch.onnx.symbolic_opset11
torch.onnx.symbolic_opset11.prim_ConstantChunk = ort_prim_ConstantChunk
class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.
@ -146,12 +123,6 @@ class ORTTrainer(object):
"""
def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
# DO NOT MERGE THIS ON MASTER
# TODO: Remove after https://github.com/pytorch/pytorch/pull/51396 is merged
monkey_patch_pytorch()
# Basic validation
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
assert isinstance(optim_config, optim._OptimizerConfig),\