**Summary**
Fixed issue with updating the current module when transitioning between child module to parent module and in the backward pass. The first issue is caused because the prehook is not called again when we go back to the parent module and that the hook being used was a register_module_forward_hook, which runs before the register_module_hook used in redistribute, causing the collective call to be assigned to the incorrect module. In order to do this, I updated the current module to be the parent module in a register_forward_hook in the module tracker. The second issue was caused by the parent set in the module tracker I inherit from being incorrect. I fixed this issue by saving the parents of each module and using them in collective counter instead of the incorrect set. I have updated the example in module_operation_tracing to reflect the correct output. In addition, I changed the test cases that used the incompatible old CommDebugMode.
**Test Case**
1. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e MLP_operation_tracing
2. pytest test/distributed/_tensor/debug/test_comm_mode_features.py -s -k test_transformer_module_tracing
3. python test/distributed/_composable/fsdp/test_fully_shard_training.py -k TestFullyShardGradientAccumulation.test_gradient_accumulation
4. python test/distributed/_tensor/test_math_ops.py -k DistMathOpsTest.test_layer_norm_bwd
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130995
Approved by: https://github.com/XilunWu
ghstack dependencies: #130410
**Summary**
In order to give users more information, I have added the deviceMesh for operations with DTensor inputs, and module parameter sharding and FQN. These changes have only been placed in operation tracing log. In the future, I plan to just have one logging function with an argument to show how detailed a user wants the log to be, and will get rid of the module tracing log function. This information has also been added to the JSON dump and can be seen in the browser visual. I have also edited the test case file as the module_depth dictionary has been replaced with module_helper_dict and have edited the example output for the MLP operation tracing which can be seen below:
**Test Plan**
1. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e MLP_json_dump
2. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e transformer_json_dump
3. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e MLP_operation_tracing
4. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e transformer_operation_tracing
5. pytest test/distributed/_tensor/debug/test_comm_mode_features.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130072
Approved by: https://github.com/XilunWu
ghstack dependencies: #129994
**Summary**
Currently, comm_mode only allowed users to differentiate between forward and backward passes at the operational level. I modified the code so that users can now see the collective counts for the passes at a module level. I decided to slightly change how the output was formatted making it easier to differentiate between a collective count and an operation. I have designed the operational trace table function so that in the future, a user can use command line arguments in order to determine the level of information they want to display instead of having two similar functions. Finally, I have updated the new output and test cases for comm_mode example and test files. The expected output for the first 3 examples are shown below:
<img width="320" alt="Screenshot 2024-06-26 at 2 30 25 PM" src="https://github.com/pytorch/pytorch/assets/50644008/b8e88075-a07f-4e84-b728-a08959df3661">
<img width="497" alt="Screenshot 2024-06-26 at 2 29 15 PM" src="https://github.com/pytorch/pytorch/assets/50644008/5ef4bea7-1355-4089-bfb0-c7e3f588ac77">
<img width="615" alt="Screenshot 2024-06-26 at 2 31 05 PM" src="https://github.com/pytorch/pytorch/assets/50644008/feacae51-76f7-403b-b6cd-dd15e981770e">
**Test Plan**
1. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e MLP_module_tracing
2. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e transformer_module_tracing
3. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e MLP_operation_tracing
4. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e transformer_operation_tracing
5. pytest test/distributed/_tensor/debug/test_comm_mode_features.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129602
Approved by: https://github.com/XilunWu, https://github.com/wz337
**Summary**
Currently, there is only an example file for comm_mode and its features. I have created test cases that mirror the examples while the more complicated test cases also ensure that comm_mode resets all variables when used multiple times in the same function. This test case suite will also help developers ensure that new code they add to comm_mode does not affect correctness of old features.
#128536
**Test Plan**
pytest test/distributed/_tensor/debug/test_comm_mode_features.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128729
Approved by: https://github.com/XilunWu
**Summary**
Added all_reduce_coalesced tracing to CommDebugMode and added test case to test_comm_mode test suite.
**Test Plan**
pytest test/distributed/_tensor/debug/test_comm_mode.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127025
Approved by: https://github.com/XilunWu
This PR adds a CommDebugMode debugging tool to record the number of
distributed collectives, utilizing TorchDispatchMode, the idea borrows
from the FlopCounterMode and we can expand this later to make it more
feature complete like the FlopCounterMode
This is useful for debugging with DTensor and testing, in general this
fits for any complex distributed algorithms where it's non-trival to
understand the algorithm, we can use this tool to understand what
happened under the hood., we can later cover c10d collectives directly
Not sure if it would be a good general distributed debug tool yet,
so adding to the dtensor package first
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113592
Approved by: https://github.com/wconstab