mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix usages of contextmanager without finally (#96170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96170 Approved by: https://github.com/ngimel, https://github.com/malfet
This commit is contained in:
parent
34d18c8bee
commit
5bbec680d7
15 changed files with 91 additions and 59 deletions
|
|
@ -228,14 +228,15 @@ def cuda_pointwise_context(loop_levels, block_count, block_size):
|
|||
old_block_size = torch._C._jit_get_te_cuda_pointwise_block_size()
|
||||
torch._C._jit_set_te_cuda_pointwise_block_size(block_size)
|
||||
|
||||
yield
|
||||
|
||||
if loop_levels:
|
||||
torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
|
||||
if block_count:
|
||||
torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
|
||||
if block_size:
|
||||
torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if loop_levels:
|
||||
torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
|
||||
if block_count:
|
||||
torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
|
||||
if block_size:
|
||||
torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
|
||||
|
||||
# Auxiliary class to facilitate dynamic input shape
|
||||
class DynamicShape:
|
||||
|
|
|
|||
|
|
@ -65,8 +65,10 @@ class Analyzer(Visitor):
|
|||
if do_copy:
|
||||
ws = copy(ws)
|
||||
self.workspace_ctx.append(ws)
|
||||
yield ws
|
||||
del self.workspace_ctx[-1]
|
||||
try:
|
||||
yield ws
|
||||
finally:
|
||||
del self.workspace_ctx[-1]
|
||||
|
||||
def define_blob(self, blob):
|
||||
self.workspace[blob] += 1
|
||||
|
|
@ -166,12 +168,14 @@ class Text:
|
|||
self.add('with %s:' % text)
|
||||
self._indent += 4
|
||||
self._lines_in_context.append(0)
|
||||
yield
|
||||
if text is not None:
|
||||
if self._lines_in_context[-1] == 0:
|
||||
self.add('pass')
|
||||
self._indent -= 4
|
||||
del self._lines_in_context[-1]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if text is not None:
|
||||
if self._lines_in_context[-1] == 0:
|
||||
self.add('pass')
|
||||
self._indent -= 4
|
||||
del self._lines_in_context[-1]
|
||||
|
||||
def add(self, text):
|
||||
self._lines_in_context[-1] += 1
|
||||
|
|
|
|||
|
|
@ -811,8 +811,10 @@ class TestFFT(TestCase):
|
|||
plan_cache = torch.backends.cuda.cufft_plan_cache[device]
|
||||
original = plan_cache.max_size
|
||||
plan_cache.max_size = n
|
||||
yield
|
||||
plan_cache.max_size = original
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
plan_cache.max_size = original
|
||||
|
||||
with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
|
||||
self._test_fft_ifft_rfft_irfft(devices[0], dtype)
|
||||
|
|
|
|||
|
|
@ -1012,7 +1012,6 @@ def disable_cache_limit():
|
|||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
config.cache_size_limit = prior
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1228,10 +1228,12 @@ def track_graph_compiling(aot_config, graph_name):
|
|||
global graph_being_compiled
|
||||
# TODO: Don't shove the aot_id in here; set it in the context
|
||||
graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"]
|
||||
yield
|
||||
global nth_graph
|
||||
nth_graph += 1
|
||||
graph_being_compiled = []
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
global nth_graph
|
||||
nth_graph += 1
|
||||
graph_being_compiled = []
|
||||
|
||||
|
||||
def make_boxed_func(f):
|
||||
|
|
|
|||
|
|
@ -560,8 +560,10 @@ class Kernel(CodeGen):
|
|||
def set_current_node(self, node):
|
||||
prior = self.current_node
|
||||
self.current_node = node
|
||||
yield
|
||||
self.current_node = prior
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.current_node = prior
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_buffers(self, lb, cb=None, sb=None):
|
||||
|
|
@ -575,11 +577,13 @@ class Kernel(CodeGen):
|
|||
self.compute = cb
|
||||
self.stores = sb
|
||||
self.cse = cse.clone()
|
||||
yield
|
||||
self.loads = loads
|
||||
self.compute = compute
|
||||
self.stores = stores
|
||||
self.cse = cse
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.loads = loads
|
||||
self.compute = compute
|
||||
self.stores = stores
|
||||
self.cse = cse
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -696,11 +696,13 @@ class TritonKernel(Kernel):
|
|||
# and write out a reduction loop
|
||||
self.codegen_body()
|
||||
self.inside_reduction = False
|
||||
yield
|
||||
if not self.persistent_reduction:
|
||||
# flush out any code before opening the next loop
|
||||
self.codegen_body()
|
||||
self.inside_reduction = True
|
||||
try:
|
||||
yield
|
||||
if not self.persistent_reduction:
|
||||
# flush out any code before opening the next loop
|
||||
self.codegen_body()
|
||||
finally:
|
||||
self.inside_reduction = True
|
||||
|
||||
return ctx()
|
||||
|
||||
|
|
@ -957,10 +959,12 @@ class TritonKernel(Kernel):
|
|||
mask = self.cse.generate(self.compute, f"{mask} & {prior}")
|
||||
|
||||
self._load_mask = mask
|
||||
with self.swap_buffers(self.compute, self.compute):
|
||||
# TODO(jansel): do we need a reshape here?
|
||||
yield mask
|
||||
self._load_mask = prior
|
||||
try:
|
||||
with self.swap_buffers(self.compute, self.compute):
|
||||
# TODO(jansel): do we need a reshape here?
|
||||
yield mask
|
||||
finally:
|
||||
self._load_mask = prior
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
var = self.args.input(name)
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ def end_graph():
|
|||
cur_file = inspect.stack()[1].filename
|
||||
print(f"SUMMARY ({cur_file})")
|
||||
print(
|
||||
f"{overall_time:.2f}ms\t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
|
||||
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
|
||||
)
|
||||
print()
|
||||
|
||||
|
|
@ -250,10 +250,12 @@ class DebugAutotuner(CachingAutotuner):
|
|||
num_gb = get_num_bytes(*args) / 1e9
|
||||
gb_per_s = num_gb / (ms / 1e3)
|
||||
|
||||
collected_calls.append((kernel_name, ms, num_gb, gb_per_s))
|
||||
collected_calls.append((ms, num_gb, gb_per_s, kernel_name)),
|
||||
import colorama
|
||||
|
||||
info_str = f"{kernel_name}\t {ms:.3f}ms\t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s"
|
||||
info_str = (
|
||||
f"{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s \t {kernel_name}"
|
||||
)
|
||||
if ms > 0.012 and gb_per_s < 650:
|
||||
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -444,8 +444,10 @@ class IndentedBuffer:
|
|||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
self._indent += offset
|
||||
yield
|
||||
self._indent -= offset
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._indent -= offset
|
||||
|
||||
return ctx()
|
||||
|
||||
|
|
|
|||
|
|
@ -1243,8 +1243,10 @@ def _create_named_tuple(
|
|||
def _disable_emit_hooks():
|
||||
hooks = torch._C._jit_get_emit_hooks()
|
||||
torch._C._jit_set_emit_hooks(None, None)
|
||||
yield
|
||||
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
|
||||
|
||||
|
||||
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811
|
||||
|
|
|
|||
|
|
@ -84,5 +84,7 @@ def range(msg, *args, **kwargs):
|
|||
msg (str): message to associate with the range
|
||||
"""
|
||||
range_push(msg.format(*args, **kwargs))
|
||||
yield
|
||||
range_pop()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
range_pop()
|
||||
|
|
|
|||
|
|
@ -93,8 +93,10 @@ def redirect(std: str, to_file: str):
|
|||
|
||||
with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
|
||||
_redirect(dst)
|
||||
yield
|
||||
_redirect(orig_std)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_redirect(orig_std)
|
||||
|
||||
|
||||
redirect_stdout = partial(redirect, "stdout")
|
||||
|
|
|
|||
|
|
@ -69,5 +69,7 @@ def range(msg, *args, **kwargs):
|
|||
msg (str): message to associate with the range
|
||||
"""
|
||||
range_push(msg.format(*args, **kwargs))
|
||||
yield
|
||||
range_pop()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
range_pop()
|
||||
|
|
|
|||
|
|
@ -56,8 +56,10 @@ class SourceChangeWarning(Warning):
|
|||
@contextmanager
|
||||
def mkdtemp():
|
||||
path = tempfile.mkdtemp()
|
||||
yield path
|
||||
shutil.rmtree(path)
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
_package_registry = []
|
||||
|
|
|
|||
|
|
@ -1189,11 +1189,13 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True):
|
|||
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
yield
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
if init_pg:
|
||||
c10d.destroy_process_group()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
if init_pg:
|
||||
c10d.destroy_process_group()
|
||||
|
||||
|
||||
class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
|
||||
|
|
|
|||
Loading…
Reference in a new issue