diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c72ef5338b..fdfd64cb83 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -757,11 +757,6 @@ if (onnxruntime_USE_CUDA) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) set(ONNXRUNTIME_CUDA_LIBRARIES ${CUDA_LIBRARIES}) list(APPEND ONNXRUNTIME_CUDA_LIBRARIES cublas cudnn curand) - if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "10.1.0") - list(APPEND ONNXRUNTIME_CUDA_LIBRARIES cublasLt) - else() - message(WARNING "cublasLT is not supported in CUDA with version lower than 10.1.") - endif() if (WIN32) link_directories(${onnxruntime_CUDNN_HOME}/lib/x64) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 934caedaba..5b8fb9a1c6 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -551,7 +551,9 @@ if (onnxruntime_USE_DML) set(delayload_dml "/DELAYLOAD:directml.dll") endif(onnxruntime_USE_DML) -target_link_options(winml_dll PRIVATE /DEF:${WINML_DIR}/windows.ai.machinelearning.def ${os_component_link_flags} /DELAYLOAD:api-ms-win-core-libraryloader-l1-2-1.dll /DELAYLOAD:api-ms-win-core-threadpool-legacy-l1-1-0.dll /DELAYLOAD:api-ms-win-core-processtopology-obsolete-l1-1-0.dll /DELAYLOAD:api-ms-win-core-kernel32-legacy-l1-1-0.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:d3d11.dll /DELAYLOAD:dxgi.dll ${delayload_dml}) +set(os_component_link_flags_list ${os_component_link_flags}) +separate_arguments(os_component_link_flags_list) +target_link_options(winml_dll PRIVATE /DEF:${WINML_DIR}/windows.ai.machinelearning.def ${os_component_link_flags_list} /DELAYLOAD:api-ms-win-core-libraryloader-l1-2-1.dll /DELAYLOAD:api-ms-win-core-threadpool-legacy-l1-1-0.dll /DELAYLOAD:api-ms-win-core-processtopology-obsolete-l1-1-0.dll /DELAYLOAD:api-ms-win-core-kernel32-legacy-l1-1-0.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:d3d11.dll /DELAYLOAD:dxgi.dll ${delayload_dml}) if (EXISTS ${dxcore_header}) target_link_options(winml_dll PRIVATE /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index bf5182a2e7..1732a26c77 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -35,7 +35,7 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const { // Input 0 - input : (batch_size, sequence_length, hidden_size) // Input 1 - weights : (hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) - // Input 3 - mask_index : (batch_size) + // Input 3 - mask_index : (batch_size) if presented // Output : (batch_size, sequence_length, hidden_size) const Tensor* input = context->Input(0); @@ -77,13 +77,15 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const { } const Tensor* mask_index = context->Input(3); - const auto mask_dims = mask_index->Shape().GetDims(); - if (mask_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ", - mask_dims.size()); - } - if (static_cast(mask_dims[0]) != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0"); + if (mask_index != nullptr) { // mask_index is optional + const auto mask_dims = mask_index->Shape().GetDims(); + if (mask_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ", + mask_dims.size()); + } + if (static_cast(mask_dims[0]) != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0"); + } } return Status::OK(); @@ -179,22 +181,36 @@ Status Attention::Compute(OpKernelContext* context) const { // STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, // 1, 1) - auto scratch_data = - allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * sequence_length * element_size); + size_t scratch_data_bytes = SafeInt(batch_size) * num_heads_ * sequence_length * sequence_length * element_size; + auto scratch_data = allocator->Alloc(scratch_data_bytes); BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator)); { - auto scratch_broadcast_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * element_size); - BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator)); - memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size); - T* p_scratch_broadcast_current_data = reinterpret_cast(scratch_broadcast_data); - for (int b_i = 0; b_i < batch_size; b_i++) { - // TODO: mask_index can be used in softmax to save some calculation. - int mask = mask_index->template Data()[b_i]; - for (int m_i = mask; m_i < sequence_length; m_i++) { - p_scratch_broadcast_current_data[m_i] = static_cast(-10000.0); + size_t mask_data_bytes = 0; + if (mask_index != nullptr) { + mask_data_bytes = SafeInt(batch_size) * sequence_length * element_size; + } + + void* mask_data = nullptr; + if (mask_data_bytes > 0) { + mask_data = allocator->Alloc(mask_data_bytes); + memset(mask_data, 0, mask_data_bytes); + } + BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); + + if (mask_index != nullptr) { + T* p_mask = reinterpret_cast(mask_data); + for (int b_i = 0; b_i < batch_size; b_i++) { + // TODO: mask_index can be used in softmax to save some calculation. + // Convert mask_index to mask (-10000 means out of range, which will be 0 after softmax): B => BxS + int valid_length = mask_index->template Data()[b_i]; + for (int m_i = valid_length; m_i < sequence_length; m_i++) { + p_mask[m_i] = static_cast(-10000.0); + } + p_mask += sequence_length; } - p_scratch_broadcast_current_data += sequence_length; + } else { + memset(scratch_data, 0, scratch_data_bytes); } const int loop_len = batch_size * num_heads_; @@ -206,12 +222,15 @@ Status Attention::Compute(OpKernelContext* context) const { ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const std::ptrdiff_t batch_index = i / num_heads_; + // broadcast masks (B) -> (B.N.)S.S - const T* broadcast_data_src = reinterpret_cast(scratch_broadcast_data) + batch_index * sequence_length; - T* broadcast_data_dest = reinterpret_cast(scratch_data) + sequence_length * sequence_length * i; - for (int seq_index = 0; seq_index < sequence_length; seq_index++) { - memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T)); - broadcast_data_dest += sequence_length; + if (mask_index != nullptr) { + const T* broadcast_data_src = reinterpret_cast(mask_data) + batch_index * sequence_length; + T* broadcast_data_dest = reinterpret_cast(scratch_data) + sequence_length * sequence_length * i; + for (int seq_index = 0; seq_index < sequence_length; seq_index++) { + memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T)); + broadcast_data_dest += sequence_length; + } } // gemm diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 0b70556ffa..4959b07b6a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -40,7 +40,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Input 0 - input : (batch_size, sequence_length, hidden_size) // Input 1 - weights : (hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) - // Input 3 - mask_index : (batch_size) + // Input 3 - mask_index : (batch_size) if presented // Output : (batch_size, sequence_length, hidden_size) const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); @@ -88,7 +88,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto temp_buffer = GetScratchBuffer(workSpaceSize); if (!LaunchAttentionKernel( reinterpret_cast(gemm_buffer.get()), - mask_index->template Data(), + nullptr == mask_index ? nullptr : mask_index->template Data(), output->template MutableData(), batch_size, sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 26aee9affb..adb73cbf64 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -152,6 +152,38 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T* } } +template +__global__ void SoftmaxKernelSmall(const int sequence_length, const T* input, T* output) { + SoftmaxSmall(sequence_length, sequence_length, input, output); +} + +template +__global__ void SoftmaxKernel(const int sequence_length, const T* input, T* output) { + Softmax(sequence_length, sequence_length, input, output); +} + +template +bool ComputeSoftmax( + cudaStream_t stream, const int sequence_length, const int batch_size, const int num_heads, + const T* input, T* output) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + if (sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmall<<>>(sequence_length, input, output); + } else if (sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmall<<>>(sequence_length, input, output); + } else if (sequence_length == 384) { + const int blockSize = 384; + SoftmaxKernelSmall<<>>(sequence_length, input, output); + } else { + const int blockSize = 256; + SoftmaxKernel<<>>(sequence_length, input, output); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + template __global__ void MaskedSoftmaxKernelSmall(const int sequence_length, const int* mask_index, const T* input, T* output) { __shared__ int num_valid; @@ -390,8 +422,14 @@ bool QkvToContext( } // apply softmax and store result P to scratch2: BxNxSxS - if (!ComputeMaskedSoftmax(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) { - return false; + if (nullptr != mask_index) { + if (!ComputeMaskedSoftmax(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) { + return false; + } + } else { + if (!ComputeSoftmax(stream, sequence_length, batch_size, num_heads, scratch1, scratch2)) { + return false; + } } // compute P*V (as V*P), and store in scratch3: BxNxSxH diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 5e8a53c539..d3a1857fce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -7,20 +7,20 @@ namespace onnxruntime { namespace contrib { namespace cuda { - size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length); +size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length); - bool LaunchAttentionKernel( - const void* input, // Input tensor - const int* mask_index, // Nask index where each element is length of a sequence - void* output, // Output tensor - int batch_size, // Batch size (B) - int sequence_length, // Sequence length (S) - int num_heads, // Number of attention heads (N) - int head_size, // Hidden layer size per head (H) - void* workspace, // Temporary buffer - cublasHandle_t& cublas, // Cublas handle - const size_t element_size // Element size of input tensor - ); +bool LaunchAttentionKernel( + const void* input, // Input tensor + const int* mask_index, // Mask index (length of each sequence). NULL means no mask. + void* output, // Output tensor + int batch_size, // Batch size (B) + int sequence_length, // Sequence length (S) + int num_heads, // Number of attention heads (N) + int head_size, // Hidden layer size per head (H) + void* workspace, // Temporary buffer + cublasHandle_t& cublas, // Cublas handle + const size_t element_size // Element size of input tensor +); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index 2b73bd6d19..ab844e30cb 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -137,7 +137,8 @@ bool KernelRegistry::VerifyKernelDef(const onnxruntime::Node& node, || (kernel_start_version < node_since_version && kernel_end_version != INT_MAX && kernel_end_version >= node_since_version); if (!valid_version) { std::ostringstream ostr; - ostr << "Op: " << node.OpType() + ostr << "Op with name (" << node.Name() << ")" + << " and type (" << node.OpType() << ")" << " Version mismatch." << " node_version: " << node_since_version << " kernel start version: " << kernel_start_version @@ -168,17 +169,31 @@ bool KernelRegistry::VerifyKernelDef(const onnxruntime::Node& node, // missing optional parameter, which can be skipped. // TODO: We should check that names specified in kernel_type_constraints are // valid names (of types or parameters) at the time that kernels are registered. - if ((nullptr != actual_type) && - !std::any_of(allowed_types.begin(), allowed_types.end(), - [actual_type, &node, &error_str](const DataTypeImpl* expected_type) { - bool rc = expected_type->IsCompatible(*actual_type); // for easier debugging - if (!rc) { - // TODO print type information as well - error_str = "Op: " + node.OpType() + " Incompatible types."; - } - return rc; - })) { - return false; + if (nullptr != actual_type) { + bool is_type_compatible = std::any_of(allowed_types.begin(), allowed_types.end(), + [actual_type](const DataTypeImpl* expected_type) { + bool rc = expected_type->IsCompatible(*actual_type); // for easier debugging + return rc; + }); + if (!is_type_compatible) { + std::ostringstream ostr; + ostr << "Found kernel for Op with name (" << node.Name() << ")" + << " and type (" << node.OpType() << ")" + << " in the supported version range" + << " (node_version: " << node_since_version + << " kernel start version: " << kernel_start_version + << " kernel_end_version: " << kernel_end_version << ")." + << " However the types are incompatible." + << " This op has been implemented only for the following types ("; + for (const auto& allowed_type : allowed_types) { + ostr << DataTypeImpl::ToString(allowed_type) << ","; + } + ostr << "),"; + const char* actual_type_str = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*actual_type)); + ostr << " but the node in the model has the following type (" << actual_type_str << ")"; + error_str = ostr.str(); + return false; + } } } return true; @@ -240,7 +255,7 @@ Status KernelRegistry::TryCreateKernel(const onnxruntime::Node& node, static std::string ToString(const std::vector& error_strs) { std::ostringstream ostr; std::for_each(std::begin(error_strs), std::end(error_strs), - [&ostr](const std::string& str) { ostr << str << " "; }); + [&ostr](const std::string& str) { ostr << str << "\n"; }); return ostr.str(); } @@ -256,6 +271,9 @@ const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& n auto range = kernel_creator_fn_map_.equal_range(GetMapKey(node.OpType(), node.Domain(), expected_provider)); std::vector verify_kernel_def_error_strs; + LOGS_DEFAULT(VERBOSE) << "Trying to find a kernel for op with name (" << node.Name() << ")" + << " and type (" << node.OpType() << ")" + << " and execution provider (" << expected_provider << ")"; for (auto i = range.first; i != range.second; ++i) { if (!i->second.status.IsOK()) { LOGS_DEFAULT(ERROR) << "Failed to create kernel for op: " << node.OpType() @@ -270,7 +288,9 @@ const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& n } if (!verify_kernel_def_error_strs.empty()) { - LOGS_DEFAULT(INFO) << node.OpType() << " kernel is not supported in " << expected_provider + LOGS_DEFAULT(INFO) << "Op with name (" << node.Name() << ")" + << " and type (" << node.OpType() << ")" + << " kernel is not supported in " << expected_provider << "." << " Encountered following errors: (" << ToString(verify_kernel_def_error_strs) << ")"; } return nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5472912013..25cfd774e1 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -299,7 +299,7 @@ void RegisterBertSchemas() { .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T") .Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") - .Input(3, "mask_index", "Attention mask index with shape (batch_size)", "M") + .Input(3, "mask_index", "Attention mask index with shape (batch_size).", "M", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types") @@ -2325,7 +2325,6 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); RegisterBertSchemas(); - } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 1327aee8ff..2cba64b36a 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2530,7 +2530,9 @@ Status Graph::SetGraphInputsOutputs() { for (auto& graph_value_info : graph_proto_->value_info()) { auto& name = graph_value_info.name(); const auto* node_arg = GetNodeArg(name); - value_info_.push_back(node_arg); + if (node_arg != nullptr) { + value_info_.push_back(node_arg); + } } } else { diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 0fa130328a..0274273190 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -159,20 +159,13 @@ class CudaKernel : public OpKernel { return provider_->PerThreadCublasHandle(); } -#if CUDA_VERSION >= 10010 - inline cublasLtHandle_t CublasLtHandle() const { - return provider_->PerThreadCublasLtHandle(); - } -#endif - inline cudnnHandle_t CudnnHandle() const { return provider_->PerThreadCudnnHandle(); } - inline curandGenerator_t CurandGenerator() const { return provider_->PerThreadCurandGenerator(); } - + template inline const T* GetConstOnes(size_t count) const { return provider_->template GetConstOnes(count); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 6e1b5b6eac..d4e5d39b15 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -50,9 +50,6 @@ thread_local std::unique_ptr CUDAExe CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, size_t cuda_mem_limit) { CUDA_CALL_THROW(cudaSetDevice(device_id)); CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); - #if CUDA_VERSION >= 10010 - CUBLAS_CALL_THROW(cublasLtCreate(&cublasLt_handle_)); - #endif CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_)); CURAND_CALL_THROW(curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); @@ -72,14 +69,6 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() { LOGS_DEFAULT(ERROR) << "cublasDestroy threw:" << ex.what(); } -#if CUDA_VERSION >= 10010 - try { - CUBLAS_CALL(cublasLtDestroy(cublasLt_handle_)); - } catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "cublasLtDestroy threw:" << ex.what(); - } -#endif - try { CUDNN_CALL(cudnnDestroy(cudnn_handle_)); } catch (const std::exception& ex) { @@ -209,6 +198,8 @@ void CUDAExecutionProvider::AddDeferredReleaseCPUPtr(void* p) { } Status CUDAExecutionProvider::OnRunStart() { + // always set CUDA device when session::Run() in case it runs in a worker thread + CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); auto cpu_alloc = GetAllocator(0, OrtMemTypeCPU); // check if cudaEvents has passed for deferred release // note that we need to take a mutex in case of multi-threaded Run() @@ -754,6 +745,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, uint8_t, ReduceMin); + static void RegisterCudaKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index c4501ae23c..3f119fa379 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -44,12 +44,6 @@ class CUDAExecutionProvider : public IExecutionProvider { return GetPerThreadContext().CublasHandle(); } -#if CUDA_VERSION >= 10010 - cublasLtHandle_t PerThreadCublasLtHandle() { - return GetPerThreadContext().CublasLtHandle(); - } -#endif - cudnnHandle_t PerThreadCudnnHandle() { return GetPerThreadContext().CudnnHandle(); } @@ -101,12 +95,6 @@ class CUDAExecutionProvider : public IExecutionProvider { return cublas_handle_; } -#if CUDA_VERSION >= 10010 - cublasLtHandle_t CublasLtHandle() const { - return cublasLt_handle_; - } -#endif - cudnnHandle_t CudnnHandle() const { return cudnn_handle_; } @@ -147,9 +135,6 @@ class CUDAExecutionProvider : public IExecutionProvider { private: cublasHandle_t cublas_handle_ = nullptr; -#if CUDA_VERSION >= 10010 - cublasLtHandle_t cublasLt_handle_ = nullptr; -#endif cudnnHandle_t cudnn_handle_ = nullptr; curandGenerator_t curand_generator_ = nullptr; diff --git a/onnxruntime/core/providers/cuda/cuda_pch.h b/onnxruntime/core/providers/cuda/cuda_pch.h index 0c16cc0d17..3235505060 100644 --- a/onnxruntime/core/providers/cuda/cuda_pch.h +++ b/onnxruntime/core/providers/cuda/cuda_pch.h @@ -14,11 +14,6 @@ #include #include -// support of cublasLt starts 10.1 -#if CUDA_VERSION >= 10010 -#include -#endif - #ifdef USE_NCCL #include #endif diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index de8669cc6b..d0d50718ec 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -115,12 +115,12 @@ Status CudnnFilterDescriptor::Set(const std::vector& filter_dims, cudnn template cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN engine currently supports only single/double/half/int8/uint8 precision data types. Got:", - typeid(ElemType).name()); + typeid(ElemType).name()); // Not reachable but GCC complains return 0; } -template <> +template<> cudnnDataType_t CudnnTensor::GetDataType() { return CUDNN_DATA_FLOAT; } diff --git a/onnxruntime/core/providers/cuda/generator/constant_of_shape.cc b/onnxruntime/core/providers/cuda/generator/constant_of_shape.cc index 52d3450230..f598a0337b 100644 --- a/onnxruntime/core/providers/cuda/generator/constant_of_shape.cc +++ b/onnxruntime/core/providers/cuda/generator/constant_of_shape.cc @@ -29,11 +29,11 @@ Status ConstantOfShape::ComputeInternal(OpKernelContext* ctx) const { const void* value_ptr = GetValuePtr(); const auto element_size = output_tensor->DataType()->Size(); -#define CASE(TYPE) \ - case sizeof(TYPE): \ - if (size > 0) { \ - cuda::Fill(reinterpret_cast(output_data), *(reinterpret_cast(value_ptr)), size); \ - } \ +#define CASE(TYPE) \ + case sizeof(TYPE): \ + if (size > 0) { \ + cuda::Fill(reinterpret_cast(output_data), *(reinterpret_cast(value_ptr)), size); \ + } \ break; switch (element_size) { diff --git a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h index 7443a1753e..caa41aedb0 100644 --- a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h +++ b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h @@ -15,7 +15,7 @@ namespace cuda { class ConstantOfShape final : public ConstantOfShapeBase, public CudaKernel { public: - explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), CudaKernel(info){}; + explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), CudaKernel(info) {}; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape); diff --git a/onnxruntime/core/providers/cuda/generator/range_impl.h b/onnxruntime/core/providers/cuda/generator/range_impl.h index 9ab5521022..684978d544 100644 --- a/onnxruntime/core/providers/cuda/generator/range_impl.h +++ b/onnxruntime/core/providers/cuda/generator/range_impl.h @@ -7,6 +7,7 @@ namespace onnxruntime { namespace cuda { + template bool RangeImpl(const T start, const T delta, const int count, T* output); diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 28430bd078..8fae7ae8b0 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -18,10 +18,8 @@ GPUDataTransfer::~GPUDataTransfer() { } bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || - src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || - dst_device.Type() == OrtDevice::GPU || - dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED + || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const { diff --git a/onnxruntime/core/providers/cuda/igemm.cc b/onnxruntime/core/providers/cuda/igemm.cc deleted file mode 100644 index 68bc9f4b19..0000000000 --- a/onnxruntime/core/providers/cuda/igemm.cc +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "igemm.h" - -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/shared_inc/cuda_call.h" - -namespace onnxruntime { -namespace cuda { - -#if CUDA_VERSION >= 10010 -void LtIgemmTensor(int m, - int n, - int k, - int32_t alpha_matmul, - int32_t beta_matmul, - const int8_t* a, - int lda, - const int8_t* b, - int ldb, - int32_t* c, - int ldc, - const CudaKernel* cuda_kernel, - cublasLtHandle_t lt_handle) { - // Create descriptors for the original matrices - cublasLtMatrixLayout_t a_desc = nullptr; - cublasLtMatrixLayout_t b_desc = nullptr; - cublasLtMatrixLayout_t c_desc = nullptr; - CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&a_desc, CUDA_R_8I, m, k, lda)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&b_desc, CUDA_R_8I, n, k, ldb)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&c_desc, CUDA_R_32I, m, n, ldc)); - - // Set A and C row major order. - // No need for B because B need to be transposed - cublasLtOrder_t row_order = CUBLASLT_ORDER_ROW; - CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row_order, sizeof(row_order))); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &row_order, sizeof(row_order))); - - // The tensor operations IGEMM kernels require specialized memory order of data. - // Matrix A and Matrix C need to be in CUBLASLT_ORDER_COL32 order - // And Matric B needs to be in CUBLASLT_ORDER_COL4_4R2_8C order - - cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; - cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; - - // For CUBLASLT_ORDER_COL32 order, Data is ordered in column-major ordered tiles of 32 columns. - // The leading dimension is the stride (in elements) to the beginning of next group of 32-columns. - - // For CUBLASLT_ORDER_COL4_4R2_8C, Data is ordered in column-major ordered tiles of composite tiles - // with total 32 columns and 8 rows. - // A tile is composed of interleaved inner tiles of 4 columns within 4 even or odd rows in an alternating pattern. - // The leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile - // for the next 32-wide group of columns. - int lda_transform = 32 * m; - int ldb_transform = 32 * roundoff(n, 8); - int ldc_transform = 32 * m; - - // Allocate memory for transform - IAllocatorUniquePtr a_transform = cuda_kernel->GetScratchBuffer(roundoff(k, 32) / 32 * lda_transform); - IAllocatorUniquePtr b_transform = cuda_kernel->GetScratchBuffer(roundoff(k, 32) / 32 * ldb_transform); - IAllocatorUniquePtr c_transform = cuda_kernel->GetScratchBuffer(roundoff(k, 32) / 32 * ldc_transform); - - // Create descriptors for the transformed matrices - cublasLtMatrixLayout_t a_transform_desc = nullptr; - cublasLtMatrixLayout_t b_transform_desc = nullptr; - cublasLtMatrixLayout_t c_transform_desc = nullptr; - CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&a_transform_desc, CUDA_R_8I, m, k, lda_transform)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&b_transform_desc, CUDA_R_8I, n, k, ldb_transform)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutCreate(&c_transform_desc, CUDA_R_32I, m, n, ldc_transform)); - - CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(a_transform_desc, - CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL32, - sizeof(order_COL32))); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(b_transform_desc, - CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL4_4R2_8C, - sizeof(order_COL4_4R2_8C))); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutSetAttribute(c_transform_desc, - CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL32, - sizeof(order_COL32))); - - cublasLtMatrixTransformDesc_t transform_desc = nullptr; - CUBLAS_CALL_THROW(cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F)); - - float alpha_transform = 1.0f; - float beta_transform = 0.0f; - CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle, - transform_desc, - &alpha_transform, - a, - a_desc, - &beta_transform, - nullptr, - nullptr, - a_transform.get(), - a_transform_desc, - 0)); - - CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle, - transform_desc, - &alpha_transform, - b, - b_desc, - &beta_transform, - nullptr, - nullptr, - b_transform.get(), - b_transform_desc, - 0)); - - if (beta_matmul == 1) { - CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle, - transform_desc, - &alpha_transform, - c, - c_desc, - &beta_transform, - nullptr, - nullptr, - c_transform.get(), - c_transform_desc, - 0)); - } - - // Tensor op igemm kernels only support NT gemm - cublasLtMatmulDesc_t matmul_desc = nullptr; - cublasOperation_t op_trans = CUBLAS_OP_T; - CUBLAS_CALL_THROW(cublasLtMatmulDescCreate(&matmul_desc, CUDA_R_32I)); - CUBLAS_CALL_THROW(cublasLtMatmulDescSetAttribute(matmul_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &op_trans, - sizeof(op_trans))); - - CUBLAS_CALL_THROW(cublasLtMatmul(lt_handle, - matmul_desc, - &alpha_matmul, - a_transform.get(), - a_transform_desc, - b_transform.get(), - b_transform_desc, - &beta_matmul, - c_transform.get(), - c_transform_desc, - c_transform.get(), - c_transform_desc, - nullptr, - nullptr, - 0, - 0)); - - CUBLAS_CALL_THROW(cublasLtMatrixTransform(lt_handle, - transform_desc, - &alpha_transform, - c_transform.get(), - c_transform_desc, - &beta_transform, - nullptr, - nullptr, - c, - c_desc, - 0)); - - CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(c_transform_desc)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(b_transform_desc)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(a_transform_desc)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(c_desc)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(b_desc)); - CUBLAS_CALL_THROW(cublasLtMatrixLayoutDestroy(a_desc)); - CUBLAS_CALL_THROW(cublasLtMatmulDescDestroy(matmul_desc)); - CUBLAS_CALL_THROW(cublasLtMatrixTransformDescDestroy(transform_desc)); -} -#endif -} -} diff --git a/onnxruntime/core/providers/cuda/igemm.h b/onnxruntime/core/providers/cuda/igemm.h deleted file mode 100644 index 4910e36c30..0000000000 --- a/onnxruntime/core/providers/cuda/igemm.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/providers/cuda/cuda_common.h" - -namespace onnxruntime { -namespace cuda { - -inline int roundoff(int v, int d) { - return (v + d - 1) / d * d; -} - -#if CUDA_VERSION >= 10010 -void LtIgemmTensor(int m, - int n, - int k, - int32_t alpha_matmul, - int32_t beta_matmul, - const int8_t* a, - int lda, - const int8_t* b, - int ldb, - int32_t* c, - int ldc, - const CudaKernel* cuda_kernel, - cublasLtHandle_t lt_handle); -#endif -} -} \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h index 452eec6cb2..f9ff9e4cfb 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h @@ -14,7 +14,7 @@ struct BinaryElementwisePreparation { const Tensor* lhs_tensor = nullptr; const Tensor* rhs_tensor = nullptr; Tensor* output_tensor = nullptr; - int32_t output_rank_or_simple_broadcast = 0; // for no_broadcast|left_scalar|right_scalar cases, output_rank uses SimpleBroadcast enums + int32_t output_rank_or_simple_broadcast = 0; // for no_broadcast|left_scalar|right_scalar cases, output_rank uses SimpleBroadcast enums TArray lhs_padded_strides; TArray rhs_padded_strides; @@ -42,8 +42,8 @@ struct BinaryElementwisePreparation { // early return if one operand is scalar if (lhs_shape.Size() == 1 || rhs_shape.Size() == 1) { output_rank_or_simple_broadcast = static_cast(lhs_shape.Size() == 1 - ? SimpleBroadcast::LeftScalar - : SimpleBroadcast::RightScalar); + ? SimpleBroadcast::LeftScalar + : SimpleBroadcast::RightScalar); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h index 15a6ecd697..ac717524b1 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops_impl.h @@ -34,18 +34,18 @@ namespace cuda { // NOTE that cu files are compiled with nvcc and should not refer to any onnxruntime headers // so struct BinaryElementwisePreparation cannot be used here -#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \ - template \ - void Impl_##name( \ - int32_t output_rank_or_simple_broadcast, \ - const TArray* lhs_padded_strides, \ - const T* lhs_data, \ - const TArray* rhs_padded_strides, \ - const T* rhs_data, \ +#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \ + template \ + void Impl_##name( \ + int32_t output_rank_or_simple_broadcast, \ + const TArray* lhs_padded_strides, \ + const T* lhs_data, \ + const TArray* rhs_padded_strides, \ + const T* rhs_data, \ const TArray* fdm_output_strides, \ - const fast_divmod& fdm_H, \ - const fast_divmod& fdm_C, \ - T* output_data, \ + const fast_divmod& fdm_H, \ + const fast_divmod& fdm_C, \ + T* output_data, \ size_t count) #define BINARY_OP_NAME_EXPR(name, expr) BINARY_ELEMENTWISE_IMPL_DECLARATION(name); diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.cc b/onnxruntime/core/providers/cuda/math/matmul_integer.cc index 61164a1ddb..e4547da039 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.cc @@ -6,7 +6,6 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "core/providers/cuda/cuda_allocator.h" -#include "core/providers/cuda/igemm.h" #include "core/providers/common.h" namespace onnxruntime { @@ -107,28 +106,6 @@ Status MatMulInteger::ComputeInternal(OpKernelContext* ctx) cons beta = 1; } -#if CUDA_VERSION >= 10010 - if (DeviceProp::GetDeviceProps().major >= 7 && DeviceProp::GetDeviceProps().minor >= 5) { - for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) { - LtIgemmTensor( - static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K()), - alpha, - beta, - a_ptr + helper.LeftOffsets()[batch], - static_cast(helper.K()), - b_ptr + helper.RightOffsets()[batch], - static_cast(helper.N()), - output_ptr + helper.OutputOffsets()[batch], - static_cast(helper.N()), - this, - Base::CublasLtHandle()); - } - return Status::OK(); - } -#endif - // pad A and B to make their leading dimension be multiples of 32 // because cublasGemmEx requires: // 1. leading dimension is multiples of 4 diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index a04bcd9b93..de78cd6833 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -10,7 +10,7 @@ namespace cuda { ONNX_OPERATOR_VERSIONED_KERNEL_EX( TopK, kOnnxDomain, - 1, 9, + 1,9, kCudaExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), TopK); @@ -18,7 +18,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( ONNX_OPERATOR_VERSIONED_KERNEL_EX( TopK, kOnnxDomain, - 10, 10, + 10,10, kCudaExecutionProvider, KernelDefBuilder().InputMemoryType(1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), TopK); diff --git a/onnxruntime/core/providers/cuda/math/topk.h b/onnxruntime/core/providers/cuda/math/topk.h index caf6d02851..944c95e369 100644 --- a/onnxruntime/core/providers/cuda/math/topk.h +++ b/onnxruntime/core/providers/cuda/math/topk.h @@ -7,7 +7,7 @@ namespace onnxruntime { namespace cuda { -template +template class TopK final : public CudaKernel { public: TopK(const OpKernelInfo&); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 4bb529d58d..ca3266c811 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -88,7 +88,7 @@ class lru_unordered_map { lru_list_.clear(); } - private: +private: using list_type = std::list; using iterator_type = typename list_type::iterator; struct value_type { @@ -126,12 +126,12 @@ struct CudnnConvState { CudnnConvolutionDescriptor conv_desc; struct PerfResultParams { - decltype(AlgoPerfType().algo) algo; - decltype(AlgoPerfType().memory) memory; + decltype(AlgoPerfType().algo) algo; + decltype(AlgoPerfType().memory) memory; decltype(AlgoPerfType().mathType) mathType; }; - lru_unordered_map, PerfResultParams, vector_hash> cached_benchmark_results{MAX_CACHED_ALGO_PERF_RESULTS}; + lru_unordered_map, PerfResultParams, vector_hash> cached_benchmark_results { MAX_CACHED_ALGO_PERF_RESULTS }; // note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing OrtMutex mutex; diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 2b5830c5f3..367930d9b3 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -32,6 +32,7 @@ namespace cuda { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).TypeConstraint("I", DataTypeImpl::GetTensorType()), \ Pool); + POOLING_KERNEL_VERSIONED(AveragePool, float, AveragePool, 7, 9) POOLING_KERNEL_VERSIONED(AveragePool, double, AveragePool, 7, 9) POOLING_KERNEL_VERSIONED(AveragePool, MLFloat16, AveragePool, 7, 9) @@ -63,6 +64,7 @@ POOLING_KERNEL(MaxPool, MLFloat16, MaxPool<8>, 12) POOLING_KERNEL(MaxPool, int8_t, MaxPool<8>, 12) POOLING_KERNEL(MaxPool, uint8_t, MaxPool<8>, 12) + POOLING_KERNEL(GlobalMaxPool, float, MaxPool<1>, 1) POOLING_KERNEL(GlobalMaxPool, double, MaxPool<1>, 1) POOLING_KERNEL(GlobalMaxPool, MLFloat16, MaxPool<1>, 1) @@ -165,8 +167,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const { cudnnPoolingMode_t mode = CUDNN_POOLING_MAX; if (PoolType::type == onnxruntime::PoolType::kAveragePool) { - mode = pool_attrs_.count_include_pad ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING - : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + mode = pool_attrs_.count_include_pad ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; } CudnnPoolingDescriptor pooling_desc; ORT_RETURN_IF_ERROR(pooling_desc.Set(mode, kernel_shape, pads, strides)); diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.cc b/onnxruntime/core/providers/cuda/object_detection/roialign.cc index 1d33617673..5ca757382f 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.cc +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.cc @@ -1,5 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "roialign.h" #include "roialign_impl.h" @@ -15,7 +15,7 @@ namespace cuda { T, \ kCudaExecutionProvider, \ KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ RoiAlign); @@ -58,7 +58,8 @@ Status RoiAlign::ComputeInternal(OpKernelContext* context) const { num_roi_cols, reinterpret_cast::MappedType*>(Y.template MutableData()), this->mode_ == RoiAlignMode::avg, - batch_indices_ptr->template Data()); + batch_indices_ptr->template Data() + ); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.h b/onnxruntime/core/providers/cuda/object_detection/roialign.h index 4ef23f0380..fdd0f95ccf 100644 --- a/onnxruntime/core/providers/cuda/object_detection/roialign.h +++ b/onnxruntime/core/providers/cuda/object_detection/roialign.h @@ -1,5 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 728bab9a8a..f080672777 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -59,6 +59,7 @@ namespace cuda { KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); + // CUDA's reduction descriptor cudnnReduceTensorDescriptor_t is a pointer so // it's safer to wrap it with automatically memory deleter as CudnnReduceDescriptor. // An implicit caster from CudnnReduceDescriptor to cudnnReduceTensorDescriptor_t diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index c0b5cb8a15..41a6c94c8c 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -91,7 +91,7 @@ class CudnnRnnBase : public CudaKernel { rnn_mode_ = CUDNN_LSTM; weight_cached_ = false; w_data_cache_ = nullptr; - + size_t state_size; cudnn_dropout_desc_.CreateDescriptorIfNeeded(); cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size); diff --git a/onnxruntime/core/providers/cuda/rnn/lstm.h b/onnxruntime/core/providers/cuda/rnn/lstm.h index b8f8465962..3ed12cfa7f 100644 --- a/onnxruntime/core/providers/cuda/rnn/lstm.h +++ b/onnxruntime/core/providers/cuda/rnn/lstm.h @@ -10,6 +10,7 @@ namespace cuda { template class LSTM final : public CudnnRnnBase { + public: LSTM(const OpKernelInfo& info) : CudnnRnnBase(info) { CudnnRnnBase::SetRNNMode(CUDNN_LSTM); diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index 1c4fbc5a24..78ceabf23b 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace cuda { -template +template void ReverseBySequence(const int32_t seq_length, const int32_t batch_size, const int32_t input_or_hidden_size, diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index f499a5d1c3..fa3955ce85 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -48,7 +48,7 @@ struct TArray { ORT_ENFORCE(size <= capacity, "TArray size was set to ", size, ", exeeding the capacity limit of ", capacity); } - TArray(const std::vector& vec) : TArray(static_cast(vec.size())) { + TArray(const std::vector& vec) : TArray(static_cast(vec.size())) { memcpy(data_, vec.data(), vec.size() * sizeof(T)); } diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index d56d4273f1..0ff4447927 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -79,7 +79,7 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, const double* beta, double* C, int ldc, long long int strideC, - int batch_count) { + int batch_count){ return cublasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batch_count); } diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index ea92651db5..2b7ffd463f 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -22,7 +22,8 @@ static void CalcEffectiveDims(vector& x_dims, vector& y_dims) if (xdim == ydim || xdim == 1) { x_reverse.push_back(xdim); y_reverse.push_back(ydim); - } else { // xdim < ydim && xdim > 1, split + } + else { // xdim < ydim && xdim > 1, split ydim /= xdim; x_reverse.push_back(xdim); y_reverse.push_back(xdim); @@ -43,15 +44,18 @@ static void CalcEffectiveDims(vector& x_dims, vector& y_dims) } if (x_dims.back() == 1) { y_dims.back() *= y_reverse[i]; - } else { + } + else { x_dims.push_back(1); y_dims.push_back(y_reverse[i]); } - } else { // x_reverse[i] == y_reverse[i] + } + else { // x_reverse[i] == y_reverse[i] if (x_dims.back() == y_dims.back()) { x_dims.back() *= x_reverse[i]; y_dims.back() *= y_reverse[i]; - } else { + } + else { x_dims.push_back(x_reverse[i]); y_dims.push_back(y_reverse[i]); } @@ -103,6 +107,7 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const { input_strides); } + ONNX_OPERATOR_KERNEL_EX( Expand, kOnnxDomain, diff --git a/onnxruntime/core/providers/cuda/tensor/expand_impl.h b/onnxruntime/core/providers/cuda/tensor/expand_impl.h index 77813a8b28..27d5d69d9c 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/expand_impl.h @@ -20,5 +20,6 @@ Status ExpandImpl( const TArray& output_strides, const TArray& input_strides); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/eye_like.cc b/onnxruntime/core/providers/cuda/tensor/eye_like.cc index 78c51d03b6..cdbbeb4fac 100644 --- a/onnxruntime/core/providers/cuda/tensor/eye_like.cc +++ b/onnxruntime/core/providers/cuda/tensor/eye_like.cc @@ -21,23 +21,25 @@ ONNX_OPERATOR_KERNEL_EX( DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .TypeConstraint("T2", - std::vector{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + DataTypeImpl::GetTensorType() + }) + .TypeConstraint("T2", + std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType() + }), EyeLike); -#define TYPED_FUNCTION_CALL(T) \ - EyeLikeImpl::MappedType>( \ - offset, \ - dim1 + 1, \ - reinterpret_cast::MappedType*>(T2->template MutableData()), \ - diag_count); \ - break; +#define TYPED_FUNCTION_CALL(T) \ + EyeLikeImpl::MappedType>( \ + offset, \ + dim1 + 1, \ + reinterpret_cast::MappedType *>(T2->template MutableData()), \ + diag_count); \ + break; Status EyeLike::ComputeInternal(OpKernelContext* context) const { const auto* T1 = context->Input(0); diff --git a/onnxruntime/core/providers/cuda/tensor/eye_like_impl.h b/onnxruntime/core/providers/cuda/tensor/eye_like_impl.h index daf0d2b686..f95ca63782 100644 --- a/onnxruntime/core/providers/cuda/tensor/eye_like_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/eye_like_impl.h @@ -12,11 +12,11 @@ namespace cuda { template void EyeLikeImpl( - size_t offset, // offset of first element in diagnal - size_t stripe, // stripe, here it's width + 1 - T* output_data, // output buffer - size_t diag_count // total number of elements in diagnal - ); + size_t offset, // offset of first element in diagnal + size_t stripe, // stripe, here it's width + 1 + T* output_data, // output buffer + size_t diag_count // total number of elements in diagnal +); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.cc b/onnxruntime/core/providers/cuda/tensor/identity_op.cc index 3d5f456774..890bdf5cac 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.cc @@ -11,8 +11,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 7, 9, kCudaExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .Alias(0, 0), IdentityOp); diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.h b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.h index 5288a64d18..7d55e83133 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_impl.h @@ -10,20 +10,21 @@ namespace cuda { int NonZeroCalcBlockCount(int64_t x_size); -cudaError_t NonZeroCalcPrefixSumTempStorageBytes(int* prefix_counts, int number_of_blocks, size_t&); +cudaError_t NonZeroCalcPrefixSumTempStorageBytes(int* prefix_counts, int number_of_blocks, size_t& ); cudaError_t NonZeroInclusivePrefixSum(void* d_temp_storage, size_t temp_storage_bytes, int* prefix_counts, int number_of_blocks); -// count nonzero elements in each block into counts_in_blocks, +// count nonzero elements in each block into counts_in_blocks, // the counts_in_blocks buffer is pre-allocated on gpu first. -template +template cudaError_t NonZeroCountEachBlock(const InputT* x, int64_t x_size, int* counts_in_blocks); // output nonzero positions using input x and prefix_counts for each blocks -template +template cudaError_t NonZeroOutputPositions( - const InputT* x, int64_t x_size, int x_rank, const TArray& x_strides, + const InputT *x, int64_t x_size, int x_rank, const TArray& x_strides, const int* prefix_counts, int nonzero_elements, int64_t* results); } // namespace cuda } // namespace onnxruntime + diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_op.h b/onnxruntime/core/providers/cuda/tensor/nonzero_op.h index e19d432e54..1091e6fd9e 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_op.h +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_op.h @@ -10,8 +10,8 @@ namespace onnxruntime { namespace cuda { template -class NonZero final : public CudaKernel { - public: +class NonZero final: public CudaKernel { +public: NonZero(const OpKernelInfo& info) : CudaKernel(info) {} Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/tensor/reverse_sequence.cc b/onnxruntime/core/providers/cuda/tensor/reverse_sequence.cc index 1a1f9afc24..f51d99c549 100644 --- a/onnxruntime/core/providers/cuda/tensor/reverse_sequence.cc +++ b/onnxruntime/core/providers/cuda/tensor/reverse_sequence.cc @@ -20,7 +20,7 @@ ONNX_OPERATOR_KERNEL_EX( ReverseSequenceOp); #define ReverseSequenceCallCudaImplTypeAs(T, TEqual) \ - if (X.IsDataType()) { \ + if (X.IsDataType()) { \ CUDA_RETURN_IF_ERROR(ReverseSequenceCudaImpl( \ reinterpret_cast::MappedType*>(X.template Data()), \ seq_lengths.Data(), \ diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h old mode 100644 new mode 100755 index 3cd55b8c9c..f70bf6b778 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h @@ -24,3 +24,4 @@ class ScatterElements final : public CudaKernel { } // namespace cuda } // namespace onnxruntime + diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h old mode 100644 new mode 100755 index 5eea6ab808..2f08d542e0 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements_impl.h @@ -26,3 +26,4 @@ Status ScatterElementsImpl( } // namespace cuda } // namespace onnxruntime + diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index be96ef1040..d1969785b4 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -14,7 +14,7 @@ ONNX_OPERATOR_KERNEL_EX( kCudaExecutionProvider, KernelDefBuilder() .OutputMemoryType(0) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); diff --git a/onnxruntime/core/providers/cuda/tensor/slice.h b/onnxruntime/core/providers/cuda/tensor/slice.h index c83b17c8fa..e4c301ad41 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.h +++ b/onnxruntime/core/providers/cuda/tensor/slice.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace cuda { -template +template class Slice final : public CudaKernel, public SliceBase { public: Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic) {} diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.h b/onnxruntime/core/providers/cuda/tensor/split_impl.h index 4638aca421..fa07a68fb5 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace cuda { -Status SplitImpl(const size_t element_size, +Status SplitImpl(const size_t element_size, const int block_size_including_axis_dim, const int block_size_inside_axis_dim, const int64_t* split_sizes, diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index 5031e57bba..a66a11bb43 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -8,16 +8,16 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Tile, \ - kOnnxDomain, \ - 6, \ - T, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .InputMemoryType(1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Tile, \ + kOnnxDomain, \ + 6, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .InputMemoryType(1) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ Tile); diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index 19f889594f..3fc0b9272e 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -68,10 +68,10 @@ struct TernaryElementwisePreparation { const Tensor* a_tensor = nullptr; const Tensor* b_tensor = nullptr; const Tensor* c_tensor = nullptr; - size_t output_rank_or_simple_broadcast = 0; // for no_broadcast cases, output_rank uses SimpleBroadcast enums - TArray a_padded_strides; // for a shape == output shape, this is nullptr - TArray b_padded_strides; // for b shape == output shape, this is nullptr - TArray c_padded_strides; // for c shape == output shape, this is nullptr + size_t output_rank_or_simple_broadcast = 0; // for no_broadcast cases, output_rank uses SimpleBroadcast enums + TArray a_padded_strides; // for a shape == output shape, this is nullptr + TArray b_padded_strides; // for b shape == output shape, this is nullptr + TArray c_padded_strides; // for c shape == output shape, this is nullptr TArray fdm_output_strides; TernaryElementwisePreparation(const Tensor* a, const Tensor* b, const Tensor* c) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a090837361..2d5db7de07 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -183,11 +183,10 @@ bool FindCycleHelper(int i, const std::list* adjacency_map, // Remove nodes with empty shape (for example [1, 0]) because TensorRT 7 doens't support empty shape SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph) { - // Here only NonZero and NonMaxSuppression related empty shape nodes are removed, particularly for Faster-rcnn and Mask-rcnn models. + // Here only NonZero, NonMaxSuppression and TopK related empty shape nodes are removed, particularly for RCNN models. // TODO: Remove the code if TensorRT fixed the issue in the future release, or find a better generic way here to work around const std::vector& node_index = graph.GetNodesInTopologicalOrder(); - const std::string exclude_dim_name1 = "NonZero"; - const std::string exclude_dim_name2 = "NonMaxSuppression"; + const std::vector exclude_dim_names{"NonZero", "NonMaxSuppression", "TopK"}; SubGraphCollection_t parser_nodes_vector = {{{}, false}}; std::vector nodes_vector(node_index.size()); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); @@ -201,8 +200,13 @@ SubGraphCollection_t RemoveEmptyShapeNodes(const onnxruntime::GraphViewer& graph for (const auto& dim : input_shape->dim()) { std::string dim_name = dim.dim_param(); if (!dim_name.empty()) { - if ((dim_name.find(exclude_dim_name1) != std::string::npos) || (dim_name.find(exclude_dim_name2) != std::string::npos)) { - exclude_node = true; + for (const auto& exclude : exclude_dim_names) { + if (dim_name.find(exclude) != std::string::npos) { + exclude_node = true; + break; + } + } + if (exclude_node) { break; } } @@ -260,7 +264,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } } - // For output searching, there is two special cases, + // For output searching, there are two special cases, // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, // if the output is connected to nodes that don't belong to the subgraph, the output need to be added // to the output list @@ -322,11 +326,15 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph meta_def->domain = kMSDomain; for (const auto& input : inputs) { - meta_def->inputs.push_back(input.second->Name()); + if (input.second->Exists()) { + meta_def->inputs.push_back(input.second->Name()); + } } for (const auto& output : outputs) { - meta_def->outputs.push_back(output.second->Name()); + if (output.second->Exists()) { + meta_def->outputs.push_back(output.second->Name()); + } } meta_def->since_version = 1; @@ -385,6 +393,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); } + // Add initializers to the subgraph + const auto& init_tensors = graph.GetAllInitializedTensors(); + for (const auto& tensor : init_tensors) { + graph_build.AddInitializedTensor(*(tensor.second)); + } + ORT_ENFORCE(graph_build.Resolve().IsOK()); // Add parent graph output to the subgraph @@ -400,13 +414,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto& graph_build_outputs = graph_build.GetOutputs(); subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); graph_build.SetOutputs(graph_build_outputs); - - // Add initializers to the subgraph - const auto& init_tensors = graph.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - graph_build.AddInitializedTensor(*(tensor.second)); - } - ORT_ENFORCE(graph_build.Resolve().IsOK()); // Check if input tensors have shapes diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 464524b016..571d4e99b1 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -39,16 +39,18 @@ static void RunAttentionTest( tester.AddInput("input", input_dims, ToFloat16(input_data)); tester.AddInput("weight", weights_dims, ToFloat16(weights_data)); tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); - tester.AddInput("mask_index", mask_index_dims, mask_index_data); tester.AddOutput("output", output_dims, ToFloat16(output_data)); } else { tester.AddInput("input", input_dims, input_data); tester.AddInput("weight", weights_dims, weights_data); tester.AddInput("bias", bias_dims, bias_data); - tester.AddInput("mask_index", mask_index_dims, mask_index_data); tester.AddOutput("output", output_dims, output_data); } + if (mask_index_data.size() > 0) { // mask index is optional. + tester.AddInput("mask_index", mask_index_dims, mask_index_data); + } + tester.Run(); } } @@ -204,5 +206,34 @@ TEST(AttentionTest, AttentionMaskExceedSequence) { batch_size, sequence_length, hidden_size, number_of_heads); } +TEST(AttentionTest, AttentionNoMaskIndex) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // No mask_index + std::vector mask_index_data = {}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 1ea8bc8084..104ef377c3 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -246,6 +246,31 @@ TEST_F(GraphTest, SimpleUnique) { std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(std::move(m), model, nullptr, *logger_)); } + +TEST_F(GraphTest, UnusedValueInfoSerializes) { + ModelProto m; + m.set_ir_version(4); + ImportOpset(m, "", 11); + GraphProto& g = *m.mutable_graph(); + NodeProto* node = g.add_node(); + *node->add_input() = "x"; + *node->add_output() = "sum"; + node->set_op_type("Unique"); + node->set_domain(""); + ValueInfoProto* input1 = g.add_input(); + input1->set_name("x"); + SetTypeAndShape(input1->mutable_type()->mutable_tensor_type(), 1, {3, 4, 5}); + ValueInfoProto* output = g.add_output(); + output->set_name("sum"); + SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {60}); + ValueInfoProto* unused = g.add_value_info(); + unused->set_name("unused"); + SetTypeAndShape(unused->mutable_type()->mutable_tensor_type(), 1, {123}); + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(std::move(m), model, nullptr, *logger_)); + model->MainGraph().SetGraphProtoSyncNeeded(); + EXPECT_TRUE(Model::Save(*model, "graph_with_unused_value_info.onnx").IsOK()); +} TEST_F(GraphTest, WrongOpset) { ModelProto m; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index e071c93b4b..d722034073 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -31,6 +31,7 @@ namespace perftest { "\t\tProvide 'duration' to run the test for a fix duration, and 'times' to repeated for a certain times. \n" "\t-M: Disable memory pattern.\n" "\t-A: Disable memory arena\n" + "\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n" "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" "\t-e [cpu|cuda|dnnl|tensorrt|ngraph|openvino|nuphar|dml|acl]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " "'ngraph', 'openvino', 'nuphar', 'dml' or 'acl'. " @@ -52,7 +53,7 @@ namespace perftest { /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:o:u:AMPvhs"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:o:u:AMPIvhs"))) != -1) { switch (ch) { case 'm': if (!CompareCString(optarg, ORT_TSTR("duration"))) { @@ -170,6 +171,9 @@ namespace perftest { case 'u': test_config.run_config.optimized_model_path = optarg; break; + case 'I': + test_config.run_config.generate_model_input_binding = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index f5b74e19d8..331a899b4e 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -136,5 +136,30 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } } +bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData() +{ + // iterate over all input nodes + for (size_t i = 0; i < static_cast(input_length_); i++) { + Ort::TypeInfo type_info = session_.GetInputTypeInfo(i); + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) { + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector input_node_dim = tensor_info.GetShape(); + + // free dimensions are treated as 1 + for (int64_t& dim : input_node_dim) { + if (dim == -1) { + dim = 1; + } + } + // default allocator doesn't have to be freed by user + auto allocator = static_cast(Ort::AllocatorWithDefaultOptions()); + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), input_node_dim.size(), tensor_info.GetElementType()); + PreLoadTestData(0, i, input_tensor.release()); + } + } + return true; +} + } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index e71b5ad86f..7cfbe8ea6e 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -25,6 +25,8 @@ class OnnxRuntimeTestSession : public TestSession { test_inputs_[test_data_id][input_id] = Ort::Value{value}; } + bool PopulateGeneratedInputTestData(); + ~OnnxRuntimeTestSession() override { for (char* p : input_names_) { free(p); diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index 6e539162bf..f8c2aa04f6 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -218,6 +218,11 @@ bool PerformanceRunner::Initialize() { test_case_.reset(CreateOnnxTestCase(narrow_model_name, test_model_info_, 0.0, 0.0)); + if (performance_test_config_.run_config.generate_model_input_binding) + { + return static_cast(session_.get())->PopulateGeneratedInputTestData(); + } + // TODO: Place input tensor on cpu memory if dnnl provider type to avoid CopyTensor logic in CopyInputAcrossDevices size_t test_data_count = test_case_->GetDataCount(); if (test_data_count == 0) { diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 36186e129f..9554f12486 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -44,6 +44,7 @@ struct RunConfig { bool f_verbose{false}; bool enable_memory_pattern{true}; bool enable_cpu_mem_arena{true}; + bool generate_model_input_binding{false}; ExecutionMode execution_mode{ExecutionMode::ORT_SEQUENTIAL}; int intra_op_num_threads{0}; int inter_op_num_threads{0}; diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux1 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux1 index 5d6dd4da94..9e2089cc24 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux1 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux1 @@ -1,4 +1,4 @@ -FROM quay.io/pypa/manylinux1_x86_64:latest +FROM quay.io/pypa/manylinux1_x86_64:2020-04-01-7a4ddf4 ARG PYTHON_VERSION ADD scripts /tmp/scripts