mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
hHandle upper-cased subscript labels in Einsum (#4964)
This commit is contained in:
parent
f4b057b098
commit
7080e485a3
3 changed files with 39 additions and 6 deletions
|
|
@ -134,8 +134,10 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
|
|||
num_of_ellipsis_dims_ = static_cast<size_t>(current_num_of_ellipsis_dims);
|
||||
}
|
||||
|
||||
// We reserve 'EinsumOp::num_of_letters' for broadcasted dims as we only allow 'a' - 'z' (0 - 25) for non-broadcasted dims
|
||||
// We will assign appropriate indices (based on number of dimensions the ellipsis corresponds to) during broadcasting related post-processing
|
||||
// We reserve 'EinsumOp::num_of_letters' for broadcasted dims as we only allow 'a' - 'z'
|
||||
// and 'A' - 'Z' (0 - 51) for non-broadcasted dims.
|
||||
// We will assign appropriate indices (based on number of dimensions the ellipsis corresponds to)
|
||||
// during broadcasting related post-processing.
|
||||
for (size_t i = 0; i < num_of_ellipsis_dims_; ++i) {
|
||||
current_subscript_indices.push_back(EinsumOp::num_of_letters);
|
||||
}
|
||||
|
|
@ -150,12 +152,13 @@ Status EinsumComputePreprocessor::ProcessSubscripts() {
|
|||
"Found '.' not part of an ellipsis in input: ", input_index);
|
||||
}
|
||||
|
||||
if (!(subscript_label >= 'a' && subscript_label <= 'z')) {
|
||||
auto letter_index = EinsumOp::LetterToIndex(subscript_label);
|
||||
if (letter_index == -1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"The only subscript labels allowed are lowercase letters (a-z)");
|
||||
"The only subscript labels allowed are lower-cased letters (a-z) and "
|
||||
"upper-cased letters (A-Z)");
|
||||
}
|
||||
|
||||
auto letter_index = static_cast<int64_t>(subscript_label - 'a');
|
||||
auto dim_value = dims[dim_counter];
|
||||
|
||||
// Subscript label not found in global subscript label array
|
||||
|
|
|
|||
|
|
@ -19,7 +19,27 @@ namespace onnxruntime {
|
|||
|
||||
namespace EinsumOp {
|
||||
|
||||
constexpr size_t num_of_letters = 26;
|
||||
// Einsum accepts 'a' - 'z' and 'A' - 'Z' and needs to differentiate between lower-cased
|
||||
// and upper-cased letters in the equation string (26 * 2 = 52).
|
||||
constexpr size_t num_of_letters = 52;
|
||||
|
||||
/** Returns the index associated with the input character
|
||||
* Returns a value between 0 - 25 for input in 'a' - 'z'
|
||||
* Returns a value between 26 - 51 for input in 'A' - 'Z'
|
||||
* Returns -1 for invalid input not in 'a' - 'z' or 'A' - 'Z' (caller should handle the returned result)
|
||||
*/
|
||||
inline int64_t LetterToIndex(char ch) {
|
||||
if (ch >= 'a' && ch <= 'z') {
|
||||
return static_cast<int64_t>(ch - 'a');
|
||||
}
|
||||
|
||||
if (ch >= 'A' && ch <= 'Z') {
|
||||
return 26 + static_cast<int64_t>(ch - 'A');
|
||||
}
|
||||
|
||||
// invalid character - return error value
|
||||
return -1;
|
||||
}
|
||||
|
||||
} // namespace EinsumOp
|
||||
|
||||
|
|
|
|||
|
|
@ -177,6 +177,16 @@ TEST(Einsum, ExplicitEinsumAsMatmul) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
// 'K' != 'k' (and dim values differ too) and Einsum should handle be able to handle that
|
||||
test.AddAttribute<std::string>("equation", "iK,Kk->ik");
|
||||
test.AddInput<float>("x", {2, 1}, {1.f, 2.f});
|
||||
test.AddInput<float>("y", {1, 2}, {1.f, 2.f});
|
||||
test.AddOutput<float>("o", {2, 2}, {1.f, 2.f, 2.f, 4.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmul_Multi_Input) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "ij,jk,kl->li");
|
||||
|
|
|
|||
Loading…
Reference in a new issue