mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Fixed an issue in updating realized dims (#2597)
when we update realized dims for scan's output, the sliced axis also needs to be inclusive, i.e. we should check with "dim >= insert_inclusive_axis", because the offset in the symbols are based on Scan sugraph. Otherwise, we would end up with shape mismatch later.
This commit is contained in:
parent
78099701b4
commit
b0128a4843
2 changed files with 50 additions and 5 deletions
|
|
@ -178,18 +178,18 @@ class KernelComputeCtx {
|
|||
}
|
||||
|
||||
// UpdateRealizedDims is used to sync realize dim
|
||||
// Note insert_exclusive_axis is introduced to adjusted shape.
|
||||
// Note insert_inclusive_axis is introduced to adjusted shape.
|
||||
// It is commonly used in Scan or other subgraphs
|
||||
// when Tensors' shapes in a subgraph are sliced from the main grahp.
|
||||
// Using the sliced axis as insert_exclusive_axis can find the correct shape dim in the main graph
|
||||
// Using the sliced axis as insert_inclusive_axis can find the correct shape dim in the main graph
|
||||
inline void UpdateRealizedDims(
|
||||
const std::vector<std::pair<size_t, std::string>>& symbols,
|
||||
std::vector<int64_t>& realized_output_shape,
|
||||
size_t insert_exclusive_axis = 65535 /*minimal maximum of size_t*/) {
|
||||
size_t insert_inclusive_axis = 65535 /*minimal maximum of size_t*/) {
|
||||
for (const auto& s_pair : symbols) {
|
||||
size_t dim = s_pair.first;
|
||||
size_t adjusted_dim = dim;
|
||||
if (dim > insert_exclusive_axis) {
|
||||
if (dim >= insert_inclusive_axis) {
|
||||
adjusted_dim = dim + 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import onnx
|
|||
from onnx import numpy_helper
|
||||
import onnxruntime as onnxrt
|
||||
import os
|
||||
from onnxruntime.nuphar.rnn_benchmark import perf_test
|
||||
from onnxruntime.nuphar.rnn_benchmark import perf_test, generate_model
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
|
|
@ -131,6 +131,51 @@ class TestNuphar(unittest.TestCase):
|
|||
min_duration_seconds=1)
|
||||
|
||||
|
||||
def test_batch_scan(self):
|
||||
input_dim = 3
|
||||
hidden_dim = 5
|
||||
bidirectional = False
|
||||
layers = 3
|
||||
|
||||
lstm_model_name = 'test_batch_rnn_lstm.onnx'
|
||||
# create an LSTM model for generating baseline data
|
||||
generate_model('lstm', input_dim, hidden_dim, bidirectional, layers, lstm_model_name, batch_one=False, has_seq_len=True)
|
||||
|
||||
seq_len = 8
|
||||
batch_size = 2
|
||||
# prepare input
|
||||
data_input = (np.random.rand(seq_len, batch_size, input_dim) * 2 - 1).astype(np.float32)
|
||||
data_seq_len = np.random.randint(1, seq_len, size=(batch_size,), dtype=np.int32)
|
||||
|
||||
# run lstm as baseline
|
||||
sess = onnxrt.InferenceSession(lstm_model_name)
|
||||
first_lstm_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]})
|
||||
|
||||
lstm_data_output = []
|
||||
lstm_data_output = first_lstm_data_output
|
||||
|
||||
for b in range(1, batch_size):
|
||||
lstm_data_output = lstm_data_output + sess.run([], {'input':data_input[:,b:(b+1),:], 'seq_len':data_seq_len[b:(b+1)]})
|
||||
lstm_data_output = np.concatenate(lstm_data_output, axis=1)
|
||||
|
||||
# generate a batch scan model
|
||||
scan_model_name = 'test_batch_rnn_scan.onnx'
|
||||
subprocess.run([sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', lstm_model_name, '--output', scan_model_name, '--mode', 'to_scan'], check=True)
|
||||
|
||||
# run scan_batch with batch size 1
|
||||
sess = onnxrt.InferenceSession(scan_model_name)
|
||||
scan_batch_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]})
|
||||
assert np.allclose(first_lstm_data_output, scan_batch_data_output)
|
||||
|
||||
# run scan_batch with batch size 2
|
||||
scan_batch_data_output = sess.run([], {'input':data_input, 'seq_len':data_seq_len})
|
||||
assert np.allclose(lstm_data_output, scan_batch_data_output)
|
||||
|
||||
# run scan_batch with batch size 1 again
|
||||
scan_batch_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]})
|
||||
assert np.allclose(first_lstm_data_output, scan_batch_data_output)
|
||||
|
||||
|
||||
def test_symbolic_shape_infer(self):
|
||||
cwd = os.getcwd()
|
||||
test_model_dir = os.path.join(cwd, '..', 'models')
|
||||
|
|
|
|||
Loading…
Reference in a new issue