mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add module attribute to ORTModule to support HuggingFace Trainer save_model (#8088)
This commit is contained in:
parent
08eeb8763d
commit
7701c8703e
4 changed files with 49 additions and 3 deletions
|
|
@ -193,3 +193,18 @@ class ORTModule(torch.nn.Module):
|
|||
"""Raises a NotImplementedError exception since ORTModule does not support adding modules to it"""
|
||||
|
||||
raise NotImplementedError("ORTModule does not support adding modules to it.")
|
||||
|
||||
@property
|
||||
def module(self):
|
||||
"""The original `torch.nn.Module` that this module wraps.
|
||||
|
||||
This property provides access to methods and properties on the original module.
|
||||
"""
|
||||
|
||||
# HuggingFace Trainer `save_model` method checks to see if the input model is a HuggingFace PreTrainedModel
|
||||
# or if the model has an attribute called `module` which references a HuggingFace PreTrainedModel to save
|
||||
# the entire context of the model so that it can be loaded using HuggingFace `from_pretrained` method.
|
||||
# This `module` property enables HuggingFace Trainer to retrieve the underlying PreTrainedModel inside ORTModule
|
||||
# to save and load a complete checkpoint
|
||||
|
||||
return self._module_metadata.original_module
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import math
|
|||
import random
|
||||
import copy
|
||||
import torch
|
||||
from transformers import AutoConfig, BertForSequenceClassification
|
||||
from transformers import AutoConfig, BertForSequenceClassification, Trainer
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
import pytest
|
||||
from time import sleep
|
||||
|
|
@ -15,6 +15,7 @@ from unittest.mock import patch
|
|||
from collections import OrderedDict
|
||||
from collections import namedtuple
|
||||
from inspect import signature
|
||||
import tempfile
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule, _utils, _io
|
||||
import _test_helpers
|
||||
|
|
@ -2777,3 +2778,33 @@ def test_load_state_dict_for_wrapped_ortmodule():
|
|||
for param_name, param_value in state_dict1.items():
|
||||
assert param_name in state_dict2
|
||||
assert torch.equal(param_value, state_dict2[param_name])
|
||||
|
||||
def test_hf_save_pretrained():
|
||||
device = 'cuda'
|
||||
|
||||
model1 = _get_bert_for_sequence_classification_model(device)
|
||||
model1 = ORTModule(model1)
|
||||
state_dict = model1.state_dict()
|
||||
list(next(iter(state_dict.items())))[1] += 100
|
||||
model1.load_state_dict(state_dict)
|
||||
|
||||
trainer = Trainer(model=model1)
|
||||
|
||||
# Assert that ORTModule has an attribute called module. This attribute is used
|
||||
# for trainer.save_model to reference the underlying HuggingFace PreTrainedModel
|
||||
assert hasattr(model1, "module")
|
||||
|
||||
# Create a temporary directory for the checkpoint from save_pretrained
|
||||
with tempfile.TemporaryDirectory() as temporary_dir:
|
||||
trainer.save_model(temporary_dir)
|
||||
|
||||
# Create a new model and compare all state dictionary values for equality
|
||||
# to check if from_pretrained worked.
|
||||
config = AutoConfig.from_pretrained(temporary_dir)
|
||||
model2 = BertForSequenceClassification.from_pretrained(
|
||||
temporary_dir, config=config,
|
||||
).to(device)
|
||||
model2 = ORTModule(model2)
|
||||
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert p1.data.ne(p2.data).sum() == 0
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ jobs:
|
|||
DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image'
|
||||
BuildConfig: $(buildConfig)
|
||||
ArtifactName: 'drop-linux'
|
||||
TimeoutInMinutes: 120
|
||||
TimeoutInMinutes: 140
|
||||
# Enable unreleased onnx opsets in CI builds
|
||||
# This facilitates testing the implementation for the new opsets
|
||||
AllowReleasedOpsetOnly: '0'
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pandas
|
||||
sklearn
|
||||
numpy==1.19.5
|
||||
transformers==v4.3.2
|
||||
transformers==v4.4.2
|
||||
tensorboard>=2.2.0,<2.5.0
|
||||
h5py
|
||||
wget
|
||||
|
|
|
|||
Loading…
Reference in a new issue