mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Add initializer for embed layer norm unit tests. (#8196)
This commit is contained in:
parent
9ec0fd6a1c
commit
507d97b200
2 changed files with 62 additions and 35 deletions
|
|
@ -64,19 +64,28 @@ static void RunTest(const embedlayernorm::OpData& data,
|
|||
if (use_float16) {
|
||||
tester.AddInput<MLFloat16>("word_embedding",
|
||||
word_embedding_dims,
|
||||
ToFloat16(data.word_embedding_data));
|
||||
ToFloat16(data.word_embedding_data),
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<MLFloat16>("position_embedding",
|
||||
position_embedding_dims,
|
||||
ToFloat16(data.position_embedding_data));
|
||||
ToFloat16(data.position_embedding_data),
|
||||
/*is_initializer=*/true);
|
||||
if (!data.has_segment) {
|
||||
tester.AddMissingOptionalInput<MLFloat16>();
|
||||
} else {
|
||||
tester.AddInput<MLFloat16>("segment_embedding",
|
||||
segment_embedding_dims,
|
||||
ToFloat16(data.segment_embedding_data));
|
||||
ToFloat16(data.segment_embedding_data),
|
||||
/*is_initializer=*/true);
|
||||
}
|
||||
tester.AddInput<MLFloat16>("gamma", gamma_dims, ToFloat16(data.gamma_data));
|
||||
tester.AddInput<MLFloat16>("beta", beta_dims, ToFloat16(data.beta_data));
|
||||
tester.AddInput<MLFloat16>("gamma",
|
||||
gamma_dims,
|
||||
ToFloat16(data.gamma_data),
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<MLFloat16>("beta",
|
||||
beta_dims,
|
||||
ToFloat16(data.beta_data),
|
||||
/*is_initializer=*/true);
|
||||
tester.AddAttribute("epsilon", data.epsilon);
|
||||
if (data.has_mask) {
|
||||
tester.AddInput<int32_t>("mask", mask_dims, data.mask_data);
|
||||
|
|
@ -85,19 +94,22 @@ static void RunTest(const embedlayernorm::OpData& data,
|
|||
} else {
|
||||
tester.AddInput<float>("word_embedding",
|
||||
word_embedding_dims,
|
||||
data.word_embedding_data);
|
||||
data.word_embedding_data,
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<float>("position_embedding",
|
||||
position_embedding_dims,
|
||||
data.position_embedding_data);
|
||||
data.position_embedding_data,
|
||||
/*is_initializer=*/true);
|
||||
if (!data.has_segment) {
|
||||
tester.AddMissingOptionalInput<MLFloat16>();
|
||||
} else {
|
||||
tester.AddInput<float>("segment_embedding",
|
||||
segment_embedding_dims,
|
||||
data.segment_embedding_data);
|
||||
data.segment_embedding_data,
|
||||
/*is_initializer=*/true);
|
||||
}
|
||||
tester.AddInput<float>("gamma", gamma_dims, data.gamma_data);
|
||||
tester.AddInput<float>("beta", beta_dims, data.beta_data);
|
||||
tester.AddInput<float>("gamma", gamma_dims, data.gamma_data, /*is_initializer=*/true);
|
||||
tester.AddInput<float>("beta", beta_dims, data.beta_data, /*is_initializer=*/true);
|
||||
tester.AddAttribute("epsilon", data.epsilon);
|
||||
if (data.has_mask) {
|
||||
tester.AddInput<int32_t>("mask", mask_dims, data.mask_data);
|
||||
|
|
|
|||
|
|
@ -13,22 +13,22 @@ namespace test {
|
|||
namespace {
|
||||
|
||||
static void RunTest(const embedlayernorm::OpData& data,
|
||||
float accuracy_threshold = 0.25f) {
|
||||
float accuracy_threshold = 0.25f) {
|
||||
ASSERT_TRUE(data.word_embedding_data.size() % data.hidden_size == 0);
|
||||
ASSERT_TRUE(data.position_embedding_data.size() % data.hidden_size == 0);
|
||||
ASSERT_TRUE(data.segment_embedding_data.size() % data.hidden_size == 0);
|
||||
|
||||
std::vector<int64_t> input_ids_dims = {data.batch_size, data.sequence_size};
|
||||
std::vector<int64_t> segment_ids_dims = {data.batch_size, data.sequence_size};
|
||||
std::vector<int64_t> word_embedding_dims = {
|
||||
static_cast<int64_t>(data.word_embedding_data.size() / data.hidden_size),
|
||||
data.hidden_size};
|
||||
std::vector<int64_t> position_embedding_dims = {
|
||||
static_cast<int64_t>(data.position_embedding_data.size() / data.hidden_size),
|
||||
data.hidden_size};
|
||||
std::vector<int64_t> segment_embedding_dims = {
|
||||
static_cast<int64_t>(data.segment_embedding_data.size() / data.hidden_size),
|
||||
data.hidden_size};
|
||||
std::vector<int64_t> word_embedding_dims = {
|
||||
static_cast<int64_t>(data.word_embedding_data.size() / data.hidden_size),
|
||||
data.hidden_size};
|
||||
std::vector<int64_t> position_embedding_dims = {
|
||||
static_cast<int64_t>(data.position_embedding_data.size() / data.hidden_size),
|
||||
data.hidden_size};
|
||||
std::vector<int64_t> segment_embedding_dims = {
|
||||
static_cast<int64_t>(data.segment_embedding_data.size() / data.hidden_size),
|
||||
data.hidden_size};
|
||||
std::vector<int64_t> gamma_dims = {data.hidden_size};
|
||||
std::vector<int64_t> beta_dims = {data.hidden_size};
|
||||
std::vector<int64_t> output_dims = {data.batch_size, data.sequence_size, data.hidden_size};
|
||||
|
|
@ -80,23 +80,28 @@ static void RunTest(const embedlayernorm::OpData& data,
|
|||
// Quantized initializer inputs:
|
||||
tester.AddInput<uint8_t>("word_embedding_data",
|
||||
word_embedding_dims,
|
||||
word_embedding_data_quant);
|
||||
word_embedding_data_quant,
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<uint8_t>("position_embedding_data",
|
||||
position_embedding_dims,
|
||||
position_embedding_data_quant);
|
||||
position_embedding_data_quant,
|
||||
/*is_initializer=*/true);
|
||||
if (data.has_segment) {
|
||||
tester.AddInput<uint8_t>("segment_embedding_data",
|
||||
segment_embedding_dims,
|
||||
segment_embedding_data_quant);
|
||||
segment_embedding_data_quant,
|
||||
/*is_initializer=*/true);
|
||||
} else {
|
||||
tester.AddMissingOptionalInput<uint8_t>();
|
||||
}
|
||||
tester.AddInput<uint8_t>("gamma",
|
||||
gamma_dims,
|
||||
gamma_data_quant);
|
||||
gamma_data_quant,
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<uint8_t>("beta",
|
||||
beta_dims,
|
||||
beta_data_quant);
|
||||
beta_data_quant,
|
||||
/*is_initializer=*/true);
|
||||
if (data.has_mask) {
|
||||
std::vector<int64_t> mask_dims = {data.batch_size, data.sequence_size};
|
||||
tester.AddInput<int32_t>("mask", mask_dims, data.mask_data);
|
||||
|
|
@ -107,44 +112,54 @@ static void RunTest(const embedlayernorm::OpData& data,
|
|||
// Quantized scales:
|
||||
tester.AddInput<float>("word_embedding_scale",
|
||||
/*dims=*/{},
|
||||
{word_embedding_scale});
|
||||
{word_embedding_scale},
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<float>("position_embedding_scale",
|
||||
/*dims=*/{},
|
||||
{position_embedding_scale});
|
||||
{position_embedding_scale},
|
||||
/*is_initializer=*/true);
|
||||
if (data.has_segment) {
|
||||
tester.AddInput<float>("segment_embedding_scale",
|
||||
/*dims=*/{},
|
||||
{segment_embedding_scale});
|
||||
{segment_embedding_scale},
|
||||
/*is_initializer=*/true);
|
||||
} else {
|
||||
tester.AddMissingOptionalInput<float>();
|
||||
}
|
||||
tester.AddInput<float>("gamma_scale",
|
||||
/*dims=*/{},
|
||||
{gamma_scale});
|
||||
{gamma_scale},
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<float>("beta_scale",
|
||||
/*dims=*/{},
|
||||
{beta_scale});
|
||||
{beta_scale},
|
||||
/*is_initializer=*/true);
|
||||
|
||||
// Quantized zero points:
|
||||
tester.AddInput<uint8_t>("word_embedding_zero_point",
|
||||
/*dims=*/{},
|
||||
{word_embedding_zero_point});
|
||||
{word_embedding_zero_point},
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<uint8_t>("position_embedding_zero_point",
|
||||
/*dims=*/{},
|
||||
{position_embedding_zero_point});
|
||||
{position_embedding_zero_point},
|
||||
/*is_initializer=*/true);
|
||||
if (data.has_segment) {
|
||||
tester.AddInput<uint8_t>("segment_embedding_zero_point",
|
||||
/*dims=*/{},
|
||||
{segment_embedding_zero_point});
|
||||
{segment_embedding_zero_point},
|
||||
/*is_initializer=*/true);
|
||||
} else {
|
||||
tester.AddMissingOptionalInput<uint8_t>();
|
||||
}
|
||||
tester.AddInput<uint8_t>("gamma_zero_point",
|
||||
/*dims=*/{},
|
||||
{gamma_zero_point});
|
||||
{gamma_zero_point},
|
||||
/*is_initializer=*/true);
|
||||
tester.AddInput<uint8_t>("beta_zero_point",
|
||||
/*dims=*/{},
|
||||
{beta_zero_point});
|
||||
{beta_zero_point},
|
||||
/*is_initializer=*/true);
|
||||
// Outputs:
|
||||
tester.AddOutput<float>("output", output_dims, data.output_data);
|
||||
tester.AddOutput<int32_t>("mask_index", mask_index_dims, data.mask_index_data);
|
||||
|
|
|
|||
Loading…
Reference in a new issue