cast logits to half when T=MLFloat16 (#13454)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Ye Wang 2022-11-03 16:40:19 -07:00 committed by GitHub
parent b4a1ae8350
commit df796bbb62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 5 deletions

View file

@ -475,8 +475,8 @@ Status GreedySearchProcessLogits(
cudaMemcpyHostToDevice, cuda_stream));
}
cuda::LaunchLogitsProcessKernel<float>(
reinterpret_cast<float*>(next_token_scores.data()),
cuda::LaunchLogitsProcessKernel<CudaT>(
reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->vocab_mask.data(),
step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only.
parameters->batch_size,

View file

@ -36,13 +36,15 @@ class TestBeamSearchGpt(unittest.TestCase):
f"-m {self.model_name}",
f"--decoder_onnx {self.gpt2_onnx_path}",
f"--output {self.beam_search_onnx_path}",
"--output_sequences_score",
"--repetition_penalty 2.0",
]
self.sentences = [
"The product is released",
"I enjoy walking in the park",
"Test best way to invest",
"The AI community building the future",
"The selloff in tech shares deepened",
"Abortion rights take centre stage",
]
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
self.remove_onnx_files()
@ -57,12 +59,18 @@ class TestBeamSearchGpt(unittest.TestCase):
if os.path.exists(self.beam_search_onnx_path):
os.remove(self.beam_search_onnx_path)
def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True):
def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments=True, is_greedy=False):
if append_arguments:
arguments = " ".join(self.default_arguments + [extra_arguments]).split()
else:
arguments = extra_arguments.split()
if is_greedy:
arguments.extend("--num_beams 1 --num_return_sequences 1".split())
else:
arguments.extend("--output_sequences_score".split())
# Test CPU
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
@ -97,7 +105,13 @@ class TestBeamSearchGpt(unittest.TestCase):
@pytest.mark.slow
def test_greedy_search(self):
self.run_beam_search("--num_beams 1 --num_return_sequences 1")
self.run_beam_search("", is_greedy=True)
@pytest.mark.slow
def test_greedy_search_float16(self):
# TODO: investigate fp16 parity issue for greedy/beam search with repetition_penalty != 1.0
if self.enable_cuda:
self.run_beam_search("--repetition_penalty 1.0 --use_gpu -p fp16", is_greedy=True)
@pytest.mark.slow
def test_external_data(self):