mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Another build test
This commit is contained in:
parent
8159924918
commit
03bdd1cf62
4 changed files with 20 additions and 17 deletions
|
|
@ -11,8 +11,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
void Link_embed_layer_norm() {}
|
||||
|
||||
// These ops are internal-only, so register outside of onnx
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
|
|
|
|||
|
|
@ -5,9 +5,22 @@
|
|||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
|
||||
#include "longformer_attention_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
void Link_embed_layer_norm_helper() {}
|
||||
|
||||
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p,
|
||||
const TensorShape& input_shape,
|
||||
const TensorShape& weights_shape,
|
||||
const TensorShape& bias_shape,
|
||||
const TensorShape& mask_shape,
|
||||
const TensorShape& global_weights_shape,
|
||||
const TensorShape& global_bias_shape,
|
||||
const TensorShape& global_shape) {
|
||||
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
|
||||
}
|
||||
|
||||
namespace embed_layer_norm {
|
||||
|
||||
Status CheckInputs(const OpKernelContext* context) {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
void Link_longformer_attention_base() {}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#include "longformer_attention_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -45,6 +45,11 @@
|
|||
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
|
||||
#include "contrib_ops/cpu/bert/embed_layer_norm_helper.h"
|
||||
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -98,8 +103,6 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef;
|
|||
namespace onnxruntime {
|
||||
|
||||
namespace contrib {
|
||||
void Link_embed_layer_norm();
|
||||
void Link_longformer_attention_base();
|
||||
void Link_embed_layer_norm_helper();
|
||||
|
||||
} // namespace contrib
|
||||
|
|
@ -832,11 +835,7 @@ struct ProviderHostImpl : ProviderHost {
|
|||
Status embed_layer_norm__CheckInputs(const OpKernelContext* context) override { return contrib::embed_layer_norm::CheckInputs(context); }
|
||||
Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) override { return contrib::bias_gelu_helper::CheckInputs(context); }
|
||||
Status LongformerAttentionBase__CheckInputs(const contrib::LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape) override {
|
||||
contrib::Link_embed_layer_norm();
|
||||
contrib::Link_longformer_attention_base();
|
||||
contrib::Link_embed_layer_norm_helper();
|
||||
|
||||
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
|
||||
return contrib::LongformerAttentionBase__CheckInputs(p, input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
|
||||
}
|
||||
Status AttentionBase__CheckInputs(const contrib::AttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, const int max_threads_per_block) override { return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, max_threads_per_block); }
|
||||
Tensor* AttentionBase__GetPresent(const contrib::AttentionBase* p, OpKernelContext* context, const Tensor* past, int batch_size, int head_size, int sequence_length, int& past_sequence_length) override { return p->GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); }
|
||||
|
|
|
|||
Loading…
Reference in a new issue