mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
enable fp16 autocast for dynamo benchmark (#114088)
`--amp` to enable amp path for` CUDA` (default amp_dtype will be float16) and `CPU` (default amp_dtype will be bfloat16). If users set `--amp_dtype`, the amp_dtype from users will have the highest priority. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114088 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
afe6d272c6
commit
6500ccebd7
4 changed files with 47 additions and 32 deletions
|
|
@ -1759,38 +1759,48 @@ class BenchmarkRunner:
|
|||
self.model_iter_fn = None
|
||||
self.grad_scaler = DummyGradScaler()
|
||||
self.autocast = contextlib.nullcontext
|
||||
self.autocast_arg = {}
|
||||
self.optimizer = None
|
||||
self._args = None
|
||||
|
||||
def setup_amp(self):
|
||||
def setup_amp(self, current_device=None):
|
||||
if self.args.only in self.fp32_only_models:
|
||||
return
|
||||
|
||||
if self.args.amp and self.args.devices == ["cuda"]:
|
||||
# AMP training can lead to small loss values which can undeflow
|
||||
# gradient values returning in zero gradients. To solve this
|
||||
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
|
||||
# structure, that scales the loss values to prevent underflow. Loss
|
||||
# values are big at the beginning of training (therefore not
|
||||
# requiring scaling), while loss value tends to be small as network
|
||||
# starts getting better (requiring scaling). GradScaler manages all
|
||||
# of this fine tuning, checking the gradients are turning to inf,
|
||||
# discarding such batches.
|
||||
devices = [current_device] if current_device else self.args.devices
|
||||
if self.args.amp:
|
||||
if devices == ["cuda"]:
|
||||
# AMP training can lead to small loss values which can undeflow
|
||||
# gradient values returning in zero gradients. To solve this
|
||||
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
|
||||
# structure, that scales the loss values to prevent underflow. Loss
|
||||
# values are big at the beginning of training (therefore not
|
||||
# requiring scaling), while loss value tends to be small as network
|
||||
# starts getting better (requiring scaling). GradScaler manages all
|
||||
# of this fine tuning, checking the gradients are turning to inf,
|
||||
# discarding such batches.
|
||||
|
||||
# Since we are not running a long iteration, default value of
|
||||
# init_scale 65536 is going to turn all gradients to inf. Therefore,
|
||||
# we just use a init_scale of 2.0 for benchmarking purpose.
|
||||
# Since we are not running a long iteration, default value of
|
||||
# init_scale 65536 is going to turn all gradients to inf. Therefore,
|
||||
# we just use a init_scale of 2.0 for benchmarking purpose.
|
||||
|
||||
# Disabling Gradscaler because
|
||||
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
|
||||
# 2) Current setup shares grad_scaler for eager and dynamo model,
|
||||
# which is bad as Gradscaler has state and can adjust the scaling
|
||||
# factor between eager and dynamo run, making accuracy check
|
||||
# harder.
|
||||
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
|
||||
self.autocast = torch.cuda.amp.autocast
|
||||
elif (self.args.bfloat16 or self.args.amp) and self.args.devices == ["cpu"]:
|
||||
self.autocast = torch.cpu.amp.autocast
|
||||
# Disabling Gradscaler because
|
||||
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
|
||||
# 2) Current setup shares grad_scaler for eager and dynamo model,
|
||||
# which is bad as Gradscaler has state and can adjust the scaling
|
||||
# factor between eager and dynamo run, making accuracy check
|
||||
# harder.
|
||||
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
|
||||
self.autocast = torch.cuda.amp.autocast
|
||||
if devices == ["cpu"]:
|
||||
self.autocast = torch.cpu.amp.autocast
|
||||
if self.args.amp_dtype:
|
||||
amp_dtype = (
|
||||
torch.float16
|
||||
if self.args.amp_dtype == "float16"
|
||||
else torch.bfloat16
|
||||
)
|
||||
self.autocast_arg["dtype"] = amp_dtype
|
||||
|
||||
def init_optimizer(self, name, device, params):
|
||||
if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER:
|
||||
|
|
@ -2232,7 +2242,7 @@ class BenchmarkRunner:
|
|||
# apply export on module directly
|
||||
# no need for n iterations
|
||||
# the logic should be the same to self.model_iter_fn (forward_pass)
|
||||
with self.autocast():
|
||||
with self.autocast(**self.autocast_arg):
|
||||
optimized_model_iter_fn = optimize_ctx(
|
||||
model_copy, example_inputs
|
||||
)
|
||||
|
|
@ -2968,7 +2978,11 @@ def parse_args(args=None):
|
|||
group_prec.add_argument(
|
||||
"--amp", action="store_true", help="use automatic mixed precision"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--amp-dtype",
|
||||
choices=("bfloat16", "float16"),
|
||||
help="the data type used with automatic mixed precision",
|
||||
)
|
||||
group_printout = parser.add_mutually_exclusive_group()
|
||||
group_printout.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="enable verbose debug printouts"
|
||||
|
|
@ -3654,6 +3668,7 @@ def run(runner, args, original_dir=None):
|
|||
|
||||
else:
|
||||
model, example_inputs = runner.cast_based_on_args(model, example_inputs)
|
||||
runner.setup_amp(current_device)
|
||||
runner.run_one_model(
|
||||
name,
|
||||
model,
|
||||
|
|
|
|||
|
|
@ -553,13 +553,13 @@ class HuggingfaceRunner(BenchmarkRunner):
|
|||
return pred[0]
|
||||
|
||||
def forward_pass(self, mod, inputs, collect_outputs=True):
|
||||
with self.autocast():
|
||||
with self.autocast(**self.autocast_arg):
|
||||
return mod(**inputs)
|
||||
|
||||
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
|
||||
cloned_inputs = clone_inputs(inputs)
|
||||
self.optimizer_zero_grad(mod)
|
||||
with self.autocast():
|
||||
with self.autocast(**self.autocast_arg):
|
||||
pred = mod(**cloned_inputs)
|
||||
loss = self.compute_loss(pred)
|
||||
self.grad_scaler.scale(loss).backward()
|
||||
|
|
|
|||
|
|
@ -332,13 +332,13 @@ class TimmRunner(BenchmarkRunner):
|
|||
return reduce_to_scalar_loss(pred) / 1000.0
|
||||
|
||||
def forward_pass(self, mod, inputs, collect_outputs=True):
|
||||
with self.autocast():
|
||||
with self.autocast(**self.autocast_arg):
|
||||
return mod(*inputs)
|
||||
|
||||
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
|
||||
cloned_inputs = clone_inputs(inputs)
|
||||
self.optimizer_zero_grad(mod)
|
||||
with self.autocast():
|
||||
with self.autocast(**self.autocast_arg):
|
||||
pred = mod(*cloned_inputs)
|
||||
if isinstance(pred, tuple):
|
||||
pred = pred[0]
|
||||
|
|
|
|||
|
|
@ -528,13 +528,13 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||
return reduce_to_scalar_loss(pred)
|
||||
|
||||
def forward_pass(self, mod, inputs, collect_outputs=True):
|
||||
with self.autocast():
|
||||
with self.autocast(**self.autocast_arg):
|
||||
return mod(*inputs)
|
||||
|
||||
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
|
||||
cloned_inputs = clone_inputs(inputs)
|
||||
self.optimizer_zero_grad(mod)
|
||||
with self.autocast():
|
||||
with self.autocast(**self.autocast_arg):
|
||||
pred = mod(*cloned_inputs)
|
||||
loss = self.compute_loss(pred)
|
||||
self.grad_scaler.scale(loss).backward()
|
||||
|
|
|
|||
Loading…
Reference in a new issue