pytorch/torch/_inductor
Adnan Akhundov b1b3f61f2c Skip Triton templates in MM max autotune with zero-size inputs (#106865)
Summary:

MM max autotune (and friends) crash when one of the inputs is zero-size.

E.g., running this code:

```
@torch.compile()
def fn(x, y):
    return torch.mm(x, y)

inps = [torch.rand([0, 30]), torch.rand([30, 40])]
inps = [x.to(device="cuda") for x in inps]
out = fn(*inps)
```

with this command:

```
TORCHINDUCTOR_MAX_AUTOTUNE=1 python test.py
```

raises this error (the top of the stack trace omitted for brevity):

```
...
  File "/data/users/aakhundov/pytorch/torch/_inductor/kernel/mm.py", line 119, in tuned_mm
    return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/aakhundov/pytorch/torch/_inductor/select_algorithm.py", line 960, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/aakhundov/pytorch/torch/_inductor/select_algorithm.py", line 787, in __call__
    timings = self.lookup(
              ^^^^^^^^^^^^
  File "/data/users/aakhundov/pytorch/torch/_inductor/codecache.py", line 267, in lookup
    timings[choice] = benchmark(choice)
                      ^^^^^^^^^^^^^^^^^
  File "/data/users/aakhundov/pytorch/torch/_inductor/select_algorithm.py", line 774, in autotune
    raise ErrorFromChoice(msg, choice, benchmark_fn.debug_str())
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: ErrorFromChoice: Please run `ptxas /tmp/compile-ptx-src-bfb1c6` to confirm that this is a bug in `ptxas`

From choice TritonTemplateCaller(/tmp/torchinductor_aakhundov/z7/cz7n7nn6rdlaelu4pbaaurgmu74ikl6g76lkngwawrevlfxlc6re.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4)
inputs = [
    torch.empty_strided((0, 30), (30, 1), dtype=torch.float32, device='cuda'),
    torch.empty_strided((30, 40), (40, 1), dtype=torch.float32, device='cuda'),
]
out = torch.empty_strided((0, 40), (40, 1), dtype=torch.float32, device='cuda')

  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[0, s0], stride=[s0, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1], stride=[s1, 1]))
  ))
```

This PR adds a check to skip Triton templates in the `mm`, `addmm`, `mm_plus_mm` autotuning when the product of the MM problem shape (`m * n * k`) is zero.

Additionally, early exit conditions have been added to the mm and mm_plus_mm Triton templates on `M * N * K == 0`, to prevent issues when autotuning was done on non-zero-size inputs with dynamic shapes, then zero-size inputs are encountered by the compiled model.

Test Plan:

```
$ python test/inductor/test_max_autotune.py -v

...

----------------------------------------------------------------------
Ran 16 tests in 29.569s

OK
```

Reviewers: @eellison

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106865
Approved by: https://github.com/jansel
2023-08-11 19:10:16 +00:00
..
codegen [inductor] make thread order consistent with loop order (#106827) 2023-08-11 17:05:21 +00:00
fx_passes adding mixed_dtype_mm to torch._inductor (#106443) 2023-08-11 05:34:54 +00:00
kernel Skip Triton templates in MM max autotune with zero-size inputs (#106865) 2023-08-11 19:10:16 +00:00
__init__.py
autotune_process.py Enable Mypy Check in torch/_inductor/select_algorithm.py (#106701) 2023-08-10 03:19:50 +00:00
bounds.py [inductor] Use shape env bounds in inductor bounds.py (#106175) (#106568) 2023-08-11 00:16:09 +00:00
codecache.py Compile AOTInductor in Meta prod env (#106636) 2023-08-05 08:01:24 +00:00
compile_fx.py Enable mypy checking in compile_fx.py (#105830) 2023-08-09 09:05:23 +00:00
config.py adding mixed_dtype_mm to torch._inductor (#106443) 2023-08-11 05:34:54 +00:00
coordinate_descent_tuner.py
cuda_properties.py mypy _inductor/cuda_properties (#105620) 2023-07-20 21:13:01 +00:00
cudagraph_trees.py [memory snapshots] removed chained history (#106079) 2023-07-28 06:45:48 +00:00
debug.py More descriptive graph diagram names in svg (#106146) 2023-07-28 17:34:09 +00:00
decomposition.py [ROCm] enabling miopen_batch_norm lowering in inductor (#105740) 2023-08-01 22:39:17 +00:00
dependencies.py [Inductor] Optimize read write merging in FusedSchedulerNode ctor (#105693) 2023-07-21 17:26:44 +00:00
exc.py [inductor][easy] Improved warning message for missing OMP on mac (#106241) 2023-08-02 02:12:27 +00:00
freezing.py conv-bn folding in low precision (#106576) 2023-08-10 05:12:04 +00:00
fx_utils.py
graph.py Revert "Extend Inductor to support the third-party backend (#100706)" (#106652) 2023-08-05 06:41:08 +00:00
hooks.py
index_propagation.py mypy index propagation (#105622) 2023-07-20 21:37:43 +00:00
inductor_prims.py [inductor] Switch inductor_prims._bucketize over to aten.bucketize (#106658) 2023-08-09 14:00:22 +00:00
ir.py [inductor] Enable multilayer reductions with dynamic shapes (#106747) 2023-08-10 21:07:25 +00:00
lowering.py [inductor] Use shape env bounds in inductor bounds.py (#106175) (#106568) 2023-08-11 00:16:09 +00:00
metrics.py [inductor] Enable mypy checking in torch/_inductor/metrics.py (#105793) 2023-08-03 22:43:57 +00:00
optimize_indexing.py
pattern_matcher.py [inductor] Pass to remove pointless clones (#105994) 2023-07-28 00:57:09 +00:00
scheduler.py [inductor] make thread order consistent with loop order (#106827) 2023-08-11 17:05:21 +00:00
select_algorithm.py [inductor] Type triton size arguments in the kernel index_dtype (#106870) 2023-08-10 21:07:25 +00:00
sizevars.py [inductor] Enable multilayer reductions with dynamic shapes (#106747) 2023-08-10 21:07:25 +00:00
test_operators.py
triton_helpers.py
triton_heuristics.py [indcutor] add one triton config for reduction (#106925) 2023-08-11 17:15:03 +00:00
utils.py [inductor] don't cache non-static content (#106502) 2023-08-03 22:09:58 +00:00
virtualized.py
wrapper_benchmark.py [inductor] refactor wrapper benchmark code out of utils.py (#105584) 2023-07-21 00:01:35 +00:00