Another build test

This commit is contained in:
Ryan Hill 2021-04-19 14:26:38 -07:00
parent 8159924918
commit 03bdd1cf62
4 changed files with 20 additions and 17 deletions

View file

@ -11,8 +11,6 @@
namespace onnxruntime {
namespace contrib {
void Link_embed_layer_norm() {}
// These ops are internal-only, so register outside of onnx
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \

View file

@ -5,9 +5,22 @@
#include "core/framework/tensorprotoutils.h"
#include "onnx/defs/tensor_proto_util.h"
#include "longformer_attention_base.h"
namespace onnxruntime {
namespace contrib {
void Link_embed_layer_norm_helper() {}
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p,
const TensorShape& input_shape,
const TensorShape& weights_shape,
const TensorShape& bias_shape,
const TensorShape& mask_shape,
const TensorShape& global_weights_shape,
const TensorShape& global_bias_shape,
const TensorShape& global_shape) {
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
}
namespace embed_layer_norm {
Status CheckInputs(const OpKernelContext* context) {

View file

@ -1,13 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace onnxruntime {
namespace contrib {
void Link_longformer_attention_base() {}
} // namespace contrib
} // namespace onnxruntime
#include "longformer_attention_base.h"
namespace onnxruntime {

View file

@ -45,6 +45,11 @@
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
#include "contrib_ops/cpu/bert/embed_layer_norm_helper.h"
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
namespace onnxruntime {
namespace contrib {
Status LongformerAttentionBase__CheckInputs(const LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape);
}
} // namespace onnxruntime
#include "contrib_ops/cpu/bert/attention_base.h"
#endif
@ -98,8 +103,6 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef;
namespace onnxruntime {
namespace contrib {
void Link_embed_layer_norm();
void Link_longformer_attention_base();
void Link_embed_layer_norm_helper();
} // namespace contrib
@ -832,11 +835,7 @@ struct ProviderHostImpl : ProviderHost {
Status embed_layer_norm__CheckInputs(const OpKernelContext* context) override { return contrib::embed_layer_norm::CheckInputs(context); }
Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) override { return contrib::bias_gelu_helper::CheckInputs(context); }
Status LongformerAttentionBase__CheckInputs(const contrib::LongformerAttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const TensorShape& mask_shape, const TensorShape& global_weights_shape, const TensorShape& global_bias_shape, const TensorShape& global_shape) override {
contrib::Link_embed_layer_norm();
contrib::Link_longformer_attention_base();
contrib::Link_embed_layer_norm_helper();
return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
return contrib::LongformerAttentionBase__CheckInputs(p, input_shape, weights_shape, bias_shape, mask_shape, global_weights_shape, global_bias_shape, global_shape);
}
Status AttentionBase__CheckInputs(const contrib::AttentionBase* p, const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, const Tensor* past, const int max_threads_per_block) override { return p->CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, max_threads_per_block); }
Tensor* AttentionBase__GetPresent(const contrib::AttentionBase* p, OpKernelContext* context, const Tensor* past, int batch_size, int head_size, int sequence_length, int& past_sequence_length) override { return p->GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); }