mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: For LSTM, the input and hidden state are projected with Linear layers to construct the 4 gates. This is typically performed together as a single Linear (for each state) with output channel count `4 * hidden_dim` for efficiency. https://www.internalfb.com/code/fbsource/[ebef7c4238aa55948b2b444044f2c8ed2040de55]/fbcode/caffe2/torch/ao/nn/quantizable/modules/rnn.py?lines=52-58 The output is then ultimately split into 4: https://www.internalfb.com/code/fbsource/[ebef7c4238aa55948b2b444044f2c8ed2040de55]/fbcode/caffe2/torch/ao/nn/quantizable/modules/rnn.py?lines=83-87 For on-device latency (and possibly memory) considerations, we want to avoid constructing the intermediate `gates` tensor (which can be relatively large), by splitting `igates` and `hgates` first (as 4x `Linear(hidden_dim, hidden_dim)` each), applying add separately, then proceeding as usual. This functionality can be enabled by specifying `split_gates=True` (default False is original behavior) at any entry point (directly with `torch.ao.nn.quantizable.LSTM` or via `_get_lstm_with_individually_observed_parts`). Test Plan: piggy back on existing test to check for correct swap handling, numerics, and jit.script during prepare/convert ``` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/quantization:test_quantization -- --exact 'caffe2/test/quantization:test_quantization - test_custom_module_lstm (caffe2.test.quantization.core.test_quantized_op.TestQuantizedOps)' ``` https://www.internalfb.com/intern/testinfra/testrun/11540474102848372 This test is quite long running now (more than double original). Reviewed By: Ninja91 Differential Revision: D65283170 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140868 Approved by: https://github.com/jerryzh168 |
||
|---|---|---|
| .. | ||
| experimental | ||
| __init__.py | ||
| test_backend_config.py | ||
| test_docs.py | ||
| test_quantized_functional.py | ||
| test_quantized_module.py | ||
| test_quantized_op.py | ||
| test_quantized_tensor.py | ||
| test_top_level_apis.py | ||
| test_utils.py | ||
| test_workflow_module.py | ||
| test_workflow_ops.py | ||