mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: # Latest Update This diff is no longer needed because we did need the check to exist, to make meta behave the same as other devices, see D54526190. --------------------------------- # Background T176105639 | case | embedding bag weight | per_sample_weight | fbgemm lookup | forward in meta | | A | fp32 | fp32 | good | good | | B | fp16 | fp32 | good| failed [check](https://fburl.com/code/k3n3h031) that forces weight dtype == per_sample_weights dtype | | C | fp16 | fp16 | P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call" | good | | D | fp32 | fp16 | N/A | N/A | Currently we are in case A. Users need to add `use_fp32_embedding` in training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight. When deleting `use_fp32_embedding`, they would fail the [check](https://fburl.com/code/k3n3h031) that forces `weight dtype == per_sample_weights dtype ` in meta_registration. The check is actually not necessary. Is it because the backend fbgemm does support case B. Additionally, later on in the `meta_embedding_bag`, `weight` and `per_sample_weights` don't need to be in the same dtype (https://fburl.com/code/q0tho05h, weight is src, per_sample_weights is scale) for `is_fast_path_index_select`. # This diff Therefore, this diff remove the unnecessary [check](https://fburl.com/code/k3n3h031) to support case B in meta forward. With such, users are able to use fp16 to be the emb bag dtype without the need to force per_sample_weights the same dtype in meta forward (see Test Plan). # Reference diffs to resolve this issue Diff 1: D52591217 This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However, `is_meta` also needs to be passed because of case C. fbgemm still does not support per_sample_weights = fp16 (see the above table). Therefore users are forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks. Diff 2: D53232739 Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too. However, it would then result in fbgemm issue too and has broken a bunch of prod models. Test Plan: # APS The following command will run icvr_launcher which triggers ads_launcher and run forward in meta device: ``` buck2 run mode/opt -c python.package_style=inplace //aps_models/ads/icvr:icvr_launcher_publish -- mode=mast_ig_fm_when_combo0_uhm_publish launcher.fbl_entitlement=ads_global_tc_ads_score launcher.data_project=oncall_ads_model_platform launcher.tags=[ads_ranking_taxonomy_exlarge_fm_prod] stages.train=false ``` Result: {F1461463993} Reviewed By: ezyang Differential Revision: D54175438 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136774 Approved by: https://github.com/ezyang |
||
|---|---|---|
| .. | ||
| test_convolution.py | ||
| test_dropout.py | ||
| test_embedding.py | ||
| test_init.py | ||
| test_lazy_modules.py | ||
| test_load_state_dict.py | ||
| test_module_hooks.py | ||
| test_multihead_attention.py | ||
| test_packed_sequence.py | ||
| test_parametrization.py | ||
| test_pooling.py | ||
| test_pruning.py | ||