mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
enable bf16 emb (#94163)
Merge https://github.com/pytorch/pytorch/pull/89199 and https://github.com/pytorch/pytorch/pull/91949 into one PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94163 Approved by: https://github.com/jianyuh, https://github.com/malfet, https://github.com/jgong5
This commit is contained in:
parent
020a0fbf62
commit
ed54a5d06b
6 changed files with 273 additions and 166 deletions
|
|
@ -1,10 +1,11 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/EmbeddingBag.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/TensorSubclassLikeUtils.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/EmbeddingBag.h>
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
|
|
@ -86,14 +87,20 @@ std::pair<Tensor, Tensor> promoteIndicesAndOffsets(
|
|||
// is only applicable if special conditions are met
|
||||
template<typename index_t>
|
||||
bool is_fast_path_index_select(const Tensor& src, Tensor& output, index_t padding_idx) {
|
||||
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && padding_idx < static_cast<index_t>(0);
|
||||
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
|
||||
src.scalar_type() == kBFloat16) &&
|
||||
src.strides()[1] == 1 && output.strides()[1] == 1 &&
|
||||
padding_idx < static_cast<index_t>(0);
|
||||
}
|
||||
|
||||
// Determines if we can use a fast implementation for index_select_scale_add,
|
||||
// which is only applicable if special conditions are met
|
||||
template<typename index_t>
|
||||
bool is_fast_path_index_select_scale(const Tensor& src, const Tensor& scale, Tensor& output, index_t padding_idx) {
|
||||
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && scale.strides()[0] == 1 && padding_idx < static_cast<index_t>(0);
|
||||
return (src.scalar_type() == kFloat || src.scalar_type() == kHalf ||
|
||||
src.scalar_type() == kBFloat16) &&
|
||||
src.strides()[1] == 1 && output.strides()[1] == 1 &&
|
||||
scale.strides()[0] == 1 && padding_idx < static_cast<index_t>(0);
|
||||
}
|
||||
|
||||
template<typename index_t>
|
||||
|
|
@ -106,17 +113,18 @@ bool is_fast_path(const Tensor& src, const c10::optional<Tensor>& scale, Tensor&
|
|||
// This function combines index_select (using select_indices as the index) and
|
||||
// index_add (using add_indices as the index), without creating an intermediary
|
||||
// tensor to hold the selected embeddings
|
||||
template<typename data_t, typename index_t>
|
||||
typename std::enable_if<!std::is_same<data_t, float>::value && !std::is_same<data_t, at::Half>::value, void>::type
|
||||
index_select_add(const Tensor &select_indices,
|
||||
const Tensor &add_indices,
|
||||
const Tensor &src,
|
||||
Tensor &output,
|
||||
const Tensor& /*offsets*/,
|
||||
bool /*include_last_offset*/,
|
||||
Tensor &bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
|
||||
template <typename data_t, typename index_t>
|
||||
static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
|
||||
index_select_add(
|
||||
const Tensor& select_indices,
|
||||
const Tensor& add_indices,
|
||||
const Tensor& src,
|
||||
Tensor& output,
|
||||
const Tensor& /*offsets*/,
|
||||
bool /*include_last_offset*/,
|
||||
Tensor& bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
|
||||
TORCH_CHECK(select_indices.numel() == add_indices.numel());
|
||||
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||
|
|
@ -184,24 +192,28 @@ void fbgemm_spmdm_report_error_(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
template<typename data_t, typename index_t>
|
||||
typename std::enable_if<std::is_same<data_t, at::Half>::value, void>::type
|
||||
index_select_add(const Tensor &select_indices,
|
||||
const Tensor &add_indices,
|
||||
const Tensor &src,
|
||||
Tensor &output,
|
||||
const Tensor& offsets,
|
||||
bool include_last_offset,
|
||||
Tensor &bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
|
||||
template <typename data_t, typename index_t>
|
||||
typename std::enable_if<
|
||||
std::is_same<data_t, at::Half>::value ||
|
||||
std::is_same<data_t, at::BFloat16>::value,
|
||||
void>::type
|
||||
index_select_add(
|
||||
const Tensor& select_indices,
|
||||
const Tensor& add_indices,
|
||||
const Tensor& src,
|
||||
Tensor& output,
|
||||
const Tensor& offsets,
|
||||
bool include_last_offset,
|
||||
Tensor& bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
|
||||
int64_t ddim = src.size(1);
|
||||
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||
auto* output_data = output.data_ptr<at::Half>();
|
||||
auto* output_data = output.data_ptr<data_t>();
|
||||
|
||||
if (is_fast_path_index_select(src, output, padding_idx)) {
|
||||
auto src_contig = src.contiguous();
|
||||
auto* src_data = src_contig.data_ptr<at::Half>();
|
||||
auto* src_data = src_contig.data_ptr<data_t>();
|
||||
int64_t output_size = offsets.numel() - 1;
|
||||
auto* offsets_data = offsets.data_ptr<index_t>();
|
||||
std::vector<index_t> offsets_include_last;
|
||||
|
|
@ -220,36 +232,31 @@ index_select_add(const Tensor &select_indices,
|
|||
offsets_include_last[offsets.numel()] = select_indices.numel();
|
||||
offsets_data = offsets_include_last.data();
|
||||
}
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
using float16 = uint16_t;
|
||||
auto kernel_fp16_index_t = fbgemm_kernel_cache ?
|
||||
fbgemm_kernel_cache->getCallback</* has_weight */ false, index_t, float16>(ddim) :
|
||||
fbgemm::GenerateEmbeddingSpMDM<float16, index_t, index_t, float16>(
|
||||
/* block_size */ddim,
|
||||
/* has_weight */false,
|
||||
/* normalize_by_lengths */false,
|
||||
/* prefetch */16,
|
||||
/* is_weight_positional */false,
|
||||
/* use_offsets */true
|
||||
);
|
||||
#else
|
||||
// Initialize the intermediate output buffer to be 0.
|
||||
Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
|
||||
auto* output_data_fp32 = output_fp32.data_ptr<float>();
|
||||
#endif
|
||||
#if defined(USE_FBGEMM)
|
||||
bool isbf16 = std::is_same<data_t, at::Half>::value ? false : true;
|
||||
auto kernel_16bit_index_t = fbgemm_kernel_cache
|
||||
? fbgemm_kernel_cache
|
||||
->getCallback</* has_weight */ false, index_t, uint16_t>(ddim)
|
||||
: fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
|
||||
/* block_size */ ddim,
|
||||
/* has_weight */ false,
|
||||
/* normalize_by_lengths */ false,
|
||||
/* prefetch */ 16,
|
||||
/* is_weight_positional */ false,
|
||||
/* use_offsets */ true,
|
||||
/* isbf16*/ isbf16);
|
||||
at::parallel_for(
|
||||
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
|
||||
#ifdef USE_FBGEMM
|
||||
bool success = kernel_fp16_index_t(
|
||||
/* output_size */end_idx - start_idx,
|
||||
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
|
||||
/* data_size */src.size(0),
|
||||
/* input */reinterpret_cast<const float16*>(src_data),
|
||||
/* indices */select_indices_data + offsets_data[start_idx],
|
||||
/* offsets_or_lengths */offsets_data + start_idx,
|
||||
/* weights */nullptr,
|
||||
/* output */reinterpret_cast<float16*>(output_data + start_idx * ddim));
|
||||
bool success = kernel_16bit_index_t(
|
||||
/* output_size */ end_idx - start_idx,
|
||||
/* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
|
||||
/* data_size */ src.size(0),
|
||||
/* input */ reinterpret_cast<const uint16_t*>(src_data),
|
||||
/* indices */ select_indices_data + offsets_data[start_idx],
|
||||
/* offsets_or_lengths */ offsets_data + start_idx,
|
||||
/* weights */ nullptr,
|
||||
/* output */
|
||||
reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
|
||||
if (!success) {
|
||||
fbgemm_spmdm_report_error_(
|
||||
end_idx - start_idx,
|
||||
|
|
@ -258,7 +265,15 @@ index_select_add(const Tensor &select_indices,
|
|||
offsets_data + start_idx,
|
||||
select_indices_data + offsets_data[start_idx]);
|
||||
}
|
||||
});
|
||||
#else
|
||||
// Initialize the intermediate output buffer to be 0.
|
||||
Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
|
||||
auto* output_data_fp32 = output_fp32.data_ptr<float>();
|
||||
using bVec = vec::Vectorized<BFloat16>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
at::parallel_for(
|
||||
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
|
||||
caffe2::EmbeddingLookupIdx(
|
||||
/*block_size=*/ddim,
|
||||
/*output_size=*/end_idx - start_idx,
|
||||
|
|
@ -271,18 +286,36 @@ index_select_add(const Tensor &select_indices,
|
|||
/*scale_bias=*/nullptr,
|
||||
/*normalize_by_lengths=*/false,
|
||||
/*out=*/output_data_fp32 + start_idx * ddim);
|
||||
for (const auto i : c10::irange(output_size)) {
|
||||
// Convert FP32 intermediate buffer result back to FP16 for output dtype
|
||||
for (const auto d : c10::irange(ddim)) {
|
||||
(output_data + i * ddim)[d] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
|
||||
for (int64_t i = start_idx; i < end_idx; i++) {
|
||||
// Convert FP32 intermediate buffer result back to 16 bit for
|
||||
// output dtype
|
||||
if (std::is_same<data_t, at::Half>::value) {
|
||||
// FP16
|
||||
for (const auto d : c10::irange(ddim)) {
|
||||
(output_data + i * ddim)[d] =
|
||||
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
|
||||
}
|
||||
} else {
|
||||
// BF16
|
||||
int64_t d = 0;
|
||||
for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
|
||||
fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
|
||||
fVec temp_fp32_1 =
|
||||
fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
|
||||
convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
|
||||
.store(output_data + i * ddim + d);
|
||||
}
|
||||
for (; d < ddim; d++) {
|
||||
(output_data + i * ddim)[d] =
|
||||
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
});
|
||||
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(select_indices.numel() == add_indices.numel());
|
||||
auto* src_data = src.data_ptr<at::Half>();
|
||||
auto* src_data = src.data_ptr<data_t>();
|
||||
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
index_t* bag_size_data = nullptr;
|
||||
|
|
@ -300,7 +333,8 @@ index_select_add(const Tensor &select_indices,
|
|||
auto* src_data_fp32 = src_fp32.data_ptr<float>();
|
||||
|
||||
// Initialize the intermediate output buffer to be 0.
|
||||
Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
|
||||
Tensor output_fp32 =
|
||||
at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
|
||||
auto* output_data_fp32 = output_fp32.data_ptr<float>();
|
||||
|
||||
for (const auto i : c10::irange(numel)) {
|
||||
|
|
@ -314,11 +348,16 @@ index_select_add(const Tensor &select_indices,
|
|||
if (idx != padding_idx) {
|
||||
// Copy src_data + src_stride0 * idx to src_data_fp32
|
||||
for (const auto d : c10::irange(ddim)) {
|
||||
src_data_fp32[d] = static_cast<float>((src_data + src_stride0 * idx)[d * src_stride1]);
|
||||
src_data_fp32[d] = static_cast<float>(
|
||||
(src_data + src_stride0 * idx)[d * src_stride1]);
|
||||
}
|
||||
at::native::cpublas::axpy<float>(ddim, 1,
|
||||
src_data_fp32, 1,
|
||||
output_data_fp32 + ddim * add_indices_data[i], 1);
|
||||
at::native::cpublas::axpy<float>(
|
||||
ddim,
|
||||
1,
|
||||
src_data_fp32,
|
||||
1,
|
||||
output_data_fp32 + ddim * add_indices_data[i],
|
||||
1);
|
||||
|
||||
} else if (bag_size.defined()) {
|
||||
// Decrement bag_size to reflect that the index is padded
|
||||
|
|
@ -327,14 +366,15 @@ index_select_add(const Tensor &select_indices,
|
|||
}
|
||||
}
|
||||
for (const auto i : c10::irange(output.size(0))) {
|
||||
// Convert FP32 intermediate buffer result back to FP16 for output dtype
|
||||
// Convert FP32 intermediate buffer result back to 16 bit for output
|
||||
// dtype
|
||||
for (const auto d : c10::irange(ddim)) {
|
||||
(output_data + output_stride0 * i)[d * output_stride1] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
|
||||
(output_data + output_stride0 * i)[d * output_stride1] =
|
||||
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename data_t, typename index_t>
|
||||
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
|
||||
index_select_add(const Tensor &select_indices,
|
||||
|
|
@ -464,18 +504,19 @@ index_select_add(const Tensor &select_indices,
|
|||
// index_select (using select_indices as the index)
|
||||
// mul (scaling by per_sample_weights)
|
||||
// index_add (using add_indices as the index)
|
||||
template<typename data_t, typename index_t>
|
||||
static typename std::enable_if<!std::is_same<data_t, float>::value && !std::is_same<data_t, at::Half>::value, void>::type
|
||||
index_select_scale_add(const Tensor &select_indices,
|
||||
const Tensor &add_indices,
|
||||
const Tensor &scale,
|
||||
const Tensor &src,
|
||||
Tensor &output,
|
||||
const Tensor& /*offsets*/,
|
||||
bool /*include_last_offset*/,
|
||||
Tensor &bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
|
||||
template <typename data_t, typename index_t>
|
||||
static typename std::enable_if<std::is_same<data_t, double>::value, void>::type
|
||||
index_select_scale_add(
|
||||
const Tensor& select_indices,
|
||||
const Tensor& add_indices,
|
||||
const Tensor& scale,
|
||||
const Tensor& src,
|
||||
Tensor& output,
|
||||
const Tensor& /*offsets*/,
|
||||
bool /*include_last_offset*/,
|
||||
Tensor& bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) {
|
||||
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
||||
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||
|
|
@ -520,26 +561,30 @@ index_select_scale_add(const Tensor &select_indices,
|
|||
}
|
||||
}
|
||||
|
||||
template<typename data_t, typename index_t>
|
||||
typename std::enable_if<std::is_same<data_t, at::Half>::value, void>::type
|
||||
index_select_scale_add(const Tensor &select_indices,
|
||||
const Tensor &add_indices,
|
||||
const Tensor &scale,
|
||||
const Tensor &src,
|
||||
Tensor &output,
|
||||
const Tensor& offsets,
|
||||
bool include_last_offset,
|
||||
Tensor &bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
|
||||
template <typename data_t, typename index_t>
|
||||
typename std::enable_if<
|
||||
std::is_same<data_t, at::Half>::value ||
|
||||
std::is_same<data_t, at::BFloat16>::value,
|
||||
void>::type
|
||||
index_select_scale_add(
|
||||
const Tensor& select_indices,
|
||||
const Tensor& add_indices,
|
||||
const Tensor& scale,
|
||||
const Tensor& src,
|
||||
Tensor& output,
|
||||
const Tensor& offsets,
|
||||
bool include_last_offset,
|
||||
Tensor& bag_size,
|
||||
index_t padding_idx,
|
||||
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
|
||||
int64_t ddim = src.size(1);
|
||||
auto* scale_data = scale.data_ptr<at::Half>();
|
||||
auto* scale_data = scale.data_ptr<data_t>();
|
||||
auto* select_indices_data = select_indices.data_ptr<index_t>();
|
||||
auto* output_data = output.data_ptr<at::Half>();
|
||||
auto* output_data = output.data_ptr<data_t>();
|
||||
|
||||
if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) {
|
||||
auto src_contig = src.contiguous();
|
||||
auto* src_data = src_contig.data_ptr<at::Half>();
|
||||
auto* src_data = src_contig.data_ptr<data_t>();
|
||||
int64_t output_size = offsets.numel() - 1;
|
||||
auto* offsets_data = offsets.data_ptr<index_t>();
|
||||
std::vector<index_t> offsets_include_last;
|
||||
|
|
@ -560,40 +605,42 @@ index_select_scale_add(const Tensor &select_indices,
|
|||
Tensor scale_fp32 = at::empty(scale.sizes(), scale.options().dtype(at::kFloat));
|
||||
auto* scale_data_fp32 = scale_fp32.data_ptr<float>();
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
using float16 = uint16_t;
|
||||
fbgemm::Float16ToFloat_simd(reinterpret_cast<const float16*>(scale_data), scale_data_fp32, scale_fp32.numel());
|
||||
auto kernel_fp16_index_t =
|
||||
fbgemm_kernel_cache ?
|
||||
fbgemm_kernel_cache->getCallback</* has_weight */ true, index_t, float16>(ddim) :
|
||||
fbgemm::GenerateEmbeddingSpMDM<float16, index_t, index_t, float16>(
|
||||
/* block_size */ddim,
|
||||
/* has_weight */true,
|
||||
/* normalize_by_lengths */false,
|
||||
/* prefetch */16,
|
||||
/* is_weight_positional */false,
|
||||
/* use_offsets */true
|
||||
);
|
||||
#else
|
||||
// Initialize the intermediate output buffer to be 0.
|
||||
Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
|
||||
auto* output_data_fp32 = output_fp32.data_ptr<float>();
|
||||
for (const auto i : c10::irange(scale.numel())) {
|
||||
scale_data_fp32[i] = static_cast<float>(scale_data[i]);
|
||||
#if defined(USE_FBGEMM)
|
||||
bool isbf16 = std::is_same<data_t, at::Half>::value ? false : true;
|
||||
if (isbf16) {
|
||||
fbgemm::Bfloat16ToFloat_simd(
|
||||
reinterpret_cast<const fbgemm::bfloat16*>(scale_data),
|
||||
scale_data_fp32,
|
||||
scale_fp32.numel());
|
||||
} else {
|
||||
fbgemm::Float16ToFloat_simd(
|
||||
reinterpret_cast<const fbgemm::float16*>(scale_data),
|
||||
scale_data_fp32,
|
||||
scale_fp32.numel());
|
||||
}
|
||||
#endif
|
||||
auto kernel_16bit_index_t = fbgemm_kernel_cache
|
||||
? fbgemm_kernel_cache
|
||||
->getCallback</* has_weight */ true, index_t, uint16_t>(ddim)
|
||||
: fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
|
||||
/* block_size */ ddim,
|
||||
/* has_weight */ true,
|
||||
/* normalize_by_lengths */ false,
|
||||
/* prefetch */ 16,
|
||||
/* is_weight_positional */ false,
|
||||
/* use_offsets */ true,
|
||||
/* isbf16*/ isbf16);
|
||||
at::parallel_for(
|
||||
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
|
||||
#ifdef USE_FBGEMM
|
||||
bool success = kernel_fp16_index_t(
|
||||
/* output_size */end_idx - start_idx,
|
||||
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
|
||||
/* data_size */src.size(0),
|
||||
/* input */reinterpret_cast<const float16*>(src_data),
|
||||
/* indices */select_indices_data + offsets_data[start_idx],
|
||||
/* offsets_or_lengths */offsets_data + start_idx,
|
||||
/* weights */scale_data_fp32 + offsets_data[start_idx],
|
||||
/* output */reinterpret_cast<float16*>(output_data + start_idx * ddim));
|
||||
bool success = kernel_16bit_index_t(
|
||||
/* output_size */ end_idx - start_idx,
|
||||
/* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
|
||||
/* data_size */ src.size(0),
|
||||
/* input */ reinterpret_cast<const uint16_t*>(src_data),
|
||||
/* indices */ select_indices_data + offsets_data[start_idx],
|
||||
/* offsets_or_lengths */ offsets_data + start_idx,
|
||||
/* weights */ scale_data_fp32 + offsets_data[start_idx],
|
||||
/* output */
|
||||
reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
|
||||
if (!success) {
|
||||
fbgemm_spmdm_report_error_(
|
||||
end_idx - start_idx,
|
||||
|
|
@ -602,7 +649,19 @@ index_select_scale_add(const Tensor &select_indices,
|
|||
offsets_data + start_idx,
|
||||
select_indices_data + offsets_data[start_idx]);
|
||||
}
|
||||
});
|
||||
#else
|
||||
// Initialize the intermediate output buffer to be 0.
|
||||
Tensor output_fp32 =
|
||||
at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
|
||||
auto* output_data_fp32 = output_fp32.data_ptr<float>();
|
||||
for (const auto i : c10::irange(scale.numel())) {
|
||||
scale_data_fp32[i] = static_cast<float>(scale_data[i]);
|
||||
}
|
||||
using bVec = vec::Vectorized<BFloat16>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
at::parallel_for(
|
||||
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
|
||||
caffe2::EmbeddingLookupIdx(
|
||||
/*block_size=*/ddim,
|
||||
/*output_size=*/end_idx - start_idx,
|
||||
|
|
@ -615,17 +674,36 @@ index_select_scale_add(const Tensor &select_indices,
|
|||
/*scale_bias=*/nullptr,
|
||||
/*normalize_by_lengths=*/false,
|
||||
/*out=*/output_data_fp32 + start_idx * ddim);
|
||||
for (const auto i : c10::irange(output_size)) {
|
||||
// Convert FP32 intermediate buffer result back to FP16 for output dtype
|
||||
for (const auto d : c10::irange(ddim)) {
|
||||
(output_data + i * ddim)[d] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
|
||||
for (int64_t i = start_idx; i < end_idx; i++) {
|
||||
// Convert FP32 intermediate buffer result back to 16 bit for
|
||||
// output dtype
|
||||
if (std::is_same<data_t, at::Half>::value) {
|
||||
// FP16
|
||||
for (const auto d : c10::irange(ddim)) {
|
||||
(output_data + i * ddim)[d] =
|
||||
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
|
||||
}
|
||||
} else {
|
||||
// BF16
|
||||
int64_t d = 0;
|
||||
for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
|
||||
fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
|
||||
fVec temp_fp32_1 =
|
||||
fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
|
||||
convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
|
||||
.store(output_data + i * ddim + d);
|
||||
}
|
||||
for (; d < ddim; d++) {
|
||||
(output_data + i * ddim)[d] =
|
||||
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
});
|
||||
#endif
|
||||
} else {
|
||||
AT_ASSERT(select_indices.numel() == add_indices.numel());
|
||||
auto* src_data = src.data_ptr<at::Half>();
|
||||
auto* src_data = src.data_ptr<data_t>();
|
||||
auto* add_indices_data = add_indices.data_ptr<index_t>();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
index_t* bag_size_data = nullptr;
|
||||
|
|
@ -641,7 +719,8 @@ index_select_scale_add(const Tensor &select_indices,
|
|||
auto numel = add_indices.numel();
|
||||
|
||||
// Initialize the intermediate output buffer to be 0.
|
||||
Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
|
||||
Tensor output_fp32 =
|
||||
at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
|
||||
auto* output_data_fp32 = output_fp32.data_ptr<float>();
|
||||
|
||||
for (const auto i : c10::irange(numel)) {
|
||||
|
|
@ -653,12 +732,12 @@ index_select_scale_add(const Tensor &select_indices,
|
|||
"embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
|
||||
idx);
|
||||
if (idx != padding_idx) {
|
||||
|
||||
auto* src_base = src_data + src_stride0 * idx;
|
||||
auto* output_base_fp32 = output_data_fp32 + ddim * add_indices_data[i];
|
||||
auto scale = scale_data[i * scale_stride];
|
||||
for (const auto j : c10::irange(ddim)) {
|
||||
output_base_fp32[j] += static_cast<float>(src_base[j * src_stride1]) * static_cast<float>(scale);
|
||||
output_base_fp32[j] += static_cast<float>(src_base[j * src_stride1]) *
|
||||
static_cast<float>(scale);
|
||||
}
|
||||
} else if (bag_size.defined()) {
|
||||
// Decrement bag_size to reflect that the index is padded
|
||||
|
|
@ -667,14 +746,15 @@ index_select_scale_add(const Tensor &select_indices,
|
|||
}
|
||||
}
|
||||
for (const auto i : c10::irange(output.size(0))) {
|
||||
// Convert FP32 intermediate buffer result back to FP16 for output dtype
|
||||
// Convert FP32 intermediate buffer result back to 16 bit for output
|
||||
// dtype
|
||||
for (const auto d : c10::irange(ddim)) {
|
||||
(output_data + output_stride0 * i)[d * output_stride1] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]);
|
||||
(output_data + output_stride0 * i)[d * output_stride1] =
|
||||
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename data_t, typename index_t>
|
||||
typename std::enable_if<std::is_same<data_t, float>::value, void>::type
|
||||
index_select_scale_add(const Tensor &select_indices,
|
||||
|
|
@ -817,7 +897,8 @@ void check_arguments(
|
|||
checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt});
|
||||
checkSameType("embedding_bag", indices_arg, offsets_arg);
|
||||
auto weight_arg = TensorArg(weight, "weight", 1);
|
||||
checkScalarTypes("embedding_bag", weight_arg, {kHalf, kFloat, kDouble});
|
||||
checkScalarTypes(
|
||||
"embedding_bag", weight_arg, {kHalf, kBFloat16, kFloat, kDouble});
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() {
|
||||
if (offsets.size(0) > 0) {
|
||||
|
|
@ -1086,12 +1167,22 @@ void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
|
|||
max_indices->copy_(bag_size);
|
||||
}
|
||||
} else { // MODE_MAX
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
weight.scalar_type(), "embedding_bag_cpu_max_out", [&]() {
|
||||
embedding_bag_cpu_max_out<scalar_t>(
|
||||
max_indices, weight, indices, offset2bag, output, include_last_offset, bag_size, padding_idx);
|
||||
}
|
||||
);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
weight.scalar_type(),
|
||||
"embedding_bag_cpu_max_out",
|
||||
[&]() {
|
||||
embedding_bag_cpu_max_out<scalar_t>(
|
||||
max_indices,
|
||||
weight,
|
||||
indices,
|
||||
offset2bag,
|
||||
output,
|
||||
include_last_offset,
|
||||
bag_size,
|
||||
padding_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1521,7 +1612,8 @@ Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indi
|
|||
// for more details.
|
||||
auto grad = grad_.contiguous();
|
||||
auto grad_arg = TensorArg(grad, "grad_", 1);
|
||||
checkScalarTypes("embedding_bag", grad_arg, {kHalf, kFloat, kDouble});
|
||||
checkScalarTypes(
|
||||
"embedding_bag", grad_arg, {kHalf, kBFloat16, kFloat, kDouble});
|
||||
|
||||
if (mode == MODE_MAX) {
|
||||
return _embedding_bag_dense_backward_cpu_max(
|
||||
|
|
|
|||
|
|
@ -98,14 +98,14 @@ struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
|
|||
// instantiate the cache with the list of storage mixins
|
||||
// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
|
||||
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
|
||||
_CallbackAndBlockSize<true, int32_t, float>,
|
||||
_CallbackAndBlockSize<false, int32_t, float>,
|
||||
_CallbackAndBlockSize<true, int64_t, float>,
|
||||
_CallbackAndBlockSize<false, int64_t, float>,
|
||||
_CallbackAndBlockSize<true, int32_t, unsigned short>,
|
||||
_CallbackAndBlockSize<false, int32_t, unsigned short>,
|
||||
_CallbackAndBlockSize<true, int64_t, unsigned short>,
|
||||
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
|
||||
_CallbackAndBlockSize<true, int32_t, float>,
|
||||
_CallbackAndBlockSize<false, int32_t, float>,
|
||||
_CallbackAndBlockSize<true, int64_t, float>,
|
||||
_CallbackAndBlockSize<false, int64_t, float>,
|
||||
_CallbackAndBlockSize<true, int32_t, unsigned short>,
|
||||
_CallbackAndBlockSize<false, int32_t, unsigned short>,
|
||||
_CallbackAndBlockSize<true, int64_t, unsigned short>,
|
||||
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
|
||||
#else
|
||||
struct _EmbeddingBagKernelCache {
|
||||
explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}
|
||||
|
|
|
|||
|
|
@ -818,7 +818,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
|||
return torch.stack(bags)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.half, torch.float, torch.double)))
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.half, torch.bfloat16, torch.float, torch.double)))
|
||||
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half)))
|
||||
def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
|
||||
# Test empty input and per sample weight, and backward pass. There was a CUDA
|
||||
# invalid configuration bug (more context in #46572)
|
||||
|
|
@ -857,7 +860,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
|||
test_per_sample_weights(mode, trainable)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half, torch.bfloat16)))
|
||||
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half)))
|
||||
def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
|
||||
def test_per_sample_weights(mode, trainable_scale):
|
||||
es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
|
||||
|
|
@ -891,7 +897,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
|||
test_per_sample_weights(mode, trainable)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half, torch.bfloat16)))
|
||||
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half)))
|
||||
def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
|
||||
def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True):
|
||||
es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[2], device=device)
|
||||
|
|
@ -1156,7 +1165,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
|||
self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half, torch.bfloat16)))
|
||||
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half)))
|
||||
def test_embedding_bag_device(self, device, dtypes):
|
||||
with set_default_dtype(torch.double):
|
||||
self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[2], dtype=dtypes[0], odtype=dtypes[1])
|
||||
|
|
@ -1192,7 +1204,10 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
|||
)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half, torch.bfloat16)))
|
||||
@dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long),
|
||||
(torch.float, torch.double, torch.half)))
|
||||
def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
|
||||
weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)
|
||||
|
||||
|
|
@ -1216,7 +1231,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
|||
)
|
||||
self.assertEqual(output_non_contig, output_contig)
|
||||
|
||||
@onlyCUDA
|
||||
@onlyNativeDeviceTypes # currently fails on XLA
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
|
||||
def test_embedding_bag_bfloat16(self, device, dtypes):
|
||||
with set_default_dtype(torch.double):
|
||||
|
|
|
|||
|
|
@ -967,7 +967,7 @@ meta_dispatch_device_expected_failures['cuda'] = {
|
|||
}
|
||||
|
||||
meta_dispatch_device_skips['cpu'] = {
|
||||
aten._embedding_bag_forward_only.default: {f16, f32, f64},
|
||||
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
|
||||
aten.native_batch_norm.default: {f32, f64},
|
||||
aten._native_batch_norm_legit.default: {f32, f64},
|
||||
aten._native_batch_norm_legit.no_stats: {f32, f64},
|
||||
|
|
|
|||
2
third_party/fbgemm
vendored
2
third_party/fbgemm
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit 80d64206c07879fd4683be66873de7cefa1a0a71
|
||||
Subproject commit 03b2046676707da64504e898490ab46104d4682a
|
||||
|
|
@ -16969,7 +16969,7 @@ op_db: List[OpInfo] = [
|
|||
# This is because currently only the `input` field of SampleInput
|
||||
# is tested in gradient tests.
|
||||
op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs),
|
||||
dtypes=floating_types_and(torch.float16),
|
||||
dtypes=floating_types_and(torch.bfloat16, torch.float16),
|
||||
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
|
||||
# backward is not supported for mode `max` and dtype `bfloat16`
|
||||
backward_dtypesIfCUDA=floating_types_and(torch.float16),
|
||||
|
|
|
|||
Loading…
Reference in a new issue