From 648679b37093740bbefc7b41ff53d3239451333f Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Fri, 7 May 2021 18:11:06 -0700 Subject: [PATCH] Remove LongformerAttentionBase workaround --- .../contrib_ops/cpu/bert/embed_layer_norm_helper.cc | 12 ------------ onnxruntime/core/framework/provider_bridge_ort.cc | 9 ++------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc index e991502745..a656197995 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc @@ -9,18 +9,6 @@ namespace onnxruntime { namespace contrib { - -Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, - const TensorShape& input_shape, - const TensorShape& weights_shape, - const TensorShape& bias_shape, - const TensorShape& mask_shape, - const TensorShape& global_weights_shape, - const TensorShape& global_bias_shape, - const TensorShape& global_shape) { - return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape); -} - namespace embed_layer_norm { Status CheckInputs(const OpKernelContext* context) { diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index 84067dea51..1cd612bb34 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -46,11 +46,6 @@ #include "contrib_ops/cpu/bert/bias_gelu_helper.h" #include "contrib_ops/cpu/bert/embed_layer_norm_helper.h" #include "contrib_ops/cpu/bert/longformer_attention_base.h" -namespace onnxruntime { -namespace contrib { -Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape); -} -} // namespace onnxruntime #include "contrib_ops/cpu/bert/attention_base.h" #endif @@ -839,7 +834,7 @@ struct ProviderHostImpl : ProviderHost { Status embed_layer_norm__CheckInputs(const OpKernelContext* context) override { return contrib::embed_layer_norm::CheckInputs(context); } Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) override { return contrib::bias_gelu_helper::CheckInputs(context); } Status LongformerAttentionBase__CheckInputs(const contrib::LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape) override { - return contrib::LongformerAttentionBase__CheckInputs(p, input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape); + return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape); } Status AttentionBase__CheckInputs(const contrib::AttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, const int max_threads_per_block) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, max_threads_per_block); } Tensor* AttentionBase__GetPresent(const contrib::AttentionBase* p, OpKernelContext* context, const Tensor* past, int batch_size, int head_size, int sequence_length, int& past_sequence_length) override { return p->contrib::AttentionBase::GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); } @@ -953,7 +948,7 @@ struct ProviderLibrary { static ProviderLibrary s_library_cuda(LIBRARY_PREFIX "onnxruntime_providers_cuda" LIBRARY_EXTENSION #ifndef _WIN32 , - false /* unload - On Linux if we unload the cuda shared provider we crash */ + false /* unload - On Linux if we unload the cuda shared provider we crash. On Windows we'll crash if we don't */ #endif ); static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);