[caffe2] use JIT'ed fp16 SLS (#32432)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32432

Use JIT'ed fp16 SLS in D19477209 from Caffe2 operators

Test Plan: CI

Reviewed By: jianyuh

Differential Revision: D19477208

fbshipit-source-id: ef2ccba10f5f4c475166141bf09c266dedb92d38
This commit is contained in:
Jongsoo Park 2020-02-13 21:11:58 -08:00 committed by Facebook Github Bot
parent 642bd51043
commit 92fbf7cf97

View file

@ -80,32 +80,56 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
}
#ifdef USE_FBGEMM
if (std::is_same<InputType, float>::value) {
// If this is the first call or block size has changed (should never
// happen actually), generate a kernel.
if (D != last_block_size) {
last_block_size = D;
// If this is the first call or block size has changed (should never
// happen actually), generate a kernel.
if (D != last_block_size) {
last_block_size = D;
if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
kernel32_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT);
kernel_fp32_i32_ =
fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT);
} else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel64_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT);
kernel_fp32_i64_ =
fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT);
}
} else {
CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
if (std::is_same<IndexType, std::int32_t>::value) {
kernel_fp16_i32_ =
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int32_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT);
} else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel_fp16_i64_ =
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int64_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT);
}
}
}
bool success;
bool success;
if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
success = kernel32_(
success = kernel_fp32_i32_(
M,
indices_size,
N,
@ -115,7 +139,7 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
in_weight,
out_data);
} else {
success = kernel64_(
success = kernel_fp32_i64_(
M,
indices_size,
N,
@ -125,39 +149,61 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
in_weight,
out_data);
}
if (success) {
return true;
} else {
if (std::is_same<IndexType, std::int32_t>::value) {
success = kernel_fp16_i32_(
M,
indices_size,
N,
reinterpret_cast<const fbgemm::float16*>(in_data),
indicesInput.template data<std::int32_t>(),
lengths,
in_weight,
out_data);
} else {
success = kernel_fp16_i64_(
M,
indices_size,
N,
reinterpret_cast<const fbgemm::float16*>(in_data),
indicesInput.template data<std::int64_t>(),
lengths,
in_weight,
out_data);
}
int64_t current = 0;
for (int m = 0; m < M; ++m) {
for (int i = 0; i < lengths[m]; ++i) {
CAFFE_ENFORCE_LT(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values "
"should be the size of the indices tensor, but it appears not.");
IndexType idx = indices[current];
CAFFE_ENFORCE(
0 <= idx && idx < N,
"Index ",
current,
" is out of bounds: ",
idx,
", range 0 to ",
N);
++current;
}
}
CAFFE_ENFORCE_EQ(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values should be "
"the size of the indices tensor, but it appears not.");
return false;
}
if (success) {
return true;
}
int64_t current = 0;
for (int m = 0; m < M; ++m) {
for (int i = 0; i < lengths[m]; ++i) {
CAFFE_ENFORCE_LT(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values "
"should be the size of the indices tensor, but it appears not.");
IndexType idx = indices[current];
CAFFE_ENFORCE(
0 <= idx && idx < N,
"Index ",
current,
" is out of bounds: ",
idx,
", range 0 to ",
N);
++current;
}
}
CAFFE_ENFORCE_EQ(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values should be "
"the size of the indices tensor, but it appears not.");
return false;
#endif
// delegate work to perfkernel that branches based on architecture
@ -188,8 +234,14 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
#ifdef USE_FBGEMM
private:
std::int64_t last_block_size{-1};
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type kernel32_;
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type kernel64_;
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type
kernel_fp32_i32_;
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type
kernel_fp32_i64_;
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int32_t>::Type
kernel_fp16_i32_;
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int64_t>::Type
kernel_fp16_i64_;
#endif
};