mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update ORTModule Default Opset Version to 15 (#12419)
* update ortmodule opset to 15 * update torch version * fix ut * fix ut * rollback * rollback for orttrainer
This commit is contained in:
parent
a7d6290774
commit
e85e31ee80
17 changed files with 78 additions and 84 deletions
|
|
@ -1,23 +1,22 @@
|
|||
import io
|
||||
import os
|
||||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnx import helper
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.onnx
|
||||
from onnx import helper, numpy_helper
|
||||
|
||||
import onnxruntime as ort
|
||||
from ..training import postprocess
|
||||
from distutils.version import LooseVersion
|
||||
import warnings
|
||||
|
||||
from .checkpointing_utils import list_checkpoint_files, get_checkpoint_name, CombineZeroCheckpoint
|
||||
import onnxruntime.capi.pt_patch
|
||||
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
from ..training import postprocess
|
||||
from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files
|
||||
|
||||
DEFAULT_OPSET_VERSION = 14
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
import importlib.util
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import TensorProto
|
||||
|
||||
from functools import wraps
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
def get_device_index(device):
|
||||
|
|
@ -94,7 +95,11 @@ def dtype_torch_to_numpy(torch_dtype):
|
|||
return np.int8
|
||||
elif torch_dtype == torch.uint8:
|
||||
return np.uint8
|
||||
elif torch_dtype == torch.complex32 or torch_dtype == torch.complex64:
|
||||
elif torch_dtype == torch.complex64 or (
|
||||
# complex32 is missing in torch-1.11.
|
||||
(Version(torch.__version__) < Version("1.11.0") or Version(torch.__version__) >= Version("1.12.0"))
|
||||
and torch_dtype == torch.complex32
|
||||
):
|
||||
# NOTE: numpy doesn't support complex32
|
||||
return np.complex64
|
||||
elif torch_dtype == torch.complex128 or torch_dtype == torch.cdouble:
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@
|
|||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from onnxruntime import set_seed
|
||||
from onnxruntime.capi import build_and_package_info as ort_info
|
||||
from ._fallback import _FallbackPolicy, ORTModuleFallbackException, ORTModuleInitException, wrap_exception
|
||||
|
||||
from ._fallback import ORTModuleFallbackException, ORTModuleInitException, _FallbackPolicy, wrap_exception
|
||||
from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed
|
||||
|
||||
|
||||
|
|
@ -33,7 +35,7 @@ def _defined_from_envvar(name, default_value, warn=True):
|
|||
# NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and
|
||||
# assign them new values. Importing them directly do not propagate changes.
|
||||
################################################################################
|
||||
ONNX_OPSET_VERSION = 14
|
||||
ONNX_OPSET_VERSION = 15
|
||||
MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1"
|
||||
ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions")
|
||||
_FALLBACK_INIT_EXCEPTION = None
|
||||
|
|
@ -118,6 +120,7 @@ def _are_deterministic_algorithms_enabled():
|
|||
return ORTMODULE_IS_DETERMINISTIC
|
||||
|
||||
|
||||
from .debug_options import DebugOptions, LogLevel # noqa: E402
|
||||
|
||||
# ORTModule must be loaded only after all validation passes
|
||||
from .ortmodule import ORTModule # noqa: E402
|
||||
from .debug_options import DebugOptions, LogLevel # noqa: E402
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import cerberus
|
||||
import torch
|
||||
|
||||
from .optim import lr_scheduler
|
||||
from .amp import loss_scaler
|
||||
from . import PropagateCastOpsStrategy
|
||||
import onnxruntime as ort
|
||||
|
||||
from . import PropagateCastOpsStrategy
|
||||
from .amp import loss_scaler
|
||||
from .optim import lr_scheduler
|
||||
|
||||
|
||||
class ORTTrainerOptions(object):
|
||||
r"""Settings used by ONNX Runtime training backend
|
||||
|
|
@ -291,8 +292,8 @@ class ORTTrainerOptions(object):
|
|||
'onnx_opset_version': {
|
||||
'type': 'integer',
|
||||
'min' : 12,
|
||||
'max' : 13,
|
||||
'default': 12
|
||||
'max' :14,
|
||||
'default': 14
|
||||
},
|
||||
'enable_onnx_contrib_ops' : {
|
||||
'type' : 'boolean',
|
||||
|
|
|
|||
|
|
@ -1,22 +1,21 @@
|
|||
import unittest
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
from numpy.testing import assert_allclose, assert_array_equal
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from orttraining_test_utils import map_optimizer_attributes
|
||||
from orttraining_test_transformers import BertModelTest, BertForPreTraining
|
||||
from orttraining_test_data_loader import create_ort_test_dataloader
|
||||
from numpy.testing import assert_allclose, assert_array_equal
|
||||
from orttraining_test_bert_postprocess import postprocess_model
|
||||
import onnxruntime
|
||||
from orttraining_test_data_loader import create_ort_test_dataloader
|
||||
from orttraining_test_transformers import BertForPreTraining, BertModelTest
|
||||
from orttraining_test_utils import map_optimizer_attributes
|
||||
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer, generate_sample
|
||||
|
||||
torch.manual_seed(1)
|
||||
onnxruntime.set_seed(1)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""
|
||||
@brief test log(time=3s)
|
||||
"""
|
||||
import unittest
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
|
||||
|
||||
|
|
@ -77,7 +79,7 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
|
|||
for onnx_model in [onnx_graph_inf, onnx_graph_train]:
|
||||
for oimp in onnx_model.opset_import:
|
||||
if oimp.domain == "":
|
||||
self.assertEqual(oimp.version, 14)
|
||||
self.assertEqual(oimp.version, 15)
|
||||
if op_grad_type is not None:
|
||||
if isinstance(op_grad_type, tuple):
|
||||
text = str(onnx_graph_train)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from onnxruntime.training.ortmodule import (
|
|||
)
|
||||
from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient
|
||||
|
||||
DEFAULT_OPSET = 14
|
||||
DEFAULT_OPSET = 15
|
||||
|
||||
|
||||
# PyTorch model definitions for tests
|
||||
|
|
@ -5079,7 +5079,7 @@ def test_sigmoid_grad_opset13():
|
|||
old_opst_cst = ortmodule.ONNX_OPSET_VERSION
|
||||
old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None)
|
||||
os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13"
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 14
|
||||
assert ortmodule.ONNX_OPSET_VERSION == 15
|
||||
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
|
|
@ -5109,7 +5109,7 @@ def test_sigmoid_grad_opset13():
|
|||
ortmodule.ONNX_OPSET_VERSION = old_opst_cst
|
||||
|
||||
|
||||
@pytest.mark.parametrize("opset_version", [12, 13, 14])
|
||||
@pytest.mark.parametrize("opset_version", [12, 13, 14, 15])
|
||||
def test_opset_version_change(opset_version):
|
||||
device = "cuda"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,35 +1,27 @@
|
|||
from functools import partial
|
||||
import inspect
|
||||
import math
|
||||
from distutils.version import StrictVersion
|
||||
from numpy.testing import assert_allclose
|
||||
import onnx
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from distutils.version import StrictVersion
|
||||
from functools import partial
|
||||
|
||||
import _test_commons
|
||||
import _test_helpers
|
||||
import onnx
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from onnxruntime import set_seed
|
||||
from onnxruntime.capi.ort_trainer import (
|
||||
IODescription as Legacy_IODescription,
|
||||
ModelDescription as Legacy_ModelDescription,
|
||||
LossScaler as Legacy_LossScaler,
|
||||
ORTTrainer as Legacy_ORTTrainer,
|
||||
)
|
||||
from onnxruntime.training import (
|
||||
_utils,
|
||||
amp,
|
||||
checkpoint,
|
||||
optim,
|
||||
orttrainer,
|
||||
TrainStepInfo,
|
||||
model_desc_validation as md_val,
|
||||
orttrainer_options as orttrainer_options,
|
||||
)
|
||||
import _test_commons, _test_helpers
|
||||
from onnxruntime import SessionOptions
|
||||
from onnxruntime.training import PropagateCastOpsStrategy
|
||||
from onnxruntime import SessionOptions, set_seed
|
||||
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription
|
||||
from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler
|
||||
from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer
|
||||
from onnxruntime.training import PropagateCastOpsStrategy, TrainStepInfo, _utils, amp, checkpoint
|
||||
from onnxruntime.training import model_desc_validation as md_val
|
||||
from onnxruntime.training import optim, orttrainer
|
||||
from onnxruntime.training import orttrainer_options as orttrainer_options
|
||||
|
||||
###############################################################################
|
||||
# Testing starts here #########################################################
|
||||
|
|
@ -708,7 +700,7 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device)
|
|||
|
||||
assert trainer._onnx_model.graph.output[i].name == output_name
|
||||
for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim):
|
||||
if opset != 14:
|
||||
if opset is None or opset <= 12:
|
||||
assert output_dim[dim_idx] == dim.dim_value
|
||||
assert output_type == _utils.dtype_onnx_to_torch(
|
||||
trainer._onnx_model.graph.output[i].type.tensor_type.elem_type
|
||||
|
|
|
|||
|
|
@ -1,19 +1,12 @@
|
|||
import torch
|
||||
|
||||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription
|
||||
|
||||
from orttraining_test_data_loader import create_ort_test_dataloader, BatchArgsOption, split_batch
|
||||
from orttraining_test_bert_postprocess import postprocess_model
|
||||
from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch
|
||||
|
||||
from onnxruntime.training import (
|
||||
_utils,
|
||||
amp,
|
||||
optim,
|
||||
orttrainer,
|
||||
TrainStepInfo,
|
||||
model_desc_validation as md_val,
|
||||
orttrainer_options as orttrainer_options,
|
||||
)
|
||||
from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer
|
||||
from onnxruntime.training import TrainStepInfo, _utils, amp
|
||||
from onnxruntime.training import model_desc_validation as md_val
|
||||
from onnxruntime.training import optim, orttrainer
|
||||
from onnxruntime.training import orttrainer_options as orttrainer_options
|
||||
from onnxruntime.training.optim import _LRScheduler
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ stages:
|
|||
parameters:
|
||||
build_py_parameters: --enable_training --update --build
|
||||
torch_version: '1.11.0'
|
||||
opset_version: '14'
|
||||
opset_version: '15'
|
||||
cuda_version: '11.3'
|
||||
gcc_version: 10
|
||||
cmake_cuda_architectures: 37;50;52;60;61;70;75;80;86
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ stages:
|
|||
parameters:
|
||||
build_py_parameters: --enable_training --update --build
|
||||
torch_version: '1.11.0'
|
||||
opset_version: '14'
|
||||
opset_version: '15'
|
||||
cuda_version: '11.5'
|
||||
gcc_version: 10
|
||||
cmake_cuda_architectures: 37;50;52;60;61;70;75;80;86;87
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ RUN export LIBXCRYPT_VERSION=4.4.28 && \
|
|||
export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \
|
||||
export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \
|
||||
manylinux-entrypoint /build_scripts/install-libxcrypt.sh
|
||||
|
||||
|
||||
COPY scripts/install-protobuf.sh /build_scripts/
|
||||
RUN export PROTOBUF_VERSION=3.17.3 && \
|
||||
export PROTOBUF_ROOT=protobuf-all-${PROTOBUF_VERSION} && \
|
||||
|
|
@ -163,7 +163,7 @@ CMD ["/bin/bash"]
|
|||
|
||||
ARG PYTHON_VERSION=3.7
|
||||
ARG TORCH_VERSION=1.11.0
|
||||
ARG OPSET_VERSION=14
|
||||
ARG OPSET_VERSION=15
|
||||
ARG INSTALL_DEPS_EXTRA_ARGS
|
||||
|
||||
#Add our own dependencies
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ CMD ["/bin/bash"]
|
|||
|
||||
ARG PYTHON_VERSION=3.7
|
||||
ARG TORCH_VERSION=1.11.0
|
||||
ARG OPSET_VERSION=14
|
||||
ARG OPSET_VERSION=15
|
||||
ARG INSTALL_DEPS_EXTRA_ARGS
|
||||
|
||||
#Add our own dependencies
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ CMD ["/bin/bash"]
|
|||
#Build manylinux2014 docker image end
|
||||
ARG PYTHON_VERSION=3.9
|
||||
ARG TORCH_VERSION=1.11.0
|
||||
ARG OPSET_VERSION=14
|
||||
ARG OPSET_VERSION=15
|
||||
ARG INSTALL_DEPS_EXTRA_ARGS
|
||||
|
||||
#Add our own dependencies
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ CMD ["/bin/bash"]
|
|||
#Build manylinux2014 docker image end
|
||||
ARG PYTHON_VERSION=3.9
|
||||
ARG TORCH_VERSION=1.11.0
|
||||
ARG OPSET_VERSION=14
|
||||
ARG OPSET_VERSION=15
|
||||
ARG INSTALL_DEPS_EXTRA_ARGS
|
||||
|
||||
#Add our own dependencies
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ ORTMODULE_BUILD=false
|
|||
TARGET_ROCM=false
|
||||
CU_VER="11.1"
|
||||
ROCM_VER="5.1.1"
|
||||
TORCH_VERSION='1.10.0'
|
||||
TORCH_VERSION='1.11.0'
|
||||
USE_CONDA=false
|
||||
|
||||
while getopts p:h:d:v:tmurc parameter_Option
|
||||
|
|
|
|||
|
|
@ -41,4 +41,4 @@ RUN pip install \
|
|||
pytorch_lightning==1.6.0
|
||||
|
||||
RUN pip install torch-ort --no-dependencies
|
||||
ENV ORTMODULE_ONNX_OPSET_VERSION=14
|
||||
ENV ORTMODULE_ONNX_OPSET_VERSION=15
|
||||
|
|
|
|||
Loading…
Reference in a new issue