diff --git a/benchmarks/fastrnns/bench.py b/benchmarks/fastrnns/bench.py index 433a4309f69..1984d280d5c 100644 --- a/benchmarks/fastrnns/bench.py +++ b/benchmarks/fastrnns/bench.py @@ -85,9 +85,12 @@ def trainbench(name, rnn_creator, nloops=100, warmup=10, return fwd_time, bwd_time assert device == 'cuda' - creator_args = dict(seqLength=seqLength, numLayers=numLayers, - inputSize=inputSize, hiddenSize=hiddenSize, - miniBatch=miniBatch, device=device, seed=seed) + creator_args = creator_args = { + 'seqLength': seqLength, 'numLayers': numLayers, + 'inputSize': inputSize, 'hiddenSize': hiddenSize, + 'miniBatch': miniBatch, 'device': device, 'seed': seed + } + modeldef = rnn_creator(**creator_args) [train_batch(modeldef) for _ in range(warmup)] @@ -217,7 +220,7 @@ if __name__ == '__main__': del bench_args['cnns'] del bench_args['variable_lstms'] - results = dict() + results = {} if should_bench_varlen_lstms: if args.nloops + args.warmup > 30: print_stderr( diff --git a/torch/serialization.py b/torch/serialization.py index 1e41b22e1ff..40236b258b0 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -161,10 +161,10 @@ def validate_cuda_device(location): 'to map your storages to the CPU.') if device >= torch.cuda.device_count(): raise RuntimeError('Attempting to deserialize object on CUDA device ' - '{} but torch.cuda.device_count() is {}. Please use ' + '{device} but torch.cuda.device_count() is {device_count}. Please use ' 'torch.load with map_location to map your storages ' 'to an existing device.'.format( - device, torch.cuda.device_count())) + device=device, device_count=torch.cuda.device_count())) return device @@ -188,8 +188,8 @@ def location_tag(storage): location = tagger(storage) if location: return location - raise RuntimeError("don't know how to determine data location of " + - torch.typename(storage)) + raise RuntimeError("don't know how to determine data location of " + + torch.typename(storage)) def default_restore_location(storage, location): @@ -197,9 +197,9 @@ def default_restore_location(storage, location): result = fn(storage, location) if result is not None: return result - raise RuntimeError("don't know how to restore data location of " + - torch.typename(storage) + " (tagged with " + - location + ")") + raise RuntimeError("don't know how to restore data location of " + + torch.typename(storage) + " (tagged with " + + location + ")") def normalize_storage_type(storage_type): @@ -320,9 +320,9 @@ def _check_seekable(f): def raise_err_msg(patterns, e): for p in patterns: if p in str(e): - msg = (str(e) + ". You can only torch.load from a file that is seekable." + - " Please pre-load the data into a buffer like io.BytesIO and" + - " try to load from it instead.") + msg = (str(e) + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") raise type(e)(msg) raise e @@ -656,8 +656,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): "accessing the object's source attribute or set " "`torch.nn.Module.dump_patches = True` and use the " "patch tool to revert the changes.") - msg = ("source code of class '{}' has changed. {}" - .format(torch.typename(container_type), msg)) + msg = ("source code of class '{container_type}' has changed. {msg}" + .format(container_type=torch.typename(container_type), msg=msg)) warnings.warn(msg, SourceChangeWarning) def legacy_load(f): @@ -753,7 +753,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): except tarfile.TarError: if _is_zipfile(f): # .zip is used for torch.jit.save and will throw an un-pickling error here - raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name)) + raise RuntimeError( + "{filename} is a zip archive (did you mean to use torch.jit.load()?)".format(filename=f.name)) # if not a tarfile, reset file offset and proceed f.seek(0) diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index 19a7363906c..0a2836da4b9 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import collections +from collections import OrderedDict import weakref import warnings @@ -25,7 +25,7 @@ class RemovableHandle(object): def __setstate__(self, state): if state[0] is None: # create a dead reference - self.hooks_dict_ref = weakref.ref(collections.OrderedDict()) + self.hooks_dict_ref = weakref.ref(OrderedDict()) else: self.hooks_dict_ref = weakref.ref(state[0]) self.id = state[1]