diff --git a/aten/src/ATen/native/vulkan/glsl/hardshrink.glsl b/aten/src/ATen/native/vulkan/glsl/hardshrink.glsl new file mode 100644 index 00000000000..ed6e45b3f09 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/hardshrink.glsl @@ -0,0 +1,26 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + float lambd; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const vec4 inval = texelFetch(uInput, pos, 0); + const vec4 mask = vec4(greaterThanEqual(inval, vec4(-uBlock.lambd)))*vec4(lessThanEqual(inval, vec4(uBlock.lambd))); + const vec4 outval = (vec4(1.0) - mask)*inval; + imageStore(uOutput, pos, outval); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/hardshrink_.glsl b/aten/src/ATen/native/vulkan/glsl/hardshrink_.glsl new file mode 100644 index 00000000000..21b5866c025 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/hardshrink_.glsl @@ -0,0 +1,25 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION restrict Block { + ivec4 size; + float lambd; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const vec4 inval = imageLoad(uOutput, pos); + const vec4 mask = vec4(greaterThanEqual(inval, vec4(-uBlock.lambd)))*vec4(lessThanEqual(inval, vec4(uBlock.lambd))); + const vec4 outval = (vec4(1.0) - mask)*inval; + imageStore(uOutput, pos, outval); + } +} diff --git a/aten/src/ATen/native/vulkan/ops/Clamp.cpp b/aten/src/ATen/native/vulkan/ops/Clamp.cpp index e56c005acc5..c6f046e84fd 100644 --- a/aten/src/ATen/native/vulkan/ops/Clamp.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clamp.cpp @@ -289,6 +289,121 @@ Tensor& hardsigmoid_(Tensor& self) { return ops::activation_(self, VK_KERNEL(hardsigmoid_)); } +Tensor hardshrink( + const Tensor& self_arg, + const Scalar& lambd) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + v_self.sizes(), + v_self.options(), + }; + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_output.has_image() && v_self.has_image()) { + const struct Block final { + uvec3 extents; + uint32_t _; + float lambd; + } block { + v_output.extents(), + 0u, + lambd.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(hardshrink), + v_output.extents(), + context->gpu().adapter->local_work_group_size(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image( + command_buffer, + vTensor::Stage::Compute), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return convert(v_output); +} + +Tensor& hardshrink_( + Tensor& self, + const Scalar& lambd) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self.is_vulkan(), + "Vulkan: In-place hardshrink is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self); + + api::Command::Pool& command_pool = context->command().pool; + api::Command::Buffer& command_buffer = command_pool.stream(); + { + if C10_LIKELY(v_self.has_image()) { + const struct Block final { + uvec3 extents; + uint32_t _; + float lambd; + } block { + v_self.extents(), + 0u, + lambd.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(hardshrink_), + v_self.extents(), + context->gpu().adapter->local_work_group_size(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image( + command_buffer, + vTensor::Stage::Compute, + vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_pool.submit(context->gpu().queue, command_buffer); + + return self; +} + Tensor sigmoid(const Tensor& self) { return ops::activation(self, VK_KERNEL(sigmoid)); } @@ -312,6 +427,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::clamp_"), TORCH_FN(clamp_)); m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid"), hardsigmoid); m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), hardsigmoid_); + m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink)); + m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), TORCH_FN(hardshrink_)); m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), hardswish); m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), hardswish_); m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), hardtanh); diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index c98fea5336a..b2c6daf115d 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -20,7 +20,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { constexpr float tolerance = 1e-5; #endif - return diff.abs().max().item() < (tolerance * maxValue); + return diff.abs().max().item() <= (tolerance * maxValue); } bool almostEqual(const at::Tensor& a, const at::Tensor& b) { @@ -936,6 +936,49 @@ TEST(VulkanAPITest, hardsigmoid_) { ASSERT_TRUE(check); } +TEST(VulkanAPITest, hardshrink) { + if (!at::is_vulkan_available()) { + return; + } + + for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { + const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = at::hardshrink(in_cpu, lambd_value); + const auto out_vulkan = at::hardshrink(in_vulkan, lambd_value); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); + } +} + +TEST(VulkanAPITest, hardshrink_) { + if (!at::is_vulkan_available()) { + return; + } + + for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { + const auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto vulkan = cpu.vulkan(); + + cpu.hardshrink(lambd_value); + vulkan.hardshrink(lambd_value); + + const auto check = almostEqual(cpu, vulkan.cpu()); + if (!check) { + showRtol(cpu, vulkan.cpu()); + } + + ASSERT_TRUE(check); + } +} + TEST(VulkanAPITest, hardswish) { if (!at::is_vulkan_available()) { return;