mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Zhijxu/fix softmax cudnn bf16 (#21045)
if seq >2048, ort will fallback to cudnn version, while when dtype is bf16, ort will throw exception, this PR trying to fix it.
This commit is contained in:
parent
5b5ce0bfb0
commit
269d9b094f
3 changed files with 37 additions and 1 deletions
|
|
@ -174,7 +174,11 @@ cudnnDataType_t CudnnTensor::GetDataType<half>() {
|
|||
|
||||
template <>
|
||||
cudnnDataType_t CudnnTensor::GetDataType<BFloat16>() {
|
||||
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
#else
|
||||
ORT_THROW("cuDNN doesn't support BFloat16.");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class RandomValueGenerator {
|
|||
// Random values generated are in the range [min, max).
|
||||
template <typename TFloat16>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<TFloat16, MLFloat16>,
|
||||
std::is_same_v<TFloat16, MLFloat16> || std::is_same_v<TFloat16, BFloat16>,
|
||||
std::vector<TFloat16>>::type
|
||||
Uniform(gsl::span<const int64_t> dims, float min, float max) {
|
||||
std::vector<TFloat16> val(detail::SizeFromDims(dims));
|
||||
|
|
|
|||
|
|
@ -146,6 +146,38 @@ class TestOnnxOpsOrtModule(unittest.TestCase):
|
|||
device = torch.device(device_name)
|
||||
self.gradient_correctness(name, device)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_bf16_supported(), "Test requires CUDA and BF16 support")
|
||||
def test_softmax_bf16_large(self):
|
||||
if not torch.cuda.is_available():
|
||||
# only test bf16 on cuda
|
||||
return
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
out = torch.softmax(input, dim=-1)
|
||||
return out
|
||||
|
||||
device = "cuda:0"
|
||||
input_shape = [2, 4096]
|
||||
# run torch to get the expected result
|
||||
data_torch = torch.randn(size=input_shape, device=device, dtype=torch.bfloat16) + 10
|
||||
data_torch.requires_grad = True
|
||||
torch_model = Model()
|
||||
torch_res = torch_model(input=data_torch)
|
||||
init_grad = torch.ones_like(torch_res)
|
||||
torch_res.backward(gradient=init_grad)
|
||||
# run ort
|
||||
ort_model = ORTModule(torch_model)
|
||||
data_ort = data_torch.detach().clone()
|
||||
data_ort.requires_grad = True
|
||||
ort_res = ort_model(input=data_ort)
|
||||
ort_res.backward(gradient=init_grad)
|
||||
# compara result
|
||||
torch.testing.assert_close(data_torch.grad, data_ort.grad, rtol=1e-5, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue