Redesign InPlaceAccumulator op (#11842)

* op changes

* review comments

* shape consolidation, test trigger, cleanup

* review comments
This commit is contained in:
ashbhandare 2022-06-24 02:11:06 -07:00 committed by GitHub
parent 17a8ecee6f
commit c2fd5ccbe9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 300 additions and 73 deletions

View file

@ -1005,22 +1005,6 @@ class PlannerImpl {
}
}
}
#if defined(ENABLE_TRAINING) && defined(ENABLE_TRAINING_ON_DEVICE)
// This is required because for the on-device training case,
// the InPlaceAccumulator produces graph
// outputs which will be re-allocated instead of re-using
// the input accumulation buffer
if (pnode->OpType() == "InPlaceAccumulator") {
const NodeArg* input = pnode->InputDefs()[0];
const auto& input_name = input->Name();
const auto input_index = Index(input_name);
const auto& alloc_plan = AllocPlan(input_index);
if (alloc_plan.alloc_kind == AllocKind::kPreExisting) {
Reuse(input_index, current, AllocKind::kShare);
}
}
#endif
} else if (!context_.IsParallelExecutionEnabled() &&
FindReusableInput(*pnode, static_cast<int>(output_arg_def_index), &reused)) {
// Re-using inputs is applicable for tensors, sequence tensors,

View file

@ -1413,6 +1413,40 @@ void RegisterTrainingOpSchemas() {
propagateShapeAndTypeFromFirstInput(ctx);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(InPlaceAccumulatorV2)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc("In-place accumulator for tensors. Differs from older op by adding cotrol input for reset, and optional output buffer.")
.Input(0, "accumulation_buffer", "historical result of accumulator", "T")
.Input(1, "value", "the value that will be added to the accumulator", "T_GRAD")
.Input(2, "overwrite_flag", "Indicates if tensor should be overwritten. Default is accumulation", "T_BOOL", OpSchema::Optional)
.Output(0, "updated_flag", "Whether the update was completed", "T_BOOL")
.Output(1, "accumulation_buffer_out", "updated result of accumulator", "T", OpSchema::Optional)
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint(
"T_GRAD",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint(
"T_BOOL",
{"tensor(bool)"},
"Constrain types to boolean tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL);
ONNX_NAMESPACE::TensorShapeProto updated_shape;
updated_shape.add_dim()->set_dim_value(1);
updateOutputShape(ctx, 0, updated_shape);
if (ctx.getNumOutputs() == 2){
propagateElemTypeFromInputToOutput(ctx, 0, 1);
if (hasNInputShapes(ctx, 1)) {
propagateShapeFromInputToOutput(ctx, 0, 1);
}
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(ZeroGradient)
.SetDomain(kMSDomain)
.SinceVersion(1)

View file

@ -109,11 +109,11 @@ def build_gradient_graph(accessor, user_args_requiring_grad, user_args_not_requi
def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_names):
"""Builds gradient accumulation nodes on top of a training model.
Adds an InPlaceAccumulator node for every gradient so that the gradients
are accumulated in a gradient buffer (which is an input to InPlaceAccumulator).
Adds an InPlaceAccumulatorV2 node for every gradient so that the gradients
are accumulated in a gradient buffer (which is an input to InPlaceAccumulatorV2).
For example, if there is a gradient in the graph called fc1.weight_grad,
an InPlaceAccumulator will be added for that gradient whose input will
an InPlaceAccumulatorV2 will be added for that gradient whose input will
be a graph input (fc1.weight_grad.accumulation.buffer) and the newly
computed gradient (fc1.weight_grad).
@ -122,7 +122,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
Ʌ v v
| |_________________________|
| |
| InPlaceAccumulator
| InPlaceAccumulatorV2
| |
| v
|______________________|
@ -154,7 +154,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
# Gradient accumulation node
acc_node = onnx.helper.make_node(
"InPlaceAccumulator",
"InPlaceAccumulatorV2",
[grad_accumulation_buffer_name, grad_name, lazy_reset_grad_input_name],
[grad_accumulation_output_name],
name=f"GradientAccumulator{idx}",
@ -168,9 +168,10 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
grad_accumulation_buffer_input.name = grad_accumulation_buffer_name
graph_inputs.append(grad_accumulation_buffer_input)
# accumulated gradient is also a graph output
grad_accumulation_output = copy.deepcopy(graph_output)
grad_accumulation_output.name = grad_accumulation_output_name
# accumulated gradient update flag is also a graph output
grad_accumulation_output = onnx.helper.make_tensor_value_info(
grad_accumulation_output_name, onnx.TensorProto.BOOL, [1]
)
graph_outputs.append(grad_accumulation_output)
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1])

View file

@ -50,18 +50,44 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
static void RunReductionTests(const OpDef& op_def, bool axes_as_input = false,
bool check_not_have_shape_inferencing = false) {
std::vector<std::vector<int64_t>> x_shapes = {
{4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2},
{4, 3, 2},
{4, 3, 2},
{4, 3, 2},
{4, 3, 2},
{4, 3, 2},
{4, 3, 2},
{4, 3, 2},
{4, 3, 2},
};
std::vector<std::vector<int64_t>> y_shapes = {
{1, 1, 1}, {}, {1, 3, 1}, {2}, {4, 1, 2}, {4, 3}, {4, 1, 2}, {4},
{1, 1, 1},
{},
{1, 3, 1},
{2},
{4, 1, 2},
{4, 3},
{4, 1, 2},
{4},
};
std::vector<std::vector<int64_t>> axes_vec = {
{}, // default case
{0, 1, 2}, {0, 2}, {0, 1}, {1}, {2}, {-2}, {-2, -1},
{0, 1, 2},
{0, 2},
{0, 1},
{1},
{2},
{-2},
{-2, -1},
};
std::vector<int64_t> keepdims_ip = {
-1, // default case
0, 1, 0, 1, 0, 1, 0,
0,
1,
0,
1,
0,
1,
0,
};
GradientChecker<float, float, float> gradient_checker;
@ -2084,15 +2110,69 @@ TEST(GradientCheckerTest, SimplifiedLayerNormGrad) {
TEST(GradientUtilsTest, InPlaceAccumulatorFloat32) {
OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain);
test.AddInput<float>("old_sum", {3}, {1, 2, 3});
test.AddInput<float>("value", {3}, {4, 5, 6});
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
test.AddOutput<float>("new_sum", {3}, {5, 7, 9});
test.AddOutput<float>("new_sum", {3}, {5.f, 7.f, 9.f});
test.Run();
}
TEST(GradientUtilsTest, InPlaceAccumulatorV2_CPU) {
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
test.AddOutput<bool>("updated", {1}, {true});
test.AddOutput<float>("new_sum", {3}, {5.f, 7.f, 9.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
}
TEST(GradientUtilsTest, InPlaceAccumulatorV2Overwrite) {
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
test.AddInput<bool>("overwrite", {1}, {true});
test.AddOutput<bool>("updated", {1}, {true});
test.AddOutput<float>("new_sum", {3}, {4.f, 5.f, 6.f});
test.Run();
}
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(GradientUtilsTest, InPlaceAccumulatorV2_GPU) {
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
test.AddOutput<bool>("updated", {1}, {true});
test.AddOutput<float>("new_sum", {3}, {5.f, 7.f, 9.f});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
}
TEST(GradientUtilsTest, InPlaceAccumulatorV2_Float16) {
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
std::vector<float> old_sum = {1.0f, 2.0f, 3.0f};
std::vector<float> value = {4.0f, 5.0f, 6.0f};
std::vector<float> new_sum = {4.0f, 5.0f, 6.0f};
std::vector<MLFloat16> value_half(3);
ConvertFloatToMLFloat16(value.data(), value_half.data(), 3);
test.AddInput<float>("old_sum", {3}, old_sum);
test.AddInput<MLFloat16>("value", {3}, value_half);
test.AddInput<bool>("overwrite", {1}, {true});
test.AddOutput<bool>("updated", {1}, {true});
test.AddOutput<float>("new_sum", {3}, new_sum);
// Didn't implement mixed precision InPlaceAccumulatorV2 in CPU
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
}
TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) {
OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain);

View file

@ -48,7 +48,7 @@ void GenerateRandomInput(gsl::span<const int64_t> dims, OrtValue& input) {
}
TEST(TrainingApiTest, ModuleTrainStep) {
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
auto model_uri = MODEL_FOLDER "training_model.onnx";
CheckpointState state;
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
@ -59,7 +59,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
*env, std::vector<std::shared_ptr<IExecutionProvider>>());
ORT_ENFORCE(model->GetTrainModeOutputCount() == 1);
OrtValue input, target;
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
CreateInputOrtValue<int32_t>(std::array<int64_t, 1>{2}, std::vector<int32_t>(2, 1), &target);
@ -76,6 +76,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
std::vector<OrtValue>& inputs = *it;
std::vector<OrtValue> fetches;
ORT_ENFORCE(model->TrainStep(inputs, fetches).IsOK());
ORT_ENFORCE(fetches.size() == 1);
bias_grad = bias_param->Gradient();
if (step > 1) {
@ -103,7 +104,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(TrainingApiTest, OptimStep) {
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
auto model_uri = MODEL_FOLDER "training_model.onnx";
auto optim_uri = MODEL_FOLDER "adamw.onnx";
CheckpointState state;
@ -182,7 +183,7 @@ void CompareValue(float expected, float output, float rtol = 1e-4, float atol =
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
int64_t warmup_step_count) {
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
auto model_uri = MODEL_FOLDER "training_model.onnx";
auto optim_uri = MODEL_FOLDER "adamw.onnx";
CheckpointState state;

View file

@ -311,7 +311,9 @@ int RunTraining(const TestRunnerParameters& params) {
g_ort_api->ReleaseValue(inputs[i]);
}
// TODO(askhade): release output values. Needs changes from Aishwarya's PR.
for (size_t i = 0; i < fetches.size(); i++) {
g_ort_api->ReleaseValue(fetches[i]);
}
}
data_loader.ResetIterateIndex();

View file

@ -66,7 +66,8 @@ Module::Module(const std::string& train_model_path_or_bytes,
ORT_THROW_IF_ERROR(train_sess_->Initialize());
// Extract model input and output names
utils::GetGraphInputOutputNames(train_sess_, train_input_names_, train_output_names_);
std::vector<std::string> train_input_names, train_output_names;
utils::GetGraphInputOutputNames(train_sess_, train_input_names, train_output_names);
// Reorder the extracted input names in the following order:
// user inputs, weights, gradients, reset_grad
@ -74,7 +75,7 @@ Module::Module(const std::string& train_model_path_or_bytes,
std::string param_name;
std::unordered_map<std::string, size_t> param_name_to_grad_input_index_map;
for (const auto& input_name : train_input_names_) {
for (const auto& input_name : train_input_names) {
auto it = named_parameters_.find(input_name);
if (it != named_parameters_.end()) {
param_input_names.emplace_back(input_name);
@ -95,6 +96,12 @@ Module::Module(const std::string& train_model_path_or_bytes,
train_input_names_.insert(train_input_names_.end(), grad_input_names.begin(), grad_input_names.end());
train_input_names_.insert(train_input_names_.end(), reset_grad_name.begin(), reset_grad_name.end());
for (const auto& output_name : train_output_names) {
if (!utils::GetParamNameFromGradient(output_name, param_name)) {
train_output_names_.emplace_back(output_name);
}
}
// Loop each parameter, allocate it's memory based on user specified device.
auto& train_sess_state = train_sess_->GetSessionState();
for (auto& param_name : param_input_names) {
@ -205,11 +212,10 @@ Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtVal
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
feeds.insert(feeds.end(), gradients_.begin(), gradients_.end());
// TODO: consider maintaining this as ortvalue instead of bool
OrtValue do_update_input;
utils::WrapInOrtValue<bool>(accumulate_gradient_, &do_update_input);
feeds.push_back(do_update_input);
OrtValue reset_grad_input;
utils::WrapInOrtValue<bool>(!accumulate_gradient_, &reset_grad_input);
feeds.push_back(reset_grad_input);
// TODO: need to filter out the grads from the output ortvalues
auto status = train_sess_->Run(RunOptions(), train_input_names_, feeds, train_output_names_, &outputs);
ORT_THROW_IF_ERROR(status);

View file

@ -10,6 +10,7 @@ namespace contrib {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SGDOptimizer);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AdamOptimizer);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulatorV2);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, PassThrough);
@ -121,6 +122,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SGDOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AdamOptimizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulatorV2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, PassThrough)>,

View file

@ -9,6 +9,20 @@
namespace onnxruntime {
namespace contrib {
template <typename T>
void getBroadcastSpanFunc(ProcessBroadcastSpanFuncs& funcs) {
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
}};
funcs = std::move(add_funcs);
}
ONNX_OPERATOR_KERNEL_EX(
InPlaceAccumulator,
kMSDomain,
@ -28,33 +42,17 @@ Status InPlaceAccumulator<T>::Compute(OpKernelContext* context) const {
if (do_update_tensor) {
const bool do_update = *(do_update_tensor->template Data<bool>());
if (!do_update) {
#ifdef ENABLE_TRAINING_ON_DEVICE
// This is temporary fix till we potentially redesign inplaceaccumulator op
// to fit lazy reset grad functionality
const Tensor* new_gradient = context->Input<Tensor>(1);
const void* updated_data = new_gradient->template Data<T>();
memcpy(output_data, updated_data, new_gradient->SizeInBytes());
#else
const void* input_data = gradient_buffer->template Data<T>();
if (output_data != input_data) {
memcpy(output_data, input_data, gradient_buffer->SizeInBytes());
}
#endif
return Status::OK();
}
}
//Copy from Add CPU kernel
ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
}};
// Copy from Add CPU kernel
ProcessBroadcastSpanFuncs funcs;
getBroadcastSpanFunc<T>(funcs);
UntypedBroadcastTwo(*context, funcs);
@ -81,5 +79,54 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T2", DataTypeImpl::AllTensorTypes()),
ZeroGradient<float>);
ONNX_OPERATOR_KERNEL_EX(
InPlaceAccumulatorV2,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.Alias(0, 1) // accumulate tensors in-place
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
InPlaceAccumulatorV2<float>);
template <typename T>
Status InPlaceAccumulatorV2<T>::Compute(OpKernelContext* context) const {
Tensor* accumulation_buffer = const_cast<Tensor*>(context->Input<Tensor>(0));
const Tensor* new_value = context->Input<Tensor>(1);
const Tensor* overwrite_tensor = context->Input<Tensor>(2);
void* accumulation_buffer_data = accumulation_buffer->template MutableData<T>();
const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data<bool>()) : false;
if (overwrite) {
const void* updated_data = new_value->template Data<T>();
memcpy(accumulation_buffer_data, updated_data, new_value->SizeInBytes());
} else {
// Copy from Add CPU kernel
ProcessBroadcastSpanFuncs funcs;
getBroadcastSpanFunc<T>(funcs);
InputBroadcaster input_broadcaster(*accumulation_buffer, *new_value);
OutputBroadcaster output_broadcaster(input_broadcaster.GetSpanSize(), *accumulation_buffer);
BroadcastHelper broadcast_helper(input_broadcaster, output_broadcaster, nullptr);
BroadcastLooper(broadcast_helper, funcs);
}
Tensor* updated_output = context->Output(0, {1});
bool* updated_output_ptr = updated_output->template MutableData<bool>();
*updated_output_ptr = true;
Tensor* accumulated_value_out = context->Output(1, new_value->Shape());
if (nullptr != accumulated_value_out) {
void* output_data = accumulated_value_out->template MutableData<T>();
if (output_data != accumulation_buffer_data) {
memcpy(output_data, accumulation_buffer_data, new_value->SizeInBytes());
}
}
return Status::OK();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -21,5 +21,13 @@ class InPlaceAccumulator final : public OpKernel {
InPlaceAccumulator(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override;
};
template <typename T>
class InPlaceAccumulatorV2 final : public OpKernel {
public:
InPlaceAccumulatorV2(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -50,6 +50,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ZeroGradient);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropy);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropyGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InPlaceAccumulatorV2);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, InPlaceAccumulatorV2);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropy);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropy);
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad);
@ -274,6 +276,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, InPlaceAccumulator)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, InPlaceAccumulator)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InPlaceAccumulatorV2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, InPlaceAccumulatorV2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ZeroGradient)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ZeroGradient)>,

View file

@ -4,6 +4,7 @@
#include "gradient_control.h"
#include "gradient_control_impl.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/binary_elementwise_ops.h"
#include "core/providers/cuda/reduction/reduction_functions.h"
#include "core/providers/cuda/cuda_allocator.h"
@ -75,19 +76,7 @@ Status InPlaceAccumulator<T, T_GRAD>::ComputeInternal(OpKernelContext* ctx) cons
if (do_update_tensor) {
const bool do_update = *(do_update_tensor->template Data<bool>());
if (!do_update) {
#ifdef ENABLE_TRAINING_ON_DEVICE
// This is temporary fix till we potentially redesign inplaceaccumulator op
// to fit lazy reset grad functionality
if (std::is_same<T,T_GRAD>::value) {
const void* source = right_addee_buffer.template Data<T>();
T* target = accumulation_output.template MutableData<T>();
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target, source, right_addee_buffer.SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
} else {
ORT_NOT_IMPLEMENTED();
}
#else
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<T>(Stream(), left_addee_buffer, accumulation_output));
#endif
return Status::OK();
}
}
@ -102,5 +91,66 @@ Status InPlaceAccumulator<T, T_GRAD>::ComputeInternal(OpKernelContext* ctx) cons
return Status::OK();
}
#define REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(T, T_GRAD) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
InPlaceAccumulatorV2, \
kMSDomain, \
1, \
T##_##T_GRAD, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.Alias(0, 1) /* Accumulate tensors in-place */ \
.InputMemoryType(OrtMemTypeCPUInput, 2) /* overwrite_flag is on CPU*/ \
.OutputMemoryType(OrtMemTypeCPUOutput, 0) /* updated_flag is on CPU*/ \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T_GRAD", DataTypeImpl::GetTensorType<T_GRAD>()), \
InPlaceAccumulatorV2<T, T_GRAD>);
REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(float, float)
REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(float, MLFloat16)
template <typename T, typename T_GRAD>
Status InPlaceAccumulatorV2<T, T_GRAD>::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToCudaType<T>::MappedType CudaT;
typedef typename ToCudaType<T_GRAD>::MappedType CudaT_GRAD;
Tensor& left_addee_buffer = *const_cast<Tensor*>(ctx->Input<Tensor>(0));
const Tensor& right_addee_buffer = *ctx->Input<Tensor>(1);
const Tensor* overwrite_tensor = ctx->Input<Tensor>(2);
const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data<bool>()) : false;
if (overwrite) {
const T_GRAD* source = right_addee_buffer.template Data<T_GRAD>();
T* target = left_addee_buffer.template MutableData<T>();
if (std::is_same<T, T_GRAD>::value) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target, source, right_addee_buffer.SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
} else {
Impl_Cast<CudaT_GRAD, CudaT>(
Stream(),
reinterpret_cast<const CudaT_GRAD*>(source),
reinterpret_cast<CudaT*>(target),
right_addee_buffer.Shape().Size());
}
} else {
InPlaceAccumulatorImpl(
Stream(),
reinterpret_cast<const CudaT*>(left_addee_buffer.template Data<T>()),
reinterpret_cast<const CudaT_GRAD*>(right_addee_buffer.template Data<T_GRAD>()),
reinterpret_cast<CudaT*>(left_addee_buffer.template MutableData<T>()),
right_addee_buffer.Shape().Size());
}
Tensor& updated_output = *ctx->Output(0, {1});
bool* updated_output_ptr = updated_output.MutableData<bool>();
*updated_output_ptr = true;
Tensor* accumulation_output = ctx->Output(1, left_addee_buffer.Shape());
if (nullptr != accumulation_output) {
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<T>(Stream(), left_addee_buffer, *accumulation_output));
}
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -22,5 +22,12 @@ class InPlaceAccumulator final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T, typename T_GRAD>
class InPlaceAccumulatorV2 final : public CudaKernel {
public:
InPlaceAccumulatorV2(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace onnxruntime