mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Making _MeshEnv subclassing thread local (#124555)
With _mesh_resources being global var, when thread pg based testing is used (aka spawn_threads_and_init_comms()), the last rank with the same key would overwrite the formers. This isn't an issue in regular process-based runtime as logically each key is unique. Example failure: https://github.com/pytorch/pytorch/actions/runs/8779134353/job/24087295785 ``` RuntimeError: Could not resolve the process group registered under the name 8 or Throwing assert not none error ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124555 Approved by: https://github.com/xunnanxu, https://github.com/wanchaol
This commit is contained in:
parent
e913f77c60
commit
43f4e71daa
1 changed files with 2 additions and 1 deletions
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -57,7 +58,7 @@ else:
|
|||
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
|
||||
)
|
||||
|
||||
class _MeshEnv:
|
||||
class _MeshEnv(threading.local):
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue