mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Address lintrunner errors
This commit is contained in:
parent
53dbe12ea4
commit
45cd195f24
1 changed files with 8 additions and 4 deletions
|
|
@ -40,6 +40,7 @@ if dist.is_available():
|
|||
from torch.distributed.distributed_c10d import (
|
||||
_get_default_group,
|
||||
_rank_not_in_group,
|
||||
ProcessGroup,
|
||||
ReduceOp,
|
||||
)
|
||||
from torch.distributed.utils import (
|
||||
|
|
@ -642,6 +643,8 @@ class DistributedDataParallel(Module, Joinable):
|
|||
# used to track whether the given thread is inside ddp forward for torchdynamo purposes
|
||||
_active_ddp_module: Optional["DistributedDataParallel"] = None
|
||||
|
||||
reducer: dist.Reducer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: Module,
|
||||
|
|
@ -650,7 +653,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
dim: int = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
init_sync=True,
|
||||
process_group: Optional[Any] = None,
|
||||
process_group: "Optional[ProcessGroup]" = None,
|
||||
bucket_cap_mb: float = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
|
|
@ -680,7 +683,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
elif process_group is None and device_mesh is None:
|
||||
self.process_group = _get_default_group()
|
||||
elif device_mesh is None:
|
||||
self.process_group = process_group
|
||||
self.process_group = process_group # type: ignore[assignment]
|
||||
else:
|
||||
if device_mesh.ndim != 1:
|
||||
raise RuntimeError(
|
||||
|
|
@ -814,7 +817,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
if bucket_cap_mb is None:
|
||||
# default case (bucket cap is 25 MiB)
|
||||
bucket_cap_mb = 25
|
||||
self.bucket_bytes_cap_default = True
|
||||
self.bucket_bytes_cap_default: bool = True
|
||||
else:
|
||||
self.bucket_bytes_cap_default = False
|
||||
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
||||
|
|
@ -926,7 +929,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
)
|
||||
if self._use_python_reducer:
|
||||
torch._inductor.config._fuse_ddp_communication = True
|
||||
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
|
||||
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb # type: ignore[assignment]
|
||||
# Directly adding this to the trace rule will disturb the users
|
||||
# who are using DDPOptimizer.
|
||||
torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
|
||||
|
|
@ -1095,6 +1098,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
# Do not cast DDP ignored parameters.
|
||||
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
|
||||
continue
|
||||
assert hasattr(param, "_mp_param")
|
||||
_alloc_storage(param._mp_param, param.size())
|
||||
# copy() implicitly casts to low precision
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in a new issue