mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[vulkan] jit passes for vulkan conv2 prepack and fuse with clamp (#39282)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39282 Test Plan: Imported from OSS Differential Revision: D21962424 Pulled By: IvanKobzarev fbshipit-source-id: 2d20e827d2c3836b7e6b443293377c68dc1ffa5a
This commit is contained in:
parent
f69460d0cb
commit
3852215170
16 changed files with 540 additions and 64 deletions
|
|
@ -179,7 +179,7 @@ at::Tensor vulkan_convolution(
|
|||
voutput,
|
||||
vinput,
|
||||
weight.data_ptr<float>(),
|
||||
bias.defined() ? c10::make_optional<float*>(bias.data_ptr<float>())
|
||||
bias.defined() ? c10::make_optional<const float*>(bias.data_ptr<float>())
|
||||
: c10::nullopt,
|
||||
params);
|
||||
return new_with_vtensor_vulkan(std::move(voutput), input.options());
|
||||
|
|
@ -242,7 +242,8 @@ at::Tensor vulkan_convolution_prepacked(
|
|||
voutput,
|
||||
vinput,
|
||||
vweight,
|
||||
hasBias ? c10::make_optional((*bias).data_ptr<float>()) : c10::nullopt,
|
||||
hasBias ? c10::make_optional<const float*>((*bias).data_ptr<float>())
|
||||
: c10::nullopt,
|
||||
params,
|
||||
output_min,
|
||||
output_max);
|
||||
|
|
|
|||
|
|
@ -66,14 +66,13 @@ ContextConv2D create(
|
|||
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
|
||||
const auto dilation_expanded =
|
||||
expand_param_if_needed(dilation, "dilation", 2);
|
||||
const Tensor weight_nchw = weight.contiguous();
|
||||
Tensor weight_nchw = weight.contiguous();
|
||||
auto ws = weight_nchw.sizes();
|
||||
return ContextConv2D{
|
||||
at::native::vulkan_convolution_prepack_weights(weight),
|
||||
groups == 1 ? at::native::vulkan_convolution_prepack_weights(weight_nchw)
|
||||
: weight_nchw.vulkan(),
|
||||
bias.has_value() ? c10::make_optional((*bias).vulkan()) : c10::nullopt,
|
||||
{weight_nchw.sizes()[0],
|
||||
weight_nchw.sizes()[1],
|
||||
weight_nchw.sizes()[2],
|
||||
weight_nchw.sizes()[3]},
|
||||
{{ws[0], ws[1], ws[2], ws[3]}},
|
||||
{padding_expanded[0], padding_expanded[1]},
|
||||
{stride_expanded[0], stride_expanded[1]},
|
||||
{dilation_expanded[0], dilation_expanded[1]},
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ VBuffer kernelNCHW_OCHW_repack_O4C4HWi4o4(
|
|||
}
|
||||
|
||||
VBuffer bufferFromOptionalHostData(
|
||||
c10::optional<float*> data,
|
||||
c10::optional<const float*> data,
|
||||
const uint32_t size) {
|
||||
const auto sizeAligned =
|
||||
ROUND_UP(size, context().limits().minStorageBufferOffsetAlignment);
|
||||
|
|
@ -202,17 +202,15 @@ uint32_t conv2d_biasBufferSize(uint32_t oc) {
|
|||
void conv2d_depthwise(
|
||||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const float* weight,
|
||||
const c10::optional<float*> bias,
|
||||
const Conv2DParams params,
|
||||
const VulkanTensor& weight,
|
||||
const VBuffer& biasBuffer,
|
||||
const Conv2DParams& params,
|
||||
c10::optional<float> output_min,
|
||||
c10::optional<float> output_max) {
|
||||
TORCH_INTERNAL_ASSERT(params.G == params.C);
|
||||
auto osizes = output.sizes();
|
||||
TORCH_INTERNAL_ASSERT(osizes[2] == params.OH);
|
||||
TORCH_INTERNAL_ASSERT(osizes[3] == params.OW);
|
||||
auto biasBuffer =
|
||||
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC));
|
||||
struct ConstBlock {
|
||||
int32_t padding[2];
|
||||
int32_t kernelSize[2];
|
||||
|
|
@ -234,9 +232,6 @@ void conv2d_depthwise(
|
|||
output_max ? *output_max : std::numeric_limits<float>::infinity()};
|
||||
VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb));
|
||||
|
||||
VulkanTensor kernel{{params.OC, params.KH, params.KW}};
|
||||
kernel.set_data_from_host(weight);
|
||||
|
||||
VkDescriptorSetLayout descriptorSetLayout{};
|
||||
VkDescriptorPool descriptorPool{};
|
||||
VkDescriptorSet descriptorSet{};
|
||||
|
|
@ -256,7 +251,7 @@ void conv2d_depthwise(
|
|||
|
||||
output.image()->bindStorageImage(descriptorSet, 0);
|
||||
input.image()->bindShaderRead(descriptorSet, 1);
|
||||
kernel.image()->bindShaderRead(descriptorSet, 2);
|
||||
weight.image()->bindShaderRead(descriptorSet, 2);
|
||||
biasBuffer.bind(descriptorSet, 3);
|
||||
constBuffer.bind(descriptorSet, 4);
|
||||
|
||||
|
|
@ -269,7 +264,7 @@ void conv2d_depthwise(
|
|||
auto commandBuffer = computeUnit.commandBuffer();
|
||||
output.image()->addImageMemoryBarrierToGeneral(commandBuffer);
|
||||
input.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
|
||||
kernel.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
|
||||
weight.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
|
||||
computeUnit.dispatchCommandBuffer(
|
||||
params.OW, params.OH, params.OC_4, workGroupSize);
|
||||
computeUnit.endCommandBuffer();
|
||||
|
|
@ -279,6 +274,44 @@ void conv2d_depthwise(
|
|||
vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
|
||||
}
|
||||
|
||||
void conv2d_depthwise(
|
||||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const VulkanTensor& weight,
|
||||
const c10::optional<const float*> bias,
|
||||
const Conv2DParams params,
|
||||
c10::optional<float> output_min,
|
||||
c10::optional<float> output_max) {
|
||||
conv2d_depthwise(
|
||||
output,
|
||||
input,
|
||||
weight,
|
||||
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
|
||||
params,
|
||||
output_min,
|
||||
output_max);
|
||||
}
|
||||
|
||||
void conv2d_depthwise(
|
||||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const float* weight,
|
||||
const c10::optional<const float*> bias,
|
||||
const Conv2DParams params,
|
||||
c10::optional<float> output_min,
|
||||
c10::optional<float> output_max) {
|
||||
VulkanTensor weightTensor{{params.OC, params.KH, params.KW}};
|
||||
weightTensor.set_data_from_host(weight);
|
||||
conv2d_depthwise(
|
||||
output,
|
||||
input,
|
||||
weightTensor,
|
||||
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
|
||||
params,
|
||||
output_min,
|
||||
output_max);
|
||||
}
|
||||
|
||||
ImageSizes conv2d_prepack_weights_image_sizes(
|
||||
int64_t OC,
|
||||
int64_t C,
|
||||
|
|
@ -463,7 +496,7 @@ void conv2d(
|
|||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const VImage& kernelImage,
|
||||
const c10::optional<float*> bias,
|
||||
const c10::optional<const float*> bias,
|
||||
const Conv2DParams& params,
|
||||
c10::optional<float> output_min,
|
||||
c10::optional<float> output_max) {
|
||||
|
|
@ -483,10 +516,22 @@ void conv2d(
|
|||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const VulkanTensor& weight_prepacked,
|
||||
c10::optional<float*> bias,
|
||||
c10::optional<const float*> bias,
|
||||
const Conv2DParams params,
|
||||
c10::optional<float> output_min,
|
||||
c10::optional<float> output_max) {
|
||||
if (params.G > 1) {
|
||||
conv2d_depthwise(
|
||||
output,
|
||||
input,
|
||||
weight_prepacked,
|
||||
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
|
||||
params,
|
||||
output_min,
|
||||
output_max);
|
||||
return;
|
||||
}
|
||||
|
||||
conv2d(
|
||||
output,
|
||||
input,
|
||||
|
|
@ -505,6 +550,18 @@ void conv2d(
|
|||
const Conv2DParams params,
|
||||
c10::optional<float> output_min,
|
||||
c10::optional<float> output_max) {
|
||||
if (params.G > 1) {
|
||||
conv2d_depthwise(
|
||||
output,
|
||||
input,
|
||||
weight_prepacked,
|
||||
*(bias.buffer()),
|
||||
params,
|
||||
output_min,
|
||||
output_max);
|
||||
return;
|
||||
}
|
||||
|
||||
conv2d(
|
||||
output,
|
||||
input,
|
||||
|
|
@ -519,7 +576,7 @@ void conv2d(
|
|||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const float* weight,
|
||||
const c10::optional<float*> bias,
|
||||
const c10::optional<const float*> bias,
|
||||
const Conv2DParams params,
|
||||
c10::optional<float> output_min,
|
||||
c10::optional<float> output_max) {
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ void conv2d(
|
|||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const float* weight,
|
||||
const c10::optional<float*> bias,
|
||||
const c10::optional<const float*> bias,
|
||||
const Conv2DParams params,
|
||||
c10::optional<float> output_min = c10::nullopt,
|
||||
c10::optional<float> output_max = c10::nullopt);
|
||||
|
|
@ -46,7 +46,7 @@ void conv2d(
|
|||
VulkanTensor& output,
|
||||
const VulkanTensor& input,
|
||||
const VulkanTensor& weight_prepacked,
|
||||
const c10::optional<float*> bias,
|
||||
const c10::optional<const float*> bias,
|
||||
const Conv2DParams params,
|
||||
c10::optional<float> output_min = c10::nullopt,
|
||||
c10::optional<float> output_max = c10::nullopt);
|
||||
|
|
|
|||
|
|
@ -496,7 +496,7 @@ TEST(VulkanTest, conv2dPrepack) {
|
|||
ASSERT_TRUE(no_prepack_check);
|
||||
|
||||
auto prepack = callOpByName(
|
||||
"vulkan::conv2d_clamp_prepack",
|
||||
"vulkan_prepack::conv2d_clamp_prepack",
|
||||
"",
|
||||
t_w,
|
||||
t_b,
|
||||
|
|
@ -507,7 +507,7 @@ TEST(VulkanTest, conv2dPrepack) {
|
|||
output_min,
|
||||
output_max);
|
||||
auto tv_out_prepack_ivalues =
|
||||
callOpByName("vulkan::conv2d_clamp_run", "", tv_in, prepack[0]);
|
||||
callOpByName("vulkan_prepack::conv2d_clamp_run", "", tv_in, prepack[0]);
|
||||
auto tv_out_prepack = tv_out_prepack_ivalues[0].toTensor();
|
||||
auto t_out_prepack = tv_out_prepack.cpu();
|
||||
const auto prepack_check = almostEqual(t_out_prepack, t_out_expected);
|
||||
|
|
|
|||
|
|
@ -103,3 +103,4 @@ endif()
|
|||
caffe2_binary_target("tutorial_blob.cc")
|
||||
|
||||
caffe2_binary_target("dump_operator_names.cc")
|
||||
caffe2_binary_target("optimize_for_mobile.cc")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "torch/csrc/jit/api/module.h"
|
||||
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
|
||||
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
|
||||
#include "torch/csrc/jit/serialization/import.h"
|
||||
|
||||
|
|
@ -29,6 +30,7 @@ C10_DEFINE_bool(
|
|||
save_for_mobile,
|
||||
false,
|
||||
"Save the model with bytecode format compatible with lite inteprter.");
|
||||
C10_DEFINE_bool(vulkan, false, "Vulkan optimize_for_mobile");
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
c10::SetUsageMessage(
|
||||
|
|
@ -52,7 +54,10 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
auto module = torch::jit::load(FLAGS_model);
|
||||
auto optimized_module = torch::jit::optimizeForMobile(module);
|
||||
|
||||
auto optimized_module = FLAGS_vulkan
|
||||
? torch::jit::vulkanOptimizeForMobile(module)
|
||||
: torch::jit::optimizeForMobile(module);
|
||||
|
||||
if (FLAGS_save_for_mobile) {
|
||||
optimized_module._save_for_mobile(output_model_name);
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ TESTS = [
|
|||
'test_optim',
|
||||
'test_mobile_optimizer',
|
||||
'test_xnnpack_integration',
|
||||
'test_vulkan',
|
||||
'test_quantization',
|
||||
'test_sparse',
|
||||
'test_serialization',
|
||||
|
|
|
|||
162
test/test_vulkan.py
Normal file
162
test/test_vulkan.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import unittest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing import FileCheck
|
||||
import io
|
||||
|
||||
@unittest.skipUnless(torch.is_vulkan_available(),
|
||||
"Vulkan backend must be available for these tests.")
|
||||
class TestVulkanRewritePass(TestCase):
|
||||
@staticmethod
|
||||
def validate_transformed_module(
|
||||
# To please flake
|
||||
self,
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=False,
|
||||
fuse_clamping_ops=False):
|
||||
module_instance = self
|
||||
scripted_model = torch.jit.script(module_instance)
|
||||
scripted_model.eval()
|
||||
input_data = torch.normal(1, 20, size=data_shape)
|
||||
ref_result = scripted_model(input_data)
|
||||
torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
|
||||
if fuse_clamping_ops or prepack_removal:
|
||||
scripted_model._c = torch._C._freeze_module(scripted_model._c)
|
||||
if fuse_clamping_ops:
|
||||
torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
|
||||
if prepack_removal:
|
||||
torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(scripted_model, buffer)
|
||||
buffer.seek(0)
|
||||
deserialized_scripted_model = torch.jit.load(buffer)
|
||||
for pattern, v in pattern_count_map.items():
|
||||
if (v == 0):
|
||||
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
|
||||
elif (v == -1):
|
||||
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
|
||||
else:
|
||||
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
|
||||
|
||||
def test_conv(self):
|
||||
# Conv params
|
||||
batch_size = 2
|
||||
input_channels_per_group = 6
|
||||
height = 16
|
||||
width = 16
|
||||
output_channels_per_group = 6
|
||||
groups = 4
|
||||
kernel_h = kernel_w = 3
|
||||
stride_h = stride_w = 1
|
||||
pad_h = pad_w = 1
|
||||
dilation = 1
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
kernels = (kernel_h, kernel_w)
|
||||
strides = (stride_h, stride_w)
|
||||
paddings = (pad_h, pad_w)
|
||||
dilations = (dilation, dilation)
|
||||
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
|
||||
conv_bias_shape = (output_channels)
|
||||
|
||||
class Conv2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv2D, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
|
||||
data_shape = (batch_size, input_channels, height, width)
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"vulkan_prepack::conv2d_clamp_prepack": 1,
|
||||
"vulkan_prepack::conv2d_clamp_run": 1}
|
||||
TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
|
||||
|
||||
class Conv2DRelu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv2DRelu, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
o = F.conv2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.relu(o)
|
||||
return o
|
||||
|
||||
data_shape = (batch_size, input_channels, height, width)
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"vulkan_prepack::conv2d_clamp_prepack": 1,
|
||||
"vulkan_prepack::conv2d_clamp_run": 1}
|
||||
TestVulkanRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(), pattern_count_map, data_shape)
|
||||
|
||||
pattern_count_map["aten::relu"] = 1
|
||||
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
|
||||
TestVulkanRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True)
|
||||
pattern_count_map["aten::relu"] = -1
|
||||
TestVulkanRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
|
||||
class Conv2DHardtanh(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Conv2DHardtanh, self).__init__()
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
|
||||
self.strides = strides
|
||||
self.paddings = paddings
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
o = F.conv2d(x, self.weight, self.bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.hardtanh(o)
|
||||
return o
|
||||
|
||||
data_shape = (batch_size, input_channels, height, width)
|
||||
pattern_count_map = {"Tensor = aten::conv2d": -1,
|
||||
"vulkan_prepack::conv2d_clamp_prepack": 1,
|
||||
"vulkan_prepack::conv2d_clamp_run": 1}
|
||||
TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
|
||||
pattern_count_map["aten::hardtanh"] = 1
|
||||
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
|
||||
TestVulkanRewritePass.validate_transformed_module(
|
||||
Conv2DHardtanh(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True)
|
||||
pattern_count_map["aten::hardtanh"] = -1
|
||||
TestVulkanRewritePass.validate_transformed_module(
|
||||
Conv2DRelu(),
|
||||
pattern_count_map,
|
||||
data_shape,
|
||||
prepack_removal=True,
|
||||
fuse_clamping_ops=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -177,6 +177,7 @@ libtorch_core_sources = [
|
|||
"torch/csrc/jit/passes/utils/memory_dag.cpp",
|
||||
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
|
||||
"torch/csrc/jit/passes/xnnpack_rewrite.cpp",
|
||||
"torch/csrc/jit/passes/vulkan_rewrite.cpp",
|
||||
"torch/csrc/jit/passes/quantization/helper.cpp",
|
||||
"torch/csrc/jit/passes/quantization/quantization_type.cpp",
|
||||
"torch/csrc/jit/passes/quantization/insert_observers.cpp",
|
||||
|
|
|
|||
|
|
@ -152,6 +152,41 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
|||
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
|
||||
}
|
||||
|
||||
bool isClampFusable(
|
||||
const Match& match,
|
||||
const std::unordered_map<std::string, Value*>& vmap) {
|
||||
const auto& match_vmap = match.values_map;
|
||||
TORCH_CHECK(
|
||||
vmap.find("dummy_min_max") != vmap.end(),
|
||||
"Expected to find dummy_min_max Value in the subgraph to be replaced.");
|
||||
auto dummy_min_max =
|
||||
graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
|
||||
|
||||
auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
|
||||
|
||||
// Also check if the output_min and output_max values are actually constant.
|
||||
// If hardtanh's min/max Value's are not actually constants, we will end up
|
||||
// rerouting those values to prepack op. And if they are not constants
|
||||
// we will not be able to remove prepacking ops.
|
||||
if (vmap.find("output_min") != vmap.end()) {
|
||||
// aten::relu pattern does not have output_min/output_max.
|
||||
// aten::hardtanh/_ does.
|
||||
TORCH_CHECK(
|
||||
vmap.find("output_max") != vmap.end(),
|
||||
"Expected to find output_max as well given "
|
||||
"output_min exist in pattern graph.");
|
||||
// If output_min/max are not constant, we get c10::nullopt.
|
||||
auto output_min =
|
||||
graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
|
||||
auto output_max =
|
||||
graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
|
||||
is_fusable =
|
||||
is_fusable && (output_min.has_value() && output_max.has_value());
|
||||
}
|
||||
|
||||
return is_fusable;
|
||||
}
|
||||
|
||||
} // namespace graph_rewrite_helper
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ c10::optional<IValue> getIValue(
|
|||
const std::unordered_map<std::string, Value*>& vmap);
|
||||
void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph);
|
||||
|
||||
bool isClampFusable(
|
||||
const Match& match,
|
||||
const std::unordered_map<std::string, Value*>& vmap);
|
||||
|
||||
using MatchFilter = std::function<
|
||||
bool(const Match&, const std::unordered_map<std::string, Value*>&)>;
|
||||
|
||||
|
|
|
|||
205
torch/csrc/jit/passes/vulkan_rewrite.cpp
Normal file
205
torch/csrc/jit/passes/vulkan_rewrite.cpp
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
#include <ATen/core/jit_type.h>
|
||||
#ifdef USE_VULKAN
|
||||
#include <ATen/native/vulkan/VulkanOpContext.h>
|
||||
#endif
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
||||
#include <torch/csrc/jit/passes/constant_pooling.h>
|
||||
#include <torch/csrc/jit/passes/fold_conv_bn.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
#include <torch/csrc/jit/passes/prepack_folding.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#ifdef USE_VULKAN
|
||||
|
||||
namespace {
|
||||
|
||||
void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
|
||||
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
|
||||
|
||||
std::string conv_2d_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
||||
%r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
std::string prepacked_ops_conv2d_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
||||
%output_min_max : None = prim::Constant()
|
||||
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%output_min_max, %output_min_max)
|
||||
%r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
|
||||
return (%r) )";
|
||||
|
||||
SubgraphRewriter rewriter;
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv_2d_pattern, prepacked_ops_conv2d_pattern);
|
||||
rewriter.runOnGraph(graph);
|
||||
}
|
||||
|
||||
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
|
||||
SubgraphRewriter rewriter;
|
||||
|
||||
std::string conv2d_prepack_run_hardtanh_fused = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
||||
%packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%output_min, %output_max)
|
||||
%r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d_prepack_run_hardtanh = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
||||
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
|
||||
%r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
|
||||
|
||||
std::string conv2d_prepack_run_hardtanh_inplace = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
||||
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
|
||||
%r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
|
||||
|
||||
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
||||
}
|
||||
|
||||
void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
|
||||
SubgraphRewriter rewriter;
|
||||
|
||||
std::string conv2d_prepack_run_relu_fused = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %dummy_min_max):
|
||||
%output_min: float = prim::Constant[value=0.0]()
|
||||
%output_max: None = prim::Constant()
|
||||
%packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%output_min, %output_max)
|
||||
%r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d_prepack_run_relu = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %dummy_min_max):
|
||||
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
|
||||
%r = aten::relu(%conv2d_res)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
|
||||
|
||||
std::string conv2d_prepack_run_relu_inplace = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %dummy_min_max):
|
||||
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
|
||||
%r = aten::relu_(%conv2d_res)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
|
||||
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
|
||||
insertPrePackedConv2dOp(graph);
|
||||
}
|
||||
|
||||
void vulkanInsertPrePackedOps(script::Module& module) {
|
||||
for (auto& method : module.get_methods()) {
|
||||
auto graph = method.graph();
|
||||
vulkanInsertPrePackedOps(graph);
|
||||
}
|
||||
for (script::Module m : module.children()) {
|
||||
vulkanInsertPrePackedOps(m);
|
||||
}
|
||||
}
|
||||
|
||||
void vulkanFusePrePackedConvWithClamp(script::Module& module) {
|
||||
auto graph = module.get_method("forward").graph();
|
||||
fuseReluWithPackedOps(graph);
|
||||
fuseHardtanhWithPackedOps(graph);
|
||||
}
|
||||
|
||||
void vulkanFoldPrePackingOps(script::Module& m) {
|
||||
PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
|
||||
return (
|
||||
n->kind() ==
|
||||
Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack"));
|
||||
};
|
||||
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
|
||||
}
|
||||
|
||||
script::Module vulkanOptimizeForMobile(const script::Module& m) {
|
||||
auto cloned_module = m.clone();
|
||||
cloned_module.eval();
|
||||
cloned_module = FoldConvBatchNorm2d(cloned_module);
|
||||
vulkanInsertPrePackedOps(cloned_module);
|
||||
cloned_module = freeze_module(cloned_module);
|
||||
vulkanFusePrePackedConvWithClamp(cloned_module);
|
||||
vulkanFoldPrePackingOps(cloned_module);
|
||||
removeDropout(cloned_module);
|
||||
return cloned_module;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Vulkan is not enabled. Please build with USE_VULKAN=1");
|
||||
}
|
||||
|
||||
void vulkanInsertPrePackedOps(script::Module& module) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Vulkan is not enabled. Please build with USE_VULKAN=1");
|
||||
}
|
||||
|
||||
void vulkanFusePrePackedConvWithClamp(script::Module& module) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Vulkan is not enabled. Please build with USE_VULKAN=1");
|
||||
}
|
||||
|
||||
void vulkanFoldPrePackingOps(script::Module& m) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Vulkan is not enabled. Please build with USE_VULKAN=1");
|
||||
}
|
||||
|
||||
script::Module vulkanOptimizeForMobile(const script::Module& module) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Mobile optimizaiton only available with Vulkan at the moment. "
|
||||
"Vulkan is not enabled. Please build with USE_VULKAN=1");
|
||||
return module;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
14
torch/csrc/jit/passes/vulkan_rewrite.h
Normal file
14
torch/csrc/jit/passes/vulkan_rewrite.h
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
TORCH_API void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void vulkanInsertPrePackedOps(script::Module& module);
|
||||
TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
|
||||
TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
|
||||
TORCH_API script::Module vulkanOptimizeForMobile(const script::Module& module);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -93,41 +93,6 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
|
|||
rewriter.runOnGraph(graph);
|
||||
}
|
||||
|
||||
bool isClampFusable(
|
||||
const Match& match,
|
||||
const std::unordered_map<std::string, Value*>& vmap) {
|
||||
const auto& match_vmap = match.values_map;
|
||||
TORCH_CHECK(
|
||||
vmap.find("dummy_min_max") != vmap.end(),
|
||||
"Expected to find dummy_min_max Value in the subgraph to be replaced.");
|
||||
auto dummy_min_max =
|
||||
graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
|
||||
|
||||
auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
|
||||
|
||||
// Also check if the output_min and output_max values are actually constant.
|
||||
// If hardtanh's min/max Value's are not actually constants, we will end up
|
||||
// rerouting those values to prepack op. And if they are not constants
|
||||
// we will not be able to remove prepacking ops.
|
||||
if (vmap.find("output_min") != vmap.end()) {
|
||||
// aten::relu pattern does not have output_min/output_max.
|
||||
// aten::hardtanh/_ does.
|
||||
TORCH_CHECK(
|
||||
vmap.find("output_max") != vmap.end(),
|
||||
"Expected to find output_max as well given "
|
||||
"output_min exist in pattern graph.");
|
||||
// If output_min/max are not constant, we get c10::nullopt.
|
||||
auto output_min =
|
||||
graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
|
||||
auto output_max =
|
||||
graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
|
||||
is_fusable =
|
||||
is_fusable && (output_min.has_value() && output_max.has_value());
|
||||
}
|
||||
|
||||
return is_fusable;
|
||||
}
|
||||
|
||||
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
|
||||
SubgraphRewriter rewriter;
|
||||
|
||||
|
|
@ -194,7 +159,7 @@ void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
|
|||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
|
||||
|
||||
rewriter.runOnGraph(graph, isClampFusable);
|
||||
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
||||
}
|
||||
|
||||
void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
|
||||
|
|
@ -266,7 +231,7 @@ void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
|
|||
linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
|
||||
rewriter.runOnGraph(graph, isClampFusable);
|
||||
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
||||
}
|
||||
|
||||
void runCanonicalOptimizations(script::Module& module) {
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@
|
|||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
|
||||
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/jit/python/python_arg_flatten.h>
|
||||
|
|
@ -541,6 +542,31 @@ void initJITBindings(PyObject* module) {
|
|||
std::set<MobileOptimizerType>& optimization_blacklist) {
|
||||
return optimizeForMobile(module, optimization_blacklist);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_vulkan_insert_prepacked_ops",
|
||||
[](std::shared_ptr<Graph>& graph) {
|
||||
return vulkanInsertPrePackedOps(graph);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_vulkan_insert_prepacked_ops",
|
||||
[](script::Module& module) {
|
||||
return vulkanInsertPrePackedOps(module);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
|
||||
[](script::Module& module) {
|
||||
return vulkanFusePrePackedConvWithClamp(module);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_vulkan_fold_prepacking_ops",
|
||||
[](script::Module& module) {
|
||||
return vulkanFoldPrePackingOps(module);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_vulkan_optimize_for_mobile",
|
||||
[](script::Module& module) {
|
||||
return vulkanOptimizeForMobile(module);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_onnx_unpack_quantized_weights",
|
||||
[](std::shared_ptr<Graph>& graph,
|
||||
|
|
|
|||
Loading…
Reference in a new issue