mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Misc transformer fixes (#14103)
### Description 1. SkipLayerNormalization has a new output (https://github.com/microsoft/onnxruntime/pull/13988) and the symbolic shape inference script needs corresponding updates 2. The greedy sampling op (https://github.com/microsoft/onnxruntime/pull/13426) shouldn't re-use the logits buffer as its corresponding kernel doesn't seem to support it yet. ### Motivation and Context Fix some transformer issues
This commit is contained in:
parent
e7f9d40dde
commit
d43e0ec9ba
3 changed files with 35 additions and 11 deletions
|
|
@ -369,8 +369,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
parameters->temperature,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
parameters->vocab_size,
|
||||
parameters->vocab_size,
|
||||
vocab_size,
|
||||
vocab_size,
|
||||
(parameters->min_length > 0 && current_sequence_length < parameters->min_length) ? parameters->eos_token_id : -1,
|
||||
reinterpret_cast<int32_t*>(sequences_buffer.get()),
|
||||
parameters->max_length,
|
||||
|
|
@ -508,7 +508,7 @@ template <typename T>
|
|||
Status GreedySearchProcessLogits(
|
||||
const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state, // buffers
|
||||
transformers::ISamplingState<T>* sampling_state, // buffers
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
|
|
@ -556,7 +556,10 @@ Status GreedySearchProcessLogits(
|
|||
// In greedy search, next_token_scores is next_token_logits.
|
||||
gsl::span<T>& next_token_scores = greedy_state->next_token_scores;
|
||||
|
||||
auto is_reuse_logits_buffer = (input_length == 1);
|
||||
// TODO(hasesh/wy): Support re-using logits buffer for the sampling case.
|
||||
// Currently, we cannot re-use the logits because the sampling logic expects
|
||||
// `next_token_scores` to be populated.
|
||||
auto is_reuse_logits_buffer = !do_sampling && (input_length == 1);
|
||||
|
||||
// Copy over the logits data into the staging buffer, only if
|
||||
// we do not plan to re-use the logits buffer directly
|
||||
|
|
@ -605,7 +608,7 @@ Status GreedySearchProcessLogits(
|
|||
gsl::span<int>& presence_mask = sampling_state->d_presence_mask;
|
||||
if (step == 1 && parameters->presence_mask.data() != nullptr) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(presence_mask.data(), parameters->presence_mask.data(),
|
||||
sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
|
||||
sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
|
||||
}
|
||||
|
||||
// TODO(hasesh): Can we avoid the const_cast by changing the interface of
|
||||
|
|
@ -622,10 +625,11 @@ Status GreedySearchProcessLogits(
|
|||
parameters->temperature,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
parameters->vocab_size,
|
||||
is_reuse_logits_buffer ? padded_vocab_size : parameters->vocab_size,
|
||||
vocab_size,
|
||||
is_reuse_logits_buffer ? padded_vocab_size : vocab_size,
|
||||
(parameters->min_length > 0 && current_sequence_length < parameters->sequence_length + parameters->min_length)
|
||||
? parameters->eos_token_id : -1,
|
||||
? parameters->eos_token_id
|
||||
: -1,
|
||||
reinterpret_cast<int32_t*>(sequences_buffer.get()),
|
||||
parameters->max_length,
|
||||
current_sequence_length,
|
||||
|
|
@ -686,8 +690,8 @@ Status GreedySearchProcessLogits(
|
|||
*topk_scores, *topk_indices));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
dumper->Print("topk_indices", *(topk_indices.get()));
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
dumper->Print("topk_indices", *(topk_indices.get()));
|
||||
#endif
|
||||
|
||||
const int64_t* next_token_indices = topk_indices->Data<int64_t>();
|
||||
|
|
|
|||
|
|
@ -2030,6 +2030,11 @@ class SymbolicShapeInference:
|
|||
def _infer_SkipLayerNormalization(self, node):
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
# If the SkipLayerNormalization node contains the optional
|
||||
# output for inference, infer the shape and type for it too
|
||||
if len(node.output) > 3:
|
||||
self._propagate_shape_and_type(node, 0, 3)
|
||||
|
||||
def _infer_PythonOp(self, node):
|
||||
output_tensor_types = get_attribute(node, "output_tensor_types")
|
||||
assert output_tensor_types
|
||||
|
|
@ -2212,6 +2217,11 @@ class SymbolicShapeInference:
|
|||
self._check_merged_dims(in_dims, allow_broadcast=True)
|
||||
|
||||
for i_o in range(len(node.output)):
|
||||
# Special case: We do not care about the training related
|
||||
# outputs of SkipLayerNormalization
|
||||
if node.op_type == "SkipLayerNormalization" and i_o in [1, 2]:
|
||||
continue
|
||||
|
||||
vi = self.known_vi_[node.output[i_o]]
|
||||
out_type = vi.type
|
||||
out_type_kind = out_type.WhichOneof("value")
|
||||
|
|
|
|||
|
|
@ -51,8 +51,18 @@ class Gpt2OnnxModel(BertOnnxModel):
|
|||
[0, 0],
|
||||
output_name_to_node,
|
||||
)
|
||||
|
||||
if nodes is None:
|
||||
continue
|
||||
nodes = self.match_parent_path(
|
||||
gemm_node,
|
||||
["Reshape", "SkipLayerNormalization"],
|
||||
[0, 0],
|
||||
output_name_to_node,
|
||||
)
|
||||
|
||||
if nodes is None:
|
||||
continue
|
||||
|
||||
(reshape_before_gemm, root_node) = nodes
|
||||
|
||||
matmul_node_name = self.create_node_name("MatMul", "FullyConnect_MatMul")
|
||||
|
|
|
|||
Loading…
Reference in a new issue