onnxruntime/orttraining/orttraining/python/training/ortmodule
pengwa f6c81d8aca
Introduce padding inspector in ORTModule (#14652)
### Introduce padding inspector in ORTModule

In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.

This PR introduces a switch to enable data sparsity inspection, which 
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).

Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.

**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
        | STEP       | INPUT TYPE |  INPUT NAME     | PAD IDX    | DENSITY    | VALID TOKENS    | TOTAL TOKENS    | VALID TOKENS/BATCH |
        | 60         | EMBED      | input_ids       | 1          | 35.21    % | 1442            | 4096            | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
        | 61         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 62         | EMBED      | input_ids       | 1          | 30.00    % | 1229            | 4096            | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
        | 63         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 64         | EMBED      | input_ids       | 1          | 26.73    % | 1095            | 4096            | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
        | 65         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 66         | EMBED      | input_ids       | 1          | 30.03    % | 1230            | 4096            | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
        | 67         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 68         | EMBED      | input_ids       | 1          | 31.62    % | 1295            | 4096            | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
        | 69         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
<<<
```

**High Embed Density, Low Label Density Case - masked language model** 
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
        | STEP       | INPUT TYPE |  INPUT NAME     | PAD IDX    | DENSITY    | VALID TOKENS    | TOTAL TOKENS    | VALID TOKENS/BATCH |
        | 710        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 711        | LABEL      | labels          | -100       | 13.77    % | 564             | 4096            | N/A             |
        | 712        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 713        | LABEL      | labels          | -100       | 14.48    % | 593             | 4096            | N/A             |
        | 714        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 715        | LABEL      | labels          | -100       | 14.18    % | 581             | 4096            | N/A             |
        | 716        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 717        | LABEL      | labels          | -100       | 14.53    % | 595             | 4096            | N/A             |
        | 718        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 719        | LABEL      | labels          | -100       | 15.31    % | 627             | 4096            | N/A             |
<<<
```

#### Next Step

Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 18:36:08 +08:00
..
experimental Fix a PoliCheck finding in _hierarchical_ortmodule.py(#13462) 2022-10-26 15:45:18 -07:00
torch_cpp_extensions Add support for ORTModule Torch cpp CUDA extension build within docker (#12868) 2022-09-08 15:30:44 -04:00
__init__.py Update ORTModule Default Opset Version to 15 (#12419) 2022-08-05 16:55:04 +08:00
_custom_autograd_function.py [ORTModule] Add Env Variable to Control Disabling Custom AutoGrad Function Support (#13430) 2022-10-25 16:58:04 +08:00
_custom_autograd_function_exporter.py SCELoss(SCELossGrad) support half(float) input float(half) output (#13972) 2023-02-28 18:02:08 +08:00
_custom_autograd_function_runner.py Fix the tensor save for backward release problem (#13679) 2022-11-22 17:32:19 +08:00
_custom_gradient_registry.py [ORTModule] ATen Support for upsample_bilinear (#14519) 2023-02-04 15:20:18 +08:00
_custom_op_symbolic_registry.py SCELoss(SCELossGrad) support half(float) input float(half) output (#13972) 2023-02-28 18:02:08 +08:00
_execution_agent.py fix memory profile for partial graph run (#11911) 2022-06-24 13:08:14 +08:00
_fallback.py Format all python files under onnxruntime with black and isort (#11324) 2022-04-26 09:35:16 -07:00
_fallback_exceptions.py Format all python files under onnxruntime with black and isort (#11324) 2022-04-26 09:35:16 -07:00
_gradient_accumulation_manager.py Move OrtValueVector from onnxruntime-training to onnxruntime (#11176) 2022-06-15 09:36:28 +02:00
_graph_execution_interface.py Add PyTorch fallback for ORTModule forward exceptions (#8346) 2021-08-17 10:41:15 -07:00
_graph_execution_manager.py Introduce padding inspector in ORTModule (#14652) 2023-03-03 18:36:08 +08:00
_graph_execution_manager_factory.py Add PyTorch fallback for ORTModule forward exceptions (#8346) 2021-08-17 10:41:15 -07:00
_inference_manager.py Introduce padding inspector in ORTModule (#14652) 2023-03-03 18:36:08 +08:00
_io.py Introduce padding inspector in ORTModule (#14652) 2023-03-03 18:36:08 +08:00
_logger.py Format all python files under onnxruntime with black and isort (#11324) 2022-04-26 09:35:16 -07:00
_onnx_models.py [ORTModule] Fix Graph Builder for Eval Mode (#13255) 2022-10-12 14:39:54 +08:00
_runtime_inspector.py Introduce padding inspector in ORTModule (#14652) 2023-03-03 18:36:08 +08:00
_torch_module_factory.py Add PyTorch fallback for ORTModule forward exceptions (#8346) 2021-08-17 10:41:15 -07:00
_torch_module_interface.py Format all python files under onnxruntime with black and isort (#11324) 2022-04-26 09:35:16 -07:00
_torch_module_ort.py Format all python files under onnxruntime with black and isort (#11324) 2022-04-26 09:35:16 -07:00
_torch_module_pytorch.py Format all python files under onnxruntime with black and isort (#11324) 2022-04-26 09:35:16 -07:00
_training_manager.py Introduce padding inspector in ORTModule (#14652) 2023-03-03 18:36:08 +08:00
_utils.py Replace distutils by setuptools to import build_ext (#14108) 2023-01-09 11:48:01 +01:00
debug_options.py Format all python files under onnxruntime with black and isort (#11324) 2022-04-26 09:35:16 -07:00
ortmodule.py Del ort_model._modules to foward its accessing to torch_model._modules (#14563) 2023-03-03 10:12:37 +08:00