fix local_thread:pytorch embeddingbag

Test Plan: buck test //deeplearning/fbgemm:EmbeddingSpMDMTest

Differential Revision: D35595457

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75702
Approved by: https://github.com/jianyuh
This commit is contained in:
Jiyuan Zhang 2022-04-18 19:52:04 +00:00 committed by PyTorch MergeBot
parent 381e725911
commit 2602a5e76f
2 changed files with 2 additions and 2 deletions

View file

@ -405,7 +405,7 @@ at::Tensor& embedding_bag_byte_impl(
if (!pruned_weights || fallback_to_no_sparse) {
auto kernel_i8 =
fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType>(
fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType, /*OutType=*/float, /*TRHEAD_LOCAL=*/true>(
/*block_size=*/D,
/*has_weight=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,

2
third_party/fbgemm vendored

@ -1 +1 @@
Subproject commit 9cf1a9ffefbb439e823dd3340ab4967e0cfe23a6
Subproject commit 2e9be65810107a9595da717f95d21924b73be833