Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088)

This commit is contained in:
baijumeswani 2021-06-18 13:13:45 -07:00 committed by GitHub
parent 08eeb8763d
commit 7701c8703e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 3 deletions

View file

@ -193,3 +193,18 @@ class ORTModule(torch.nn.Module):
"""Raises a NotImplementedError exception since ORTModule does not support adding modules to it"""
raise NotImplementedError("ORTModule does not support adding modules to it.")
@property
def module(self):
"""The original `torch.nn.Module` that this module wraps.
This property provides access to methods and properties on the original module.
"""
# HuggingFace Trainer `save_model` method checks to see if the input model is a HuggingFace PreTrainedModel
# or if the model has an attribute called `module` which references a HuggingFace PreTrainedModel to save
# the entire context of the model so that it can be loaded using HuggingFace `from_pretrained` method.
# This `module` property enables HuggingFace Trainer to retrieve the underlying PreTrainedModel inside ORTModule
# to save and load a complete checkpoint
return self._module_metadata.original_module

View file

@ -6,7 +6,7 @@ import math
import random
import copy
import torch
from transformers import AutoConfig, BertForSequenceClassification
from transformers import AutoConfig, BertForSequenceClassification, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput
import pytest
from time import sleep
@ -15,6 +15,7 @@ from unittest.mock import patch
from collections import OrderedDict
from collections import namedtuple
from inspect import signature
import tempfile
from onnxruntime.training.ortmodule import ORTModule, _utils, _io
import _test_helpers
@ -2777,3 +2778,33 @@ def test_load_state_dict_for_wrapped_ortmodule():
for param_name, param_value in state_dict1.items():
assert param_name in state_dict2
assert torch.equal(param_value, state_dict2[param_name])
def test_hf_save_pretrained():
device = 'cuda'
model1 = _get_bert_for_sequence_classification_model(device)
model1 = ORTModule(model1)
state_dict = model1.state_dict()
list(next(iter(state_dict.items())))[1] += 100
model1.load_state_dict(state_dict)
trainer = Trainer(model=model1)
# Assert that ORTModule has an attribute called module. This attribute is used
# for trainer.save_model to reference the underlying HuggingFace PreTrainedModel
assert hasattr(model1, "module")
# Create a temporary directory for the checkpoint from save_pretrained
with tempfile.TemporaryDirectory() as temporary_dir:
trainer.save_model(temporary_dir)
# Create a new model and compare all state dictionary values for equality
# to check if from_pretrained worked.
config = AutoConfig.from_pretrained(temporary_dir)
model2 = BertForSequenceClassification.from_pretrained(
temporary_dir, config=config,
).to(device)
model2 = ORTModule(model2)
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert p1.data.ne(p2.data).sum() == 0

View file

@ -23,7 +23,7 @@ jobs:
DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image'
BuildConfig: $(buildConfig)
ArtifactName: 'drop-linux'
TimeoutInMinutes: 120
TimeoutInMinutes: 140
# Enable unreleased onnx opsets in CI builds
# This facilitates testing the implementation for the new opsets
AllowReleasedOpsetOnly: '0'

View file

@ -1,7 +1,7 @@
pandas
sklearn
numpy==1.19.5
transformers==v4.3.2
transformers==v4.4.2
tensorboard>=2.2.0,<2.5.0
h5py
wget