mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Vulkan] Added Hardshrink op (#62870)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62870 Added Hardshrink operator for Vulkan Added tests for Hardshrink op Reference: [Hardshrink](https://pytorch.org/docs/stable/generated/torch.nn.Hardshrink.html#torch.nn.Hardshrink) Test Plan: Imported from OSS Reviewed By: SS-JIA Differential Revision: D30174950 Pulled By: beback4u fbshipit-source-id: 3e192390eb9f92abecae966e84bbfae356bfd7c8
This commit is contained in:
parent
922710f9b9
commit
710c419f11
4 changed files with 212 additions and 1 deletions
26
aten/src/ATen/native/vulkan/glsl/hardshrink.glsl
Normal file
26
aten/src/ATen/native/vulkan/glsl/hardshrink.glsl
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#version 450 core
|
||||
#define PRECISION $precision
|
||||
|
||||
layout(std430) buffer;
|
||||
|
||||
/* Qualifiers: layout - storage - precision - memory */
|
||||
|
||||
layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput;
|
||||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
|
||||
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
|
||||
ivec4 size;
|
||||
float lambd;
|
||||
} uBlock;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
||||
|
||||
void main() {
|
||||
const ivec3 pos = ivec3(gl_GlobalInvocationID);
|
||||
|
||||
if (all(lessThan(pos, uBlock.size.xyz))) {
|
||||
const vec4 inval = texelFetch(uInput, pos, 0);
|
||||
const vec4 mask = vec4(greaterThanEqual(inval, vec4(-uBlock.lambd)))*vec4(lessThanEqual(inval, vec4(uBlock.lambd)));
|
||||
const vec4 outval = (vec4(1.0) - mask)*inval;
|
||||
imageStore(uOutput, pos, outval);
|
||||
}
|
||||
}
|
||||
25
aten/src/ATen/native/vulkan/glsl/hardshrink_.glsl
Normal file
25
aten/src/ATen/native/vulkan/glsl/hardshrink_.glsl
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#version 450 core
|
||||
#define PRECISION $precision
|
||||
|
||||
layout(std430) buffer;
|
||||
|
||||
/* Qualifiers: layout - storage - precision - memory */
|
||||
|
||||
layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput;
|
||||
layout(set = 0, binding = 1) uniform PRECISION restrict Block {
|
||||
ivec4 size;
|
||||
float lambd;
|
||||
} uBlock;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
||||
|
||||
void main() {
|
||||
const ivec3 pos = ivec3(gl_GlobalInvocationID);
|
||||
|
||||
if (all(lessThan(pos, uBlock.size.xyz))) {
|
||||
const vec4 inval = imageLoad(uOutput, pos);
|
||||
const vec4 mask = vec4(greaterThanEqual(inval, vec4(-uBlock.lambd)))*vec4(lessThanEqual(inval, vec4(uBlock.lambd)));
|
||||
const vec4 outval = (vec4(1.0) - mask)*inval;
|
||||
imageStore(uOutput, pos, outval);
|
||||
}
|
||||
}
|
||||
|
|
@ -289,6 +289,121 @@ Tensor& hardsigmoid_(Tensor& self) {
|
|||
return ops::activation_(self, VK_KERNEL(hardsigmoid_));
|
||||
}
|
||||
|
||||
Tensor hardshrink(
|
||||
const Tensor& self_arg,
|
||||
const Scalar& lambd) {
|
||||
api::Context* const context = api::context();
|
||||
|
||||
const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan();
|
||||
const vTensor& v_self = convert(self);
|
||||
|
||||
vTensor v_output{
|
||||
context,
|
||||
v_self.sizes(),
|
||||
v_self.options(),
|
||||
};
|
||||
|
||||
api::Command::Pool& command_pool = context->command().pool;
|
||||
api::Command::Buffer& command_buffer = command_pool.stream();
|
||||
{
|
||||
if C10_LIKELY(v_output.has_image() && v_self.has_image()) {
|
||||
const struct Block final {
|
||||
uvec3 extents;
|
||||
uint32_t _;
|
||||
float lambd;
|
||||
} block {
|
||||
v_output.extents(),
|
||||
0u,
|
||||
lambd.to<float>(),
|
||||
};
|
||||
|
||||
context->dispatch(
|
||||
command_buffer,
|
||||
{
|
||||
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
|
||||
VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
|
||||
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
|
||||
},
|
||||
VK_KERNEL(hardshrink),
|
||||
v_output.extents(),
|
||||
context->gpu().adapter->local_work_group_size(),
|
||||
// Write-only access bypasses synchronization but inserts appropriate
|
||||
// barriers if necessary.
|
||||
v_output.image(
|
||||
command_buffer,
|
||||
vTensor::Stage::Compute,
|
||||
vTensor::Access::Write),
|
||||
// Read-only access is implied on const tensors and triggers an async
|
||||
// synchronization if necessary.
|
||||
v_self.image(
|
||||
command_buffer,
|
||||
vTensor::Stage::Compute),
|
||||
// Object lifetime is managed by the resource pool.
|
||||
// It is OK not to keep track of the handle.
|
||||
context->resource().pool.uniform(block).object);
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "Not implemented!");
|
||||
}
|
||||
}
|
||||
command_pool.submit(context->gpu().queue, command_buffer);
|
||||
|
||||
return convert(v_output);
|
||||
}
|
||||
|
||||
Tensor& hardshrink_(
|
||||
Tensor& self,
|
||||
const Scalar& lambd) {
|
||||
api::Context* const context = api::context();
|
||||
|
||||
TORCH_CHECK(
|
||||
self.is_vulkan(),
|
||||
"Vulkan: In-place hardshrink is only supported on Vulkan tensors.");
|
||||
|
||||
vTensor& v_self = convert(self);
|
||||
|
||||
api::Command::Pool& command_pool = context->command().pool;
|
||||
api::Command::Buffer& command_buffer = command_pool.stream();
|
||||
{
|
||||
if C10_LIKELY(v_self.has_image()) {
|
||||
const struct Block final {
|
||||
uvec3 extents;
|
||||
uint32_t _;
|
||||
float lambd;
|
||||
} block {
|
||||
v_self.extents(),
|
||||
0u,
|
||||
lambd.to<float>(),
|
||||
};
|
||||
|
||||
context->dispatch(
|
||||
command_buffer,
|
||||
{
|
||||
VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
|
||||
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
|
||||
},
|
||||
VK_KERNEL(hardshrink_),
|
||||
v_self.extents(),
|
||||
context->gpu().adapter->local_work_group_size(),
|
||||
// Read-Write access triggers an async synchronization if necessory
|
||||
// and inserts appropriate barriers if hazards are detected.
|
||||
v_self.image(
|
||||
command_buffer,
|
||||
vTensor::Stage::Compute,
|
||||
vTensor::Access::Read | vTensor::Access::Write),
|
||||
// Object lifetime is managed by the resource pool.
|
||||
// It is OK not to keep track of the handle.
|
||||
context->resource().pool.uniform(block).object);
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "Not implemented!");
|
||||
}
|
||||
}
|
||||
command_pool.submit(context->gpu().queue, command_buffer);
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor sigmoid(const Tensor& self) {
|
||||
return ops::activation(self, VK_KERNEL(sigmoid));
|
||||
}
|
||||
|
|
@ -312,6 +427,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
|
|||
m.impl(TORCH_SELECTIVE_NAME("aten::clamp_"), TORCH_FN(clamp_));
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid"), hardsigmoid);
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), hardsigmoid_);
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink));
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), TORCH_FN(hardshrink_));
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), hardswish);
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), hardswish_);
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), hardtanh);
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor>& inputs) {
|
|||
constexpr float tolerance = 1e-5;
|
||||
#endif
|
||||
|
||||
return diff.abs().max().item<float>() < (tolerance * maxValue);
|
||||
return diff.abs().max().item<float>() <= (tolerance * maxValue);
|
||||
}
|
||||
|
||||
bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
|
||||
|
|
@ -936,6 +936,49 @@ TEST(VulkanAPITest, hardsigmoid_) {
|
|||
ASSERT_TRUE(check);
|
||||
}
|
||||
|
||||
TEST(VulkanAPITest, hardshrink) {
|
||||
if (!at::is_vulkan_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
|
||||
const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
|
||||
const auto in_vulkan = in_cpu.vulkan();
|
||||
|
||||
const auto out_cpu = at::hardshrink(in_cpu, lambd_value);
|
||||
const auto out_vulkan = at::hardshrink(in_vulkan, lambd_value);
|
||||
|
||||
const auto check = almostEqual(out_cpu, out_vulkan.cpu());
|
||||
|
||||
if (!check) {
|
||||
showRtol(out_cpu, out_vulkan.cpu());
|
||||
}
|
||||
|
||||
ASSERT_TRUE(check);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VulkanAPITest, hardshrink_) {
|
||||
if (!at::is_vulkan_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
|
||||
const auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat));
|
||||
const auto vulkan = cpu.vulkan();
|
||||
|
||||
cpu.hardshrink(lambd_value);
|
||||
vulkan.hardshrink(lambd_value);
|
||||
|
||||
const auto check = almostEqual(cpu, vulkan.cpu());
|
||||
if (!check) {
|
||||
showRtol(cpu, vulkan.cpu());
|
||||
}
|
||||
|
||||
ASSERT_TRUE(check);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(VulkanAPITest, hardswish) {
|
||||
if (!at::is_vulkan_available()) {
|
||||
return;
|
||||
|
|
|
|||
Loading…
Reference in a new issue