[CUTLASS][AOTI] Fixes undefined symbol: cudaLaunchKernelExC (#142094)

Summary:
### Context
* When compiling the object file for a CUTLASS kernel, CUDA RT symbols are left undefined.
* When compiling the final shared object file, we statically link with `libcudart_static.a`.
* One important thing is that ordering matters when specifying the lib search paths (-L).

Test Plan:
```
// before diff
RuntimeError: Failure loading .so: /tmp/tmpqhz_dnza/model.so: undefined symbol: cudaLaunchKernelExC
```

Differential Revision: D66793974

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142094
Approved by: https://github.com/chenyang78, https://github.com/hl475
This commit is contained in:
Colin Peppler 2024-12-06 02:18:54 +00:00 committed by PyTorch MergeBot
parent 8bfc0094e4
commit 0602676c8d
2 changed files with 48 additions and 0 deletions

View file

@ -643,6 +643,43 @@ class TestCutlassBackend(TestCase):
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
@unittest.skipIf(not SM90OrLater, "need sm_90")
def test_force_cutlass_backend_aoti_dynamic(self):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
class MyModel(torch.nn.Module):
def forward(self, x, w):
return x @ w
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": "CUTLASS",
"autotune_fallback_to_aten": False,
"cuda.cutlass_dir": _CUTLASS_DIR,
}
):
model = MyModel()
M, N, K = 16, 32, 64
dynamic_shapes = {
"x": {0: M, 1: K},
"w": {0: K, 1: N},
}
x = torch.randn(M, K).cuda().half()
w = torch.randn(K, N).cuda().half()
actual = AOTIRunnerUtil.run(
"cuda",
model,
(x, w),
dynamic_shapes=dynamic_shapes,
)
expected = model(x, w)
torch.testing.assert_close(expected, actual)
# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM80, "need sm_80 exactly")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")

View file

@ -1327,6 +1327,17 @@ class CppTorchDeviceOptions(CppTorchOptions):
_append_list(self._passthough_args, device_passthough_args)
self._finalize_options()
def _finalize_options(self) -> None:
super()._finalize_options()
if config.is_fbcode():
# Re-order library search paths in case there are lib conflicts
# that also live in the FBCode python lib dir.
_, python_lib_dirs = _get_python_related_args()
assert len(python_lib_dirs) == 1, f"Python lib dirs: {python_lib_dirs}"
if python_lib_dirs[0] in self._libraries_dirs:
self._libraries_dirs.remove(python_lib_dirs[0])
self._libraries_dirs.append(python_lib_dirs[0])
def get_name_and_dir_from_output_file_path(
file_path: str,