diff --git a/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json index 1cc69e3975..44ea17b717 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/training_ops.cpu.json @@ -274,5 +274,9 @@ [ "InPlaceAccumulatorV2 com.microsoft CPUExecutionProvider", 12968279839987729832 + ], + [ + "InplaceClipGradNorm com.microsoft CPUExecutionProvider", + 10251631611024722504 ] ] diff --git a/onnxruntime/test/testdata/training_api/adamw.onnx b/onnxruntime/test/testdata/training_api/adamw.onnx index 6eb4e83a76..eed8ef5e48 100644 Binary files a/onnxruntime/test/testdata/training_api/adamw.onnx and b/onnxruntime/test/testdata/training_api/adamw.onnx differ diff --git a/orttraining/orttraining/python/training/onnxblock/model.py b/orttraining/orttraining/python/training/onnxblock/model.py index b2cd3906c2..41bd15b118 100644 --- a/orttraining/orttraining/python/training/onnxblock/model.py +++ b/orttraining/orttraining/python/training/onnxblock/model.py @@ -17,7 +17,6 @@ class Model(building_blocks.Block): def __init__(self): super().__init__() - ... @abstractmethod def build(self, *args, **kwargs): diff --git a/orttraining/orttraining/python/training/onnxblock/optim/optim.py b/orttraining/orttraining/python/training/onnxblock/optim/optim.py index 7650657365..5bf10da1f6 100644 --- a/orttraining/orttraining/python/training/onnxblock/optim/optim.py +++ b/orttraining/orttraining/python/training/onnxblock/optim/optim.py @@ -99,33 +99,36 @@ class ClipGradNorm(building_blocks.Block): self._max_norm = max_norm - self._reduce = building_blocks.ReduceAllL2() - self._add = building_blocks.Add() - self._div = building_blocks.Div() - self._mul = building_blocks.Mul() - self._clip = building_blocks.Clip(clip_max=1.0) - - def build(self, *gradient_names): + def build(self, gradients_name: str): """Adds a clip grad norm sub graph to the onnx model.""" # get the model to manipulate onnx_model = accessor.global_accessor.model - # add the necessary graph initializers - add_node_eps_name = graph_utils.generate_random_graph_name("add_eps") - onnx_model.graph.initializer.append( - onnx.helper.make_tensor(add_node_eps_name, onnx.TensorProto.FLOAT, [1], [1e-6]) + node_attributes = { + "max_norm": self._max_norm, + } + + # create the graph node for InplaceClipGradNorm + cgn_node_input_names = [gradients_name] + cgn_node_output_name = graph_utils.generate_random_graph_name("clip_grad_norm_output") + cgn_node_output_names = [cgn_node_output_name] + cgn_node = onnx.helper.make_node( + "InplaceClipGradNorm", + cgn_node_input_names, + cgn_node_output_names, + name=graph_utils.generate_random_graph_name("InplaceClipGradNorm"), + domain="com.microsoft", + **node_attributes, ) - max_norm_name = graph_utils.generate_random_graph_name("max_norm") - onnx_model.graph.initializer.append( - onnx.helper.make_tensor(max_norm_name, onnx.TensorProto.FLOAT, [1], [self._max_norm]) + onnx_model.graph.node.append(cgn_node) + + # Add the output to the value info of the model. + onnx_model.graph.value_info.append( + onnx.helper.make_tensor_sequence_value_info(cgn_node_output_name, onnx.TensorProto.FLOAT, None) ) - # perform gradient clipping - total_norm_name = self._reduce(*gradient_names) - adjusted_total_norm_name = self._add(total_norm_name, add_node_eps_name) - clip_coef_name = self._clip(self._div(max_norm_name, adjusted_total_norm_name)) - return [self._mul(grad_name, clip_coef_name) for grad_name in gradient_names] + return cgn_node_output_name class AdamW(model.Model): @@ -181,7 +184,7 @@ class AdamW(model.Model): params_name = "params" first_order_moments_name = "first_order_moments" second_order_moments_name = "second_order_moments" - gradient_suffix = "_grad" + gradients_name = "gradients" trainable_parameters, _ = parameters @@ -194,30 +197,21 @@ class AdamW(model.Model): ) # Prepare the tensor sequence inputs for params and moments - for input_name in [params_name, first_order_moments_name, second_order_moments_name]: + for input_name in [params_name, gradients_name, first_order_moments_name, second_order_moments_name]: onnx_model.graph.input.append( onnx.helper.make_tensor_sequence_value_info(input_name, trainable_parameters[0].data_type, None) ) - # TODO: Make the grads as a tensor sequence input after implementing clip grad - # normalization implementation which takes in a tensor sequence. - grad_names = [] - for param in trainable_parameters: - grad_names.append(f"{param.name}{gradient_suffix}") - onnx_model.graph.input.append( - onnx.helper.make_tensor_value_info(grad_names[-1], param.data_type, param.dims) - ) - # Clip the gradients if needed if self._clip_grad is not None: - grad_names = self._clip_grad(*grad_names) + gradients_name = self._clip_grad(gradients_name) # Run multi tensor AdamWOptimizer updated_flag_name = self._adamw( learning_rate_name, step_name, params_name, - self._sc(*grad_names), + gradients_name, first_order_moments_name, second_order_moments_name, ) diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py index 15ae7824dd..2dafe7e8f1 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py @@ -478,12 +478,13 @@ def test_adamw_optimizer_execution(): "learning_rate": np.full(1, learning_rate, dtype=np.float32), "step": np.full(1, step, dtype=np.int64), "params": [], + "gradients": [], "first_order_moments": [], "second_order_moments": [], } - for name, param in pt_model.named_parameters(): + for _, param in pt_model.named_parameters(): ort_inputs["params"].append(_to_numpy(copy.deepcopy(param))) - ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad)) + ort_inputs["gradients"].append(_to_numpy(copy.deepcopy(param.grad))) ort_inputs["first_order_moments"].append(_to_numpy(torch.zeros_like(param))) ort_inputs["second_order_moments"].append(_to_numpy(torch.zeros_like(param))) @@ -696,30 +697,30 @@ def test_grad_clipping_execution(): # Prepare the onnx model with only grad clipping onnx_model = onnx.ModelProto() - onnx_model.graph.name = "AdamW Optimizer Model" + onnx_model.graph.name = "ClipGradNorm Model" onnx_model.producer_name = "grad clipping test" onnx_model.opset_import.extend(onnxblock.optim.optim._OPSET_IMPORTS) onnx_model.ir_version = onnx.IR_VERSION class GradClippingModel(onnxblock.Model): def __init__(self, max_norm): + super().__init__() self._grad_clip = onnxblock.optim.ClipGradNorm(max_norm) - def build(self, *grad_names): - return self._grad_clip(*grad_names) + def build(self, grads_name): + return self._grad_clip(grads_name) - grad_names = [] - for name, param in pt_model.named_parameters(): - grad_names.append(f"{name}_grad") - - onnx_model.graph.input.append( - onnx.helper.make_tensor_value_info(grad_names[-1], onnx.TensorProto.FLOAT, param.shape) - ) + onnx_model.graph.input.append( + onnx.helper.make_tensor_sequence_value_info("gradients", onnx.TensorProto.FLOAT, None) + ) grad_clip = GradClippingModel(2.5) - with onnxblock.onnx_model(onnx_model): - ort_output_names = grad_clip(*grad_names) + ort_output_names = grad_clip("gradients") + + onnx_model.graph.output.append( + onnx.helper.make_tensor_sequence_value_info(ort_output_names, onnx.TensorProto.FLOAT, None) + ) def mse_loss(prediction, target): loss = torch.nn.MSELoss() @@ -732,16 +733,16 @@ def test_grad_clipping_execution(): loss = mse_loss(pt_model(x), target) loss.backward() - ort_inputs = {} - for name, param in pt_model.named_parameters(): - ort_inputs[f"{name}_grad"] = _to_numpy(copy.deepcopy(param.grad)) + ort_inputs = {"gradients": []} + for _, param in pt_model.named_parameters(): + ort_inputs["gradients"].append(_to_numpy(copy.deepcopy(param.grad))) torch.nn.utils.clip_grad_norm_(pt_model.parameters(), 2.5) # Then no error occurs when executing the model ort_session = onnxruntime.InferenceSession(onnx_fo.name, providers=C.get_available_providers()) - ort_outs = ort_session.run(ort_output_names, ort_inputs) + ort_outs = ort_session.run([ort_output_names], ort_inputs) # assert all the gradients are close - for ort_grad, pt_param in zip(ort_outs, pt_model.parameters()): + for ort_grad, pt_param in zip(ort_outs[0], pt_model.parameters()): assert np.allclose(ort_grad, _to_numpy(pt_param.grad)) diff --git a/orttraining/orttraining/test/training_ops/cuda/optimizer/clip_grad_norm_test.cc b/orttraining/orttraining/test/training_ops/cuda/optimizer/clip_grad_norm_test.cc index 65b0b8d0bf..a2b974de54 100644 --- a/orttraining/orttraining/test/training_ops/cuda/optimizer/clip_grad_norm_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/optimizer/clip_grad_norm_test.cc @@ -8,9 +8,9 @@ namespace onnxruntime { namespace test { -#ifdef USE_CUDA +namespace { -TEST(OptimizerTest, InplaceClipGradNorm) { +void InplaceClipGradNormTest(std::vector>* providers) { OpTester test("InplaceClipGradNorm", 1, onnxruntime::kMSDomain); SeqTensors gradients_input; @@ -28,12 +28,10 @@ TEST(OptimizerTest, InplaceClipGradNorm) { clipped_gradients.AddTensor({5}, {3.7654f, 4.2361f, 4.7068f, 5.1775f, 5.6481f}); test.AddSeqOutput("clipped_gradients", clipped_gradients); - std::vector> providers; - providers.emplace_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, providers); } -TEST(OptimizerTest, InplaceClipGradNormNoClipping) { +void InplaceClipGradNormNoClippingTest(std::vector>* providers) { OpTester test("InplaceClipGradNorm", 1, onnxruntime::kMSDomain); SeqTensors gradients_input; @@ -51,9 +49,35 @@ TEST(OptimizerTest, InplaceClipGradNormNoClipping) { clipped_gradients.AddTensor({5}, {8.f, 9.f, 10.f, 11.f, 12.f}); test.AddSeqOutput("clipped_gradients", clipped_gradients); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, providers); +} + +} // namespace + +TEST(OptimizerTest, InplaceClipGradNorm_CPU) { + std::vector> providers; + providers.emplace_back(DefaultCpuExecutionProvider()); + InplaceClipGradNormTest(&providers); +} + +TEST(OptimizerTest, InplaceClipGradNormNoClipping_CPU) { + std::vector> providers; + providers.emplace_back(DefaultCpuExecutionProvider()); + InplaceClipGradNormNoClippingTest(&providers); +} + +#ifdef USE_CUDA + +TEST(OptimizerTest, InplaceClipGradNorm_CUDA) { std::vector> providers; providers.emplace_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); + InplaceClipGradNormTest(&providers); +} + +TEST(OptimizerTest, InplaceClipGradNormNoClipping_CUDA) { + std::vector> providers; + providers.emplace_back(DefaultCudaExecutionProvider()); + InplaceClipGradNormNoClippingTest(&providers); } #endif diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index ddefa6333f..80916f9c3c 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -22,11 +22,50 @@ constexpr char GROUP_ZERO_NAME[] = "group0"; // TODO: Conolidate with frontend tooling const std::vector MOMENT_STATE_NAMES{"momentum0", "momentum1"}; -constexpr char LearningRateName[] = "learning_rate"; -constexpr char StepName[] = "step"; -constexpr char ParamsName[] = "params"; -constexpr char FirstOrderMomentsName[] = "first_order_moments"; -constexpr char SecondOrderMomentsName[] = "second_order_moments"; +constexpr std::array AdamWOptimizerInputs = { + "learning_rate", + "step", + "params", + "gradients", + "first_order_moments", + "second_order_moments"}; + +Status GraphInputsAreExpected(gsl::span actual_graph_inputs, + gsl::span expected_graph_inputs) { + const auto stringify = [](const auto& container) { + if (container.empty()) { + return std::string("[]"); + } + std::string container_str("["); + for (const auto& val : container) { + container_str += std::string(val) + ", "; + } + container_str.pop_back(); + container_str.back() = ']'; + + return container_str; + }; + + const auto construct_unexpected_input_status = [&stringify](const auto& actual_inputs, const auto& expected_inputs) { + std::ostringstream error_stream; + error_stream << "Invalid graph inputs." + << "\n\tExpected: " << stringify(expected_inputs) + << "\n\tActual: " << stringify(actual_inputs); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, error_stream.str()); + }; + + if (actual_graph_inputs.size() != expected_graph_inputs.size()) { + return construct_unexpected_input_status(actual_graph_inputs, expected_graph_inputs); + } + + for (size_t input_idx = 0; input_idx < expected_graph_inputs.size(); ++input_idx) { + if (actual_graph_inputs[input_idx] != expected_graph_inputs[input_idx]) { + return construct_unexpected_input_status(actual_graph_inputs, expected_graph_inputs); + } + } + + return Status::OK(); +} } // namespace @@ -48,38 +87,30 @@ Status Optimizer::GenerateMomentumNamedStates() { return Status::OK(); } -// Constructs the ortvalue inputs to be fed to the graph -// at each step +// Constructs the ortvalue inputs to be fed to the graph at each step Status Optimizer::ConstructInputs() { if (optimizer_type_ == OptimizerType::AdamW) { auto& param_named_optimizer_states = optimizer_state_.param_named_optimizer_states; - std::vector params, first_order_moments, second_order_moments; - // TODO: Change to tensor seq implementation once clip grad norm op - // that accepts tensor seq as input for gradients is complete. - std::vector grads; - - // Input names 0-4 are reserved for lr, step, params, first order moments, second order moments - // input names 5 onwards are all the gradient names. - // Collect all the inputs based on the gradient names order. - for (size_t i = 5; i < input_names_.size(); i++) { - std::string param_name; - if (utils::GetParamNameFromGradient(input_names_[i], param_name)) { - const auto named_parameter_it = named_parameters_.find(param_name); - ORT_ENFORCE(named_parameter_it != named_parameters_.end(), - "Unknown param: ", param_name, " for field: ", input_names_[i]); - - // Collect the gradients as ortvalues - grads.push_back(named_parameter_it->second->Gradient()); + std::vector params, grads, first_order_moments, second_order_moments; + // Collect all the non user defined inputs from the named_parameters_. + for (auto& [parameter_name, parameter] : named_parameters_) { + if (parameter->RequiresGrad()) { // Collect parameters and prepare for tensorseq creation - auto* param_tensor = named_parameter_it->second->Data().GetMutable(); + auto* param_tensor = parameter->Data().GetMutable(); params.emplace_back( Tensor(param_tensor->DataType(), param_tensor->Shape(), param_tensor->MutableDataRaw(), param_tensor->Location())); + // Collect gradients and prepare for tensorseq creation + auto* grad_tensor = parameter->Gradient().GetMutable(); + grads.emplace_back( + Tensor(grad_tensor->DataType(), grad_tensor->Shape(), + grad_tensor->MutableDataRaw(), grad_tensor->Location())); + // Collect first order moments and prepare for tensorseq creation - auto* first_order_moment_tensor = param_named_optimizer_states.at(param_name) + auto* first_order_moment_tensor = param_named_optimizer_states.at(parameter_name) .momentum_named_states.at(MOMENT_STATE_NAMES[0]) .GetMutable(); first_order_moments.emplace_back( @@ -87,20 +118,17 @@ Status Optimizer::ConstructInputs() { first_order_moment_tensor->MutableDataRaw(), first_order_moment_tensor->Location())); // Collect second order moments and prepare for tensorseq creation - auto* second_order_moment_tensor = param_named_optimizer_states.at(param_name) + auto* second_order_moment_tensor = param_named_optimizer_states.at(parameter_name) .momentum_named_states.at(MOMENT_STATE_NAMES[1]) .GetMutable(); second_order_moments.emplace_back( Tensor(second_order_moment_tensor->DataType(), second_order_moment_tensor->Shape(), second_order_moment_tensor->MutableDataRaw(), second_order_moment_tensor->Location())); - } else { - ORT_ENFORCE( - false, "This is an invalid graph. Optimizer graph contains unknown user input:", input_names_[i]); } } const auto tensorseq_inserter = [](auto& tensors, auto* inputs) { - ORT_ENFORCE(!tensors.empty(), "Tensors cannot be empty while building a tensor sequence."); + ORT_ENFORCE(!tensors.empty(), "Tensors vector cannot be empty while building a tensor sequence."); auto tensor_seq = std::make_unique(tensors.front().DataType()); tensor_seq->SetElements(std::move(tensors)); @@ -111,13 +139,9 @@ Status Optimizer::ConstructInputs() { // Add the params and moments as tensorseq ortvalues to inputs tensorseq_inserter(params, &inputs_); + tensorseq_inserter(grads, &inputs_); tensorseq_inserter(first_order_moments, &inputs_); tensorseq_inserter(second_order_moments, &inputs_); - - // Add the gradients as ortvalues to inputs - inputs_.insert(inputs_.end(), - std::make_move_iterator(grads.begin()), - std::make_move_iterator(grads.end())); } // Add other optimizer reordering logic here return Status::OK(); @@ -138,13 +162,9 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, ORT_THROW_IF_ERROR(optim_sess_->Initialize()); utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_); - ORT_ENFORCE(input_names_[0] == LearningRateName); // TODO: make this better - ORT_ENFORCE(input_names_[1] == StepName); // TODO: make this better - ORT_ENFORCE(input_names_[2] == ParamsName); // TODO: make this better if (optimizer_type_ == OptimizerType::AdamW) { - ORT_ENFORCE(input_names_[3] == FirstOrderMomentsName); // TODO: make this better - ORT_ENFORCE(input_names_[4] == SecondOrderMomentsName); // TODO: make this better + ORT_THROW_IF_ERROR(GraphInputsAreExpected(input_names_, AdamWOptimizerInputs)); ORT_THROW_IF_ERROR(GenerateMomentumNamedStates()); } else { diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 082be5eef1..c30851425e 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -86,6 +86,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_float, ReduceAllL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InplaceClipGradNorm); + // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING @@ -198,6 +200,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, + BuildKernelCreateInfo, + // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.cc b/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.cc new file mode 100644 index 0000000000..8d441e18d4 --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.h" +#include "core/providers/cpu/math/element_wise_ops.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/cpu/reduction/reduction_ops.h" + +namespace onnxruntime { +namespace contrib { + +namespace { + +constexpr float Epsilon = 0.000001f; + +template +T GetL2Norm(const TensorSeq& gradients) { + T l2_norm = 0; + for (const auto& tensor : gradients) { + l2_norm += + ReduceAggregatorSumSquare(tensor.Shape().Size(), *tensor.Data()).aggall(tensor.Data()); + } + return reduce_sqrt(l2_norm); +} + +template +void ClipGradNorm(T total_norm, T max_norm, TensorSeq& gradients) { + const T clip_coefficient = std::min(max_norm / (total_norm + static_cast(Epsilon)), static_cast(1.0f)); + + for (const auto& grad : gradients) { + auto& tensor = const_cast(grad); + MakeEigenArrayMap(tensor) *= clip_coefficient; + } +} + +Status PopulateOutput(OpKernelContext* ctx, const TensorSeq* gradients, TensorSeq* clipped_gradients) { + if (gradients == clipped_gradients) { + return Status::OK(); + } + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); + + clipped_gradients->SetType(gradients->DataType()); + clipped_gradients->Reserve(gradients->Size()); + for (const auto& grad : *gradients) { + Tensor target_tensor(grad.DataType(), grad.Shape(), alloc); + CopyCpuTensor(&grad, &target_tensor); + clipped_gradients->Add(std::move(target_tensor)); // Add will check for type consistency + } + + return Status::OK(); +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + InplaceClipGradNorm, + kMSDomain, + 1, + kCpuExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) /* Return updated gradients in-place */ + .TypeConstraint("S_GRAD", DataTypeImpl::AllFixedSizeSequenceTensorTypes()), + InplaceClipGradNorm); + +template +Status InplaceClipGradNorm::Compute(OpKernelContext* ctx) const { + const TensorSeq* gradients = ctx->Input(0); + + const T total_norm = GetL2Norm(*gradients); + + auto grads = const_cast(gradients); + ClipGradNorm(total_norm, max_norm_, *grads); + + // Populate the output sequence tensors. + TensorSeq* clipped_gradients = ctx->Output(0); + ORT_RETURN_IF_ERROR(PopulateOutput(ctx, gradients, clipped_gradients)); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.h b/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.h new file mode 100644 index 0000000000..d5da3445ed --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/optimizer/clip_grad_norm/clip_grad_norm.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class InplaceClipGradNorm final : public OpKernel { + public: + InplaceClipGradNorm(const OpKernelInfo& info) + : OpKernel(info) { + info.GetAttrOrDefault("max_norm", &max_norm_, 1.0f); + info.GetAttrOrDefault("norm_type", &norm_type_, std::string("fro")); + ORT_ENFORCE(norm_type_ == "fro", "Given norm type ", norm_type_, " is not supported for InplaceClipGradNorm."); + } + + Status Compute(OpKernelContext* context) const override; + + private: + float max_norm_; + std::string norm_type_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/clip_grad_norm/clip_grad_norm.cc b/orttraining/orttraining/training_ops/cuda/optimizer/clip_grad_norm/clip_grad_norm.cc index 064c1b2486..9841f2b7f6 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/clip_grad_norm/clip_grad_norm.cc +++ b/orttraining/orttraining/training_ops/cuda/optimizer/clip_grad_norm/clip_grad_norm.cc @@ -40,7 +40,7 @@ Status PopulateOutput(cudaStream_t stream, AllocatorPtr alloc, const TensorSeq* TensorSeq** clipped_gradients) { // If the output buffer is the same as the input buffer, the planner has // decided to reuse the buffer. No need to perform a memcpy in that case. - if (const_cast(gradients) == *clipped_gradients) { + if (gradients == *clipped_gradients) { return Status::OK(); } @@ -84,7 +84,7 @@ Status InplaceClipGradNorm::ComputeInternal(OpKernelContext* ctx) const { GetGroupedTensors(gradients, &tensor_sizes, &grouped_tensor_pointers); AllocatorPtr alloc; - ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK(), "InplaceClipGradNorm: Unable to get an allocator."); + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); // Get frobenius norm for the grouped inputs float* total_norm = reinterpret_cast(alloc->Alloc(sizeof(float)));