From c2fd5ccbe94bc3e62eca770b17b0b631883419a3 Mon Sep 17 00:00:00 2001 From: ashbhandare Date: Fri, 24 Jun 2022 02:11:06 -0700 Subject: [PATCH] Redesign InPlaceAccumulator op (#11842) * op changes * review comments * shape consolidation, test trigger, cleanup * review comments --- .../core/framework/allocation_planner.cc | 16 --- .../training_api/optimizer_graph.onnx | Bin 3103 -> 0 bytes ...radient_graph.onnx => training_model.onnx} | Bin 3107 -> 3692 bytes .../core/graph/training_op_defs.cc | 34 +++++++ .../python/training/onnxblock/_graph_utils.py | 17 ++-- .../test/gradient/gradient_ops_test.cc | 94 ++++++++++++++++-- .../training_api/core/training_api_tests.cc | 9 +- .../test/training_api/trainer/trainer.cc | 4 +- .../orttraining/training_api/module.cc | 18 ++-- .../training_ops/cpu/cpu_training_kernels.cc | 2 + .../cpu/optimizer/gradient_control.cc | 85 ++++++++++++---- .../cpu/optimizer/gradient_control.h | 8 ++ .../cuda/cuda_training_kernels.cc | 5 + .../cuda/optimizer/gradient_control.cc | 74 +++++++++++--- .../cuda/optimizer/gradient_control.h | 7 ++ 15 files changed, 300 insertions(+), 73 deletions(-) delete mode 100644 onnxruntime/test/testdata/training_api/optimizer_graph.onnx rename onnxruntime/test/testdata/training_api/{gradient_graph.onnx => training_model.onnx} (54%) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 72e22efc79..44e63fa6cf 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1005,22 +1005,6 @@ class PlannerImpl { } } } -#if defined(ENABLE_TRAINING) && defined(ENABLE_TRAINING_ON_DEVICE) - // This is required because for the on-device training case, - // the InPlaceAccumulator produces graph - // outputs which will be re-allocated instead of re-using - // the input accumulation buffer - if (pnode->OpType() == "InPlaceAccumulator") { - const NodeArg* input = pnode->InputDefs()[0]; - const auto& input_name = input->Name(); - const auto input_index = Index(input_name); - - const auto& alloc_plan = AllocPlan(input_index); - if (alloc_plan.alloc_kind == AllocKind::kPreExisting) { - Reuse(input_index, current, AllocKind::kShare); - } - } -#endif } else if (!context_.IsParallelExecutionEnabled() && FindReusableInput(*pnode, static_cast(output_arg_def_index), &reused)) { // Re-using inputs is applicable for tensors, sequence tensors, diff --git a/onnxruntime/test/testdata/training_api/optimizer_graph.onnx b/onnxruntime/test/testdata/training_api/optimizer_graph.onnx deleted file mode 100644 index 3e59aa32ef38b271790ba37bad2354a91e77991e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3103 zcmds3NpI6Y6vkcM2T2ptbQN6!t!UIn>VlAZL2%*7Jr`@dC6>J5IJ9u##tn&oK>rB* zDy~S(cs%17kKLY2PWJb{d44b7TSnInJ&vLqFNm+N{%*kcd1wVLN+UlSn<>g%SeRyR z0=j3P4J{u{VV6Cb;}khVWZScF7NE?JqhUPDpu_A#_a-sXr!hROBIfiXtaETWiD!Xh z&JqWI&ghmn{x0!otZ}IDn4=q26no7LPq+ae-fq8jQ1~Iq{Lug6rf2<@^gdXHMHD0x z)bl*=&8?;l+OS}`8R}jC{+hpUf#Jj^$=tN#G<9u~=%#k7%?-M+>?Tt`h@;-G*WX@| z(;Zj}P-r=*_Xz(8vIeaX-I!6FhNd0(2|cg;W5>-cS1mgZhoNt$ z@ig`__}QQbc=q{_x`6%ixj_@7WN=k_Dn@kK6ku5yEF=5mkkmx}XJZ*TD5|ri;!K~m zU{V!2|Bnu!6aZFt7iaAD!N}p=A1=y0BzIzjlAg zoEMGjjHIzrg!~MIsco!63kSV|69gLg#r=a^J#+{)vt6aVg!*x>GLP=q0Aez`BpEg( zj1IJMGBgpK&5xzgCCP@_>)Wd4)DcOConoz3!#gJVRC%~rYN&ZaoGLe_n9w>YDB5H~ z@xkbX)Dy)U(kM$6M!Te(J<4{B~o#7)s4)^n)8H;K(+ z9YNzYO5*K+bgPo8-zsvVw~_K{AO>?xXJbJT*HE<$tP{B7?x3&_pmLp1Dw>8@7699e zVt51Ezct}0g>9XZ__(_w6*i|;dY=Ty>3DjpB}%nqir2;0CtAu?lkW?mY^*~frOaVb ViL`B^WI-d*2)1`!kZ+$I?H@kgcdY;b diff --git a/onnxruntime/test/testdata/training_api/gradient_graph.onnx b/onnxruntime/test/testdata/training_api/training_model.onnx similarity index 54% rename from onnxruntime/test/testdata/training_api/gradient_graph.onnx rename to onnxruntime/test/testdata/training_api/training_model.onnx index fb84a16fd7b42874ef4e2a870f9d1eb93e2638d0..34b50bfab3c27c6c2c8dd44ac5685f019977f8bd 100644 GIT binary patch literal 3692 zcmb_f&2HO95EgA&rl!`9w{oOPjo~OkfvPnlDo&c{(83CeAV3d=0~CcW#7bI8RQLyy zitP5%Cn$;@dMteCLlo$_Z`23qE`KC0Dbfm%i?rMse*4YtH?u?W5FJ`Ku4#|w&4Mh+ zvefPW-U0^>Rl{1k=Uss5Se8~qoy}bk*?6R42OW6kX5kglb*0|K+#yeG0hM5AIT{%#WGSdM;(N<|D~6McN2@`8#<2WX&SMx zIBb%g%l)7fY$}CME6zJkOo!y+d}To3m8jrE{2CIz1;}<`^cT) zR*Lb1PLibim;sYU^T}>B{e6rS-mN70+MJ|F_c0YBjo`&j1X0!NtIijydrqz`({?-c zs?g4}?r6Z@=bgSNR1ftTs`0~#6%v)-+iCU^eQ)$*R3A<3?7LBI)gGg3c5dq zUWVP+vpjsTZaC_2A{NjwRM6&LXRi!TqkSW)LYV>s?2G^H)1Ti8!+NBb{0NHFdf-`y z_oE#hsz8aFim)g8FAwjr_kt!5K%~YzGY+5al=(@#?hdJlID%iXl`=%_*O)CKKW0m% zN-cOT8q@uxSed!#g)EE^Y+BC(w;8q~X%@TStB&U`bKiTo)12FdiIkkMX$O*Wne zOqz===d|cYNx?5f@u#UO{Pz;YLMXM%64Xj{MRRAXnFbi>)NF zC_S}Ah$Sb#xLAlICqF&DpeR2{N_=uWqr1UQMlMc}M0|QtVhR^r7OsmIW)4_^loe2u z_~f6ArTi6KC{mL*G8s;u!OF|S&Bc;fky`A)uz-<$@-P_RfRnB0&)_QQyr6&OLI$e5=-)ntay|2bMSjnXo_v8@S%{yD7s%t}V&dT9V&?e5oTNL6N0E=4ixViy3K21xtjD7)D8R)B6oM!M kicesh9K)loz$nB5(&{A0g&Md5j7}_E3<5lp_w%j=0AVG?CIA2c diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 21a1a2b681..8de5226331 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1413,6 +1413,40 @@ void RegisterTrainingOpSchemas() { propagateShapeAndTypeFromFirstInput(ctx); }); +ONNX_CONTRIB_OPERATOR_SCHEMA(InPlaceAccumulatorV2) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc("In-place accumulator for tensors. Differs from older op by adding cotrol input for reset, and optional output buffer.") + .Input(0, "accumulation_buffer", "historical result of accumulator", "T") + .Input(1, "value", "the value that will be added to the accumulator", "T_GRAD") + .Input(2, "overwrite_flag", "Indicates if tensor should be overwritten. Default is accumulation", "T_BOOL", OpSchema::Optional) + .Output(0, "updated_flag", "Whether the update was completed", "T_BOOL") + .Output(1, "accumulation_buffer_out", "updated result of accumulator", "T", OpSchema::Optional) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint( + "T_GRAD", + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint( + "T_BOOL", + {"tensor(bool)"}, + "Constrain types to boolean tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + ONNX_NAMESPACE::TensorShapeProto updated_shape; + updated_shape.add_dim()->set_dim_value(1); + updateOutputShape(ctx, 0, updated_shape); + if (ctx.getNumOutputs() == 2){ + propagateElemTypeFromInputToOutput(ctx, 0, 1); + if (hasNInputShapes(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + }); + ONNX_CONTRIB_OPERATOR_SCHEMA(ZeroGradient) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py index f5291a69c3..97304e9985 100644 --- a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py @@ -109,11 +109,11 @@ def build_gradient_graph(accessor, user_args_requiring_grad, user_args_not_requi def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_names): """Builds gradient accumulation nodes on top of a training model. - Adds an InPlaceAccumulator node for every gradient so that the gradients - are accumulated in a gradient buffer (which is an input to InPlaceAccumulator). + Adds an InPlaceAccumulatorV2 node for every gradient so that the gradients + are accumulated in a gradient buffer (which is an input to InPlaceAccumulatorV2). For example, if there is a gradient in the graph called fc1.weight_grad, - an InPlaceAccumulator will be added for that gradient whose input will + an InPlaceAccumulatorV2 will be added for that gradient whose input will be a graph input (fc1.weight_grad.accumulation.buffer) and the newly computed gradient (fc1.weight_grad). @@ -122,7 +122,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na Ʌ v v | |_________________________| | | - | InPlaceAccumulator + | InPlaceAccumulatorV2 | | | v |______________________| @@ -154,7 +154,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na # Gradient accumulation node acc_node = onnx.helper.make_node( - "InPlaceAccumulator", + "InPlaceAccumulatorV2", [grad_accumulation_buffer_name, grad_name, lazy_reset_grad_input_name], [grad_accumulation_output_name], name=f"GradientAccumulator{idx}", @@ -168,9 +168,10 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na grad_accumulation_buffer_input.name = grad_accumulation_buffer_name graph_inputs.append(grad_accumulation_buffer_input) - # accumulated gradient is also a graph output - grad_accumulation_output = copy.deepcopy(graph_output) - grad_accumulation_output.name = grad_accumulation_output_name + # accumulated gradient update flag is also a graph output + grad_accumulation_output = onnx.helper.make_tensor_value_info( + grad_accumulation_output_name, onnx.TensorProto.BOOL, [1] + ) graph_outputs.append(grad_accumulation_output) lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1]) diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 3c803748c7..29859f499c 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -50,18 +50,44 @@ static bool IsErrorWithinTolerance(float error, float tolerance) { static void RunReductionTests(const OpDef& op_def, bool axes_as_input = false, bool check_not_have_shape_inferencing = false) { std::vector> x_shapes = { - {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, + {4, 3, 2}, + {4, 3, 2}, + {4, 3, 2}, + {4, 3, 2}, + {4, 3, 2}, + {4, 3, 2}, + {4, 3, 2}, + {4, 3, 2}, }; std::vector> y_shapes = { - {1, 1, 1}, {}, {1, 3, 1}, {2}, {4, 1, 2}, {4, 3}, {4, 1, 2}, {4}, + {1, 1, 1}, + {}, + {1, 3, 1}, + {2}, + {4, 1, 2}, + {4, 3}, + {4, 1, 2}, + {4}, }; std::vector> axes_vec = { {}, // default case - {0, 1, 2}, {0, 2}, {0, 1}, {1}, {2}, {-2}, {-2, -1}, + {0, 1, 2}, + {0, 2}, + {0, 1}, + {1}, + {2}, + {-2}, + {-2, -1}, }; std::vector keepdims_ip = { -1, // default case - 0, 1, 0, 1, 0, 1, 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, }; GradientChecker gradient_checker; @@ -2084,15 +2110,69 @@ TEST(GradientCheckerTest, SimplifiedLayerNormGrad) { TEST(GradientUtilsTest, InPlaceAccumulatorFloat32) { OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain); - test.AddInput("old_sum", {3}, {1, 2, 3}); - test.AddInput("value", {3}, {4, 5, 6}); + test.AddInput("old_sum", {3}, {1.f, 2.f, 3.f}); + test.AddInput("value", {3}, {4.f, 5.f, 6.f}); - test.AddOutput("new_sum", {3}, {5, 7, 9}); + test.AddOutput("new_sum", {3}, {5.f, 7.f, 9.f}); + + test.Run(); +} + +TEST(GradientUtilsTest, InPlaceAccumulatorV2_CPU) { + OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain); + + test.AddInput("old_sum", {3}, {1.f, 2.f, 3.f}); + test.AddInput("value", {3}, {4.f, 5.f, 6.f}); + test.AddOutput("updated", {1}, {true}); + test.AddOutput("new_sum", {3}, {5.f, 7.f, 9.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(GradientUtilsTest, InPlaceAccumulatorV2Overwrite) { + OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain); + + test.AddInput("old_sum", {3}, {1.f, 2.f, 3.f}); + test.AddInput("value", {3}, {4.f, 5.f, 6.f}); + test.AddInput("overwrite", {1}, {true}); + test.AddOutput("updated", {1}, {true}); + test.AddOutput("new_sum", {3}, {4.f, 5.f, 6.f}); test.Run(); } #if defined(USE_CUDA) || defined(USE_ROCM) +TEST(GradientUtilsTest, InPlaceAccumulatorV2_GPU) { + OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain); + + test.AddInput("old_sum", {3}, {1.f, 2.f, 3.f}); + test.AddInput("value", {3}, {4.f, 5.f, 6.f}); + test.AddOutput("updated", {1}, {true}); + test.AddOutput("new_sum", {3}, {5.f, 7.f, 9.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider}); +} + +TEST(GradientUtilsTest, InPlaceAccumulatorV2_Float16) { + OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain); + + std::vector old_sum = {1.0f, 2.0f, 3.0f}; + std::vector value = {4.0f, 5.0f, 6.0f}; + std::vector new_sum = {4.0f, 5.0f, 6.0f}; + + std::vector value_half(3); + ConvertFloatToMLFloat16(value.data(), value_half.data(), 3); + + test.AddInput("old_sum", {3}, old_sum); + test.AddInput("value", {3}, value_half); + test.AddInput("overwrite", {1}, {true}); + test.AddOutput("updated", {1}, {true}); + test.AddOutput("new_sum", {3}, new_sum); + + // Didn't implement mixed precision InPlaceAccumulatorV2 in CPU + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider}); +} + TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) { OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain); diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 3621fcc0a6..3bcbb89d93 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -48,7 +48,7 @@ void GenerateRandomInput(gsl::span dims, OrtValue& input) { } TEST(TrainingApiTest, ModuleTrainStep) { - auto model_uri = MODEL_FOLDER "gradient_graph.onnx"; + auto model_uri = MODEL_FOLDER "training_model.onnx"; CheckpointState state; auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt"; @@ -59,7 +59,7 @@ TEST(TrainingApiTest, ModuleTrainStep) { ORT_THROW_IF_ERROR(Environment::Create(nullptr, env)); auto model = std::make_unique(model_uri, state.module_checkpoint_state.named_parameters, session_option, *env, std::vector>()); - + ORT_ENFORCE(model->GetTrainModeOutputCount() == 1); OrtValue input, target; GenerateRandomInput(std::array{2, 784}, input); CreateInputOrtValue(std::array{2}, std::vector(2, 1), &target); @@ -76,6 +76,7 @@ TEST(TrainingApiTest, ModuleTrainStep) { std::vector& inputs = *it; std::vector fetches; ORT_ENFORCE(model->TrainStep(inputs, fetches).IsOK()); + ORT_ENFORCE(fetches.size() == 1); bias_grad = bias_param->Gradient(); if (step > 1) { @@ -103,7 +104,7 @@ TEST(TrainingApiTest, ModuleTrainStep) { #if defined(USE_CUDA) || defined(USE_ROCM) TEST(TrainingApiTest, OptimStep) { - auto model_uri = MODEL_FOLDER "gradient_graph.onnx"; + auto model_uri = MODEL_FOLDER "training_model.onnx"; auto optim_uri = MODEL_FOLDER "adamw.onnx"; CheckpointState state; @@ -182,7 +183,7 @@ void CompareValue(float expected, float output, float rtol = 1e-4, float atol = void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count, int64_t warmup_step_count) { /// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances. - auto model_uri = MODEL_FOLDER "gradient_graph.onnx"; + auto model_uri = MODEL_FOLDER "training_model.onnx"; auto optim_uri = MODEL_FOLDER "adamw.onnx"; CheckpointState state; diff --git a/orttraining/orttraining/test/training_api/trainer/trainer.cc b/orttraining/orttraining/test/training_api/trainer/trainer.cc index d6ebad30f8..28d5c269f7 100644 --- a/orttraining/orttraining/test/training_api/trainer/trainer.cc +++ b/orttraining/orttraining/test/training_api/trainer/trainer.cc @@ -311,7 +311,9 @@ int RunTraining(const TestRunnerParameters& params) { g_ort_api->ReleaseValue(inputs[i]); } - // TODO(askhade): release output values. Needs changes from Aishwarya's PR. + for (size_t i = 0; i < fetches.size(); i++) { + g_ort_api->ReleaseValue(fetches[i]); + } } data_loader.ResetIterateIndex(); diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index e362e1118c..03957a5dd6 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -66,7 +66,8 @@ Module::Module(const std::string& train_model_path_or_bytes, ORT_THROW_IF_ERROR(train_sess_->Initialize()); // Extract model input and output names - utils::GetGraphInputOutputNames(train_sess_, train_input_names_, train_output_names_); + std::vector train_input_names, train_output_names; + utils::GetGraphInputOutputNames(train_sess_, train_input_names, train_output_names); // Reorder the extracted input names in the following order: // user inputs, weights, gradients, reset_grad @@ -74,7 +75,7 @@ Module::Module(const std::string& train_model_path_or_bytes, std::string param_name; std::unordered_map param_name_to_grad_input_index_map; - for (const auto& input_name : train_input_names_) { + for (const auto& input_name : train_input_names) { auto it = named_parameters_.find(input_name); if (it != named_parameters_.end()) { param_input_names.emplace_back(input_name); @@ -95,6 +96,12 @@ Module::Module(const std::string& train_model_path_or_bytes, train_input_names_.insert(train_input_names_.end(), grad_input_names.begin(), grad_input_names.end()); train_input_names_.insert(train_input_names_.end(), reset_grad_name.begin(), reset_grad_name.end()); + for (const auto& output_name : train_output_names) { + if (!utils::GetParamNameFromGradient(output_name, param_name)) { + train_output_names_.emplace_back(output_name); + } + } + // Loop each parameter, allocate it's memory based on user specified device. auto& train_sess_state = train_sess_->GetSessionState(); for (auto& param_name : param_input_names) { @@ -205,11 +212,10 @@ Status Module::TrainStep(const std::vector& inputs, std::vector(accumulate_gradient_, &do_update_input); - feeds.push_back(do_update_input); + OrtValue reset_grad_input; + utils::WrapInOrtValue(!accumulate_gradient_, &reset_grad_input); + feeds.push_back(reset_grad_input); - // TODO: need to filter out the grads from the output ortvalues auto status = train_sess_->Run(RunOptions(), train_input_names_, feeds, train_output_names_, &outputs); ORT_THROW_IF_ERROR(status); diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 2475ffb9bb..f7d7fe7455 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -10,6 +10,7 @@ namespace contrib { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SGDOptimizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AdamOptimizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulatorV2); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, PassThrough); @@ -121,6 +122,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.cc b/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.cc index 9b7aafbe37..7cbaeaedf2 100644 --- a/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.cc +++ b/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.cc @@ -9,6 +9,20 @@ namespace onnxruntime { namespace contrib { +template +void getBroadcastSpanFunc(ProcessBroadcastSpanFuncs& funcs) { + ProcessBroadcastSpanFuncs add_funcs{ + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + }, + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + }, + [](BroadcastHelper& per_iter_bh) { + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + }}; + funcs = std::move(add_funcs); +} ONNX_OPERATOR_KERNEL_EX( InPlaceAccumulator, kMSDomain, @@ -28,33 +42,17 @@ Status InPlaceAccumulator::Compute(OpKernelContext* context) const { if (do_update_tensor) { const bool do_update = *(do_update_tensor->template Data()); if (!do_update) { -#ifdef ENABLE_TRAINING_ON_DEVICE - // This is temporary fix till we potentially redesign inplaceaccumulator op - // to fit lazy reset grad functionality - const Tensor* new_gradient = context->Input(1); - const void* updated_data = new_gradient->template Data(); - memcpy(output_data, updated_data, new_gradient->SizeInBytes()); -#else const void* input_data = gradient_buffer->template Data(); if (output_data != input_data) { memcpy(output_data, input_data, gradient_buffer->SizeInBytes()); } -#endif return Status::OK(); } } - //Copy from Add CPU kernel - ProcessBroadcastSpanFuncs funcs{ - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); - }, - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); - }, - [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); - }}; + // Copy from Add CPU kernel + ProcessBroadcastSpanFuncs funcs; + getBroadcastSpanFunc(funcs); UntypedBroadcastTwo(*context, funcs); @@ -81,5 +79,54 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T2", DataTypeImpl::AllTensorTypes()), ZeroGradient); +ONNX_OPERATOR_KERNEL_EX( + InPlaceAccumulatorV2, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .Alias(0, 1) // accumulate tensors in-place + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + InPlaceAccumulatorV2); + +template +Status InPlaceAccumulatorV2::Compute(OpKernelContext* context) const { + Tensor* accumulation_buffer = const_cast(context->Input(0)); + const Tensor* new_value = context->Input(1); + const Tensor* overwrite_tensor = context->Input(2); + + void* accumulation_buffer_data = accumulation_buffer->template MutableData(); + const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data()) : false; + + if (overwrite) { + const void* updated_data = new_value->template Data(); + memcpy(accumulation_buffer_data, updated_data, new_value->SizeInBytes()); + } else { + // Copy from Add CPU kernel + ProcessBroadcastSpanFuncs funcs; + getBroadcastSpanFunc(funcs); + + InputBroadcaster input_broadcaster(*accumulation_buffer, *new_value); + OutputBroadcaster output_broadcaster(input_broadcaster.GetSpanSize(), *accumulation_buffer); + BroadcastHelper broadcast_helper(input_broadcaster, output_broadcaster, nullptr); + + BroadcastLooper(broadcast_helper, funcs); + } + + Tensor* updated_output = context->Output(0, {1}); + bool* updated_output_ptr = updated_output->template MutableData(); + *updated_output_ptr = true; + + Tensor* accumulated_value_out = context->Output(1, new_value->Shape()); + if (nullptr != accumulated_value_out) { + void* output_data = accumulated_value_out->template MutableData(); + if (output_data != accumulation_buffer_data) { + memcpy(output_data, accumulation_buffer_data, new_value->SizeInBytes()); + } + } + + return Status::OK(); +} + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.h b/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.h index dc9b5285d7..ee459bfdad 100644 --- a/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.h +++ b/orttraining/orttraining/training_ops/cpu/optimizer/gradient_control.h @@ -21,5 +21,13 @@ class InPlaceAccumulator final : public OpKernel { InPlaceAccumulator(const OpKernelInfo& info) : OpKernel(info) {} Status Compute(OpKernelContext* context) const override; }; + +template +class InPlaceAccumulatorV2 final : public OpKernel { + public: + InPlaceAccumulatorV2(const OpKernelInfo& info) : OpKernel(info) {} + Status Compute(OpKernelContext* context) const override; +}; + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 6faf6edd40..b0f553ba9d 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -50,6 +50,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ZeroGradient); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropy); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropyGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InPlaceAccumulatorV2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, InPlaceAccumulatorV2); // class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropy); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropy); // class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad); @@ -274,6 +276,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.cc b/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.cc index b21bd35eed..5261b63c81 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.cc +++ b/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.cc @@ -4,6 +4,7 @@ #include "gradient_control.h" #include "gradient_control_impl.h" +#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/providers/cuda/math/binary_elementwise_ops.h" #include "core/providers/cuda/reduction/reduction_functions.h" #include "core/providers/cuda/cuda_allocator.h" @@ -75,19 +76,7 @@ Status InPlaceAccumulator::ComputeInternal(OpKernelContext* ctx) cons if (do_update_tensor) { const bool do_update = *(do_update_tensor->template Data()); if (!do_update) { -#ifdef ENABLE_TRAINING_ON_DEVICE - // This is temporary fix till we potentially redesign inplaceaccumulator op - // to fit lazy reset grad functionality - if (std::is_same::value) { - const void* source = right_addee_buffer.template Data(); - T* target = accumulation_output.template MutableData(); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target, source, right_addee_buffer.SizeInBytes(), cudaMemcpyDeviceToDevice, Stream())); - } else { - ORT_NOT_IMPLEMENTED(); - } -#else ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer(Stream(), left_addee_buffer, accumulation_output)); -#endif return Status::OK(); } } @@ -102,5 +91,66 @@ Status InPlaceAccumulator::ComputeInternal(OpKernelContext* ctx) cons return Status::OK(); } +#define REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(T, T_GRAD) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + InPlaceAccumulatorV2, \ + kMSDomain, \ + 1, \ + T##_##T_GRAD, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .Alias(0, 1) /* Accumulate tensors in-place */ \ + .InputMemoryType(OrtMemTypeCPUInput, 2) /* overwrite_flag is on CPU*/ \ + .OutputMemoryType(OrtMemTypeCPUOutput, 0) /* updated_flag is on CPU*/ \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T_GRAD", DataTypeImpl::GetTensorType()), \ + InPlaceAccumulatorV2); + +REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(float, float) +REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(float, MLFloat16) + +template +Status InPlaceAccumulatorV2::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToCudaType::MappedType CudaT; + typedef typename ToCudaType::MappedType CudaT_GRAD; + + Tensor& left_addee_buffer = *const_cast(ctx->Input(0)); + const Tensor& right_addee_buffer = *ctx->Input(1); + const Tensor* overwrite_tensor = ctx->Input(2); + const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data()) : false; + + if (overwrite) { + const T_GRAD* source = right_addee_buffer.template Data(); + T* target = left_addee_buffer.template MutableData(); + if (std::is_same::value) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target, source, right_addee_buffer.SizeInBytes(), cudaMemcpyDeviceToDevice, Stream())); + } else { + Impl_Cast( + Stream(), + reinterpret_cast(source), + reinterpret_cast(target), + right_addee_buffer.Shape().Size()); + } + } else { + InPlaceAccumulatorImpl( + Stream(), + reinterpret_cast(left_addee_buffer.template Data()), + reinterpret_cast(right_addee_buffer.template Data()), + reinterpret_cast(left_addee_buffer.template MutableData()), + right_addee_buffer.Shape().Size()); + } + + Tensor& updated_output = *ctx->Output(0, {1}); + bool* updated_output_ptr = updated_output.MutableData(); + *updated_output_ptr = true; + + Tensor* accumulation_output = ctx->Output(1, left_addee_buffer.Shape()); + if (nullptr != accumulation_output) { + ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer(Stream(), left_addee_buffer, *accumulation_output)); + } + + return Status::OK(); +} + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h b/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h index a4f60bc66b..5c38e53bc4 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h +++ b/orttraining/orttraining/training_ops/cuda/optimizer/gradient_control.h @@ -22,5 +22,12 @@ class InPlaceAccumulator final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class InPlaceAccumulatorV2 final : public CudaKernel { + public: + InPlaceAccumulatorV2(const OpKernelInfo& info) : CudaKernel(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime \ No newline at end of file