mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[DTensor][Optim] Add support for fused_adam and fused_adamw when lr is a tensor (#126750)
Fixes #126670 In this PR, we update the following: 1. lr is an kwarg. Add support to automatically turn on implict replication for kwarg. We only did this for arg previously. 2. add associated tensor_lr ops in pointwises.py 3. add associated unit test in test_optimizers.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/126750 Approved by: https://github.com/wanchaol, https://github.com/msaroufim
This commit is contained in:
parent
7ee74d986a
commit
fed536dbcf
3 changed files with 85 additions and 38 deletions
|
|
@ -93,8 +93,8 @@ class TestDTensorOptimizer(DTensorTestBase):
|
|||
def test_adam_1d_sharding(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
# TODO: add fused_adam support
|
||||
adam_configs = [
|
||||
# lr as a Tensor is not supported for capturable=False and foreach=True
|
||||
adam_float_lr_configs = [
|
||||
{"lr": 0.1, "foreach": False},
|
||||
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
|
||||
{"lr": 0.1, "weight_decay": 0.05},
|
||||
|
|
@ -105,6 +105,8 @@ class TestDTensorOptimizer(DTensorTestBase):
|
|||
"maximize": True,
|
||||
"amsgrad": True,
|
||||
},
|
||||
]
|
||||
fused_adam_float_lr_configs = [
|
||||
{"lr": 0.1, "fused": True},
|
||||
{"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "fused": True},
|
||||
{
|
||||
|
|
@ -115,6 +117,22 @@ class TestDTensorOptimizer(DTensorTestBase):
|
|||
"fused": True,
|
||||
},
|
||||
]
|
||||
# lr could be a Tensor or a float when fused=True for adam optimizer
|
||||
fused_adam_tensor_lr_configs = [
|
||||
{**config, "lr": torch.tensor(0.1)}
|
||||
for config in fused_adam_float_lr_configs
|
||||
]
|
||||
fused_adam_tensor_lr_configs.extend(
|
||||
[
|
||||
{**config, "lr": torch.tensor([0.1])}
|
||||
for config in fused_adam_float_lr_configs
|
||||
]
|
||||
)
|
||||
adam_configs = [
|
||||
*adam_float_lr_configs,
|
||||
*fused_adam_float_lr_configs,
|
||||
*fused_adam_tensor_lr_configs,
|
||||
]
|
||||
|
||||
for config in adam_configs:
|
||||
mod = MLPModule(self.device_type)
|
||||
|
|
@ -134,7 +152,8 @@ class TestDTensorOptimizer(DTensorTestBase):
|
|||
def test_adamw_1d_sharding(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
adamw_configs = [
|
||||
# lr as a Tensor is not supported for capturable=False and foreach=True
|
||||
adamw_float_lr_configs = [
|
||||
{"lr": 0.1, "foreach": False},
|
||||
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
|
||||
{"lr": 0.1, "weight_decay": 0.05},
|
||||
|
|
@ -153,6 +172,8 @@ class TestDTensorOptimizer(DTensorTestBase):
|
|||
"maximize": True,
|
||||
"amsgrad": True,
|
||||
},
|
||||
]
|
||||
fused_adamw_float_lr_configs = [
|
||||
{"lr": 0.1, "weight_decay": 0.05, "fused": True},
|
||||
{
|
||||
"lr": 0.1,
|
||||
|
|
@ -172,6 +193,22 @@ class TestDTensorOptimizer(DTensorTestBase):
|
|||
"fused": True,
|
||||
},
|
||||
]
|
||||
# lr could be a Tensor or a float when fused=True for adamW optimizer
|
||||
fused_adamw_tensor_lr_configs = [
|
||||
{**config, "lr": torch.tensor(0.1)}
|
||||
for config in fused_adamw_float_lr_configs
|
||||
]
|
||||
fused_adamw_tensor_lr_configs.extend(
|
||||
[
|
||||
{**config, "lr": torch.tensor([0.1])}
|
||||
for config in fused_adamw_float_lr_configs
|
||||
]
|
||||
)
|
||||
adamw_configs = [
|
||||
*adamw_float_lr_configs,
|
||||
*fused_adamw_float_lr_configs,
|
||||
*fused_adamw_tensor_lr_configs,
|
||||
]
|
||||
|
||||
for config in adamw_configs:
|
||||
mod = MLPModule(self.device_type)
|
||||
|
|
|
|||
|
|
@ -297,6 +297,41 @@ class OpDispatcher:
|
|||
local_kwargs: Dict[str, object] = {}
|
||||
mesh: Optional[DeviceMesh] = None
|
||||
|
||||
def try_get_replicate_spec(
|
||||
tensor_arg: torch.Tensor, mesh: "DeviceMesh"
|
||||
) -> DTensorSpec:
|
||||
# tensor_arg is an instance of torch.Tensor and could be an arg or kwarg.
|
||||
if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
|
||||
warnings.warn(
|
||||
"Found a non-scalar tensor with numel=1 and ndim!=0, "
|
||||
"we are implicitly creating a replicated DTensor for it. "
|
||||
"However, please consider changing it to a scalar tensor "
|
||||
"or explicitly create a DTensor under distributed enviroment."
|
||||
)
|
||||
|
||||
# if the arg.numel() == 1, arg.ndim could be 0 or 1.
|
||||
if (
|
||||
tensor_arg.ndim <= 1
|
||||
and tensor_arg.numel() == 1
|
||||
or self._allow_implicit_replication
|
||||
):
|
||||
# scalar tensor can be safely treated as replicated
|
||||
replication_spec = DTensorSpec(
|
||||
mesh,
|
||||
(Replicate(),) * mesh.ndim,
|
||||
tensor_meta=TensorMeta(
|
||||
shape=tensor_arg.shape,
|
||||
stride=tensor_arg.stride(),
|
||||
dtype=tensor_arg.dtype,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
||||
" torch.Tensor to DTensor before calling distributed operators!"
|
||||
)
|
||||
return replication_spec
|
||||
|
||||
for arg in args_list:
|
||||
if isinstance(arg, dtensor.DTensor):
|
||||
args_schema.append(arg._spec)
|
||||
|
|
@ -309,37 +344,9 @@ class OpDispatcher:
|
|||
else:
|
||||
mesh = arg.device_mesh
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
if arg.numel() == 1 and arg.ndim == 1:
|
||||
warnings.warn(
|
||||
"Found a non-scalar tensor with numel=1 and ndim!=0, "
|
||||
"we are implicitly creating a replicated DTensor for it. "
|
||||
"However, please consider changing it to a scalar tensor "
|
||||
"or explicitly create a DTensor under distributed enviroment."
|
||||
)
|
||||
|
||||
# if the arg.numel() == 1, arg.ndim could be 0 or 1.
|
||||
if (
|
||||
arg.ndim <= 1
|
||||
and arg.numel() == 1
|
||||
or self._allow_implicit_replication
|
||||
):
|
||||
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
||||
# scalar tensor can be safely treated as replicated
|
||||
args_schema.append(
|
||||
DTensorSpec(
|
||||
mesh,
|
||||
(Replicate(),) * mesh.ndim,
|
||||
tensor_meta=TensorMeta(
|
||||
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
|
||||
),
|
||||
)
|
||||
)
|
||||
local_args.append(arg)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
||||
" torch.Tensor to DTensor before calling distributed operators!"
|
||||
)
|
||||
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
||||
args_schema.append(try_get_replicate_spec(arg, mesh))
|
||||
local_args.append(arg)
|
||||
else:
|
||||
args_schema.append(arg)
|
||||
local_args.append(arg)
|
||||
|
|
@ -356,10 +363,9 @@ class OpDispatcher:
|
|||
else:
|
||||
mesh = v.device_mesh
|
||||
elif isinstance(v, torch.Tensor):
|
||||
raise RuntimeError(
|
||||
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
|
||||
" torch.Tensor to DTensor before calling distributed operators!"
|
||||
)
|
||||
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
||||
kwargs_schema[k] = try_get_replicate_spec(v, mesh)
|
||||
local_kwargs[k] = v
|
||||
else:
|
||||
kwargs_schema[k] = v
|
||||
local_kwargs[k] = v
|
||||
|
|
|
|||
|
|
@ -644,8 +644,12 @@ for op in for_each_linearity_ops:
|
|||
fused_ops = [
|
||||
aten._fused_adam_.default,
|
||||
aten._fused_adam.default,
|
||||
aten._fused_adam.tensor_lr,
|
||||
aten._fused_adam_.tensor_lr,
|
||||
aten._fused_adamw_.default,
|
||||
aten._fused_adamw.default,
|
||||
aten._fused_adamw.tensor_lr,
|
||||
aten._fused_adamw_.tensor_lr,
|
||||
]
|
||||
|
||||
for op in fused_ops:
|
||||
|
|
|
|||
Loading…
Reference in a new issue