mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[CUTLASS] fix addmm (#143537)
We would get a CUDA IMA before because we pass Bias in for X. So, we need to re-order the inputs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143537 Approved by: https://github.com/chenyang78 ghstack dependencies: #143528
This commit is contained in:
parent
b54620f40f
commit
e3fefdfbf0
2 changed files with 34 additions and 0 deletions
|
|
@ -602,6 +602,39 @@ class TestCutlassBackend(TestCase):
|
|||
# Broadcast last dim.
|
||||
compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1])
|
||||
|
||||
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
|
||||
def test_addmm_with_expanded_bias(self):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x, w):
|
||||
bias = torch.zeros(
|
||||
size=(w.size(1), x.size(0)), dtype=torch.float16, device="cuda"
|
||||
).t()
|
||||
return torch.addmm(bias, x, w)
|
||||
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"autotune_in_subproc": False,
|
||||
"max_autotune_gemm_backends": "ATEN,CUTLASS",
|
||||
"cuda.cutlass_dir": _CUTLASS_DIR,
|
||||
"cuda.cutlass_max_profiling_configs": 1,
|
||||
}
|
||||
):
|
||||
model = MyModel()
|
||||
M, N, K = 2048, 3072, 6144
|
||||
x = torch.randn(M, K).cuda().half()
|
||||
w = torch.randn(K, N).cuda().half()
|
||||
|
||||
actual = AOTIRunnerUtil.run(
|
||||
"cuda",
|
||||
model,
|
||||
(x, w),
|
||||
)
|
||||
expected = model(x, w)
|
||||
torch.testing.assert_close(expected, actual)
|
||||
|
||||
# TODO: Enable dynamic test cases when dynamic support is added.
|
||||
@unittest.skipIf(not SM80OrLater, "need sm_80")
|
||||
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
|
||||
|
|
|
|||
|
|
@ -557,6 +557,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
|||
[mat1, mat2, inp_expanded],
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
input_reorder=[2, 0, 1],
|
||||
)
|
||||
|
||||
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
|
||||
|
|
|
|||
Loading…
Reference in a new issue