From 89ef987ab1cddd0efab0fd1f1e27738da7befe92 Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 25 Mar 2022 07:35:45 +0800 Subject: [PATCH] Improve NonZero on CUDA/ROCM (#10307) * improve NonZero * fix megatron_fp16 optimzier, fix the doc * multi_tensor_applier * resolve comment * fix building warning * fix build error when enabling training and use tensorrt --- cmake/onnxruntime_providers.cmake | 3 + docs/OperatorKernels.md | 4 +- .../providers/cuda/cuda_execution_provider.cc | 4 + .../providers/cuda/tensor/nonzero_impl.cu | 42 ++- .../providers/rocm/rocm_execution_provider.cc | 6 +- .../python/training/optim/_modifier.py | 11 +- .../cuda/fused_ops/fused_ops_frontend.cpp | 283 +++++++++--------- .../cpu/math/isfinite_ops_test.cc | 12 +- 8 files changed, 212 insertions(+), 153 deletions(-) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 98719280b2..88040c194a 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -632,6 +632,9 @@ if (onnxruntime_USE_TENSORRT) # Needed for the provider interface, as it includes training headers when training is enabled if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ORTTRAINING_ROOT}) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt Python::Module) + endif() endif() if(APPLE) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8bb10aa10f..f45b47d160 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -597,8 +597,8 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| -|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)| -|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)| +|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| +|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)| |OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)
**T3** = tensor(float), tensor(float16), tensor(int64)| |Or|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index f2743c9c94..386163ef81 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -791,6 +791,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, float, NonZero); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, 8, Scan); @@ -1085,6 +1086,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, NonZero); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast); @@ -1661,6 +1663,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1955,6 +1958,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu index 90be2b8b27..2e1ea4e3af 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu @@ -69,6 +69,40 @@ __global__ void NonZeroOutputPositionsKernel( } } + +constexpr int MAX_DIMS = 16; + +template +__global__ void UnRolledNonZeroOutputPositionsKernel( + const InputT* x, int64_t x_size, int x_rank, const TArray x_strides, + const int* prefix_counts, int nonzero_elements, int64_t* results) { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + // const cub::CastOp cast_to_bool; not supported on amd hipcub + int nz = 0; + if (index < x_size && bool(x[index])) ++nz; + int pos_in_block = 0; + BlockScanT(temp_storage).InclusiveSum(nz, pos_in_block); + + int result_position = ((blockIdx.x == 0) ? 0 : prefix_counts[blockIdx.x - 1]) + pos_in_block - nz; + + if (index < x_size && bool(x[index])) { + int remain = (int)index, dim = 0; + int rp = result_position; + #pragma unroll + for (int axis = 0; axis < MAX_DIMS; ++axis) { + if (axis == x_rank) { + break; + } + x_strides[axis].divmod(remain, dim, remain); + results[rp] = (int64_t)dim; + rp += nonzero_elements; + } + } +} + template cudaError_t NonZeroCountEachBlock(cudaStream_t stream, const InputT* x, int64_t x_size, int* count_in_blocks) { int num_blocks = NonZeroCalcBlockCount(x_size); @@ -82,9 +116,15 @@ cudaError_t NonZeroOutputPositions( cudaStream_t stream, const InputT* x, int64_t x_size, int x_rank, const TArray& x_strides, const int* prefix_counts, int nonzero_elements, int64_t* results) { int num_blocks = NonZeroCalcBlockCount(x_size); - NonZeroOutputPositionsKernel<<>>( + if (x_rank > MAX_DIMS) { + NonZeroOutputPositionsKernel<<>>( x, x_size, x_rank, x_strides, prefix_counts, nonzero_elements, results); + } else { + UnRolledNonZeroOutputPositionsKernel<<>>( + x, x_size, x_rank, x_strides, + prefix_counts, nonzero_elements, results); + } return cudaSuccess; } diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 5121c1e626..24e7b7187c 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -710,6 +710,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 8, Scan); @@ -1004,6 +1005,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, NonZero); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast); @@ -1232,7 +1234,7 @@ KernelCreateInfo BuildKernelCreateInfo() { static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, //default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1579,6 +1581,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1873,6 +1876,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/python/training/optim/_modifier.py b/orttraining/orttraining/python/training/optim/_modifier.py index c364156840..26c5c72a7c 100644 --- a/orttraining/orttraining/python/training/optim/_modifier.py +++ b/orttraining/orttraining/python/training/optim/_modifier.py @@ -90,6 +90,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type, norm_type = float(norm_type) total_norm = 0.0 + dummy_overflow_buf = torch.cuda.IntTensor([0]) # Calculate norm. if norm_type == inf: total_norm = max(grad.abs().max() for grad in grads_for_norm) @@ -104,7 +105,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type, else: if norm_type == 2.0: - dummy_overflow_buf = torch.cuda.IntTensor([0]) # Use apex's multi-tensor applier for efficiency reasons. # Multi-tensor applier takes a function and a list of list # and performs the operation on that list all in one kernel. @@ -133,6 +133,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type, op=torch.distributed.ReduceOp.SUM, group=get_horizontal_model_parallel_group()) total_norm = total_norm.item() ** (1.0 / norm_type) + clip_coef = max_norm / (total_norm + 1e-6) + # Filter parameters with gradients. + grads = [p.grad for p in parameters if p.grad is not None] + if clip_coef < 1.0: + multi_tensor_applier( + amp_C.multi_tensor_scale, + dummy_overflow_buf, + [grads, grads], + clip_coef) return total_norm diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp index 3bcdd37a2d..6111214676 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp @@ -26,15 +26,14 @@ void multi_tensor_adam_cuda(int chunk_size, const int bias_correction, const float weight_decay); -// This function is adapted from NVIDIA/apex +// This function is adapted from NVIDIA/apex // https://github.com/NVIDIA/apex/blob/0c7d8e3fa9a095a1641a2290877436d0314b69c6/csrc/amp_C_frontend.cpp#L3 void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector>& tensor_lists, float scale); - -// This function is adapted from NVIDIA/apex +// This function is adapted from NVIDIA/apex // https://github.com/NVIDIA/apex/blob/0c7d8e3fa9a095a1641a2290877436d0314b69c6/csrc/amp_C_frontend.cpp#L22 void multi_tensor_axpby_cuda(int chunk_size, at::Tensor noop_flag, @@ -43,43 +42,42 @@ void multi_tensor_axpby_cuda(int chunk_size, float b, int arg_to_check); - class MemoryBuffer { - public: - MemoryBuffer(size_t numel, at::Tensor val){ - data_buffer_ = at::empty({numel}, val.options()); - } + public: + MemoryBuffer(size_t numel, at::Tensor val) { + data_buffer_ = at::empty({static_cast(numel)}, val.options()); + } - at::Tensor Get(at::Tensor param, size_t start_index) { - size_t end_index = start_index + param.numel(); - return data_buffer_.slice(0, start_index, end_index).view(param.sizes()); - } + at::Tensor Get(at::Tensor param, size_t start_index) { + size_t end_index = start_index + param.numel(); + return data_buffer_.slice(0, start_index, end_index).view(param.sizes()); + } - private: - at::Tensor data_buffer_; + private: + at::Tensor data_buffer_; }; class CachedStates { - public: - static CachedStates& GetInstance(){ - static CachedStates states_; - return states_; - } + public: + static CachedStates& GetInstance() { + static CachedStates states_; + return states_; + } - void ClearStates(){ - idx_to_numel_map.clear(); - } + void ClearStates() { + idx_to_numel_map.clear(); + } - // Parameter index to number of element mapping for each parameter. - std::vector> idx_to_numel_map; + // Parameter index to number of element mapping for each parameter. + std::vector> idx_to_numel_map; - private: - CachedStates(){} + private: + CachedStates() {} }; -bool SortByElementSizeDesc(const std::pair &a, - const std::pair &b) { - return (a.second > b.second); +bool SortByElementSizeDesc(const std::pair& a, + const std::pair& b) { + return (a.second > b.second); }; // This function is trying to move into C++ implementation from Python logic @@ -89,134 +87,135 @@ void unscale_fp16_grads_into_fp32_grads(std::vector& all_fp16_params std::vector& all_fp32_from_fp16_params, at::Tensor is_overflow_buffer, float scale) { - if (all_fp16_params.size() == 0 || all_fp32_from_fp16_params.size() == 0) { - return; - } + if (all_fp16_params.size() == 0 || all_fp32_from_fp16_params.size() == 0) { + return; + } - const float inv_scale = 1.0 / scale; - TORCH_CHECK(all_fp16_params.size() == all_fp32_from_fp16_params.size(), - "mismatch param size between fp16_param and fp32_from_fp16_param."); + const float inv_scale = 1.0 / scale; + TORCH_CHECK(all_fp16_params.size() == all_fp32_from_fp16_params.size(), + "mismatch param size between fp16_param and fp32_from_fp16_param."); - // Use cached states only parameter count did not get changed. - bool need_reset_states = - all_fp32_from_fp16_params.size() != CachedStates::GetInstance().idx_to_numel_map.size(); + // Use cached states only parameter count did not get changed. + bool need_reset_states = + all_fp32_from_fp16_params.size() != CachedStates::GetInstance().idx_to_numel_map.size(); + if (need_reset_states) { + CachedStates::GetInstance().ClearStates(); + } + + std::vector fp16_grads_needing_unscale; + std::vector fp16_grads_needing_unscale_with_stash; + std::vector preexisting_fp32_grads; + + // Parameter index to parameter mapping for each fp32_from_fp16 parameter. + std::unordered_map idx_to_fp32_from_fp16_params; + + // "buffer index" to "offset in memory buffer" mapping for each fp32_from_fp16 parameter. + std::vector memory_buffer_idx_to_offset_map; + size_t memory_buffer_size = 0; + auto& idx_to_numel_map = CachedStates::GetInstance().idx_to_numel_map; + + for (size_t idx = 0; idx < all_fp16_params.size(); ++idx) { + auto& fp16_param_grad = all_fp16_params[idx].grad(); + bool fp16_param_has_grad = fp16_param_grad.defined(); + + auto& fp32_from_fp16_param = all_fp32_from_fp16_params[idx]; + auto& fp32_from_fp16_param_grad = fp32_from_fp16_param.grad(); + bool fp32_from_fp16_param_has_grad = fp32_from_fp16_param_grad.defined(); + + size_t num_elem = fp32_from_fp16_param.numel(); if (need_reset_states) { - CachedStates::GetInstance().ClearStates(); + idx_to_numel_map.push_back(std::make_pair(idx, num_elem)); } - std::vector fp16_grads_needing_unscale; - std::vector fp16_grads_needing_unscale_with_stash; - std::vector preexisting_fp32_grads; + if (fp16_param_has_grad && !fp32_from_fp16_param_has_grad) { + idx_to_fp32_from_fp16_params.emplace(std::make_pair(idx, fp32_from_fp16_param)); + fp16_grads_needing_unscale.emplace_back(fp16_param_grad); + memory_buffer_idx_to_offset_map.emplace_back(memory_buffer_size); + memory_buffer_size += num_elem; + } else if (fp16_param_has_grad && fp32_from_fp16_param_has_grad) { + fp16_grads_needing_unscale_with_stash.emplace_back(fp16_param_grad); + preexisting_fp32_grads.emplace_back(fp32_from_fp16_param_grad); + } + } - // Parameter index to parameter mapping for each fp32_from_fp16 parameter. - std::unordered_map idx_to_fp32_from_fp16_params; + if (need_reset_states) { + std::sort(idx_to_numel_map.begin(), idx_to_numel_map.end(), SortByElementSizeDesc); + } - // "buffer index" to "offset in memory buffer" mapping for each fp32_from_fp16 parameter. - std::vector memory_buffer_idx_to_offset_map; - size_t memory_buffer_size = 0; - auto& idx_to_numel_map = CachedStates::GetInstance().idx_to_numel_map; + if (idx_to_fp32_from_fp16_params.size() > 0) { + auto mem_buffer = MemoryBuffer(memory_buffer_size, idx_to_fp32_from_fp16_params.begin()->second); + const size_t emit_threshhold = memory_buffer_size / EMIT_NUM; - for (size_t idx = 0; idx < all_fp16_params.size(); ++idx) { - auto& fp16_param_grad = all_fp16_params[idx].grad(); - bool fp16_param_has_grad = fp16_param_grad.defined(); + size_t acc_size = 0; + std::vector partial_new_fp32_grads; + std::vector partial_fp16_grads_needing_unscale; + for (size_t idx = 0, fp32_from_fp16_param_idx = 0; idx < idx_to_numel_map.size(); ++idx) { + if (idx_to_fp32_from_fp16_params.find(idx) == idx_to_fp32_from_fp16_params.end()) { + continue; + } - auto& fp32_from_fp16_param = all_fp32_from_fp16_params[idx]; - auto& fp32_from_fp16_param_grad = fp32_from_fp16_param.grad(); - bool fp32_from_fp16_param_has_grad = fp32_from_fp16_param_grad.defined(); + acc_size += idx_to_numel_map[idx].second; + idx_to_fp32_from_fp16_params[idx].mutable_grad() = + mem_buffer.Get(idx_to_fp32_from_fp16_params[idx], + memory_buffer_idx_to_offset_map[fp32_from_fp16_param_idx]); + partial_new_fp32_grads.emplace_back(idx_to_fp32_from_fp16_params[idx].grad()); + partial_fp16_grads_needing_unscale.emplace_back(fp16_grads_needing_unscale[fp32_from_fp16_param_idx]); - size_t num_elem = fp32_from_fp16_param.numel(); - if (need_reset_states) { - idx_to_numel_map.push_back(std::make_pair(idx, num_elem)); - } - - if (fp16_param_has_grad && !fp32_from_fp16_param_has_grad) { - idx_to_fp32_from_fp16_params.emplace(std::make_pair(idx, fp32_from_fp16_param)); - fp16_grads_needing_unscale.emplace_back(fp16_param_grad); - memory_buffer_idx_to_offset_map.emplace_back(memory_buffer_size); - memory_buffer_size += num_elem; - } else if (fp16_param_has_grad && fp32_from_fp16_param_has_grad) { - fp16_grads_needing_unscale_with_stash.emplace_back(fp16_param_grad); - preexisting_fp32_grads.emplace_back(fp32_from_fp16_param_grad); + if (acc_size > emit_threshhold || fp32_from_fp16_param_idx == idx_to_fp32_from_fp16_params.size() - 1) { + if (partial_fp16_grads_needing_unscale.size() > 0) { + std::vector> tensor_lists; + tensor_lists.emplace_back(partial_fp16_grads_needing_unscale); + tensor_lists.emplace_back(partial_new_fp32_grads); + multi_tensor_scale_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale); + + partial_fp16_grads_needing_unscale.clear(); + partial_new_fp32_grads.clear(); + acc_size = 0; } + } + ++fp32_from_fp16_param_idx; } + } - if (need_reset_states) { - std::sort(idx_to_numel_map.begin(), idx_to_numel_map.end(), SortByElementSizeDesc); - } - - if (idx_to_fp32_from_fp16_params.size() > 0) { - auto mem_buffer = MemoryBuffer(memory_buffer_size, idx_to_fp32_from_fp16_params.begin()->second); - const size_t emit_threshhold = memory_buffer_size / EMIT_NUM; - - size_t acc_size = 0; - std::vector partial_new_fp32_grads; - std::vector partial_fp16_grads_needing_unscale; - for (size_t idx = 0, fp32_from_fp16_param_idx = 0; idx < idx_to_numel_map.size(); ++idx) { - if (idx_to_fp32_from_fp16_params.find(idx) == idx_to_fp32_from_fp16_params.end()) { - continue; - } - - acc_size += idx_to_numel_map[idx].second; - idx_to_fp32_from_fp16_params[idx].mutable_grad() = - mem_buffer.Get(idx_to_fp32_from_fp16_params[idx], - memory_buffer_idx_to_offset_map[fp32_from_fp16_param_idx]); - partial_new_fp32_grads.emplace_back(idx_to_fp32_from_fp16_params[idx].grad()); - partial_fp16_grads_needing_unscale.emplace_back(fp16_grads_needing_unscale[fp32_from_fp16_param_idx]); - - if (acc_size > emit_threshhold || fp32_from_fp16_param_idx == idx_to_fp32_from_fp16_params.size() - 1) { - if (partial_fp16_grads_needing_unscale.size() > 0) { - std::vector> tensor_lists; - tensor_lists.emplace_back(partial_fp16_grads_needing_unscale); - tensor_lists.emplace_back(partial_new_fp32_grads); - multi_tensor_scale_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale); - - partial_fp16_grads_needing_unscale.clear(); - partial_new_fp32_grads.clear(); - acc_size = 0; - } - } - ++fp32_from_fp16_param_idx; - } - } - - if (fp16_grads_needing_unscale_with_stash.size() > 0) { - std::vector> tensor_lists; - tensor_lists.emplace_back(fp16_grads_needing_unscale_with_stash); - tensor_lists.emplace_back(preexisting_fp32_grads); - tensor_lists.emplace_back(preexisting_fp32_grads); - // a * x + b * y - multi_tensor_axpby_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale, float(1.0), 0); - } + if (fp16_grads_needing_unscale_with_stash.size() > 0) { + std::vector> tensor_lists; + tensor_lists.emplace_back(fp16_grads_needing_unscale_with_stash); + tensor_lists.emplace_back(preexisting_fp32_grads); + tensor_lists.emplace_back(preexisting_fp32_grads); + // a * x + b * y + multi_tensor_axpby_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale, float(1.0), 0); + } }; -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - // Cannot use the shortcut API below because https://github.com/pybind/pybind11/issues/1470 - // py::bind_vector>(m, "TorchTensorVector"); - py::class_>(m, "TorchTensorVector") - .def(py::init<>()) - .def("clear", &std::vector::clear) - .def("pop_back", &std::vector::pop_back) - .def("__len__", [](const std::vector &v) { return v.size(); }) - .def("__iter__", [](std::vector &v) { +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // Cannot use the shortcut API below because https://github.com/pybind/pybind11/issues/1470 + // py::bind_vector>(m, "TorchTensorVector"); + py::class_>(m, "TorchTensorVector") + .def(py::init<>()) + .def("clear", &std::vector::clear) + .def("pop_back", &std::vector::pop_back) + .def("__len__", [](const std::vector& v) { return v.size(); }) + .def( + "__iter__", [](std::vector& v) { return py::make_iterator(v.begin(), v.end()); - }, py::keep_alive<0, 1>()) - .def("extend", [](std::vector &v, const std::vector &src) { - v.insert(v.end(), src.begin(), src.end()); - }) - .def(py::init([](const py::iterable &it) { - auto v = std::unique_ptr>(new std::vector()); - v->reserve(py::len_hint(it)); - for (py::handle h : it) { - v->push_back(h.cast()); - } - return v.release(); - })); + }, + py::keep_alive<0, 1>()) + .def("extend", [](std::vector& v, const std::vector& src) { + v.insert(v.end(), src.begin(), src.end()); + }) + .def(py::init([](const py::iterable& it) { + auto v = std::unique_ptr>(new std::vector()); + v->reserve(py::len_hint(it)); + for (py::handle h : it) { + v->push_back(h.cast()); + } + return v.release(); + })); - m.def("multi_tensor_adam", - &multi_tensor_adam_cuda, - "Compute and apply gradient update to parameters for Adam optimizer"); - m.def("unscale_fp16_grads_into_fp32_grads", - &unscale_fp16_grads_into_fp32_grads, - "Unscale those fp16 gradients into fp32 gradient buffers."); + m.def("multi_tensor_adam", + &multi_tensor_adam_cuda, + "Compute and apply gradient update to parameters for Adam optimizer"); + m.def("unscale_fp16_grads_into_fp32_grads", + &unscale_fp16_grads_into_fp32_grads, + "Unscale those fp16 gradients into fp32 gradient buffers."); } diff --git a/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc b/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc index 675d1ad7c6..d8d2cb8e83 100644 --- a/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc @@ -171,7 +171,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatTensorLarge) { OpTester test("IsAllFinite", 1, kMSDomain); bool expected_answer = false; auto tensors = generate_is_all_finite_test_data(13, 941736, expected_answer, test_count); - for (int i = 0; i < tensors.size(); ++i) { + for (size_t i = 0; i < tensors.size(); ++i) { auto name = std::string("X") + std::to_string(i); auto size = static_cast(tensors[i].size()); test.AddInput(name.c_str(), {size}, tensors[i]); @@ -186,7 +186,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlock) { OpTester test("IsAllFinite", 1, kMSDomain); bool expected_answer = false; auto tensors = generate_is_all_finite_test_data(894, 17, expected_answer, test_count); - for (int i = 0; i < tensors.size(); ++i) { + for (size_t i = 0; i < tensors.size(); ++i) { auto name = std::string("X") + std::to_string(i); auto size = static_cast(tensors[i].size()); test.AddInput(name.c_str(), {size}, tensors[i]); @@ -199,7 +199,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlock) { TEST(IsAllFiniteTest, MoreFalseFloatMultipleFalse) { OpTester test("IsAllFinite", 1, kMSDomain); auto tensors = generate_is_all_finite_test_data(1234, 1987, 0.1f, 0); - for (int i = 0; i < tensors.size(); ++i) { + for (size_t i = 0; i < tensors.size(); ++i) { auto name = std::string("X") + std::to_string(i); auto size = static_cast(tensors[i].size()); test.AddInput(name.c_str(), {size}, tensors[i]); @@ -212,7 +212,7 @@ TEST(IsAllFiniteTest, MoreTrueFloatTensorLarge) { OpTester test("IsAllFinite", 1, kMSDomain); bool expected_answer = true; auto tensors = generate_is_all_finite_test_data(12, 941736, expected_answer, 0); - for (int i = 0; i < tensors.size(); ++i) { + for (size_t i = 0; i < tensors.size(); ++i) { auto name = std::string("X") + std::to_string(i); auto size = static_cast(tensors[i].size()); test.AddInput(name.c_str(), {size}, tensors[i]); @@ -227,7 +227,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlockFloat16) { OpTester test("IsAllFinite", 1, kMSDomain); bool expected_answer = false; auto tensors = generate_is_all_finite_test_data(894, 17, expected_answer, 0); - for (int i = 0; i < tensors.size(); ++i) { + for (size_t i = 0; i < tensors.size(); ++i) { auto name = std::string("X") + std::to_string(i); auto size = static_cast(tensors[i].size()); std::vector buffer_half(tensors[i].size()); @@ -242,7 +242,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatTensorLargeFloat16) { OpTester test("IsAllFinite", 1, kMSDomain); bool expected_answer = false; auto tensors = generate_is_all_finite_test_data(12, 941736, expected_answer, 0); - for (int i = 0; i < tensors.size(); ++i) { + for (size_t i = 0; i < tensors.size(); ++i) { auto name = std::string("X") + std::to_string(i); auto size = static_cast(tensors[i].size()); std::vector buffer_half(tensors[i].size());