mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Expose PiecewiseLinearTransform to PyTorch
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26903 Test Plan: Unit Test Reviewed By: bddppq Differential Revision: D17585637 fbshipit-source-id: fe669aaf3301d7efb5c28ec0097945d55a71773d
This commit is contained in:
parent
71011211c1
commit
d63d7ab997
4 changed files with 61 additions and 0 deletions
|
|
@ -82,3 +82,19 @@ bound.
|
|||
|
||||
SHOULD_NOT_DO_GRADIENT(PiecewiseLinearTransform);
|
||||
} // namespace caffe2
|
||||
|
||||
using PiecewiseLinearTransformOpFloatCPU =
|
||||
caffe2::PiecewiseLinearTransformOp<float, caffe2::CPUContext>;
|
||||
|
||||
// clang-format off
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
|
||||
PiecewiseLinearTransform,
|
||||
"_caffe2::PiecewiseLinearTransform("
|
||||
"Tensor predictions, "
|
||||
"float[] bounds, "
|
||||
"float[] slopes, "
|
||||
"float[] intercepts, "
|
||||
"bool binary"
|
||||
") -> (Tensor output_0)",
|
||||
PiecewiseLinearTransformOpFloatCPU);
|
||||
// clang-format on
|
||||
|
|
|
|||
|
|
@ -281,3 +281,10 @@ REGISTER_CUDA_OPERATOR(
|
|||
PiecewiseLinearTransformOp<float, CUDAContext>);
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
using PiecewiseLinearTransformOpFloatCUDA =
|
||||
caffe2::PiecewiseLinearTransformOp<float, caffe2::CUDAContext>;
|
||||
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(
|
||||
PiecewiseLinearTransform,
|
||||
PiecewiseLinearTransformOpFloatCUDA);
|
||||
|
|
|
|||
|
|
@ -2,8 +2,11 @@
|
|||
#define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(PiecewiseLinearTransform);
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T, class Context>
|
||||
|
|
|
|||
|
|
@ -676,5 +676,40 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
reference = fused_rowwise_8bit_quantize_dequantize_reference(input_data)
|
||||
np.testing.assert_array_almost_equal(dequantized_data.numpy(), reference)
|
||||
|
||||
@given(binary_input=st.booleans())
|
||||
def test_piecewise_linear_op(self, binary_input):
|
||||
if binary_input:
|
||||
num_dims = 1
|
||||
else:
|
||||
num_dims = 3
|
||||
data = np.random.rand(1024, num_dims).astype(np.float32)
|
||||
slopes = np.zeros(4 * num_dims).astype(np.float32)
|
||||
bounds = np.sort(np.random.rand(5, num_dims).astype(np.float32), axis=0).flatten('F')
|
||||
intercepts = np.random.rand(4 * num_dims).astype(np.float32)
|
||||
|
||||
def _piecewise_linear_ref(X):
|
||||
ref_op = core.CreateOperator(
|
||||
"PiecewiseLinearTransform",
|
||||
["data",
|
||||
"bounds",
|
||||
"slopes",
|
||||
"intercepts"],
|
||||
["calibrated"],
|
||||
binary=binary_input,
|
||||
)
|
||||
workspace.FeedBlob("data", X)
|
||||
workspace.FeedBlob("bounds", bounds)
|
||||
workspace.FeedBlob("slopes", slopes)
|
||||
workspace.FeedBlob("intercepts", intercepts)
|
||||
workspace.RunOperatorOnce(ref_op)
|
||||
return workspace.FetchBlob("calibrated")
|
||||
|
||||
expected_output = _piecewise_linear_ref(data)
|
||||
actual_output = torch.ops._caffe2.PiecewiseLinearTransform(
|
||||
torch.tensor(data), bounds.tolist(), slopes.tolist(), intercepts.tolist(), binary_input)
|
||||
|
||||
torch.testing.assert_allclose(torch.tensor(expected_output), actual_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue