diff --git a/onnxruntime/core/providers/cuda/integer_gemm.cc b/onnxruntime/core/providers/cuda/integer_gemm.cc index 76bc708736..9c2a05e9a9 100644 --- a/onnxruntime/core/providers/cuda/integer_gemm.cc +++ b/onnxruntime/core/providers/cuda/integer_gemm.cc @@ -20,8 +20,9 @@ Status GemmInt8(int m, int n, int k, const CudaKernel* cuda_kernel, onnxruntime::Stream* ort_stream) { ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null"); + ORT_ENFORCE(ort_stream != nullptr, "Cuda kernel must have the stream instance"); - cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + cudaStream_t stream = static_cast(ort_stream->GetHandle()); // pad A and B to make their leading dimension be multiples of 32 // because cublasGemmEx requires: