Some modifications to improve readability (#31352)

Summary:
In the long string, formalstring thinks it is good to have a name.

When using dict, literal is better for readability and faster than dict constructor.

I always appreciate your efforts in creating the world's best frameworks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31352

Differential Revision: D19191967

Pulled By: ngimel

fbshipit-source-id: 21f063b163b67de8cf9761a4db5991f74318e991
This commit is contained in:
olramde 2020-01-02 12:45:14 -08:00 committed by Facebook Github Bot
parent 7078f4b27d
commit d770fbc1d2
3 changed files with 23 additions and 19 deletions

View file

@ -85,9 +85,12 @@ def trainbench(name, rnn_creator, nloops=100, warmup=10,
return fwd_time, bwd_time
assert device == 'cuda'
creator_args = dict(seqLength=seqLength, numLayers=numLayers,
inputSize=inputSize, hiddenSize=hiddenSize,
miniBatch=miniBatch, device=device, seed=seed)
creator_args = creator_args = {
'seqLength': seqLength, 'numLayers': numLayers,
'inputSize': inputSize, 'hiddenSize': hiddenSize,
'miniBatch': miniBatch, 'device': device, 'seed': seed
}
modeldef = rnn_creator(**creator_args)
[train_batch(modeldef) for _ in range(warmup)]
@ -217,7 +220,7 @@ if __name__ == '__main__':
del bench_args['cnns']
del bench_args['variable_lstms']
results = dict()
results = {}
if should_bench_varlen_lstms:
if args.nloops + args.warmup > 30:
print_stderr(

View file

@ -161,10 +161,10 @@ def validate_cuda_device(location):
'to map your storages to the CPU.')
if device >= torch.cuda.device_count():
raise RuntimeError('Attempting to deserialize object on CUDA device '
'{} but torch.cuda.device_count() is {}. Please use '
'{device} but torch.cuda.device_count() is {device_count}. Please use '
'torch.load with map_location to map your storages '
'to an existing device.'.format(
device, torch.cuda.device_count()))
device=device, device_count=torch.cuda.device_count()))
return device
@ -188,8 +188,8 @@ def location_tag(storage):
location = tagger(storage)
if location:
return location
raise RuntimeError("don't know how to determine data location of " +
torch.typename(storage))
raise RuntimeError("don't know how to determine data location of "
+ torch.typename(storage))
def default_restore_location(storage, location):
@ -197,9 +197,9 @@ def default_restore_location(storage, location):
result = fn(storage, location)
if result is not None:
return result
raise RuntimeError("don't know how to restore data location of " +
torch.typename(storage) + " (tagged with " +
location + ")")
raise RuntimeError("don't know how to restore data location of "
+ torch.typename(storage) + " (tagged with "
+ location + ")")
def normalize_storage_type(storage_type):
@ -320,9 +320,9 @@ def _check_seekable(f):
def raise_err_msg(patterns, e):
for p in patterns:
if p in str(e):
msg = (str(e) + ". You can only torch.load from a file that is seekable." +
" Please pre-load the data into a buffer like io.BytesIO and" +
" try to load from it instead.")
msg = (str(e) + ". You can only torch.load from a file that is seekable."
+ " Please pre-load the data into a buffer like io.BytesIO and"
+ " try to load from it instead.")
raise type(e)(msg)
raise e
@ -656,8 +656,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
"accessing the object's source attribute or set "
"`torch.nn.Module.dump_patches = True` and use the "
"patch tool to revert the changes.")
msg = ("source code of class '{}' has changed. {}"
.format(torch.typename(container_type), msg))
msg = ("source code of class '{container_type}' has changed. {msg}"
.format(container_type=torch.typename(container_type), msg=msg))
warnings.warn(msg, SourceChangeWarning)
def legacy_load(f):
@ -753,7 +753,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
except tarfile.TarError:
if _is_zipfile(f):
# .zip is used for torch.jit.save and will throw an un-pickling error here
raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
raise RuntimeError(
"{filename} is a zip archive (did you mean to use torch.jit.load()?)".format(filename=f.name))
# if not a tarfile, reset file offset and proceed
f.seek(0)

View file

@ -1,5 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
from collections import OrderedDict
import weakref
import warnings
@ -25,7 +25,7 @@ class RemovableHandle(object):
def __setstate__(self, state):
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(collections.OrderedDict())
self.hooks_dict_ref = weakref.ref(OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]