mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Allow Upper case letters in RHS of einsum equations. (#5569)
Co-authored-by: Andrew McDowell <andrew@neva-labs.com>
This commit is contained in:
parent
51af108af5
commit
b2da700e4d
2 changed files with 14 additions and 5 deletions
|
|
@ -369,14 +369,13 @@ Status EinsumComputePreprocessor::CalculateOutputShape() {
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Found '.' not part of an ellipsis in the output subscript provided");
|
||||
}
|
||||
|
||||
if (!(subscript_label >= 'a' && subscript_label <= 'z')) {
|
||||
auto letter_index = EinsumOp::LetterToIndex(subscript_label);
|
||||
if (letter_index == -1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"The only subscript labels allowed in the output subscript "
|
||||
"are lowercase letters (a-z)");
|
||||
"The only subscript labels allowed are lower-cased letters (a-z) and "
|
||||
"upper-cased letters (A-Z)");
|
||||
}
|
||||
|
||||
auto letter_index = static_cast<int64_t>(subscript_label - 'a');
|
||||
|
||||
if (output_letter_to_count[letter_index] != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Output subscript contains repeated letters");
|
||||
|
|
|
|||
|
|
@ -187,6 +187,16 @@ TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedOutputLabel) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
// Einsum should handle be able to handle upper case on both LHS and RHS
|
||||
test.AddAttribute<std::string>("equation", "Ki,ik->Kk");
|
||||
test.AddInput<float>("x", {2, 1}, {1.f, 2.f});
|
||||
test.AddInput<float>("y", {1, 2}, {1.f, 2.f});
|
||||
test.AddOutput<float>("o", {2, 2}, {1.f, 2.f, 2.f, 4.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmul_Multi_Input) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "ij,jk,kl->li");
|
||||
|
|
|
|||
Loading…
Reference in a new issue