Add support for Trilu<bool>. (#20917)

### Description
<!-- Describe your changes. -->
Trilu<bool> is used by phi-3 when exported with torch.onnx.export.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Scott McKay 2024-06-06 15:21:34 +10:00 committed by GitHub
parent eb2ec66716
commit 3ecf48e3b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 118 additions and 314 deletions

View file

@ -421,7 +421,7 @@ Do not modify directly.*
|Transpose|*in* data:**T**<br> *out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)|
|Unique|*in* X:**T**<br> *out* Y:**T**<br> *out* indices:**tensor(int64)**<br> *out* inverse_indices:**tensor(int64)**<br> *out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)|
|Unsqueeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* expanded:**T**<br><br>or<br><br>*in* data:**T**<br> *out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

View file

@ -31,7 +31,7 @@ ONNX_OPERATOR_KERNEL_EX(
kOnnxDomain,
14,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t>()),
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t, bool>()),
Trilu);
template <typename T>
@ -110,6 +110,9 @@ Status Trilu::Compute(OpKernelContext* ctx) const {
case sizeof(double):
status = TriluImpl<double>(X, Y, k_val, up);
break;
case sizeof(bool):
status = TriluImpl<bool>(X, Y, k_val, up);
break;
default:
ORT_THROW("Unsupported input data type of ", data_type);
}

View file

@ -62,63 +62,54 @@ TEST(TriluOpTest, two_by_two_long_lower) {
test.Run();
}
TEST(TriluOpTest, two_by_two_bool_upper) {
OpTester test("Trilu", 14, kOnnxDomain);
int64_t up = 1;
test.AddAttribute("upper", up);
test.AddInput<bool>("X", {2, 2},
{true, true,
true, true});
test.AddOutput<bool>("Y", {2, 2},
{true, true,
false, true});
test.Run();
}
TEST(TriluOpTest, three_by_three_bool_lower) {
OpTester test("Trilu", 14, kOnnxDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<bool>("X", {3, 3},
// include a couple of false values to check they are copied
{true, true, true,
true, false, true,
true, true, false});
test.AddOutput<bool>("Y", {3, 3},
{true, false, false,
true, false, false,
true, true, false});
test.Run();
}
TEST(TriluOpTest, three_dim_float_upper) {
OpTester test("Trilu", 14, kOnnxDomain);
test.AddInput<float>("X", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f});
test.AddInput<int64_t>("k", {1}, {1});
test.AddOutput<float>("Y", {2, 3, 4},
{
0.f,
1.f,
5.f,
8.f,
0.f,
0.f,
2.f,
4.f,
0.f,
0.f,
0.f,
3.f,
0.f,
6.f,
2.f,
1.f,
0.f,
0.f,
5.f,
8.f,
0.f,
0.f,
0.f,
4.f,
});
{0.f, 1.f, 5.f, 8.f,
0.f, 0.f, 2.f, 4.f,
0.f, 0.f, 0.f, 3.f,
0.f, 6.f, 2.f, 1.f,
0.f, 0.f, 5.f, 8.f,
0.f, 0.f, 0.f, 4.f});
test.Run();
}
@ -127,60 +118,22 @@ TEST(TriluOpTest, three_dim_float_lower) {
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f});
test.AddInput<int64_t>("k", {1}, {1});
test.AddOutput<float>("Y", {2, 3, 4},
{
4.f,
1.f,
0.f,
0.f,
4.f,
3.f,
2.f,
0.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
0.f,
0.f,
4.f,
1.f,
5.f,
0.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 0.f, 0.f,
4.f, 3.f, 2.f, 0.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 0.f, 0.f,
4.f, 1.f, 5.f, 0.f,
4.f, 3.f, 2.f, 4.f});
test.Run();
}
@ -189,60 +142,22 @@ TEST(TriluOpTest, neg_k_float_upper) {
int64_t up = 1;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f});
test.AddInput<int64_t>("k", {1}, {-1});
test.AddOutput<float>("Y", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
0.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
0.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
0.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
0.f, 3.f, 2.f, 4.f});
test.Run();
}
@ -251,120 +166,44 @@ TEST(TriluOpTest, neg_k_float_lower) {
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f});
test.AddInput<int64_t>("k", {1}, {-1});
test.AddOutput<float>("Y", {2, 3, 4},
{
0.f,
0.f,
0.f,
0.f,
4.f,
0.f,
0.f,
0.f,
6.f,
1.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
4.f,
0.f,
0.f,
0.f,
4.f,
3.f,
0.f,
0.f,
});
{0.f, 0.f, 0.f, 0.f,
4.f, 0.f, 0.f, 0.f,
6.f, 1.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
4.f, 0.f, 0.f, 0.f,
4.f, 3.f, 0.f, 0.f});
test.Run();
}
TEST(TriluTest, small_k_float_upper) {
OpTester test("Trilu", 14, kOnnxDomain);
test.AddInput<float>("X", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f});
test.AddInput<int64_t>("k", {1}, {-5});
test.AddOutput<float>("Y", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f});
test.Run();
}
@ -373,60 +212,22 @@ TEST(TriluOpTest, small_k_float_lower) {
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
6.f,
1.f,
2.f,
3.f,
1.f,
6.f,
2.f,
1.f,
4.f,
1.f,
5.f,
8.f,
4.f,
3.f,
2.f,
4.f,
});
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f});
test.AddInput<int64_t>("k", {1}, {-5});
test.AddOutput<float>("Y", {2, 3, 4},
{
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
0.f,
});
{0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f});
test.Run();
}