mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Re-enable multi-tensor-apply for LAMB optimizer
This commit is contained in:
parent
fc614ad050
commit
ced5b66306
1 changed files with 3 additions and 6 deletions
|
|
@ -196,8 +196,7 @@ Status launch_lamb_compute_direction(
|
|||
ORT_ENFORCE(group_count == static_cast<int>(epsilons.size()));
|
||||
|
||||
constexpr int tensor_count_per_group = 6;
|
||||
// const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
const int max_tensor_size = 0;
|
||||
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
// Bucketize tensor groups by the associated optimizer configuration.
|
||||
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
|
||||
std::map<std::tuple<float, float, float, float>, std::vector<std::vector<void*>>> buckets;
|
||||
|
|
@ -290,8 +289,7 @@ Status launch_lamb_reduction(
|
|||
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
|
||||
std::vector<std::vector<void*>> buckets;
|
||||
std::vector<int> tensor_sizes_in_buckets;
|
||||
// const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
const int max_tensor_size = 0;
|
||||
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
for (int i = 0; i < group_count; ++i) {
|
||||
if (tensor_sizes[i] > max_tensor_size) {
|
||||
reduce_square_sum(
|
||||
|
|
@ -368,8 +366,7 @@ Status launch_lamb_update(
|
|||
// If two tensor groups use different "alpha", they should be put into two distinct buckets.
|
||||
std::vector<std::vector<void*>> buckets;
|
||||
std::vector<int> tensor_sizes_in_bucket;
|
||||
// const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
const int max_tensor_size = 0;
|
||||
const int max_tensor_size = compute_max_tensor_size_per_launch<tensor_count_per_group>(4);
|
||||
for (int i = 0; i < group_count; ++i) {
|
||||
if (tensor_sizes[i] > max_tensor_size) {
|
||||
LambUpdate(
|
||||
|
|
|
|||
Loading…
Reference in a new issue