Address lintrunner errors

This commit is contained in:
Mauricio Villegas 2024-09-30 07:40:30 +02:00
parent 53dbe12ea4
commit 45cd195f24

View file

@ -40,6 +40,7 @@ if dist.is_available():
from torch.distributed.distributed_c10d import (
_get_default_group,
_rank_not_in_group,
ProcessGroup,
ReduceOp,
)
from torch.distributed.utils import (
@ -642,6 +643,8 @@ class DistributedDataParallel(Module, Joinable):
# used to track whether the given thread is inside ddp forward for torchdynamo purposes
_active_ddp_module: Optional["DistributedDataParallel"] = None
reducer: dist.Reducer
def __init__(
self,
module: Module,
@ -650,7 +653,7 @@ class DistributedDataParallel(Module, Joinable):
dim: int = 0,
broadcast_buffers: bool = True,
init_sync=True,
process_group: Optional[Any] = None,
process_group: "Optional[ProcessGroup]" = None,
bucket_cap_mb: float = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
@ -680,7 +683,7 @@ class DistributedDataParallel(Module, Joinable):
elif process_group is None and device_mesh is None:
self.process_group = _get_default_group()
elif device_mesh is None:
self.process_group = process_group
self.process_group = process_group # type: ignore[assignment]
else:
if device_mesh.ndim != 1:
raise RuntimeError(
@ -814,7 +817,7 @@ class DistributedDataParallel(Module, Joinable):
if bucket_cap_mb is None:
# default case (bucket cap is 25 MiB)
bucket_cap_mb = 25
self.bucket_bytes_cap_default = True
self.bucket_bytes_cap_default: bool = True
else:
self.bucket_bytes_cap_default = False
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
@ -926,7 +929,7 @@ class DistributedDataParallel(Module, Joinable):
)
if self._use_python_reducer:
torch._inductor.config._fuse_ddp_communication = True
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb # type: ignore[assignment]
# Directly adding this to the trace rule will disturb the users
# who are using DDPOptimizer.
torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST.add(
@ -1095,6 +1098,7 @@ class DistributedDataParallel(Module, Joinable):
# Do not cast DDP ignored parameters.
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
continue
assert hasattr(param, "_mp_param")
_alloc_storage(param._mp_param, param.size())
# copy() implicitly casts to low precision
with torch.no_grad():