Misc transformer fixes (#14103)

### Description
1. SkipLayerNormalization has a new output
(https://github.com/microsoft/onnxruntime/pull/13988) and the symbolic
shape inference script needs corresponding updates

2. The greedy sampling op
(https://github.com/microsoft/onnxruntime/pull/13426) shouldn't re-use
the logits buffer as its corresponding kernel doesn't seem to support it
yet.

### Motivation and Context
Fix some transformer issues
This commit is contained in:
Hariharan Seshadri 2023-01-03 13:05:55 -08:00 committed by GitHub
parent e7f9d40dde
commit d43e0ec9ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 11 deletions

View file

@ -369,8 +369,8 @@ Status ProcessLogits(const OrtValue& logits, //
parameters->temperature,
parameters->batch_size,
parameters->num_beams,
parameters->vocab_size,
parameters->vocab_size,
vocab_size,
vocab_size,
(parameters->min_length > 0 && current_sequence_length < parameters->min_length) ? parameters->eos_token_id : -1,
reinterpret_cast<int32_t*>(sequences_buffer.get()),
parameters->max_length,
@ -508,7 +508,7 @@ template <typename T>
Status GreedySearchProcessLogits(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state, // buffers
transformers::ISamplingState<T>* sampling_state, // buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
@ -556,7 +556,10 @@ Status GreedySearchProcessLogits(
// In greedy search, next_token_scores is next_token_logits.
gsl::span<T>& next_token_scores = greedy_state->next_token_scores;
auto is_reuse_logits_buffer = (input_length == 1);
// TODO(hasesh/wy): Support re-using logits buffer for the sampling case.
// Currently, we cannot re-use the logits because the sampling logic expects
// `next_token_scores` to be populated.
auto is_reuse_logits_buffer = !do_sampling && (input_length == 1);
// Copy over the logits data into the staging buffer, only if
// we do not plan to re-use the logits buffer directly
@ -605,7 +608,7 @@ Status GreedySearchProcessLogits(
gsl::span<int>& presence_mask = sampling_state->d_presence_mask;
if (step == 1 && parameters->presence_mask.data() != nullptr) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(presence_mask.data(), parameters->presence_mask.data(),
sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
}
// TODO(hasesh): Can we avoid the const_cast by changing the interface of
@ -622,10 +625,11 @@ Status GreedySearchProcessLogits(
parameters->temperature,
parameters->batch_size,
parameters->num_beams,
parameters->vocab_size,
is_reuse_logits_buffer ? padded_vocab_size : parameters->vocab_size,
vocab_size,
is_reuse_logits_buffer ? padded_vocab_size : vocab_size,
(parameters->min_length > 0 && current_sequence_length < parameters->sequence_length + parameters->min_length)
? parameters->eos_token_id : -1,
? parameters->eos_token_id
: -1,
reinterpret_cast<int32_t*>(sequences_buffer.get()),
parameters->max_length,
current_sequence_length,
@ -686,8 +690,8 @@ Status GreedySearchProcessLogits(
*topk_scores, *topk_indices));
#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", *(topk_scores.get()));
dumper->Print("topk_indices", *(topk_indices.get()));
dumper->Print("topk_scores", *(topk_scores.get()));
dumper->Print("topk_indices", *(topk_indices.get()));
#endif
const int64_t* next_token_indices = topk_indices->Data<int64_t>();

View file

@ -2030,6 +2030,11 @@ class SymbolicShapeInference:
def _infer_SkipLayerNormalization(self, node):
self._propagate_shape_and_type(node)
# If the SkipLayerNormalization node contains the optional
# output for inference, infer the shape and type for it too
if len(node.output) > 3:
self._propagate_shape_and_type(node, 0, 3)
def _infer_PythonOp(self, node):
output_tensor_types = get_attribute(node, "output_tensor_types")
assert output_tensor_types
@ -2212,6 +2217,11 @@ class SymbolicShapeInference:
self._check_merged_dims(in_dims, allow_broadcast=True)
for i_o in range(len(node.output)):
# Special case: We do not care about the training related
# outputs of SkipLayerNormalization
if node.op_type == "SkipLayerNormalization" and i_o in [1, 2]:
continue
vi = self.known_vi_[node.output[i_o]]
out_type = vi.type
out_type_kind = out_type.WhichOneof("value")

View file

@ -51,8 +51,18 @@ class Gpt2OnnxModel(BertOnnxModel):
[0, 0],
output_name_to_node,
)
if nodes is None:
continue
nodes = self.match_parent_path(
gemm_node,
["Reshape", "SkipLayerNormalization"],
[0, 0],
output_name_to_node,
)
if nodes is None:
continue
(reshape_before_gemm, root_node) = nodes
matmul_node_name = self.create_node_name("MatMul", "FullyConnect_MatMul")