mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix and enable few ORTModule Unit Tests (#19847)
### Fix and enable few ORTModule Unit Tests
Fix 'test_bert_inputs_with_dynamic_shape' and
'test_bert_result_with_layerwise_recompute' generate Nan loss in ORT
run.
The root cause is, the logic to generatic attention mask test data is
not correct, only 0 or 1 is allowed in the dataset, but we see lots of
other numbers. ( The reason we don't have this using old version of
transformers for example v4.4.2 or 4.16.2 is because they don't contains
such
d3cb28886a,
which increase the scaling to a bigger number, causing a overflow to
inf)
Another improvement during the investigation using convergence tools:
Don't dump the activations during model export phase, otherwise, the
dumped data might contains some PyTorch run's result making us confused
during comparing with stock PyTorch run results.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
0c078dfc8b
commit
3e954da3e6
3 changed files with 58 additions and 49 deletions
|
|
@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output
|
|||
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
|
||||
|
||||
```diff
|
||||
+ from onnxruntime.training.utils import inspect_activation
|
||||
+ from onnxruntime.training.utils.hooks import inspect_activation
|
||||
class BloomForCausalLM(BloomPreTrainedModel):
|
||||
def __init__(self, config: BloomConfig):
|
||||
...
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import onnx
|
|||
import torch
|
||||
|
||||
from ._subscriber_base import RuntimeStates, SubscriberBase
|
||||
from ._subscriber_manager import ORT_NO_INCREASE_GLOBAL_STEP
|
||||
|
||||
|
||||
class _InspectActivation(torch.autograd.Function):
|
||||
|
|
@ -176,21 +177,23 @@ class StatisticsSubscriber(SubscriberBase):
|
|||
display_name = name + " forward run" if is_forward is True else name + " backward run"
|
||||
output_file_name = name + "_forward" if is_forward is True else name + "_backward"
|
||||
|
||||
if tensor is None or not isinstance(tensor, torch.Tensor):
|
||||
print(f"{display_name} not a torch tensor, value: {tensor}")
|
||||
return
|
||||
# Skip dump during model pre-export output schema preparison run and export run.
|
||||
if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
|
||||
if tensor is None or not isinstance(tensor, torch.Tensor):
|
||||
print(f"{display_name} not a torch tensor, value: {tensor}")
|
||||
return
|
||||
|
||||
step_path = Path(step_folder)
|
||||
if not step_path.exists():
|
||||
step_path.mkdir(parents=True, exist_ok=False)
|
||||
order_file_path = step_path / "order.txt"
|
||||
tensor_file_path = step_path / output_file_name
|
||||
step_path = Path(step_folder)
|
||||
if not step_path.exists():
|
||||
step_path.mkdir(parents=True, exist_ok=False)
|
||||
order_file_path = step_path / "order.txt"
|
||||
tensor_file_path = step_path / output_file_name
|
||||
|
||||
with order_file_path.open(mode="a", encoding="utf-8") as f:
|
||||
f.write(f"{output_file_name}\n")
|
||||
with order_file_path.open(mode="a", encoding="utf-8") as f:
|
||||
f.write(f"{output_file_name}\n")
|
||||
|
||||
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
|
||||
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
|
||||
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
|
||||
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
|
||||
|
||||
|
||||
def _summarize_tensor(
|
||||
|
|
|
|||
|
|
@ -417,24 +417,38 @@ def _get_bert_for_sequence_classification_model(
|
|||
return model
|
||||
|
||||
|
||||
def _get_bert_for_sequence_classification_sample_data(device):
|
||||
"""Returns sample data to be used with BertForSequenceClassification model"""
|
||||
def _generate_attention_mask_for_encoder_following_hf(batch_size, seq_length, device, past_key_values_length=0):
|
||||
"""Generate attention mask for encoder following the implementation in HuggingFace.
|
||||
|
||||
input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
|
||||
input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
|
||||
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
|
||||
Be noted: past_key_values_length is 0 for training.
|
||||
|
||||
return input_ids, input_mask, labels
|
||||
Generate mask using this
|
||||
https://github.com/huggingface/transformers/blame/4f27ee936a861f56f32ea6db138978b274008006/src/transformers/models/bert/modeling_bert.py#L974C81-L974C81
|
||||
|
||||
"""
|
||||
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
return attention_mask
|
||||
|
||||
|
||||
def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device):
|
||||
"""Returns sample data with random shape to be used with BertForSequenceClassification model"""
|
||||
|
||||
x = random.randint(1, 100)
|
||||
y = random.randint(1, 100)
|
||||
input_ids = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
|
||||
input_mask = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
|
||||
labels = torch.randint(0, 1, (x,), dtype=torch.long, device=device)
|
||||
bsz = random.randint(1, 100)
|
||||
seq_length = random.randint(1, 100)
|
||||
input_ids = torch.randint(0, 100, (bsz, seq_length), dtype=torch.long, device=device)
|
||||
input_mask = _generate_attention_mask_for_encoder_following_hf(bsz, seq_length, device)
|
||||
labels = torch.randint(0, 1, (bsz,), dtype=torch.long, device=device)
|
||||
|
||||
return input_ids, input_mask, labels
|
||||
|
||||
|
||||
def _get_bert_for_sequence_classification_sample_data(device):
|
||||
"""Returns sample data to be used with BertForSequenceClassification model"""
|
||||
|
||||
input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
|
||||
input_mask = _generate_attention_mask_for_encoder_following_hf(32, 64, device)
|
||||
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
|
||||
|
||||
return input_ids, input_mask, labels
|
||||
|
||||
|
|
@ -2211,32 +2225,27 @@ def test_ortmodule_inputs_with_dynamic_shape():
|
|||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
|
||||
# TODO(askhade): This test is failing with smaller tolerance, need to investigate! Disabling it right now to
|
||||
# unblock the move to a later version of transformers to resolve security vulnerability.
|
||||
# (Moving from transformers v4.4.2 to v4.30.0)
|
||||
# def test_bert_inputs_with_dynamic_shape():
|
||||
# # create pytorch model with dropout disabled
|
||||
# pt_model = _get_bert_for_sequence_classification_model(
|
||||
# "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
|
||||
# )
|
||||
# ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
def test_bert_inputs_with_dynamic_shape():
|
||||
# create pytorch model with dropout disabled
|
||||
pt_model = _get_bert_for_sequence_classification_model(
|
||||
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
|
||||
)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
# def run_step(model, x, y, z):
|
||||
# outputs = model(x, y, None, None, None, None, z)
|
||||
# loss = outputs[0]
|
||||
# loss.backward()
|
||||
# return outputs[0]
|
||||
def run_step(model, x, y, z):
|
||||
outputs = model(x, y, None, None, None, None, z)
|
||||
loss = outputs[0]
|
||||
loss.backward()
|
||||
return outputs[0]
|
||||
|
||||
# for _step in range(10):
|
||||
# x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
|
||||
for _step in range(10):
|
||||
x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
|
||||
|
||||
# pt_p = run_step(pt_model, x, y, z)
|
||||
# ort_p = run_step(ort_model, x, y, z)
|
||||
pt_p = run_step(pt_model, x, y, z)
|
||||
ort_p = run_step(ort_model, x, y, z)
|
||||
|
||||
# _test_helpers.assert_values_are_close(
|
||||
# ort_p, pt_p, atol=1e-01
|
||||
# ) # TODO: this assert is failing with smaller tolerance, need to investigate!!
|
||||
# # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) #TODO - enable this check after the investigation
|
||||
_test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-01)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
|
|
@ -6424,9 +6433,6 @@ def test_conv_transpose_gradient_with_strides_padding_and_dilation(conv_algo_sea
|
|||
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now."
|
||||
)
|
||||
def test_bert_result_with_layerwise_recompute():
|
||||
original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None)
|
||||
# Create PyTorch model with dropout disabled.
|
||||
|
|
|
|||
Loading…
Reference in a new issue