[DTensor][Optim] Add support for fused_adam and fused_adamw when lr is a tensor (#126750)

Fixes #126670

In this PR, we update the following:
1. lr is an kwarg. Add support to automatically turn on implict replication for kwarg. We only did this for arg previously.
2. add associated tensor_lr ops in pointwises.py
3. add associated unit test in test_optimizers.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126750
Approved by: https://github.com/wanchaol, https://github.com/msaroufim
This commit is contained in:
wz337 2024-05-21 21:38:01 +00:00 committed by PyTorch MergeBot
parent 7ee74d986a
commit fed536dbcf
3 changed files with 85 additions and 38 deletions

View file

@ -93,8 +93,8 @@ class TestDTensorOptimizer(DTensorTestBase):
def test_adam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
# TODO: add fused_adam support
adam_configs = [
# lr as a Tensor is not supported for capturable=False and foreach=True
adam_float_lr_configs = [
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
@ -105,6 +105,8 @@ class TestDTensorOptimizer(DTensorTestBase):
"maximize": True,
"amsgrad": True,
},
]
fused_adam_float_lr_configs = [
{"lr": 0.1, "fused": True},
{"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "fused": True},
{
@ -115,6 +117,22 @@ class TestDTensorOptimizer(DTensorTestBase):
"fused": True,
},
]
# lr could be a Tensor or a float when fused=True for adam optimizer
fused_adam_tensor_lr_configs = [
{**config, "lr": torch.tensor(0.1)}
for config in fused_adam_float_lr_configs
]
fused_adam_tensor_lr_configs.extend(
[
{**config, "lr": torch.tensor([0.1])}
for config in fused_adam_float_lr_configs
]
)
adam_configs = [
*adam_float_lr_configs,
*fused_adam_float_lr_configs,
*fused_adam_tensor_lr_configs,
]
for config in adam_configs:
mod = MLPModule(self.device_type)
@ -134,7 +152,8 @@ class TestDTensorOptimizer(DTensorTestBase):
def test_adamw_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
adamw_configs = [
# lr as a Tensor is not supported for capturable=False and foreach=True
adamw_float_lr_configs = [
{"lr": 0.1, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05, "foreach": False},
{"lr": 0.1, "weight_decay": 0.05},
@ -153,6 +172,8 @@ class TestDTensorOptimizer(DTensorTestBase):
"maximize": True,
"amsgrad": True,
},
]
fused_adamw_float_lr_configs = [
{"lr": 0.1, "weight_decay": 0.05, "fused": True},
{
"lr": 0.1,
@ -172,6 +193,22 @@ class TestDTensorOptimizer(DTensorTestBase):
"fused": True,
},
]
# lr could be a Tensor or a float when fused=True for adamW optimizer
fused_adamw_tensor_lr_configs = [
{**config, "lr": torch.tensor(0.1)}
for config in fused_adamw_float_lr_configs
]
fused_adamw_tensor_lr_configs.extend(
[
{**config, "lr": torch.tensor([0.1])}
for config in fused_adamw_float_lr_configs
]
)
adamw_configs = [
*adamw_float_lr_configs,
*fused_adamw_float_lr_configs,
*fused_adamw_tensor_lr_configs,
]
for config in adamw_configs:
mod = MLPModule(self.device_type)

View file

@ -297,6 +297,41 @@ class OpDispatcher:
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None
def try_get_replicate_spec(
tensor_arg: torch.Tensor, mesh: "DeviceMesh"
) -> DTensorSpec:
# tensor_arg is an instance of torch.Tensor and could be an arg or kwarg.
if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
warnings.warn(
"Found a non-scalar tensor with numel=1 and ndim!=0, "
"we are implicitly creating a replicated DTensor for it. "
"However, please consider changing it to a scalar tensor "
"or explicitly create a DTensor under distributed enviroment."
)
# if the arg.numel() == 1, arg.ndim could be 0 or 1.
if (
tensor_arg.ndim <= 1
and tensor_arg.numel() == 1
or self._allow_implicit_replication
):
# scalar tensor can be safely treated as replicated
replication_spec = DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=tensor_arg.shape,
stride=tensor_arg.stride(),
dtype=tensor_arg.dtype,
),
)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
return replication_spec
for arg in args_list:
if isinstance(arg, dtensor.DTensor):
args_schema.append(arg._spec)
@ -309,37 +344,9 @@ class OpDispatcher:
else:
mesh = arg.device_mesh
elif isinstance(arg, torch.Tensor):
if arg.numel() == 1 and arg.ndim == 1:
warnings.warn(
"Found a non-scalar tensor with numel=1 and ndim!=0, "
"we are implicitly creating a replicated DTensor for it. "
"However, please consider changing it to a scalar tensor "
"or explicitly create a DTensor under distributed enviroment."
)
# if the arg.numel() == 1, arg.ndim could be 0 or 1.
if (
arg.ndim <= 1
and arg.numel() == 1
or self._allow_implicit_replication
):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
# scalar tensor can be safely treated as replicated
args_schema.append(
DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=arg.shape, stride=arg.stride(), dtype=arg.dtype
),
)
)
local_args.append(arg)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
args_schema.append(try_get_replicate_spec(arg, mesh))
local_args.append(arg)
else:
args_schema.append(arg)
local_args.append(arg)
@ -356,10 +363,9 @@ class OpDispatcher:
else:
mesh = v.device_mesh
elif isinstance(v, torch.Tensor):
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
kwargs_schema[k] = try_get_replicate_spec(v, mesh)
local_kwargs[k] = v
else:
kwargs_schema[k] = v
local_kwargs[k] = v

View file

@ -644,8 +644,12 @@ for op in for_each_linearity_ops:
fused_ops = [
aten._fused_adam_.default,
aten._fused_adam.default,
aten._fused_adam.tensor_lr,
aten._fused_adam_.tensor_lr,
aten._fused_adamw_.default,
aten._fused_adamw.default,
aten._fused_adamw.tensor_lr,
aten._fused_adamw_.tensor_lr,
]
for op in fused_ops: