Allow Upper case letters in RHS of einsum equations. (#5569)

Co-authored-by: Andrew McDowell <andrew@neva-labs.com>
This commit is contained in:
Andrew McDowell 2020-10-26 01:11:12 +00:00 committed by GitHub
parent 51af108af5
commit b2da700e4d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 5 deletions

View file

@ -369,14 +369,13 @@ Status EinsumComputePreprocessor::CalculateOutputShape() {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Found '.' not part of an ellipsis in the output subscript provided");
}
if (!(subscript_label >= 'a' && subscript_label <= 'z')) {
auto letter_index = EinsumOp::LetterToIndex(subscript_label);
if (letter_index == -1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The only subscript labels allowed in the output subscript "
"are lowercase letters (a-z)");
"The only subscript labels allowed are lower-cased letters (a-z) and "
"upper-cased letters (A-Z)");
}
auto letter_index = static_cast<int64_t>(subscript_label - 'a');
if (output_letter_to_count[letter_index] != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Output subscript contains repeated letters");

View file

@ -187,6 +187,16 @@ TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) {
test.Run();
}
TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedOutputLabel) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
// Einsum should handle be able to handle upper case on both LHS and RHS
test.AddAttribute<std::string>("equation", "Ki,ik->Kk");
test.AddInput<float>("x", {2, 1}, {1.f, 2.f});
test.AddInput<float>("y", {1, 2}, {1.f, 2.f});
test.AddOutput<float>("o", {2, 2}, {1.f, 2.f, 2.f, 4.f});
test.Run();
}
TEST(Einsum, ExplicitEinsumAsMatmul_Multi_Input) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ij,jk,kl->li");