diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 98719280b2..88040c194a 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -632,6 +632,9 @@ if (onnxruntime_USE_TENSORRT)
# Needed for the provider interface, as it includes training headers when training is enabled
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ORTTRAINING_ROOT})
+ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
+ onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt Python::Module)
+ endif()
endif()
if(APPLE)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8bb10aa10f..f45b47d160 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -597,8 +597,8 @@ Do not modify directly.*
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
-|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)|
-|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)|
+|NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
+|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|Not|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bool)|
|OneHot|*in* indices:**T1**
*in* depth:**T2**
*in* values:**T3**
*out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)
**T3** = tensor(float), tensor(float16), tensor(int64)|
|Or|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)|
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index f2743c9c94..386163ef81 100755
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -791,6 +791,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, float, NonZero);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, 8, Scan);
@@ -1085,6 +1086,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, NonZero);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast);
@@ -1661,6 +1663,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1955,6 +1958,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu
index 90be2b8b27..2e1ea4e3af 100644
--- a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.cu
@@ -69,6 +69,40 @@ __global__ void NonZeroOutputPositionsKernel(
}
}
+
+constexpr int MAX_DIMS = 16;
+
+template
+__global__ void UnRolledNonZeroOutputPositionsKernel(
+ const InputT* x, int64_t x_size, int x_rank, const TArray x_strides,
+ const int* prefix_counts, int nonzero_elements, int64_t* results) {
+ typedef cub::BlockScan BlockScanT;
+ __shared__ typename BlockScanT::TempStorage temp_storage;
+
+ int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
+ // const cub::CastOp cast_to_bool; not supported on amd hipcub
+ int nz = 0;
+ if (index < x_size && bool(x[index])) ++nz;
+ int pos_in_block = 0;
+ BlockScanT(temp_storage).InclusiveSum(nz, pos_in_block);
+
+ int result_position = ((blockIdx.x == 0) ? 0 : prefix_counts[blockIdx.x - 1]) + pos_in_block - nz;
+
+ if (index < x_size && bool(x[index])) {
+ int remain = (int)index, dim = 0;
+ int rp = result_position;
+ #pragma unroll
+ for (int axis = 0; axis < MAX_DIMS; ++axis) {
+ if (axis == x_rank) {
+ break;
+ }
+ x_strides[axis].divmod(remain, dim, remain);
+ results[rp] = (int64_t)dim;
+ rp += nonzero_elements;
+ }
+ }
+}
+
template
cudaError_t NonZeroCountEachBlock(cudaStream_t stream, const InputT* x, int64_t x_size, int* count_in_blocks) {
int num_blocks = NonZeroCalcBlockCount(x_size);
@@ -82,9 +116,15 @@ cudaError_t NonZeroOutputPositions(
cudaStream_t stream, const InputT* x, int64_t x_size, int x_rank, const TArray& x_strides,
const int* prefix_counts, int nonzero_elements, int64_t* results) {
int num_blocks = NonZeroCalcBlockCount(x_size);
- NonZeroOutputPositionsKernel<<>>(
+ if (x_rank > MAX_DIMS) {
+ NonZeroOutputPositionsKernel<<>>(
x, x_size, x_rank, x_strides,
prefix_counts, nonzero_elements, results);
+ } else {
+ UnRolledNonZeroOutputPositionsKernel<<>>(
+ x, x_size, x_rank, x_strides,
+ prefix_counts, nonzero_elements, results);
+ }
return cudaSuccess;
}
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index 5121c1e626..24e7b7187c 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -710,6 +710,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, TopK);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, If);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 8, Scan);
@@ -1004,6 +1005,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, NonZero);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Cast);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast);
@@ -1232,7 +1234,7 @@ KernelCreateInfo BuildKernelCreateInfo() {
static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
- BuildKernelCreateInfo, //default entry to avoid the list become empty after ops-reducing
+ BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1579,6 +1581,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
// BuildKernelCreateInfo,
// BuildKernelCreateInfo,
@@ -1873,6 +1876,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/orttraining/orttraining/python/training/optim/_modifier.py b/orttraining/orttraining/python/training/optim/_modifier.py
index c364156840..26c5c72a7c 100644
--- a/orttraining/orttraining/python/training/optim/_modifier.py
+++ b/orttraining/orttraining/python/training/optim/_modifier.py
@@ -90,6 +90,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type,
norm_type = float(norm_type)
total_norm = 0.0
+ dummy_overflow_buf = torch.cuda.IntTensor([0])
# Calculate norm.
if norm_type == inf:
total_norm = max(grad.abs().max() for grad in grads_for_norm)
@@ -104,7 +105,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type,
else:
if norm_type == 2.0:
- dummy_overflow_buf = torch.cuda.IntTensor([0])
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
@@ -133,6 +133,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type,
op=torch.distributed.ReduceOp.SUM,
group=get_horizontal_model_parallel_group())
total_norm = total_norm.item() ** (1.0 / norm_type)
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Filter parameters with gradients.
+ grads = [p.grad for p in parameters if p.grad is not None]
+ if clip_coef < 1.0:
+ multi_tensor_applier(
+ amp_C.multi_tensor_scale,
+ dummy_overflow_buf,
+ [grads, grads],
+ clip_coef)
return total_norm
diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp
index 3bcdd37a2d..6111214676 100644
--- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp
+++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/fused_ops_frontend.cpp
@@ -26,15 +26,14 @@ void multi_tensor_adam_cuda(int chunk_size,
const int bias_correction,
const float weight_decay);
-// This function is adapted from NVIDIA/apex
+// This function is adapted from NVIDIA/apex
// https://github.com/NVIDIA/apex/blob/0c7d8e3fa9a095a1641a2290877436d0314b69c6/csrc/amp_C_frontend.cpp#L3
void multi_tensor_scale_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector>& tensor_lists,
float scale);
-
-// This function is adapted from NVIDIA/apex
+// This function is adapted from NVIDIA/apex
// https://github.com/NVIDIA/apex/blob/0c7d8e3fa9a095a1641a2290877436d0314b69c6/csrc/amp_C_frontend.cpp#L22
void multi_tensor_axpby_cuda(int chunk_size,
at::Tensor noop_flag,
@@ -43,43 +42,42 @@ void multi_tensor_axpby_cuda(int chunk_size,
float b,
int arg_to_check);
-
class MemoryBuffer {
- public:
- MemoryBuffer(size_t numel, at::Tensor val){
- data_buffer_ = at::empty({numel}, val.options());
- }
+ public:
+ MemoryBuffer(size_t numel, at::Tensor val) {
+ data_buffer_ = at::empty({static_cast(numel)}, val.options());
+ }
- at::Tensor Get(at::Tensor param, size_t start_index) {
- size_t end_index = start_index + param.numel();
- return data_buffer_.slice(0, start_index, end_index).view(param.sizes());
- }
+ at::Tensor Get(at::Tensor param, size_t start_index) {
+ size_t end_index = start_index + param.numel();
+ return data_buffer_.slice(0, start_index, end_index).view(param.sizes());
+ }
- private:
- at::Tensor data_buffer_;
+ private:
+ at::Tensor data_buffer_;
};
class CachedStates {
- public:
- static CachedStates& GetInstance(){
- static CachedStates states_;
- return states_;
- }
+ public:
+ static CachedStates& GetInstance() {
+ static CachedStates states_;
+ return states_;
+ }
- void ClearStates(){
- idx_to_numel_map.clear();
- }
+ void ClearStates() {
+ idx_to_numel_map.clear();
+ }
- // Parameter index to number of element mapping for each parameter.
- std::vector> idx_to_numel_map;
+ // Parameter index to number of element mapping for each parameter.
+ std::vector> idx_to_numel_map;
- private:
- CachedStates(){}
+ private:
+ CachedStates() {}
};
-bool SortByElementSizeDesc(const std::pair &a,
- const std::pair &b) {
- return (a.second > b.second);
+bool SortByElementSizeDesc(const std::pair& a,
+ const std::pair& b) {
+ return (a.second > b.second);
};
// This function is trying to move into C++ implementation from Python logic
@@ -89,134 +87,135 @@ void unscale_fp16_grads_into_fp32_grads(std::vector& all_fp16_params
std::vector& all_fp32_from_fp16_params,
at::Tensor is_overflow_buffer,
float scale) {
- if (all_fp16_params.size() == 0 || all_fp32_from_fp16_params.size() == 0) {
- return;
- }
+ if (all_fp16_params.size() == 0 || all_fp32_from_fp16_params.size() == 0) {
+ return;
+ }
- const float inv_scale = 1.0 / scale;
- TORCH_CHECK(all_fp16_params.size() == all_fp32_from_fp16_params.size(),
- "mismatch param size between fp16_param and fp32_from_fp16_param.");
+ const float inv_scale = 1.0 / scale;
+ TORCH_CHECK(all_fp16_params.size() == all_fp32_from_fp16_params.size(),
+ "mismatch param size between fp16_param and fp32_from_fp16_param.");
- // Use cached states only parameter count did not get changed.
- bool need_reset_states =
- all_fp32_from_fp16_params.size() != CachedStates::GetInstance().idx_to_numel_map.size();
+ // Use cached states only parameter count did not get changed.
+ bool need_reset_states =
+ all_fp32_from_fp16_params.size() != CachedStates::GetInstance().idx_to_numel_map.size();
+ if (need_reset_states) {
+ CachedStates::GetInstance().ClearStates();
+ }
+
+ std::vector fp16_grads_needing_unscale;
+ std::vector fp16_grads_needing_unscale_with_stash;
+ std::vector preexisting_fp32_grads;
+
+ // Parameter index to parameter mapping for each fp32_from_fp16 parameter.
+ std::unordered_map idx_to_fp32_from_fp16_params;
+
+ // "buffer index" to "offset in memory buffer" mapping for each fp32_from_fp16 parameter.
+ std::vector memory_buffer_idx_to_offset_map;
+ size_t memory_buffer_size = 0;
+ auto& idx_to_numel_map = CachedStates::GetInstance().idx_to_numel_map;
+
+ for (size_t idx = 0; idx < all_fp16_params.size(); ++idx) {
+ auto& fp16_param_grad = all_fp16_params[idx].grad();
+ bool fp16_param_has_grad = fp16_param_grad.defined();
+
+ auto& fp32_from_fp16_param = all_fp32_from_fp16_params[idx];
+ auto& fp32_from_fp16_param_grad = fp32_from_fp16_param.grad();
+ bool fp32_from_fp16_param_has_grad = fp32_from_fp16_param_grad.defined();
+
+ size_t num_elem = fp32_from_fp16_param.numel();
if (need_reset_states) {
- CachedStates::GetInstance().ClearStates();
+ idx_to_numel_map.push_back(std::make_pair(idx, num_elem));
}
- std::vector fp16_grads_needing_unscale;
- std::vector fp16_grads_needing_unscale_with_stash;
- std::vector preexisting_fp32_grads;
+ if (fp16_param_has_grad && !fp32_from_fp16_param_has_grad) {
+ idx_to_fp32_from_fp16_params.emplace(std::make_pair(idx, fp32_from_fp16_param));
+ fp16_grads_needing_unscale.emplace_back(fp16_param_grad);
+ memory_buffer_idx_to_offset_map.emplace_back(memory_buffer_size);
+ memory_buffer_size += num_elem;
+ } else if (fp16_param_has_grad && fp32_from_fp16_param_has_grad) {
+ fp16_grads_needing_unscale_with_stash.emplace_back(fp16_param_grad);
+ preexisting_fp32_grads.emplace_back(fp32_from_fp16_param_grad);
+ }
+ }
- // Parameter index to parameter mapping for each fp32_from_fp16 parameter.
- std::unordered_map idx_to_fp32_from_fp16_params;
+ if (need_reset_states) {
+ std::sort(idx_to_numel_map.begin(), idx_to_numel_map.end(), SortByElementSizeDesc);
+ }
- // "buffer index" to "offset in memory buffer" mapping for each fp32_from_fp16 parameter.
- std::vector memory_buffer_idx_to_offset_map;
- size_t memory_buffer_size = 0;
- auto& idx_to_numel_map = CachedStates::GetInstance().idx_to_numel_map;
+ if (idx_to_fp32_from_fp16_params.size() > 0) {
+ auto mem_buffer = MemoryBuffer(memory_buffer_size, idx_to_fp32_from_fp16_params.begin()->second);
+ const size_t emit_threshhold = memory_buffer_size / EMIT_NUM;
- for (size_t idx = 0; idx < all_fp16_params.size(); ++idx) {
- auto& fp16_param_grad = all_fp16_params[idx].grad();
- bool fp16_param_has_grad = fp16_param_grad.defined();
+ size_t acc_size = 0;
+ std::vector partial_new_fp32_grads;
+ std::vector partial_fp16_grads_needing_unscale;
+ for (size_t idx = 0, fp32_from_fp16_param_idx = 0; idx < idx_to_numel_map.size(); ++idx) {
+ if (idx_to_fp32_from_fp16_params.find(idx) == idx_to_fp32_from_fp16_params.end()) {
+ continue;
+ }
- auto& fp32_from_fp16_param = all_fp32_from_fp16_params[idx];
- auto& fp32_from_fp16_param_grad = fp32_from_fp16_param.grad();
- bool fp32_from_fp16_param_has_grad = fp32_from_fp16_param_grad.defined();
+ acc_size += idx_to_numel_map[idx].second;
+ idx_to_fp32_from_fp16_params[idx].mutable_grad() =
+ mem_buffer.Get(idx_to_fp32_from_fp16_params[idx],
+ memory_buffer_idx_to_offset_map[fp32_from_fp16_param_idx]);
+ partial_new_fp32_grads.emplace_back(idx_to_fp32_from_fp16_params[idx].grad());
+ partial_fp16_grads_needing_unscale.emplace_back(fp16_grads_needing_unscale[fp32_from_fp16_param_idx]);
- size_t num_elem = fp32_from_fp16_param.numel();
- if (need_reset_states) {
- idx_to_numel_map.push_back(std::make_pair(idx, num_elem));
- }
-
- if (fp16_param_has_grad && !fp32_from_fp16_param_has_grad) {
- idx_to_fp32_from_fp16_params.emplace(std::make_pair(idx, fp32_from_fp16_param));
- fp16_grads_needing_unscale.emplace_back(fp16_param_grad);
- memory_buffer_idx_to_offset_map.emplace_back(memory_buffer_size);
- memory_buffer_size += num_elem;
- } else if (fp16_param_has_grad && fp32_from_fp16_param_has_grad) {
- fp16_grads_needing_unscale_with_stash.emplace_back(fp16_param_grad);
- preexisting_fp32_grads.emplace_back(fp32_from_fp16_param_grad);
+ if (acc_size > emit_threshhold || fp32_from_fp16_param_idx == idx_to_fp32_from_fp16_params.size() - 1) {
+ if (partial_fp16_grads_needing_unscale.size() > 0) {
+ std::vector> tensor_lists;
+ tensor_lists.emplace_back(partial_fp16_grads_needing_unscale);
+ tensor_lists.emplace_back(partial_new_fp32_grads);
+ multi_tensor_scale_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale);
+
+ partial_fp16_grads_needing_unscale.clear();
+ partial_new_fp32_grads.clear();
+ acc_size = 0;
}
+ }
+ ++fp32_from_fp16_param_idx;
}
+ }
- if (need_reset_states) {
- std::sort(idx_to_numel_map.begin(), idx_to_numel_map.end(), SortByElementSizeDesc);
- }
-
- if (idx_to_fp32_from_fp16_params.size() > 0) {
- auto mem_buffer = MemoryBuffer(memory_buffer_size, idx_to_fp32_from_fp16_params.begin()->second);
- const size_t emit_threshhold = memory_buffer_size / EMIT_NUM;
-
- size_t acc_size = 0;
- std::vector partial_new_fp32_grads;
- std::vector partial_fp16_grads_needing_unscale;
- for (size_t idx = 0, fp32_from_fp16_param_idx = 0; idx < idx_to_numel_map.size(); ++idx) {
- if (idx_to_fp32_from_fp16_params.find(idx) == idx_to_fp32_from_fp16_params.end()) {
- continue;
- }
-
- acc_size += idx_to_numel_map[idx].second;
- idx_to_fp32_from_fp16_params[idx].mutable_grad() =
- mem_buffer.Get(idx_to_fp32_from_fp16_params[idx],
- memory_buffer_idx_to_offset_map[fp32_from_fp16_param_idx]);
- partial_new_fp32_grads.emplace_back(idx_to_fp32_from_fp16_params[idx].grad());
- partial_fp16_grads_needing_unscale.emplace_back(fp16_grads_needing_unscale[fp32_from_fp16_param_idx]);
-
- if (acc_size > emit_threshhold || fp32_from_fp16_param_idx == idx_to_fp32_from_fp16_params.size() - 1) {
- if (partial_fp16_grads_needing_unscale.size() > 0) {
- std::vector> tensor_lists;
- tensor_lists.emplace_back(partial_fp16_grads_needing_unscale);
- tensor_lists.emplace_back(partial_new_fp32_grads);
- multi_tensor_scale_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale);
-
- partial_fp16_grads_needing_unscale.clear();
- partial_new_fp32_grads.clear();
- acc_size = 0;
- }
- }
- ++fp32_from_fp16_param_idx;
- }
- }
-
- if (fp16_grads_needing_unscale_with_stash.size() > 0) {
- std::vector> tensor_lists;
- tensor_lists.emplace_back(fp16_grads_needing_unscale_with_stash);
- tensor_lists.emplace_back(preexisting_fp32_grads);
- tensor_lists.emplace_back(preexisting_fp32_grads);
- // a * x + b * y
- multi_tensor_axpby_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale, float(1.0), 0);
- }
+ if (fp16_grads_needing_unscale_with_stash.size() > 0) {
+ std::vector> tensor_lists;
+ tensor_lists.emplace_back(fp16_grads_needing_unscale_with_stash);
+ tensor_lists.emplace_back(preexisting_fp32_grads);
+ tensor_lists.emplace_back(preexisting_fp32_grads);
+ // a * x + b * y
+ multi_tensor_axpby_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale, float(1.0), 0);
+ }
};
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
-{
- // Cannot use the shortcut API below because https://github.com/pybind/pybind11/issues/1470
- // py::bind_vector>(m, "TorchTensorVector");
- py::class_>(m, "TorchTensorVector")
- .def(py::init<>())
- .def("clear", &std::vector::clear)
- .def("pop_back", &std::vector::pop_back)
- .def("__len__", [](const std::vector &v) { return v.size(); })
- .def("__iter__", [](std::vector &v) {
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ // Cannot use the shortcut API below because https://github.com/pybind/pybind11/issues/1470
+ // py::bind_vector>(m, "TorchTensorVector");
+ py::class_>(m, "TorchTensorVector")
+ .def(py::init<>())
+ .def("clear", &std::vector::clear)
+ .def("pop_back", &std::vector::pop_back)
+ .def("__len__", [](const std::vector& v) { return v.size(); })
+ .def(
+ "__iter__", [](std::vector& v) {
return py::make_iterator(v.begin(), v.end());
- }, py::keep_alive<0, 1>())
- .def("extend", [](std::vector &v, const std::vector &src) {
- v.insert(v.end(), src.begin(), src.end());
- })
- .def(py::init([](const py::iterable &it) {
- auto v = std::unique_ptr>(new std::vector());
- v->reserve(py::len_hint(it));
- for (py::handle h : it) {
- v->push_back(h.cast());
- }
- return v.release();
- }));
+ },
+ py::keep_alive<0, 1>())
+ .def("extend", [](std::vector& v, const std::vector& src) {
+ v.insert(v.end(), src.begin(), src.end());
+ })
+ .def(py::init([](const py::iterable& it) {
+ auto v = std::unique_ptr>(new std::vector());
+ v->reserve(py::len_hint(it));
+ for (py::handle h : it) {
+ v->push_back(h.cast());
+ }
+ return v.release();
+ }));
- m.def("multi_tensor_adam",
- &multi_tensor_adam_cuda,
- "Compute and apply gradient update to parameters for Adam optimizer");
- m.def("unscale_fp16_grads_into_fp32_grads",
- &unscale_fp16_grads_into_fp32_grads,
- "Unscale those fp16 gradients into fp32 gradient buffers.");
+ m.def("multi_tensor_adam",
+ &multi_tensor_adam_cuda,
+ "Compute and apply gradient update to parameters for Adam optimizer");
+ m.def("unscale_fp16_grads_into_fp32_grads",
+ &unscale_fp16_grads_into_fp32_grads,
+ "Unscale those fp16 gradients into fp32 gradient buffers.");
}
diff --git a/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc b/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc
index 675d1ad7c6..d8d2cb8e83 100644
--- a/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc
+++ b/orttraining/orttraining/test/training_ops/cpu/math/isfinite_ops_test.cc
@@ -171,7 +171,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatTensorLarge) {
OpTester test("IsAllFinite", 1, kMSDomain);
bool expected_answer = false;
auto tensors = generate_is_all_finite_test_data(13, 941736, expected_answer, test_count);
- for (int i = 0; i < tensors.size(); ++i) {
+ for (size_t i = 0; i < tensors.size(); ++i) {
auto name = std::string("X") + std::to_string(i);
auto size = static_cast(tensors[i].size());
test.AddInput(name.c_str(), {size}, tensors[i]);
@@ -186,7 +186,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlock) {
OpTester test("IsAllFinite", 1, kMSDomain);
bool expected_answer = false;
auto tensors = generate_is_all_finite_test_data(894, 17, expected_answer, test_count);
- for (int i = 0; i < tensors.size(); ++i) {
+ for (size_t i = 0; i < tensors.size(); ++i) {
auto name = std::string("X") + std::to_string(i);
auto size = static_cast(tensors[i].size());
test.AddInput(name.c_str(), {size}, tensors[i]);
@@ -199,7 +199,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlock) {
TEST(IsAllFiniteTest, MoreFalseFloatMultipleFalse) {
OpTester test("IsAllFinite", 1, kMSDomain);
auto tensors = generate_is_all_finite_test_data(1234, 1987, 0.1f, 0);
- for (int i = 0; i < tensors.size(); ++i) {
+ for (size_t i = 0; i < tensors.size(); ++i) {
auto name = std::string("X") + std::to_string(i);
auto size = static_cast(tensors[i].size());
test.AddInput(name.c_str(), {size}, tensors[i]);
@@ -212,7 +212,7 @@ TEST(IsAllFiniteTest, MoreTrueFloatTensorLarge) {
OpTester test("IsAllFinite", 1, kMSDomain);
bool expected_answer = true;
auto tensors = generate_is_all_finite_test_data(12, 941736, expected_answer, 0);
- for (int i = 0; i < tensors.size(); ++i) {
+ for (size_t i = 0; i < tensors.size(); ++i) {
auto name = std::string("X") + std::to_string(i);
auto size = static_cast(tensors[i].size());
test.AddInput(name.c_str(), {size}, tensors[i]);
@@ -227,7 +227,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlockFloat16) {
OpTester test("IsAllFinite", 1, kMSDomain);
bool expected_answer = false;
auto tensors = generate_is_all_finite_test_data(894, 17, expected_answer, 0);
- for (int i = 0; i < tensors.size(); ++i) {
+ for (size_t i = 0; i < tensors.size(); ++i) {
auto name = std::string("X") + std::to_string(i);
auto size = static_cast(tensors[i].size());
std::vector buffer_half(tensors[i].size());
@@ -242,7 +242,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatTensorLargeFloat16) {
OpTester test("IsAllFinite", 1, kMSDomain);
bool expected_answer = false;
auto tensors = generate_is_all_finite_test_data(12, 941736, expected_answer, 0);
- for (int i = 0; i < tensors.size(); ++i) {
+ for (size_t i = 0; i < tensors.size(); ++i) {
auto name = std::string("X") + std::to_string(i);
auto size = static_cast(tensors[i].size());
std::vector buffer_half(tensors[i].size());