Fix recompilation issue with content store (#113533)

While running the accuracy minifier, I was getting the error:
```
NotImplementedError("xor_sum only implemented with inductor")
```

The logs showed that the cache limit was exceeded, and it was falling back to
eager mode which doesn't work for this function. The cache failures was due to
the code guarding on the id of the function being compiled which in this case is
a closure that gets re-created for each function call so the guard always fails.

This fixes the issue by making the storage hash kernel a global function and
working around the dynamo dependency by the `lazy_compile` helper which defers
the `torch.compile` call to the first invocation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113533
Approved by: https://github.com/Skylion007
This commit is contained in:
Peter Bell 2023-11-12 23:20:25 +00:00 committed by PyTorch MergeBot
parent ad06e9f060
commit 4b09b08d2e
2 changed files with 55 additions and 30 deletions

View file

@ -60,6 +60,13 @@ class TestContentStore(TestCase):
# Should not raise an error
hash_storage(torch.tensor(2, device=device).untyped_storage())
@torch._dynamo.config.patch(cache_size_limit=1)
def test_repeated_hash(self, device):
# Test that repeated hashing doesn't trigger a recompile in dynamo
# If it does, we will execute prims.xor_sum in eager which fails
for _ in range(4):
hash_storage(torch.tensor(2, device=device).untyped_storage())
@skipIfRocm
def test_load_tensor(self, device):
with tempfile.TemporaryDirectory() as loc:

View file

@ -28,6 +28,7 @@
# users.
import ctypes
import functools
import hashlib
import os.path
import struct
@ -43,6 +44,52 @@ from torch._C import default_generator
from torch.multiprocessing.reductions import StorageWeakRef
def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call
This avoids eagerly importing dynamo.
"""
def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)
return compile_hook
return decorate_fn
# Use of torch.compile is mandatory for (1) good memory usage
# and (2) xor_sum implementation. This is our first instance of
# using PT2 to implement a kernel in PyTorch; if we get AOT capabilities
# it would be good to apply it here.
@lazy_compile(dynamic=True)
def hash_storage_kernel(x):
# The randint calls are carefully written to hit things we
# have lowerings for in inductor. Lack of unsigned 32-bit integer
# is a pain.
a = torch.randint(
-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32
).abs()
a = ((a % (2**31 - 1)) + 1).long()
b = (
torch.randint(-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32)
.abs()
.long()
)
# This is a standard shift-multiply universal hash family
# plus xor sum hash, using Philox to generate random numbers.
# Our Philox RNG is not deterministic across devices so
# don't use this for stable hashing.
#
# This assumes fixed length so you're also obligated to bucket
# by the length of tensor as well
return prims.xor_sum((a * x + b).int(), [0])
# Returns a hex digest of the data in the storage. Guaranteed to be
# SHA-1 if stable_hash=True, otherwise it will consistent for a single
# process run but not necessarily across processes.
@ -62,35 +109,6 @@ def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) ->
sha1.update(buf)
return sha1.hexdigest()
# Use of torch.compile is mandatory for (1) good memory usage
# and (2) xor_sum implementation. This is our first instance of
# using PT2 to implement a kernel in PyTorch; if we get AOT capabilities
# it would be good to apply it here.
@torch.compile(dynamic=True)
def kernel(x):
# The randint calls are carefully written to hit things we
# have lowerings for in inductor. Lack of unsigned 32-bit integer
# is a pain.
a = torch.randint(
-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32
).abs()
a = ((a % (2**31 - 1)) + 1).long()
b = (
torch.randint(
-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32
)
.abs()
.long()
)
# This is a standard shift-multiply universal hash family
# plus xor sum hash, using Philox to generate random numbers.
# Our Philox RNG is not deterministic across devices so
# don't use this for stable hashing.
#
# This assumes fixed length so you're also obligated to bucket
# by the length of tensor as well
return prims.xor_sum((a * x + b).int(), [0])
# TODO: factor this into a random utility
if device_type == "cpu":
generator = default_generator
@ -114,7 +132,7 @@ def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) ->
# We run the 32-bit hash five times with differing parameters to
# reduce chance of collision
ITER = 5
cs = [kernel(x).item() for _ in range(ITER)]
cs = [hash_storage_kernel(x).item() for _ in range(ITER)]
return struct.pack(">" + "i" * ITER, *cs).hex()
finally:
generator.set_state(state)