diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 7dc1aa91791..33f0a9a1568 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -643,6 +643,43 @@ class TestCutlassBackend(TestCase): Y = mm(a, b) torch.testing.assert_close(Y_compiled, Y) + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + @unittest.skipIf(not SM90OrLater, "need sm_90") + def test_force_cutlass_backend_aoti_dynamic(self): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + class MyModel(torch.nn.Module): + def forward(self, x, w): + return x @ w + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": "CUTLASS", + "autotune_fallback_to_aten": False, + "cuda.cutlass_dir": _CUTLASS_DIR, + } + ): + model = MyModel() + M, N, K = 16, 32, 64 + dynamic_shapes = { + "x": {0: M, 1: K}, + "w": {0: K, 1: N}, + } + + x = torch.randn(M, K).cuda().half() + w = torch.randn(K, N).cuda().half() + + actual = AOTIRunnerUtil.run( + "cuda", + model, + (x, w), + dynamic_shapes=dynamic_shapes, + ) + expected = model(x, w) + torch.testing.assert_close(expected, actual) + # TODO: Enable dynamic test cases when dynamic support is added. @unittest.skipIf(not SM80, "need sm_80 exactly") @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index dbb4c49d289..92cf88df8eb 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1327,6 +1327,17 @@ class CppTorchDeviceOptions(CppTorchOptions): _append_list(self._passthough_args, device_passthough_args) self._finalize_options() + def _finalize_options(self) -> None: + super()._finalize_options() + if config.is_fbcode(): + # Re-order library search paths in case there are lib conflicts + # that also live in the FBCode python lib dir. + _, python_lib_dirs = _get_python_related_args() + assert len(python_lib_dirs) == 1, f"Python lib dirs: {python_lib_dirs}" + if python_lib_dirs[0] in self._libraries_dirs: + self._libraries_dirs.remove(python_lib_dirs[0]) + self._libraries_dirs.append(python_lib_dirs[0]) + def get_name_and_dir_from_output_file_path( file_path: str,