Fix and enable few ORTModule Unit Tests (#19847)

### Fix and enable few ORTModule Unit Tests

Fix 'test_bert_inputs_with_dynamic_shape' and
'test_bert_result_with_layerwise_recompute' generate Nan loss in ORT
run.

The root cause is, the logic to generatic attention mask test data is
not correct, only 0 or 1 is allowed in the dataset, but we see lots of
other numbers. ( The reason we don't have this using old version of
transformers for example v4.4.2 or 4.16.2 is because they don't contains
such
d3cb28886a,
which increase the scaling to a bigger number, causing a overflow to
inf)

Another improvement during the investigation using convergence tools:
Don't dump the activations during model export phase, otherwise, the
dumped data might contains some PyTorch run's result making us confused
during comparing with stock PyTorch run results.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2024-03-12 10:49:19 +08:00 committed by GitHub
parent 0c078dfc8b
commit 3e954da3e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 58 additions and 49 deletions

View file

@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:
```diff
+ from onnxruntime.training.utils import inspect_activation
+ from onnxruntime.training.utils.hooks import inspect_activation
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
...

View file

@ -14,6 +14,7 @@ import onnx
import torch
from ._subscriber_base import RuntimeStates, SubscriberBase
from ._subscriber_manager import ORT_NO_INCREASE_GLOBAL_STEP
class _InspectActivation(torch.autograd.Function):
@ -176,21 +177,23 @@ class StatisticsSubscriber(SubscriberBase):
display_name = name + " forward run" if is_forward is True else name + " backward run"
output_file_name = name + "_forward" if is_forward is True else name + "_backward"
if tensor is None or not isinstance(tensor, torch.Tensor):
print(f"{display_name} not a torch tensor, value: {tensor}")
return
# Skip dump during model pre-export output schema preparison run and export run.
if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
if tensor is None or not isinstance(tensor, torch.Tensor):
print(f"{display_name} not a torch tensor, value: {tensor}")
return
step_path = Path(step_folder)
if not step_path.exists():
step_path.mkdir(parents=True, exist_ok=False)
order_file_path = step_path / "order.txt"
tensor_file_path = step_path / output_file_name
step_path = Path(step_folder)
if not step_path.exists():
step_path.mkdir(parents=True, exist_ok=False)
order_file_path = step_path / "order.txt"
tensor_file_path = step_path / output_file_name
with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")
with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
def _summarize_tensor(

View file

@ -417,24 +417,38 @@ def _get_bert_for_sequence_classification_model(
return model
def _get_bert_for_sequence_classification_sample_data(device):
"""Returns sample data to be used with BertForSequenceClassification model"""
def _generate_attention_mask_for_encoder_following_hf(batch_size, seq_length, device, past_key_values_length=0):
"""Generate attention mask for encoder following the implementation in HuggingFace.
input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
Be noted: past_key_values_length is 0 for training.
return input_ids, input_mask, labels
Generate mask using this
https://github.com/huggingface/transformers/blame/4f27ee936a861f56f32ea6db138978b274008006/src/transformers/models/bert/modeling_bert.py#L974C81-L974C81
"""
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
return attention_mask
def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device):
"""Returns sample data with random shape to be used with BertForSequenceClassification model"""
x = random.randint(1, 100)
y = random.randint(1, 100)
input_ids = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
input_mask = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
labels = torch.randint(0, 1, (x,), dtype=torch.long, device=device)
bsz = random.randint(1, 100)
seq_length = random.randint(1, 100)
input_ids = torch.randint(0, 100, (bsz, seq_length), dtype=torch.long, device=device)
input_mask = _generate_attention_mask_for_encoder_following_hf(bsz, seq_length, device)
labels = torch.randint(0, 1, (bsz,), dtype=torch.long, device=device)
return input_ids, input_mask, labels
def _get_bert_for_sequence_classification_sample_data(device):
"""Returns sample data to be used with BertForSequenceClassification model"""
input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
input_mask = _generate_attention_mask_for_encoder_following_hf(32, 64, device)
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
return input_ids, input_mask, labels
@ -2211,32 +2225,27 @@ def test_ortmodule_inputs_with_dynamic_shape():
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
# TODO(askhade): This test is failing with smaller tolerance, need to investigate! Disabling it right now to
# unblock the move to a later version of transformers to resolve security vulnerability.
# (Moving from transformers v4.4.2 to v4.30.0)
# def test_bert_inputs_with_dynamic_shape():
# # create pytorch model with dropout disabled
# pt_model = _get_bert_for_sequence_classification_model(
# "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
# )
# ort_model = ORTModule(copy.deepcopy(pt_model))
def test_bert_inputs_with_dynamic_shape():
# create pytorch model with dropout disabled
pt_model = _get_bert_for_sequence_classification_model(
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
)
ort_model = ORTModule(copy.deepcopy(pt_model))
# def run_step(model, x, y, z):
# outputs = model(x, y, None, None, None, None, z)
# loss = outputs[0]
# loss.backward()
# return outputs[0]
def run_step(model, x, y, z):
outputs = model(x, y, None, None, None, None, z)
loss = outputs[0]
loss.backward()
return outputs[0]
# for _step in range(10):
# x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
for _step in range(10):
x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
# pt_p = run_step(pt_model, x, y, z)
# ort_p = run_step(ort_model, x, y, z)
pt_p = run_step(pt_model, x, y, z)
ort_p = run_step(ort_model, x, y, z)
# _test_helpers.assert_values_are_close(
# ort_p, pt_p, atol=1e-01
# ) # TODO: this assert is failing with smaller tolerance, need to investigate!!
# # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) #TODO - enable this check after the investigation
_test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-01)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@ -6424,9 +6433,6 @@ def test_conv_transpose_gradient_with_strides_padding_and_dilation(conv_algo_sea
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
@pytest.mark.skip(
reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now."
)
def test_bert_result_with_layerwise_recompute():
original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None)
# Create PyTorch model with dropout disabled.