pytorch/test/distributed/tensor/parallel
Wanchao Liang 7775fee10f [tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431)
as titled, this PR refactors the PrepareModuleInput style to have common
method prepare_input_arg, allow both args/kwargs to reuse this logic

This also fixes https://github.com/pytorch/pytorch/issues/128365

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128431
Approved by: https://github.com/awgu
2024-06-12 19:16:33 +00:00
..
__init__.py
test_ddp_2d_parallel.py
test_fsdp_2d_parallel.py
test_micro_pipeline_tp.py Introduce Inductor passes to micro-pipeline all-gather-matmul and matmul-reduce-scatter in certain cases (#126598) 2024-06-04 09:06:56 +00:00
test_parallelize_api.py
test_tp_examples.py [DTensor] Turn on foreach implementation of optimizer for DTensor by default (#123394) 2024-05-15 16:45:42 +00:00
test_tp_random_state.py
test_tp_style.py [tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431) 2024-06-12 19:16:33 +00:00