From 7ee5baa60db7ecfbfe956b0c8d58f1c2461f6972 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Tue, 16 Feb 2021 17:24:50 -0800 Subject: [PATCH] Remove monkey patch for PyTorch Nightly + ORTTrainer (#6659) --- .../orttraining/python/training/ortmodule.py | 6 ++-- .../orttraining/python/training/orttrainer.py | 29 ------------------- 2 files changed, 3 insertions(+), 32 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 067d99dc50..b9e3fe5824 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -10,7 +10,7 @@ import numpy as np from inspect import signature from torch.utils.dlpack import from_dlpack -from torch._six import container_abcs +from collections import abc # Needed to re-implement PyTorch's cpu,cuda,to methods from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict @@ -134,9 +134,9 @@ def _parse_outputs_for_onnx_export(module, inputs): output_dynamic_axes = {} if isinstance(sample_outputs, torch.Tensor): output_names, output_dynamic_axes = _create_output_dim_names(sample_outputs, 0, False) - elif isinstance(sample_outputs, container_abcs.Mapping): + elif isinstance(sample_outputs, abc.Mapping): raise NotImplementedError('Dictionaries are not supported as output yet') - elif isinstance(sample_outputs, container_abcs.Sequence): + elif isinstance(sample_outputs, abc.Sequence): for idx, out in enumerate(sample_outputs): tmp_output_names, tmp_output_dynamic_axes = _create_output_dim_names(out, idx, True) output_names += tmp_output_names diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 798c3b416a..2348d9094d 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -14,29 +14,6 @@ from .model_desc_validation import _ORTTrainerModelDesc from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -def monkey_patch_pytorch(): - warnings.warn('ORTTrainer: Remove this monkey patch when https://github.com/pytorch/pytorch/pull/51396 is merged') - - def ort_prim_ConstantChunk(g, self, chunks, dim): - input_shape = g.op("Shape", self) - axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) - input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) - start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) - chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) - chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)) - input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) - chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) - res = [] - for i in range(chunks): - index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) - end = g.op("Mul", chunk_dim, index) - res.append(g.op("Slice", self, start, end, axis)) - start = end - return res - - import torch.onnx.symbolic_opset11 - torch.onnx.symbolic_opset11.prim_ConstantChunk = ort_prim_ConstantChunk - class TrainStepInfo(object): r"""Private class used to store runtime information from current train step. @@ -146,12 +123,6 @@ class ORTTrainer(object): """ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): - - # DO NOT MERGE THIS ON MASTER - # TODO: Remove after https://github.com/pytorch/pytorch/pull/51396 is merged - monkey_patch_pytorch() - - # Basic validation assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" assert isinstance(optim_config, optim._OptimizerConfig),\