pytorch/test/cpp_extensions/cpp_api_extension.cpp
Peter Goldsborough e05d689c49 Unify C++ API with C++ extensions (#11510)
Summary:
Currently the C++ API and C++ extensions are effectively two different, entirely orthogonal code paths. This PR unifies the C++ API with the C++ extension API by adding an element of Python binding support to the C++ API. This means the `torch/torch.h` included by C++ extensions, which currently routes to `torch/csrc/torch.h`, can now be rerouted to `torch/csrc/api/include/torch/torch.h` -- i.e. the main C++ API header. This header then includes Python binding support conditioned on a define (`TORCH_WITH_PYTHON_BINDINGS`), *which is only passed when building a C++ extension*.

Currently stacked on top of https://github.com/pytorch/pytorch/pull/11498

Why is this useful?

1. One less codepath. In particular, there has been trouble again and again due to the two `torch/torch.h` header files and ambiguity when both ended up in the include path. This is now fixed.
2. I have found that it is quite common to want to bind a C++ API module back into Python. This could be for simple experimentation, or to have your training loop in Python but your models in C++. This PR makes this easier by adding pybind11 support to the C++ API.
3. The C++ extension API simply becomes richer by gaining access to the C++ API headers.

soumith ezyang apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11510

Reviewed By: ezyang

Differential Revision: D9998835

Pulled By: goldsborough

fbshipit-source-id: 7a94b44a9d7e0377b7f1cfc99ba2060874d51535
2018-09-24 14:44:21 -07:00

38 lines
933 B
C++

#include <torch/extension.h>
#include <torch/python.h>
#include <torch/torch.h>
struct Net : torch::nn::Module {
Net(int64_t in, int64_t out)
: fc(in, out),
bn(torch::nn::BatchNormOptions(out).stateful(true)),
dropout(0.5) {
register_module("fc", fc);
register_module("bn", bn);
register_module("dropout", dropout);
}
torch::Tensor forward(torch::Tensor x) {
return dropout->forward(bn->forward(torch::relu(fc->forward(x))));
}
void set_bias(torch::Tensor bias) {
fc->bias = bias;
}
torch::Tensor get_bias() const {
return fc->bias;
}
torch::nn::Linear fc;
torch::nn::BatchNorm bn;
torch::nn::Dropout dropout;
};
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
torch::python::bind_module<Net>(m, "Net")
.def(py::init<int64_t, int64_t>())
.def("forward", &Net::forward)
.def("set_bias", &Net::set_bias)
.def("get_bias", &Net::get_bias);
}