mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[caffe2] use JIT'ed fp16 SLS (#32432)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32432 Use JIT'ed fp16 SLS in D19477209 from Caffe2 operators Test Plan: CI Reviewed By: jianyuh Differential Revision: D19477208 fbshipit-source-id: ef2ccba10f5f4c475166141bf09c266dedb92d38
This commit is contained in:
parent
642bd51043
commit
92fbf7cf97
1 changed files with 105 additions and 53 deletions
|
|
@ -80,32 +80,56 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
|||
}
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
if (std::is_same<InputType, float>::value) {
|
||||
// If this is the first call or block size has changed (should never
|
||||
// happen actually), generate a kernel.
|
||||
if (D != last_block_size) {
|
||||
last_block_size = D;
|
||||
// If this is the first call or block size has changed (should never
|
||||
// happen actually), generate a kernel.
|
||||
if (D != last_block_size) {
|
||||
last_block_size = D;
|
||||
if (std::is_same<InputType, float>::value) {
|
||||
if (std::is_same<IndexType, std::int32_t>::value) {
|
||||
kernel32_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
|
||||
D,
|
||||
USE_WEIGHT,
|
||||
USE_MEAN,
|
||||
/*prefetch distance*/ 16,
|
||||
USE_POSITIONAL_WEIGHT);
|
||||
kernel_fp32_i32_ =
|
||||
fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
|
||||
D,
|
||||
USE_WEIGHT,
|
||||
USE_MEAN,
|
||||
/*prefetch distance*/ 16,
|
||||
USE_POSITIONAL_WEIGHT);
|
||||
} else {
|
||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||
kernel64_ = fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
|
||||
D,
|
||||
USE_WEIGHT,
|
||||
USE_MEAN,
|
||||
/*prefetch distance*/ 16,
|
||||
USE_POSITIONAL_WEIGHT);
|
||||
kernel_fp32_i64_ =
|
||||
fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
|
||||
D,
|
||||
USE_WEIGHT,
|
||||
USE_MEAN,
|
||||
/*prefetch distance*/ 16,
|
||||
USE_POSITIONAL_WEIGHT);
|
||||
}
|
||||
} else {
|
||||
CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
|
||||
if (std::is_same<IndexType, std::int32_t>::value) {
|
||||
kernel_fp16_i32_ =
|
||||
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int32_t>(
|
||||
D,
|
||||
USE_WEIGHT,
|
||||
USE_MEAN,
|
||||
/*prefetch distance*/ 16,
|
||||
USE_POSITIONAL_WEIGHT);
|
||||
} else {
|
||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||
kernel_fp16_i64_ =
|
||||
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int64_t>(
|
||||
D,
|
||||
USE_WEIGHT,
|
||||
USE_MEAN,
|
||||
/*prefetch distance*/ 16,
|
||||
USE_POSITIONAL_WEIGHT);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool success;
|
||||
bool success;
|
||||
if (std::is_same<InputType, float>::value) {
|
||||
if (std::is_same<IndexType, std::int32_t>::value) {
|
||||
success = kernel32_(
|
||||
success = kernel_fp32_i32_(
|
||||
M,
|
||||
indices_size,
|
||||
N,
|
||||
|
|
@ -115,7 +139,7 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
|||
in_weight,
|
||||
out_data);
|
||||
} else {
|
||||
success = kernel64_(
|
||||
success = kernel_fp32_i64_(
|
||||
M,
|
||||
indices_size,
|
||||
N,
|
||||
|
|
@ -125,39 +149,61 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
|||
in_weight,
|
||||
out_data);
|
||||
}
|
||||
|
||||
if (success) {
|
||||
return true;
|
||||
} else {
|
||||
if (std::is_same<IndexType, std::int32_t>::value) {
|
||||
success = kernel_fp16_i32_(
|
||||
M,
|
||||
indices_size,
|
||||
N,
|
||||
reinterpret_cast<const fbgemm::float16*>(in_data),
|
||||
indicesInput.template data<std::int32_t>(),
|
||||
lengths,
|
||||
in_weight,
|
||||
out_data);
|
||||
} else {
|
||||
success = kernel_fp16_i64_(
|
||||
M,
|
||||
indices_size,
|
||||
N,
|
||||
reinterpret_cast<const fbgemm::float16*>(in_data),
|
||||
indicesInput.template data<std::int64_t>(),
|
||||
lengths,
|
||||
in_weight,
|
||||
out_data);
|
||||
}
|
||||
|
||||
int64_t current = 0;
|
||||
for (int m = 0; m < M; ++m) {
|
||||
for (int i = 0; i < lengths[m]; ++i) {
|
||||
CAFFE_ENFORCE_LT(
|
||||
current,
|
||||
indices_size,
|
||||
"Your input seems to be incorrect: the sum of lengths values "
|
||||
"should be the size of the indices tensor, but it appears not.");
|
||||
IndexType idx = indices[current];
|
||||
CAFFE_ENFORCE(
|
||||
0 <= idx && idx < N,
|
||||
"Index ",
|
||||
current,
|
||||
" is out of bounds: ",
|
||||
idx,
|
||||
", range 0 to ",
|
||||
N);
|
||||
++current;
|
||||
}
|
||||
}
|
||||
CAFFE_ENFORCE_EQ(
|
||||
current,
|
||||
indices_size,
|
||||
"Your input seems to be incorrect: the sum of lengths values should be "
|
||||
"the size of the indices tensor, but it appears not.");
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (success) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int64_t current = 0;
|
||||
for (int m = 0; m < M; ++m) {
|
||||
for (int i = 0; i < lengths[m]; ++i) {
|
||||
CAFFE_ENFORCE_LT(
|
||||
current,
|
||||
indices_size,
|
||||
"Your input seems to be incorrect: the sum of lengths values "
|
||||
"should be the size of the indices tensor, but it appears not.");
|
||||
IndexType idx = indices[current];
|
||||
CAFFE_ENFORCE(
|
||||
0 <= idx && idx < N,
|
||||
"Index ",
|
||||
current,
|
||||
" is out of bounds: ",
|
||||
idx,
|
||||
", range 0 to ",
|
||||
N);
|
||||
++current;
|
||||
}
|
||||
}
|
||||
CAFFE_ENFORCE_EQ(
|
||||
current,
|
||||
indices_size,
|
||||
"Your input seems to be incorrect: the sum of lengths values should be "
|
||||
"the size of the indices tensor, but it appears not.");
|
||||
|
||||
return false;
|
||||
#endif
|
||||
|
||||
// delegate work to perfkernel that branches based on architecture
|
||||
|
|
@ -188,8 +234,14 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
|||
#ifdef USE_FBGEMM
|
||||
private:
|
||||
std::int64_t last_block_size{-1};
|
||||
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type kernel32_;
|
||||
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type kernel64_;
|
||||
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type
|
||||
kernel_fp32_i32_;
|
||||
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type
|
||||
kernel_fp32_i64_;
|
||||
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int32_t>::Type
|
||||
kernel_fp16_i32_;
|
||||
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int64_t>::Type
|
||||
kernel_fp16_i64_;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue