mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
cast logits to half when T=MLFloat16 (#13454)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
b4a1ae8350
commit
df796bbb62
2 changed files with 19 additions and 5 deletions
|
|
@ -475,8 +475,8 @@ Status GreedySearchProcessLogits(
|
|||
cudaMemcpyHostToDevice, cuda_stream));
|
||||
}
|
||||
|
||||
cuda::LaunchLogitsProcessKernel<float>(
|
||||
reinterpret_cast<float*>(next_token_scores.data()),
|
||||
cuda::LaunchLogitsProcessKernel<CudaT>(
|
||||
reinterpret_cast<CudaT*>(next_token_scores.data()),
|
||||
parameters->vocab_mask.data(),
|
||||
step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only.
|
||||
parameters->batch_size,
|
||||
|
|
|
|||
|
|
@ -36,13 +36,15 @@ class TestBeamSearchGpt(unittest.TestCase):
|
|||
f"-m {self.model_name}",
|
||||
f"--decoder_onnx {self.gpt2_onnx_path}",
|
||||
f"--output {self.beam_search_onnx_path}",
|
||||
"--output_sequences_score",
|
||||
"--repetition_penalty 2.0",
|
||||
]
|
||||
self.sentences = [
|
||||
"The product is released",
|
||||
"I enjoy walking in the park",
|
||||
"Test best way to invest",
|
||||
"The AI community building the future",
|
||||
"The selloff in tech shares deepened",
|
||||
"Abortion rights take centre stage",
|
||||
]
|
||||
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
|
||||
self.remove_onnx_files()
|
||||
|
|
@ -57,12 +59,18 @@ class TestBeamSearchGpt(unittest.TestCase):
|
|||
if os.path.exists(self.beam_search_onnx_path):
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
|
||||
def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True):
|
||||
def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True, is_greedy=False):
|
||||
|
||||
if append_arguments:
|
||||
arguments = " ".join(self.default_arguments + [extra_arguments]).split()
|
||||
else:
|
||||
arguments = extra_arguments.split()
|
||||
|
||||
if is_greedy:
|
||||
arguments.extend("--num_beams 1 --num_return_sequences 1".split())
|
||||
else:
|
||||
arguments.extend("--output_sequences_score".split())
|
||||
|
||||
# Test CPU
|
||||
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
|
||||
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
|
||||
|
|
@ -97,7 +105,13 @@ class TestBeamSearchGpt(unittest.TestCase):
|
|||
|
||||
@pytest.mark.slow
|
||||
def test_greedy_search(self):
|
||||
self.run_beam_search("--num_beams 1 --num_return_sequences 1")
|
||||
self.run_beam_search("", is_greedy=True)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_greedy_search_float16(self):
|
||||
# TODO: investigate fp16 parity issue for greedy/beam search with repetition_penalty != 1.0
|
||||
if self.enable_cuda:
|
||||
self.run_beam_search("--repetition_penalty 1.0 --use_gpu -p fp16", is_greedy=True)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_external_data(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue