pytorch/test/inductor/test_combo_kernels.py
Aaron Orenstein 8c356ce3da Fix lint errors in fbcode (#135614)
Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps.  After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports.

Test Plan:
```
fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS
```
Before:
```
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include <torch/csrc/autograd/profiler_kineto.h> (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib.  Some things to try:
```

Differential Revision: D62049222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614
Approved by: https://github.com/oulgen, https://github.com/laithsakka
2024-09-13 02:04:34 +00:00

543 lines
17 KiB
Python

# Owner(s): ["module: inductor"]
import contextlib
import sys
import unittest
import torch
import torch._inductor
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TestCase,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.triton_utils import requires_cuda
aten = torch.ops.aten
try:
try:
from .test_torchinductor import check_model, check_model_cuda
except ImportError:
from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
check_model,
check_model_cuda,
)
except (unittest.SkipTest, ImportError) as e:
sys.stderr.write(f"{type(e)}: {e}\n")
if __name__ == "__main__":
sys.exit(0)
raise
@instantiate_parametrized_tests
class ComboKernelTests(TestCase):
check_model_cuda = check_model_cuda
check_model_cpu = check_model
check_kernel_count = True
def setUp(self):
super().setUp()
torch._inductor.metrics.reset()
self._test_stack = contextlib.ExitStack()
self._test_stack.enter_context(
torch._inductor.config.patch(
{
"combo_kernels": True,
"benchmark_combo_kernel": False,
}
)
)
def tearDown(self):
self._test_stack.close()
torch._inductor.metrics.reset()
super().tearDown()
@requires_cuda
def test_activation_functions(self):
def test_activations(a, b, c):
a1 = torch.nn.functional.relu(a)
b1 = torch.nn.functional.sigmoid(b)
c1 = torch.nn.functional.tanh(c)
return a1, b1, c1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
]
out_eager = test_activations(*inps)
out_compiled = torch.compile(test_activations)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
@requires_cuda
def test_reduce_functions(self):
def test_reduce(a, b, c, d):
a1 = torch.sum(a, dim=0)
b1 = torch.max(b, dim=0)
c1 = torch.min(c, dim=0)
d1 = torch.nn.functional.tanh(d)
return a1, b1, c1, d1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_reduce(*inps)
out_compiled = torch.compile(test_reduce)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2)
@requires_cuda
def test_mutated_args(self):
def test_mutated(a, b, c, d):
a.add_(1)
b.sigmoid_()
c = torch.add(c, 5)
d.tanh_()
return a, b, c, d
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_mutated(*inps)
out_compiled = torch.compile(test_mutated)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
@requires_cuda
def test_reduce_split(self):
def fn(a, b):
a1 = torch.linalg.vector_norm(a)
b1 = torch.sum(b, dim=0)
return a1, b1
inps = [
torch.rand(2048, 512, device="cuda"),
torch.rand(20, 20, device="cuda"),
]
out_eager = fn(*inps)
out_compiled = torch.compile(fn)(*inps)
self.assertEqual(out_eager, out_compiled)
@requires_cuda
def test_2d_blocking_partitioning(self):
def fn(a0, a1, a2, b0, b1, b2):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
c2 = torch.add(a2, b2)
return c0, c1, c2
self.check_model_cuda(
fn,
(
torch.rand(30, 20, device="cuda"),
torch.rand(40, 30, device="cuda"),
torch.rand(36, 40, device="cuda"),
torch.rand(30, 20, device="cuda"),
torch.rand(30, 40, device="cuda").t(),
torch.rand(40, 36, device="cuda").t(),
),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
@instantiate_parametrized_tests
class ComboKernelBenchmarkTests(TestCase):
check_model_cuda = check_model_cuda
check_model_cpu = check_model
check_kernel_count = True
def setUp(self):
super().setUp()
torch._inductor.metrics.reset()
self._test_stack = contextlib.ExitStack()
self._test_stack.enter_context(
torch._inductor.config.patch(
{
"combo_kernels": True,
"benchmark_combo_kernel": True,
}
)
)
def tearDown(self):
self._test_stack.close()
torch._inductor.metrics.reset()
super().tearDown()
@requires_cuda
def test_activation_benchmark(self):
def test_activations(a, b, c):
a1 = torch.nn.functional.relu(a)
b1 = torch.nn.functional.sigmoid(b)
c1 = torch.nn.functional.tanh(c)
return a1, b1, c1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
]
out_eager = test_activations(*inps)
out_compiled = torch.compile(test_activations)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
@requires_cuda
def test_reduce_benchmark(self):
def test_reduce(a, b, c, d):
a1 = torch.sum(a, dim=0)
b1 = torch.max(b, dim=0)
c1 = torch.min(c, dim=0)
d1 = torch.nn.functional.tanh(d)
return a1, b1, c1, d1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_reduce(*inps)
out_compiled = torch.compile(test_reduce)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10)
@requires_cuda
def test_mutated_benchmark(self):
def test_mutated(a, b, c, d):
a.add_(1)
b.sigmoid_()
c = torch.add(c, 5)
d.tanh_()
return a, b, c, d
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_mutated(*inps)
out_compiled = torch.compile(test_mutated)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9])
@requires_cuda
def test_round_robin_dispatch(self):
# combo kernel dispatch strategy: round robin
def test_mutated(a, b, c, d):
a.add_(1)
b.sigmoid_()
c = torch.add(c, 5)
d.tanh_()
return a, b, c, d
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 5, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(5, 18, device="cuda"),
]
out_eager = test_mutated(*inps)
out_compiled = torch.compile(test_mutated)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6)
@requires_cuda
def test_2d_blocking_benchmark(self):
def fn(a0, a1, a2, b0, b1, b2):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
c2 = torch.add(a2, b2)
return c0, c1, c2
self.check_model_cuda(
fn,
(
torch.rand(30, 20, device="cuda"),
torch.rand(40, 30, device="cuda"),
torch.rand(36, 40, device="cuda"),
torch.rand(30, 20, device="cuda"),
torch.rand(30, 40, device="cuda").t(),
torch.rand(40, 36, device="cuda").t(),
),
)
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
@requires_cuda
def test_persistent_reduction_no_x_dim(self):
def fn(x, y):
return x.sum(1), y.sum(1)
inps = (
torch.rand(16, 256, device="cuda"),
torch.rand(32, 256, device="cuda"),
)
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
out_eager = fn(*inps)
out_compiled = torch.compile(fn)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
@instantiate_parametrized_tests
class ComboKernelDynamicShapesTests(TestCase):
check_model_cuda = check_model_cuda
check_model_cpu = check_model
check_kernel_count = True
def setUp(self):
super().setUp()
torch._inductor.metrics.reset()
self._test_stack = contextlib.ExitStack()
self._test_stack.enter_context(
torch._inductor.config.patch(
{
"combo_kernels": True,
"benchmark_combo_kernel": True,
}
)
)
self._test_stack.enter_context(
torch._dynamo.config.patch(
{
"automatic_dynamic_shapes": False,
"assume_static_by_default": False,
}
)
)
def tearDown(self):
self._test_stack.close()
torch._inductor.metrics.reset()
super().tearDown()
@requires_cuda
def test_dynamic_shapes_activations(self):
def test_activations(a, b, c):
a1 = torch.nn.functional.relu(a)
b1 = torch.nn.functional.sigmoid(b)
c1 = torch.nn.functional.tanh(c)
return a1, b1, c1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
]
out_eager = test_activations(*inps)
out_compiled = torch.compile(test_activations)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
@requires_cuda
def test_dynamic_shapes_2d_blocking(self):
def fn(a0, a1, a2, b0, b1, b2):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
c2 = torch.add(a2, b2)
return c0, c1, c2
self.check_model_cuda(
fn,
(
torch.rand(30, 20, device="cuda"),
torch.rand(40, 30, device="cuda"),
torch.rand(36, 40, device="cuda"),
torch.rand(30, 20, device="cuda"),
torch.rand(30, 40, device="cuda").t(),
torch.rand(40, 36, device="cuda").t(),
),
)
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
@requires_cuda
def test_dynamic_shapes_reduce(self):
def test_reduce(a, b, c, d):
a1 = torch.sum(a, dim=0)
b1 = torch.max(b, dim=0)
c1 = torch.min(c, dim=0)
d1 = torch.nn.functional.tanh(d)
return a1, b1, c1, d1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_reduce(*inps)
out_compiled = torch.compile(test_reduce)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10)
@requires_cuda
def test_dynamic_shapes_mutated(self):
# combo kernel dispatch strategy: round robin
def test_mutated(a, b, c, d):
a.add_(1)
b.sigmoid_()
c = torch.add(c, 5)
d.tanh_()
return a, b, c, d
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 5, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(5, 18, device="cuda"),
]
out_eager = test_mutated(*inps)
out_compiled = torch.compile(test_mutated)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6)
@requires_cuda
@torch._inductor.config.patch("combo_kernels_autotune", 0)
def test_dynamic_shapes_activations_no_autotune(self):
def test_activations(a, b, c):
a1 = torch.nn.functional.relu(a)
b1 = torch.nn.functional.sigmoid(b)
c1 = torch.nn.functional.tanh(c)
return a1, b1, c1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
]
out_eager = test_activations(*inps)
out_compiled = torch.compile(test_activations)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
@requires_cuda
@torch._dynamo.config.patch("automatic_dynamic_shapes", True)
@torch._dynamo.config.patch("assume_static_by_default", True)
def test_dynamic_shapes_persistent_reduction_no_x_dim(self):
def fn(x, y):
return x.sum(1), y.sum(1)
inps = (
torch.rand(16, 256, device="cuda"),
torch.rand(32, 256, device="cuda"),
)
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
out_eager = fn(*inps)
out_compiled = torch.compile(fn)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
@requires_cuda
@torch._dynamo.config.patch("automatic_dynamic_shapes", True)
@torch._dynamo.config.patch("assume_static_by_default", True)
def test_dynamic_shapes_2d_blocking_round_robin(self):
def fn(a0, a1, a2, b0, b1, b2):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
c2 = torch.add(a2, b2)
return c0, c1, c2
inps = (
torch.rand(20, 30, device="cuda"),
torch.rand(30, 30, device="cuda"),
torch.rand(40, 32, device="cuda"),
torch.rand(30, 20, device="cuda").t(),
torch.rand(30, 30, device="cuda").t(),
torch.rand(32, 40, device="cuda").t(),
)
out_eager = fn(*inps)
compiled = torch.compile(fn)
out_compiled = compiled(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6)
torch._inductor.metrics.reset()
inps = (
torch.rand(24, 30, device="cuda"),
torch.rand(32, 30, device="cuda"),
torch.rand(48, 32, device="cuda"),
torch.rand(30, 24, device="cuda").t(),
torch.rand(30, 32, device="cuda").t(),
torch.rand(32, 48, device="cuda").t(),
)
out_compiled = compiled(*inps)
out_eager = fn(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6)
@requires_cuda
@torch._dynamo.config.patch("automatic_dynamic_shapes", True)
@torch._dynamo.config.patch("assume_static_by_default", True)
@torch._inductor.config.patch("triton.autotune_at_compile_time", True)
def test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda(self):
def fn(x, y, z):
return x.sum(1), y.mean(1), z.max(1)
inps = (
torch.rand(16, 128, device="cuda"),
torch.rand(32, 128, device="cuda"),
torch.rand(32, 256, device="cuda"),
)
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
torch._dynamo.mark_dynamic(inps[2], 0, min=1, max=256)
out_eager = fn(*inps)
out_compiled = torch.compile(fn)(*inps)
self.assertEqual(out_eager, out_compiled)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if HAS_CPU or HAS_CUDA:
run_tests(needs="filelock")