mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
As title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145058 Approved by: https://github.com/jansel
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import fresh_inductor_cache
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
GPU_TYPE,
|
|
requires_gpu,
|
|
requires_triton,
|
|
)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestAsyncCompile(TestCase):
|
|
@requires_gpu()
|
|
@requires_triton()
|
|
@parametrize("method", ("subprocess", "fork", "spawn"))
|
|
def test_pool(self, method):
|
|
def fn(x, y):
|
|
return x + y
|
|
|
|
x = torch.rand(10).to(GPU_TYPE)
|
|
y = torch.rand(10).to(GPU_TYPE)
|
|
|
|
with config.patch("worker_start_method", method):
|
|
shutdown_compile_workers()
|
|
pool = AsyncCompile.process_pool()
|
|
pool.ready_future.result(timeout=120)
|
|
|
|
with fresh_inductor_cache():
|
|
compiled_fn = torch.compile(fn)
|
|
self.assertEqual(fn(x, y), compiled_fn(x, y))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|