pytorch/test/cpp_extensions/cuda_extension.cpp
Edward Z. Yang 0427afadd1
Make AT_ASSERT/AT_ERROR non-printf based, other tweaks (#7104)
* Make AT_ASSERT/AT_ERROR non-printf based, other tweaks

- AT_ASSERT/AT_ERROR don't take printf strings anymore; instead,
  they take a comma-separated list of things you wanted to print
  (bringing it inline with Caffe2's conventions).

  Instead of AT_ASSERT(x == 0, "%d is not zero", x)
  you write AT_ASSERT(x == 0, x, " is not zero")

  This is done by way of a new variadic template at::str(), which
  takes a list of arguments and cats their string reps (as per
  operator<<) together.

- A bunch of the demangling logic that was in Error.h is now
  moved to Error.cpp (better header hygiene.)  Also, demangle
  has been moved out to its own helper function, and also
  a new helper demangle_type (from Caffe2) added.

- A bunch of AT_ASSERT converted into AT_CHECK, to more properly
  convey which checks can be caused by user error, and which are
  due to logic error in ATen.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* CR

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Fix test failure.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* buildfix

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* More fixes.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* One more fix

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Try harder

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-05-01 10:28:31 -04:00

19 lines
716 B
C++

#include <torch/torch.h>
// Declare the function from cuda_extension.cu. It will be compiled
// separately with nvcc and linked with the object file of cuda_extension.cpp
// into one shared library.
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) {
AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
auto output = at::zeros_like(x);
sigmoid_add_cuda(
x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
}