[Vulkan] Added Hardshrink op (#62870)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62870

Added Hardshrink operator for Vulkan
Added tests for Hardshrink op

Reference: [Hardshrink](https://pytorch.org/docs/stable/generated/torch.nn.Hardshrink.html#torch.nn.Hardshrink)

Test Plan: Imported from OSS

Reviewed By: SS-JIA

Differential Revision: D30174950

Pulled By: beback4u

fbshipit-source-id: 3e192390eb9f92abecae966e84bbfae356bfd7c8
This commit is contained in:
Sangbaek Park 2021-08-09 10:48:39 -07:00 committed by Facebook GitHub Bot
parent 922710f9b9
commit 710c419f11
4 changed files with 212 additions and 1 deletions

View file

@ -0,0 +1,26 @@
#version 450 core
#define PRECISION $precision
layout(std430) buffer;
/* Qualifiers: layout - storage - precision - memory */
layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
ivec4 size;
float lambd;
} uBlock;
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uBlock.size.xyz))) {
const vec4 inval = texelFetch(uInput, pos, 0);
const vec4 mask = vec4(greaterThanEqual(inval, vec4(-uBlock.lambd)))*vec4(lessThanEqual(inval, vec4(uBlock.lambd)));
const vec4 outval = (vec4(1.0) - mask)*inval;
imageStore(uOutput, pos, outval);
}
}

View file

@ -0,0 +1,25 @@
#version 450 core
#define PRECISION $precision
layout(std430) buffer;
/* Qualifiers: layout - storage - precision - memory */
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION restrict Block {
ivec4 size;
float lambd;
} uBlock;
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
if (all(lessThan(pos, uBlock.size.xyz))) {
const vec4 inval = imageLoad(uOutput, pos);
const vec4 mask = vec4(greaterThanEqual(inval, vec4(-uBlock.lambd)))*vec4(lessThanEqual(inval, vec4(uBlock.lambd)));
const vec4 outval = (vec4(1.0) - mask)*inval;
imageStore(uOutput, pos, outval);
}
}

View file

@ -289,6 +289,121 @@ Tensor& hardsigmoid_(Tensor& self) {
return ops::activation_(self, VK_KERNEL(hardsigmoid_));
}
Tensor hardshrink(
const Tensor& self_arg,
const Scalar& lambd) {
api::Context* const context = api::context();
const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
const vTensor& v_self = convert(self);
vTensor v_output{
context,
v_self.sizes(),
v_self.options(),
};
api::Command::Pool& command_pool = context->command().pool;
api::Command::Buffer& command_buffer = command_pool.stream();
{
if C10_LIKELY(v_output.has_image() && v_self.has_image()) {
const struct Block final {
uvec3 extents;
uint32_t _;
float lambd;
} block {
v_output.extents(),
0u,
lambd.to<float>(),
};
context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(hardshrink),
v_output.extents(),
context->gpu().adapter->local_work_group_size(),
// Write-only access bypasses synchronization but inserts appropriate
// barriers if necessary.
v_output.image(
command_buffer,
vTensor::Stage::Compute,
vTensor::Access::Write),
// Read-only access is implied on const tensors and triggers an async
// synchronization if necessary.
v_self.image(
command_buffer,
vTensor::Stage::Compute),
// Object lifetime is managed by the resource pool.
// It is OK not to keep track of the handle.
context->resource().pool.uniform(block).object);
}
else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_pool.submit(context->gpu().queue, command_buffer);
return convert(v_output);
}
Tensor& hardshrink_(
Tensor& self,
const Scalar& lambd) {
api::Context* const context = api::context();
TORCH_CHECK(
self.is_vulkan(),
"Vulkan: In-place hardshrink is only supported on Vulkan tensors.");
vTensor& v_self = convert(self);
api::Command::Pool& command_pool = context->command().pool;
api::Command::Buffer& command_buffer = command_pool.stream();
{
if C10_LIKELY(v_self.has_image()) {
const struct Block final {
uvec3 extents;
uint32_t _;
float lambd;
} block {
v_self.extents(),
0u,
lambd.to<float>(),
};
context->dispatch(
command_buffer,
{
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
},
VK_KERNEL(hardshrink_),
v_self.extents(),
context->gpu().adapter->local_work_group_size(),
// Read-Write access triggers an async synchronization if necessory
// and inserts appropriate barriers if hazards are detected.
v_self.image(
command_buffer,
vTensor::Stage::Compute,
vTensor::Access::Read | vTensor::Access::Write),
// Object lifetime is managed by the resource pool.
// It is OK not to keep track of the handle.
context->resource().pool.uniform(block).object);
}
else {
TORCH_CHECK(false, "Not implemented!");
}
}
command_pool.submit(context->gpu().queue, command_buffer);
return self;
}
Tensor sigmoid(const Tensor& self) {
return ops::activation(self, VK_KERNEL(sigmoid));
}
@ -312,6 +427,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl(TORCH_SELECTIVE_NAME("aten::clamp_"), TORCH_FN(clamp_));
m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid"), hardsigmoid);
m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), hardsigmoid_);
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink));
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), TORCH_FN(hardshrink_));
m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), hardswish);
m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), hardswish_);
m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), hardtanh);

View file

@ -20,7 +20,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor>& inputs) {
constexpr float tolerance = 1e-5;
#endif
return diff.abs().max().item<float>() < (tolerance * maxValue);
return diff.abs().max().item<float>() <= (tolerance * maxValue);
}
bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
@ -936,6 +936,49 @@ TEST(VulkanAPITest, hardsigmoid_) {
ASSERT_TRUE(check);
}
TEST(VulkanAPITest, hardshrink) {
if (!at::is_vulkan_available()) {
return;
}
for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
const auto in_vulkan = in_cpu.vulkan();
const auto out_cpu = at::hardshrink(in_cpu, lambd_value);
const auto out_vulkan = at::hardshrink(in_vulkan, lambd_value);
const auto check = almostEqual(out_cpu, out_vulkan.cpu());
if (!check) {
showRtol(out_cpu, out_vulkan.cpu());
}
ASSERT_TRUE(check);
}
}
TEST(VulkanAPITest, hardshrink_) {
if (!at::is_vulkan_available()) {
return;
}
for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
const auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
const auto vulkan = cpu.vulkan();
cpu.hardshrink(lambd_value);
vulkan.hardshrink(lambd_value);
const auto check = almostEqual(cpu, vulkan.cpu());
if (!check) {
showRtol(cpu, vulkan.cpu());
}
ASSERT_TRUE(check);
}
}
TEST(VulkanAPITest, hardswish) {
if (!at::is_vulkan_available()) {
return;