From ced5b66306cd6884124042cc0d9db714e7444aac Mon Sep 17 00:00:00 2001 From: Jesse Benson Date: Fri, 13 Nov 2020 12:28:44 -0800 Subject: [PATCH] Re-enable multi-tensor-apply for LAMB optimizer --- .../orttraining/training_ops/rocm/optimizer/lamb.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/training_ops/rocm/optimizer/lamb.cc b/orttraining/orttraining/training_ops/rocm/optimizer/lamb.cc index 0656e9671a..865769cdd7 100644 --- a/orttraining/orttraining/training_ops/rocm/optimizer/lamb.cc +++ b/orttraining/orttraining/training_ops/rocm/optimizer/lamb.cc @@ -196,8 +196,7 @@ Status launch_lamb_compute_direction( ORT_ENFORCE(group_count == static_cast(epsilons.size())); constexpr int tensor_count_per_group = 6; - // const int max_tensor_size = compute_max_tensor_size_per_launch(4); - const int max_tensor_size = 0; + const int max_tensor_size = compute_max_tensor_size_per_launch(4); // Bucketize tensor groups by the associated optimizer configuration. // If two tensor groups use different "alpha", they should be put into two distinct buckets. std::map, std::vector>> buckets; @@ -290,8 +289,7 @@ Status launch_lamb_reduction( // If two tensor groups use different "alpha", they should be put into two distinct buckets. std::vector> buckets; std::vector tensor_sizes_in_buckets; - // const int max_tensor_size = compute_max_tensor_size_per_launch(4); - const int max_tensor_size = 0; + const int max_tensor_size = compute_max_tensor_size_per_launch(4); for (int i = 0; i < group_count; ++i) { if (tensor_sizes[i] > max_tensor_size) { reduce_square_sum( @@ -368,8 +366,7 @@ Status launch_lamb_update( // If two tensor groups use different "alpha", they should be put into two distinct buckets. std::vector> buckets; std::vector tensor_sizes_in_bucket; - // const int max_tensor_size = compute_max_tensor_size_per_launch(4); - const int max_tensor_size = 0; + const int max_tensor_size = compute_max_tensor_size_per_launch(4); for (int i = 0; i < group_count; ++i) { if (tensor_sizes[i] > max_tensor_size) { LambUpdate(