mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Revert "Remove LongformerAttentionBase workaround"
This reverts commit 648679b370.
This commit is contained in:
parent
567133235d
commit
7bdf68ffa2
2 changed files with 19 additions and 2 deletions
|
|
@ -9,6 +9,18 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p,
|
||||
const TensorShape& input_shape,
|
||||
const TensorShape& weights_shape,
|
||||
const TensorShape& bias_shape,
|
||||
const TensorShape& mask_shape,
|
||||
const TensorShape& global_weights_shape,
|
||||
const TensorShape& global_bias_shape,
|
||||
const TensorShape& global_shape) {
|
||||
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
|
||||
}
|
||||
|
||||
namespace embed_layer_norm {
|
||||
|
||||
Status CheckInputs(const OpKernelContext* context) {
|
||||
|
|
|
|||
|
|
@ -46,6 +46,11 @@
|
|||
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
|
||||
#include "contrib_ops/cpu/bert/embed_layer_norm_helper.h"
|
||||
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -834,7 +839,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
Status embed_layer_norm__CheckInputs(const OpKernelContext* context) override { return contrib::embed_layer_norm::CheckInputs(context); }
|
||||
Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) override { return contrib::bias_gelu_helper::CheckInputs(context); }
|
||||
Status LongformerAttentionBase__CheckInputs(const contrib::LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape) override {
|
||||
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
|
||||
return contrib::LongformerAttentionBase__CheckInputs(p, input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
|
||||
}
|
||||
Status AttentionBase__CheckInputs(const contrib::AttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, const int max_threads_per_block) override { return p->contrib::AttentionBase::CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, max_threads_per_block); }
|
||||
Tensor* AttentionBase__GetPresent(const contrib::AttentionBase* p, OpKernelContext* context, const Tensor* past, int batch_size, int head_size, int sequence_length, int& past_sequence_length) override { return p->contrib::AttentionBase::GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); }
|
||||
|
|
@ -948,7 +953,7 @@ struct ProviderLibrary {
|
|||
static ProviderLibrary s_library_cuda(LIBRARY_PREFIX "onnxruntime_providers_cuda" LIBRARY_EXTENSION
|
||||
#ifndef _WIN32
|
||||
,
|
||||
false /* unload - On Linux if we unload the cuda shared provider we crash. On Windows we'll crash if we don't */
|
||||
false /* unload - On Linux if we unload the cuda shared provider we crash */
|
||||
#endif
|
||||
);
|
||||
static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
|
||||
|
|
|
|||
Loading…
Reference in a new issue