mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
legacy_megatron-lm/deepspeed_ZERO1&2 FP16_Optimizer wrapper (#9184)
* megatron-lm FP16_Optimizer Wrap, allow model parallelism aggregation optional * add deepspeed zero1 and zero2 - checkoverflow & clip norm * re-structure code and add the copyright * update the document * refine the code after validation
This commit is contained in:
parent
4771256be3
commit
5ee47e3ffa
5 changed files with 460 additions and 0 deletions
152
orttraining/orttraining/python/training/optim/_ds_modifier.py
Normal file
152
orttraining/orttraining/python/training/optim/_ds_modifier.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Copyright 2020 The Microsoft DeepSpeed Team
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION.
|
||||
# Some functions in this file are adapted from following sources:
|
||||
# - has_overflow_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1792
|
||||
# - get_grad_norm_direct : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1466
|
||||
# - has_overflow_partitioned_grads_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1799
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import types
|
||||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
from numpy import inf
|
||||
|
||||
from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads
|
||||
from .multi_tensor_apply import MultiTensorApply
|
||||
multi_tensor_applier = MultiTensorApply(2048 * 32)
|
||||
|
||||
class DeepSpeedZeROModifier(FP16OptimizerModifier):
|
||||
def __init__(self, optimizer, **kwargs) -> None:
|
||||
super().__init__(optimizer)
|
||||
|
||||
def can_be_modified(self):
|
||||
try:
|
||||
import deepspeed
|
||||
v = LooseVersion(deepspeed.__version__)
|
||||
if v > LooseVersion("0.5.4") or v < LooseVersion("0.4.0"):
|
||||
warnings.warn('Unsupported DeepSpeed version to override, skipped.', UserWarning)
|
||||
return False
|
||||
except Exception as _:
|
||||
return False
|
||||
|
||||
return self.check_requirements(["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"],
|
||||
require_apex=True, require_torch_non_finite_check=True)
|
||||
|
||||
def override_function(self):
|
||||
warnings.warn('DeepSpeed fp16_optimizer functions are overrided with faster implementation.', UserWarning)
|
||||
def get_grad_norm_direct(target, gradients, params, norm_type=2):
|
||||
import amp_C
|
||||
def is_model_parallel_parameter(p):
|
||||
return hasattr(p, 'model_parallel') and p.model_parallel
|
||||
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
torch.distributed.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=target.dp_process_group)
|
||||
|
||||
# Take max across all GPUs.
|
||||
target._model_parallel_all_reduce(tensor=total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
total_norm = 0.0
|
||||
|
||||
#### THIS IS THE ORIGINAL IMPLEMENTATION ####
|
||||
# # if dist.get_rank() == 0:
|
||||
# # logger.info(f"Total Norm beginning {total_norm}")
|
||||
# for g, p in zip(gradients, params):
|
||||
# # Pipeline parallelism may replicate parameters. Avoid multi-counting.
|
||||
# if hasattr(p, 'ds_pipe_replicated') and p.ds_pipe_replicated:
|
||||
# continue
|
||||
# if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
|
||||
# param_norm = g.data.double().norm(2)
|
||||
# total_norm += param_norm.item()**2
|
||||
# # Sum across all model parallel GPUs.
|
||||
# total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
#### END OF THE ORIGINAL IMPLEMENTATION ####
|
||||
|
||||
#### THIS IS THE FASTER IMPLEMENTATION ####
|
||||
grads_for_norm = []
|
||||
for g, p in zip(gradients, params):
|
||||
if is_model_parallel_parameter(p) or (target.model_parallel_rank == 0):
|
||||
# BE NOTED: deepspeed original give a double type conversion here, not sure whether this is impacting some models.
|
||||
# https://github.com/microsoft/DeepSpeed/blob/9e5c0c5c3ecabb68b7e9dffac0e9b8d167e3cab8/deepspeed/runtime/zero/stage2.py#L1501
|
||||
# grads_for_norm.append(g.data.double())
|
||||
grads_for_norm.append(g.data)
|
||||
|
||||
if len(grads_for_norm) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
# Use apex's multi-tensor applier for efficiency reasons.
|
||||
# Multi-tensor applier takes a function and a list of list
|
||||
# and performs the operation on that list all in one kernel.
|
||||
grad_norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads_for_norm],
|
||||
False # no per-parameter norm
|
||||
)
|
||||
# Since we will be summing across data parallel groups,
|
||||
# we need the pow(norm-type).
|
||||
total_norm_cuda = grad_norm ** norm_type
|
||||
else:
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
#### END OF THE FASTER IMPLEMENTATION ####
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
torch.distributed.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=target.dp_process_group)
|
||||
|
||||
target._model_parallel_all_reduce(tensor=total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||
|
||||
if total_norm == float(
|
||||
'inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
def has_overflow_serial(target, params, is_grad_list=False):
|
||||
#### THIS IS THE ORIGINAL IMPLEMENTATION ####
|
||||
# for p in params:
|
||||
# if p.grad is not None and self._has_inf_or_nan(p.grad.data):
|
||||
# return True
|
||||
#
|
||||
# return False
|
||||
#### END OF THE ORIGINAL IMPLEMENTATION ####
|
||||
|
||||
#### THIS IS THE FASTER IMPLEMENTATION ####
|
||||
return check_overflow(params)
|
||||
#### END OF THE FASTER IMPLEMENTATION ####
|
||||
|
||||
def has_overflow_partitioned_grads_serial(target):
|
||||
#### THIS IS THE ORIGINAL IMPLEMENTATION ####
|
||||
# for i in range(len(self.fp16_groups)):
|
||||
# for j, grad in enumerate(self.averaged_gradients[i]):
|
||||
# if grad is not None and self._has_inf_or_nan(grad.data, j):
|
||||
# return True
|
||||
# return False
|
||||
#### END OF THE ORIGINAL IMPLEMENTATION ####
|
||||
|
||||
#### THIS IS THE FASTER IMPLEMENTATION ####
|
||||
for i in range(len(target.fp16_groups)):
|
||||
grad_data = [grad.data for grad in target.averaged_gradients[i] if grad is not None]
|
||||
if check_overflow_for_grads(grad_data):
|
||||
return True
|
||||
return False
|
||||
#### END OF THE FASTER IMPLEMENTATION ####
|
||||
|
||||
self._optimizer.has_overflow_serial = types.MethodType(has_overflow_serial, self._optimizer)
|
||||
self._optimizer.get_grad_norm_direct = types.MethodType(get_grad_norm_direct, self._optimizer)
|
||||
# zero1 should not call into following function, is this a deepspeed bug?
|
||||
self._optimizer.has_overflow_partitioned_grads_serial = types.MethodType(has_overflow_partitioned_grads_serial, self._optimizer)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Copyright 2020 The Microsoft DeepSpeed Team
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION.
|
||||
# Some functions in this file are adapted from following sources:
|
||||
# - _check_overflow : https://github.com/microsoft/DeepSpeedExamples/blob/590364d482b592c3a8a44c28141a8139c7918c55/Megatron-LM-v1.1.5-ZeRO3/megatron/fp16/fp16.py#L294
|
||||
# - clip_master_grads : https://github.com/microsoft/DeepSpeedExamples/blob/590364d482b592c3a8a44c28141a8139c7918c55/Megatron-LM-v1.1.5-ZeRO3/megatron/fp16/fp16.py#L332
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import types
|
||||
import warnings
|
||||
from numpy import inf
|
||||
from ._modifier import FP16OptimizerModifier, check_overflow, clip_grad_norm_fp32
|
||||
|
||||
class LegacyMegatronLMModifier(FP16OptimizerModifier):
|
||||
def __init__(self, optimizer, **kwargs) -> None:
|
||||
super().__init__(optimizer)
|
||||
self.get_horizontal_model_parallel_rank = kwargs.get("get_horizontal_model_parallel_rank", None)
|
||||
self.get_horizontal_model_parallel_group = kwargs.get("get_horizontal_model_parallel_group", None)
|
||||
|
||||
def can_be_modified(self):
|
||||
return self.check_requirements(["_check_overflow", "clip_master_grads"],
|
||||
require_apex=True, require_torch_non_finite_check=True)
|
||||
|
||||
def override_function(self):
|
||||
warnings.warn('Megatron-LM fp16_optimizer functions are overrided with faster implementation.', UserWarning)
|
||||
def clip_master_grads(target, max_norm, norm_type=2):
|
||||
"""
|
||||
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
|
||||
|
||||
Args:
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
|
||||
Returns:
|
||||
Total norm of the current fp32 gradients (viewed as a single vector).
|
||||
|
||||
.. warning::
|
||||
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
|
||||
"""
|
||||
if not target.overflow:
|
||||
fp32_params = []
|
||||
for param_group in target.optimizer.param_groups:
|
||||
for param in param_group['params']:
|
||||
fp32_params.append(param)
|
||||
#### THIS IS THE ORIGINAL IMPLEMENTATION ####
|
||||
#return self.clip_grad_norm(fp32_params, max_norm, norm_type)
|
||||
#### END OF THE ORIGINAL IMPLEMENTATION ####
|
||||
|
||||
#### THIS IS THE FASTER IMPLEMENTATION ####
|
||||
return clip_grad_norm_fp32(fp32_params, max_norm, norm_type,
|
||||
get_horizontal_model_parallel_rank=self.get_horizontal_model_parallel_rank,
|
||||
get_horizontal_model_parallel_group=self.get_horizontal_model_parallel_group)
|
||||
#### END OF THE FASTER IMPLEMENTATION ####
|
||||
else:
|
||||
return -1
|
||||
|
||||
def _check_overflow(target):
|
||||
params = []
|
||||
for group in target.fp16_groups:
|
||||
for param in group:
|
||||
params.append(param)
|
||||
for group in target.fp32_from_fp32_groups:
|
||||
for param in group:
|
||||
params.append(param)
|
||||
#### THIS IS THE ORIGINAL IMPLEMENTATION ####
|
||||
# self.overflow = self.loss_scaler.has_overflow(params)
|
||||
#### END OF THE ORIGINAL IMPLEMENTATION ####
|
||||
|
||||
#### THIS IS THE FASTER IMPLEMENTATION ####
|
||||
target.overflow = check_overflow(params)
|
||||
#### END OF THE FASTER IMPLEMENTATION ####
|
||||
return target.overflow
|
||||
|
||||
self._optimizer._check_overflow = types.MethodType(_check_overflow, self._optimizer)
|
||||
self._optimizer.clip_master_grads = types.MethodType(clip_master_grads, self._optimizer)
|
||||
137
orttraining/orttraining/python/training/optim/_modifier.py
Normal file
137
orttraining/orttraining/python/training/optim/_modifier.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION.
|
||||
# Some functions in this file are adapted from following sources:
|
||||
# - clip_grad_norm_fp32 : https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/optimizer/clip_grads.py#L29
|
||||
# - check_overflow_for_grads : https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/optimizer/optimizer.py#L341
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
from numpy import inf
|
||||
from .multi_tensor_apply import MultiTensorApply
|
||||
multi_tensor_applier = MultiTensorApply(2048 * 32)
|
||||
|
||||
class FP16OptimizerModifier(object):
|
||||
def __init__(self, optimizer) -> None:
|
||||
super().__init__()
|
||||
self._optimizer = optimizer
|
||||
|
||||
def apply(self):
|
||||
if self.can_be_modified():
|
||||
self.override_function()
|
||||
|
||||
def check_requirements(self, required_funcs, require_apex=False, require_torch_non_finite_check=False):
|
||||
try:
|
||||
if require_apex:
|
||||
import amp_C
|
||||
if require_torch_non_finite_check:
|
||||
_ = torch._amp_foreach_non_finite_check_and_unscale_
|
||||
except Exception as _:
|
||||
return False
|
||||
|
||||
if not required_funcs:
|
||||
for func_name in required_funcs:
|
||||
func = getattr(self._optimizer, func_name, None)
|
||||
if not func or not callable(func):
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_overflow(params):
|
||||
grad_data = [p.grad.data for p in params if p.grad is not None]
|
||||
return check_overflow_for_grads(grad_data)
|
||||
|
||||
def check_overflow_for_grads(grad_data):
|
||||
found_inf = torch.cuda.FloatTensor([0.0])
|
||||
scaler = torch.cuda.FloatTensor([1.0])
|
||||
# Unscale and set found inf/nan
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grad_data, found_inf, scaler)
|
||||
|
||||
# Check for nan.
|
||||
overflow = (found_inf.item() > 0)
|
||||
return overflow
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type,
|
||||
get_horizontal_model_parallel_rank=None, get_horizontal_model_parallel_group=None):
|
||||
import amp_C
|
||||
|
||||
horizontal_model_parallel_grad_norm_aggregation = False
|
||||
if get_horizontal_model_parallel_rank and get_horizontal_model_parallel_group:
|
||||
horizontal_model_parallel_grad_norm_aggregation = True
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
is_mp_tensor = hasattr(param, 'model_parallel') and param.model_parallel
|
||||
return is_mp_tensor or (get_horizontal_model_parallel_rank() == 0)
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
grads_for_norm = []
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
grad = param.grad.detach()
|
||||
if grad_not_none:
|
||||
# Make sure the grads are in fp32
|
||||
assert param.grad.type() == 'torch.cuda.FloatTensor'
|
||||
if horizontal_model_parallel_grad_norm_aggregation:
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
if grad_not_none and is_not_tp_duplicate:
|
||||
grads_for_norm.append(grad)
|
||||
else:
|
||||
grads_for_norm.append(grad)
|
||||
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
total_norm = 0.0
|
||||
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.abs().max() for grad in grads_for_norm)
|
||||
if horizontal_model_parallel_grad_norm_aggregation:
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
|
||||
# Take max across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(total_norm_cuda,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=get_horizontal_model_parallel_group())
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
|
||||
else:
|
||||
if norm_type == 2.0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
# Use apex's multi-tensor applier for efficiency reasons.
|
||||
# Multi-tensor applier takes a function and a list of list
|
||||
# and performs the operation on that list all in one kernel.
|
||||
grad_norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads_for_norm],
|
||||
False # no per-parameter norm
|
||||
)
|
||||
|
||||
if not horizontal_model_parallel_grad_norm_aggregation:
|
||||
return grad_norm.item()
|
||||
|
||||
# Since we will be summing across data parallel groups,
|
||||
# we need the pow(norm-type).
|
||||
total_norm = grad_norm ** norm_type
|
||||
|
||||
else:
|
||||
for grad in grads_for_norm:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
total_norm += grad_norm ** norm_type
|
||||
|
||||
if horizontal_model_parallel_grad_norm_aggregation:
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(total_norm,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=get_horizontal_model_parallel_group())
|
||||
total_norm = total_norm.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from ._ds_modifier import DeepSpeedZeROModifier
|
||||
from ._megatron_modifier import LegacyMegatronLMModifier
|
||||
|
||||
LEAGCY_MEGATRON_LM_OPTIMIZER_NAME = "megatron.fp16.fp16.FP16_Optimizer"
|
||||
DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME = "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer"
|
||||
|
||||
OptimizerModifierTypeRegistry = {
|
||||
LEAGCY_MEGATRON_LM_OPTIMIZER_NAME: LegacyMegatronLMModifier,
|
||||
DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME : DeepSpeedZeROModifier,
|
||||
}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from ._modifier_registry import OptimizerModifierTypeRegistry
|
||||
|
||||
def FP16_Optimizer(optimizer, **kwargs):
|
||||
"""
|
||||
Simple wrapper to replace inefficient FP16_Optimizer function calls implemented by libraries for example
|
||||
Apex, DeepSpeed, Megatron-LM.
|
||||
|
||||
Usage:
|
||||
1. DeepSpeed ZeRO Optimizer Override:
|
||||
|
||||
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
|
||||
>>> optimizer = Adam(param_groups,
|
||||
>>> lr=args.lr,
|
||||
>>> weight_decay=args.weight_decay,
|
||||
>>> betas=(args.adam_beta1, args.adam_beta2),
|
||||
>>> eps=args.adam_eps)
|
||||
|
||||
>>> model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
||||
>>> model=model,
|
||||
>>> optimizer=optimizer,
|
||||
>>> args=args,
|
||||
>>> lr_scheduler=lr_scheduler,
|
||||
>>> mpu=mpu,
|
||||
>>> dist_init_required=False)
|
||||
>>> if args.fp16:
|
||||
>>> optimizer = FP16_Optimizer(optimizer)
|
||||
|
||||
2. Megatron-LM-v1.1.5 Optimizer Override:
|
||||
|
||||
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
||||
>>> optimizer = Adam(param_groups,
|
||||
>>> lr=args.lr,
|
||||
>>> weight_decay=args.weight_decay,
|
||||
>>> betas=(args.adam_beta1, args.adam_beta2),
|
||||
>>> eps=args.adam_eps)
|
||||
|
||||
>>> # Wrap into fp16 optimizer.
|
||||
>>> if args.fp16:
|
||||
>>> optimizer = FP16_Optimizer(optimizer,
|
||||
>>> static_loss_scale=args.loss_scale,
|
||||
>>> dynamic_loss_scale=args.dynamic_loss_scale,
|
||||
>>> dynamic_loss_args={
|
||||
>>> 'scale_window': args.loss_scale_window,
|
||||
>>> 'min_scale': args.min_scale,
|
||||
>>> 'delayed_shift': args.hysteresis},
|
||||
>>> verbose=True)
|
||||
>>> optimizer = ORT_FP16_Optimizer(optimizer,
|
||||
>>> get_tensor_model_parallel_rank=mpu.get_model_parallel_rank,
|
||||
>>> get_tensor_model_parallel_group=mpu.get_model_parallel_group)
|
||||
|
||||
Args:
|
||||
optimizer: the FP16_Optimizer instance
|
||||
|
||||
Returns:
|
||||
The modified FP16_Optimizer instance
|
||||
|
||||
"""
|
||||
def get_full_qualified_type_name(o):
|
||||
klass = o.__class__
|
||||
module = klass.__module__
|
||||
if module == 'builtins':
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
optimizer_full_qualified_name = get_full_qualified_type_name(optimizer)
|
||||
if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry:
|
||||
return optimizer
|
||||
|
||||
modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs)
|
||||
modifier.apply()
|
||||
|
||||
return optimizer
|
||||
Loading…
Reference in a new issue