diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc index a656197995..e991502745 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc @@ -9,6 +9,18 @@ namespace onnxruntime { namespace contrib { + +Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, + const TensorShape& input_shape, + const TensorShape& weights_shape, + const TensorShape& bias_shape, + const TensorShape& mask_shape, + const TensorShape& global_weights_shape, + const TensorShape& global_bias_shape, + const TensorShape& global_shape) { + return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape); +} + namespace embed_layer_norm { Status CheckInputs(const OpKernelContext* context) { diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index 1cd612bb34..84067dea51 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -46,6 +46,11 @@ #include "contrib_ops/cpu/bert/bias_gelu_helper.h" #include "contrib_ops/cpu/bert/embed_layer_norm_helper.h" #include "contrib_ops/cpu/bert/longformer_attention_base.h" +namespace onnxruntime { +namespace contrib { +Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape); +} +} // namespace onnxruntime #include "contrib_ops/cpu/bert/attention_base.h" #endif @@ -834,7 +839,7 @@ struct ProviderHostImpl : ProviderHost { Status embed_layer_norm__CheckInputs(const OpKernelContext* context) override { return contrib::embed_layer_norm::CheckInputs(context); } Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) override { return contrib::bias_gelu_helper::CheckInputs(context); } Status LongformerAttentionBase__CheckInputs(const contrib::LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape) override { - return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape); + return contrib::LongformerAttentionBase__CheckInputs(p, input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape); } Status AttentionBase__CheckInputs(const contrib::AttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, const int max_threads_per_block) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, max_threads_per_block); } Tensor* AttentionBase__GetPresent(const contrib::AttentionBase* p, OpKernelContext* context, const Tensor* past, int batch_size, int head_size, int sequence_length, int& past_sequence_length) override { return p->contrib::AttentionBase::GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); } @@ -948,7 +953,7 @@ struct ProviderLibrary { static ProviderLibrary s_library_cuda(LIBRARY_PREFIX "onnxruntime_providers_cuda" LIBRARY_EXTENSION #ifndef _WIN32 , - false /* unload - On Linux if we unload the cuda shared provider we crash. On Windows we'll crash if we don't */ + false /* unload - On Linux if we unload the cuda shared provider we crash */ #endif ); static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);