Making _MeshEnv subclassing thread local (#124555)

With _mesh_resources being global var, when thread pg based testing is used (aka spawn_threads_and_init_comms()), the last rank with the same key would overwrite the formers. This isn't an issue in regular process-based runtime as logically each key is unique.

Example failure: https://github.com/pytorch/pytorch/actions/runs/8779134353/job/24087295785
```
RuntimeError: Could not resolve the process group registered under the name 8
or
Throwing assert not none error
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124555
Approved by: https://github.com/xunnanxu, https://github.com/wanchaol
This commit is contained in:
Iris Zhang (PyTorch) 2024-04-26 02:45:42 +00:00 committed by PyTorch MergeBot
parent e913f77c60
commit 43f4e71daa

View file

@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import math
import threading
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
@ -57,7 +58,7 @@ else:
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
)
class _MeshEnv:
class _MeshEnv(threading.local):
def __init__(self) -> None:
self.mesh_stack: List[DeviceMesh] = []
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}