Revert "Remove LongformerAttentionBase workaround"

This reverts commit 648679b370.
This commit is contained in:
Ryan Hill 2021-05-10 14:56:05 -07:00
parent 567133235d
commit 7bdf68ffa2
2 changed files with 19 additions and 2 deletions

View file

@ -9,6 +9,18 @@
namespace onnxruntime {
namespace contrib {
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p,
const TensorShape& input_shape,
const TensorShape& weights_shape,
const TensorShape& bias_shape,
const TensorShape& mask_shape,
const TensorShape& global_weights_shape,
const TensorShape& global_bias_shape,
const TensorShape& global_shape) {
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
}
namespace embed_layer_norm {
Status CheckInputs(const OpKernelContext* context) {

View file

@ -46,6 +46,11 @@
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
#include "contrib_ops/cpu/bert/embed_layer_norm_helper.h"
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
namespace onnxruntime {
namespace contrib {
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape);
}
} // namespace onnxruntime
#include "contrib_ops/cpu/bert/attention_base.h"
#endif
@ -834,7 +839,7 @@ struct ProviderHostImpl : ProviderHost {
Status embed_layer_norm__CheckInputs(const OpKernelContext* context) override { return contrib::embed_layer_norm::CheckInputs(context); }
Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) override { return contrib::bias_gelu_helper::CheckInputs(context); }
Status LongformerAttentionBase__CheckInputs(const contrib::LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape) override {
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
return contrib::LongformerAttentionBase__CheckInputs(p, input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
}
Status AttentionBase__CheckInputs(const contrib::AttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, const int max_threads_per_block) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, max_threads_per_block); }
Tensor* AttentionBase__GetPresent(const contrib::AttentionBase* p, OpKernelContext* context, const Tensor* past, int batch_size, int head_size, int sequence_length, int& past_sequence_length) override { return p->contrib::AttentionBase::GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); }
@ -948,7 +953,7 @@ struct ProviderLibrary {
static ProviderLibrary s_library_cuda(LIBRARY_PREFIX "onnxruntime_providers_cuda" LIBRARY_EXTENSION
#ifndef _WIN32
,
false /* unload - On Linux if we unload the cuda shared provider we crash. On Windows we'll crash if we don't */
false /* unload - On Linux if we unload the cuda shared provider we crash */
#endif
);
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);