pytorch/test/expect
mingfeima c620ece726 port sparse_mm.reduce to pytorch and optimize it on CPU (#83727)
### Motivation of this PR

This patch is to migrate `spmm_reduce` from `torch-sparse` (a 3rd party dependency for PyG) to `torch`, which is a response to the initial proposal for fusion of **Gather, Apply Scatter** in Message Passing of GNN inference/training. https://github.com/pytorch/pytorch/issues/71300

**GAS** is the major step for Message Passing, the behavior of **GAS** can be classified into 2 kinds depending on the storage type of `EdgeIndex` which records the connections of nodes:

* COO: the hotspot is `scatter_reduce`
* CSR: the hotspot is `spmm_reduce`

The reduce type can be choose from: "max", "mean", "max",  "min".

extend `torch.sparse.mm` with an `reduce` argument, maps to `torch.sparse_mm.reduce` internally.
`sparse_mm_reduce` is registered under the TensorTypeId of `SparseCsrCPU`, and this operator requires an internal interface `_sparse_mm_reduce_impl` which has dual outputs:
* `out` - the actual output
* `arg_out` - records output indices in the non zero elements if the reduce type is "max" or "min", this is only useful for training. So for inference, it will not be calculated.

### Performance

Benchmark on GCN for obgn-products on Xeon single socket, the workload is improved by `4.3x` with this patch.

Performance benefit for training will be bigger, the original backward impl for `sum|mean` is sequential; the original backward impl for `max|min` is not fused.

#### before:
```
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       torch_sparse::spmm_sum        97.09%       56.086s        97.09%       56.088s        6.232s             9
                 aten::linear         0.00%      85.000us         1.38%     795.485ms      88.387ms             9
                 aten::matmul         0.00%      57.000us         1.38%     795.260ms      88.362ms             9
                     aten::mm         1.38%     795.201ms         1.38%     795.203ms      88.356ms             9
                   aten::relu         0.00%      50.000us         0.76%     440.434ms      73.406ms             6
              aten::clamp_min         0.76%     440.384ms         0.76%     440.384ms      73.397ms             6
                   aten::add_         0.57%     327.801ms         0.57%     327.801ms      36.422ms             9
            aten::log_softmax         0.00%      23.000us         0.10%      55.503ms      18.501ms             3
```

#### after
```
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
               aten::spmm_sum        87.35%       11.826s        87.36%       11.827s        1.314s             9
                 aten::linear         0.00%      92.000us         5.87%     794.451ms      88.272ms             9
                 aten::matmul         0.00%      62.000us         5.87%     794.208ms      88.245ms             9
                     aten::mm         5.87%     794.143ms         5.87%     794.146ms      88.238ms             9
                   aten::relu         0.00%      53.000us         3.35%     452.977ms      75.496ms             6
              aten::clamp_min         3.35%     452.924ms         3.35%     452.924ms      75.487ms             6
                   aten::add_         2.58%     348.663ms         2.58%     348.663ms      38.740ms             9
                 aten::argmax         0.42%      57.473ms         0.42%      57.475ms      14.369ms             4
            aten::log_softmax         0.00%      22.000us         0.39%      52.605ms      17.535ms             3
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83727
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch, https://github.com/rusty1s, https://github.com/pearu
2023-02-10 15:56:40 +00:00
..
__init__.py
HasDecompTest.test_has_decomposition.expect port sparse_mm.reduce to pytorch and optimize it on CPU (#83727) 2023-02-10 15:56:40 +00:00
TestAutograd.test_function-x_grad_desc.expect
TestAutograd.test_function-y_grad_desc.expect
TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect Fix dumb make_fx issue (#84011) 2022-08-25 06:52:01 +00:00
TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect fx.replace_pattern accepts pattern/replacement as GraphModule (#88479) 2022-11-05 03:35:30 +00:00
TestJit.test_cu_escaped_number.expect
TestJit.test_import_method.expect
TestJit.test_non_ascii_string.expect
TestJit.test_pretty_printer-empty_float_list_test.expect
TestJit.test_pretty_printer-empty_int_list_test.expect
TestJit.test_pretty_printer-if_one.expect
TestJit.test_pretty_printer-if_test.expect
TestJit.test_pretty_printer-loop_use_test.expect
TestJit.test_pretty_printer-print_weird_test.expect
TestJit.test_pretty_printer-python_op_name_test.expect
TestJit.test_pretty_printer-while_if_test.expect
TestJit.test_pretty_printer-while_test.expect
TestPytorchExportModes.test_aten_fallback.expect
TestPytorchExportModes.test_onnx_aten.expect
TestScript.test_annot_ast_mypy_fn.expect
TestScript.test_annot_ast_mypy_method.expect
TestScript.test_annot_ast_py3_fn.expect
TestScript.test_annot_ast_py3_method.expect
TestScript.test_annot_string_mypy_fn.expect
TestScript.test_annot_string_mypy_method.expect
TestScript.test_annot_string_py3_fn.expect
TestScript.test_annot_string_py3_method.expect
TestScript.test_annotated_script_fn.expect Make string serialization of C++ FunctionSchema consistent with torchgen.model.FunctionSchema 2022-05-24 19:39:26 +00:00
TestScript.test_annotated_script_method.expect
TestScript.test_format-stdout.expect
TestScript.test_listconstruct_erasure.expect
TestScript.test_parser_type_annotations.expect
TestScript.test_parser_type_annotations_comment.expect
TestScript.test_print-stdout.expect
TestScript.test_python_frontend.expect
TestScript.test_python_frontend_py2.expect
TestScript.test_python_frontend_py3.expect
TestScript.test_string_print-stdout.expect
TestScript.test_torch_dot_tensor_annotation.expect
TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect Generator of tensor inputs with variable layout and structure (batch/non-batch, hybrid/non-hybrid, block/non-block) (#88914) 2022-11-30 02:13:33 +00:00
TestSparseCPU.test_print_coalesced_cpu_float64.expect
TestSparseCPU.test_print_uncoalesced_cpu_float64.expect
TestSparseCUDA.test_print_coalesced_cuda_float64.expect
TestSparseCUDA.test_print_uncoalesced_cuda_float64.expect
TestTensorBoard.test_audio.expect
TestTensorBoard.test_caffe2_simple_cnnmodel.expect
TestTensorBoard.test_caffe2_simple_model.expect
TestTensorBoard.test_histogram_auto.expect
TestTensorBoard.test_histogram_doane.expect
TestTensorBoard.test_histogram_fd.expect
TestTensorBoard.test_hparams_bool.expect
TestTensorBoard.test_hparams_number.expect
TestTensorBoard.test_hparams_string.expect
TestTensorBoard.test_image_with_3_channel_batched.expect Avoid overflow in tensorboard image summary (#90423) 2022-12-08 08:31:52 +00:00
TestTensorBoard.test_image_with_boxes.expect Avoid overflow in tensorboard image summary (#90423) 2022-12-08 08:31:52 +00:00
TestTensorBoard.test_image_with_one_channel.expect Avoid overflow in tensorboard image summary (#90423) 2022-12-08 08:31:52 +00:00
TestTensorBoard.test_image_with_one_channel_batched.expect Avoid overflow in tensorboard image summary (#90423) 2022-12-08 08:31:52 +00:00
TestTensorBoard.test_image_without_channel.expect Avoid overflow in tensorboard image summary (#90423) 2022-12-08 08:31:52 +00:00
TestTensorBoard.test_mesh.expect
TestTensorBoard.test_nested_nn_squential.expect
TestTensorBoard.test_pr_curve.expect
TestTensorBoard.test_pr_curve_raw.expect
TestTensorBoard.test_pytorch_graph.expect
TestTensorBoard.test_scalar_new_style.expect
TestTensorBoard.test_text.expect
TestTensorBoard.test_video.expect
TestTorch.test_is_nonzero-empty.expect
TestTorch.test_is_nonzero-multiple.expect
TestTorch.test_print-non_contiguous.expect