mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12180 I had to fix a lot of call sites, because a lot of places assume that you can actually get a const vector&, and if the internal representation of sizes in a tensor is NOT a vector, it's not possible to fulfill this API contract. Framework changes: - I deleted TensorImpl::dims(); caffe2::Tensor::dims() just forwards to sizes() now. - De-templatized SetDims; now it is an explicit list of ArrayRef and variadic overloads. This makes implicit conversions work again, so I don't need to explicitly list the std::vector cases too. - As a knock-on effect, this causes Reset() to accept at::IntList as well as const std::vector<int64_t>& - Edited variadic overloads of SetDims to all forward to the underlying arbitrary-dim implementation, reducing code duplication. (It's probably marginally less efficient in the new world.) - Replace Tensor constructor accepting const std::vector<int64_t>& with at::IntList - Make MKLTensor accept ArrayRef along with vector in constructor and Reset (unfortunately, no implicit conversions here, since it's templated on index type.) - There are a few other places, like cudnn, where I changed functions that previously took const std::vector<int64_t>& to take at::IntList instead. Classification of call site changes: - 'const std::vector<int64_t>& x_dims = x.dims()' ==> 'at::IntList x_dims = x.dims()' - 'std::vector<int64_t> x_dims = x.dims()' ==> 'std::vector<int64_t> x_dims = x.dims().vec()' (we need a copy!) Usually this is because we're about to mutably modify the vector to compute some new dimension. However, it also very commonly occurs in the form: 'x_dims_ = x.dims()' because we frequently cache sizes in operators. - Instead of constructing std::vector<int64_t>{blah, blah}, construct an at::IntList directly ArrayRef changes: - cbegin()/cend() iterators, they operate the same aas begin()/end() because everything on ArrayRef is const. - Moved operator<< into ArrayRef.h, so that it's always available when working with ArrayRef. I also templated it, so it now works on an ArrayRef of any type. - Add operator== overload for ArrayRef, and also add variants to permit comparison of ArrayRef with std::vector, a very common operation. (The non-templated version of operator== can get these automatically via implicit conversion, but with templates C++ refuses to do any explicit conversions.) I'm planning to audit all dims() call sites to make sure they don't expect 'auto x = t.dims()' to give you an x whose lifetime can validly outlive the tensor. I opted not to do a dims() to sizes() rename, because dims() also matches the protobufs accessor. Bad news! Reviewed By: jerryzh168 Differential Revision: D10111759 fbshipit-source-id: a2a81dc4b92c22ad4b3b8ef4077a7e97b6479452 |
||
|---|---|---|
| .. | ||
| docs | ||
| examples | ||
| helpers | ||
| ideep | ||
| layers | ||
| mint | ||
| mkl | ||
| modeling | ||
| models | ||
| onnx | ||
| operator_test | ||
| predictor | ||
| rnn | ||
| serialized_test | ||
| test | ||
| trt | ||
| __init__.py | ||
| _import_c_extension.py | ||
| allcompare_test.py | ||
| attention.py | ||
| benchmark_generator.py | ||
| binarysize.py | ||
| brew.py | ||
| brew_test.py | ||
| build.py | ||
| cached_reader.py | ||
| caffe_translator.py | ||
| caffe_translator_test.py | ||
| checkpoint.py | ||
| checkpoint_test.py | ||
| CMakeLists.txt | ||
| cnn.py | ||
| compatibility.py | ||
| context.py | ||
| context_test.py | ||
| control.py | ||
| control_ops_grad.py | ||
| control_ops_util.py | ||
| control_test.py | ||
| convert.py | ||
| convert_test.py | ||
| convnet_benchmarks.py | ||
| convnet_benchmarks_test.py | ||
| core.py | ||
| core_gradients_test.py | ||
| core_test.py | ||
| crf.py | ||
| crf_predict.py | ||
| crf_viterbi_test.py | ||
| data_parallel_model.py | ||
| data_parallel_model_test.py | ||
| data_workers.py | ||
| data_workers_test.py | ||
| dataio.py | ||
| dataio_test.py | ||
| dataset.py | ||
| db_file_reader.py | ||
| db_test.py | ||
| device_checker.py | ||
| dlpack.h | ||
| dyndep.py | ||
| embedding_generation_benchmark.py | ||
| experiment_util.py | ||
| extension_loader.py | ||
| functional.py | ||
| functional_test.py | ||
| fused_8bit_rowwise_conversion_ops_test.py | ||
| gradient_check_test.py | ||
| gradient_checker.py | ||
| gru_cell.py | ||
| hsm_util.py | ||
| hypothesis_test.py | ||
| hypothesis_test_util.py | ||
| ideep_test_util.py | ||
| layer_model_helper.py | ||
| layer_model_instantiator.py | ||
| layer_parameter_sharing_test.py | ||
| layer_test_util.py | ||
| layers_test.py | ||
| lengths_reducer_fused_8bit_rowwise_ops_test.py | ||
| lengths_reducer_rowwise_8bit_ops_test.py | ||
| lstm_benchmark.py | ||
| memonger.py | ||
| memonger_test.py | ||
| mkl_test_util.py | ||
| model_device_test.py | ||
| model_helper.py | ||
| model_helper_test.py | ||
| modifier_context.py | ||
| mpi_python.cc | ||
| muji.py | ||
| muji_test.py | ||
| net_builder.py | ||
| net_builder_test.py | ||
| net_drawer.py | ||
| net_printer.py | ||
| net_printer_test.py | ||
| nomnigraph.py | ||
| nomnigraph_test.py | ||
| normalizer.py | ||
| normalizer_context.py | ||
| normalizer_test.py | ||
| numa_benchmark.py | ||
| numa_test.py | ||
| observer_test.py | ||
| optimizer.py | ||
| optimizer_context.py | ||
| optimizer_test.py | ||
| optimizer_test_util.py | ||
| parallel_workers.py | ||
| parallel_workers_test.py | ||
| parallelize_bmuf_distributed_test.py | ||
| pipeline.py | ||
| pipeline_test.py | ||
| predictor_constants.py | ||
| pybind_state.cc | ||
| pybind_state.h | ||
| pybind_state_dlpack.cc | ||
| pybind_state_dlpack.h | ||
| pybind_state_gpu.cc | ||
| pybind_state_hip.cc | ||
| pybind_state_ideep.cc | ||
| pybind_state_int8.cc | ||
| pybind_state_mkl.cc | ||
| pybind_state_nomni.cc | ||
| pybind_state_registry.cc | ||
| pybind_state_registry.h | ||
| python_op_test.py | ||
| queue_util.py | ||
| record_queue.py | ||
| recurrent.py | ||
| regularizer.py | ||
| regularizer_context.py | ||
| regularizer_test.py | ||
| rnn_cell.py | ||
| schema.py | ||
| schema_test.py | ||
| scope.py | ||
| scope_test.py | ||
| session.py | ||
| session_test.py | ||
| sparse_to_dense_mask_test.py | ||
| sparse_to_dense_test.py | ||
| task.py | ||
| test_util.py | ||
| text_file_reader.py | ||
| timeout_guard.py | ||
| toy_regression_test.py | ||
| transformations.py | ||
| transformations_test.py | ||
| tt_core.py | ||
| tt_core_test.py | ||
| utils.py | ||
| visualize.py | ||
| workspace.py | ||
| workspace_test.py | ||