From df796bbb6236100fa2de3ef303cf9ea12bc91d64 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 3 Nov 2022 16:40:19 -0700 Subject: [PATCH] cast logits to half when T=MLFloat16 (#13454) ### Description ### Motivation and Context --- .../transformers/generation_device_helper.cc | 4 ++-- .../python/transformers/test_generation.py | 20 ++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 6f5e59bbcc..13915081e1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -475,8 +475,8 @@ Status GreedySearchProcessLogits( cudaMemcpyHostToDevice, cuda_stream)); } - cuda::LaunchLogitsProcessKernel( - reinterpret_cast(next_token_scores.data()), + cuda::LaunchLogitsProcessKernel( + reinterpret_cast(next_token_scores.data()), parameters->vocab_mask.data(), step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. parameters->batch_size, diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index d42b36fc96..26fe6c333c 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -36,13 +36,15 @@ class TestBeamSearchGpt(unittest.TestCase): f"-m {self.model_name}", f"--decoder_onnx {self.gpt2_onnx_path}", f"--output {self.beam_search_onnx_path}", - "--output_sequences_score", "--repetition_penalty 2.0", ] self.sentences = [ "The product is released", "I enjoy walking in the park", "Test best way to invest", + "The AI community building the future", + "The selloff in tech shares deepened", + "Abortion rights take centre stage", ] self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() self.remove_onnx_files() @@ -57,12 +59,18 @@ class TestBeamSearchGpt(unittest.TestCase): if os.path.exists(self.beam_search_onnx_path): os.remove(self.beam_search_onnx_path) - def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True): + def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True, is_greedy=False): + if append_arguments: arguments = " ".join(self.default_arguments + [extra_arguments]).split() else: arguments = extra_arguments.split() + if is_greedy: + arguments.extend("--num_beams 1 --num_return_sequences 1".split()) + else: + arguments.extend("--output_sequences_score".split()) + # Test CPU result = run(arguments, sentences=self.sentences if sentences is None else sentences) self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}") @@ -97,7 +105,13 @@ class TestBeamSearchGpt(unittest.TestCase): @pytest.mark.slow def test_greedy_search(self): - self.run_beam_search("--num_beams 1 --num_return_sequences 1") + self.run_beam_search("", is_greedy=True) + + @pytest.mark.slow + def test_greedy_search_float16(self): + # TODO: investigate fp16 parity issue for greedy/beam search with repetition_penalty != 1.0 + if self.enable_cuda: + self.run_beam_search("--repetition_penalty 1.0 --use_gpu -p fp16", is_greedy=True) @pytest.mark.slow def test_external_data(self):