From e85e31ee803555f7b2eec4f7b11a1d8645b2c115 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Fri, 5 Aug 2022 16:55:04 +0800 Subject: [PATCH] Update ORTModule Default Opset Version to 15 (#12419) * update ortmodule opset to 15 * update torch version * fix ut * fix ut * rollback * rollback for orttrainer --- orttraining/orttraining/python/ort_trainer.py | 15 +++---- .../orttraining/python/training/_utils.py | 13 ++++-- .../python/training/ortmodule/__init__.py | 9 ++-- .../python/training/orttrainer_options.py | 11 ++--- .../python/onnxruntime_test_postprocess.py | 21 +++++---- .../orttraining_test_onnx_ops_ortmodule.py | 6 ++- .../python/orttraining_test_ortmodule_api.py | 6 +-- .../orttraining_test_orttrainer_frontend.py | 44 ++++++++----------- .../test/python/orttraining_test_utils.py | 19 +++----- ...training-py-packaging-pipeline-cuda113.yml | 2 +- ...training-py-packaging-pipeline-cuda115.yml | 2 +- .../docker/Dockerfile.manylinux2014_rocm5.1.1 | 4 +- .../docker/Dockerfile.manylinux2014_rocm5.2 | 2 +- ...Dockerfile.manylinux2014_training_cuda11_3 | 2 +- ...Dockerfile.manylinux2014_training_cuda11_5 | 2 +- .../docker/scripts/install_python_deps.sh | 2 +- .../pai/rocm-ci-pipeline-env.Dockerfile | 2 +- 17 files changed, 78 insertions(+), 84 deletions(-) diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 80c04fc0ed..ec159d83bd 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -1,23 +1,22 @@ import io import os import warnings +from distutils.version import LooseVersion + import numpy as np import onnx -from onnx import numpy_helper -from onnx import helper import torch import torch.nn import torch.onnx +from onnx import helper, numpy_helper + import onnxruntime as ort -from ..training import postprocess -from distutils.version import LooseVersion -import warnings - -from .checkpointing_utils import list_checkpoint_files, get_checkpoint_name, CombineZeroCheckpoint import onnxruntime.capi.pt_patch - from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +from ..training import postprocess +from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files + DEFAULT_OPSET_VERSION = 14 diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index d42e1d9ef5..979ecfbde6 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -4,13 +4,14 @@ # -------------------------------------------------------------------------- import importlib.util -import numpy as np import os import sys +from functools import wraps + +import numpy as np import torch from onnx import TensorProto - -from functools import wraps +from packaging.version import Version def get_device_index(device): @@ -94,7 +95,11 @@ def dtype_torch_to_numpy(torch_dtype): return np.int8 elif torch_dtype == torch.uint8: return np.uint8 - elif torch_dtype == torch.complex32 or torch_dtype == torch.complex64: + elif torch_dtype == torch.complex64 or ( + # complex32 is missing in torch-1.11. + (Version(torch.__version__) < Version("1.11.0") or Version(torch.__version__) >= Version("1.12.0")) + and torch_dtype == torch.complex32 + ): # NOTE: numpy doesn't support complex32 return np.complex64 elif torch_dtype == torch.complex128 or torch_dtype == torch.cdouble: diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 4e6c3a21f4..f6ed8827bd 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -6,12 +6,14 @@ import os import sys import warnings + import torch from packaging import version from onnxruntime import set_seed from onnxruntime.capi import build_and_package_info as ort_info -from ._fallback import _FallbackPolicy, ORTModuleFallbackException, ORTModuleInitException, wrap_exception + +from ._fallback import ORTModuleFallbackException, ORTModuleInitException, _FallbackPolicy, wrap_exception from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed @@ -33,7 +35,7 @@ def _defined_from_envvar(name, default_value, warn=True): # NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and # assign them new values. Importing them directly do not propagate changes. ################################################################################ -ONNX_OPSET_VERSION = 14 +ONNX_OPSET_VERSION = 15 MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1" ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions") _FALLBACK_INIT_EXCEPTION = None @@ -118,6 +120,7 @@ def _are_deterministic_algorithms_enabled(): return ORTMODULE_IS_DETERMINISTIC +from .debug_options import DebugOptions, LogLevel # noqa: E402 + # ORTModule must be loaded only after all validation passes from .ortmodule import ORTModule # noqa: E402 -from .debug_options import DebugOptions, LogLevel # noqa: E402 diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 080b8202a5..9e7a2bde4d 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -1,11 +1,12 @@ import cerberus import torch -from .optim import lr_scheduler -from .amp import loss_scaler -from . import PropagateCastOpsStrategy import onnxruntime as ort +from . import PropagateCastOpsStrategy +from .amp import loss_scaler +from .optim import lr_scheduler + class ORTTrainerOptions(object): r"""Settings used by ONNX Runtime training backend @@ -291,8 +292,8 @@ class ORTTrainerOptions(object): 'onnx_opset_version': { 'type': 'integer', 'min' : 12, - 'max' : 13, - 'default': 12 + 'max' :14, + 'default': 14 }, 'enable_onnx_contrib_ops' : { 'type' : 'boolean', diff --git a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py index e91dd480b7..83bd524e7d 100644 --- a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py +++ b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py @@ -1,22 +1,21 @@ -import unittest -import pytest -import sys -import os import copy -from numpy.testing import assert_allclose, assert_array_equal +import os +import sys +import unittest import onnx +import pytest import torch import torch.nn as nn import torch.nn.functional as F - -from orttraining_test_utils import map_optimizer_attributes -from orttraining_test_transformers import BertModelTest, BertForPreTraining -from orttraining_test_data_loader import create_ort_test_dataloader +from numpy.testing import assert_allclose, assert_array_equal from orttraining_test_bert_postprocess import postprocess_model -import onnxruntime +from orttraining_test_data_loader import create_ort_test_dataloader +from orttraining_test_transformers import BertForPreTraining, BertModelTest +from orttraining_test_utils import map_optimizer_attributes -from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample +import onnxruntime +from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer, generate_sample torch.manual_seed(1) onnxruntime.set_seed(1) diff --git a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py index 1411d43be7..4c84d7ffb6 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py @@ -1,10 +1,12 @@ """ @brief test log(time=3s) """ -import unittest import copy +import unittest + import numpy as np import torch + from onnxruntime.training.ortmodule import ORTModule @@ -77,7 +79,7 @@ class TestOnnxOpsOrtModule(unittest.TestCase): for onnx_model in [onnx_graph_inf, onnx_graph_train]: for oimp in onnx_model.opset_import: if oimp.domain == "": - self.assertEqual(oimp.version, 14) + self.assertEqual(oimp.version, 15) if op_grad_type is not None: if isinstance(op_grad_type, tuple): text = str(onnx_graph_train) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 32543b7226..e824cbda53 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -39,7 +39,7 @@ from onnxruntime.training.ortmodule import ( ) from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient -DEFAULT_OPSET = 14 +DEFAULT_OPSET = 15 # PyTorch model definitions for tests @@ -5079,7 +5079,7 @@ def test_sigmoid_grad_opset13(): old_opst_cst = ortmodule.ONNX_OPSET_VERSION old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None) os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13" - assert ortmodule.ONNX_OPSET_VERSION == 14 + assert ortmodule.ONNX_OPSET_VERSION == 15 ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -5109,7 +5109,7 @@ def test_sigmoid_grad_opset13(): ortmodule.ONNX_OPSET_VERSION = old_opst_cst -@pytest.mark.parametrize("opset_version", [12, 13, 14]) +@pytest.mark.parametrize("opset_version", [12, 13, 14, 15]) def test_opset_version_change(opset_version): device = "cuda" diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 924750e738..85234d3d92 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -1,35 +1,27 @@ -from functools import partial import inspect import math -from distutils.version import StrictVersion -from numpy.testing import assert_allclose -import onnx import os -import pytest import tempfile +from distutils.version import StrictVersion +from functools import partial + +import _test_commons +import _test_helpers +import onnx +import pytest import torch import torch.nn.functional as F +from numpy.testing import assert_allclose -from onnxruntime import set_seed -from onnxruntime.capi.ort_trainer import ( - IODescription as Legacy_IODescription, - ModelDescription as Legacy_ModelDescription, - LossScaler as Legacy_LossScaler, - ORTTrainer as Legacy_ORTTrainer, -) -from onnxruntime.training import ( - _utils, - amp, - checkpoint, - optim, - orttrainer, - TrainStepInfo, - model_desc_validation as md_val, - orttrainer_options as orttrainer_options, -) -import _test_commons, _test_helpers -from onnxruntime import SessionOptions -from onnxruntime.training import PropagateCastOpsStrategy +from onnxruntime import SessionOptions, set_seed +from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription +from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler +from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription +from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer +from onnxruntime.training import PropagateCastOpsStrategy, TrainStepInfo, _utils, amp, checkpoint +from onnxruntime.training import model_desc_validation as md_val +from onnxruntime.training import optim, orttrainer +from onnxruntime.training import orttrainer_options as orttrainer_options ############################################################################### # Testing starts here ######################################################### @@ -708,7 +700,7 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device) assert trainer._onnx_model.graph.output[i].name == output_name for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): - if opset != 14: + if opset is None or opset <= 12: assert output_dim[dim_idx] == dim.dim_value assert output_type == _utils.dtype_onnx_to_torch( trainer._onnx_model.graph.output[i].type.tensor_type.elem_type diff --git a/orttraining/orttraining/test/python/orttraining_test_utils.py b/orttraining/orttraining/test/python/orttraining_test_utils.py index 0ec501b0d7..7397cc9d51 100644 --- a/orttraining/orttraining/test/python/orttraining_test_utils.py +++ b/orttraining/orttraining/test/python/orttraining_test_utils.py @@ -1,19 +1,12 @@ import torch - -from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription - -from orttraining_test_data_loader import create_ort_test_dataloader, BatchArgsOption, split_batch from orttraining_test_bert_postprocess import postprocess_model +from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch -from onnxruntime.training import ( - _utils, - amp, - optim, - orttrainer, - TrainStepInfo, - model_desc_validation as md_val, - orttrainer_options as orttrainer_options, -) +from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer +from onnxruntime.training import TrainStepInfo, _utils, amp +from onnxruntime.training import model_desc_validation as md_val +from onnxruntime.training import optim, orttrainer +from onnxruntime.training import orttrainer_options as orttrainer_options from onnxruntime.training.optim import _LRScheduler diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda113.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda113.yml index 47334e74bb..421d0d377f 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda113.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda113.yml @@ -13,7 +13,7 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '1.11.0' - opset_version: '14' + opset_version: '15' cuda_version: '11.3' gcc_version: 10 cmake_cuda_architectures: 37;50;52;60;61;70;75;80;86 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda115.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda115.yml index 99ac852921..05ca414826 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda115.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda115.yml @@ -13,7 +13,7 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '1.11.0' - opset_version: '14' + opset_version: '15' cuda_version: '11.5' gcc_version: 10 cmake_cuda_architectures: 37;50;52;60;61;70;75;80;86;87 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.1.1 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.1.1 index 7bedb994c7..5ee0fcccb7 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.1.1 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.1.1 @@ -75,7 +75,7 @@ RUN export LIBXCRYPT_VERSION=4.4.28 && \ export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ manylinux-entrypoint /build_scripts/install-libxcrypt.sh - + COPY scripts/install-protobuf.sh /build_scripts/ RUN export PROTOBUF_VERSION=3.17.3 && \ export PROTOBUF_ROOT=protobuf-all-${PROTOBUF_VERSION} && \ @@ -163,7 +163,7 @@ CMD ["/bin/bash"] ARG PYTHON_VERSION=3.7 ARG TORCH_VERSION=1.11.0 -ARG OPSET_VERSION=14 +ARG OPSET_VERSION=15 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.2 index c0b7bb97cc..4edd0c8203 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.2 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm5.2 @@ -163,7 +163,7 @@ CMD ["/bin/bash"] ARG PYTHON_VERSION=3.7 ARG TORCH_VERSION=1.11.0 -ARG OPSET_VERSION=14 +ARG OPSET_VERSION=15 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_3 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_3 index 9108a713df..e09d239f5b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_3 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_3 @@ -166,7 +166,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=1.11.0 -ARG OPSET_VERSION=14 +ARG OPSET_VERSION=15 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_5 index dbeca18c89..d79a32fbe5 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_5 @@ -166,7 +166,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=1.11.0 -ARG OPSET_VERSION=14 +ARG OPSET_VERSION=15 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh index e079f00edf..f1bf87c667 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh @@ -7,7 +7,7 @@ ORTMODULE_BUILD=false TARGET_ROCM=false CU_VER="11.1" ROCM_VER="5.1.1" -TORCH_VERSION='1.10.0' +TORCH_VERSION='1.11.0' USE_CONDA=false while getopts p:h:d:v:tmurc parameter_Option diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 63deb4b920..5955f672f2 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -41,4 +41,4 @@ RUN pip install \ pytorch_lightning==1.6.0 RUN pip install torch-ort --no-dependencies -ENV ORTMODULE_ONNX_OPSET_VERSION=14 +ENV ORTMODULE_ONNX_OPSET_VERSION=15