Commit graph

26 commits

Author SHA1 Message Date
guyang3532
cfe830b248
Generalize label input sparsity check and refactor (#20636)
### Description
The InsertGatherBeforeSceLoss optimization is enabled when the density
of label padding less than 90%. We need to check the density of the
label padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern
between graph input and SCE node.

This pr check padding density by the direct input of SCE module rather
than the input of graph at the first graph execution when exporting onnx
graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
```
           SoftmaxCrossEntropy
		  |
            PythonOp (func_name: FlagAndPrintDensity)   (insert if density < 90%)
		  |
            Following graph
```

When the InsertGatherBeforeSceLoss is invoked, it check if there is the
flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if
it is, remove it and do the padding elimination optimization.

If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input
density each step by the PythonOp (func_name: FlagAndPrintDensity). In
this case the PythonOp will not be removed.
2024-05-10 21:55:43 +08:00
Adam Louly
ee74fb6908
Introducing ORTPipelineModule - DeepSpeed Parallel Pipeline Support. (#20287)
### Description
Introducing a new class ORTPipelineModule to handle wrapping layers in
DeepSpeed pipeline parallel.


### Motivation and Context
To support pipeline parallelism on ORTModule.

This PR will include an initial support of deepspeed Pipeline
parallelism.

- [x] Support Pipeline parallel where layers are nn Modules in
Sequential.
- [ ] Support LayerSpec and TiedLayerSpec
- [ ] Enable partitioning to accept List
- [ ] Full-GPU Graph Consolidation
- [ ] Subgraph Merging for Inference
2024-04-18 11:30:15 -07:00
pengwa
d102569755
Fix seed for recomputed Dropout (#19715)
### Fix seed for recomputed Dropout

If Dropout node is recomputed in the backward, we should make sure its
execution is same as the run in the forward.
If we don't set seed attribute, then this cannot be guaranteed. 

Add ` export ORTMODULE_MEMORY_OPT_LEVEL=2` to enabled per layer
recompute with compromised recomputable subgraphs.
2024-03-06 10:06:25 +08:00
guyang3532
cd56ea4a74
enable embedding sparse optimization by default (#19714) 2024-03-05 13:15:30 +08:00
pengwa
acbfc29f27
Follow up fix for Gelu impl (#19693)
### Follow up fix for Gelu impl

There are two minor comments in
https://github.com/microsoft/onnxruntime/pull/19560.

Fix them in this pull request. 


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-03-01 10:57:14 +08:00
pengwa
1150b1f81e
ORTModule memory improvement (#18924)
## Dependency

https://github.com/microsoft/onnxruntime/pull/19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
https://github.com/microsoft/onnxruntime/pull/8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
2024-01-16 08:57:37 +08:00
pengwa
ccf3b2054b
Allow layer-wise recompute (#18566)
### Allow layer-wise recompute 

Early, we need users/developers to specify the subgraphs to recompute,
now we introduced a more user-friendly way to enable recompute for all
detected stashed activation recomputation subgraphs. This scarifies
getting the best configs while makes it easier to support user
requirements when they switches from PyTorch per-layer gradient
checkpoint to ORTModule.

`ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by
default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin
`ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible
to existing recompute usage in ORTModule integrated models.

Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans
detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be
respected any more.


Add Unit Tests using 3 layer blooms. 



https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md
2023-12-12 08:44:05 +08:00
pengwa
4bfa84487c
Skip module clone for preparing large model export (#18663)
### Skip module clone for preparing large model export

For LLAMA2 13B, when running with Lora, DeepSpeed stage2 on 8 GPUs . It
failed during preparing outputs which will be used for
torch.onnx.export. The reason, we deep copy all the params including
both big sizes of frozen weights, + a little bit of Lora trainable
weight.

This PR will firstly check whether the GPU memmory is enough for a
cloned module, if not, skip the copy.

Copying the module is to guarantee the fw path run may change the
weight, while this case should be rare. But for now, Not-Able-To-Run is
worse than Runnable-with-A-little-bit-different-initial-weight,
especially for large models.
2023-12-05 12:41:17 -08:00
Vincent Wang
e1d1033131
[ORTModule] Remove Unused Arguments from Generated Triton Code (#18636)
This PR:
- Remove unused arguments from generated triton code,
- Remove unnecessary mask for symbolic shape case from generated triton
code.
- Add doc for usage of ORTMODULE_TRITON_CONFIG_FILE.
2023-11-30 18:32:36 +08:00
Vincent Wang
3bc9efc7b2
[ORTModule] Adjust Attention Patterns for Efficient Attention ATen Fallback (#18471)
Adjust attention patterns to match latest Whisper+exporter. Also add
some condition check and add docs.
2023-11-22 15:24:05 +08:00
pengwa
2151c79bf1
Tune ORTModule logging experience a bit (#18298)
### Tune logging experience a bit

After last time we update the ORTModule log experience, we found few
issues:
1. `INFO` level output too many things, including PyTorch exporter
verbose logs (tracing graphs) on every ranks. On this level, we only
want to
- Output a little bit more information to Users than `WARNING` level,
for example the memory recomputation recommendations or other
not-fully-ready features.
- Output a little bit more information for a quick diagnostic, collected
on rank-0 only.
2. ONNX Runtime logging filter during graph build, session init
sometimes will hide the issues (for example segement fault), there is no
useful information in `WARNING`/`INFO` for users to report to us. This
is not good!
3. Some of our devs like using `pdb` to debug Python code, but if we add
`import pdb; pdb.set_trace()` in models' code might hang when they use
`INFO` or `WARNING`, where exporter happens and all output got
redirected due to log filtering. The only workaround is to switch to
VERBOSE, which output toooooooooooo many logs.

The corresponding changes proposed here are:
1. For `INFO` logging, 
    - We only logs rank-0. 
- We restricted the ORT backend logging level to be WARNING in this
case, because ORT backend code output way too many logs that should be
under verbose, while we cannot guarantee we can get them cleaned up
immediately once they are added.
- We output the PyTorch exporter verbose log (including tracing graph),
which is useful for a quick diagnostic when an issue happens.
2. Remove all logging filtering on ORT backend, then the segment fault
issue details will not be hidden once it happens again.
 3. Introduced a `DEVINFO` logging,
     - Log logs on all ranks
     - Log ORT backend logging level INFO
- PyTorch exporter logging filtering are all turned OFF (to unblock the
pdb debugging).
4. Currently, to use Memory Optimizer, need use DEVINFO (which will
output ORT backend INFO log). So update memory optimizer document to
reflect this. https://github.com/microsoft/onnxruntime/pull/17481 will
update the requirement back to INFO for show memory optimization infos.

You can check
https://github.com/microsoft/onnxruntime/blob/pengwa/devinfo_level/docs/ORTModule_Training_Guidelines.md#log-level-explanations
for a better view of different log levels.

This PR also extract some changes from a bigger one
https://github.com/microsoft/onnxruntime/pull/17481, to reduce its
complexity for review.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: mindest <30493312+mindest@users.noreply.github.com>
2023-11-08 17:42:50 +08:00
pengwa
6e6f582e08
Use full qualified name for PythonOp export (#17021)
### Use full qualified name for PythonOp export

Originally, when there are duplicate named torch.autograd.Function in
different module, for example:

`a.b.c.Gelu` v.s. `d.e.func.<locals>.Gelu`

We by default will throw exception to let user be aware we cannot
distinguish the two Gelu because during model export, we did not module
path. The workaround is we introduced
`ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS` to ignore those duplicated named
Gelu that is not used by model run. This has limitations obviously for
example if two Gelus are both used in training.



This PR finds a way to construct a full qualified name.

`def _export_pt_1_10(g, n, *args, **kwargs):`

1. in exporter function, kwargs contains `name` and `module`, in the
above example:
   `a.b.c.Gelu`  --> name: `Gelu`, module: `a.b.c`
   `d.e.func.<locals>.Gelu` --> name: `Gelu`, module: `d.e`
   
 
Using name and module is not enough to get a full qualified name, for
the second case, where `d.e` is the module path, then there is a
function called `func`, in this function, there is a local
auto.grad.Function named `Gelu`. (Many of our UT looks like this). We
can only get `d.e.Gelu`, but this is not the correct full qual name.

The reason for this: `kwargs[name]` or `n.name` only return the class's
name, not the class's full qual name. (be noted kwargs[module]` is
correct).

2. `n` is torch.Node, we can access `pyobj` to get the
torch.autograd.Function's apply method instance, then use `._self` to
get the torch.autograd.Function class. Then we can get the `module` and
`class`'s ful qual name, added together, we get the full qual name.

With the above change, we don't need use `kwargs[name]` and
`kwargs[module]` , and don't need check naming conflicting or
`ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS` env var any more.
2023-08-09 10:58:33 +08:00
Prathik Rao
779fba1666
ORT Cache (#16744)
### Description
<!-- Describe your changes. -->

This PR adds support to cache the exported training/evaluation ONNX
model in `ORTModule`. On future runs, instead of exporting the model
again, we can pick up the model from a location on disc and run
`ORTModule` training/evaluation.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

ORT Training DRI Contribution

---------

Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
Co-authored-by: pengwa <pengwa@microsoft.com>
2023-07-27 09:00:43 -07:00
Vincent Wang
c07a3b869c
Triton Codegen for ORTModule (#15831)
Fuse connected elementwise and reduce Ops to TritonOp and codegen triton
code to run the kernel.

This PR is co-edited by @wejoncy and @er3x3
2023-07-13 18:17:58 +08:00
Adam Louly
211fe5988e
add steps to write modulewithloss wrapper (#16486)
### Description
This PR includes documentation updates, providing step-by-step
instructions on how to implement the ModuleWithLoss wrapper in a
different codebase.
The documentation outlines the necessary code changes and offers
customization options based on specific requirements.

---------

Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
2023-07-11 09:07:35 +08:00
pengwa
a49bb85cfe
Manage ORTModule configurations consistently (#16396)
### Manage ORTModule options

Move all env vars that used for feature ON/OFF into runtime options for
consistent managements.


Be noted: the features' switch are assigned in 2 phases: default values,
overwritten by env vars (if specified by users). So env vars take the
highest priority when all 2 phases both given value explicitly for one
feature.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-06-27 19:19:36 +08:00
guyang3532
341484e67c
Embedding sparsity optimization (#16141)
### Description
Optimize compute graph by eliminating padding in embedding.


### Motivation and Context
The computation for padding in nodes after embedding is unnecessary and
waste computation resources.
This pr just add an Optimizer of PaddingElimination to check and
eliminate the padding after embedding automatically by modifying the
graph.

### Implementation:
1. Find and check embedding node in graph.
2. Iterate the subgraph afterward the embedding node and record all the
input nodes and output nodes to this subgraph.
3. Insert 'Reshape + ShrunkenGather' to flatten each input node shape
from [batch_size, seqlen, ...] to [valid_token_without_padding, ...],
and insert 'GatherGrad + Reshape' to unflatten each output node shape
from [valid_token_without_padding, ...] to [batch_size, seqlen, ...]

---------

Co-authored-by: mindest <linminuser@gmail.com>
2023-06-19 20:34:53 +08:00
pengwa
735a32fee1
Introduce memory observer for ORTModule (#16213)
### Introduce memory observer for ORTModule

To analyze memory usage for ORTModule training, we need collect
per-iteration memory footprint in different stages (pre-forward,
post-forward, pre-backward, and post-backward).

Currently we only collect the data using torch.cuda APIs. The next step
is, we could collect the detailed stashed activation list and its
percentage within ORT backend, which is beyond this PR.

Sample as below: 
```
0/8] step 0 memory (MiB) | phase: pre_forward | allocated: 1866 | max allocated: 1866 | cached: 1874 | max cached: 1874 | inactive: 8 | max inactive: 8
[0/8] step 0 memory (MiB) | phase: post_forward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405
[0/8] step 0 memory (MiB) | phase: pre_backward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405
[0/8] step 0 memory (MiB) | phase: post_backward | allocated: 2932 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 6158 | max inactive: 6158
  0%|█                                                                                                                                                                                                            | 1/200 [00:26<1:26:18, 26.02s/it]
[0/8] step 1 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 2454 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 1 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165
  1%|██                                                                                                                                                                                                             | 2/200 [00:26<36:47, 11.15s/it]
[0/8] step 2 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2454 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165
[0/8] step 2 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165
```
2023-06-15 15:45:36 +08:00
pengwa
b457cfaa8f
Enable conditional optimization automatically (#15885)
### Enable conditional optimization on inputs

Label sparsity based optimization can be enabled depending on the input
inspection result.

So this PR introduce a conditional optimization path for ORTModule,
where we automatically detect data sparsity from label or embedding, and
enable the graph optimization accordingly without any user interaction.

This feature had a new requirement of delaying passing pre_grad graph
transformation config to OrtModuleGraphBuilder, from `Initialize` phase
to its `Build` phase. Because once after `_initialize_graph_builder` we
can detect the input sparsity, and make a decision to enable the
label/embed sparisty based graph optimizations.

Add UT cases for label/embed input runtime inspector.
2023-05-23 13:08:05 +08:00
Baiju Meswani
11b0a18de6
Add support for cuda 11.8 and python 3.11 for training (#15548) 2023-04-20 12:56:45 -07:00
pengwa
516c8e95fa
Optimize SCE loss compute (#15401)
### Optimize SCE loss compute

Compute optimization based on label data sparsity:
- Insert ShrunkenGather before SCELoss node, to filter out invalid
labels for compute.
- Support ShrunkenGather upstream.
- Added test for the above.
- Added flag to enable label sparsity optimization with env var, by
default disabled now. Will enable after comprehensive benchmarking
later.
- Extract common logic into test_optimizer_utils.h/cc from
core/optimizer/compute_optimzier_test.cc, then the common functions can
be shared by both core/optimizer/compute_optimzier_test.cc and
orttraining/core/optimizer/compute_optimzier_test.cc
- Extract common logic into shared_utils.h/cc: `GetONNXOpSetVersion` and
`Create1DInitializerFromVector`


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-04-13 13:02:12 +08:00
pengwa
f6c81d8aca
Introduce padding inspector in ORTModule (#14652)
### Introduce padding inspector in ORTModule

In some Transformer-based LLM training recipes, high data sparsity is
observed due to 1). token padding (to max sequence length), 2). labels
contains many ignore_index for calculate loss.

This PR introduces a switch to enable data sparsity inspection, which 
1). in short term, can inform training users to use techniques like
dynamic batching to amortize the issue.
2). in medium and longer term, also helps us (training team) to have
better understanding what our training customers' models looks like from
perspective of data sparsity (and potentially motivate us to improve
with runtime).

Here is an example of different data sparsity with same training model
arch, same training input, but with different user models.

**Low Embed Density, High Label Density Case - Sentence Classification**
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/text-classification/run_glue.py
--model_name_or_path roberta-large-openai-detector --task_name mnli
--do_train --do_eval --max_seq_length 128 --per_device_train_batch_size
32 --learning_rate 2e-5 --num_train_epochs 3 --overwrite_output_dir
--output_dir ./outputs/ --per_device_eval_batch_size 32 --seed 1137
--fp16 True --ignore_mismatched_sizes True --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
        | STEP       | INPUT TYPE |  INPUT NAME     | PAD IDX    | DENSITY    | VALID TOKENS    | TOTAL TOKENS    | VALID TOKENS/BATCH |
        | 60         | EMBED      | input_ids       | 1          | 35.21    % | 1442            | 4096            | [50, 81, 35, 11, 29, 36, 66, 19, 40, 22, 21, 42, 17, 37, 40, 41, 26, 58, 38, 54, 41, 73, 48, 57, 50, 51, 49, 85, 48, 36, 79, 62] |
        | 61         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 62         | EMBED      | input_ids       | 1          | 30.00    % | 1229            | 4096            | [36, 73, 13, 47, 27, 33, 53, 25, 51, 28, 36, 42, 42, 32, 39, 52, 27, 13, 31, 66, 42, 45, 52, 45, 58, 42, 37, 66, 12, 18, 29, 17] |
        | 63         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 64         | EMBED      | input_ids       | 1          | 26.73    % | 1095            | 4096            | [37, 28, 20, 53, 16, 20, 44, 52, 27, 28, 16, 19, 16, 24, 63, 31, 24, 42, 33, 41, 44, 60, 44, 67, 54, 30, 20, 19, 33, 23, 24, 43] |
        | 65         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 66         | EMBED      | input_ids       | 1          | 30.03    % | 1230            | 4096            | [22, 46, 36, 41, 46, 43, 26, 50, 60, 16, 24, 42, 56, 35, 35, 59, 29, 39, 34, 20, 66, 23, 47, 53, 19, 35, 44, 23, 34, 81, 21, 25] |
        | 67         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
        | 68         | EMBED      | input_ids       | 1          | 31.62    % | 1295            | 4096            | [75, 36, 48, 20, 38, 21, 49, 54, 38, 41, 26, 28, 80, 45, 48, 16, 22, 41, 34, 28, 37, 16, 74, 63, 62, 34, 22, 45, 23, 27, 37, 67] |
        | 69         | LABEL      | labels          | -100       | 100.00   % | 32              | 32              | N/A             |
<<<
```

**High Embed Density, Low Label Density Case - masked language model** 
`
python -m torch.distributed.launch --nproc_per_node=4
examples/onnxruntime/training/language-modeling/run_mlm.py
--model_name_or_path bert-base-uncased --dataset_name wikitext
--dataset_config_name wikitext-2-raw-v1 --num_train_epochs 10
--per_device_train_batch_size 8 --per_device_eval_batch_size 8
--do_train --do_eval --overwrite_output_dir --output_dir ./outputs/
--seed 1137 --fp16 --report_to none --optim adamw_ort_fused
`
```
>>>Valid token/label density (e.g. valid/total) in passing 10 steps:
        | STEP       | INPUT TYPE |  INPUT NAME     | PAD IDX    | DENSITY    | VALID TOKENS    | TOTAL TOKENS    | VALID TOKENS/BATCH |
        | 710        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 711        | LABEL      | labels          | -100       | 13.77    % | 564             | 4096            | N/A             |
        | 712        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 713        | LABEL      | labels          | -100       | 14.48    % | 593             | 4096            | N/A             |
        | 714        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 715        | LABEL      | labels          | -100       | 14.18    % | 581             | 4096            | N/A             |
        | 716        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 717        | LABEL      | labels          | -100       | 14.53    % | 595             | 4096            | N/A             |
        | 718        | EMBED      | input_ids       | 0          | 100.00   % | 4096            | 4096            | [512, 512, 512, 512, 512, 512, 512, 512] |
        | 719        | LABEL      | labels          | -100       | 15.31    % | 627             | 4096            | N/A             |
<<<
```

#### Next Step

Let's see how we leverage the data sparsity for improvement.
Optimizations on the way around compute optimizer wave 2:
> Loss compute flops reduction.
> Flatten/Unflatten embedding tokens to save compute flops.
2023-03-03 18:36:08 +08:00
Ashwini Khade
68b5b2d7d3
Refactor training build options (#13964)
### Description
1. Renames all references of on device training to training apis. This
is to keep the naming general. Nothing really prevents us from using the
same apis on servers\non-edge devices.
2. Update ENABLE_TRAINING option: With this PR when this option is
enabled, training apis and torch interop is also enabled.
3. Refactoring for onnxruntime_ENABLE_TRAINING_TORCH_INTEROP option: 
   -  Removed user facing option
- Setting onnxruntime_ENABLE_TRAINING_TORCH_INTEROP to ON when
onnxruntime_ENABLE_TRAINING is ON as we always build with torch interop.

Once this PR is merged when --enable_training is selected we will do a
"FULL Build" for training (with all the training entry points and
features).
Training entry points include:
1. ORTModule
2. Training APIs

Features include:
1. ATen Fallback
2. All Training OPs includes communication and collectives
3. Strided Tensor Support
4. Python Op (torch interop)
5. ONNXBlock (Front end tools for training artifacts prep when using
trianing apis)

### Motivation and Context
Intention is to simply the options for building training enabled builds.
This is part of the larger work item to create dedicated build for
learning on the edge scenarios with just training apis enabled.
2023-01-03 13:28:16 -08:00
pengwa
2f5bf75e51
Optimize computation orders (#13672)
### Optimize computation orders

In `Roberta/Electra`, when `ClassificationHead` is used, there is
slicing operation on features on sequence_length dimensions, then loss
calculations only depend on this sliced data. This is a slicing at axis
1. Before slicing the shape is [batch, sequence_length, hidden], after
slicing, it becomes [batch , hidden_stage]

We had opportunities to bring this slicing earlier as much as possible,
by passing through simple elementwise ops (like Add/Div), or
Layernorm/Softmax(if their reduce axis is after the slicing axis), or
even MatMul's the left operand (if only it did not affect the last
dims).

For operators like Reshape/Transpose, it is special since they have
either data specified (after slicing we need update), or they have perm
specified, which requires the input rank remain unchanged. So for those
kinds of operators, we can remain the original rank, but just leave the
sliced dim to be 1, after the compute completed, we do a Squeeze.

```
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
```

src\transformers\models\roberta\modeling_roberta.py
src\transformers\models\electra\modeling_electra.py

#### Benchmark

A simple benchmark shows Robeta training latency dropped from 208ms ~
199ms. 4.5+% reduction.
More comprehensive tests are on the way.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2022-12-22 15:12:52 +08:00
pengwa
d5721b3464
Fix wrong import path in docs (#13680)
### Fix wrong import path in docs
2022-11-17 18:15:02 +08:00
pengwa
ab9ac2acc4
Add guidelines for ORTModule (#13553)
### Add guidelines for ORTModule

As title.

Feel free to let me know if I missed something. 

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2022-11-04 19:42:10 +08:00