From 5f7ce38e44791817d326467813e354fde1d01db0 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 10 Feb 2025 07:34:04 +0000 Subject: [PATCH] 2025-02-10 nightly release (e8304f08fedc802a90f9361c30861f8c5aab946e) --- aten/src/ATen/native/cuda/Math.cuh | 7 +++---- c10/metal/special_math.h | 26 ++++++++++++++++++++++++++ test/dynamo/test_repros.py | 21 +++++++++++++++++++++ test/inductor/test_flex_attention.py | 6 +++--- tools/build_with_debinfo.py | 3 +++ torch/_dynamo/config.py | 2 ++ torch/_dynamo/eval_frame.py | 8 ++++++++ torch/_torch_docs.py | 4 ++-- 8 files changed, 68 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index b99e9d0c94d..2fe8f5dd2e3 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify( T sinc(T a) { if (a == T(0)) { return T(1); - } else { - constexpr T pi = T(3.14159265358979323846L); - T product = pi * a; - return std::sin(product) / product; } + constexpr T pi = T(3.14159265358979323846L); + T product = pi * a; + return std::sin(product) / product; } ); // sinc_string diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 8bcb1f7a53e..04fd7eee18f 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) { return float2(re, im) / a2; } +template +inline T spherical_bessel_j0(T x) { + if (::metal::isinf(x)) + return T(0.0); + T x2 = x * x; + T k1 = static_cast(-1.0); + T k2 = static_cast(1.0); + + if (::metal::abs(x) < T(0.5)) { + return T(1.0) + + x2 * + (k1 / T(6.0) + + x2 * + (k2 / T(120.0) + + x2 * + (k1 / T(5040.0) + + x2 * + (k2 / T(362880.0) + + x2 * + (k1 / T(39916800.0) + + x2 * (k2 / T(6227020800.0))))))); + } + + return ::metal::sin(x) / x; +} + } // namespace metal } // namespace c10 diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index dc3d58f5b0d..eb1a8d2d6ca 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6510,6 +6510,27 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): ).sum() self.assertEqual(actual, expected) + def test_incompatible_configs(self): + with torch._dynamo.config.patch( + suppress_errors=False, fail_on_recompile_limit_hit=False + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=True, fail_on_recompile_limit_hit=False + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=False, fail_on_recompile_limit_hit=True + ): + torch.compile(lambda: None) + + with torch._dynamo.config.patch( + suppress_errors=True, fail_on_recompile_limit_hit=True + ), self.assertRaises(AssertionError): + torch.compile(lambda: None) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 8b4382061b0..99440593c2b 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform def test_strided_backwards(self): shape = (1, 2, 4096, 64) - Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) - V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16) + Q = torch.randn(shape, requires_grad=True, device="cuda") + K = torch.randn(shape, requires_grad=True, device="cuda") + V = torch.randn(shape, requires_grad=True, device="cuda") func = torch.compile(flex_attention, dynamic=True, fullgraph=True) K_sliced = K[:, :, :-128] diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index 26c054bf2a0..0c9553b963e 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -78,6 +78,9 @@ def create_build_plan() -> list[tuple[str, str]]: if line.startswith(": &&") and line.endswith("&& :"): line = line[4:-4] line = line.replace("-O2", "-g").replace("-O3", "-g") + # Build Metal shaders with debug infomation + if "xcrun metal " in line and "-frecord-sources" not in line: + line += " -frecord-sources -gline-tables-only" try: name = line.split("-o ", 1)[1].split(" ")[0] rc.append((name, line)) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index f033a282d27..84ed3be5a70 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -52,6 +52,7 @@ skip_code_recursive_on_recompile_limit_hit = True # raise a hard error if cache limit is hit. If you are on a model where you # know you've sized the cache correctly, this can help detect problems when # you regress guards/specialization. This works best when recompile_limit = 1. +# This flag is incompatible with: suppress_errors. # [@compile_ignored: runtime_behaviour] fail_on_recompile_limit_hit = False @@ -164,6 +165,7 @@ traceable_tensor_subclasses: set[type[Any]] = set() # This is a good way to get your model to work one way or another, but you may # lose optimization opportunities this way. Devs, if your benchmark model is failing # this way, you should figure out why instead of suppressing it. +# This flag is incompatible with: fail_on_recompile_limit_hit. suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) # Record and write an execution record of the current frame to a file diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 045cd350b60..789ed41d3a2 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -833,6 +833,13 @@ def is_inductor_supported(): return False +def check_for_incompatible_configs(): + # Some of the configs should be mutually exclusive + assert not ( + config.suppress_errors and config.fail_on_recompile_limit_hit + ), "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time." + + def optimize(*args, **kwargs): def rebuild_ctx(): ca_kwargs_override = config.compiled_autograd_kwargs_override @@ -885,6 +892,7 @@ def _optimize( ... """ check_if_dynamo_supported() + check_for_incompatible_configs() # Note: The hooks object could be global instead of passed around, *however* that would make # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 057ed0fe63e..2dd16890880 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -11100,8 +11100,8 @@ are designed to work with this function. See the examples below. Args: {input} - indices (tensor): the indices into :attr:`input`. Must have long dtype. - dim (int, optional): dimension to select along. + indices (LongTensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. Default: 0 Keyword args: {out}