mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add support for Trilu<bool>. (#20917)
### Description <!-- Describe your changes. --> Trilu<bool> is used by phi-3 when exported with torch.onnx.export. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
eb2ec66716
commit
3ecf48e3b5
3 changed files with 118 additions and 314 deletions
|
|
@ -421,7 +421,7 @@ Do not modify directly.*
|
|||
|Transpose|*in* data:**T**<br> *out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|
||||
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
|
||||
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)|
|
||||
|Unique|*in* X:**T**<br> *out* Y:**T**<br> *out* indices:**tensor(int64)**<br> *out* inverse_indices:**tensor(int64)**<br> *out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)|
|
||||
|Unsqueeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* expanded:**T**<br><br>or<br><br>*in* data:**T**<br> *out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kOnnxDomain,
|
||||
14,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t>()),
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t, bool>()),
|
||||
Trilu);
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -110,6 +110,9 @@ Status Trilu::Compute(OpKernelContext* ctx) const {
|
|||
case sizeof(double):
|
||||
status = TriluImpl<double>(X, Y, k_val, up);
|
||||
break;
|
||||
case sizeof(bool):
|
||||
status = TriluImpl<bool>(X, Y, k_val, up);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported input data type of ", data_type);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,63 +62,54 @@ TEST(TriluOpTest, two_by_two_long_lower) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluOpTest, two_by_two_bool_upper) {
|
||||
OpTester test("Trilu", 14, kOnnxDomain);
|
||||
int64_t up = 1;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<bool>("X", {2, 2},
|
||||
{true, true,
|
||||
true, true});
|
||||
test.AddOutput<bool>("Y", {2, 2},
|
||||
{true, true,
|
||||
false, true});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluOpTest, three_by_three_bool_lower) {
|
||||
OpTester test("Trilu", 14, kOnnxDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<bool>("X", {3, 3},
|
||||
// include a couple of false values to check they are copied
|
||||
{true, true, true,
|
||||
true, false, true,
|
||||
true, true, false});
|
||||
test.AddOutput<bool>("Y", {3, 3},
|
||||
{true, false, false,
|
||||
true, false, false,
|
||||
true, true, false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluOpTest, three_dim_float_upper) {
|
||||
OpTester test("Trilu", 14, kOnnxDomain);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.AddInput<int64_t>("k", {1}, {1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{
|
||||
0.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
0.f,
|
||||
0.f,
|
||||
2.f,
|
||||
4.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
3.f,
|
||||
0.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
0.f,
|
||||
0.f,
|
||||
5.f,
|
||||
8.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
4.f,
|
||||
});
|
||||
{0.f, 1.f, 5.f, 8.f,
|
||||
0.f, 0.f, 2.f, 4.f,
|
||||
0.f, 0.f, 0.f, 3.f,
|
||||
|
||||
0.f, 6.f, 2.f, 1.f,
|
||||
0.f, 0.f, 5.f, 8.f,
|
||||
0.f, 0.f, 0.f, 4.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -127,60 +118,22 @@ TEST(TriluOpTest, three_dim_float_lower) {
|
|||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.AddInput<int64_t>("k", {1}, {1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
0.f,
|
||||
0.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
0.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
0.f,
|
||||
0.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
0.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 0.f, 0.f,
|
||||
4.f, 3.f, 2.f, 0.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 0.f, 0.f,
|
||||
4.f, 1.f, 5.f, 0.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -189,60 +142,22 @@ TEST(TriluOpTest, neg_k_float_upper) {
|
|||
int64_t up = 1;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.AddInput<int64_t>("k", {1}, {-1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
0.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
0.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
0.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
0.f, 3.f, 2.f, 4.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -251,120 +166,44 @@ TEST(TriluOpTest, neg_k_float_lower) {
|
|||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.AddInput<int64_t>("k", {1}, {-1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
4.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
6.f,
|
||||
1.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
4.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
4.f,
|
||||
3.f,
|
||||
0.f,
|
||||
0.f,
|
||||
});
|
||||
{0.f, 0.f, 0.f, 0.f,
|
||||
4.f, 0.f, 0.f, 0.f,
|
||||
6.f, 1.f, 0.f, 0.f,
|
||||
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
4.f, 0.f, 0.f, 0.f,
|
||||
4.f, 3.f, 0.f, 0.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluTest, small_k_float_upper) {
|
||||
OpTester test("Trilu", 14, kOnnxDomain);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.AddInput<int64_t>("k", {1}, {-5});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -373,60 +212,22 @@ TEST(TriluOpTest, small_k_float_lower) {
|
|||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
6.f,
|
||||
1.f,
|
||||
2.f,
|
||||
3.f,
|
||||
1.f,
|
||||
6.f,
|
||||
2.f,
|
||||
1.f,
|
||||
4.f,
|
||||
1.f,
|
||||
5.f,
|
||||
8.f,
|
||||
4.f,
|
||||
3.f,
|
||||
2.f,
|
||||
4.f,
|
||||
});
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f});
|
||||
test.AddInput<int64_t>("k", {1}, {-5});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
0.f,
|
||||
});
|
||||
{0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue