mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
2025-02-10 nightly release (e8304f08fe)
This commit is contained in:
parent
9b43fab6c5
commit
5f7ce38e44
8 changed files with 68 additions and 9 deletions
|
|
@ -758,11 +758,10 @@ const auto sinc_string = jiterator_stringify(
|
|||
T sinc(T a) {
|
||||
if (a == T(0)) {
|
||||
return T(1);
|
||||
} else {
|
||||
constexpr T pi = T(3.14159265358979323846L);
|
||||
T product = pi * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
constexpr T pi = T(3.14159265358979323846L);
|
||||
T product = pi * a;
|
||||
return std::sin(product) / product;
|
||||
}
|
||||
); // sinc_string
|
||||
|
||||
|
|
|
|||
|
|
@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) {
|
|||
return float2(re, im) / a2;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T spherical_bessel_j0(T x) {
|
||||
if (::metal::isinf(x))
|
||||
return T(0.0);
|
||||
T x2 = x * x;
|
||||
T k1 = static_cast<T>(-1.0);
|
||||
T k2 = static_cast<T>(1.0);
|
||||
|
||||
if (::metal::abs(x) < T(0.5)) {
|
||||
return T(1.0) +
|
||||
x2 *
|
||||
(k1 / T(6.0) +
|
||||
x2 *
|
||||
(k2 / T(120.0) +
|
||||
x2 *
|
||||
(k1 / T(5040.0) +
|
||||
x2 *
|
||||
(k2 / T(362880.0) +
|
||||
x2 *
|
||||
(k1 / T(39916800.0) +
|
||||
x2 * (k2 / T(6227020800.0)))))));
|
||||
}
|
||||
|
||||
return ::metal::sin(x) / x;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -6510,6 +6510,27 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
).sum()
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_incompatible_configs(self):
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=False, fail_on_recompile_limit_hit=False
|
||||
):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=True, fail_on_recompile_limit_hit=False
|
||||
):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=False, fail_on_recompile_limit_hit=True
|
||||
):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
suppress_errors=True, fail_on_recompile_limit_hit=True
|
||||
), self.assertRaises(AssertionError):
|
||||
torch.compile(lambda: None)
|
||||
|
||||
|
||||
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def test_sub_alpha_scalar_repro(self, device):
|
||||
|
|
|
|||
|
|
@ -2510,9 +2510,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
@supported_platform
|
||||
def test_strided_backwards(self):
|
||||
shape = (1, 2, 4096, 64)
|
||||
Q = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
K = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
V = torch.randn(shape, requires_grad=True, device="cuda", dtype=torch.bfloat16)
|
||||
Q = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
K = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
V = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
func = torch.compile(flex_attention, dynamic=True, fullgraph=True)
|
||||
|
||||
K_sliced = K[:, :, :-128]
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ def create_build_plan() -> list[tuple[str, str]]:
|
|||
if line.startswith(": &&") and line.endswith("&& :"):
|
||||
line = line[4:-4]
|
||||
line = line.replace("-O2", "-g").replace("-O3", "-g")
|
||||
# Build Metal shaders with debug infomation
|
||||
if "xcrun metal " in line and "-frecord-sources" not in line:
|
||||
line += " -frecord-sources -gline-tables-only"
|
||||
try:
|
||||
name = line.split("-o ", 1)[1].split(" ")[0]
|
||||
rc.append((name, line))
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ skip_code_recursive_on_recompile_limit_hit = True
|
|||
# raise a hard error if cache limit is hit. If you are on a model where you
|
||||
# know you've sized the cache correctly, this can help detect problems when
|
||||
# you regress guards/specialization. This works best when recompile_limit = 1.
|
||||
# This flag is incompatible with: suppress_errors.
|
||||
# [@compile_ignored: runtime_behaviour]
|
||||
fail_on_recompile_limit_hit = False
|
||||
|
||||
|
|
@ -164,6 +165,7 @@ traceable_tensor_subclasses: set[type[Any]] = set()
|
|||
# This is a good way to get your model to work one way or another, but you may
|
||||
# lose optimization opportunities this way. Devs, if your benchmark model is failing
|
||||
# this way, you should figure out why instead of suppressing it.
|
||||
# This flag is incompatible with: fail_on_recompile_limit_hit.
|
||||
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
|
||||
|
||||
# Record and write an execution record of the current frame to a file
|
||||
|
|
|
|||
|
|
@ -833,6 +833,13 @@ def is_inductor_supported():
|
|||
return False
|
||||
|
||||
|
||||
def check_for_incompatible_configs():
|
||||
# Some of the configs should be mutually exclusive
|
||||
assert not (
|
||||
config.suppress_errors and config.fail_on_recompile_limit_hit
|
||||
), "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time."
|
||||
|
||||
|
||||
def optimize(*args, **kwargs):
|
||||
def rebuild_ctx():
|
||||
ca_kwargs_override = config.compiled_autograd_kwargs_override
|
||||
|
|
@ -885,6 +892,7 @@ def _optimize(
|
|||
...
|
||||
"""
|
||||
check_if_dynamo_supported()
|
||||
check_for_incompatible_configs()
|
||||
# Note: The hooks object could be global instead of passed around, *however* that would make
|
||||
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
|
||||
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
|
||||
|
|
|
|||
|
|
@ -11100,8 +11100,8 @@ are designed to work with this function. See the examples below.
|
|||
|
||||
Args:
|
||||
{input}
|
||||
indices (tensor): the indices into :attr:`input`. Must have long dtype.
|
||||
dim (int, optional): dimension to select along.
|
||||
indices (LongTensor): the indices into :attr:`input`. Must have long dtype.
|
||||
dim (int, optional): dimension to select along. Default: 0
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
|
|
|||
Loading…
Reference in a new issue