diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index 52882c6052..de8a9acc0f 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -193,3 +193,18 @@ class ORTModule(torch.nn.Module): """Raises a NotImplementedError exception since ORTModule does not support adding modules to it""" raise NotImplementedError("ORTModule does not support adding modules to it.") + + @property + def module(self): + """The original `torch.nn.Module` that this module wraps. + + This property provides access to methods and properties on the original module. + """ + + # HuggingFace Trainer `save_model` method checks to see if the input model is a HuggingFace PreTrainedModel + # or if the model has an attribute called `module` which references a HuggingFace PreTrainedModel to save + # the entire context of the model so that it can be loaded using HuggingFace `from_pretrained` method. + # This `module` property enables HuggingFace Trainer to retrieve the underlying PreTrainedModel inside ORTModule + # to save and load a complete checkpoint + + return self._module_metadata.original_module diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 28eb17987a..0abe778d90 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6,7 +6,7 @@ import math import random import copy import torch -from transformers import AutoConfig, BertForSequenceClassification +from transformers import AutoConfig, BertForSequenceClassification, Trainer from transformers.modeling_outputs import SequenceClassifierOutput import pytest from time import sleep @@ -15,6 +15,7 @@ from unittest.mock import patch from collections import OrderedDict from collections import namedtuple from inspect import signature +import tempfile from onnxruntime.training.ortmodule import ORTModule, _utils, _io import _test_helpers @@ -2777,3 +2778,33 @@ def test_load_state_dict_for_wrapped_ortmodule(): for param_name, param_value in state_dict1.items(): assert param_name in state_dict2 assert torch.equal(param_value, state_dict2[param_name]) + +def test_hf_save_pretrained(): + device = 'cuda' + + model1 = _get_bert_for_sequence_classification_model(device) + model1 = ORTModule(model1) + state_dict = model1.state_dict() + list(next(iter(state_dict.items())))[1] += 100 + model1.load_state_dict(state_dict) + + trainer = Trainer(model=model1) + + # Assert that ORTModule has an attribute called module. This attribute is used + # for trainer.save_model to reference the underlying HuggingFace PreTrainedModel + assert hasattr(model1, "module") + + # Create a temporary directory for the checkpoint from save_pretrained + with tempfile.TemporaryDirectory() as temporary_dir: + trainer.save_model(temporary_dir) + + # Create a new model and compare all state dictionary values for equality + # to check if from_pretrained worked. + config = AutoConfig.from_pretrained(temporary_dir) + model2 = BertForSequenceClassification.from_pretrained( + temporary_dir, config=config, + ).to(device) + model2 = ORTModule(model2) + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert p1.data.ne(p2.data).sum() == 0 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml index 4f8ab9130b..d811f05e37 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml @@ -23,7 +23,7 @@ jobs: DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' BuildConfig: $(buildConfig) ArtifactName: 'drop-linux' - TimeoutInMinutes: 120 + TimeoutInMinutes: 140 # Enable unreleased onnx opsets in CI builds # This facilitates testing the implementation for the new opsets AllowReleasedOpsetOnly: '0' diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt index 5426211953..6614e4c7ba 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt @@ -1,7 +1,7 @@ pandas sklearn numpy==1.19.5 -transformers==v4.3.2 +transformers==v4.4.2 tensorboard>=2.2.0,<2.5.0 h5py wget