Update ORTModule Default Opset Version to 15 (#12419)

* update ortmodule opset to 15

* update torch version

* fix ut

* fix ut

* rollback

* rollback for orttrainer
This commit is contained in:
Vincent Wang 2022-08-05 16:55:04 +08:00 committed by GitHub
parent a7d6290774
commit e85e31ee80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 78 additions and 84 deletions

View file

@ -1,23 +1,22 @@
import io
import os
import warnings
from distutils.version import LooseVersion
import numpy as np
import onnx
from onnx import numpy_helper
from onnx import helper
import torch
import torch.nn
import torch.onnx
from onnx import helper, numpy_helper
import onnxruntime as ort
from ..training import postprocess
from distutils.version import LooseVersion
import warnings
from .checkpointing_utils import list_checkpoint_files, get_checkpoint_name, CombineZeroCheckpoint
import onnxruntime.capi.pt_patch
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from ..training import postprocess
from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files
DEFAULT_OPSET_VERSION = 14

View file

@ -4,13 +4,14 @@
# --------------------------------------------------------------------------
import importlib.util
import numpy as np
import os
import sys
from functools import wraps
import numpy as np
import torch
from onnx import TensorProto
from functools import wraps
from packaging.version import Version
def get_device_index(device):
@ -94,7 +95,11 @@ def dtype_torch_to_numpy(torch_dtype):
return np.int8
elif torch_dtype == torch.uint8:
return np.uint8
elif torch_dtype == torch.complex32 or torch_dtype == torch.complex64:
elif torch_dtype == torch.complex64 or (
# complex32 is missing in torch-1.11.
(Version(torch.__version__) < Version("1.11.0") or Version(torch.__version__) >= Version("1.12.0"))
and torch_dtype == torch.complex32
):
# NOTE: numpy doesn't support complex32
return np.complex64
elif torch_dtype == torch.complex128 or torch_dtype == torch.cdouble:

View file

@ -6,12 +6,14 @@
import os
import sys
import warnings
import torch
from packaging import version
from onnxruntime import set_seed
from onnxruntime.capi import build_and_package_info as ort_info
from ._fallback import _FallbackPolicy, ORTModuleFallbackException, ORTModuleInitException, wrap_exception
from ._fallback import ORTModuleFallbackException, ORTModuleInitException, _FallbackPolicy, wrap_exception
from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
@ -33,7 +35,7 @@ def _defined_from_envvar(name, default_value, warn=True):
# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and
# assign them new values. Importing them directly do not propagate changes.
################################################################################
ONNX_OPSET_VERSION = 14
ONNX_OPSET_VERSION = 15
MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1"
ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions")
_FALLBACK_INIT_EXCEPTION = None
@ -118,6 +120,7 @@ def _are_deterministic_algorithms_enabled():
return ORTMODULE_IS_DETERMINISTIC
from .debug_options import DebugOptions, LogLevel # noqa: E402
# ORTModule must be loaded only after all validation passes
from .ortmodule import ORTModule # noqa: E402
from .debug_options import DebugOptions, LogLevel # noqa: E402

View file

@ -1,11 +1,12 @@
import cerberus
import torch
from .optim import lr_scheduler
from .amp import loss_scaler
from . import PropagateCastOpsStrategy
import onnxruntime as ort
from . import PropagateCastOpsStrategy
from .amp import loss_scaler
from .optim import lr_scheduler
class ORTTrainerOptions(object):
r"""Settings used by ONNX Runtime training backend
@ -291,8 +292,8 @@ class ORTTrainerOptions(object):
'onnx_opset_version': {
'type': 'integer',
'min' : 12,
'max' : 13,
'default': 12
'max' :14,
'default': 14
},
'enable_onnx_contrib_ops' : {
'type' : 'boolean',

View file

@ -1,22 +1,21 @@
import unittest
import pytest
import sys
import os
import copy
from numpy.testing import assert_allclose, assert_array_equal
import os
import sys
import unittest
import onnx
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from orttraining_test_utils import map_optimizer_attributes
from orttraining_test_transformers import BertModelTest, BertForPreTraining
from orttraining_test_data_loader import create_ort_test_dataloader
from numpy.testing import assert_allclose, assert_array_equal
from orttraining_test_bert_postprocess import postprocess_model
import onnxruntime
from orttraining_test_data_loader import create_ort_test_dataloader
from orttraining_test_transformers import BertForPreTraining, BertModelTest
from orttraining_test_utils import map_optimizer_attributes
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample
import onnxruntime
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer, generate_sample
torch.manual_seed(1)
onnxruntime.set_seed(1)

View file

@ -1,10 +1,12 @@
"""
@brief test log(time=3s)
"""
import unittest
import copy
import unittest
import numpy as np
import torch
from onnxruntime.training.ortmodule import ORTModule
@ -77,7 +79,7 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
for onnx_model in [onnx_graph_inf, onnx_graph_train]:
for oimp in onnx_model.opset_import:
if oimp.domain == "":
self.assertEqual(oimp.version, 14)
self.assertEqual(oimp.version, 15)
if op_grad_type is not None:
if isinstance(op_grad_type, tuple):
text = str(onnx_graph_train)

View file

@ -39,7 +39,7 @@ from onnxruntime.training.ortmodule import (
)
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
DEFAULT_OPSET = 14
DEFAULT_OPSET = 15
# PyTorch model definitions for tests
@ -5079,7 +5079,7 @@ def test_sigmoid_grad_opset13():
old_opst_cst = ortmodule.ONNX_OPSET_VERSION
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13"
assert ortmodule.ONNX_OPSET_VERSION == 14
assert ortmodule.ONNX_OPSET_VERSION == 15
ort_model = ORTModule(copy.deepcopy(pt_model))
@ -5109,7 +5109,7 @@ def test_sigmoid_grad_opset13():
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
@pytest.mark.parametrize("opset_version", [12, 13, 14])
@pytest.mark.parametrize("opset_version", [12, 13, 14, 15])
def test_opset_version_change(opset_version):
device = "cuda"

View file

@ -1,35 +1,27 @@
from functools import partial
import inspect
import math
from distutils.version import StrictVersion
from numpy.testing import assert_allclose
import onnx
import os
import pytest
import tempfile
from distutils.version import StrictVersion
from functools import partial
import _test_commons
import _test_helpers
import onnx
import pytest
import torch
import torch.nn.functional as F
from numpy.testing import assert_allclose
from onnxruntime import set_seed
from onnxruntime.capi.ort_trainer import (
IODescription as Legacy_IODescription,
ModelDescription as Legacy_ModelDescription,
LossScaler as Legacy_LossScaler,
ORTTrainer as Legacy_ORTTrainer,
)
from onnxruntime.training import (
_utils,
amp,
checkpoint,
optim,
orttrainer,
TrainStepInfo,
model_desc_validation as md_val,
orttrainer_options as orttrainer_options,
)
import _test_commons, _test_helpers
from onnxruntime import SessionOptions
from onnxruntime.training import PropagateCastOpsStrategy
from onnxruntime import SessionOptions, set_seed
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription
from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler
from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer
from onnxruntime.training import PropagateCastOpsStrategy, TrainStepInfo, _utils, amp, checkpoint
from onnxruntime.training import model_desc_validation as md_val
from onnxruntime.training import optim, orttrainer
from onnxruntime.training import orttrainer_options as orttrainer_options
###############################################################################
# Testing starts here #########################################################
@ -708,7 +700,7 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device)
assert trainer._onnx_model.graph.output[i].name == output_name
for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim):
if opset != 14:
if opset is None or opset <= 12:
assert output_dim[dim_idx] == dim.dim_value
assert output_type == _utils.dtype_onnx_to_torch(
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type

View file

@ -1,19 +1,12 @@
import torch
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription
from orttraining_test_data_loader import create_ort_test_dataloader, BatchArgsOption, split_batch
from orttraining_test_bert_postprocess import postprocess_model
from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch
from onnxruntime.training import (
_utils,
amp,
optim,
orttrainer,
TrainStepInfo,
model_desc_validation as md_val,
orttrainer_options as orttrainer_options,
)
from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer
from onnxruntime.training import TrainStepInfo, _utils, amp
from onnxruntime.training import model_desc_validation as md_val
from onnxruntime.training import optim, orttrainer
from onnxruntime.training import orttrainer_options as orttrainer_options
from onnxruntime.training.optim import _LRScheduler

View file

@ -13,7 +13,7 @@ stages:
parameters:
build_py_parameters: --enable_training --update --build
torch_version: '1.11.0'
opset_version: '14'
opset_version: '15'
cuda_version: '11.3'
gcc_version: 10
cmake_cuda_architectures: 37;50;52;60;61;70;75;80;86

View file

@ -13,7 +13,7 @@ stages:
parameters:
build_py_parameters: --enable_training --update --build
torch_version: '1.11.0'
opset_version: '14'
opset_version: '15'
cuda_version: '11.5'
gcc_version: 10
cmake_cuda_architectures: 37;50;52;60;61;70;75;80;86;87

View file

@ -75,7 +75,7 @@ RUN export LIBXCRYPT_VERSION=4.4.28 && \
export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \
export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \
manylinux-entrypoint /build_scripts/install-libxcrypt.sh
COPY scripts/install-protobuf.sh /build_scripts/
RUN export PROTOBUF_VERSION=3.17.3 && \
export PROTOBUF_ROOT=protobuf-all-${PROTOBUF_VERSION} && \
@ -163,7 +163,7 @@ CMD ["/bin/bash"]
ARG PYTHON_VERSION=3.7
ARG TORCH_VERSION=1.11.0
ARG OPSET_VERSION=14
ARG OPSET_VERSION=15
ARG INSTALL_DEPS_EXTRA_ARGS
#Add our own dependencies

View file

@ -163,7 +163,7 @@ CMD ["/bin/bash"]
ARG PYTHON_VERSION=3.7
ARG TORCH_VERSION=1.11.0
ARG OPSET_VERSION=14
ARG OPSET_VERSION=15
ARG INSTALL_DEPS_EXTRA_ARGS
#Add our own dependencies

View file

@ -166,7 +166,7 @@ CMD ["/bin/bash"]
#Build manylinux2014 docker image end
ARG PYTHON_VERSION=3.9
ARG TORCH_VERSION=1.11.0
ARG OPSET_VERSION=14
ARG OPSET_VERSION=15
ARG INSTALL_DEPS_EXTRA_ARGS
#Add our own dependencies

View file

@ -166,7 +166,7 @@ CMD ["/bin/bash"]
#Build manylinux2014 docker image end
ARG PYTHON_VERSION=3.9
ARG TORCH_VERSION=1.11.0
ARG OPSET_VERSION=14
ARG OPSET_VERSION=15
ARG INSTALL_DEPS_EXTRA_ARGS
#Add our own dependencies

View file

@ -7,7 +7,7 @@ ORTMODULE_BUILD=false
TARGET_ROCM=false
CU_VER="11.1"
ROCM_VER="5.1.1"
TORCH_VERSION='1.10.0'
TORCH_VERSION='1.11.0'
USE_CONDA=false
while getopts p:h:d:v:tmurc parameter_Option

View file

@ -41,4 +41,4 @@ RUN pip install \
pytorch_lightning==1.6.0
RUN pip install torch-ort --no-dependencies
ENV ORTMODULE_ONNX_OPSET_VERSION=14
ENV ORTMODULE_ONNX_OPSET_VERSION=15