mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Redesign InPlaceAccumulator op (#11842)
* op changes * review comments * shape consolidation, test trigger, cleanup * review comments
This commit is contained in:
parent
17a8ecee6f
commit
c2fd5ccbe9
15 changed files with 300 additions and 73 deletions
|
|
@ -1005,22 +1005,6 @@ class PlannerImpl {
|
|||
}
|
||||
}
|
||||
}
|
||||
#if defined(ENABLE_TRAINING) && defined(ENABLE_TRAINING_ON_DEVICE)
|
||||
// This is required because for the on-device training case,
|
||||
// the InPlaceAccumulator produces graph
|
||||
// outputs which will be re-allocated instead of re-using
|
||||
// the input accumulation buffer
|
||||
if (pnode->OpType() == "InPlaceAccumulator") {
|
||||
const NodeArg* input = pnode->InputDefs()[0];
|
||||
const auto& input_name = input->Name();
|
||||
const auto input_index = Index(input_name);
|
||||
|
||||
const auto& alloc_plan = AllocPlan(input_index);
|
||||
if (alloc_plan.alloc_kind == AllocKind::kPreExisting) {
|
||||
Reuse(input_index, current, AllocKind::kShare);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else if (!context_.IsParallelExecutionEnabled() &&
|
||||
FindReusableInput(*pnode, static_cast<int>(output_arg_def_index), &reused)) {
|
||||
// Re-using inputs is applicable for tensors, sequence tensors,
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -1413,6 +1413,40 @@ void RegisterTrainingOpSchemas() {
|
|||
propagateShapeAndTypeFromFirstInput(ctx);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(InPlaceAccumulatorV2)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc("In-place accumulator for tensors. Differs from older op by adding cotrol input for reset, and optional output buffer.")
|
||||
.Input(0, "accumulation_buffer", "historical result of accumulator", "T")
|
||||
.Input(1, "value", "the value that will be added to the accumulator", "T_GRAD")
|
||||
.Input(2, "overwrite_flag", "Indicates if tensor should be overwritten. Default is accumulation", "T_BOOL", OpSchema::Optional)
|
||||
.Output(0, "updated_flag", "Whether the update was completed", "T_BOOL")
|
||||
.Output(1, "accumulation_buffer_out", "updated result of accumulator", "T", OpSchema::Optional)
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeConstraint(
|
||||
"T_GRAD",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeConstraint(
|
||||
"T_BOOL",
|
||||
{"tensor(bool)"},
|
||||
"Constrain types to boolean tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL);
|
||||
ONNX_NAMESPACE::TensorShapeProto updated_shape;
|
||||
updated_shape.add_dim()->set_dim_value(1);
|
||||
updateOutputShape(ctx, 0, updated_shape);
|
||||
if (ctx.getNumOutputs() == 2){
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 1);
|
||||
if (hasNInputShapes(ctx, 1)) {
|
||||
propagateShapeFromInputToOutput(ctx, 0, 1);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ZeroGradient)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
|
|||
|
|
@ -109,11 +109,11 @@ def build_gradient_graph(accessor, user_args_requiring_grad, user_args_not_requi
|
|||
def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_names):
|
||||
"""Builds gradient accumulation nodes on top of a training model.
|
||||
|
||||
Adds an InPlaceAccumulator node for every gradient so that the gradients
|
||||
are accumulated in a gradient buffer (which is an input to InPlaceAccumulator).
|
||||
Adds an InPlaceAccumulatorV2 node for every gradient so that the gradients
|
||||
are accumulated in a gradient buffer (which is an input to InPlaceAccumulatorV2).
|
||||
|
||||
For example, if there is a gradient in the graph called fc1.weight_grad,
|
||||
an InPlaceAccumulator will be added for that gradient whose input will
|
||||
an InPlaceAccumulatorV2 will be added for that gradient whose input will
|
||||
be a graph input (fc1.weight_grad.accumulation.buffer) and the newly
|
||||
computed gradient (fc1.weight_grad).
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
|
|||
Ʌ v v
|
||||
| |_________________________|
|
||||
| |
|
||||
| InPlaceAccumulator
|
||||
| InPlaceAccumulatorV2
|
||||
| |
|
||||
| v
|
||||
|______________________|
|
||||
|
|
@ -154,7 +154,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
|
|||
|
||||
# Gradient accumulation node
|
||||
acc_node = onnx.helper.make_node(
|
||||
"InPlaceAccumulator",
|
||||
"InPlaceAccumulatorV2",
|
||||
[grad_accumulation_buffer_name, grad_name, lazy_reset_grad_input_name],
|
||||
[grad_accumulation_output_name],
|
||||
name=f"GradientAccumulator{idx}",
|
||||
|
|
@ -168,9 +168,10 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
|
|||
grad_accumulation_buffer_input.name = grad_accumulation_buffer_name
|
||||
graph_inputs.append(grad_accumulation_buffer_input)
|
||||
|
||||
# accumulated gradient is also a graph output
|
||||
grad_accumulation_output = copy.deepcopy(graph_output)
|
||||
grad_accumulation_output.name = grad_accumulation_output_name
|
||||
# accumulated gradient update flag is also a graph output
|
||||
grad_accumulation_output = onnx.helper.make_tensor_value_info(
|
||||
grad_accumulation_output_name, onnx.TensorProto.BOOL, [1]
|
||||
)
|
||||
graph_outputs.append(grad_accumulation_output)
|
||||
|
||||
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1])
|
||||
|
|
|
|||
|
|
@ -50,18 +50,44 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
|
|||
static void RunReductionTests(const OpDef& op_def, bool axes_as_input = false,
|
||||
bool check_not_have_shape_inferencing = false) {
|
||||
std::vector<std::vector<int64_t>> x_shapes = {
|
||||
{4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2}, {4, 3, 2},
|
||||
{4, 3, 2},
|
||||
{4, 3, 2},
|
||||
{4, 3, 2},
|
||||
{4, 3, 2},
|
||||
{4, 3, 2},
|
||||
{4, 3, 2},
|
||||
{4, 3, 2},
|
||||
{4, 3, 2},
|
||||
};
|
||||
std::vector<std::vector<int64_t>> y_shapes = {
|
||||
{1, 1, 1}, {}, {1, 3, 1}, {2}, {4, 1, 2}, {4, 3}, {4, 1, 2}, {4},
|
||||
{1, 1, 1},
|
||||
{},
|
||||
{1, 3, 1},
|
||||
{2},
|
||||
{4, 1, 2},
|
||||
{4, 3},
|
||||
{4, 1, 2},
|
||||
{4},
|
||||
};
|
||||
std::vector<std::vector<int64_t>> axes_vec = {
|
||||
{}, // default case
|
||||
{0, 1, 2}, {0, 2}, {0, 1}, {1}, {2}, {-2}, {-2, -1},
|
||||
{0, 1, 2},
|
||||
{0, 2},
|
||||
{0, 1},
|
||||
{1},
|
||||
{2},
|
||||
{-2},
|
||||
{-2, -1},
|
||||
};
|
||||
std::vector<int64_t> keepdims_ip = {
|
||||
-1, // default case
|
||||
0, 1, 0, 1, 0, 1, 0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
};
|
||||
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
|
|
@ -2084,15 +2110,69 @@ TEST(GradientCheckerTest, SimplifiedLayerNormGrad) {
|
|||
TEST(GradientUtilsTest, InPlaceAccumulatorFloat32) {
|
||||
OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain);
|
||||
|
||||
test.AddInput<float>("old_sum", {3}, {1, 2, 3});
|
||||
test.AddInput<float>("value", {3}, {4, 5, 6});
|
||||
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
|
||||
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
|
||||
|
||||
test.AddOutput<float>("new_sum", {3}, {5, 7, 9});
|
||||
test.AddOutput<float>("new_sum", {3}, {5.f, 7.f, 9.f});
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GradientUtilsTest, InPlaceAccumulatorV2_CPU) {
|
||||
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
|
||||
|
||||
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
|
||||
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
|
||||
test.AddOutput<bool>("updated", {1}, {true});
|
||||
test.AddOutput<float>("new_sum", {3}, {5.f, 7.f, 9.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(GradientUtilsTest, InPlaceAccumulatorV2Overwrite) {
|
||||
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
|
||||
|
||||
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
|
||||
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
|
||||
test.AddInput<bool>("overwrite", {1}, {true});
|
||||
test.AddOutput<bool>("updated", {1}, {true});
|
||||
test.AddOutput<float>("new_sum", {3}, {4.f, 5.f, 6.f});
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(GradientUtilsTest, InPlaceAccumulatorV2_GPU) {
|
||||
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
|
||||
|
||||
test.AddInput<float>("old_sum", {3}, {1.f, 2.f, 3.f});
|
||||
test.AddInput<float>("value", {3}, {4.f, 5.f, 6.f});
|
||||
test.AddOutput<bool>("updated", {1}, {true});
|
||||
test.AddOutput<float>("new_sum", {3}, {5.f, 7.f, 9.f});
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(GradientUtilsTest, InPlaceAccumulatorV2_Float16) {
|
||||
OpTester test("InPlaceAccumulatorV2", 1, onnxruntime::kMSDomain);
|
||||
|
||||
std::vector<float> old_sum = {1.0f, 2.0f, 3.0f};
|
||||
std::vector<float> value = {4.0f, 5.0f, 6.0f};
|
||||
std::vector<float> new_sum = {4.0f, 5.0f, 6.0f};
|
||||
|
||||
std::vector<MLFloat16> value_half(3);
|
||||
ConvertFloatToMLFloat16(value.data(), value_half.data(), 3);
|
||||
|
||||
test.AddInput<float>("old_sum", {3}, old_sum);
|
||||
test.AddInput<MLFloat16>("value", {3}, value_half);
|
||||
test.AddInput<bool>("overwrite", {1}, {true});
|
||||
test.AddOutput<bool>("updated", {1}, {true});
|
||||
test.AddOutput<float>("new_sum", {3}, new_sum);
|
||||
|
||||
// Didn't implement mixed precision InPlaceAccumulatorV2 in CPU
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(GradientUtilsTest, InPlaceAccumulatorFloat16) {
|
||||
OpTester test("InPlaceAccumulator", 1, onnxruntime::kMSDomain);
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ void GenerateRandomInput(gsl::span<const int64_t> dims, OrtValue& input) {
|
|||
}
|
||||
|
||||
TEST(TrainingApiTest, ModuleTrainStep) {
|
||||
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
|
||||
CheckpointState state;
|
||||
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
|
||||
|
|
@ -59,7 +59,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
|
|||
ORT_THROW_IF_ERROR(Environment::Create(nullptr, env));
|
||||
auto model = std::make_unique<Module>(model_uri, state.module_checkpoint_state.named_parameters, session_option,
|
||||
*env, std::vector<std::shared_ptr<IExecutionProvider>>());
|
||||
|
||||
ORT_ENFORCE(model->GetTrainModeOutputCount() == 1);
|
||||
OrtValue input, target;
|
||||
GenerateRandomInput(std::array<int64_t, 2>{2, 784}, input);
|
||||
CreateInputOrtValue<int32_t>(std::array<int64_t, 1>{2}, std::vector<int32_t>(2, 1), &target);
|
||||
|
|
@ -76,6 +76,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
|
|||
std::vector<OrtValue>& inputs = *it;
|
||||
std::vector<OrtValue> fetches;
|
||||
ORT_ENFORCE(model->TrainStep(inputs, fetches).IsOK());
|
||||
ORT_ENFORCE(fetches.size() == 1);
|
||||
bias_grad = bias_param->Gradient();
|
||||
|
||||
if (step > 1) {
|
||||
|
|
@ -103,7 +104,7 @@ TEST(TrainingApiTest, ModuleTrainStep) {
|
|||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
|
||||
TEST(TrainingApiTest, OptimStep) {
|
||||
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
auto optim_uri = MODEL_FOLDER "adamw.onnx";
|
||||
|
||||
CheckpointState state;
|
||||
|
|
@ -182,7 +183,7 @@ void CompareValue(float expected, float output, float rtol = 1e-4, float atol =
|
|||
void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count,
|
||||
int64_t warmup_step_count) {
|
||||
/// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances.
|
||||
auto model_uri = MODEL_FOLDER "gradient_graph.onnx";
|
||||
auto model_uri = MODEL_FOLDER "training_model.onnx";
|
||||
auto optim_uri = MODEL_FOLDER "adamw.onnx";
|
||||
|
||||
CheckpointState state;
|
||||
|
|
|
|||
|
|
@ -311,7 +311,9 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
g_ort_api->ReleaseValue(inputs[i]);
|
||||
}
|
||||
|
||||
// TODO(askhade): release output values. Needs changes from Aishwarya's PR.
|
||||
for (size_t i = 0; i < fetches.size(); i++) {
|
||||
g_ort_api->ReleaseValue(fetches[i]);
|
||||
}
|
||||
}
|
||||
|
||||
data_loader.ResetIterateIndex();
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ Module::Module(const std::string& train_model_path_or_bytes,
|
|||
ORT_THROW_IF_ERROR(train_sess_->Initialize());
|
||||
|
||||
// Extract model input and output names
|
||||
utils::GetGraphInputOutputNames(train_sess_, train_input_names_, train_output_names_);
|
||||
std::vector<std::string> train_input_names, train_output_names;
|
||||
utils::GetGraphInputOutputNames(train_sess_, train_input_names, train_output_names);
|
||||
|
||||
// Reorder the extracted input names in the following order:
|
||||
// user inputs, weights, gradients, reset_grad
|
||||
|
|
@ -74,7 +75,7 @@ Module::Module(const std::string& train_model_path_or_bytes,
|
|||
std::string param_name;
|
||||
|
||||
std::unordered_map<std::string, size_t> param_name_to_grad_input_index_map;
|
||||
for (const auto& input_name : train_input_names_) {
|
||||
for (const auto& input_name : train_input_names) {
|
||||
auto it = named_parameters_.find(input_name);
|
||||
if (it != named_parameters_.end()) {
|
||||
param_input_names.emplace_back(input_name);
|
||||
|
|
@ -95,6 +96,12 @@ Module::Module(const std::string& train_model_path_or_bytes,
|
|||
train_input_names_.insert(train_input_names_.end(), grad_input_names.begin(), grad_input_names.end());
|
||||
train_input_names_.insert(train_input_names_.end(), reset_grad_name.begin(), reset_grad_name.end());
|
||||
|
||||
for (const auto& output_name : train_output_names) {
|
||||
if (!utils::GetParamNameFromGradient(output_name, param_name)) {
|
||||
train_output_names_.emplace_back(output_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Loop each parameter, allocate it's memory based on user specified device.
|
||||
auto& train_sess_state = train_sess_->GetSessionState();
|
||||
for (auto& param_name : param_input_names) {
|
||||
|
|
@ -205,11 +212,10 @@ Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtVal
|
|||
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
|
||||
feeds.insert(feeds.end(), gradients_.begin(), gradients_.end());
|
||||
// TODO: consider maintaining this as ortvalue instead of bool
|
||||
OrtValue do_update_input;
|
||||
utils::WrapInOrtValue<bool>(accumulate_gradient_, &do_update_input);
|
||||
feeds.push_back(do_update_input);
|
||||
OrtValue reset_grad_input;
|
||||
utils::WrapInOrtValue<bool>(!accumulate_gradient_, &reset_grad_input);
|
||||
feeds.push_back(reset_grad_input);
|
||||
|
||||
// TODO: need to filter out the grads from the output ortvalues
|
||||
auto status = train_sess_->Run(RunOptions(), train_input_names_, feeds, train_output_names_, &outputs);
|
||||
ORT_THROW_IF_ERROR(status);
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ namespace contrib {
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SGDOptimizer);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AdamOptimizer);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulatorV2);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, PassThrough);
|
||||
|
|
@ -121,6 +122,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SGDOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AdamOptimizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulator)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, InPlaceAccumulatorV2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ZeroGradient)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Group)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, PassThrough)>,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,20 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
template <typename T>
|
||||
void getBroadcastSpanFunc(ProcessBroadcastSpanFuncs& funcs) {
|
||||
ProcessBroadcastSpanFuncs add_funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
|
||||
}};
|
||||
funcs = std::move(add_funcs);
|
||||
}
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
InPlaceAccumulator,
|
||||
kMSDomain,
|
||||
|
|
@ -28,33 +42,17 @@ Status InPlaceAccumulator<T>::Compute(OpKernelContext* context) const {
|
|||
if (do_update_tensor) {
|
||||
const bool do_update = *(do_update_tensor->template Data<bool>());
|
||||
if (!do_update) {
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
// This is temporary fix till we potentially redesign inplaceaccumulator op
|
||||
// to fit lazy reset grad functionality
|
||||
const Tensor* new_gradient = context->Input<Tensor>(1);
|
||||
const void* updated_data = new_gradient->template Data<T>();
|
||||
memcpy(output_data, updated_data, new_gradient->SizeInBytes());
|
||||
#else
|
||||
const void* input_data = gradient_buffer->template Data<T>();
|
||||
if (output_data != input_data) {
|
||||
memcpy(output_data, input_data, gradient_buffer->SizeInBytes());
|
||||
}
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
//Copy from Add CPU kernel
|
||||
ProcessBroadcastSpanFuncs funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
|
||||
}};
|
||||
// Copy from Add CPU kernel
|
||||
ProcessBroadcastSpanFuncs funcs;
|
||||
getBroadcastSpanFunc<T>(funcs);
|
||||
|
||||
UntypedBroadcastTwo(*context, funcs);
|
||||
|
||||
|
|
@ -81,5 +79,54 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
.TypeConstraint("T2", DataTypeImpl::AllTensorTypes()),
|
||||
ZeroGradient<float>);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
InPlaceAccumulatorV2,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.Alias(0, 1) // accumulate tensors in-place
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
InPlaceAccumulatorV2<float>);
|
||||
|
||||
template <typename T>
|
||||
Status InPlaceAccumulatorV2<T>::Compute(OpKernelContext* context) const {
|
||||
Tensor* accumulation_buffer = const_cast<Tensor*>(context->Input<Tensor>(0));
|
||||
const Tensor* new_value = context->Input<Tensor>(1);
|
||||
const Tensor* overwrite_tensor = context->Input<Tensor>(2);
|
||||
|
||||
void* accumulation_buffer_data = accumulation_buffer->template MutableData<T>();
|
||||
const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data<bool>()) : false;
|
||||
|
||||
if (overwrite) {
|
||||
const void* updated_data = new_value->template Data<T>();
|
||||
memcpy(accumulation_buffer_data, updated_data, new_value->SizeInBytes());
|
||||
} else {
|
||||
// Copy from Add CPU kernel
|
||||
ProcessBroadcastSpanFuncs funcs;
|
||||
getBroadcastSpanFunc<T>(funcs);
|
||||
|
||||
InputBroadcaster input_broadcaster(*accumulation_buffer, *new_value);
|
||||
OutputBroadcaster output_broadcaster(input_broadcaster.GetSpanSize(), *accumulation_buffer);
|
||||
BroadcastHelper broadcast_helper(input_broadcaster, output_broadcaster, nullptr);
|
||||
|
||||
BroadcastLooper(broadcast_helper, funcs);
|
||||
}
|
||||
|
||||
Tensor* updated_output = context->Output(0, {1});
|
||||
bool* updated_output_ptr = updated_output->template MutableData<bool>();
|
||||
*updated_output_ptr = true;
|
||||
|
||||
Tensor* accumulated_value_out = context->Output(1, new_value->Shape());
|
||||
if (nullptr != accumulated_value_out) {
|
||||
void* output_data = accumulated_value_out->template MutableData<T>();
|
||||
if (output_data != accumulation_buffer_data) {
|
||||
memcpy(output_data, accumulation_buffer_data, new_value->SizeInBytes());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -21,5 +21,13 @@ class InPlaceAccumulator final : public OpKernel {
|
|||
InPlaceAccumulator(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InPlaceAccumulatorV2 final : public OpKernel {
|
||||
public:
|
||||
InPlaceAccumulatorV2(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ZeroGradient);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropy);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SoftmaxCrossEntropyGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InPlaceAccumulatorV2);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, InPlaceAccumulatorV2);
|
||||
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropy);
|
||||
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int64_t, SparseSoftmaxCrossEntropy);
|
||||
// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, int32_t, SparseSoftmaxCrossEntropyGrad);
|
||||
|
|
@ -274,6 +276,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, InPlaceAccumulator)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, InPlaceAccumulator)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, InPlaceAccumulatorV2)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, InPlaceAccumulatorV2)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ZeroGradient)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ZeroGradient)>,
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "gradient_control.h"
|
||||
#include "gradient_control_impl.h"
|
||||
|
||||
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops.h"
|
||||
#include "core/providers/cuda/reduction/reduction_functions.h"
|
||||
#include "core/providers/cuda/cuda_allocator.h"
|
||||
|
|
@ -75,19 +76,7 @@ Status InPlaceAccumulator<T, T_GRAD>::ComputeInternal(OpKernelContext* ctx) cons
|
|||
if (do_update_tensor) {
|
||||
const bool do_update = *(do_update_tensor->template Data<bool>());
|
||||
if (!do_update) {
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
// This is temporary fix till we potentially redesign inplaceaccumulator op
|
||||
// to fit lazy reset grad functionality
|
||||
if (std::is_same<T,T_GRAD>::value) {
|
||||
const void* source = right_addee_buffer.template Data<T>();
|
||||
T* target = accumulation_output.template MutableData<T>();
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target, source, right_addee_buffer.SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED();
|
||||
}
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<T>(Stream(), left_addee_buffer, accumulation_output));
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
@ -102,5 +91,66 @@ Status InPlaceAccumulator<T, T_GRAD>::ComputeInternal(OpKernelContext* ctx) cons
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(T, T_GRAD) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
InPlaceAccumulatorV2, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T##_##T_GRAD, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.Alias(0, 1) /* Accumulate tensors in-place */ \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2) /* overwrite_flag is on CPU*/ \
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 0) /* updated_flag is on CPU*/ \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T_GRAD", DataTypeImpl::GetTensorType<T_GRAD>()), \
|
||||
InPlaceAccumulatorV2<T, T_GRAD>);
|
||||
|
||||
REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(float, float)
|
||||
REGISTER_IN_PLACE_TENSOR_ACCUMULATORV2_TYPED(float, MLFloat16)
|
||||
|
||||
template <typename T, typename T_GRAD>
|
||||
Status InPlaceAccumulatorV2<T, T_GRAD>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
typedef typename ToCudaType<T_GRAD>::MappedType CudaT_GRAD;
|
||||
|
||||
Tensor& left_addee_buffer = *const_cast<Tensor*>(ctx->Input<Tensor>(0));
|
||||
const Tensor& right_addee_buffer = *ctx->Input<Tensor>(1);
|
||||
const Tensor* overwrite_tensor = ctx->Input<Tensor>(2);
|
||||
const bool overwrite = overwrite_tensor != nullptr ? *(overwrite_tensor->template Data<bool>()) : false;
|
||||
|
||||
if (overwrite) {
|
||||
const T_GRAD* source = right_addee_buffer.template Data<T_GRAD>();
|
||||
T* target = left_addee_buffer.template MutableData<T>();
|
||||
if (std::is_same<T, T_GRAD>::value) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target, source, right_addee_buffer.SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
|
||||
} else {
|
||||
Impl_Cast<CudaT_GRAD, CudaT>(
|
||||
Stream(),
|
||||
reinterpret_cast<const CudaT_GRAD*>(source),
|
||||
reinterpret_cast<CudaT*>(target),
|
||||
right_addee_buffer.Shape().Size());
|
||||
}
|
||||
} else {
|
||||
InPlaceAccumulatorImpl(
|
||||
Stream(),
|
||||
reinterpret_cast<const CudaT*>(left_addee_buffer.template Data<T>()),
|
||||
reinterpret_cast<const CudaT_GRAD*>(right_addee_buffer.template Data<T_GRAD>()),
|
||||
reinterpret_cast<CudaT*>(left_addee_buffer.template MutableData<T>()),
|
||||
right_addee_buffer.Shape().Size());
|
||||
}
|
||||
|
||||
Tensor& updated_output = *ctx->Output(0, {1});
|
||||
bool* updated_output_ptr = updated_output.MutableData<bool>();
|
||||
*updated_output_ptr = true;
|
||||
|
||||
Tensor* accumulation_output = ctx->Output(1, left_addee_buffer.Shape());
|
||||
if (nullptr != accumulation_output) {
|
||||
ORT_RETURN_IF_ERROR(CopyIfNotSameBuffer<T>(Stream(), left_addee_buffer, *accumulation_output));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -22,5 +22,12 @@ class InPlaceAccumulator final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
template <typename T, typename T_GRAD>
|
||||
class InPlaceAccumulatorV2 final : public CudaKernel {
|
||||
public:
|
||||
InPlaceAccumulatorV2(const OpKernelInfo& info) : CudaKernel(info) {}
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue