pytorch/test/cpp_extensions/extension.cpp
gchanan a3442f62bc
Support native namespace functions with type dispatch. (#5576)
* Support native namespace functions with type dispatch.

Use 'ones' as an example.  Note this is a "halfway" solution; i.e. the call chain is:
at::ones(shape, dtype) -> dtype.ones(shape, dtype) -> CPUFloatType.ones(shape, dtype) -> at::native::ones(shape, dtype)

The "nicer" solution would probably be something like:
at::ones(shape, dtype) -> dtype.ones(shape) -> CPUFloatType.ones(shape) -> at::native::ones(shape, this)

* Fix type inference.

* Fix test install.

* Fix extensions.

* Put dtype argument at the beginning.

* Fix extension.cpp.

* Fix rnn.

* Move zeros in the same manner.

* Fix cuda.

* Change randn.

* Change rand.

* Change randperm.

* Fix aten contrib.

* Resize in randperm_out.

* Implement eye.

* Fix sparse zeros.

* linspace, logspace.

* arange.

* range.

* Remove type dispatch from gen_python_functions.

* Properly generate maybe_init_cuda for type dispatch functions not named type.

* Don't duplicate dtype, this parameters for native type dispatched functions.

* Call VariableType factory methods from the base type so it gets version number 0.

* Address review comments.
2018-03-09 10:52:53 -05:00

30 lines
675 B
C++

#include <torch/torch.h>
using namespace at;
Tensor sigmoid_add(Tensor x, Tensor y) {
return x.sigmoid() + y.sigmoid();
}
struct MatrixMultiplier {
MatrixMultiplier(int A, int B) {
tensor_ = ones(CPU(kDouble), {A, B});
}
Tensor forward(Tensor weights) {
return tensor_.mm(weights);
}
Tensor get() const {
return tensor_;
}
private:
Tensor tensor_;
};
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
py::class_<MatrixMultiplier>(m, "MatrixMultiplier")
.def(py::init<int, int>())
.def("forward", &MatrixMultiplier::forward)
.def("get", &MatrixMultiplier::get);
}