From ed54a5d06bd5a7bd14bac58b956845c4cd292f68 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Sun, 12 Feb 2023 00:05:09 +0000 Subject: [PATCH] enable bf16 emb (#94163) Merge https://github.com/pytorch/pytorch/pull/89199 and https://github.com/pytorch/pytorch/pull/91949 into one PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94163 Approved by: https://github.com/jianyuh, https://github.com/malfet, https://github.com/jgong5 --- aten/src/ATen/native/EmbeddingBag.cpp | 390 +++++++++++------- aten/src/ATen/native/EmbeddingBag.h | 16 +- test/nn/test_embedding.py | 27 +- test/test_meta.py | 2 +- third_party/fbgemm | 2 +- .../_internal/common_methods_invocations.py | 2 +- 6 files changed, 273 insertions(+), 166 deletions(-) diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 48537aacbdc..6a0ee75d814 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -1,10 +1,11 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include #include #include -#include #include +#include +#include +#include #include #include @@ -86,14 +87,20 @@ std::pair promoteIndicesAndOffsets( // is only applicable if special conditions are met template bool is_fast_path_index_select(const Tensor& src, Tensor& output, index_t padding_idx) { - return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && padding_idx < static_cast(0); + return (src.scalar_type() == kFloat || src.scalar_type() == kHalf || + src.scalar_type() == kBFloat16) && + src.strides()[1] == 1 && output.strides()[1] == 1 && + padding_idx < static_cast(0); } // Determines if we can use a fast implementation for index_select_scale_add, // which is only applicable if special conditions are met template bool is_fast_path_index_select_scale(const Tensor& src, const Tensor& scale, Tensor& output, index_t padding_idx) { - return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && scale.strides()[0] == 1 && padding_idx < static_cast(0); + return (src.scalar_type() == kFloat || src.scalar_type() == kHalf || + src.scalar_type() == kBFloat16) && + src.strides()[1] == 1 && output.strides()[1] == 1 && + scale.strides()[0] == 1 && padding_idx < static_cast(0); } template @@ -106,17 +113,18 @@ bool is_fast_path(const Tensor& src, const c10::optional& scale, Tensor& // This function combines index_select (using select_indices as the index) and // index_add (using add_indices as the index), without creating an intermediary // tensor to hold the selected embeddings -template -typename std::enable_if::value && !std::is_same::value, void>::type -index_select_add(const Tensor &select_indices, - const Tensor &add_indices, - const Tensor &src, - Tensor &output, - const Tensor& /*offsets*/, - bool /*include_last_offset*/, - Tensor &bag_size, - index_t padding_idx, - _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) { +template +static typename std::enable_if::value, void>::type +index_select_add( + const Tensor& select_indices, + const Tensor& add_indices, + const Tensor& src, + Tensor& output, + const Tensor& /*offsets*/, + bool /*include_last_offset*/, + Tensor& bag_size, + index_t padding_idx, + _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) { TORCH_CHECK(select_indices.numel() == add_indices.numel()); auto* add_indices_data = add_indices.data_ptr(); auto* select_indices_data = select_indices.data_ptr(); @@ -184,24 +192,28 @@ void fbgemm_spmdm_report_error_( } } // namespace -template -typename std::enable_if::value, void>::type -index_select_add(const Tensor &select_indices, - const Tensor &add_indices, - const Tensor &src, - Tensor &output, - const Tensor& offsets, - bool include_last_offset, - Tensor &bag_size, - index_t padding_idx, - _EmbeddingBagKernelCache* fbgemm_kernel_cache) { +template +typename std::enable_if< + std::is_same::value || + std::is_same::value, + void>::type +index_select_add( + const Tensor& select_indices, + const Tensor& add_indices, + const Tensor& src, + Tensor& output, + const Tensor& offsets, + bool include_last_offset, + Tensor& bag_size, + index_t padding_idx, + _EmbeddingBagKernelCache* fbgemm_kernel_cache) { int64_t ddim = src.size(1); auto* select_indices_data = select_indices.data_ptr(); - auto* output_data = output.data_ptr(); + auto* output_data = output.data_ptr(); if (is_fast_path_index_select(src, output, padding_idx)) { auto src_contig = src.contiguous(); - auto* src_data = src_contig.data_ptr(); + auto* src_data = src_contig.data_ptr(); int64_t output_size = offsets.numel() - 1; auto* offsets_data = offsets.data_ptr(); std::vector offsets_include_last; @@ -220,36 +232,31 @@ index_select_add(const Tensor &select_indices, offsets_include_last[offsets.numel()] = select_indices.numel(); offsets_data = offsets_include_last.data(); } - -#ifdef USE_FBGEMM - using float16 = uint16_t; - auto kernel_fp16_index_t = fbgemm_kernel_cache ? - fbgemm_kernel_cache->getCallback(ddim) : - fbgemm::GenerateEmbeddingSpMDM( - /* block_size */ddim, - /* has_weight */false, - /* normalize_by_lengths */false, - /* prefetch */16, - /* is_weight_positional */false, - /* use_offsets */true - ); -#else - // Initialize the intermediate output buffer to be 0. - Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat)); - auto* output_data_fp32 = output_fp32.data_ptr(); -#endif +#if defined(USE_FBGEMM) + bool isbf16 = std::is_same::value ? false : true; + auto kernel_16bit_index_t = fbgemm_kernel_cache + ? fbgemm_kernel_cache + ->getCallback(ddim) + : fbgemm::GenerateEmbeddingSpMDM( + /* block_size */ ddim, + /* has_weight */ false, + /* normalize_by_lengths */ false, + /* prefetch */ 16, + /* is_weight_positional */ false, + /* use_offsets */ true, + /* isbf16*/ isbf16); at::parallel_for( 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { -#ifdef USE_FBGEMM - bool success = kernel_fp16_index_t( - /* output_size */end_idx - start_idx, - /* index_size */offsets_data[end_idx] - offsets_data[start_idx], - /* data_size */src.size(0), - /* input */reinterpret_cast(src_data), - /* indices */select_indices_data + offsets_data[start_idx], - /* offsets_or_lengths */offsets_data + start_idx, - /* weights */nullptr, - /* output */reinterpret_cast(output_data + start_idx * ddim)); + bool success = kernel_16bit_index_t( + /* output_size */ end_idx - start_idx, + /* index_size */ offsets_data[end_idx] - offsets_data[start_idx], + /* data_size */ src.size(0), + /* input */ reinterpret_cast(src_data), + /* indices */ select_indices_data + offsets_data[start_idx], + /* offsets_or_lengths */ offsets_data + start_idx, + /* weights */ nullptr, + /* output */ + reinterpret_cast(output_data + start_idx * ddim)); if (!success) { fbgemm_spmdm_report_error_( end_idx - start_idx, @@ -258,7 +265,15 @@ index_select_add(const Tensor &select_indices, offsets_data + start_idx, select_indices_data + offsets_data[start_idx]); } + }); #else + // Initialize the intermediate output buffer to be 0. + Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat)); + auto* output_data_fp32 = output_fp32.data_ptr(); + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + at::parallel_for( + 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { caffe2::EmbeddingLookupIdx( /*block_size=*/ddim, /*output_size=*/end_idx - start_idx, @@ -271,18 +286,36 @@ index_select_add(const Tensor &select_indices, /*scale_bias=*/nullptr, /*normalize_by_lengths=*/false, /*out=*/output_data_fp32 + start_idx * ddim); - for (const auto i : c10::irange(output_size)) { - // Convert FP32 intermediate buffer result back to FP16 for output dtype - for (const auto d : c10::irange(ddim)) { - (output_data + i * ddim)[d] = static_cast((output_data_fp32 + ddim * i)[d]); + for (int64_t i = start_idx; i < end_idx; i++) { + // Convert FP32 intermediate buffer result back to 16 bit for + // output dtype + if (std::is_same::value) { + // FP16 + for (const auto d : c10::irange(ddim)) { + (output_data + i * ddim)[d] = + static_cast((output_data_fp32 + ddim * i)[d]); + } + } else { + // BF16 + int64_t d = 0; + for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) { + fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d); + fVec temp_fp32_1 = + fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size()); + convert_float_bfloat16(temp_fp32_0, temp_fp32_1) + .store(output_data + i * ddim + d); + } + for (; d < ddim; d++) { + (output_data + i * ddim)[d] = + static_cast((output_data_fp32 + ddim * i)[d]); + } } } -#endif }); - +#endif } else { TORCH_CHECK(select_indices.numel() == add_indices.numel()); - auto* src_data = src.data_ptr(); + auto* src_data = src.data_ptr(); auto* add_indices_data = add_indices.data_ptr(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) index_t* bag_size_data = nullptr; @@ -300,7 +333,8 @@ index_select_add(const Tensor &select_indices, auto* src_data_fp32 = src_fp32.data_ptr(); // Initialize the intermediate output buffer to be 0. - Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat)); + Tensor output_fp32 = + at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat)); auto* output_data_fp32 = output_fp32.data_ptr(); for (const auto i : c10::irange(numel)) { @@ -314,11 +348,16 @@ index_select_add(const Tensor &select_indices, if (idx != padding_idx) { // Copy src_data + src_stride0 * idx to src_data_fp32 for (const auto d : c10::irange(ddim)) { - src_data_fp32[d] = static_cast((src_data + src_stride0 * idx)[d * src_stride1]); + src_data_fp32[d] = static_cast( + (src_data + src_stride0 * idx)[d * src_stride1]); } - at::native::cpublas::axpy(ddim, 1, - src_data_fp32, 1, - output_data_fp32 + ddim * add_indices_data[i], 1); + at::native::cpublas::axpy( + ddim, + 1, + src_data_fp32, + 1, + output_data_fp32 + ddim * add_indices_data[i], + 1); } else if (bag_size.defined()) { // Decrement bag_size to reflect that the index is padded @@ -327,14 +366,15 @@ index_select_add(const Tensor &select_indices, } } for (const auto i : c10::irange(output.size(0))) { - // Convert FP32 intermediate buffer result back to FP16 for output dtype + // Convert FP32 intermediate buffer result back to 16 bit for output + // dtype for (const auto d : c10::irange(ddim)) { - (output_data + output_stride0 * i)[d * output_stride1] = static_cast((output_data_fp32 + ddim * i)[d]); + (output_data + output_stride0 * i)[d * output_stride1] = + static_cast((output_data_fp32 + ddim * i)[d]); } } } } - template typename std::enable_if::value, void>::type index_select_add(const Tensor &select_indices, @@ -464,18 +504,19 @@ index_select_add(const Tensor &select_indices, // index_select (using select_indices as the index) // mul (scaling by per_sample_weights) // index_add (using add_indices as the index) -template -static typename std::enable_if::value && !std::is_same::value, void>::type -index_select_scale_add(const Tensor &select_indices, - const Tensor &add_indices, - const Tensor &scale, - const Tensor &src, - Tensor &output, - const Tensor& /*offsets*/, - bool /*include_last_offset*/, - Tensor &bag_size, - index_t padding_idx, - _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) { +template +static typename std::enable_if::value, void>::type +index_select_scale_add( + const Tensor& select_indices, + const Tensor& add_indices, + const Tensor& scale, + const Tensor& src, + Tensor& output, + const Tensor& /*offsets*/, + bool /*include_last_offset*/, + Tensor& bag_size, + index_t padding_idx, + _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) { AT_ASSERT(select_indices.numel() == add_indices.numel()); auto* add_indices_data = add_indices.data_ptr(); auto* select_indices_data = select_indices.data_ptr(); @@ -520,26 +561,30 @@ index_select_scale_add(const Tensor &select_indices, } } -template -typename std::enable_if::value, void>::type -index_select_scale_add(const Tensor &select_indices, - const Tensor &add_indices, - const Tensor &scale, - const Tensor &src, - Tensor &output, - const Tensor& offsets, - bool include_last_offset, - Tensor &bag_size, - index_t padding_idx, - _EmbeddingBagKernelCache* fbgemm_kernel_cache) { +template +typename std::enable_if< + std::is_same::value || + std::is_same::value, + void>::type +index_select_scale_add( + const Tensor& select_indices, + const Tensor& add_indices, + const Tensor& scale, + const Tensor& src, + Tensor& output, + const Tensor& offsets, + bool include_last_offset, + Tensor& bag_size, + index_t padding_idx, + _EmbeddingBagKernelCache* fbgemm_kernel_cache) { int64_t ddim = src.size(1); - auto* scale_data = scale.data_ptr(); + auto* scale_data = scale.data_ptr(); auto* select_indices_data = select_indices.data_ptr(); - auto* output_data = output.data_ptr(); + auto* output_data = output.data_ptr(); if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) { auto src_contig = src.contiguous(); - auto* src_data = src_contig.data_ptr(); + auto* src_data = src_contig.data_ptr(); int64_t output_size = offsets.numel() - 1; auto* offsets_data = offsets.data_ptr(); std::vector offsets_include_last; @@ -560,40 +605,42 @@ index_select_scale_add(const Tensor &select_indices, Tensor scale_fp32 = at::empty(scale.sizes(), scale.options().dtype(at::kFloat)); auto* scale_data_fp32 = scale_fp32.data_ptr(); -#ifdef USE_FBGEMM - using float16 = uint16_t; - fbgemm::Float16ToFloat_simd(reinterpret_cast(scale_data), scale_data_fp32, scale_fp32.numel()); - auto kernel_fp16_index_t = - fbgemm_kernel_cache ? - fbgemm_kernel_cache->getCallback(ddim) : - fbgemm::GenerateEmbeddingSpMDM( - /* block_size */ddim, - /* has_weight */true, - /* normalize_by_lengths */false, - /* prefetch */16, - /* is_weight_positional */false, - /* use_offsets */true - ); -#else - // Initialize the intermediate output buffer to be 0. - Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat)); - auto* output_data_fp32 = output_fp32.data_ptr(); - for (const auto i : c10::irange(scale.numel())) { - scale_data_fp32[i] = static_cast(scale_data[i]); +#if defined(USE_FBGEMM) + bool isbf16 = std::is_same::value ? false : true; + if (isbf16) { + fbgemm::Bfloat16ToFloat_simd( + reinterpret_cast(scale_data), + scale_data_fp32, + scale_fp32.numel()); + } else { + fbgemm::Float16ToFloat_simd( + reinterpret_cast(scale_data), + scale_data_fp32, + scale_fp32.numel()); } -#endif + auto kernel_16bit_index_t = fbgemm_kernel_cache + ? fbgemm_kernel_cache + ->getCallback(ddim) + : fbgemm::GenerateEmbeddingSpMDM( + /* block_size */ ddim, + /* has_weight */ true, + /* normalize_by_lengths */ false, + /* prefetch */ 16, + /* is_weight_positional */ false, + /* use_offsets */ true, + /* isbf16*/ isbf16); at::parallel_for( 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { -#ifdef USE_FBGEMM - bool success = kernel_fp16_index_t( - /* output_size */end_idx - start_idx, - /* index_size */offsets_data[end_idx] - offsets_data[start_idx], - /* data_size */src.size(0), - /* input */reinterpret_cast(src_data), - /* indices */select_indices_data + offsets_data[start_idx], - /* offsets_or_lengths */offsets_data + start_idx, - /* weights */scale_data_fp32 + offsets_data[start_idx], - /* output */reinterpret_cast(output_data + start_idx * ddim)); + bool success = kernel_16bit_index_t( + /* output_size */ end_idx - start_idx, + /* index_size */ offsets_data[end_idx] - offsets_data[start_idx], + /* data_size */ src.size(0), + /* input */ reinterpret_cast(src_data), + /* indices */ select_indices_data + offsets_data[start_idx], + /* offsets_or_lengths */ offsets_data + start_idx, + /* weights */ scale_data_fp32 + offsets_data[start_idx], + /* output */ + reinterpret_cast(output_data + start_idx * ddim)); if (!success) { fbgemm_spmdm_report_error_( end_idx - start_idx, @@ -602,7 +649,19 @@ index_select_scale_add(const Tensor &select_indices, offsets_data + start_idx, select_indices_data + offsets_data[start_idx]); } + }); #else + // Initialize the intermediate output buffer to be 0. + Tensor output_fp32 = + at::zeros({output_size, ddim}, output.options().dtype(at::kFloat)); + auto* output_data_fp32 = output_fp32.data_ptr(); + for (const auto i : c10::irange(scale.numel())) { + scale_data_fp32[i] = static_cast(scale_data[i]); + } + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + at::parallel_for( + 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { caffe2::EmbeddingLookupIdx( /*block_size=*/ddim, /*output_size=*/end_idx - start_idx, @@ -615,17 +674,36 @@ index_select_scale_add(const Tensor &select_indices, /*scale_bias=*/nullptr, /*normalize_by_lengths=*/false, /*out=*/output_data_fp32 + start_idx * ddim); - for (const auto i : c10::irange(output_size)) { - // Convert FP32 intermediate buffer result back to FP16 for output dtype - for (const auto d : c10::irange(ddim)) { - (output_data + i * ddim)[d] = static_cast((output_data_fp32 + ddim * i)[d]); + for (int64_t i = start_idx; i < end_idx; i++) { + // Convert FP32 intermediate buffer result back to 16 bit for + // output dtype + if (std::is_same::value) { + // FP16 + for (const auto d : c10::irange(ddim)) { + (output_data + i * ddim)[d] = + static_cast((output_data_fp32 + ddim * i)[d]); + } + } else { + // BF16 + int64_t d = 0; + for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) { + fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d); + fVec temp_fp32_1 = + fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size()); + convert_float_bfloat16(temp_fp32_0, temp_fp32_1) + .store(output_data + i * ddim + d); + } + for (; d < ddim; d++) { + (output_data + i * ddim)[d] = + static_cast((output_data_fp32 + ddim * i)[d]); + } } } -#endif }); +#endif } else { AT_ASSERT(select_indices.numel() == add_indices.numel()); - auto* src_data = src.data_ptr(); + auto* src_data = src.data_ptr(); auto* add_indices_data = add_indices.data_ptr(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) index_t* bag_size_data = nullptr; @@ -641,7 +719,8 @@ index_select_scale_add(const Tensor &select_indices, auto numel = add_indices.numel(); // Initialize the intermediate output buffer to be 0. - Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat)); + Tensor output_fp32 = + at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat)); auto* output_data_fp32 = output_fp32.data_ptr(); for (const auto i : c10::irange(numel)) { @@ -653,12 +732,12 @@ index_select_scale_add(const Tensor &select_indices, "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", idx); if (idx != padding_idx) { - auto* src_base = src_data + src_stride0 * idx; auto* output_base_fp32 = output_data_fp32 + ddim * add_indices_data[i]; auto scale = scale_data[i * scale_stride]; for (const auto j : c10::irange(ddim)) { - output_base_fp32[j] += static_cast(src_base[j * src_stride1]) * static_cast(scale); + output_base_fp32[j] += static_cast(src_base[j * src_stride1]) * + static_cast(scale); } } else if (bag_size.defined()) { // Decrement bag_size to reflect that the index is padded @@ -667,14 +746,15 @@ index_select_scale_add(const Tensor &select_indices, } } for (const auto i : c10::irange(output.size(0))) { - // Convert FP32 intermediate buffer result back to FP16 for output dtype + // Convert FP32 intermediate buffer result back to 16 bit for output + // dtype for (const auto d : c10::irange(ddim)) { - (output_data + output_stride0 * i)[d * output_stride1] = static_cast((output_data_fp32 + ddim * i)[d]); + (output_data + output_stride0 * i)[d * output_stride1] = + static_cast((output_data_fp32 + ddim * i)[d]); } } } } - template typename std::enable_if::value, void>::type index_select_scale_add(const Tensor &select_indices, @@ -817,7 +897,8 @@ void check_arguments( checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt}); checkSameType("embedding_bag", indices_arg, offsets_arg); auto weight_arg = TensorArg(weight, "weight", 1); - checkScalarTypes("embedding_bag", weight_arg, {kHalf, kFloat, kDouble}); + checkScalarTypes( + "embedding_bag", weight_arg, {kHalf, kBFloat16, kFloat, kDouble}); AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() { if (offsets.size(0) > 0) { @@ -1086,12 +1167,22 @@ void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag, max_indices->copy_(bag_size); } } else { // MODE_MAX - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - weight.scalar_type(), "embedding_bag_cpu_max_out", [&]() { - embedding_bag_cpu_max_out( - max_indices, weight, indices, offset2bag, output, include_last_offset, bag_size, padding_idx); - } - ); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + weight.scalar_type(), + "embedding_bag_cpu_max_out", + [&]() { + embedding_bag_cpu_max_out( + max_indices, + weight, + indices, + offset2bag, + output, + include_last_offset, + bag_size, + padding_idx); + }); } } @@ -1521,7 +1612,8 @@ Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indi // for more details. auto grad = grad_.contiguous(); auto grad_arg = TensorArg(grad, "grad_", 1); - checkScalarTypes("embedding_bag", grad_arg, {kHalf, kFloat, kDouble}); + checkScalarTypes( + "embedding_bag", grad_arg, {kHalf, kBFloat16, kFloat, kDouble}); if (mode == MODE_MAX) { return _embedding_bag_dense_backward_cpu_max( diff --git a/aten/src/ATen/native/EmbeddingBag.h b/aten/src/ATen/native/EmbeddingBag.h index 9d44fa688b2..8ba7abe706c 100644 --- a/aten/src/ATen/native/EmbeddingBag.h +++ b/aten/src/ATen/native/EmbeddingBag.h @@ -98,14 +98,14 @@ struct _EmbeddingBagKernelCacheImpl : private StorageMixins... { // instantiate the cache with the list of storage mixins // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl< - _CallbackAndBlockSize, - _CallbackAndBlockSize, - _CallbackAndBlockSize, - _CallbackAndBlockSize, - _CallbackAndBlockSize, - _CallbackAndBlockSize, - _CallbackAndBlockSize, - _CallbackAndBlockSize>; + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize, + _CallbackAndBlockSize>; #else struct _EmbeddingBagKernelCache { explicit _EmbeddingBagKernelCache(c10::optional /* maybe_block_size */) {} diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py index f4e42aa4cfd..edbff94e19b 100644 --- a/test/nn/test_embedding.py +++ b/test/nn/test_embedding.py @@ -818,7 +818,10 @@ class TestEmbeddingNNDeviceType(NNTestCase): return torch.stack(bags) @skipMeta - @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.half, torch.float, torch.double))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.half, torch.bfloat16, torch.float, torch.double))) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half))) def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes): # Test empty input and per sample weight, and backward pass. There was a CUDA # invalid configuration bug (more context in #46572) @@ -857,7 +860,10 @@ class TestEmbeddingNNDeviceType(NNTestCase): test_per_sample_weights(mode, trainable) @skipMeta - @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half, torch.bfloat16))) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half))) def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes): def test_per_sample_weights(mode, trainable_scale): es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device) @@ -891,7 +897,10 @@ class TestEmbeddingNNDeviceType(NNTestCase): test_per_sample_weights(mode, trainable) @skipMeta - @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half, torch.bfloat16))) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half))) def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes): def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True): es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[2], device=device) @@ -1156,7 +1165,10 @@ class TestEmbeddingNNDeviceType(NNTestCase): self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset)) @skipMeta - @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half, torch.bfloat16))) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half))) def test_embedding_bag_device(self, device, dtypes): with set_default_dtype(torch.double): self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[2], dtype=dtypes[0], odtype=dtypes[1]) @@ -1192,7 +1204,10 @@ class TestEmbeddingNNDeviceType(NNTestCase): ) @skipMeta - @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half))) + @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half, torch.bfloat16))) + @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), + (torch.float, torch.double, torch.half))) def test_embedding_bag_non_contiguous_weight(self, device, dtypes): weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device) @@ -1216,7 +1231,7 @@ class TestEmbeddingNNDeviceType(NNTestCase): ) self.assertEqual(output_non_contig, output_contig) - @onlyCUDA + @onlyNativeDeviceTypes # currently fails on XLA @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) def test_embedding_bag_bfloat16(self, device, dtypes): with set_default_dtype(torch.double): diff --git a/test/test_meta.py b/test/test_meta.py index 75d09cac828..bdd425b86f7 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -967,7 +967,7 @@ meta_dispatch_device_expected_failures['cuda'] = { } meta_dispatch_device_skips['cpu'] = { - aten._embedding_bag_forward_only.default: {f16, f32, f64}, + aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64}, aten.native_batch_norm.default: {f32, f64}, aten._native_batch_norm_legit.default: {f32, f64}, aten._native_batch_norm_legit.no_stats: {f32, f64}, diff --git a/third_party/fbgemm b/third_party/fbgemm index 80d64206c07..03b20466767 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 80d64206c07879fd4683be66873de7cefa1a0a71 +Subproject commit 03b2046676707da64504e898490ab46104d4682a diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b6145992363..73cdd909c89 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16969,7 +16969,7 @@ op_db: List[OpInfo] = [ # This is because currently only the `input` field of SampleInput # is tested in gradient tests. op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs), - dtypes=floating_types_and(torch.float16), + dtypes=floating_types_and(torch.bfloat16, torch.float16), dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), # backward is not supported for mode `max` and dtype `bfloat16` backward_dtypesIfCUDA=floating_types_and(torch.float16),