mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
Fix type mismatch in CUDA Trilu op. (#12863)
Added type cast to int64_t to avoid overflow errors/alerts.
This commit is contained in:
parent
9f6d452ca6
commit
dcc8fe656b
1 changed files with 2 additions and 2 deletions
|
|
@ -34,12 +34,12 @@ Status Trilu::ComputeInternal(OpKernelContext* ctx) const {
|
|||
const Tensor& input = *input_ptr;
|
||||
const auto& shape = input.Shape();
|
||||
const auto& input_dims = shape.GetDims();
|
||||
int32_t rank = gsl::narrow_cast<int32_t>(input_dims.size());
|
||||
auto rank = input_dims.size();
|
||||
if (rank < 2) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensor should have a rank of at least 2");
|
||||
}
|
||||
Tensor* output = ctx->Output(0, shape);
|
||||
int64_t matrix_size = input_dims[rank - 1] * input_dims[rank - 2];
|
||||
auto matrix_size = input_dims[rank - 1] * input_dims[rank - 2];
|
||||
if (matrix_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue