From e06bff8bbea84ce672285bc34690b2e45a1b63ab Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 4 Dec 2023 19:56:18 -0800 Subject: [PATCH] [AOTI] Handle empty input args (#114682) Summary: When the model takes no inputs, AOTInductor relies on checking weights to figure out which device to compile the model into. Currently recording buffer device type happens too late, and this PR fixes that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114682 Approved by: https://github.com/chenyang78 --- test/inductor/test_aot_inductor.py | 14 ++++++++++++++ torch/_inductor/graph.py | 19 +++++++++---------- torch/_inductor/ir.py | 6 ++---- torch/_inductor/scheduler.py | 3 +-- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 2f296af308c..93aaac3076b 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1465,6 +1465,20 @@ class AOTInductorTestsTemplate: inputs = (torch.rand(4, 4, 4, 4, device=self.device),) self.check_model(Model(4), inputs) + def test_no_args(self): + class Model(torch.nn.Module): + def __init__(self, m, n): + super().__init__() + self.weight = torch.nn.Parameter( + torch.randn(m, n), + ) + self.alpha = torch.nn.Parameter(torch.randn(m, n)) + + def forward(self): + return self.weight * self.alpha + + self.check_model(Model(6, 4), ()) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e4840b7777f..2e2e2d19c22 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -472,9 +472,10 @@ class GraphLowering(torch.fx.Interpreter): self._warned_fallback.add(name) perf_hint_log.info("Using FallbackKernel: %s", name) - def add_device_idx(self, idx: Optional[int]): - if idx is not None: - self.device_idxs.add(idx) + def add_device_info(self, device: torch.device): + self.device_types.add(device.type) + if device.index is not None: + self.device_idxs.add(device.index) @property def fake_mode(self): @@ -521,6 +522,9 @@ class GraphLowering(torch.fx.Interpreter): name = f"buf{len(self.buffers)}" self.buffers.append(buffer) self.name_to_buffer[name] = buffer + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 + if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements(): + self.add_device_info(buffer.get_device()) return name def register_list(self, buffer_names: List[str]): @@ -645,8 +649,7 @@ class GraphLowering(torch.fx.Interpreter): ) self.graph_inputs[target] = tensor self.graph_inputs_original[target] = tensor.data.data - self.device_types.add(example.device.type) - self.add_device_idx(example.device.index) + self.add_device_info(example.device) return tensor def call_function(self, target, args, kwargs): @@ -979,10 +982,6 @@ class GraphLowering(torch.fx.Interpreter): return device_types = self.device_types.copy() - # In terms of some operations that don't have input tensors, we need to - # check the device of the buffers. - for buffer in self.buffers: - device_types.add(buffer.get_device().type) device_types.discard("cpu") # TODO(Eikan): Only support mixing cpu and other device now. assert len(device_types) <= 1, "Does not support mixing {}".format( @@ -1015,7 +1014,7 @@ class GraphLowering(torch.fx.Interpreter): else: assert isinstance( x, torch.Tensor - ), "Unknown type when creating real inputs" + ), "Unknown type when creating real inputs" + str(type(x)) return x with torch.utils._python_dispatch._disable_current_modes(): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 5ad4635af2e..d454f550860 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4160,10 +4160,8 @@ class DeviceCopy(ExternKernelOut): ): return x.constant_to_device(device) - V.graph.device_types.add(device.type) - V.graph.add_device_idx(device.index) - V.graph.device_types.add(x.get_device().type) - V.graph.add_device_idx(x.get_device().index) + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) developer_warning("DeviceCopy in input program") return DeviceCopy( diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index e668ad4082e..2d78e3e113c 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2130,8 +2130,7 @@ class Scheduler: assert ( device.type != "cuda" or device.index is not None ), f"{device} should have been normalized in lowering" - V.graph.device_types.add(device.type) - V.graph.add_device_idx(device.index) + V.graph.add_device_info(device) device_scheduling = get_scheduling_for_device(device.type) if device_scheduling is None: