Fix usages of contextmanager without finally (#96170)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96170
Approved by: https://github.com/ngimel, https://github.com/malfet
This commit is contained in:
Horace He 2023-03-07 23:27:01 +00:00 committed by PyTorch MergeBot
parent 34d18c8bee
commit 5bbec680d7
15 changed files with 91 additions and 59 deletions

View file

@ -228,14 +228,15 @@ def cuda_pointwise_context(loop_levels, block_count, block_size):
old_block_size = torch._C._jit_get_te_cuda_pointwise_block_size()
torch._C._jit_set_te_cuda_pointwise_block_size(block_size)
yield
if loop_levels:
torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
if block_count:
torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
if block_size:
torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
try:
yield
finally:
if loop_levels:
torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
if block_count:
torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
if block_size:
torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
# Auxiliary class to facilitate dynamic input shape
class DynamicShape:

View file

@ -65,8 +65,10 @@ class Analyzer(Visitor):
if do_copy:
ws = copy(ws)
self.workspace_ctx.append(ws)
yield ws
del self.workspace_ctx[-1]
try:
yield ws
finally:
del self.workspace_ctx[-1]
def define_blob(self, blob):
self.workspace[blob] += 1
@ -166,12 +168,14 @@ class Text:
self.add('with %s:' % text)
self._indent += 4
self._lines_in_context.append(0)
yield
if text is not None:
if self._lines_in_context[-1] == 0:
self.add('pass')
self._indent -= 4
del self._lines_in_context[-1]
try:
yield
finally:
if text is not None:
if self._lines_in_context[-1] == 0:
self.add('pass')
self._indent -= 4
del self._lines_in_context[-1]
def add(self, text):
self._lines_in_context[-1] += 1

View file

@ -811,8 +811,10 @@ class TestFFT(TestCase):
plan_cache = torch.backends.cuda.cufft_plan_cache[device]
original = plan_cache.max_size
plan_cache.max_size = n
yield
plan_cache.max_size = original
try:
yield
finally:
plan_cache.max_size = original
with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
self._test_fft_ifft_rfft_irfft(devices[0], dtype)

View file

@ -1012,7 +1012,6 @@ def disable_cache_limit():
try:
yield
finally:
pass
config.cache_size_limit = prior

View file

@ -1228,10 +1228,12 @@ def track_graph_compiling(aot_config, graph_name):
global graph_being_compiled
# TODO: Don't shove the aot_id in here; set it in the context
graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"]
yield
global nth_graph
nth_graph += 1
graph_being_compiled = []
try:
yield
finally:
global nth_graph
nth_graph += 1
graph_being_compiled = []
def make_boxed_func(f):

View file

@ -560,8 +560,10 @@ class Kernel(CodeGen):
def set_current_node(self, node):
prior = self.current_node
self.current_node = node
yield
self.current_node = prior
try:
yield
finally:
self.current_node = prior
@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
@ -575,11 +577,13 @@ class Kernel(CodeGen):
self.compute = cb
self.stores = sb
self.cse = cse.clone()
yield
self.loads = loads
self.compute = compute
self.stores = stores
self.cse = cse
try:
yield
finally:
self.loads = loads
self.compute = compute
self.stores = stores
self.cse = cse
def load(self, name: str, index: sympy.Expr):
raise NotImplementedError()

View file

@ -696,11 +696,13 @@ class TritonKernel(Kernel):
# and write out a reduction loop
self.codegen_body()
self.inside_reduction = False
yield
if not self.persistent_reduction:
# flush out any code before opening the next loop
self.codegen_body()
self.inside_reduction = True
try:
yield
if not self.persistent_reduction:
# flush out any code before opening the next loop
self.codegen_body()
finally:
self.inside_reduction = True
return ctx()
@ -957,10 +959,12 @@ class TritonKernel(Kernel):
mask = self.cse.generate(self.compute, f"{mask} & {prior}")
self._load_mask = mask
with self.swap_buffers(self.compute, self.compute):
# TODO(jansel): do we need a reshape here?
yield mask
self._load_mask = prior
try:
with self.swap_buffers(self.compute, self.compute):
# TODO(jansel): do we need a reshape here?
yield mask
finally:
self._load_mask = prior
def load(self, name: str, index: sympy.Expr):
var = self.args.input(name)

View file

@ -228,7 +228,7 @@ def end_graph():
cur_file = inspect.stack()[1].filename
print(f"SUMMARY ({cur_file})")
print(
f"{overall_time:.2f}ms\t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
)
print()
@ -250,10 +250,12 @@ class DebugAutotuner(CachingAutotuner):
num_gb = get_num_bytes(*args) / 1e9
gb_per_s = num_gb / (ms / 1e3)
collected_calls.append((kernel_name, ms, num_gb, gb_per_s))
collected_calls.append((ms, num_gb, gb_per_s, kernel_name)),
import colorama
info_str = f"{kernel_name}\t {ms:.3f}ms\t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s"
info_str = (
f"{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s \t {kernel_name}"
)
if ms > 0.012 and gb_per_s < 650:
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
else:

View file

@ -444,8 +444,10 @@ class IndentedBuffer:
@contextlib.contextmanager
def ctx():
self._indent += offset
yield
self._indent -= offset
try:
yield
finally:
self._indent -= offset
return ctx()

View file

@ -1243,8 +1243,10 @@ def _create_named_tuple(
def _disable_emit_hooks():
hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
yield
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
try:
yield
finally:
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811

View file

@ -84,5 +84,7 @@ def range(msg, *args, **kwargs):
msg (str): message to associate with the range
"""
range_push(msg.format(*args, **kwargs))
yield
range_pop()
try:
yield
finally:
range_pop()

View file

@ -93,8 +93,10 @@ def redirect(std: str, to_file: str):
with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
_redirect(dst)
yield
_redirect(orig_std)
try:
yield
finally:
_redirect(orig_std)
redirect_stdout = partial(redirect, "stdout")

View file

@ -69,5 +69,7 @@ def range(msg, *args, **kwargs):
msg (str): message to associate with the range
"""
range_push(msg.format(*args, **kwargs))
yield
range_pop()
try:
yield
finally:
range_pop()

View file

@ -56,8 +56,10 @@ class SourceChangeWarning(Warning):
@contextmanager
def mkdtemp():
path = tempfile.mkdtemp()
yield path
shutil.rmtree(path)
try:
yield path
finally:
shutil.rmtree(path)
_package_registry = []

View file

@ -1189,11 +1189,13 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True):
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
yield
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
if init_pg:
c10d.destroy_process_group()
try:
yield
finally:
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
if init_pg:
c10d.destroy_process_group()
class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):