mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
* Make AT_ASSERT/AT_ERROR non-printf based, other tweaks - AT_ASSERT/AT_ERROR don't take printf strings anymore; instead, they take a comma-separated list of things you wanted to print (bringing it inline with Caffe2's conventions). Instead of AT_ASSERT(x == 0, "%d is not zero", x) you write AT_ASSERT(x == 0, x, " is not zero") This is done by way of a new variadic template at::str(), which takes a list of arguments and cats their string reps (as per operator<<) together. - A bunch of the demangling logic that was in Error.h is now moved to Error.cpp (better header hygiene.) Also, demangle has been moved out to its own helper function, and also a new helper demangle_type (from Caffe2) added. - A bunch of AT_ASSERT converted into AT_CHECK, to more properly convey which checks can be caused by user error, and which are due to logic error in ATen. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * CR Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Fix test failure. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * buildfix Signed-off-by: Edward Z. Yang <ezyang@fb.com> * More fixes. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * One more fix Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Try harder Signed-off-by: Edward Z. Yang <ezyang@fb.com>
19 lines
716 B
C++
19 lines
716 B
C++
#include <torch/torch.h>
|
|
|
|
// Declare the function from cuda_extension.cu. It will be compiled
|
|
// separately with nvcc and linked with the object file of cuda_extension.cpp
|
|
// into one shared library.
|
|
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
|
|
|
|
at::Tensor sigmoid_add(at::Tensor x, at::Tensor y) {
|
|
AT_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
|
|
AT_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
|
|
auto output = at::zeros_like(x);
|
|
sigmoid_add_cuda(
|
|
x.data<float>(), y.data<float>(), output.data<float>(), output.numel());
|
|
return output;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
|
|
}
|