mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Fix a potential race in the CUDA TopK kernel (#19917)
### Description If the `K` value is flowing through as a tensor, we are updating a mutable member of the `TopK` class and basing the compute off that - which is likely to cause data race issues with concurrent Run() calls and `K` value changes. ### Motivation and Context Fix potential race in CUDA TopK kernel
This commit is contained in:
parent
bcf47d3546
commit
42399dfd2b
2 changed files with 18 additions and 8 deletions
|
|
@ -56,7 +56,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
|
|||
info.GetAttrOrDefault<int64_t>("largest", &largest_, 1);
|
||||
info.GetAttrOrDefault<int64_t>("sorted", &sorted_, 1);
|
||||
if (!inputk) {
|
||||
info.GetAttrOrDefault<int64_t>("k", &K_, 0);
|
||||
info.GetAttrOrDefault<int64_t>("k", &attr_k_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
|
|||
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
|
||||
elem_nums_cuda, \
|
||||
elem_nums.size(), \
|
||||
axis, K_, largest_, sorted_, N, dimension)
|
||||
axis, k_value, largest_, sorted_, N, dimension)
|
||||
|
||||
template <bool inputk>
|
||||
Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
|
|
@ -77,19 +77,29 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
int32_t axis = static_cast<int32_t>(axis_ < 0 ? rank + axis_ : axis_);
|
||||
ORT_ENFORCE(axis > -1 && axis < rank);
|
||||
|
||||
int64_t k_value = 0;
|
||||
if (inputk) {
|
||||
auto tensor_K = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(nullptr != tensor_K);
|
||||
K_ = *tensor_K->Data<int64_t>();
|
||||
ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]);
|
||||
k_value = *tensor_K->Data<int64_t>();
|
||||
} else { // from attribute
|
||||
k_value = attr_k_;
|
||||
}
|
||||
|
||||
auto output_shape = tensor_X->Shape();
|
||||
output_shape[axis] = K_;
|
||||
// Now that we know the value of 'K' and the input shape,
|
||||
// make a final validation before going to the implementation
|
||||
const auto& input_shape = tensor_X->Shape();
|
||||
if ((k_value < 0) || (k_value > input_shape.GetDims()[axis])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Value of K outside range. K value: ", k_value,
|
||||
". Input shape: ", input_shape, " . Axis: ", axis);
|
||||
}
|
||||
|
||||
auto output_shape = input_shape;
|
||||
output_shape[axis] = k_value;
|
||||
auto tensor_V = ctx->Output(0, output_shape);
|
||||
auto tensor_I = ctx->Output(1, output_shape);
|
||||
|
||||
if (0 == K_) {
|
||||
if (output_shape.Size() == 0) { // Bail out early if the output is going to be empty
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class TopK final : public CudaKernel {
|
|||
int64_t axis_;
|
||||
int64_t largest_;
|
||||
int64_t sorted_;
|
||||
mutable int64_t K_;
|
||||
int64_t attr_k_;
|
||||
};
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue