onnxruntime/orttraining/orttraining/python/training/ortmodule
Adam Louly 4ce7bbf6f1
Add LayerSpec Support to ORTPipelineModule (#20410)
### Description
In Deepspeed's Pipeline Parallel Implementation, there is a class used
to instantiate the object after it's moved to the device and assigned in
a stage.

This approach helps reduce peak memory usage. 

In this PR, we're adding support to ORT for wrapping this LayerSpec.
2024-04-23 17:57:08 -07:00
..
experimental Add LayerSpec Support to ORTPipelineModule (#20410) 2024-04-23 17:57:08 -07:00
graph_optimizers ATen Op Supports Int Return Type and CPU Tensor Arguments (#19773) 2024-03-06 10:11:46 +08:00
torch_cpp_extensions Fix torch cpp extension build warnings (#19842) 2024-03-12 10:51:30 +08:00
__init__.py Prompt layer-wise recompute when applicable (#20126) 2024-04-10 11:50:28 +08:00
_custom_autograd_function.py Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
_custom_autograd_function_exporter.py Prompt layer-wise recompute when applicable (#20126) 2024-04-10 11:50:28 +08:00
_custom_gradient_registry.py Bump ruff linter to 0.2.1 (#19471) 2024-02-08 16:08:27 -08:00
_custom_op_symbolic_registry.py Fix softmax export (#20057) 2024-03-26 13:09:20 +08:00
_execution_agent.py Memory optimization refactor and refinement (#17481) 2023-11-23 11:39:00 +08:00
_fallback.py Bump linter versions (#18341) 2023-11-08 13:04:40 -08:00
_fallback_exceptions.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_gradient_accumulation_manager.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_graph_execution_interface.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_graph_execution_manager.py fix embedding sparsity log bug of -1% density (#20420) 2024-04-23 20:37:50 +08:00
_graph_execution_manager_factory.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00
_inference_manager.py Fix memory stats printing (#20061) 2024-03-26 21:25:59 +08:00
_io.py Skip module clone for preparing large model export (#18663) 2023-12-05 12:41:17 -08:00
_logger.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_mem_efficient_grad_mgmt.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_onnx_models.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_pythonop_helper.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_runtime_inspector.py Check padding density by input of embedding module (#19821) 2024-04-10 18:45:51 +08:00
_torch_module_factory.py Manage ORTModule configurations consistently (#16396) 2023-06-27 19:19:36 +08:00
_torch_module_interface.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_torch_module_ort.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00
_torch_module_pytorch.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_training_manager.py Fix memory stats printing (#20061) 2024-03-26 21:25:59 +08:00
_utils.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_zero_stage3_compatibility.py Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
graph_optimizer_registry.py [ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959) 2023-10-27 10:29:27 +08:00
options.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
ortmodule.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00