From 43f4e71daa6dce6014d30da046d28f14cf30d5a4 Mon Sep 17 00:00:00 2001 From: "Iris Zhang (PyTorch)" Date: Fri, 26 Apr 2024 02:45:42 +0000 Subject: [PATCH] Making _MeshEnv subclassing thread local (#124555) With _mesh_resources being global var, when thread pg based testing is used (aka spawn_threads_and_init_comms()), the last rank with the same key would overwrite the formers. This isn't an issue in regular process-based runtime as logically each key is unique. Example failure: https://github.com/pytorch/pytorch/actions/runs/8779134353/job/24087295785 ``` RuntimeError: Could not resolve the process group registered under the name 8 or Throwing assert not none error ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124555 Approved by: https://github.com/xunnanxu, https://github.com/wanchaol --- torch/distributed/device_mesh.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 79e5c508a68..2b176583dea 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging import math +import threading from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch @@ -57,7 +58,7 @@ else: "DeviceMesh requires numpy >= 1.21 to be installed for type checking" ) - class _MeshEnv: + class _MeshEnv(threading.local): def __init__(self) -> None: self.mesh_stack: List[DeviceMesh] = [] self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}