mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Differential Revision: D69068432 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146439 Approved by: https://github.com/avikchaudhuri
4547 lines
158 KiB
Python
4547 lines
158 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from unittest import skip
|
|
|
|
import torch
|
|
import torch._export
|
|
import torch._inductor
|
|
import torch._inductor.config
|
|
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
|
import torch.nn as nn
|
|
from torch._dynamo import config as dynamo_config
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
from torch._dynamo.testing import rand_strided, same
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor import config
|
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
|
from torch._inductor.test_case import TestCase
|
|
from torch._inductor.utils import is_big_gpu, run_and_get_cpp_code
|
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
|
from torch.export import Dim, export, export_for_training
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal import common_utils
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater
|
|
from torch.testing._internal.common_device_type import (
|
|
_has_sufficient_memory,
|
|
skipCUDAIf,
|
|
)
|
|
from torch.testing._internal.common_quantization import (
|
|
skip_if_no_torchvision,
|
|
skipIfNoFBGEMM,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
DeterministicGuard,
|
|
IS_CI,
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_WINDOWS,
|
|
skipIfRocm,
|
|
skipIfXpu,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
from torch.testing._internal.custom_tensor import CustomTensorPlainOut
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE
|
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
|
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._triton import has_triton_tma
|
|
|
|
|
|
if HAS_GPU:
|
|
import triton # @manual
|
|
from triton import language as tl
|
|
|
|
from torch.testing._internal.triton_utils import (
|
|
add_kernel,
|
|
add_kernel_2d_autotuned,
|
|
add_kernel_autotuned,
|
|
add_kernel_autotuned_weird_param_order,
|
|
add_kernel_with_optional_param,
|
|
add_kernel_with_scaling,
|
|
add_kernel_with_tma_1d,
|
|
add_kernel_with_tma_2d,
|
|
mul2_inplace_kernel,
|
|
)
|
|
|
|
if IS_WINDOWS and IS_CI:
|
|
sys.stderr.write(
|
|
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
|
|
)
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
|
|
|
try:
|
|
try:
|
|
from .test_aot_inductor_utils import (
|
|
AOTIRunnerUtil,
|
|
check_model,
|
|
check_model_with_multiple_inputs,
|
|
code_check_count,
|
|
)
|
|
from .test_control_flow import (
|
|
CondModels,
|
|
prepend_counters,
|
|
prepend_predicates,
|
|
WhileLoopModels,
|
|
)
|
|
from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
|
|
except ImportError:
|
|
from test_aot_inductor_utils import ( # @manual=fbcode//caffe2/test/inductor:aot_inductor_utils-library
|
|
AOTIRunnerUtil,
|
|
check_model,
|
|
check_model_with_multiple_inputs,
|
|
code_check_count,
|
|
)
|
|
from test_control_flow import ( # @manual=fbcode//caffe2/test/inductor:control_flow-library
|
|
CondModels,
|
|
prepend_counters,
|
|
prepend_predicates,
|
|
WhileLoopModels,
|
|
)
|
|
from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
|
copy_tests,
|
|
requires_multigpu,
|
|
TestFailure,
|
|
)
|
|
except (unittest.SkipTest, ImportError):
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise
|
|
|
|
|
|
class AOTInductorTestsTemplate:
|
|
def test_simple(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
model = Model()
|
|
self.check_model(model, example_inputs)
|
|
if self.use_minimal_arrayref_interface:
|
|
self.code_check_count(
|
|
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
|
|
)
|
|
|
|
def test_compile_wrapper_with_O0(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
model = Model()
|
|
with config.patch("aot_inductor.compile_wrapper_with_O0", True):
|
|
self.check_model(model, example_inputs)
|
|
self.code_check_count(model, example_inputs, "__attribute__((", 2)
|
|
|
|
def test_small_constant(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"always_keep_tensor_constants": True}):
|
|
self.check_model(Model().to(self.device), example_inputs)
|
|
|
|
def test_output_path_1(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
with config.patch("aot_inductor.output_path", "tmp_output_"):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_output_path_2(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
model = Model().to(device=self.device)
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
expected_path = os.path.join(tempfile.mkdtemp(dir=cache_dir()), "model.so")
|
|
actual_path = AOTIRunnerUtil.compile(
|
|
model, example_inputs, options={"aot_inductor.output_path": expected_path}
|
|
)
|
|
self.assertTrue(actual_path == expected_path)
|
|
|
|
def test_constant_folding(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w_pre = torch.randn(4, 4, device=device)
|
|
self.b = torch.randn(4, device=device)
|
|
|
|
def forward(self, x):
|
|
w_transpose = torch.transpose(self.w_pre, 0, 1)
|
|
w_relu = torch.nn.functional.relu(w_transpose)
|
|
w = w_relu + self.b
|
|
return torch.matmul(x, w)
|
|
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@requires_gpu
|
|
def test_duplicate_constant_folding(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.w1 = torch.randn(4, 4, device=device)
|
|
self.w2 = torch.randn(4, 4, device=device)
|
|
self.w3 = torch.randn(4, 4, device=device)
|
|
self.w4 = torch.randn(4, 4, device=device)
|
|
|
|
def forward(self, x):
|
|
w_concat = torch.cat((self.w1, self.w2, self.w3, self.w4))
|
|
return torch.cat((x, w_concat))
|
|
|
|
example_inputs = (torch.randn(4, 4, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@requires_gpu
|
|
def test_multi_device(self):
|
|
if self.device == "cpu" and GPU_TYPE == "xpu":
|
|
raise unittest.SkipTest(
|
|
"In this scenario, the test case will run XPU code in "
|
|
"AOTIModelContainerRunnerCpu, which is not reasonable,"
|
|
"See issue #140805"
|
|
)
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = x + 1
|
|
x = x.cpu()
|
|
x = x + 2
|
|
x = x.to(GPU_TYPE)
|
|
return x
|
|
|
|
example_inputs = (torch.randn(32, 64, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_large_weight(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2048, 262144)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 262144, device=self.device),
|
|
torch.randn(1, 2048, device=self.device),
|
|
)
|
|
|
|
# We only test compilation since we often get OOM running in CI.
|
|
model = Model()
|
|
model = model.to(self.device)
|
|
AOTIRunnerUtil.compile(model, example_inputs)
|
|
|
|
def test_subclasses(self):
|
|
device_to_init = self.device
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.p1 = torch.nn.Parameter(torch.ones(3, 4, device=device_to_init))
|
|
self.p2 = torch.nn.Parameter(
|
|
CustomTensorPlainOut(
|
|
torch.ones(3, 4, device=device_to_init),
|
|
torch.ones(3, 4, device=device_to_init),
|
|
)
|
|
)
|
|
|
|
def forward(self, x):
|
|
a = (2 * self.p1 + self.p2).sum()
|
|
return x + a
|
|
|
|
m = Foo()
|
|
ref_x = torch.randn(3, 4, device=device_to_init)
|
|
|
|
with torch.no_grad():
|
|
result = AOTIRunnerUtil.run(
|
|
self.device,
|
|
m,
|
|
(ref_x,),
|
|
)
|
|
actual = m(ref_x)
|
|
self.assertTrue(same(result, actual))
|
|
|
|
def test_large_mmaped_weights(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(512, 250112)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 250112, device=self.device),
|
|
torch.randn(1, 512, device=self.device),
|
|
)
|
|
with config.patch({"aot_inductor.force_mmap_weights": True}):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_with_offset(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.orig_tensor = torch.randn(2, 15, 10, device=device)[0]
|
|
self.tensor = self.orig_tensor[5:, :]
|
|
|
|
def forward(self, x, y):
|
|
return (
|
|
x
|
|
+ torch.nn.functional.linear(y, self.orig_tensor[:10, :])
|
|
+ self.tensor
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_freezing(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(9, 10, device=device)
|
|
self.padding = torch.randn(1, 10, device=device)
|
|
|
|
def forward(self, x, y):
|
|
padded_weight = torch.cat((self.weight, self.padding), dim=0)
|
|
return x + torch.nn.functional.linear(y, padded_weight)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
|
|
with config.patch({"freezing": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_conv_freezing(self):
|
|
dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float]
|
|
for dtype, groups in itertools.product(dtypes, [1, 2]):
|
|
iC = 2
|
|
oC = 3
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(oC * groups, iC, 3, 3, device=device).to(
|
|
dtype
|
|
)
|
|
|
|
def forward(self, y):
|
|
return torch.nn.functional.conv2d(y, self.weight, groups=groups)
|
|
|
|
example_inputs = (
|
|
torch.randn(2, iC * groups, 10, 10, device=self.device).to(dtype),
|
|
)
|
|
|
|
with config.patch({"freezing": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_deconv_freezing(self):
|
|
dtypes = [torch.float]
|
|
if torch._C._has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
dtypes.append(torch.bfloat16)
|
|
for dtype, groups in itertools.product(dtypes, [2, 1]):
|
|
iC = 4
|
|
oC = 2
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(iC, oC * groups, 2, 2, device=device).to(
|
|
dtype
|
|
)
|
|
|
|
def forward(self, y):
|
|
return torch.nn.functional.conv_transpose2d(
|
|
y, self.weight, groups=groups
|
|
)
|
|
|
|
example_inputs = (torch.randn(1, iC, 3, 3, device=self.device).to(dtype),)
|
|
with config.patch({"freezing": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used",
|
|
)
|
|
def test_linear_freezing(self):
|
|
dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float]
|
|
for dtype in dtypes:
|
|
|
|
class LinearModel(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(10, 10, device=device).to(dtype)
|
|
self.bias = torch.randn(10, device=device).to(dtype)
|
|
|
|
def forward(self, y):
|
|
return torch.nn.functional.linear(y, self.weight, self.bias)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)
|
|
|
|
with config.patch({"freezing": True}):
|
|
self.check_model(LinearModel(self.device), example_inputs)
|
|
|
|
def test_linear_dynamic_maxautotune(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
model = Model().to(device=self.device)
|
|
compile_inputs = (torch.randn(2048, 1, device=self.device),)
|
|
dim0_x = Dim("dim0_x", min=2, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}}
|
|
ep = torch.export.export(model, compile_inputs, dynamic_shapes=dynamic_shapes)
|
|
optimized = torch._inductor.aoti_load_package(
|
|
torch._inductor.aoti_compile_and_package(
|
|
ep,
|
|
inductor_configs={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
)
|
|
runtime_input = torch.randn(10, 1, device=self.device)
|
|
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
|
|
runtime_input = torch.randn(16, 1, device=self.device)
|
|
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
|
|
runtime_input = torch.randn(100, 1, device=self.device)
|
|
self.assertTrue(same(optimized(runtime_input), model(runtime_input)))
|
|
|
|
@torch._inductor.config.patch(
|
|
pre_grad_fusion_options={
|
|
"normalization_pass": {},
|
|
"remove_split_with_size_one_pass": {},
|
|
"merge_getitem_cat_pass": {},
|
|
"merge_stack_tahn_unbind_pass": {},
|
|
"merge_splits_pass": {},
|
|
"mutate_cat_pass": {},
|
|
"split_cat_pass": {},
|
|
"unbind_stack_pass": {},
|
|
},
|
|
post_grad_fusion_options={},
|
|
)
|
|
def test_simple_split(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
|
|
|
|
example_inputs = (torch.randn(2, 8, device=self.device),)
|
|
counters.clear()
|
|
self.check_model(Model(), example_inputs)
|
|
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
|
|
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
|
|
self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
|
|
|
|
def test_amp_fallback_random(self):
|
|
def fn(x, w):
|
|
return torch.functional.F.linear(x, w)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
with config.patch({"fallback_random": True}):
|
|
with torch.amp.autocast(device_type=self.device):
|
|
self.check_model(fn, example_inputs)
|
|
|
|
def test_missing_output(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
a = torch.sin(x)
|
|
b = torch.mm(a, y)
|
|
c = torch.cos(b)
|
|
return c
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_output_misaligned(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x_unsqueeze = torch.unsqueeze(x, dim=0)
|
|
y_unsqueeze = torch.unsqueeze(y, dim=0)
|
|
cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
|
|
x_getitem = cat[0]
|
|
y_getitem = cat[1]
|
|
x_sigmoid = torch.sigmoid(x_getitem)
|
|
return x_sigmoid, y_getitem
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@skip("Test was marked as expected failure, but does not fail always anymore.")
|
|
def test_dynamic_smem_above_default_limit(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x @ y
|
|
|
|
model = Model().to(self.device)
|
|
# on A100, the generated Triton kernel for this MM
|
|
# requires 55296 bytes of dynamic SMEM which is above
|
|
# the A100's default dynamic SMEM limit of 49152 bytes.
|
|
example_inputs = (
|
|
torch.randn(10285, 96, device=self.device),
|
|
torch.randn(96, 1, device=self.device),
|
|
)
|
|
self.check_model(
|
|
model,
|
|
example_inputs,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
|
def test_seq(self):
|
|
layernorm = torch.nn.LayerNorm(10)
|
|
net = torch.nn.Sequential(
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
layernorm,
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
example_inputs = (torch.randn(10, device=self.device),)
|
|
self.check_model(net.eval(), example_inputs)
|
|
|
|
def test_addmm(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model(N, K, self.device)
|
|
batch = 2
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_aliased_buffer_reuse(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x = 2 * x
|
|
y = 2 * y
|
|
c = torch.cat([x, y], dim=-1)
|
|
d = 1 + c
|
|
m = torch.mm(d, d)
|
|
return m[:, :2] + x
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 2, device=self.device),
|
|
torch.randn(4, 2, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_buffer_reuse(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
a = torch.sin(x)
|
|
b = torch.cos(y)
|
|
c = torch.mm(a, b)
|
|
d = torch.relu(c)
|
|
e = torch.sigmoid(d)
|
|
f = torch.mm(x, y)
|
|
g = e + f
|
|
return g
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=self.device),
|
|
torch.randn(4, 4, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_duplicated_params(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(torch.rand(6))
|
|
self.q = self.p
|
|
|
|
def forward(self, x):
|
|
return self.p * x + self.q
|
|
|
|
example_inputs = (torch.rand(6, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
|
|
def test_inf(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
x = torch.randn(10, 10, device=self.device)
|
|
x[0][0] = float("Inf")
|
|
example_inputs = (
|
|
x,
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model().to(self.device),
|
|
example_inputs,
|
|
options={"debug_check_inf_and_nan": True},
|
|
)
|
|
|
|
@unittest.skip("Skip this test, only for local test. SIGABRT is produced.")
|
|
def test_nan(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
x = torch.randn(10, 10, device=self.device)
|
|
x[0][0] = float("nan")
|
|
example_inputs = (
|
|
x,
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model().to(self.device),
|
|
example_inputs,
|
|
options={"debug_check_inf_and_nan": True},
|
|
)
|
|
|
|
def test_assert_async(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU_TYPE")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
u0 = x.item()
|
|
torch._check(u0 > 3)
|
|
return torch.ones(u0)[0]
|
|
|
|
x = torch.tensor(23, device=self.device)
|
|
example_inputs = (x,)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_simple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
add_0 = x + y
|
|
return torch.nn.functional.relu(input=add_0, inplace=False)
|
|
|
|
x = torch.randn(128, 2048, device=self.device)
|
|
y = torch.randn(128, 2048, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
|
|
example_inputs = (x, y)
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
|
@skipIfXpu
|
|
def test_fp8(self):
|
|
# cuda only
|
|
if self.device != "cuda":
|
|
return
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, dtype):
|
|
super().__init__()
|
|
self.out_dtype = dtype
|
|
|
|
def forward(self, x, weight, bias, scale_a, scale_b):
|
|
weight = weight.to(torch.float8_e4m3fn)
|
|
output = torch._scaled_mm(
|
|
x,
|
|
weight,
|
|
bias=input_bias,
|
|
out_dtype=self.out_dtype,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
)
|
|
return output
|
|
|
|
dtype = torch.float16
|
|
|
|
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
|
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
|
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
|
|
weight_shape = (32, 16)
|
|
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
|
|
a_inverse_scale = 1 / a_scale
|
|
b_inverse_scale = 1 / b_scale
|
|
|
|
x_shape = (16, 16)
|
|
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
|
|
self.check_model(
|
|
Model(dtype),
|
|
(x, weight, input_bias, a_inverse_scale, b_inverse_scale),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
|
@skipIfXpu
|
|
def test_fp8_view_of_param(self):
|
|
# cuda only
|
|
if self.device != GPU_TYPE:
|
|
return
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, dtype, weight):
|
|
super().__init__()
|
|
self.out_dtype = dtype
|
|
self.weight = weight
|
|
|
|
def forward(self, x, bias, scale_a, scale_b):
|
|
# test: do the view inside of the graph,
|
|
# AOTI needs to materialize this view before passing
|
|
# it into the scaled_mm extern kernel
|
|
weight = self.weight.T
|
|
output = torch._scaled_mm(
|
|
x,
|
|
weight,
|
|
bias=input_bias,
|
|
out_dtype=self.out_dtype,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
)
|
|
return output
|
|
|
|
dtype = torch.float16
|
|
|
|
a_scale = torch.Tensor([1.0]).to(device=self.device)
|
|
b_scale = torch.Tensor([1.0]).to(device=self.device)
|
|
input_bias = torch.rand(32, device=self.device, dtype=dtype)
|
|
weight_shape = (32, 16)
|
|
weight = torch.rand(*weight_shape, device=self.device, dtype=dtype).to(
|
|
torch.float8_e4m3fn
|
|
)
|
|
a_inverse_scale = 1 / a_scale
|
|
b_inverse_scale = 1 / b_scale
|
|
|
|
x_shape = (16, 16)
|
|
x = torch.rand(*x_shape, device=self.device, dtype=dtype).to(
|
|
torch.float8_e4m3fn
|
|
)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = ({0: dim0_x}, None, None, None)
|
|
self.check_model(
|
|
Model(dtype, weight),
|
|
(x, input_bias, a_inverse_scale, b_inverse_scale),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_poi_multiple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
add_0 = x + y
|
|
return torch.nn.functional.relu(input=add_0, inplace=False)
|
|
|
|
x = torch.randn(128, 2048, device=self.device)
|
|
y = torch.randn(128, 2048, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
|
|
list_example_inputs = [(x, y)]
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(64, 2048, device=self.device),
|
|
torch.randn(64, 2048, device=self.device),
|
|
),
|
|
)
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(211, 2048, device=self.device),
|
|
torch.randn(211, 2048, device=self.device),
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
Model(), list_example_inputs, dynamic_shapes=dynamic_shapes
|
|
)
|
|
|
|
def test_addmm_multiple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model(N, K, self.device)
|
|
batch = 2
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
dim0_a = Dim("dim0_a", min=1, max=2048)
|
|
dynamic_shapes = {"a": {0: dim0_a}}
|
|
list_example_inputs = [(a,)]
|
|
batch = 2048
|
|
list_example_inputs.append(
|
|
(torch.randn(batch, M, K, device=self.device),),
|
|
)
|
|
batch = 128
|
|
list_example_inputs.append(
|
|
(torch.randn(batch, M, K, device=self.device),),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
model,
|
|
list_example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
)
|
|
|
|
def test_bmm_multiple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
return torch.bmm(a, b)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model()
|
|
batch = 1024
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
b = torch.randn(batch, K, N, device=self.device)
|
|
dim0_a = Dim("dim0_a", min=1, max=2048)
|
|
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_a}}
|
|
list_example_inputs = [(a, b)]
|
|
batch = 2048
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(batch, M, K, device=self.device),
|
|
torch.randn(batch, K, N, device=self.device),
|
|
),
|
|
)
|
|
batch = 128
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(batch, M, K, device=self.device),
|
|
torch.randn(batch, K, N, device=self.device),
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
model,
|
|
list_example_inputs,
|
|
options={
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "TRITON",
|
|
},
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_foreach_multiple_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
x_unsqueeze = torch.unsqueeze(x, dim=0)
|
|
y_unsqueeze = torch.unsqueeze(y, dim=0)
|
|
cat = torch.cat([x_unsqueeze, y_unsqueeze], dim=0)
|
|
return cat
|
|
|
|
model = Model()
|
|
x = torch.randn(128, 2048, device=self.device)
|
|
y = torch.randn(128, 2048, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=2048)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
|
|
list_example_inputs = [(x, y)]
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(64, 2048, device=self.device),
|
|
torch.randn(64, 2048, device=self.device),
|
|
),
|
|
)
|
|
list_example_inputs.append(
|
|
(
|
|
torch.randn(211, 2048, device=self.device),
|
|
torch.randn(211, 2048, device=self.device),
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
model,
|
|
list_example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
# scaled_dot_product_flash_attention
|
|
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
|
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
|
|
def test_sdpa(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, q, k, v):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0]
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
|
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
|
|
def test_sdpa_2(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, q, k, v, x):
|
|
t = torch.nn.functional.scaled_dot_product_attention(
|
|
q, k, v, is_causal=True
|
|
)[0]
|
|
return x + t
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_linear(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(10, 10, device=device)
|
|
self.bias = torch.randn(10, device=device)
|
|
|
|
def forward(self, x):
|
|
return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight(
|
|
x, self.weight, self.bias
|
|
)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quanatized_int8_linear(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(10, 10, device=device)
|
|
self.bias = torch.randn(10, device=device)
|
|
self.input_scale = torch.tensor(0.1)
|
|
self.input_zero_point = torch.tensor(0)
|
|
self.weight_scale = torch.tensor(0.1)
|
|
self.weight_zero_point = torch.tensor(0)
|
|
self.output_scale = torch.tensor(0.1)
|
|
self.output_zero_point = torch.tensor(0)
|
|
self.out_channel = 10
|
|
|
|
def forward(self, x):
|
|
return torch.ops._quantized.wrapped_quantized_linear(
|
|
x,
|
|
self.input_scale,
|
|
self.input_zero_point,
|
|
self.weight,
|
|
self.weight_scale,
|
|
self.weight_zero_point,
|
|
self.bias,
|
|
self.output_scale,
|
|
self.output_zero_point,
|
|
self.out_channel,
|
|
)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_zero_grid_with_unbacked_symbols(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
nz = torch.nonzero(x)
|
|
b = torch.ones_like(nz, dtype=torch.float16)
|
|
c = torch.zeros_like(nz, dtype=torch.float16)
|
|
d = (b + c) @ y
|
|
return d.sum()
|
|
|
|
example_inputs = (
|
|
torch.tensor([1, 1, 1], device=self.device),
|
|
torch.randn((1, 32), dtype=torch.float16, device=self.device),
|
|
)
|
|
self.check_model(Repro(), example_inputs)
|
|
|
|
@config.patch({"triton.autotune_at_compile_time": None})
|
|
def test_stride_with_unbacked_expr(self):
|
|
class Repro(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
u0 = x.item()
|
|
torch._check(u0 >= 1)
|
|
s0 = y.size(0)
|
|
expr = u0 * s0
|
|
sevens = torch.empty_strided(
|
|
size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device
|
|
).fill_(7)
|
|
return sevens * 3
|
|
|
|
example_inputs = (
|
|
torch.scalar_tensor(2, dtype=torch.int, device=self.device),
|
|
torch.ones(8, device=self.device),
|
|
)
|
|
self.check_model(Repro(), example_inputs)
|
|
|
|
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
|
def test_fallback_kernel_with_symexpr_output(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, q, k, v):
|
|
q = q.reshape(
|
|
q.shape[0],
|
|
2,
|
|
q.shape[2] * q.shape[3],
|
|
q.shape[1] // 2,
|
|
)
|
|
k = k.reshape(
|
|
k.shape[0],
|
|
2,
|
|
k.shape[2] * k.shape[3],
|
|
k.shape[1] // 2,
|
|
)
|
|
v = v.reshape(
|
|
v.shape[0],
|
|
2,
|
|
v.shape[2] * v.shape[3],
|
|
v.shape[1] // 2,
|
|
)
|
|
|
|
res = torch.ops.aten._scaled_dot_product_flash_attention.default(
|
|
q,
|
|
k,
|
|
v,
|
|
)
|
|
return res[0]
|
|
|
|
m = Module().to(device=self.device)
|
|
tensor_shape = (4, 32, 4, 4)
|
|
inputs = (
|
|
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
|
|
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
|
|
torch.randn(tensor_shape, dtype=torch.float16, device=self.device),
|
|
)
|
|
|
|
dynamic_shapes = {
|
|
"q": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
|
|
"k": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
|
|
"v": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC},
|
|
}
|
|
ep = torch.export.export(m, inputs, dynamic_shapes=dynamic_shapes, strict=False)
|
|
path = torch._inductor.aot_compile(ep.module(), inputs)
|
|
aot_model = torch._export.aot_load(path, device=self.device)
|
|
torch.testing.assert_close(m(*inputs), aot_model(*inputs))
|
|
|
|
def test_large_grid(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, primals_5):
|
|
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
|
|
primals_5 = None
|
|
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
|
|
clone = torch.ops.aten.clone.default(
|
|
permute, memory_format=torch.contiguous_format
|
|
)
|
|
return clone
|
|
|
|
# let y_grid = 65537
|
|
s0 = 16777472
|
|
s1 = 8
|
|
example_inputs = (torch.rand(s0, s1, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_cond_simple(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.Simple(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_nested(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_abc = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p0": {},
|
|
"p1": {},
|
|
"p2": {},
|
|
"a": {0: dim0_abc, 1: None},
|
|
"b": {0: dim0_abc, 1: None},
|
|
"c": {0: dim0_abc, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.Nested(),
|
|
prepend_predicates(inputs, num_predicates=3),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_parameters(self):
|
|
inputs = (torch.randn((10, 20), device=self.device),)
|
|
dim0_abc = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_abc, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.Parameters(self.device),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_reinterpret_view_inputs_outputs(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=3, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.ReinterpretView(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_multiple_outputs(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((30, 40), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dim0_c = Dim("s1", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
"c": {0: dim0_c, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.MultipleOutputs(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_with_outer_code_before_after(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.OuterCode(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_use_buffers_from_outer_scope(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_abc = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"a": {0: dim0_abc, 1: None},
|
|
"b": {0: dim0_abc, 1: None},
|
|
"c": {0: dim0_abc, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.OuterBuffers(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_cond_non_tensor_predicates(self, dynamic):
|
|
inputs1 = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((15, 20), device=self.device),
|
|
)
|
|
inputs2 = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((5, 20), device=self.device),
|
|
)
|
|
inputs = (inputs1,)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
inputs = (inputs1, inputs2)
|
|
dim0_a = Dim("s0", min=2, max=1024)
|
|
dim0_b = Dim("s1", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_a, 1: None},
|
|
"b": {0: dim0_b, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.WithNonTensorPredicate(),
|
|
inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_cond_unbacked_symint_closure(self, dynamic):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((15, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_a = Dim("s0", min=2, max=1024)
|
|
dim0_b = Dim("s1", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"p": {},
|
|
"x": {0: dim0_a, 1: None},
|
|
"y": {0: dim0_b, 1: None},
|
|
"z": {0: dim0_a, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
CondModels.UnbackedSymIntClosure(),
|
|
prepend_predicates(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_cond_symint_input(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = y.shape[0]
|
|
b = z.shape[0]
|
|
|
|
def true_fn(x):
|
|
return x + a
|
|
|
|
def false_fn(x):
|
|
return x + b * z
|
|
|
|
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
|
|
|
input1 = (
|
|
torch.ones(3, 3, device=self.device),
|
|
torch.ones(5, device=self.device),
|
|
torch.ones(3, 3, device=self.device),
|
|
)
|
|
input2 = (
|
|
torch.ones(10, 3, device=self.device),
|
|
torch.ones(6, device=self.device),
|
|
torch.ones(10, 3, device=self.device),
|
|
)
|
|
inputs = (input1, input2)
|
|
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
|
self.check_model_with_multiple_inputs(
|
|
M(),
|
|
inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_simple(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"ci": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.Simple(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_nested(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"ci": {},
|
|
"cj": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.Nested(),
|
|
prepend_counters(inputs, num_counters=2),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_outer_code(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.OuterCode(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_parameters(self):
|
|
inputs = (torch.randn((10, 20), device=self.device),)
|
|
dim0_a = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_a, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.Parameters(self.device),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_outer_buffers(self):
|
|
inputs = (
|
|
torch.randn((10, 20), device=self.device),
|
|
torch.randn((10, 20), device=self.device),
|
|
)
|
|
# dynamic shapes don't work now due to
|
|
# https://github.com/pytorch/pytorch/issues/123596
|
|
# dim0_ab = Dim("s0", min=2, max=1024)
|
|
# dynamic_shapes = {
|
|
# "c": {},
|
|
# "a": {0: dim0_ab, 1: None},
|
|
# "b": {0: dim0_ab, 1: None},
|
|
# }
|
|
dynamic_shapes = None
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.OuterBuffers(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_while_loop_with_pytree_inputs(self):
|
|
inputs = (
|
|
torch.tensor(0, device=self.device),
|
|
(
|
|
[torch.randn(10, 20, device=self.device)],
|
|
{
|
|
"x": torch.randn(10, 20, device=self.device),
|
|
"y": torch.randn(10, 20, device=self.device),
|
|
},
|
|
),
|
|
)
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.PytreeCarry(),
|
|
[inputs],
|
|
dynamic_shapes=None,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_while_loop_with_unbacked_symint_closure(self, dynamic):
|
|
inputs = (
|
|
torch.randn(10, 20, device=self.device),
|
|
torch.randn(10, 20, device=self.device),
|
|
)
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dynamic_shapes = {
|
|
"c": {},
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
self.check_model_with_multiple_inputs(
|
|
WhileLoopModels.UnbackedSymIntClosure(),
|
|
prepend_counters(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@config.patch({"is_predispatch": True})
|
|
def test_constant(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.device = device
|
|
|
|
def forward(self, x):
|
|
t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
|
|
t = torch.sqrt(t * 3)
|
|
return x * t
|
|
|
|
self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),))
|
|
|
|
@unittest.skipIf(IS_MACOS, "no CUDA on Mac")
|
|
def test_zero_grid_with_backed_symbols(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, b):
|
|
return x + b
|
|
|
|
example_inputs = (
|
|
torch.randn((3, 2), device=self.device),
|
|
torch.randn((1, 2), device=self.device),
|
|
)
|
|
dynamic_shapes = {
|
|
"x": {0: Dim("dx"), 1: Dim.STATIC},
|
|
"b": None,
|
|
}
|
|
|
|
# Compile & run model where dynamic dim size > 0.
|
|
so_path: str = AOTIRunnerUtil.compile(
|
|
Repro(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
|
|
aot_inductor_module(*example_inputs)
|
|
|
|
# Re-run where dynamic dim size is 0.
|
|
example_inputs = (
|
|
torch.randn((0, 2), device=self.device),
|
|
torch.randn((1, 2), device=self.device),
|
|
)
|
|
actual = aot_inductor_module(*example_inputs)
|
|
expected = Repro()(*example_inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_repeat_interleave(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.repeat_interleave.Tensor(x, output_size=12)
|
|
|
|
example_inputs = (torch.ones((1,), dtype=torch.int32, device=self.device) * 12,)
|
|
self.check_model(Repro(), example_inputs)
|
|
|
|
def test_dynamic_cat(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
return torch.cat([a, b], dim=0)
|
|
|
|
a = torch.randn(2, 4, device=self.device)
|
|
b = torch.randn(3, 4, device=self.device)
|
|
dim0_a = Dim("dim0_a", min=1, max=10)
|
|
dim0_b = Dim("dim0_b", min=1, max=20)
|
|
dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}}
|
|
example_inputs = (a, b)
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_buffer_mutation_1(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.foo = torch.nn.Buffer(torch.randn(4, 4, device=device))
|
|
|
|
def forward(self, x):
|
|
self.foo.add_(1)
|
|
return self.foo + x
|
|
|
|
example_inputs = (torch.rand(4, 4, device=self.device),)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_non_tensor_input(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b, alpha=1.0):
|
|
return torch.add(a, b, alpha=alpha)
|
|
|
|
a = torch.randn(10, device=self.device)
|
|
b = torch.randn(10, device=self.device)
|
|
|
|
for simdlen in [0, None]:
|
|
with torch._inductor.config.patch({"cpp.simdlen": simdlen}):
|
|
so_path = torch._export.aot_compile(
|
|
torch.ops.aten.add,
|
|
args=(a, b),
|
|
kwargs={"alpha": 2.0},
|
|
)
|
|
kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path)
|
|
res = kernel_runner.run([a, b])
|
|
self.assertTrue(isinstance(res, list))
|
|
self.assertTrue(len(res) == 1)
|
|
self.assertEqual(Model()(a, b, alpha=2.0), res[0])
|
|
|
|
def test_buffer_mutation_2(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.foo = torch.nn.Buffer(torch.arange(10, device=device))
|
|
self.bar = torch.nn.Buffer(torch.arange(10, device=device))
|
|
|
|
def forward(self, x):
|
|
self.bar.mul_(2)
|
|
self.foo[5] = self.bar[0]
|
|
return x + self.bar, x * self.foo
|
|
|
|
example_inputs = (torch.randn(10, device=self.device),)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_buffer_mutation_3(self):
|
|
class KVCache(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
max_batch_size,
|
|
max_seq_length,
|
|
n_heads,
|
|
head_dim,
|
|
dtype=torch.float,
|
|
):
|
|
super().__init__()
|
|
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
|
self.k_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
|
|
self.v_cache = torch.nn.Buffer(torch.zeros(cache_shape, dtype=dtype))
|
|
|
|
def update(self, input_pos, k_val, v_val):
|
|
# input_pos: [S], k_val: [B, H, S, D]
|
|
k_out = self.k_cache
|
|
v_out = self.v_cache
|
|
k_out[:, :, input_pos] = k_val
|
|
v_out[:, :, input_pos] = v_val
|
|
|
|
return k_out, v_out
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.kv_cache = KVCache(1, 256, 6, 48)
|
|
|
|
def forward(self, inp_pos, k, v):
|
|
self.kv_cache.update(inp_pos, k, v)
|
|
return self.kv_cache.k_cache + 1, self.kv_cache.v_cache / 2
|
|
|
|
example_inputs = (
|
|
torch.tensor([0], device=self.device),
|
|
torch.randn(1, 6, 1, 48, device=self.device),
|
|
torch.randn(1, 6, 1, 48, device=self.device),
|
|
)
|
|
model = Model(self.device)
|
|
self.check_model(model, example_inputs)
|
|
self.code_check_count(model, example_inputs, "empty_strided", 2)
|
|
|
|
def test_buffer_mutation_4(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"_tensor_constant0",
|
|
torch.randint(1, size=[38], dtype=torch.int64, device="cpu"),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x + self._tensor_constant0.to(
|
|
torch.device(type=GPU_TYPE, index=0)
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.randint(1, size=[38], dtype=torch.int64, device=GPU_TYPE),
|
|
)
|
|
torch._export.aot_compile(Model(), example_inputs)
|
|
|
|
@skipCUDAIf(True, "Test for x86 backend")
|
|
@skipIfXpu
|
|
def test_buffer_mutation_and_force_mmap_weights(self):
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(16, 15)
|
|
self.linear2 = torch.nn.Linear(15, 14)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
out = self.linear2(x)
|
|
return out
|
|
|
|
example_inputs = (torch.randn(32, 16),)
|
|
model = Model().eval()
|
|
with config.patch(
|
|
{"freezing": True, "aot_inductor.force_mmap_weights": True}
|
|
), torch.no_grad():
|
|
exported_model = export_for_training(model, example_inputs).module()
|
|
quantizer = X86InductorQuantizer()
|
|
quantizer.set_global(
|
|
xiq.get_default_x86_inductor_quantization_config(reduce_range=True)
|
|
)
|
|
prepared_model = prepare_pt2e(exported_model, quantizer)
|
|
prepared_model(*example_inputs)
|
|
converted_model = convert_pt2e(prepared_model)
|
|
torch.ao.quantization.move_exported_model_to_eval(converted_model)
|
|
|
|
self.check_model(converted_model, example_inputs)
|
|
|
|
@requires_multigpu()
|
|
def test_replicate_on_devices(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, w1, w2):
|
|
super().__init__()
|
|
self.w1 = w1
|
|
self.w2 = w2
|
|
|
|
def forward(self, x, y):
|
|
a = x * self.w1
|
|
b = y * self.w2
|
|
return a + b
|
|
|
|
w1 = torch.randn(10, 10)
|
|
w2 = torch.randn(10, 10)
|
|
inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
result_cpu = Model(w1, w2)(*inputs)
|
|
|
|
# Compile model with AOTInductor
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
with device_interface.device(0):
|
|
so_path = AOTIRunnerUtil.compile(
|
|
model=Model(
|
|
w1.to(torch.device(GPU_TYPE, 0)), w2.to(torch.device(GPU_TYPE, 0))
|
|
),
|
|
example_inputs=tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
|
|
)
|
|
|
|
# Run model on gpu:N
|
|
for i in range(device_interface.device_count()):
|
|
with device_interface.device(i):
|
|
example_inputs = tuple(t.to(torch.device(GPU_TYPE, i)) for t in inputs)
|
|
optimized = AOTIRunnerUtil.load(GPU_TYPE, so_path)
|
|
result_gpu = optimized(*example_inputs)
|
|
self.assertTrue(same(result_cpu, result_gpu.cpu()))
|
|
|
|
@requires_multigpu()
|
|
def test_on_gpu_device1(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
try:
|
|
device_interface.get_device_properties(1)
|
|
except AssertionError:
|
|
raise unittest.SkipTest("GPU device 1 is not available") from None
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(10, 16)
|
|
self.relu = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.Linear(16, 1)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
x = self.sigmoid(x)
|
|
return x
|
|
|
|
device = f"{GPU_TYPE}:1"
|
|
model = Model().to(device)
|
|
example_inputs = (torch.randn(8, 10, device=device),)
|
|
expected = model(*example_inputs)
|
|
|
|
so_path = AOTIRunnerUtil.compile(model, example_inputs)
|
|
optimized = AOTIRunnerUtil.load(device, so_path)
|
|
actual = optimized(*example_inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_pytree_inputs(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: dict[str, torch.Tensor]):
|
|
device = next(iter(x.values())).device
|
|
add_ = torch.zeros(5, device=device)
|
|
mul_ = torch.ones(5, device=device)
|
|
for v in x.values():
|
|
add_ += v
|
|
mul_ *= v
|
|
|
|
return [add_, mul_]
|
|
|
|
self.check_model(
|
|
M(),
|
|
(
|
|
{
|
|
"x": torch.ones(5, device=self.device),
|
|
"y": torch.ones(5, device=self.device),
|
|
},
|
|
),
|
|
)
|
|
|
|
@requires_multigpu()
|
|
def test_non_default_gpu_device(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = weight
|
|
|
|
def forward(self, x, y):
|
|
return x + torch.nn.functional.linear(y, self.weight)
|
|
|
|
weight = torch.randn(10, 10)
|
|
inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
result_cpu = Model(weight)(*inputs)
|
|
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
with device_interface.device(0), torch.no_grad():
|
|
result_gpu_0 = AOTIRunnerUtil.run(
|
|
GPU_TYPE,
|
|
Model(weight.to(torch.device(GPU_TYPE, 0))),
|
|
tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
|
|
)
|
|
|
|
with device_interface.device(1), torch.no_grad():
|
|
result_gpu_1 = AOTIRunnerUtil.run(
|
|
GPU_TYPE,
|
|
Model(weight.to(torch.device(GPU_TYPE, 1))),
|
|
tuple(t.to(torch.device(GPU_TYPE, 1)) for t in inputs),
|
|
)
|
|
|
|
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
|
|
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
|
|
|
|
def test_reuse_kernel(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
a = torch.sin(x)
|
|
b = torch.mm(a, y)
|
|
c = torch.sin(b)
|
|
d = torch.mm(b, c)
|
|
return d
|
|
|
|
example_inputs = (
|
|
torch.randn(87, 87, device=self.device),
|
|
torch.randn(87, 87, device=self.device),
|
|
)
|
|
model = Model()
|
|
self.check_model(
|
|
model, example_inputs, atol=1e-4, rtol=1e-4
|
|
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
|
|
|
|
if self.device == GPU_TYPE:
|
|
self.code_check_count(
|
|
model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1
|
|
)
|
|
|
|
def test_reuse_kernel_dynamic(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.cst = torch.randn(48, device=device, dtype=torch.float)
|
|
self.weights = torch.randn(6, 48, 48, device=device, dtype=torch.float)
|
|
self.cst_1 = torch.randn(48, device=device, dtype=torch.float)
|
|
self.weights_1 = torch.randn(
|
|
6, 48, 48, device=device, dtype=torch.float
|
|
)
|
|
|
|
def forward(self, x, y, z):
|
|
dim0 = x.size(1)
|
|
add_0 = z + z
|
|
expand_2 = add_0.expand(-1, -1, 48)
|
|
# [s0, 6, 48]
|
|
mul_3 = add_0 * expand_2
|
|
# [6, s0, 48]
|
|
permute_4 = torch.permute(mul_3, (1, 0, 2))
|
|
# [6, s0, 48]
|
|
bmm_5 = torch.bmm(permute_4, self.weights)
|
|
add_6 = bmm_5 + self.cst
|
|
reshape_7 = torch.reshape(add_6, [6, dim0 * 6, 8])
|
|
# [6*s0, 6, 8]
|
|
permute_8 = torch.permute(reshape_7, (1, 0, 2))
|
|
mul_9 = permute_8 * 0.123
|
|
reshape_10 = torch.reshape(y, [8, dim0 * 6, 4])
|
|
# [6*s0, 8, 4]
|
|
permute_11 = torch.permute(reshape_10, (1, 0, 2))
|
|
bmm_12 = torch.bmm(mul_9, permute_11)
|
|
|
|
add_0_1 = z + z
|
|
expand_2_1 = add_0_1.expand(-1, -1, 48)
|
|
# [s0, 6, 48]
|
|
mul_3_1 = add_0_1 * expand_2_1
|
|
# [6, s0, 48]
|
|
permute_4_1 = torch.permute(mul_3_1, (1, 0, 2))
|
|
# [6, s0, 48]
|
|
bmm_5_1 = torch.bmm(permute_4_1, self.weights_1)
|
|
add_6_1 = bmm_5_1 + self.cst_1
|
|
reshape_7_1 = torch.reshape(add_6_1, [6, dim0 * 6, 8])
|
|
# [6*s0, 6, 8]
|
|
permute_8_1 = torch.permute(reshape_7_1, (1, 0, 2))
|
|
mul_9_1 = permute_8_1 * 0.123
|
|
reshape_10_1 = torch.reshape(y, [8, dim0 * 6, 4])
|
|
# [6*s0, 8, 4]
|
|
permute_11_1 = torch.permute(reshape_10_1, (1, 0, 2))
|
|
bmm_12_1 = torch.bmm(mul_9_1, permute_11_1)
|
|
return bmm_12 + bmm_12_1
|
|
|
|
x = torch.randn(6, 2, 48, device=self.device, dtype=torch.float)
|
|
y = torch.randn(48, 2, 4, device=self.device, dtype=torch.float)
|
|
z = torch.randn(2, 6, 1, device=self.device, dtype=torch.float)
|
|
dim0 = Dim("dim0", min=1, max=2048)
|
|
dynamic_shapes = {
|
|
"x": {1: dim0},
|
|
"y": {1: dim0},
|
|
"z": {0: dim0},
|
|
}
|
|
|
|
example_inputs = (x, y, z)
|
|
m = Model(self.device).to(dtype=torch.float)
|
|
self.check_model(m, example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_fake_tensor_device_validation(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
|
|
# Export on CPU
|
|
exported_program = export(Model(), example_inputs, strict=True)
|
|
|
|
# Compile exported model on GPU
|
|
gm = exported_program.graph_module.to(self.device)
|
|
with self.assertRaisesRegex(ValueError, "Device mismatch between fake input"):
|
|
torch._inductor.aot_compile(
|
|
gm, tuple(i.to(self.device) for i in example_inputs)
|
|
)
|
|
|
|
def test_fx_gm_return_tuple_validation(self):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
|
|
|
gm = make_fx(Model(), tracing_mode="symbolic")(*example_inputs)
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
r"Graph output must be a tuple\(\). This is so that we can avoid "
|
|
"pytree processing of the outputs.",
|
|
):
|
|
torch._inductor.aot_compile(gm, example_inputs)
|
|
|
|
def test_consecutive_compiles(self):
|
|
"""Test that compilation behaves correctly with cache hits"""
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
mod = TestModule()
|
|
inp = torch.rand(1)
|
|
mod(inp)
|
|
mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp])
|
|
so = torch._export.aot_compile(mod2, (inp,))
|
|
assert so is not None
|
|
# compile the 2nd time with cache hit
|
|
so = torch._export.aot_compile(mod2, (inp,))
|
|
assert so is not None
|
|
|
|
def test_normal_functional(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.normal_functional.default(x)
|
|
|
|
self.check_model(Model(), (torch.empty(4, 1, 4, 4, device=self.device),))
|
|
|
|
def test_empty_graph(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
example_inputs = (torch.randn(8, 4, 4, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
|
def test_dup_unbacked_sym_decl(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
abs_1 = torch.ops.aten.abs.default(x)
|
|
lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
|
|
eq = torch.ops.aten.eq.Scalar(lt, 0)
|
|
index_1 = torch.ops.aten.index.Tensor(x, [eq])
|
|
sin = torch.ops.aten.sin.default(index_1)
|
|
index_2 = torch.ops.aten.index.Tensor(x, [eq])
|
|
div_3 = torch.ops.aten.div.Tensor(sin, index_2)
|
|
return div_3
|
|
|
|
example_inputs = (torch.randn(4, 4, 4, 4).to(self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
# This exercises _eliminate_unbacked path in ShapeEnv
|
|
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
|
def test_dup_unbacked_sym_decl_with_refinement(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
abs_1 = torch.ops.aten.abs.default(x)
|
|
lt = torch.ops.aten.lt.Scalar(abs_1, 0.001)
|
|
eq = torch.ops.aten.eq.Scalar(lt, 0)
|
|
index_1 = torch.ops.aten.index.Tensor(x, [eq])
|
|
torch._check(index_1.size(0) == 4**4)
|
|
sin = torch.ops.aten.sin.default(index_1)
|
|
index_2 = torch.ops.aten.index.Tensor(x, [eq])
|
|
div_3 = torch.ops.aten.div.Tensor(sin, index_2)
|
|
return div_3
|
|
|
|
example_inputs = (torch.ones(4, 4, 4, 4).to(self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_run_with_grad_enabled(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, weight, bias):
|
|
return torch.ops.aten.addmm(bias, weight, x)
|
|
|
|
m = Model().to(device=self.device)
|
|
x = torch.rand(8, 8, device=self.device, requires_grad=True)
|
|
weight = torch.rand(8, 8, device=self.device, requires_grad=True)
|
|
bias = torch.rand(8, device=self.device, requires_grad=True)
|
|
example_inputs = (x, weight, bias)
|
|
|
|
expected = m(*example_inputs)
|
|
expected = pytree.tree_leaves(expected)
|
|
|
|
# compiler under no_grad
|
|
with torch.no_grad():
|
|
so_path = AOTIRunnerUtil.compile(m, example_inputs)
|
|
|
|
# run under grad enabled
|
|
self.assertTrue(torch.is_grad_enabled())
|
|
|
|
optimized = AOTIRunnerUtil.load(self.device, so_path)
|
|
actual = optimized(*example_inputs)
|
|
actual = pytree.tree_leaves(actual)
|
|
|
|
self.assertTrue(same(actual, expected))
|
|
|
|
def test_return_constant(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.cst = torch.randn(5, 5, device=device)
|
|
|
|
def forward(self, x):
|
|
a = self.cst.clone()
|
|
return (x, a)
|
|
|
|
x = torch.randn(5, device=self.device)
|
|
self.check_model(Model(self.device), (x,))
|
|
|
|
def test_return_view_constant(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.cst = torch.randn(5, 5, device=device)
|
|
|
|
def forward(self, x):
|
|
a = torch.transpose(self.cst, 0, 1)
|
|
return (x, a)
|
|
|
|
x = torch.randn(5, device=self.device)
|
|
self.check_model(Model(self.device), (x,))
|
|
|
|
def test_profile_benchmark_harness(self):
|
|
batch_size = 32
|
|
seq_length = 50
|
|
hidden_size = 768
|
|
|
|
def create_test_fn():
|
|
def test_fn():
|
|
inp = torch.randn(
|
|
batch_size, seq_length, hidden_size, device=self.device
|
|
)
|
|
weight = torch.randn(hidden_size, hidden_size, device=self.device)
|
|
matmul_output = inp @ weight
|
|
torch.nn.LayerNorm(hidden_size, device=self.device)(matmul_output)
|
|
return True
|
|
|
|
return test_fn
|
|
|
|
fn = torch.compile(
|
|
options={"profile_bandwidth_output": "foo", "benchmark_harness": False}
|
|
)(create_test_fn())
|
|
fn()
|
|
|
|
def test_with_profiler(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
|
|
def forward(self, x, y):
|
|
return x + self.linear(y)
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 10, device=self.device),
|
|
torch.randn(10, 10, device=self.device),
|
|
)
|
|
with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_with_no_triton_profiler(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.permute(x, (1, 0))
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
with config.patch({"profile_bandwidth": "1", "profile_bandwidth_regex": ""}):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_repeat_output(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
return y, y
|
|
|
|
example_inputs = (torch.randn(3, 10, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_repeated_calling(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("requires CUDA")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.sin(x)
|
|
|
|
example_inputs = (torch.randn(10, 10, device=self.device),)
|
|
optimized = torch._inductor.aoti_load_package(
|
|
torch._inductor.aoti_compile_and_package(
|
|
torch.export.export(Model(), example_inputs, strict=True)
|
|
)
|
|
)
|
|
try:
|
|
torch.cuda.memory.empty_cache()
|
|
torch.cuda.memory._record_memory_history(context=None)
|
|
for _ in range(10):
|
|
optimized(*example_inputs)
|
|
finally:
|
|
torch.cuda.memory._record_memory_history(False)
|
|
segments = torch.cuda.memory._snapshot()["segments"]
|
|
self.assertEqual(segments[0]["requested_size"], 400)
|
|
|
|
def test_view_outputs(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
y_same_size = y.view(*y.shape)
|
|
y_diff_size = y.view(1, *y.shape)
|
|
return y, y_same_size, y_diff_size
|
|
|
|
example_inputs = (torch.randn(3, 10, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@skip_if_no_torchvision
|
|
def test_missing_cubin(self):
|
|
from torchvision.models.resnet import Bottleneck, ResNet
|
|
|
|
class Model(ResNet):
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
block=Bottleneck,
|
|
layers=[3, 4, 6, 3],
|
|
replace_stride_with_dilation=[False, False, True],
|
|
norm_layer=None,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
f1 = x
|
|
x = self.maxpool(x)
|
|
x = self.layer1(x)
|
|
f2 = x
|
|
x = self.layer2(x)
|
|
f3 = x
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
f4 = x
|
|
return [f1, f2, f3, f4]
|
|
|
|
# Call eval() here so that batch_norm won't update the running stats
|
|
# Use float64 to avoid numeric difference failure
|
|
model = Model().to(device=self.device, dtype=torch.float64).eval()
|
|
example_inputs = (
|
|
torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64),
|
|
)
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_triton_next_power_of_2(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b, lengths):
|
|
n_elements = a.numel()
|
|
out = torch.empty_like(a)
|
|
max_len = int(lengths.max())
|
|
scaling_factor = triton.next_power_of_2(max_len)
|
|
add_kernel_with_scaling[(n_elements,)](
|
|
a,
|
|
b,
|
|
out,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(2, device=self.device),
|
|
torch.randn(2, device=self.device),
|
|
torch.arange(end=4, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("grid_type", [1, 2, 3])
|
|
@common_utils.parametrize("num_dims", [1, 2])
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("autotune", [False, True])
|
|
def test_triton_kernel(self, grid_type, num_dims, dynamic, autotune):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
output = torch.zeros_like(x)
|
|
if autotune and num_dims == 2:
|
|
x_elements = output.size()[0]
|
|
y_elements = output.size()[1]
|
|
else:
|
|
n_elements = output.numel()
|
|
|
|
# Select grid
|
|
if autotune and num_dims == 2:
|
|
if grid_type == 1:
|
|
grid = (x_elements, y_elements)
|
|
elif grid_type == 2:
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
else:
|
|
|
|
def grid_fn(meta):
|
|
return (
|
|
triton.cdiv(x_elements, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_elements, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
|
|
grid = grid_fn
|
|
else:
|
|
if grid_type == 1:
|
|
grid = (n_elements,)
|
|
elif grid_type == 2:
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
else:
|
|
|
|
def grid_fn(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
grid = grid_fn
|
|
|
|
# Select kernel
|
|
if autotune:
|
|
if num_dims == 1:
|
|
add_kernel_autotuned[grid](x, y, output, n_elements)
|
|
else:
|
|
add_kernel_2d_autotuned[grid](
|
|
x, y, output, x_elements, y_elements
|
|
)
|
|
else:
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
dims = [10] * num_dims
|
|
x = torch.randn(*dims, device=self.device)
|
|
y = torch.randn(*dims, device=self.device)
|
|
dynamic_shapes = []
|
|
if dynamic:
|
|
dim0_x = Dim("dim0_x", min=1, max=10)
|
|
dim0_y = Dim("dim0_y", min=1, max=10)
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
self.check_model(Model(), (x, y), dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_triton_kernel_dynamic_shape_with_div(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.jit
|
|
def pass_kernel(x, num):
|
|
pass
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
num = x.numel() // 4
|
|
|
|
grid = lambda meta: (triton.cdiv(num, 16),) # noqa: E731
|
|
pass_kernel[grid](x, num)
|
|
return x
|
|
|
|
x = torch.randn(10, device=self.device)
|
|
dim0_x = Dim("dim0_x", min=1, max=10)
|
|
dynamic_shapes = {"x": {0: dim0_x}}
|
|
self.check_model(Model(), (x,), dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_triton_kernel_reinterpret_view(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.jit
|
|
def pass_kernel(x, y):
|
|
pass
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
out = torch.zeros_like(x[:, 4:])
|
|
# the slicing below creates two ReinterpretView
|
|
# instances: with offset=3 and offset=4
|
|
add_kernel[(10,)](
|
|
in_ptr0=x[:, 3:-1],
|
|
in_ptr1=x[:, 4:],
|
|
out_ptr=out,
|
|
n_elements=160,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
example_inputs = (torch.randn(10, 20, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_tma_descriptor_1d(self, dynamic):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
if not has_triton_tma():
|
|
raise unittest.SkipTest("requires Triton TMA")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
BLOCK_SIZE = 256
|
|
out = torch.zeros_like(a)
|
|
n_elements = out.numel()
|
|
|
|
desc_a, desc_b, desc_out = (
|
|
triton.tools.experimental_descriptor.create_1d_tma_descriptor(
|
|
t.data_ptr(),
|
|
n_elements,
|
|
BLOCK_SIZE,
|
|
t.element_size(),
|
|
)
|
|
for t in (a, b, out)
|
|
)
|
|
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
add_kernel_with_tma_1d[grid](
|
|
desc_a,
|
|
desc_b,
|
|
desc_out,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn(301, device=self.device)
|
|
b = torch.randn(301, device=self.device)
|
|
example_inputs = (a, b)
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs=example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_tma_descriptor_2d(self, dynamic):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
if not has_triton_tma():
|
|
raise unittest.SkipTest("requires Triton TMA")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b):
|
|
BLOCK_SIZE_X = 16
|
|
BLOCK_SIZE_Y = 32
|
|
out = torch.zeros_like(a)
|
|
x_size, y_size = out.size()
|
|
|
|
desc_a, desc_b, desc_out = (
|
|
triton.tools.experimental_descriptor.create_2d_tma_descriptor(
|
|
t.data_ptr(),
|
|
x_size,
|
|
y_size,
|
|
BLOCK_SIZE_X,
|
|
BLOCK_SIZE_Y,
|
|
t.element_size(),
|
|
)
|
|
for t in (a, b, out)
|
|
)
|
|
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(x_size, meta["BLOCK_SIZE_X"]),
|
|
triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]),
|
|
)
|
|
add_kernel_with_tma_2d[grid](
|
|
desc_a,
|
|
desc_b,
|
|
desc_out,
|
|
BLOCK_SIZE_X=BLOCK_SIZE_X,
|
|
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
|
|
)
|
|
|
|
return out
|
|
|
|
a = torch.randn((25, 16), device=self.device)
|
|
b = torch.randn((25, 16), device=self.device)
|
|
example_inputs = (a, b)
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_ab = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_ab, 1: None},
|
|
"b": {0: dim0_ab, 1: None},
|
|
}
|
|
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs=example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_triton_kernel_sympy_expr_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, e):
|
|
sympy_expr = max(1, e.item())
|
|
out = torch.zeros_like(x)
|
|
add_kernel[(1,)](
|
|
in_ptr0=x,
|
|
in_ptr1=x,
|
|
out_ptr=out,
|
|
n_elements=sympy_expr,
|
|
BLOCK_SIZE=1,
|
|
)
|
|
return out
|
|
|
|
NUMEL = 64
|
|
inputs = (
|
|
torch.randn(NUMEL, device=self.device),
|
|
torch.tensor(NUMEL, device=self.device),
|
|
)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_triton_kernel_sympy_fn_like_arg(self):
|
|
# This test should hit sympy.expand("sqrt") which crashes with
|
|
# AttributeError: 'function' object has no attribute 'expand'.
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
out = torch.zeros_like(x)
|
|
add_kernel_with_optional_param[1,](
|
|
in_ptr0=x,
|
|
in_ptr1=x,
|
|
out_ptr=out,
|
|
n_elements=x.numel(),
|
|
BLOCK_SIZE=1,
|
|
ARGS_PASSED="sqrt", # sqrt is a valid sympy fn
|
|
)
|
|
return out
|
|
|
|
inputs = (torch.randn(4, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_triton_kernel_with_none_input(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
n_elements = x.size()[0]
|
|
BLOCK_SIZE = 1024
|
|
|
|
output_wo_y = torch.empty_like(x)
|
|
output_with_y = torch.empty_like(x)
|
|
|
|
add_kernel_with_optional_param[(1,)](
|
|
x,
|
|
None,
|
|
output_wo_y,
|
|
n_elements,
|
|
ARGS_PASSED="one",
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
add_kernel_with_optional_param[(1,)](
|
|
x,
|
|
y,
|
|
output_with_y,
|
|
n_elements,
|
|
ARGS_PASSED="two",
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
|
|
return 2.71 * output_wo_y + 3.14 * output_with_y
|
|
|
|
example_inputs = (
|
|
torch.randn(1023, device=self.device),
|
|
torch.randn(1023, device=self.device),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_triton_kernel_equal_to_1_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
add_kernel[(n_elements,)](x, y, out, n_elements, BLOCK_SIZE=16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(1, device=self.device),
|
|
torch.randn(1, device=self.device),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.empty_like(x)
|
|
n_elements = x.numel()
|
|
scaling_factor = (n_elements**0) / 1.0
|
|
add_kernel_with_scaling[(n_elements,)](
|
|
x,
|
|
y,
|
|
out,
|
|
n_elements,
|
|
scaling_factor,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
return out
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0_xy = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x": {0: dim0_xy, 1: None},
|
|
"y": {0: dim0_xy, 1: None},
|
|
}
|
|
example_inputs = (
|
|
torch.randn(2, device=self.device),
|
|
torch.randn(2, device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_triton_kernel_weird_param_order(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
out = torch.empty_like(x)
|
|
add_kernel_autotuned_weird_param_order[16,](
|
|
in_ptr0=x,
|
|
in_ptr1=x,
|
|
n_elements=x.numel(),
|
|
out_ptr=out,
|
|
)
|
|
return out
|
|
|
|
x = torch.randn(16, 16, device=self.device)
|
|
self.check_model(Model(), (x,))
|
|
|
|
def test_shifted_constraint_ranges(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
y: torch.Tensor,
|
|
):
|
|
torch._check(y.size(0) == x.size(0) + 1)
|
|
return x.sum(0) + y.sum(0)
|
|
|
|
a = torch.randn((4, 5), device=self.device)
|
|
b = torch.randn((5, 5), device=self.device)
|
|
dim0_x = Dim("dim0_x", min=2, max=1024)
|
|
dim0_y = dim0_x + 1
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
self.check_model(
|
|
Model(),
|
|
(a, b),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_scatter_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
inp: torch.Tensor,
|
|
index: torch.Tensor,
|
|
src: torch.Tensor,
|
|
):
|
|
return torch.scatter(inp, 1, index, src)
|
|
|
|
inputs = (
|
|
torch.ones((3, 5), device=self.device, dtype=torch.int64),
|
|
torch.tensor([[0, 1, 2, 0]], device=self.device, dtype=torch.int64),
|
|
torch.zeros((2, 5), device=self.device, dtype=torch.int64),
|
|
)
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_scatter_reduce_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
inp: torch.Tensor,
|
|
index: torch.Tensor,
|
|
src: torch.Tensor,
|
|
):
|
|
return torch.scatter_reduce(inp, 0, index, src, reduce="sum")
|
|
|
|
inputs = (
|
|
torch.tensor([1, 10, 100, 1000], device=self.device, dtype=torch.int64),
|
|
torch.tensor([0, 1, 0, 1, 2, 1], device=self.device, dtype=torch.int64),
|
|
torch.tensor([1, 2, 3, 4, 5, 6], device=self.device, dtype=torch.int64),
|
|
)
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_index_put_fallback(self):
|
|
# index_put falls back in the deterministic mode
|
|
with DeterministicGuard(True):
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
self_tensor: torch.Tensor,
|
|
indices: tuple[torch.Tensor],
|
|
values: torch.Tensor,
|
|
):
|
|
return torch.index_put(
|
|
self_tensor, indices, values, accumulate=True
|
|
)
|
|
|
|
inputs = (
|
|
torch.ones(4, device=self.device, dtype=torch.int64),
|
|
(torch.tensor([1, 1, 2, 2], device=self.device, dtype=torch.bool),),
|
|
torch.ones(4, device=self.device, dtype=torch.int64),
|
|
)
|
|
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_repeated_user_defined_triton_kernel(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
for _ in range(3):
|
|
mul2_inplace_kernel[4,](x, n_elements=4, BLOCK_SIZE=16)
|
|
return x
|
|
|
|
inputs = (torch.randn(4, 4, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_convolution(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, w, b):
|
|
return torch.ops.aten.convolution(x, w, b, [4], [0], [1], True, [0], 1)
|
|
|
|
example_inputs = (
|
|
torch.randn([2, 32, 90], device=self.device),
|
|
torch.randn([32, 16, 8], device=self.device),
|
|
torch.randn([16], device=self.device),
|
|
)
|
|
with config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"max_autotune_gemm_backends": "Triton",
|
|
}
|
|
):
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_zero_size_weight(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, channel, r=8):
|
|
super().__init__()
|
|
self.pool = torch.nn.AdaptiveAvgPool2d(1)
|
|
self.net = torch.nn.Sequential(
|
|
torch.nn.Linear(channel, channel // r, bias=False),
|
|
torch.nn.ReLU(inplace=True),
|
|
torch.nn.Linear(channel // r, channel, bias=False),
|
|
torch.nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, inp):
|
|
b, c, _, _ = inp.shape
|
|
x = self.pool(inp).view(b, c)
|
|
x = self.net(x).view(b, c, 1, 1)
|
|
x = inp * x
|
|
return x
|
|
|
|
inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
|
|
self.check_model(Model(4), inputs)
|
|
|
|
def test_zero_size_buffer(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.foo = torch.nn.Buffer(torch.zeros((0, 0), device=device))
|
|
|
|
def forward(self, x):
|
|
return x + 1, self.foo
|
|
|
|
example_inputs = (torch.rand(4, 4, device=self.device),)
|
|
self.check_model(Model(self.device), example_inputs)
|
|
|
|
def test_no_args(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, m, n):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(
|
|
torch.randn(m, n),
|
|
)
|
|
self.alpha = torch.nn.Parameter(torch.randn(m, n))
|
|
|
|
def forward(self):
|
|
return self.weight * self.alpha
|
|
|
|
self.check_model(Model(6, 4), ())
|
|
|
|
def test_dynamic_scalar(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.criterion_ce = torch.nn.CrossEntropyLoss(reduction="none")
|
|
|
|
def forward(self, inputs, targets, split_index=None):
|
|
statistics = {}
|
|
total_loss = self.criterion_ce(inputs, targets).sum()
|
|
statistics["dl"] = total_loss.item()
|
|
return total_loss, statistics
|
|
|
|
inputs = (
|
|
torch.rand(4, 4, 4, 4, device=self.device),
|
|
torch.rand(4, 4, 4, 4, device=self.device),
|
|
)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_symint_item(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, tensor):
|
|
return tensor.item()
|
|
|
|
inputs = (torch.tensor([1], dtype=torch.int, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_symbool_item(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, tensor):
|
|
return tensor.item()
|
|
|
|
inputs = (torch.tensor([0], dtype=torch.bool, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_symfloat_item(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, tensor):
|
|
return tensor.item()
|
|
|
|
inputs = (torch.tensor([3.14], dtype=torch.float, device=self.device),)
|
|
self.check_model(Model(), inputs)
|
|
|
|
def test_constant_original_fqn_and_dtype(self):
|
|
class FooBarModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4)))
|
|
self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
|
|
self.register_parameter(
|
|
"test_param", torch.nn.Parameter(torch.randn(3, 4))
|
|
)
|
|
|
|
def forward(self, x):
|
|
return ((x + self.test_buf) * getattr(self, "0")) / self.test_param
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo_bar = FooBarModule()
|
|
self.register_parameter(
|
|
"test_param", torch.nn.Parameter(torch.randn(3, 4))
|
|
)
|
|
self.test_buf = torch.nn.Buffer(torch.randn(3, 4))
|
|
|
|
def forward(self, x):
|
|
return (self.foo_bar(x) + self.test_param) * self.test_buf
|
|
|
|
with torch.no_grad():
|
|
so_path = AOTIRunnerUtil.compile(
|
|
model=TestModule().to(device=self.device),
|
|
example_inputs=(torch.rand(3, 4, device=self.device),),
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.load_runner(self.device, so_path)
|
|
|
|
expected_original_fqns = {
|
|
"L__self___test_param": "test_param",
|
|
"L__self___test_buf": "test_buf",
|
|
"getattr_L__self___foo_bar___0__": "foo_bar.0",
|
|
"L__self___foo_bar_test_param": "foo_bar.test_param",
|
|
"L__self___foo_bar_test_buf": "foo_bar.test_buf",
|
|
}
|
|
self.assertEqual(
|
|
expected_original_fqns, runner.get_constant_names_to_original_fqns()
|
|
)
|
|
|
|
expected_dtypes = {
|
|
"L__self___test_param": 6,
|
|
"L__self___test_buf": 6,
|
|
"getattr_L__self___foo_bar___0__": 6,
|
|
"L__self___foo_bar_test_param": 6,
|
|
"L__self___foo_bar_test_buf": 6,
|
|
}
|
|
self.assertEqual(expected_dtypes, runner.get_constant_names_to_dtypes())
|
|
|
|
def test_masked_select_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
mask = x.ge(0.5)
|
|
return torch.masked_select(x, mask)
|
|
|
|
example_args = (torch.randn(3, 4, 5, device=self.device),)
|
|
dim0_x_max, dim1_x_max = 100, 7
|
|
dynamic_shapes = {
|
|
"x": {
|
|
0: Dim("dim0_x", max=dim0_x_max),
|
|
1: Dim("dim1_x_max", max=dim1_x_max),
|
|
}
|
|
}
|
|
m = M()
|
|
self.check_model(m, example_args, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_fqn(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nestedchild3buffer = torch.nn.Buffer(torch.ones(2, 3) * 3)
|
|
|
|
def forward(self, x):
|
|
return x / self.nestedchild3buffer
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
self.register_parameter(
|
|
"child1param", torch.nn.Parameter(torch.ones(2, 3))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.nested(x)
|
|
return x + self.child1param
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3) * 2)
|
|
|
|
def forward(self, x):
|
|
return x - self.child2buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
self.register_parameter(
|
|
"rootparam", torch.nn.Parameter(torch.ones(2, 3) * 4)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x * self.rootparam
|
|
x = self.foo(x)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
self.check_model(MyModule(), (torch.randn(2, 3, device=self.device),))
|
|
|
|
def test_model_modified_weights(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 16
|
|
N = 10
|
|
K = 128
|
|
example_inputs = (torch.randn(2, M, K, device=self.device),)
|
|
model = Model(N, K, self.device)
|
|
self.check_model(model, example_inputs)
|
|
# Update model weights, after this AOTInductor should re-generate model.so
|
|
# if weights are stored in the model.so
|
|
model.weight += 1
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_triton_kernel_extern_kernel_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
# torch.mm is ExternKernelOut
|
|
add_kernel[(4,)](x, torch.mm(x, y), out, 4, 16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_triton_kernel_multi_output_arg(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
# torch.sort creates fallback kernel and hence MultiOutput
|
|
add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
# @skipIfXpu(msg="torch.xpu.memory_allocated not supported yet")
|
|
def test_triton_kernel_reinterpret_view_mem_leak(self):
|
|
# Check for memory leak when using user-defined Triton Kernel + AOTI.
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
yy = y * y
|
|
# reshape creates a ReinterpretView
|
|
add_kernel[(4,)](x, yy.reshape_as(x), out, 4, 16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=GPU_TYPE),
|
|
torch.randn(1, 16, device=GPU_TYPE),
|
|
)
|
|
|
|
so_path: str = AOTIRunnerUtil.compile(
|
|
Model(),
|
|
example_inputs,
|
|
)
|
|
aot_inductor_module = AOTIRunnerUtil.load(GPU_TYPE, so_path)
|
|
|
|
# Don't assign outputs to a variable b/c it will allocate GPU memory.
|
|
device_interface = get_interface_for_device(GPU_TYPE)
|
|
device: int = device_interface.current_device()
|
|
mem_before = device_interface.memory_allocated(device)
|
|
aot_inductor_module(*example_inputs)
|
|
aot_inductor_module(*example_inputs)
|
|
mem_after = device_interface.memory_allocated(device)
|
|
self.assertEqual(mem_before, mem_after)
|
|
|
|
actual = aot_inductor_module(*example_inputs)
|
|
expected = Model()(*example_inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
@common_utils.parametrize("dynamic", [False, True])
|
|
@common_utils.parametrize("autotuning", [False, True])
|
|
def test_triton_kernel_unbacked_symint_in_grid(self, dynamic, autotuning):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y, n_elements_tensor):
|
|
output = torch.zeros_like(x)
|
|
n_elements_symint = n_elements_tensor.item()
|
|
n_elements = x.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements_symint, meta["BLOCK_SIZE"]),)
|
|
|
|
if autotuning:
|
|
add_kernel_autotuned[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
)
|
|
else:
|
|
add_kernel[grid](
|
|
x,
|
|
y,
|
|
output,
|
|
n_elements,
|
|
BLOCK_SIZE=16,
|
|
)
|
|
|
|
return output
|
|
|
|
example_inputs = (
|
|
torch.randn(123, device=GPU_TYPE),
|
|
torch.randn(123, device=GPU_TYPE),
|
|
torch.tensor(123),
|
|
)
|
|
|
|
dynamic_shapes = None
|
|
if dynamic:
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x": {0: dim0},
|
|
"y": {0: dim0},
|
|
"n_elements_tensor": {},
|
|
}
|
|
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_scaled_dot_product_efficient_attention(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, q, k, v, attn_bias):
|
|
return torch.ops.aten._scaled_dot_product_efficient_attention(
|
|
q, k, v, attn_bias, False
|
|
)[0]
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_index_put_with_none_index(self):
|
|
# index_put falls back in the deterministic mode
|
|
with DeterministicGuard(True):
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, i1, i2, y):
|
|
return torch.ops.aten.index_put(
|
|
x,
|
|
(None, None, i1, i2.transpose(0, 1)),
|
|
y,
|
|
accumulate=True,
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.rand(8, 192, 30, 30, device=self.device),
|
|
torch.zeros(3, 14, 1, 1, dtype=torch.int64, device=self.device),
|
|
torch.ones(14, 3, dtype=torch.int64, device=self.device),
|
|
torch.randn(8, 192, 3, 14, 3, 14, device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_runtime_checks(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
if SM80OrLater:
|
|
|
|
def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9):
|
|
return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9)
|
|
|
|
else:
|
|
|
|
def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9):
|
|
return (x0, x1, x2, x4, x5, x6, x7, x8, x9)
|
|
|
|
inputs = []
|
|
dtypes = [
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.float64,
|
|
torch.bool,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.uint8,
|
|
]
|
|
if SM80OrLater:
|
|
dtypes.append(torch.bfloat16)
|
|
for dtype in dtypes:
|
|
inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device))
|
|
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dim1 = Dim("s1", min=2, max=512)
|
|
dim2 = Dim("s2", min=2, max=128)
|
|
dynamic_shapes = {
|
|
"x0": {0: dim0},
|
|
"x1": {0: dim0},
|
|
"x2": {0: dim0},
|
|
"x4": {1: dim1},
|
|
"x5": {1: dim1},
|
|
"x6": {},
|
|
"x7": {2: dim2},
|
|
"x8": {2: dim2},
|
|
"x9": {2: dim2},
|
|
}
|
|
if SM80OrLater:
|
|
dynamic_shapes["x3"] = {1: dim1}
|
|
|
|
m = Model()
|
|
inputs = tuple(inputs)
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"aot_inductor.debug_compile": True,
|
|
}
|
|
):
|
|
so_path = AOTIRunnerUtil.compile(m, inputs, dynamic_shapes=dynamic_shapes)
|
|
with open(os.path.splitext(so_path)[0] + ".cpp") as cpp:
|
|
src_code = cpp.read()
|
|
FileCheck().check_count(
|
|
"unmatched dtype",
|
|
10 if SM80OrLater else 9,
|
|
exactly=True,
|
|
).run(src_code)
|
|
FileCheck().check_count(
|
|
"unmatched dim value at",
|
|
21
|
|
if SM80OrLater
|
|
else 19, # we have 9 dynamic dims for which we generate different checks
|
|
exactly=True,
|
|
).run(src_code)
|
|
FileCheck().check_count(
|
|
"dim value is too",
|
|
18
|
|
if SM80OrLater
|
|
else 16, # we have 9 dynamic dims for which we generate two checks
|
|
exactly=True,
|
|
).run(src_code)
|
|
FileCheck().check_count(
|
|
"unmatched stride value at",
|
|
21
|
|
if SM80OrLater
|
|
else 19, # we have 9 symbolic strides for which we don't generate checks
|
|
exactly=True,
|
|
).run(src_code)
|
|
optimized = AOTIRunnerUtil.load(self.device, so_path)
|
|
actual = optimized(*inputs)
|
|
expected = m(*inputs)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
def test_runtime_checks_fp8(self):
|
|
# cuda only
|
|
if self.device != "cuda":
|
|
return
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x0, x1):
|
|
t = x0.to(torch.float) + x1.to(torch.float)
|
|
return t
|
|
|
|
inputs = []
|
|
for dtype in (
|
|
torch.float8_e4m3fn,
|
|
torch.float8_e5m2,
|
|
# FP8 funz are for AMD
|
|
# see https://github.com/pytorch/pytorch/issues/126734
|
|
# torch.float8_e4m3fnuz,
|
|
# torch.float8_e5m2fnuz,
|
|
):
|
|
inputs.append(torch.ones(8, 8, 8, dtype=dtype, device=self.device))
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x0": {0: dim0},
|
|
"x1": {0: dim0},
|
|
}
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"aot_inductor.debug_compile": True,
|
|
}
|
|
):
|
|
self.check_model(
|
|
Model(),
|
|
tuple(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
def test_runtime_checks_complex(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x0, x1, x2):
|
|
return (x0, x1, x2)
|
|
|
|
inputs = []
|
|
x0 = torch.tensor([1, -1], dtype=torch.complex32, device=self.device)
|
|
x1 = torch.tensor(
|
|
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1],
|
|
dtype=torch.complex64,
|
|
device=self.device,
|
|
)
|
|
x2 = torch.tensor(128, dtype=torch.complex128, device=self.device)
|
|
inputs.append(x0)
|
|
inputs.append(x1)
|
|
inputs.append(x2)
|
|
dim0 = Dim("s0", min=2, max=1024)
|
|
dynamic_shapes = {
|
|
"x0": {0: dim0},
|
|
"x1": {},
|
|
"x2": {},
|
|
}
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"aot_inductor.debug_compile": True,
|
|
}
|
|
):
|
|
self.check_model(
|
|
Model(),
|
|
tuple(inputs),
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
|
def test_runtime_checks_dtype_failed(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
y = x.type(torch.float)
|
|
return y
|
|
|
|
x = torch.randn(1, 4, dtype=torch.float16, device=self.device)
|
|
model = Model()
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"aot_inductor.debug_compile": True,
|
|
}
|
|
):
|
|
so_path: str = AOTIRunnerUtil.compile(
|
|
model,
|
|
(x,),
|
|
)
|
|
aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
|
|
x_casted = x.float()
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(x_casted)
|
|
|
|
def test_non_contiguous_output_alias(self):
|
|
# Test return x, x.contiguous() where x is non-contiguous.
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
squared = x * x
|
|
transposed = squared.t() # non-contiguous
|
|
contig = transposed.contiguous()
|
|
return transposed, contig
|
|
|
|
x = torch.randn(3, 4, dtype=torch.float16, device=self.device)
|
|
model = Model()
|
|
with torch.no_grad():
|
|
result = AOTIRunnerUtil.run(
|
|
self.device,
|
|
model,
|
|
(x,),
|
|
)
|
|
actual = model(x)
|
|
self.assertTrue(same(result, actual))
|
|
|
|
# contiguous() should create a new tensor
|
|
self.assertTrue(result[0].data_ptr() != result[1].data_ptr())
|
|
|
|
def test_multiple_output_alias(self):
|
|
# Test when mutliple outputs alias the same tensor
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
squared = x * x
|
|
contig = squared.contiguous() # alias
|
|
reshaped = squared.reshape(squared.shape) # alias
|
|
cubed = squared * x
|
|
return squared, contig, reshaped, cubed
|
|
|
|
x = torch.randn(3, 4, dtype=torch.float32, device=self.device)
|
|
model = Model()
|
|
|
|
with torch.no_grad():
|
|
result = AOTIRunnerUtil.run(
|
|
self.device,
|
|
model,
|
|
(x,),
|
|
)
|
|
actual = model(x)
|
|
self.assertTrue(same(result, actual))
|
|
|
|
# squared, contig and reshaped alias the same tensor.
|
|
self.assertTrue(result[0].data_ptr() == result[1].data_ptr())
|
|
self.assertTrue(result[0].data_ptr() == result[2].data_ptr())
|
|
# cubed shouldn't be an alias.
|
|
self.assertTrue(result[0].data_ptr() != result[3].data_ptr())
|
|
|
|
def test_runtime_checks_shape_failed(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
x = torch.randn(4, 4, 4, dtype=torch.float16, device=self.device)
|
|
y0 = torch.randn(8, 4, 4, dtype=torch.float16, device=self.device)
|
|
y1 = torch.randn(4, 8, 4, dtype=torch.float16, device=self.device)
|
|
y2 = rand_strided(
|
|
(4, 4, 4), (16, 1, 4), dtype=torch.float16, device=self.device
|
|
)
|
|
# batch size is outside of the range
|
|
y3 = torch.randn(2048, 3, 4, dtype=torch.float16, device=self.device)
|
|
y4 = torch.randn(2048, 4, 4, dtype=torch.float16, device=self.device)
|
|
dim0 = Dim("s0", min=4, max=1024)
|
|
dynamic_shapes = {
|
|
"x": {0: dim0},
|
|
}
|
|
model = Model()
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"aot_inductor.debug_compile": True,
|
|
}
|
|
):
|
|
so_path: str = AOTIRunnerUtil.compile(
|
|
model, (x,), dynamic_shapes=dynamic_shapes
|
|
)
|
|
aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
|
|
# dynamic dim works fine
|
|
_ = aot_inductor_module(y0)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y1)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y2)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y3)
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
aot_inductor_module(y4)
|
|
|
|
def test_add_complex(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return torch.add(a, b)
|
|
|
|
x = torch.tensor(
|
|
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
|
|
)
|
|
y = torch.tensor(
|
|
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], device=self.device
|
|
)
|
|
self.check_model(Model(), (x, y))
|
|
|
|
def test_embedding_bag(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, w, i, o):
|
|
return torch.ops.aten._embedding_bag(w, i, o, False, 0, False, None)
|
|
|
|
example_inputs = (
|
|
torch.randn([10, 4], device=self.device),
|
|
torch.randint(10, [8], device=self.device),
|
|
torch.tensor([0, 2, 6], device=self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_fft_c2c(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.fft.fftn(x), torch.fft.fftn(x).real
|
|
|
|
example_inputs = (torch.randn(16, 16, 16, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_bool_input(self):
|
|
# Specialize on whichever branch the example input for b is
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, b):
|
|
if b:
|
|
return x * x
|
|
else:
|
|
return x + x
|
|
|
|
example_inputs = (torch.randn(3, 3, device=self.device), True)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_int_list_input(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, i):
|
|
return x * i[0] * i[1]
|
|
|
|
example_inputs = (torch.randn(3, 3, device=self.device), [3, 4])
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_nested_tensor_from_jagged(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, values, offsets):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
res = self.mlp(nt)
|
|
return res.values()
|
|
|
|
model = Model().to(device=self.device)
|
|
|
|
example_inputs_1 = (
|
|
torch.randn((15, 128), device=self.device),
|
|
torch.tensor([0, 3, 4, 10, 15], device=self.device),
|
|
)
|
|
|
|
# same "NT batch size", different actual amount of data
|
|
example_inputs_2 = (
|
|
torch.randn((31, 128), device=self.device),
|
|
torch.tensor([0, 1, 20, 25, 31], device=self.device),
|
|
)
|
|
|
|
# same actual amount of data, different "NT batch size"
|
|
example_inputs_3 = (
|
|
torch.randn((15, 128), device=self.device),
|
|
torch.tensor([0, 3, 10, 15], device=self.device),
|
|
)
|
|
|
|
# different "NT batch size"
|
|
example_inputs_4 = (
|
|
torch.randn((37, 128), device=self.device),
|
|
torch.tensor([0, 5, 16, 25, 29, 37], device=self.device),
|
|
)
|
|
|
|
dim0_values = Dim("dim0_values", min=1, max=128)
|
|
dim0_offsets = Dim("dim0_offsets", min=1, max=9)
|
|
dynamic_shapes = {"values": {0: dim0_values}, "offsets": {0: dim0_offsets}}
|
|
example_inputs_list = [
|
|
example_inputs_1,
|
|
example_inputs_2,
|
|
example_inputs_3,
|
|
example_inputs_4,
|
|
]
|
|
|
|
self.check_model_with_multiple_inputs(
|
|
model, example_inputs_list, dynamic_shapes=dynamic_shapes
|
|
)
|
|
|
|
@common_utils.parametrize("max_autotune", [True, False])
|
|
def test_misc_1(self, max_autotune):
|
|
if self.device == "cpu" and IS_MACOS and max_autotune:
|
|
raise unittest.SkipTest("max_autotune not supported on macos")
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.Sigmoid()
|
|
)
|
|
self.emb = nn.EmbeddingBag(num_embeddings=128, embedding_dim=32)
|
|
self.over_arch = nn.Sequential(
|
|
nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 32), nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x, y):
|
|
mlp_output = self.mlp(x)
|
|
emb_output = self.emb(y)
|
|
return self.over_arch(torch.concat([mlp_output, emb_output], dim=1))
|
|
|
|
example_inputs = (
|
|
torch.randn(16, 128, device=self.device),
|
|
torch.randint(0, 128, (16, 10), device=self.device),
|
|
)
|
|
self.check_model(
|
|
Model(), example_inputs, options=dict(max_autotune=max_autotune)
|
|
)
|
|
|
|
@skip_if_no_torchvision
|
|
def test_torchvision_transforms_functional_tensor_resize(self):
|
|
import torchvision
|
|
|
|
# https://fb.workplace.com/groups/1075192433118967/permalink/1501860707118802/
|
|
class A(torch.nn.Module):
|
|
def forward(self, image: torch.Tensor, target_size: torch.Tensor):
|
|
target_h, target_w = target_size.tolist()
|
|
torch._check(target_h > 0)
|
|
torch._check(target_w > 0)
|
|
torch._check(target_h <= 4000)
|
|
torch._check(target_w <= 4000)
|
|
|
|
return torchvision.transforms._functional_tensor.resize(
|
|
image,
|
|
size=[target_h, target_w],
|
|
interpolation="bilinear",
|
|
antialias=False,
|
|
)
|
|
|
|
model = A()
|
|
example_inputs = (
|
|
torch.ones([3, 800, 600], device=self.device),
|
|
torch.tensor([448, 336], device=self.device),
|
|
)
|
|
dynamic_shapes = {
|
|
"image": {
|
|
1: torch.export.Dim("height", min=1, max=4000),
|
|
2: torch.export.Dim("width", min=1, max=4000),
|
|
},
|
|
"target_size": None,
|
|
}
|
|
self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_aoti_debug_printer_codegen(self):
|
|
# basic addmm model to test codegen for aoti intermediate debug printer
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M = 8
|
|
N = 6
|
|
K = 16
|
|
model = Model(N, K, self.device)
|
|
batch = 2
|
|
a = torch.randn(batch, M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
|
|
kernel_calls = (
|
|
[
|
|
("triton_poi_fused_0", 1),
|
|
(f"aoti_torch_{GPU_TYPE}_addmm_out", 2),
|
|
]
|
|
if self.device == GPU_TYPE
|
|
else [
|
|
("aoti_torch_cpu_addmm_out", 2),
|
|
]
|
|
)
|
|
|
|
# test default debug printing all tensor values codegen
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
|
|
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
|
|
# check the codegen for debug printing around the actual kernel call is expected
|
|
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
# test printing selected kernel's tensor values codegen
|
|
filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out"
|
|
with config.patch(
|
|
{
|
|
"aot_inductor.debug_intermediate_value_printer": "2",
|
|
"aot_inductor.filtered_kernel_names": filtered_kernel_name,
|
|
}
|
|
):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
filtered_kernel_calls = [
|
|
(filtered_kernel_name, 2),
|
|
]
|
|
for kernel_call, count in filtered_kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
kernel_calls_not_to_print = [
|
|
kernel_call
|
|
for kernel_call in kernel_calls
|
|
if kernel_call[0] != filtered_kernel_name
|
|
]
|
|
for kernel_name, _ in kernel_calls_not_to_print:
|
|
FileCheck().check_not(f"before_launch - {kernel_name}").run(code)
|
|
FileCheck().check_not(f"after_launch - {kernel_name}").run(code)
|
|
|
|
def test_aoti_debug_printer_user_defined_triton_kernel(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
out = torch.zeros_like(x)
|
|
add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randn(4, 4, device=self.device),
|
|
torch.randn(4, 4, device=self.device),
|
|
)
|
|
|
|
kernel_calls = [
|
|
("add_kernel_0", 3),
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
# check the codegen for debug printing around the actual kernel call is expected
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
def test_aoti_debug_printer_cpp_kernel(self):
|
|
if self.device != "cpu":
|
|
raise unittest.SkipTest("cpu test case only")
|
|
|
|
# a simple cpp kernel test case for testing the debug printer codegen
|
|
# on cpp kernel cpu device.
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
t = torch.tensor(x.size(-1), device="cpu", dtype=torch.float)
|
|
t = torch.sqrt(t * 3)
|
|
return x * t
|
|
|
|
example_inputs = (torch.randn(4, 4, device="cpu"),)
|
|
|
|
kernel_calls = [
|
|
("cpp_fused_mul_sqrt_0", 2),
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
# check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
# check the codegen for debug printing around the actual kernel call is expected
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"before_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
FileCheck().check_count(
|
|
f"after_launch - {kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
def test_aoti_debug_printer_sym_inputs(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
from torch.testing._internal.triton_utils import add_kernel
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
maxlen = max(x.item(), 512)
|
|
a = torch.ones(maxlen, device=GPU_TYPE)
|
|
b = torch.ones(maxlen, device=GPU_TYPE)
|
|
out = torch.zeros_like(a)
|
|
# unbacked symint in grid
|
|
add_kernel[(1, 1, maxlen)](a, b, out, maxlen, 32)
|
|
return out
|
|
|
|
example_inputs = (torch.randint(high=1024, size=(1,), device=self.device),)
|
|
|
|
expected_scalar_args = [
|
|
"triton_poi_fused_zeros_like_0_xnumel",
|
|
"triton_poi_fused_1_xnumel",
|
|
"std::max(static_cast<int64_t>(512L), static_cast<int64_t>(u0))",
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
for scalar in expected_scalar_args:
|
|
FileCheck().check_count(
|
|
f"{scalar}",
|
|
2,
|
|
).run(code)
|
|
|
|
def test_aoti_debug_printing_model_inputs_codegen(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("requires CUDA")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, a, b, c):
|
|
x = a * 3.14
|
|
y = torch.addmm(c, x, b)
|
|
z = torch.nn.functional.gelu(y)
|
|
return z
|
|
|
|
example_inputs = (
|
|
torch.randn(10, 20, device="cuda"),
|
|
torch.randn(20, 30, device="cuda"),
|
|
torch.randn(10, 30, device="cuda"),
|
|
)
|
|
model = Model()
|
|
kernel_calls = [
|
|
("aoti_model_inputs", 3),
|
|
]
|
|
|
|
with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}):
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
self.assertEqual("aoti_torch_print_tensor_handle" in code, True)
|
|
# check the codegen for debug printing around aoti model inputs is expected
|
|
for kernel_call, count in kernel_calls:
|
|
FileCheck().check_count(
|
|
f"{kernel_call}",
|
|
count,
|
|
).run(code)
|
|
|
|
def test_size_from_multi_output(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
_x, _i = torch.unique(x, sorted=True, return_inverse=True)
|
|
_x = _x.detach().clone()
|
|
return self.relu(_x), _i
|
|
|
|
example_inputs = (torch.randn(8, device=self.device),)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@dynamo_config.patch({"capture_scalar_outputs": True})
|
|
def test_sym_i64_input_codegen(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
from torch.testing._internal.triton_utils import add_kernel
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x_symint = x.item()
|
|
a = torch.ones(x_symint, device=GPU_TYPE)
|
|
b = torch.ones(x_symint, device=GPU_TYPE)
|
|
out = torch.zeros_like(a)
|
|
# unbacked symint in grid
|
|
add_kernel[(1, 1, x_symint)](a, b, out, x_symint, 32)
|
|
return out
|
|
|
|
example_inputs = (
|
|
torch.randint(high=1024, size=(1,), device=self.device, dtype=torch.int32),
|
|
)
|
|
# This simple unit test case model generates two triton kernels:
|
|
# 1. triton_poi_fused_ones_1:
|
|
# triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i64'}
|
|
# 2. add_kernel:
|
|
# triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr': '*fp32', 'n_elements': 'i64'}
|
|
# input u0 was defined as int32_t initially, verify for every kernel var args downstream,
|
|
# it gets explicitly declared using its data types in the cpp wrapper codegen code.
|
|
expected_scalar_args = [
|
|
"int64_t var_1 = u0;",
|
|
"int64_t var_3 = u0;",
|
|
"int64_t var_5 = u0;",
|
|
"int64_t var_9 = u0;",
|
|
]
|
|
# check the new behavior of codegen is expected
|
|
result, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, Model(), example_inputs
|
|
)
|
|
for scalar_line in expected_scalar_args:
|
|
FileCheck().check_count(
|
|
scalar_line,
|
|
1,
|
|
).run(code)
|
|
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@common_utils.parametrize("mark_unbacked", (True, False))
|
|
def test_unbacked_equals_input_size_runtime_assertion(self, mark_unbacked: bool):
|
|
# This test checks the unbacked symint runtime assertions, for the following cases:
|
|
# (A) an unbacked symint equals an unbacked symint (mark_unbacked=True)
|
|
# (B) an unbacked symint equals a backed symint (mark_unbacked=False)
|
|
class Model(torch.nn.Module):
|
|
def forward(self, a, b, c):
|
|
nz = torch.nonzero(a)
|
|
ones = a.new_ones([nz.size(0), b.size(0)])
|
|
torch._check(ones.size(0) >= 1)
|
|
equals = torch.add(ones, c)
|
|
return equals
|
|
|
|
model = Model()
|
|
example_inputs = (
|
|
torch.ones(64, device=self.device),
|
|
b := torch.randn((32,), device=self.device),
|
|
c := torch.randn((64, 32), device=self.device),
|
|
)
|
|
if mark_unbacked:
|
|
torch._dynamo.decorators.mark_unbacked(c, 0)
|
|
else:
|
|
torch._dynamo.mark_dynamic(c, 0)
|
|
|
|
# Check the runtime assertion is codegen'ed.
|
|
so_path, code = run_and_get_cpp_code(
|
|
AOTIRunnerUtil.compile, model, example_inputs
|
|
)
|
|
lowerbound_check = "u1 >= 1" if mark_unbacked else "u0 >= 2"
|
|
FileCheck().check_count(lowerbound_check, 1).run(code)
|
|
|
|
compiled = AOTIRunnerUtil.load(self.device, so_path)
|
|
compiled(*example_inputs)
|
|
|
|
# Check the runtime assertion.
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
unexpected_inputs = (torch.ones(0, device=self.device), b, c)
|
|
compiled(*unexpected_inputs)
|
|
|
|
def test_none_args_aot_codegen(self):
|
|
if self.device != GPU_TYPE:
|
|
raise unittest.SkipTest("requires GPU")
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2),
|
|
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
|
],
|
|
key=["n_elements"],
|
|
)
|
|
@triton.jit
|
|
def sin_kernel(
|
|
in_ptr0,
|
|
out_ptr,
|
|
# We want to include an arg known to be 1 at compile time
|
|
# This is because we remove None args from the arg list; changing the eq_1/constexpr arg indices.
|
|
# We want to make sure we recompute these correctly
|
|
EQ_1_ARG,
|
|
n_elements,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
if in_ptr0 is not None:
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
else:
|
|
x = 0.0
|
|
output = tl.sin(x) + EQ_1_ARG
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def sin_triton(x, out):
|
|
n_elements = out.numel()
|
|
sin_kernel[(n_elements,)](x, out, 1, n_elements)
|
|
return out
|
|
|
|
x = torch.randn(65, device=self.device)
|
|
out = torch.empty_like(x)
|
|
|
|
not_none_inputs = (x, out)
|
|
none_inputs = (None, out)
|
|
|
|
# AOTI compilation specializes on either None or non-None inputs
|
|
# So we have to check twice here
|
|
|
|
self.check_model(sin_triton, none_inputs)
|
|
self.check_model(sin_triton, not_none_inputs)
|
|
|
|
def test_issue_140766(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mlp = torch.nn.Sequential(
|
|
torch.nn.Linear(128, 512),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(512, 128),
|
|
)
|
|
self.norm = torch.nn.LayerNorm(128)
|
|
self.attn = torch.nn.functional.scaled_dot_product_attention
|
|
|
|
def forward(self, x):
|
|
# [2, 128, 4096]
|
|
x = x.transpose(1, 2)
|
|
# [2, 4096, 128]
|
|
for _ in range(2):
|
|
x = self.forward_block(x)
|
|
return x
|
|
|
|
def forward_block(self, x):
|
|
# x: B, H*W, C
|
|
B = x.shape[0]
|
|
H, W, C = 64, 64, 128
|
|
shortcut = x
|
|
x = self.norm(x)
|
|
x = x.reshape(B, H, W, C)
|
|
# B, H, W, C
|
|
x = self.attn(x, x, x)
|
|
x = x.reshape(B, H // 8, W // 8, 8, 8, -1)
|
|
x = x.transpose(2, 3).reshape(B, H * W, -1)
|
|
|
|
x = shortcut + x
|
|
x = x + self.mlp(self.norm(x))
|
|
return x
|
|
|
|
bs = torch.export.Dim("bs", max=12)
|
|
example_inputs = (torch.randn(2, 128, 4096, device=self.device),)
|
|
self.check_model(Model(), example_inputs, dynamic_shapes={"x": {0: bs}})
|
|
|
|
def test_so_without_weight(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 128, 2048, 4096
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"always_keep_tensor_constants": True,
|
|
"aot_inductor.package_constants_in_so": True,
|
|
}
|
|
):
|
|
so_path = AOTIRunnerUtil.compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
with torch.no_grad(), config.patch(
|
|
{
|
|
"always_keep_tensor_constants": True,
|
|
"aot_inductor.package_constants_in_so": False,
|
|
}
|
|
):
|
|
so_path_weightless = AOTIRunnerUtil.compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
self.assertTrue(os.path.getsize(so_path) > 10_000_000)
|
|
self.assertTrue(os.path.getsize(so_path_weightless) < 10_000_000)
|
|
|
|
runner = AOTIRunnerUtil.load_runner(self.device, so_path_weightless)
|
|
|
|
# Let's check whether the model has correct constant name mapping.
|
|
expected_original_fqns = {
|
|
"L__self___weight": "L__self___weight",
|
|
"L__self___bias": "L__self___bias",
|
|
}
|
|
self.assertEqual(
|
|
expected_original_fqns, runner.get_constant_names_to_original_fqns()
|
|
)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
attach_weights = {
|
|
"L__self___weight": model.weight,
|
|
"L__self___bias": model.bias,
|
|
}
|
|
runner.update_constant_buffer(attach_weights, False, False)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
def test_update_constant_buffer(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, n, k, device):
|
|
super().__init__()
|
|
self.weight = torch.randn(n, k, device=device)
|
|
self.bias = torch.randn(n, device=device)
|
|
|
|
def forward(self, a):
|
|
return torch.nn.functional.linear(a, self.weight, self.bias)
|
|
|
|
M, N, K = 8, 6, 16
|
|
model = Model(N, K, self.device)
|
|
a = torch.randn(M, K, device=self.device)
|
|
example_inputs = (a,)
|
|
with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}):
|
|
so_path = AOTIRunnerUtil.compile(
|
|
model=model,
|
|
example_inputs=example_inputs,
|
|
)
|
|
|
|
runner = AOTIRunnerUtil.load_runner(self.device, so_path)
|
|
|
|
# Let's check whether the model has correct constant name mapping.
|
|
expected_original_fqns = {
|
|
"L__self___weight": "L__self___weight",
|
|
"L__self___bias": "L__self___bias",
|
|
}
|
|
self.assertEqual(
|
|
expected_original_fqns, runner.get_constant_names_to_original_fqns()
|
|
)
|
|
|
|
def runner_call(*args, **kwargs):
|
|
import torch.fx._pytree as fx_pytree
|
|
|
|
call_spec = runner.get_call_spec()
|
|
in_spec = pytree.treespec_loads(call_spec[0])
|
|
out_spec = pytree.treespec_loads(call_spec[1])
|
|
flat_inputs = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
|
flat_outputs = runner.run(flat_inputs)
|
|
return pytree.tree_unflatten(flat_outputs, out_spec)
|
|
|
|
test_inputs = torch.randn(M, K, device=self.device)
|
|
expected = model(test_inputs)
|
|
output = runner_call(test_inputs)
|
|
self.assertEqual(expected, output)
|
|
|
|
new_weights = {
|
|
"L__self___weight": torch.randn(N, K, device=self.device),
|
|
"L__self___bias": torch.randn(N, device=self.device),
|
|
}
|
|
runner.update_constant_buffer(new_weights, False, False)
|
|
new_output = runner_call(test_inputs)
|
|
new_expected = torch.nn.functional.linear(
|
|
test_inputs, new_weights["L__self___weight"], new_weights["L__self___bias"]
|
|
)
|
|
self.assertEqual(new_expected, new_output)
|
|
|
|
def test_cond_share_predicte(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, predicate, x):
|
|
y = torch.cond(
|
|
predicate,
|
|
lambda: x + 1,
|
|
lambda: x + 2,
|
|
)
|
|
|
|
z = torch.cond(
|
|
predicate,
|
|
lambda: y + 1,
|
|
lambda: y + 2,
|
|
)
|
|
return (z,)
|
|
|
|
example_inputs = (
|
|
torch.tensor([True]).to(self.device),
|
|
torch.tensor([1, 2, 3]).to(self.device),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE,
|
|
"To enable after the C shim FC window ends",
|
|
)
|
|
def test_misaligned_input_1(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("CUDA test only")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.sin() + x.cos()
|
|
|
|
N = 64 * 64 * 64 + 64
|
|
arg = torch.randn(N, device=self.device)
|
|
example_inputs = (arg,)
|
|
model = Model()
|
|
expected = model(*example_inputs)
|
|
so_path = AOTIRunnerUtil.compile(model, example_inputs)
|
|
optimized = AOTIRunnerUtil.load(self.device, so_path)
|
|
# If the model is compiled with aligned inputs, the generated
|
|
# code will check inputs alignment at runtime
|
|
self.code_check_count(
|
|
model, example_inputs, "aoti_torch_clone_preserve_strides", 1
|
|
)
|
|
|
|
misaligned_arg = torch.zeros(N + 1, device=self.device)
|
|
misaligned_arg = misaligned_arg[1:]
|
|
misaligned_arg.copy_(arg)
|
|
actual = optimized(misaligned_arg)
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_misaligned_input_2(self):
|
|
if self.device != "cuda":
|
|
raise unittest.SkipTest("CUDA test only")
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.sin() + x.cos()
|
|
|
|
N = 64 * 64 * 64 + 64
|
|
arg = torch.randn(N, device=self.device)
|
|
misaligned_arg = torch.zeros(N + 1, device=self.device)
|
|
misaligned_arg = misaligned_arg[1:]
|
|
misaligned_arg.copy_(arg)
|
|
example_inputs = (misaligned_arg,)
|
|
|
|
model = Model()
|
|
self.check_model(model, example_inputs)
|
|
# If the model is already compiled with a misaligned input, the
|
|
# generated code should NOT contain an alignment check for that input.
|
|
self.code_check_count(
|
|
model, example_inputs, "aoti_torch_clone_preserve_strides", 0
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
|
def test_stft(self):
|
|
N_FFT = 400
|
|
HOP_LENGTH = 160
|
|
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
window = torch.hann_window(N_FFT).to(x.device)
|
|
stft = torch.stft(
|
|
x, N_FFT, HOP_LENGTH, window=window, return_complex=True
|
|
)
|
|
magnitudes = stft[..., :-1].abs() ** 2
|
|
return magnitudes
|
|
|
|
model = Model()
|
|
example_inputs = (torch.randn(500, device=self.device),)
|
|
self.check_model(model, example_inputs)
|
|
|
|
def test_conv3d(self):
|
|
if self.device != GPU_TYPE or not is_big_gpu():
|
|
raise unittest.SkipTest("requires modern GPU to run max-autotune")
|
|
|
|
if not _has_sufficient_memory(self.device, 2**35):
|
|
raise unittest.SkipTest("insufficient memory")
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
convert_element_type_1271,
|
|
convert_element_type_1272,
|
|
convert_element_type_1273,
|
|
):
|
|
return torch.ops.aten.convolution.default(
|
|
convert_element_type_1271,
|
|
convert_element_type_1272,
|
|
convert_element_type_1273,
|
|
[1, 1],
|
|
[1, 1],
|
|
[1, 1],
|
|
False,
|
|
[0, 0],
|
|
1,
|
|
)
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 64, 5160, 5160, device=self.device),
|
|
torch.randn(3, 64, 3, 3, device=self.device),
|
|
torch.randn(3, device=self.device),
|
|
)
|
|
dynamic_shapes = {
|
|
"convert_element_type_1271": {
|
|
3: torch.export.Dim.DYNAMIC,
|
|
4: torch.export.Dim.DYNAMIC,
|
|
},
|
|
"convert_element_type_1272": None,
|
|
"convert_element_type_1273": None,
|
|
}
|
|
with config.patch(
|
|
{
|
|
"max_autotune": True,
|
|
"max_autotune_conv_backends": "TRITON",
|
|
}
|
|
):
|
|
self.check_model(
|
|
Model(),
|
|
example_inputs,
|
|
atol=0.1,
|
|
rtol=1e-3,
|
|
dynamic_shapes=dynamic_shapes,
|
|
)
|
|
|
|
@skipIfXpu(
|
|
msg="The operator 'aten::_int_mm' is not currently implemented for the XPU device"
|
|
)
|
|
def test__int_mm(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return torch._int_mm(x, y)
|
|
|
|
example_inputs = (
|
|
torch.randint(-10, 10, (64, 32), device=self.device, dtype=torch.int8),
|
|
torch.randint(-10, 10, (32, 64), device=self.device, dtype=torch.int8),
|
|
)
|
|
self.check_model(Model(), example_inputs)
|
|
|
|
def test_assert_tensor_meta(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
torch.ops.aten._assert_tensor_metadata.default(
|
|
x,
|
|
dtype=torch.int32,
|
|
)
|
|
return (x + 1,)
|
|
|
|
example_inputs = (torch.tensor(1, dtype=torch.int32),)
|
|
with config.patch(
|
|
{
|
|
"implicit_fallbacks": False,
|
|
}
|
|
):
|
|
self.check_model(
|
|
Module(),
|
|
example_inputs,
|
|
atol=0.1,
|
|
rtol=1e-3,
|
|
)
|
|
|
|
def test_composed_dynamic_size(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
example_inputs = (torch.randn(10, device=self.device),)
|
|
dim = torch.export.Dim("dim_0")
|
|
dim_even = 2 * dim
|
|
dynamic_shapes = {
|
|
"x": {0: dim_even},
|
|
}
|
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
|
|
class AOTInductorLoggingTest(LoggingTestCase):
|
|
@make_logging_test(dynamic=logging.DEBUG)
|
|
def test_shape_env_reuse(self, records):
|
|
# make sure ShapeEnv is only created once and reused afterwards
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 2
|
|
|
|
inputs = (torch.randn(4, 4),)
|
|
dynamic_shapes = {
|
|
"x": {0: Dim.AUTO, 1: Dim.AUTO},
|
|
}
|
|
ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes, strict=False)
|
|
with torch.no_grad():
|
|
torch._inductor.aot_compile(ep.module(), inputs)
|
|
self.assertEqual([r.msg == "create_env" for r in records].count(True), 1)
|
|
|
|
|
|
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
|
|
|
|
|
|
def fail_cpu(is_skip=False):
|
|
return TestFailure(
|
|
("cpu",),
|
|
is_skip=is_skip,
|
|
)
|
|
|
|
|
|
def fail_gpu(suffixes: tuple[str, ...], is_skip=False):
|
|
return TestFailure(
|
|
suffixes,
|
|
is_skip=is_skip,
|
|
)
|
|
|
|
|
|
# test_failures, xfail by default, set is_skip=True to skip
|
|
CPU_TEST_FAILURES = {
|
|
# TODO: failed internally
|
|
"test_multiple_output_alias": fail_cpu(is_skip=True),
|
|
"test_update_constant_buffer": fail_cpu(is_skip=True),
|
|
"test_so_without_weight": fail_cpu(is_skip=True),
|
|
}
|
|
|
|
# test_failures, xfail by default, set is_skip=True to skip
|
|
GPU_TEST_FAILURES = {
|
|
# quantized unsupported for GPU
|
|
"test_quantized_linear": fail_gpu(("cuda", "xpu")),
|
|
"test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")),
|
|
# No scaled_dot_product_efficient_attention implementation for XPU yet.
|
|
"test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)),
|
|
}
|
|
|
|
|
|
class AOTInductorTestABICompatibleCpu(TestCase):
|
|
device = "cpu"
|
|
device_type = "cpu"
|
|
check_model = check_model
|
|
check_model_with_multiple_inputs = check_model_with_multiple_inputs
|
|
code_check_count = code_check_count
|
|
allow_stack_allocation = False
|
|
use_minimal_arrayref_interface = False
|
|
|
|
|
|
copy_tests(
|
|
AOTInductorTestsTemplate,
|
|
AOTInductorTestABICompatibleCpu,
|
|
"cpu",
|
|
CPU_TEST_FAILURES,
|
|
)
|
|
|
|
|
|
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
|
class AOTInductorTestABICompatibleGpu(TestCase):
|
|
device = GPU_TYPE
|
|
device_type = GPU_TYPE
|
|
check_model = check_model
|
|
check_model_with_multiple_inputs = check_model_with_multiple_inputs
|
|
code_check_count = code_check_count
|
|
allow_stack_allocation = False
|
|
use_minimal_arrayref_interface = False
|
|
|
|
|
|
copy_tests(
|
|
AOTInductorTestsTemplate,
|
|
AOTInductorTestABICompatibleGpu,
|
|
GPU_TYPE,
|
|
GPU_TEST_FAILURES,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
# cpp_extension N/A in fbcode
|
|
if HAS_GPU or sys.platform == "darwin":
|
|
run_tests(needs="filelock")
|