mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Improve NonZero on CUDA/ROCM (#10307)
* improve NonZero * fix megatron_fp16 optimzier, fix the doc * multi_tensor_applier * resolve comment * fix building warning * fix build error when enabling training and use tensorrt
This commit is contained in:
parent
1e917c879e
commit
89ef987ab1
8 changed files with 212 additions and 153 deletions
|
|
@ -632,6 +632,9 @@ if (onnxruntime_USE_TENSORRT)
|
|||
# Needed for the provider interface, as it includes training headers when training is enabled
|
||||
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
|
||||
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ORTTRAINING_ROOT})
|
||||
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt Python::Module)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(APPLE)
|
||||
|
|
|
|||
|
|
@ -597,8 +597,8 @@ Do not modify directly.*
|
|||
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|
||||
|Neg|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|
||||
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)|
|
||||
|NonZero|*in* X:**T**<br> *out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)|
|
||||
|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(int32), tensor(int64), tensor(uint8)|
|
||||
|NonZero|*in* X:**T**<br> *out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|
||||
|||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)|
|
||||
|Not|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bool)|
|
||||
|OneHot|*in* indices:**T1**<br> *in* depth:**T2**<br> *in* values:**T3**<br> *out* output:**T3**|11+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)<br/> **T3** = tensor(float), tensor(float16), tensor(int64)|
|
||||
|Or|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|
||||
|
|
|
|||
|
|
@ -791,6 +791,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, float, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, If);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, 8, Scan);
|
||||
|
|
@ -1085,6 +1086,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast);
|
||||
|
|
@ -1661,6 +1663,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, float, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 9, TopK)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, 8, Scan)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, Scan)>,
|
||||
|
|
@ -1955,6 +1958,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast)>,
|
||||
|
|
|
|||
|
|
@ -69,6 +69,40 @@ __global__ void NonZeroOutputPositionsKernel(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
constexpr int MAX_DIMS = 16;
|
||||
|
||||
template <typename InputT, int THREADS_PER_BLOCK>
|
||||
__global__ void UnRolledNonZeroOutputPositionsKernel(
|
||||
const InputT* x, int64_t x_size, int x_rank, const TArray<fast_divmod> x_strides,
|
||||
const int* prefix_counts, int nonzero_elements, int64_t* results) {
|
||||
typedef cub::BlockScan<int, THREADS_PER_BLOCK> BlockScanT;
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
|
||||
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// const cub::CastOp<bool> cast_to_bool; not supported on amd hipcub
|
||||
int nz = 0;
|
||||
if (index < x_size && bool(x[index])) ++nz;
|
||||
int pos_in_block = 0;
|
||||
BlockScanT(temp_storage).InclusiveSum(nz, pos_in_block);
|
||||
|
||||
int result_position = ((blockIdx.x == 0) ? 0 : prefix_counts[blockIdx.x - 1]) + pos_in_block - nz;
|
||||
|
||||
if (index < x_size && bool(x[index])) {
|
||||
int remain = (int)index, dim = 0;
|
||||
int rp = result_position;
|
||||
#pragma unroll
|
||||
for (int axis = 0; axis < MAX_DIMS; ++axis) {
|
||||
if (axis == x_rank) {
|
||||
break;
|
||||
}
|
||||
x_strides[axis].divmod(remain, dim, remain);
|
||||
results[rp] = (int64_t)dim;
|
||||
rp += nonzero_elements;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InputT>
|
||||
cudaError_t NonZeroCountEachBlock(cudaStream_t stream, const InputT* x, int64_t x_size, int* count_in_blocks) {
|
||||
int num_blocks = NonZeroCalcBlockCount(x_size);
|
||||
|
|
@ -82,9 +116,15 @@ cudaError_t NonZeroOutputPositions(
|
|||
cudaStream_t stream, const InputT* x, int64_t x_size, int x_rank, const TArray<fast_divmod>& x_strides,
|
||||
const int* prefix_counts, int nonzero_elements, int64_t* results) {
|
||||
int num_blocks = NonZeroCalcBlockCount(x_size);
|
||||
NonZeroOutputPositionsKernel<InputT, NONZERO_THREADS_PER_BLOCK><<<num_blocks, NONZERO_THREADS_PER_BLOCK, 0, stream>>>(
|
||||
if (x_rank > MAX_DIMS) {
|
||||
NonZeroOutputPositionsKernel<InputT, NONZERO_THREADS_PER_BLOCK><<<num_blocks, NONZERO_THREADS_PER_BLOCK, 0, stream>>>(
|
||||
x, x_size, x_rank, x_strides,
|
||||
prefix_counts, nonzero_elements, results);
|
||||
} else {
|
||||
UnRolledNonZeroOutputPositionsKernel<InputT, NONZERO_THREADS_PER_BLOCK><<<num_blocks, NONZERO_THREADS_PER_BLOCK, 0, stream>>>(
|
||||
x, x_size, x_rank, x_strides,
|
||||
prefix_counts, nonzero_elements, results);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -710,6 +710,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, TopK);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, If);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 8, Scan);
|
||||
|
|
@ -1004,6 +1005,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Cast);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast);
|
||||
|
|
@ -1232,7 +1234,7 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
|
||||
static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
|
||||
|
|
@ -1579,6 +1581,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int32_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, int64_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 9, TopK)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 8, 8, Scan)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 10, Scan)>,
|
||||
|
|
@ -1873,6 +1876,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, NonZero)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Cast)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Cast)>,
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type,
|
|||
norm_type = float(norm_type)
|
||||
total_norm = 0.0
|
||||
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.abs().max() for grad in grads_for_norm)
|
||||
|
|
@ -104,7 +105,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type,
|
|||
|
||||
else:
|
||||
if norm_type == 2.0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
# Use apex's multi-tensor applier for efficiency reasons.
|
||||
# Multi-tensor applier takes a function and a list of list
|
||||
# and performs the operation on that list all in one kernel.
|
||||
|
|
@ -133,6 +133,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type,
|
|||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=get_horizontal_model_parallel_group())
|
||||
total_norm = total_norm.item() ** (1.0 / norm_type)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Filter parameters with gradients.
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
if clip_coef < 1.0:
|
||||
multi_tensor_applier(
|
||||
amp_C.multi_tensor_scale,
|
||||
dummy_overflow_buf,
|
||||
[grads, grads],
|
||||
clip_coef)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
|
|
|||
|
|
@ -26,15 +26,14 @@ void multi_tensor_adam_cuda(int chunk_size,
|
|||
const int bias_correction,
|
||||
const float weight_decay);
|
||||
|
||||
// This function is adapted from NVIDIA/apex
|
||||
// This function is adapted from NVIDIA/apex
|
||||
// https://github.com/NVIDIA/apex/blob/0c7d8e3fa9a095a1641a2290877436d0314b69c6/csrc/amp_C_frontend.cpp#L3
|
||||
void multi_tensor_scale_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
float scale);
|
||||
|
||||
|
||||
// This function is adapted from NVIDIA/apex
|
||||
// This function is adapted from NVIDIA/apex
|
||||
// https://github.com/NVIDIA/apex/blob/0c7d8e3fa9a095a1641a2290877436d0314b69c6/csrc/amp_C_frontend.cpp#L22
|
||||
void multi_tensor_axpby_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
|
|
@ -43,43 +42,42 @@ void multi_tensor_axpby_cuda(int chunk_size,
|
|||
float b,
|
||||
int arg_to_check);
|
||||
|
||||
|
||||
class MemoryBuffer {
|
||||
public:
|
||||
MemoryBuffer(size_t numel, at::Tensor val){
|
||||
data_buffer_ = at::empty({numel}, val.options());
|
||||
}
|
||||
public:
|
||||
MemoryBuffer(size_t numel, at::Tensor val) {
|
||||
data_buffer_ = at::empty({static_cast<int64_t>(numel)}, val.options());
|
||||
}
|
||||
|
||||
at::Tensor Get(at::Tensor param, size_t start_index) {
|
||||
size_t end_index = start_index + param.numel();
|
||||
return data_buffer_.slice(0, start_index, end_index).view(param.sizes());
|
||||
}
|
||||
at::Tensor Get(at::Tensor param, size_t start_index) {
|
||||
size_t end_index = start_index + param.numel();
|
||||
return data_buffer_.slice(0, start_index, end_index).view(param.sizes());
|
||||
}
|
||||
|
||||
private:
|
||||
at::Tensor data_buffer_;
|
||||
private:
|
||||
at::Tensor data_buffer_;
|
||||
};
|
||||
|
||||
class CachedStates {
|
||||
public:
|
||||
static CachedStates& GetInstance(){
|
||||
static CachedStates states_;
|
||||
return states_;
|
||||
}
|
||||
public:
|
||||
static CachedStates& GetInstance() {
|
||||
static CachedStates states_;
|
||||
return states_;
|
||||
}
|
||||
|
||||
void ClearStates(){
|
||||
idx_to_numel_map.clear();
|
||||
}
|
||||
void ClearStates() {
|
||||
idx_to_numel_map.clear();
|
||||
}
|
||||
|
||||
// Parameter index to number of element mapping for each parameter.
|
||||
std::vector<std::pair<size_t, size_t>> idx_to_numel_map;
|
||||
// Parameter index to number of element mapping for each parameter.
|
||||
std::vector<std::pair<size_t, size_t>> idx_to_numel_map;
|
||||
|
||||
private:
|
||||
CachedStates(){}
|
||||
private:
|
||||
CachedStates() {}
|
||||
};
|
||||
|
||||
bool SortByElementSizeDesc(const std::pair<size_t, size_t> &a,
|
||||
const std::pair<size_t, size_t> &b) {
|
||||
return (a.second > b.second);
|
||||
bool SortByElementSizeDesc(const std::pair<size_t, size_t>& a,
|
||||
const std::pair<size_t, size_t>& b) {
|
||||
return (a.second > b.second);
|
||||
};
|
||||
|
||||
// This function is trying to move into C++ implementation from Python logic
|
||||
|
|
@ -89,134 +87,135 @@ void unscale_fp16_grads_into_fp32_grads(std::vector<at::Tensor>& all_fp16_params
|
|||
std::vector<at::Tensor>& all_fp32_from_fp16_params,
|
||||
at::Tensor is_overflow_buffer,
|
||||
float scale) {
|
||||
if (all_fp16_params.size() == 0 || all_fp32_from_fp16_params.size() == 0) {
|
||||
return;
|
||||
}
|
||||
if (all_fp16_params.size() == 0 || all_fp32_from_fp16_params.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float inv_scale = 1.0 / scale;
|
||||
TORCH_CHECK(all_fp16_params.size() == all_fp32_from_fp16_params.size(),
|
||||
"mismatch param size between fp16_param and fp32_from_fp16_param.");
|
||||
const float inv_scale = 1.0 / scale;
|
||||
TORCH_CHECK(all_fp16_params.size() == all_fp32_from_fp16_params.size(),
|
||||
"mismatch param size between fp16_param and fp32_from_fp16_param.");
|
||||
|
||||
// Use cached states only parameter count did not get changed.
|
||||
bool need_reset_states =
|
||||
all_fp32_from_fp16_params.size() != CachedStates::GetInstance().idx_to_numel_map.size();
|
||||
// Use cached states only parameter count did not get changed.
|
||||
bool need_reset_states =
|
||||
all_fp32_from_fp16_params.size() != CachedStates::GetInstance().idx_to_numel_map.size();
|
||||
if (need_reset_states) {
|
||||
CachedStates::GetInstance().ClearStates();
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> fp16_grads_needing_unscale;
|
||||
std::vector<at::Tensor> fp16_grads_needing_unscale_with_stash;
|
||||
std::vector<at::Tensor> preexisting_fp32_grads;
|
||||
|
||||
// Parameter index to parameter mapping for each fp32_from_fp16 parameter.
|
||||
std::unordered_map<size_t, at::Tensor> idx_to_fp32_from_fp16_params;
|
||||
|
||||
// "buffer index" to "offset in memory buffer" mapping for each fp32_from_fp16 parameter.
|
||||
std::vector<size_t> memory_buffer_idx_to_offset_map;
|
||||
size_t memory_buffer_size = 0;
|
||||
auto& idx_to_numel_map = CachedStates::GetInstance().idx_to_numel_map;
|
||||
|
||||
for (size_t idx = 0; idx < all_fp16_params.size(); ++idx) {
|
||||
auto& fp16_param_grad = all_fp16_params[idx].grad();
|
||||
bool fp16_param_has_grad = fp16_param_grad.defined();
|
||||
|
||||
auto& fp32_from_fp16_param = all_fp32_from_fp16_params[idx];
|
||||
auto& fp32_from_fp16_param_grad = fp32_from_fp16_param.grad();
|
||||
bool fp32_from_fp16_param_has_grad = fp32_from_fp16_param_grad.defined();
|
||||
|
||||
size_t num_elem = fp32_from_fp16_param.numel();
|
||||
if (need_reset_states) {
|
||||
CachedStates::GetInstance().ClearStates();
|
||||
idx_to_numel_map.push_back(std::make_pair(idx, num_elem));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> fp16_grads_needing_unscale;
|
||||
std::vector<at::Tensor> fp16_grads_needing_unscale_with_stash;
|
||||
std::vector<at::Tensor> preexisting_fp32_grads;
|
||||
if (fp16_param_has_grad && !fp32_from_fp16_param_has_grad) {
|
||||
idx_to_fp32_from_fp16_params.emplace(std::make_pair(idx, fp32_from_fp16_param));
|
||||
fp16_grads_needing_unscale.emplace_back(fp16_param_grad);
|
||||
memory_buffer_idx_to_offset_map.emplace_back(memory_buffer_size);
|
||||
memory_buffer_size += num_elem;
|
||||
} else if (fp16_param_has_grad && fp32_from_fp16_param_has_grad) {
|
||||
fp16_grads_needing_unscale_with_stash.emplace_back(fp16_param_grad);
|
||||
preexisting_fp32_grads.emplace_back(fp32_from_fp16_param_grad);
|
||||
}
|
||||
}
|
||||
|
||||
// Parameter index to parameter mapping for each fp32_from_fp16 parameter.
|
||||
std::unordered_map<size_t, at::Tensor> idx_to_fp32_from_fp16_params;
|
||||
if (need_reset_states) {
|
||||
std::sort(idx_to_numel_map.begin(), idx_to_numel_map.end(), SortByElementSizeDesc);
|
||||
}
|
||||
|
||||
// "buffer index" to "offset in memory buffer" mapping for each fp32_from_fp16 parameter.
|
||||
std::vector<size_t> memory_buffer_idx_to_offset_map;
|
||||
size_t memory_buffer_size = 0;
|
||||
auto& idx_to_numel_map = CachedStates::GetInstance().idx_to_numel_map;
|
||||
if (idx_to_fp32_from_fp16_params.size() > 0) {
|
||||
auto mem_buffer = MemoryBuffer(memory_buffer_size, idx_to_fp32_from_fp16_params.begin()->second);
|
||||
const size_t emit_threshhold = memory_buffer_size / EMIT_NUM;
|
||||
|
||||
for (size_t idx = 0; idx < all_fp16_params.size(); ++idx) {
|
||||
auto& fp16_param_grad = all_fp16_params[idx].grad();
|
||||
bool fp16_param_has_grad = fp16_param_grad.defined();
|
||||
size_t acc_size = 0;
|
||||
std::vector<at::Tensor> partial_new_fp32_grads;
|
||||
std::vector<at::Tensor> partial_fp16_grads_needing_unscale;
|
||||
for (size_t idx = 0, fp32_from_fp16_param_idx = 0; idx < idx_to_numel_map.size(); ++idx) {
|
||||
if (idx_to_fp32_from_fp16_params.find(idx) == idx_to_fp32_from_fp16_params.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& fp32_from_fp16_param = all_fp32_from_fp16_params[idx];
|
||||
auto& fp32_from_fp16_param_grad = fp32_from_fp16_param.grad();
|
||||
bool fp32_from_fp16_param_has_grad = fp32_from_fp16_param_grad.defined();
|
||||
acc_size += idx_to_numel_map[idx].second;
|
||||
idx_to_fp32_from_fp16_params[idx].mutable_grad() =
|
||||
mem_buffer.Get(idx_to_fp32_from_fp16_params[idx],
|
||||
memory_buffer_idx_to_offset_map[fp32_from_fp16_param_idx]);
|
||||
partial_new_fp32_grads.emplace_back(idx_to_fp32_from_fp16_params[idx].grad());
|
||||
partial_fp16_grads_needing_unscale.emplace_back(fp16_grads_needing_unscale[fp32_from_fp16_param_idx]);
|
||||
|
||||
size_t num_elem = fp32_from_fp16_param.numel();
|
||||
if (need_reset_states) {
|
||||
idx_to_numel_map.push_back(std::make_pair(idx, num_elem));
|
||||
}
|
||||
|
||||
if (fp16_param_has_grad && !fp32_from_fp16_param_has_grad) {
|
||||
idx_to_fp32_from_fp16_params.emplace(std::make_pair(idx, fp32_from_fp16_param));
|
||||
fp16_grads_needing_unscale.emplace_back(fp16_param_grad);
|
||||
memory_buffer_idx_to_offset_map.emplace_back(memory_buffer_size);
|
||||
memory_buffer_size += num_elem;
|
||||
} else if (fp16_param_has_grad && fp32_from_fp16_param_has_grad) {
|
||||
fp16_grads_needing_unscale_with_stash.emplace_back(fp16_param_grad);
|
||||
preexisting_fp32_grads.emplace_back(fp32_from_fp16_param_grad);
|
||||
if (acc_size > emit_threshhold || fp32_from_fp16_param_idx == idx_to_fp32_from_fp16_params.size() - 1) {
|
||||
if (partial_fp16_grads_needing_unscale.size() > 0) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.emplace_back(partial_fp16_grads_needing_unscale);
|
||||
tensor_lists.emplace_back(partial_new_fp32_grads);
|
||||
multi_tensor_scale_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale);
|
||||
|
||||
partial_fp16_grads_needing_unscale.clear();
|
||||
partial_new_fp32_grads.clear();
|
||||
acc_size = 0;
|
||||
}
|
||||
}
|
||||
++fp32_from_fp16_param_idx;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_reset_states) {
|
||||
std::sort(idx_to_numel_map.begin(), idx_to_numel_map.end(), SortByElementSizeDesc);
|
||||
}
|
||||
|
||||
if (idx_to_fp32_from_fp16_params.size() > 0) {
|
||||
auto mem_buffer = MemoryBuffer(memory_buffer_size, idx_to_fp32_from_fp16_params.begin()->second);
|
||||
const size_t emit_threshhold = memory_buffer_size / EMIT_NUM;
|
||||
|
||||
size_t acc_size = 0;
|
||||
std::vector<at::Tensor> partial_new_fp32_grads;
|
||||
std::vector<at::Tensor> partial_fp16_grads_needing_unscale;
|
||||
for (size_t idx = 0, fp32_from_fp16_param_idx = 0; idx < idx_to_numel_map.size(); ++idx) {
|
||||
if (idx_to_fp32_from_fp16_params.find(idx) == idx_to_fp32_from_fp16_params.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
acc_size += idx_to_numel_map[idx].second;
|
||||
idx_to_fp32_from_fp16_params[idx].mutable_grad() =
|
||||
mem_buffer.Get(idx_to_fp32_from_fp16_params[idx],
|
||||
memory_buffer_idx_to_offset_map[fp32_from_fp16_param_idx]);
|
||||
partial_new_fp32_grads.emplace_back(idx_to_fp32_from_fp16_params[idx].grad());
|
||||
partial_fp16_grads_needing_unscale.emplace_back(fp16_grads_needing_unscale[fp32_from_fp16_param_idx]);
|
||||
|
||||
if (acc_size > emit_threshhold || fp32_from_fp16_param_idx == idx_to_fp32_from_fp16_params.size() - 1) {
|
||||
if (partial_fp16_grads_needing_unscale.size() > 0) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.emplace_back(partial_fp16_grads_needing_unscale);
|
||||
tensor_lists.emplace_back(partial_new_fp32_grads);
|
||||
multi_tensor_scale_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale);
|
||||
|
||||
partial_fp16_grads_needing_unscale.clear();
|
||||
partial_new_fp32_grads.clear();
|
||||
acc_size = 0;
|
||||
}
|
||||
}
|
||||
++fp32_from_fp16_param_idx;
|
||||
}
|
||||
}
|
||||
|
||||
if (fp16_grads_needing_unscale_with_stash.size() > 0) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.emplace_back(fp16_grads_needing_unscale_with_stash);
|
||||
tensor_lists.emplace_back(preexisting_fp32_grads);
|
||||
tensor_lists.emplace_back(preexisting_fp32_grads);
|
||||
// a * x + b * y
|
||||
multi_tensor_axpby_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale, float(1.0), 0);
|
||||
}
|
||||
if (fp16_grads_needing_unscale_with_stash.size() > 0) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.emplace_back(fp16_grads_needing_unscale_with_stash);
|
||||
tensor_lists.emplace_back(preexisting_fp32_grads);
|
||||
tensor_lists.emplace_back(preexisting_fp32_grads);
|
||||
// a * x + b * y
|
||||
multi_tensor_axpby_cuda(MTA_CHUNK_SIZE, is_overflow_buffer, tensor_lists, inv_scale, float(1.0), 0);
|
||||
}
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
// Cannot use the shortcut API below because https://github.com/pybind/pybind11/issues/1470
|
||||
// py::bind_vector<std::vector<at::Tensor>>(m, "TorchTensorVector");
|
||||
py::class_<std::vector<at::Tensor>>(m, "TorchTensorVector")
|
||||
.def(py::init<>())
|
||||
.def("clear", &std::vector<at::Tensor>::clear)
|
||||
.def("pop_back", &std::vector<at::Tensor>::pop_back)
|
||||
.def("__len__", [](const std::vector<at::Tensor> &v) { return v.size(); })
|
||||
.def("__iter__", [](std::vector<at::Tensor> &v) {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// Cannot use the shortcut API below because https://github.com/pybind/pybind11/issues/1470
|
||||
// py::bind_vector<std::vector<at::Tensor>>(m, "TorchTensorVector");
|
||||
py::class_<std::vector<at::Tensor>>(m, "TorchTensorVector")
|
||||
.def(py::init<>())
|
||||
.def("clear", &std::vector<at::Tensor>::clear)
|
||||
.def("pop_back", &std::vector<at::Tensor>::pop_back)
|
||||
.def("__len__", [](const std::vector<at::Tensor>& v) { return v.size(); })
|
||||
.def(
|
||||
"__iter__", [](std::vector<at::Tensor>& v) {
|
||||
return py::make_iterator(v.begin(), v.end());
|
||||
}, py::keep_alive<0, 1>())
|
||||
.def("extend", [](std::vector<at::Tensor> &v, const std::vector<at::Tensor> &src) {
|
||||
v.insert(v.end(), src.begin(), src.end());
|
||||
})
|
||||
.def(py::init([](const py::iterable &it) {
|
||||
auto v = std::unique_ptr<std::vector<at::Tensor>>(new std::vector<at::Tensor>());
|
||||
v->reserve(py::len_hint(it));
|
||||
for (py::handle h : it) {
|
||||
v->push_back(h.cast<at::Tensor>());
|
||||
}
|
||||
return v.release();
|
||||
}));
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def("extend", [](std::vector<at::Tensor>& v, const std::vector<at::Tensor>& src) {
|
||||
v.insert(v.end(), src.begin(), src.end());
|
||||
})
|
||||
.def(py::init([](const py::iterable& it) {
|
||||
auto v = std::unique_ptr<std::vector<at::Tensor>>(new std::vector<at::Tensor>());
|
||||
v->reserve(py::len_hint(it));
|
||||
for (py::handle h : it) {
|
||||
v->push_back(h.cast<at::Tensor>());
|
||||
}
|
||||
return v.release();
|
||||
}));
|
||||
|
||||
m.def("multi_tensor_adam",
|
||||
&multi_tensor_adam_cuda,
|
||||
"Compute and apply gradient update to parameters for Adam optimizer");
|
||||
m.def("unscale_fp16_grads_into_fp32_grads",
|
||||
&unscale_fp16_grads_into_fp32_grads,
|
||||
"Unscale those fp16 gradients into fp32 gradient buffers.");
|
||||
m.def("multi_tensor_adam",
|
||||
&multi_tensor_adam_cuda,
|
||||
"Compute and apply gradient update to parameters for Adam optimizer");
|
||||
m.def("unscale_fp16_grads_into_fp32_grads",
|
||||
&unscale_fp16_grads_into_fp32_grads,
|
||||
"Unscale those fp16 gradients into fp32 gradient buffers.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatTensorLarge) {
|
|||
OpTester test("IsAllFinite", 1, kMSDomain);
|
||||
bool expected_answer = false;
|
||||
auto tensors = generate_is_all_finite_test_data(13, 941736, expected_answer, test_count);
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
auto name = std::string("X") + std::to_string(i);
|
||||
auto size = static_cast<int64_t>(tensors[i].size());
|
||||
test.AddInput<float>(name.c_str(), {size}, tensors[i]);
|
||||
|
|
@ -186,7 +186,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlock) {
|
|||
OpTester test("IsAllFinite", 1, kMSDomain);
|
||||
bool expected_answer = false;
|
||||
auto tensors = generate_is_all_finite_test_data(894, 17, expected_answer, test_count);
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
auto name = std::string("X") + std::to_string(i);
|
||||
auto size = static_cast<int64_t>(tensors[i].size());
|
||||
test.AddInput<float>(name.c_str(), {size}, tensors[i]);
|
||||
|
|
@ -199,7 +199,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlock) {
|
|||
TEST(IsAllFiniteTest, MoreFalseFloatMultipleFalse) {
|
||||
OpTester test("IsAllFinite", 1, kMSDomain);
|
||||
auto tensors = generate_is_all_finite_test_data(1234, 1987, 0.1f, 0);
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
auto name = std::string("X") + std::to_string(i);
|
||||
auto size = static_cast<int64_t>(tensors[i].size());
|
||||
test.AddInput<float>(name.c_str(), {size}, tensors[i]);
|
||||
|
|
@ -212,7 +212,7 @@ TEST(IsAllFiniteTest, MoreTrueFloatTensorLarge) {
|
|||
OpTester test("IsAllFinite", 1, kMSDomain);
|
||||
bool expected_answer = true;
|
||||
auto tensors = generate_is_all_finite_test_data(12, 941736, expected_answer, 0);
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
auto name = std::string("X") + std::to_string(i);
|
||||
auto size = static_cast<int64_t>(tensors[i].size());
|
||||
test.AddInput<float>(name.c_str(), {size}, tensors[i]);
|
||||
|
|
@ -227,7 +227,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatManyBlockFloat16) {
|
|||
OpTester test("IsAllFinite", 1, kMSDomain);
|
||||
bool expected_answer = false;
|
||||
auto tensors = generate_is_all_finite_test_data(894, 17, expected_answer, 0);
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
auto name = std::string("X") + std::to_string(i);
|
||||
auto size = static_cast<int64_t>(tensors[i].size());
|
||||
std::vector<MLFloat16> buffer_half(tensors[i].size());
|
||||
|
|
@ -242,7 +242,7 @@ TEST(IsAllFiniteTest, MoreFalseFloatTensorLargeFloat16) {
|
|||
OpTester test("IsAllFinite", 1, kMSDomain);
|
||||
bool expected_answer = false;
|
||||
auto tensors = generate_is_all_finite_test_data(12, 941736, expected_answer, 0);
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
auto name = std::string("X") + std::to_string(i);
|
||||
auto size = static_cast<int64_t>(tensors[i].size());
|
||||
std::vector<MLFloat16> buffer_half(tensors[i].size());
|
||||
|
|
|
|||
Loading…
Reference in a new issue