Zhijxu/fix softmax cudnn bf16 (#21045)

if seq >2048, ort will fallback to cudnn version, while when dtype is
bf16, ort will throw exception, this PR trying to fix it.
This commit is contained in:
zhijiang 2024-06-24 16:07:39 +08:00 committed by GitHub
parent 5b5ce0bfb0
commit 269d9b094f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 37 additions and 1 deletions

View file

@ -174,7 +174,11 @@ cudnnDataType_t CudnnTensor::GetDataType<half>() {
template <>
cudnnDataType_t CudnnTensor::GetDataType<BFloat16>() {
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
return CUDNN_DATA_BFLOAT16;
#else
ORT_THROW("cuDNN doesn't support BFloat16.");
#endif
}
template <>

View file

@ -70,7 +70,7 @@ class RandomValueGenerator {
// Random values generated are in the range [min, max).
template <typename TFloat16>
typename std::enable_if<
std::is_same_v<TFloat16, MLFloat16>,
std::is_same_v<TFloat16, MLFloat16> || std::is_same_v<TFloat16, BFloat16>,
std::vector<TFloat16>>::type
Uniform(gsl::span<const int64_t> dims, float min, float max) {
std::vector<TFloat16> val(detail::SizeFromDims(dims));

View file

@ -146,6 +146,38 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
device = torch.device(device_name)
self.gradient_correctness(name, device)
@unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support")
def test_softmax_bf16_large(self):
if not torch.cuda.is_available():
# only test bf16 on cuda
return
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
out = torch.softmax(input, dim=-1)
return out
device = "cuda:0"
input_shape = [2, 4096]
# run torch to get the expected result
data_torch = torch.randn(size=input_shape, device=device, dtype=torch.bfloat16) + 10
data_torch.requires_grad = True
torch_model = Model()
torch_res = torch_model(input=data_torch)
init_grad = torch.ones_like(torch_res)
torch_res.backward(gradient=init_grad)
# run ort
ort_model = ORTModule(torch_model)
data_ort = data_torch.detach().clone()
data_ort.requires_grad = True
ort_res = ort_model(input=data_ort)
ort_res.backward(gradient=init_grad)
# compara result
torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)
if __name__ == "__main__":
unittest.main()