pytorch/test/distributed/_tensor/test_optimizers.py
wz337 bb67a28738 [DTensor] Enable Adamax foreach optimizer (#119850)
Enable Adamax foreach optimizer and add DTensor unit test for Adamax.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119850
Approved by: https://github.com/wanchaol
2024-02-14 20:43:00 +00:00

472 lines
15 KiB
Python

# Owner(s): ["oncall: distributed"]
from copy import deepcopy
import torch
import torch.nn as nn
from torch.distributed._tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
with_comms,
)
# shard function to do full sharding on all parameters of a module
def shard_fn(name, module, device_mesh):
if isinstance(module, nn.Linear):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Shard(0)])
)
# make sure partial sum get cleared after backward()
dist_param.register_hook(
lambda grad: grad.redistribute(placements=[Shard(0)])
)
module.register_parameter(name, dist_param)
# prepare input
def input_fn(inputs, device_mesh):
# split the input tensor to be sharded input
dist_inp = distribute_tensor(inputs[0], device_mesh, [Shard(0)])
return dist_inp
# prepare output to be local torch.Tensor
def output_fn(outputs, device_mesh):
assert isinstance(outputs, DTensor)
return outputs.redistribute(placements=[Replicate()] * device_mesh.ndim).to_local()
class TestDTensorOptimizer(DTensorTestBase):
def _assert_optimizer(
self,
mesh,
model,
optim,
dist_model,
dist_optim,
inputs,
):
for iter_idx in range(2):
# run forward/backward/optim for original model
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
out = model(inputs)
loss = out.sum()
loss.backward()
optim.step()
# run forward/backward/optim for distributed model
dist_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
dist_out = dist_model(inputs)
dist_loss = dist_out.sum()
dist_loss.backward()
dist_optim.step()
# check that the optimizer update parameters with same numerics
for p1, p2 in zip(model.parameters(), dist_model.parameters()):
p2 = p2.full_tensor()
self.assertEqual(p1, p2)
@with_comms
def test_adam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
# TODO: add fused_adam support
adam_configs = [
{"lr": 0.1},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "foreach": True},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "foreach": True},
{
"lr": 0.1,
"weight_decay": 0.05,
"maximize": True,
"amsgrad": True,
"foreach": True,
},
]
for config in adam_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.Adam(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.Adam(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_adamw_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
# TODO: add fused_adamw support
adamw_configs = [
{"lr": 0.1},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"amsgrad": True,
"foreach": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"maximize": True,
"amsgrad": True,
"foreach": True,
},
]
for config in adamw_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.AdamW(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.AdamW(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_sgd_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
sgd_configs = [
{"lr": 0.1},
{"lr": 0.1, "momentum": 0.05},
{"lr": 0.1, "momentum": 0.05, "foreach": True},
{"lr": 0.1, "momentum": 0.06, "dampening": 0.07, "foreach": True},
{
"lr": 0.1,
"momentum": 0.08,
"weight_decay": 0.05,
"nesterov": True,
"maximize": True,
},
{
"lr": 0.1,
"momentum": 0.08,
"weight_decay": 0.05,
"nesterov": True,
"maximize": True,
"foreach": True,
},
]
for config in sgd_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.SGD(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.SGD(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_adagrad_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
adagrad_configs = [
{"lr": 0.1},
{"lr": 0.1, "lr_decay": 0.05},
{"lr": 0.1, "lr_decay": 0.02, "weight_decay": 0.05},
{
"lr": 0.1,
"lr_decay": 0.02,
"weight_decay": 0.05,
"initial_accumulator_value": 0.03,
},
{
"lr": 0.1,
"lr_decay": 0.02,
"weight_decay": 0.05,
"initial_accumulator_value": 0.03,
"eps": 1e-6,
},
{
"lr": 0.1,
"lr_decay": 0.02,
"weight_decay": 0.05,
"initial_accumulator_value": 0.03,
"eps": 1e-6,
"maximize": True,
},
{
"lr": 0.1,
"lr_decay": 0.02,
"weight_decay": 0.05,
"initial_accumulator_value": 0.03,
"eps": 1e-6,
"maximize": True,
"foreach": True,
},
]
for config in adagrad_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.Adagrad(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.Adagrad(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_RMSprop_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
RMSprop_configs = [
{"lr": 0.1},
{"lr": 0.1, "alpha": 0.85},
{"lr": 0.1, "alpha": 0.88, "eps": 1e-6},
{"lr": 0.1, "alpha": 0.88, "eps": 1e-6, "weight_decay": 0.05},
{
"lr": 0.1,
"alpha": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"momentum": 0.9,
},
{
"lr": 0.1,
"alpha": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"momentum": 0.9,
"centered": True,
},
{
"lr": 0.1,
"alpha": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"momentum": 0.9,
"centered": True,
"maximize": True,
},
{
"lr": 0.1,
"alpha": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"momentum": 0.9,
"centered": True,
"maximize": True,
"foreach": True,
},
]
for config in RMSprop_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.RMSprop(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.RMSprop(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_adadelta_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
adadelta_configs = [
{"lr": 0.1},
{"lr": 0.1, "rho": 0.85},
{"lr": 0.1, "rho": 0.88, "eps": 1e-5},
{"lr": 0.1, "rho": 0.88, "eps": 1e-6, "weight_decay": 0.05},
{
"lr": 0.1,
"rho": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"rho": 0.88,
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
"maximize": True,
},
]
for config in adadelta_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.Adadelta(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.Adadelta(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_nadam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
nadam_configs = [
{"lr": 0.1},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"decoupled_weight_decay": True,
"foreach": True,
},
]
for config in nadam_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.NAdam(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.NAdam(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_radam_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
radam_configs = [
{"lr": 0.1},
{"lr": 0.1, "weight_decay": 0.05},
{"lr": 0.1, "weight_decay": 0.05, "foreach": True},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"decoupled_weight_decay": True,
"foreach": True,
},
]
for config in radam_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.RAdam(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.RAdam(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
@with_comms
def test_adamax_1d_sharding(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
adamax_configs = [
{"lr": 0.1},
{"lr": 0.1, "betas": (0.6, 0.66)},
{"lr": 0.1, "betas": (0.6, 0.66), "eps": 1e-6},
{"lr": 0.1, "betas": (0.6, 0.66), "eps": 1e-6, "weight_decay": 0.05},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
},
{
"lr": 0.1,
"betas": (0.6, 0.66),
"eps": 1e-6,
"weight_decay": 0.05,
"foreach": True,
"maximize": True,
},
]
for config in adamax_configs:
mod = MLPModule(self.device_type)
opt = torch.optim.Adamax(mod.parameters(), **config)
dist_mod = distribute_module(
deepcopy(mod), mesh, shard_fn, input_fn, output_fn
)
dist_opt = torch.optim.Adamax(dist_mod.parameters(), **config)
# use ones to make sure the single machine model have the same input
# on different ranks
inp = torch.ones(8, 10, device=self.device_type)
self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
if __name__ == "__main__":
run_tests()