Fixed an issue in updating realized dims (#2597)

when we update realized dims for scan's output, the sliced axis also
needs to be inclusive, i.e. we should check with "dim >= insert_inclusive_axis",
because the offset in the symbols are based on Scan sugraph.
Otherwise, we would end up with shape mismatch later.
This commit is contained in:
Yang Chen 2019-12-09 22:56:47 -08:00 committed by GitHub
parent 78099701b4
commit b0128a4843
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 5 deletions

View file

@ -178,18 +178,18 @@ class KernelComputeCtx {
}
// UpdateRealizedDims is used to sync realize dim
// Note insert_exclusive_axis is introduced to adjusted shape.
// Note insert_inclusive_axis is introduced to adjusted shape.
// It is commonly used in Scan or other subgraphs
// when Tensors' shapes in a subgraph are sliced from the main grahp.
// Using the sliced axis as insert_exclusive_axis can find the correct shape dim in the main graph
// Using the sliced axis as insert_inclusive_axis can find the correct shape dim in the main graph
inline void UpdateRealizedDims(
const std::vector<std::pair<size_t, std::string>>& symbols,
std::vector<int64_t>& realized_output_shape,
size_t insert_exclusive_axis = 65535 /*minimal maximum of size_t*/) {
size_t insert_inclusive_axis = 65535 /*minimal maximum of size_t*/) {
for (const auto& s_pair : symbols) {
size_t dim = s_pair.first;
size_t adjusted_dim = dim;
if (dim > insert_exclusive_axis) {
if (dim >= insert_inclusive_axis) {
adjusted_dim = dim + 1;
}

View file

@ -7,7 +7,7 @@ import onnx
from onnx import numpy_helper
import onnxruntime as onnxrt
import os
from onnxruntime.nuphar.rnn_benchmark import perf_test
from onnxruntime.nuphar.rnn_benchmark import perf_test, generate_model
from pathlib import Path
import shutil
import sys
@ -131,6 +131,51 @@ class TestNuphar(unittest.TestCase):
min_duration_seconds=1)
def test_batch_scan(self):
input_dim = 3
hidden_dim = 5
bidirectional = False
layers = 3
lstm_model_name = 'test_batch_rnn_lstm.onnx'
# create an LSTM model for generating baseline data
generate_model('lstm', input_dim, hidden_dim, bidirectional, layers, lstm_model_name, batch_one=False, has_seq_len=True)
seq_len = 8
batch_size = 2
# prepare input
data_input = (np.random.rand(seq_len, batch_size, input_dim) * 2 - 1).astype(np.float32)
data_seq_len = np.random.randint(1, seq_len, size=(batch_size,), dtype=np.int32)
# run lstm as baseline
sess = onnxrt.InferenceSession(lstm_model_name)
first_lstm_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]})
lstm_data_output = []
lstm_data_output = first_lstm_data_output
for b in range(1, batch_size):
lstm_data_output = lstm_data_output + sess.run([], {'input':data_input[:,b:(b+1),:], 'seq_len':data_seq_len[b:(b+1)]})
lstm_data_output = np.concatenate(lstm_data_output, axis=1)
# generate a batch scan model
scan_model_name = 'test_batch_rnn_scan.onnx'
subprocess.run([sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', lstm_model_name, '--output', scan_model_name, '--mode', 'to_scan'], check=True)
# run scan_batch with batch size 1
sess = onnxrt.InferenceSession(scan_model_name)
scan_batch_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]})
assert np.allclose(first_lstm_data_output, scan_batch_data_output)
# run scan_batch with batch size 2
scan_batch_data_output = sess.run([], {'input':data_input, 'seq_len':data_seq_len})
assert np.allclose(lstm_data_output, scan_batch_data_output)
# run scan_batch with batch size 1 again
scan_batch_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]})
assert np.allclose(first_lstm_data_output, scan_batch_data_output)
def test_symbolic_shape_infer(self):
cwd = os.getcwd()
test_model_dir = os.path.join(cwd, '..', 'models')