haozhe.zhu 2023-02-12 00:05:09 +00:00 committed by PyTorch MergeBot
parent 020a0fbf62
commit ed54a5d06b
6 changed files with 273 additions and 166 deletions

View file

@ -1,10 +1,11 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/EmbeddingBag.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorUtils.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/EmbeddingBag.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/NonSymbolicBC.h>
@ -86,14 +87,20 @@ std::pair<Tensor, Tensor> promoteIndicesAndOffsets(
// is only applicable if special conditions are met
template<typename index_t>
bool is_fast_path_index_select(const Tensor& src, Tensor& output, index_t padding_idx) {
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && padding_idx < static_cast<index_t>(0);
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
src.scalar_type() == kBFloat16) &&
src.strides()[1] == 1 && output.strides()[1] == 1 &&
padding_idx < static_cast<index_t>(0);
}
// Determines if we can use a fast implementation for index_select_scale_add,
// which is only applicable if special conditions are met
template<typename index_t>
bool is_fast_path_index_select_scale(const Tensor& src, const Tensor& scale, Tensor& output, index_t padding_idx) {
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && scale.strides()[0] == 1 && padding_idx < static_cast<index_t>(0);
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
src.scalar_type() == kBFloat16) &&
src.strides()[1] == 1 && output.strides()[1] == 1 &&
scale.strides()[0] == 1 && padding_idx < static_cast<index_t>(0);
}
template<typename index_t>
@ -106,17 +113,18 @@ bool is_fast_path(const Tensor& src, const c10::optional<Tensor>& scale, Tensor&
// This function combines index_select (using select_indices as the index) and
// index_add (using add_indices as the index), without creating an intermediary
// tensor to hold the selected embeddings
template<typename data_t, typename index_t>
typename std::enable_if<!std::is_same<data_t, float>::value && !std::is_same<data_t, at::Half>::value, void>::type
index_select_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
Tensor &output,
const Tensor& /*offsets*/,
bool /*include_last_offset*/,
Tensor &bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
template <typename data_t, typename index_t>
static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
index_select_add(
const Tensor& select_indices,
const Tensor& add_indices,
const Tensor& src,
Tensor& output,
const Tensor& /*offsets*/,
bool /*include_last_offset*/,
Tensor& bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
TORCH_CHECK(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
@ -184,24 +192,28 @@ void fbgemm_spmdm_report_error_(
}
} // namespace
template<typename data_t, typename index_t>
typename std::enable_if<std::is_same<data_t, at::Half>::value, void>::type
index_select_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
Tensor &output,
const Tensor& offsets,
bool include_last_offset,
Tensor &bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
template <typename data_t, typename index_t>
typename std::enable_if<
std::is_same<data_t, at::Half>::value ||
std::is_same<data_t, at::BFloat16>::value,
void>::type
index_select_add(
const Tensor& select_indices,
const Tensor& add_indices,
const Tensor& src,
Tensor& output,
const Tensor& offsets,
bool include_last_offset,
Tensor& bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
int64_t ddim = src.size(1);
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* output_data = output.data_ptr<at::Half>();
auto* output_data = output.data_ptr<data_t>();
if (is_fast_path_index_select(src, output, padding_idx)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<at::Half>();
auto* src_data = src_contig.data_ptr<data_t>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<index_t>();
std::vector<index_t> offsets_include_last;
@ -220,36 +232,31 @@ index_select_add(const Tensor &select_indices,
offsets_include_last[offsets.numel()] = select_indices.numel();
offsets_data = offsets_include_last.data();
}
#ifdef USE_FBGEMM
using float16 = uint16_t;
auto kernel_fp16_index_t = fbgemm_kernel_cache ?
fbgemm_kernel_cache->getCallback</* has_weight */ false, index_t, float16>(ddim) :
fbgemm::GenerateEmbeddingSpMDM<float16, index_t, index_t, float16>(
/* block_size */ddim,
/* has_weight */false,
/* normalize_by_lengths */false,
/* prefetch */16,
/* is_weight_positional */false,
/* use_offsets */true
);
#else
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
#endif
#if defined(USE_FBGEMM)
bool isbf16 = std::is_same<data_t, at::Half>::value ? false : true;
auto kernel_16bit_index_t = fbgemm_kernel_cache
? fbgemm_kernel_cache
->getCallback</* has_weight */ false, index_t, uint16_t>(ddim)
: fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
/* block_size */ ddim,
/* has_weight */ false,
/* normalize_by_lengths */ false,
/* prefetch */ 16,
/* is_weight_positional */ false,
/* use_offsets */ true,
/* isbf16*/ isbf16);
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
#ifdef USE_FBGEMM
bool success = kernel_fp16_index_t(
/* output_size */end_idx - start_idx,
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */src.size(0),
/* input */reinterpret_cast<const float16*>(src_data),
/* indices */select_indices_data + offsets_data[start_idx],
/* offsets_or_lengths */offsets_data + start_idx,
/* weights */nullptr,
/* output */reinterpret_cast<float16*>(output_data + start_idx * ddim));
bool success = kernel_16bit_index_t(
/* output_size */ end_idx - start_idx,
/* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */ src.size(0),
/* input */ reinterpret_cast<const uint16_t*>(src_data),
/* indices */ select_indices_data + offsets_data[start_idx],
/* offsets_or_lengths */ offsets_data + start_idx,
/* weights */ nullptr,
/* output */
reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
if (!success) {
fbgemm_spmdm_report_error_(
end_idx - start_idx,
@ -258,7 +265,15 @@ index_select_add(const Tensor &select_indices,
offsets_data + start_idx,
select_indices_data + offsets_data[start_idx]);
}
});
#else
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
caffe2::EmbeddingLookupIdx(
/*block_size=*/ddim,
/*output_size=*/end_idx - start_idx,
@ -271,18 +286,36 @@ index_select_add(const Tensor &select_indices,
/*scale_bias=*/nullptr,
/*normalize_by_lengths=*/false,
/*out=*/output_data_fp32 + start_idx * ddim);
for (const auto i : c10::irange(output_size)) {
// Convert FP32 intermediate buffer result back to FP16 for output dtype
for (const auto d : c10::irange(ddim)) {
(output_data + i * ddim)[d] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
for (int64_t i = start_idx; i < end_idx; i++) {
// Convert FP32 intermediate buffer result back to 16 bit for
// output dtype
if (std::is_same<data_t, at::Half>::value) {
// FP16
for (const auto d : c10::irange(ddim)) {
(output_data + i * ddim)[d] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
} else {
// BF16
int64_t d = 0;
for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
fVec temp_fp32_1 =
fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
.store(output_data + i * ddim + d);
}
for (; d < ddim; d++) {
(output_data + i * ddim)[d] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
}
}
#endif
});
#endif
} else {
TORCH_CHECK(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<at::Half>();
auto* src_data = src.data_ptr<data_t>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
@ -300,7 +333,8 @@ index_select_add(const Tensor &select_indices,
auto* src_data_fp32 = src_fp32.data_ptr<float>();
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
Tensor output_fp32 =
at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
for (const auto i : c10::irange(numel)) {
@ -314,11 +348,16 @@ index_select_add(const Tensor &select_indices,
if (idx != padding_idx) {
// Copy src_data + src_stride0 * idx to src_data_fp32
for (const auto d : c10::irange(ddim)) {
src_data_fp32[d] = static_cast<float>((src_data + src_stride0 * idx)[d * src_stride1]);
src_data_fp32[d] = static_cast<float>(
(src_data + src_stride0 * idx)[d * src_stride1]);
}
at::native::cpublas::axpy<float>(ddim, 1,
src_data_fp32, 1,
output_data_fp32 + ddim * add_indices_data[i], 1);
at::native::cpublas::axpy<float>(
ddim,
1,
src_data_fp32,
1,
output_data_fp32 + ddim * add_indices_data[i],
1);
} else if (bag_size.defined()) {
// Decrement bag_size to reflect that the index is padded
@ -327,14 +366,15 @@ index_select_add(const Tensor &select_indices,
}
}
for (const auto i : c10::irange(output.size(0))) {
// Convert FP32 intermediate buffer result back to FP16 for output dtype
// Convert FP32 intermediate buffer result back to 16 bit for output
// dtype
for (const auto d : c10::irange(ddim)) {
(output_data + output_stride0 * i)[d * output_stride1] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
(output_data + output_stride0 * i)[d * output_stride1] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
}
}
}
template<typename data_t, typename index_t>
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_add(const Tensor &select_indices,
@ -464,18 +504,19 @@ index_select_add(const Tensor &select_indices,
// index_select (using select_indices as the index)
// mul (scaling by per_sample_weights)
// index_add (using add_indices as the index)
template<typename data_t, typename index_t>
static typename std::enable_if<!std::is_same<data_t, float>::value && !std::is_same<data_t, at::Half>::value, void>::type
index_select_scale_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &scale,
const Tensor &src,
Tensor &output,
const Tensor& /*offsets*/,
bool /*include_last_offset*/,
Tensor &bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
template <typename data_t, typename index_t>
static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
index_select_scale_add(
const Tensor& select_indices,
const Tensor& add_indices,
const Tensor& scale,
const Tensor& src,
Tensor& output,
const Tensor& /*offsets*/,
bool /*include_last_offset*/,
Tensor& bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<index_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
@ -520,26 +561,30 @@ index_select_scale_add(const Tensor &select_indices,
}
}
template<typename data_t, typename index_t>
typename std::enable_if<std::is_same<data_t, at::Half>::value, void>::type
index_select_scale_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &scale,
const Tensor &src,
Tensor &output,
const Tensor& offsets,
bool include_last_offset,
Tensor &bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
template <typename data_t, typename index_t>
typename std::enable_if<
std::is_same<data_t, at::Half>::value ||
std::is_same<data_t, at::BFloat16>::value,
void>::type
index_select_scale_add(
const Tensor& select_indices,
const Tensor& add_indices,
const Tensor& scale,
const Tensor& src,
Tensor& output,
const Tensor& offsets,
bool include_last_offset,
Tensor& bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
int64_t ddim = src.size(1);
auto* scale_data = scale.data_ptr<at::Half>();
auto* scale_data = scale.data_ptr<data_t>();
auto* select_indices_data = select_indices.data_ptr<index_t>();
auto* output_data = output.data_ptr<at::Half>();
auto* output_data = output.data_ptr<data_t>();
if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.data_ptr<at::Half>();
auto* src_data = src_contig.data_ptr<data_t>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<index_t>();
std::vector<index_t> offsets_include_last;
@ -560,40 +605,42 @@ index_select_scale_add(const Tensor &select_indices,
Tensor scale_fp32 = at::empty(scale.sizes(), scale.options().dtype(at::kFloat));
auto* scale_data_fp32 = scale_fp32.data_ptr<float>();
#ifdef USE_FBGEMM
using float16 = uint16_t;
fbgemm::Float16ToFloat_simd(reinterpret_cast<const float16*>(scale_data), scale_data_fp32, scale_fp32.numel());
auto kernel_fp16_index_t =
fbgemm_kernel_cache ?
fbgemm_kernel_cache->getCallback</* has_weight */ true, index_t, float16>(ddim) :
fbgemm::GenerateEmbeddingSpMDM<float16, index_t, index_t, float16>(
/* block_size */ddim,
/* has_weight */true,
/* normalize_by_lengths */false,
/* prefetch */16,
/* is_weight_positional */false,
/* use_offsets */true
);
#else
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
for (const auto i : c10::irange(scale.numel())) {
scale_data_fp32[i] = static_cast<float>(scale_data[i]);
#if defined(USE_FBGEMM)
bool isbf16 = std::is_same<data_t, at::Half>::value ? false : true;
if (isbf16) {
fbgemm::Bfloat16ToFloat_simd(
reinterpret_cast<const fbgemm::bfloat16*>(scale_data),
scale_data_fp32,
scale_fp32.numel());
} else {
fbgemm::Float16ToFloat_simd(
reinterpret_cast<const fbgemm::float16*>(scale_data),
scale_data_fp32,
scale_fp32.numel());
}
#endif
auto kernel_16bit_index_t = fbgemm_kernel_cache
? fbgemm_kernel_cache
->getCallback</* has_weight */ true, index_t, uint16_t>(ddim)
: fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
/* block_size */ ddim,
/* has_weight */ true,
/* normalize_by_lengths */ false,
/* prefetch */ 16,
/* is_weight_positional */ false,
/* use_offsets */ true,
/* isbf16*/ isbf16);
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
#ifdef USE_FBGEMM
bool success = kernel_fp16_index_t(
/* output_size */end_idx - start_idx,
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */src.size(0),
/* input */reinterpret_cast<const float16*>(src_data),
/* indices */select_indices_data + offsets_data[start_idx],
/* offsets_or_lengths */offsets_data + start_idx,
/* weights */scale_data_fp32 + offsets_data[start_idx],
/* output */reinterpret_cast<float16*>(output_data + start_idx * ddim));
bool success = kernel_16bit_index_t(
/* output_size */ end_idx - start_idx,
/* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */ src.size(0),
/* input */ reinterpret_cast<const uint16_t*>(src_data),
/* indices */ select_indices_data + offsets_data[start_idx],
/* offsets_or_lengths */ offsets_data + start_idx,
/* weights */ scale_data_fp32 + offsets_data[start_idx],
/* output */
reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
if (!success) {
fbgemm_spmdm_report_error_(
end_idx - start_idx,
@ -602,7 +649,19 @@ index_select_scale_add(const Tensor &select_indices,
offsets_data + start_idx,
select_indices_data + offsets_data[start_idx]);
}
});
#else
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 =
at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
for (const auto i : c10::irange(scale.numel())) {
scale_data_fp32[i] = static_cast<float>(scale_data[i]);
}
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
caffe2::EmbeddingLookupIdx(
/*block_size=*/ddim,
/*output_size=*/end_idx - start_idx,
@ -615,17 +674,36 @@ index_select_scale_add(const Tensor &select_indices,
/*scale_bias=*/nullptr,
/*normalize_by_lengths=*/false,
/*out=*/output_data_fp32 + start_idx * ddim);
for (const auto i : c10::irange(output_size)) {
// Convert FP32 intermediate buffer result back to FP16 for output dtype
for (const auto d : c10::irange(ddim)) {
(output_data + i * ddim)[d] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
for (int64_t i = start_idx; i < end_idx; i++) {
// Convert FP32 intermediate buffer result back to 16 bit for
// output dtype
if (std::is_same<data_t, at::Half>::value) {
// FP16
for (const auto d : c10::irange(ddim)) {
(output_data + i * ddim)[d] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
} else {
// BF16
int64_t d = 0;
for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
fVec temp_fp32_1 =
fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
.store(output_data + i * ddim + d);
}
for (; d < ddim; d++) {
(output_data + i * ddim)[d] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
}
}
#endif
});
#endif
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* src_data = src.data_ptr<at::Half>();
auto* src_data = src.data_ptr<data_t>();
auto* add_indices_data = add_indices.data_ptr<index_t>();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
index_t* bag_size_data = nullptr;
@ -641,7 +719,8 @@ index_select_scale_add(const Tensor &select_indices,
auto numel = add_indices.numel();
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
Tensor output_fp32 =
at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
for (const auto i : c10::irange(numel)) {
@ -653,12 +732,12 @@ index_select_scale_add(const Tensor &select_indices,
"embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
idx);
if (idx != padding_idx) {
auto* src_base = src_data + src_stride0 * idx;
auto* output_base_fp32 = output_data_fp32 + ddim * add_indices_data[i];
auto scale = scale_data[i * scale_stride];
for (const auto j : c10::irange(ddim)) {
output_base_fp32[j] += static_cast<float>(src_base[j * src_stride1]) * static_cast<float>(scale);
output_base_fp32[j] += static_cast<float>(src_base[j * src_stride1]) *
static_cast<float>(scale);
}
} else if (bag_size.defined()) {
// Decrement bag_size to reflect that the index is padded
@ -667,14 +746,15 @@ index_select_scale_add(const Tensor &select_indices,
}
}
for (const auto i : c10::irange(output.size(0))) {
// Convert FP32 intermediate buffer result back to FP16 for output dtype
// Convert FP32 intermediate buffer result back to 16 bit for output
// dtype
for (const auto d : c10::irange(ddim)) {
(output_data + output_stride0 * i)[d * output_stride1] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
(output_data + output_stride0 * i)[d * output_stride1] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
}
}
}
template<typename data_t, typename index_t>
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
index_select_scale_add(const Tensor &select_indices,
@ -817,7 +897,8 @@ void check_arguments(
checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
checkSameType("embedding_bag", indices_arg, offsets_arg);
auto weight_arg = TensorArg(weight, "weight", 1);
checkScalarTypes("embedding_bag", weight_arg, {kHalf, kFloat, kDouble});
checkScalarTypes(
"embedding_bag", weight_arg, {kHalf, kBFloat16, kFloat, kDouble});
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() {
if (offsets.size(0) > 0) {
@ -1086,12 +1167,22 @@ void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
max_indices->copy_(bag_size);
}
} else { // MODE_MAX
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weight.scalar_type(), "embedding_bag_cpu_max_out", [&]() {
embedding_bag_cpu_max_out<scalar_t>(
max_indices, weight, indices, offset2bag, output, include_last_offset, bag_size, padding_idx);
}
);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
weight.scalar_type(),
"embedding_bag_cpu_max_out",
[&]() {
embedding_bag_cpu_max_out<scalar_t>(
max_indices,
weight,
indices,
offset2bag,
output,
include_last_offset,
bag_size,
padding_idx);
});
}
}
@ -1521,7 +1612,8 @@ Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indi
// for more details.
auto grad = grad_.contiguous();
auto grad_arg = TensorArg(grad, "grad_", 1);
checkScalarTypes("embedding_bag", grad_arg, {kHalf, kFloat, kDouble});
checkScalarTypes(
"embedding_bag", grad_arg, {kHalf, kBFloat16, kFloat, kDouble});
if (mode == MODE_MAX) {
return _embedding_bag_dense_backward_cpu_max(

View file

@ -98,14 +98,14 @@ struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
// instantiate the cache with the list of storage mixins
// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
_CallbackAndBlockSize<true, int32_t, float>,
_CallbackAndBlockSize<false, int32_t, float>,
_CallbackAndBlockSize<true, int64_t, float>,
_CallbackAndBlockSize<false, int64_t, float>,
_CallbackAndBlockSize<true, int32_t, unsigned short>,
_CallbackAndBlockSize<false, int32_t, unsigned short>,
_CallbackAndBlockSize<true, int64_t, unsigned short>,
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
_CallbackAndBlockSize<true, int32_t, float>,
_CallbackAndBlockSize<false, int32_t, float>,
_CallbackAndBlockSize<true, int64_t, float>,
_CallbackAndBlockSize<false, int64_t, float>,
_CallbackAndBlockSize<true, int32_t, unsigned short>,
_CallbackAndBlockSize<false, int32_t, unsigned short>,
_CallbackAndBlockSize<true, int64_t, unsigned short>,
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
#else
struct _EmbeddingBagKernelCache {
explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}

View file

@ -818,7 +818,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
return torch.stack(bags)
@skipMeta
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.half, torch.float, torch.double)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.half, torch.bfloat16, torch.float, torch.double)))
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half)))
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
# Test empty input and per sample weight, and backward pass. There was a CUDA
# invalid configuration bug (more context in #46572)
@ -857,7 +860,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
test_per_sample_weights(mode, trainable)
@skipMeta
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half, torch.bfloat16)))
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half)))
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
def test_per_sample_weights(mode, trainable_scale):
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
@ -891,7 +897,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
test_per_sample_weights(mode, trainable)
@skipMeta
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half, torch.bfloat16)))
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half)))
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True):
es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[2], device=device)
@ -1156,7 +1165,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
@skipMeta
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half, torch.bfloat16)))
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half)))
def test_embedding_bag_device(self, device, dtypes):
with set_default_dtype(torch.double):
self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[2], dtype=dtypes[0], odtype=dtypes[1])
@ -1192,7 +1204,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
)
@skipMeta
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half, torch.bfloat16)))
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
(torch.float, torch.double, torch.half)))
def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)
@ -1216,7 +1231,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
)
self.assertEqual(output_non_contig, output_contig)
@onlyCUDA
@onlyNativeDeviceTypes # currently fails on XLA
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
def test_embedding_bag_bfloat16(self, device, dtypes):
with set_default_dtype(torch.double):

View file

@ -967,7 +967,7 @@ meta_dispatch_device_expected_failures['cuda'] = {
}
meta_dispatch_device_skips['cpu'] = {
aten._embedding_bag_forward_only.default: {f16, f32, f64},
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
aten.native_batch_norm.default: {f32, f64},
aten._native_batch_norm_legit.default: {f32, f64},
aten._native_batch_norm_legit.no_stats: {f32, f64},

2
third_party/fbgemm vendored

@ -1 +1 @@
Subproject commit 80d64206c07879fd4683be66873de7cefa1a0a71
Subproject commit 03b2046676707da64504e898490ab46104d4682a

View file

@ -16969,7 +16969,7 @@ op_db: List[OpInfo] = [
# This is because currently only the `input` field of SampleInput
# is tested in gradient tests.
op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs),
dtypes=floating_types_and(torch.float16),
dtypes=floating_types_and(torch.bfloat16, torch.float16),
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
# backward is not supported for mode `max` and dtype `bfloat16`
backward_dtypesIfCUDA=floating_types_and(torch.float16),