# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import torch from torch.distributed.pipelining.microbatch import ( merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec, ) from torch.testing._internal.common_utils import run_tests, TestCase d_hid = 512 class MicrobatchTests(TestCase): def test_split_and_merge(self): x0 = torch.randn(128, d_hid) x1 = torch.randn(256, d_hid) x2 = torch.randn(512, d_hid) args = (x0, x1, x2) kwargs = {"x0": x0, "x1": x1, "x2": x2} # Default chunking: dim 0 arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2) assert len(arg_chunks) == 2 assert len(kwarg_chunks) == 2 assert arg_chunks[0][0].shape == torch.Size([64, d_hid]) assert arg_chunks[1][0].shape == torch.Size([64, d_hid]) assert arg_chunks[0][1].shape == torch.Size([128, d_hid]) assert arg_chunks[0][2].shape == torch.Size([256, d_hid]) assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid]) assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid]) assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid]) # Merge chunks back together merged_args = merge_chunks( arg_chunks, (TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)), ) torch.testing.assert_close(merged_args, args) merged_kwargs = merge_chunks( kwarg_chunks, { "x0": TensorChunkSpec(0), "x1": TensorChunkSpec(0), "x2": TensorChunkSpec(0), }, ) torch.testing.assert_close(merged_kwargs, kwargs) print("Microbatch test passed") if __name__ == "__main__": run_tests()