hHandle upper-cased subscript labels in Einsum (#4964)

This commit is contained in:
Hariharan Seshadri 2020-08-29 15:18:21 -07:00 committed by GitHub
parent f4b057b098
commit 7080e485a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 6 deletions

View file

@ -134,8 +134,10 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
num_of_ellipsis_dims_ = static_cast<size_t>(current_num_of_ellipsis_dims);
}
// We reserve 'EinsumOp::num_of_letters' for broadcasted dims as we only allow 'a' - 'z' (0 - 25) for non-broadcasted dims
// We will assign appropriate indices (based on number of dimensions the ellipsis corresponds to) during broadcasting related post-processing
// We reserve 'EinsumOp::num_of_letters' for broadcasted dims as we only allow 'a' - 'z'
// and 'A' - 'Z' (0 - 51) for non-broadcasted dims.
// We will assign appropriate indices (based on number of dimensions the ellipsis corresponds to)
// during broadcasting related post-processing.
for (size_t i = 0; i < num_of_ellipsis_dims_; ++i) {
current_subscript_indices.push_back(EinsumOp::num_of_letters);
}
@ -150,12 +152,13 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
"Found '.' not part of an ellipsis in input: ", input_index);
}
if (!(subscript_label >= 'a' && subscript_label <= 'z')) {
auto letter_index = EinsumOp::LetterToIndex(subscript_label);
if (letter_index == -1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The only subscript labels allowed are lowercase letters (a-z)");
"The only subscript labels allowed are lower-cased letters (a-z) and "
"upper-cased letters (A-Z)");
}
auto letter_index = static_cast<int64_t>(subscript_label - 'a');
auto dim_value = dims[dim_counter];
// Subscript label not found in global subscript label array

View file

@ -19,7 +19,27 @@ namespace onnxruntime {
namespace EinsumOp {
constexpr size_t num_of_letters = 26;
// Einsum accepts 'a' - 'z' and 'A' - 'Z' and needs to differentiate between lower-cased
// and upper-cased letters in the equation string (26 * 2 = 52).
constexpr size_t num_of_letters = 52;
/** Returns the index associated with the input character
* Returns a value between 0 - 25 for input in 'a' - 'z'
* Returns a value between 26 - 51 for input in 'A' - 'Z'
* Returns -1 for invalid input not in 'a' - 'z' or 'A' - 'Z' (caller should handle the returned result)
*/
inline int64_t LetterToIndex(char ch) {
if (ch >= 'a' && ch <= 'z') {
return static_cast<int64_t>(ch - 'a');
}
if (ch >= 'A' && ch <= 'Z') {
return 26 + static_cast<int64_t>(ch - 'A');
}
// invalid character - return error value
return -1;
}
} // namespace EinsumOp

View file

@ -177,6 +177,16 @@ TEST(Einsum, ExplicitEinsumAsMatmul) {
test.Run();
}
TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
// 'K' != 'k' (and dim values differ too) and Einsum should handle be able to handle that
test.AddAttribute<std::string>("equation", "iK,Kk->ik");
test.AddInput<float>("x", {2, 1}, {1.f, 2.f});
test.AddInput<float>("y", {1, 2}, {1.f, 2.f});
test.AddOutput<float>("o", {2, 2}, {1.f, 2.f, 2.f, 4.f});
test.Run();
}
TEST(Einsum, ExplicitEinsumAsMatmul_Multi_Input) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ij,jk,kl->li");