From d43e0ec9ba75db16c7d151f3281a61ae4168777e Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 3 Jan 2023 13:05:55 -0800 Subject: [PATCH] Misc transformer fixes (#14103) ### Description 1. SkipLayerNormalization has a new output (https://github.com/microsoft/onnxruntime/pull/13988) and the symbolic shape inference script needs corresponding updates 2. The greedy sampling op (https://github.com/microsoft/onnxruntime/pull/13426) shouldn't re-use the logits buffer as its corresponding kernel doesn't seem to support it yet. ### Motivation and Context Fix some transformer issues --- .../transformers/generation_device_helper.cc | 24 +++++++++++-------- .../python/tools/symbolic_shape_infer.py | 10 ++++++++ .../tools/transformers/onnx_model_gpt2.py | 12 +++++++++- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 5377c45176..bf3f11a6b3 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -369,8 +369,8 @@ Status ProcessLogits(const OrtValue& logits, // parameters->temperature, parameters->batch_size, parameters->num_beams, - parameters->vocab_size, - parameters->vocab_size, + vocab_size, + vocab_size, (parameters->min_length > 0 && current_sequence_length < parameters->min_length) ? parameters->eos_token_id : -1, reinterpret_cast(sequences_buffer.get()), parameters->max_length, @@ -508,7 +508,7 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state - transformers::ISamplingState* sampling_state, // buffers + transformers::ISamplingState* sampling_state, // buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) @@ -556,7 +556,10 @@ Status GreedySearchProcessLogits( // In greedy search, next_token_scores is next_token_logits. gsl::span& next_token_scores = greedy_state->next_token_scores; - auto is_reuse_logits_buffer = (input_length == 1); + // TODO(hasesh/wy): Support re-using logits buffer for the sampling case. + // Currently, we cannot re-use the logits because the sampling logic expects + // `next_token_scores` to be populated. + auto is_reuse_logits_buffer = !do_sampling && (input_length == 1); // Copy over the logits data into the staging buffer, only if // we do not plan to re-use the logits buffer directly @@ -605,7 +608,7 @@ Status GreedySearchProcessLogits( gsl::span& presence_mask = sampling_state->d_presence_mask; if (step == 1 && parameters->presence_mask.data() != nullptr) { CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(presence_mask.data(), parameters->presence_mask.data(), - sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); + sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); } // TODO(hasesh): Can we avoid the const_cast by changing the interface of @@ -622,10 +625,11 @@ Status GreedySearchProcessLogits( parameters->temperature, parameters->batch_size, parameters->num_beams, - parameters->vocab_size, - is_reuse_logits_buffer ? padded_vocab_size : parameters->vocab_size, + vocab_size, + is_reuse_logits_buffer ? padded_vocab_size : vocab_size, (parameters->min_length > 0 && current_sequence_length < parameters->sequence_length + parameters->min_length) - ? parameters->eos_token_id : -1, + ? parameters->eos_token_id + : -1, reinterpret_cast(sequences_buffer.get()), parameters->max_length, current_sequence_length, @@ -686,8 +690,8 @@ Status GreedySearchProcessLogits( *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION - dumper->Print("topk_scores", *(topk_scores.get())); - dumper->Print("topk_indices", *(topk_indices.get())); + dumper->Print("topk_scores", *(topk_scores.get())); + dumper->Print("topk_indices", *(topk_indices.get())); #endif const int64_t* next_token_indices = topk_indices->Data(); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 8a92e63425..2f44675bb2 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2030,6 +2030,11 @@ class SymbolicShapeInference: def _infer_SkipLayerNormalization(self, node): self._propagate_shape_and_type(node) + # If the SkipLayerNormalization node contains the optional + # output for inference, infer the shape and type for it too + if len(node.output) > 3: + self._propagate_shape_and_type(node, 0, 3) + def _infer_PythonOp(self, node): output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types @@ -2212,6 +2217,11 @@ class SymbolicShapeInference: self._check_merged_dims(in_dims, allow_broadcast=True) for i_o in range(len(node.output)): + # Special case: We do not care about the training related + # outputs of SkipLayerNormalization + if node.op_type == "SkipLayerNormalization" and i_o in [1, 2]: + continue + vi = self.known_vi_[node.output[i_o]] out_type = vi.type out_type_kind = out_type.WhichOneof("value") diff --git a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py index 4f922820bb..92197e7e4f 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py +++ b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py @@ -51,8 +51,18 @@ class Gpt2OnnxModel(BertOnnxModel): [0, 0], output_name_to_node, ) + if nodes is None: - continue + nodes = self.match_parent_path( + gemm_node, + ["Reshape", "SkipLayerNormalization"], + [0, 0], + output_name_to_node, + ) + + if nodes is None: + continue + (reshape_before_gemm, root_node) = nodes matmul_node_name = self.create_node_name("MatMul", "FullyConnect_MatMul")