[vulkan] jit passes for vulkan conv2 prepack and fuse with clamp (#39282)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39282

Test Plan: Imported from OSS

Differential Revision: D21962424

Pulled By: IvanKobzarev

fbshipit-source-id: 2d20e827d2c3836b7e6b443293377c68dc1ffa5a
This commit is contained in:
Ivan Kobzarev 2020-06-20 14:08:52 -07:00 committed by Facebook GitHub Bot
parent f69460d0cb
commit 3852215170
16 changed files with 540 additions and 64 deletions

View file

@ -179,7 +179,7 @@ at::Tensor vulkan_convolution(
voutput,
vinput,
weight.data_ptr<float>(),
bias.defined() ? c10::make_optional<float*>(bias.data_ptr<float>())
bias.defined() ? c10::make_optional<const float*>(bias.data_ptr<float>())
: c10::nullopt,
params);
return new_with_vtensor_vulkan(std::move(voutput), input.options());
@ -242,7 +242,8 @@ at::Tensor vulkan_convolution_prepacked(
voutput,
vinput,
vweight,
hasBias ? c10::make_optional((*bias).data_ptr<float>()) : c10::nullopt,
hasBias ? c10::make_optional<const float*>((*bias).data_ptr<float>())
: c10::nullopt,
params,
output_min,
output_max);

View file

@ -66,14 +66,13 @@ ContextConv2D create(
const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
const auto dilation_expanded =
expand_param_if_needed(dilation, "dilation", 2);
const Tensor weight_nchw = weight.contiguous();
Tensor weight_nchw = weight.contiguous();
auto ws = weight_nchw.sizes();
return ContextConv2D{
at::native::vulkan_convolution_prepack_weights(weight),
groups == 1 ? at::native::vulkan_convolution_prepack_weights(weight_nchw)
: weight_nchw.vulkan(),
bias.has_value() ? c10::make_optional((*bias).vulkan()) : c10::nullopt,
{weight_nchw.sizes()[0],
weight_nchw.sizes()[1],
weight_nchw.sizes()[2],
weight_nchw.sizes()[3]},
{{ws[0], ws[1], ws[2], ws[3]}},
{padding_expanded[0], padding_expanded[1]},
{stride_expanded[0], stride_expanded[1]},
{dilation_expanded[0], dilation_expanded[1]},

View file

@ -176,7 +176,7 @@ VBuffer kernelNCHW_OCHW_repack_O4C4HWi4o4(
}
VBuffer bufferFromOptionalHostData(
c10::optional<float*> data,
c10::optional<const float*> data,
const uint32_t size) {
const auto sizeAligned =
ROUND_UP(size, context().limits().minStorageBufferOffsetAlignment);
@ -202,17 +202,15 @@ uint32_t conv2d_biasBufferSize(uint32_t oc) {
void conv2d_depthwise(
VulkanTensor& output,
const VulkanTensor& input,
const float* weight,
const c10::optional<float*> bias,
const Conv2DParams params,
const VulkanTensor& weight,
const VBuffer& biasBuffer,
const Conv2DParams& params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
TORCH_INTERNAL_ASSERT(params.G == params.C);
auto osizes = output.sizes();
TORCH_INTERNAL_ASSERT(osizes[2] == params.OH);
TORCH_INTERNAL_ASSERT(osizes[3] == params.OW);
auto biasBuffer =
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC));
struct ConstBlock {
int32_t padding[2];
int32_t kernelSize[2];
@ -234,9 +232,6 @@ void conv2d_depthwise(
output_max ? *output_max : std::numeric_limits<float>::infinity()};
VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb));
VulkanTensor kernel{{params.OC, params.KH, params.KW}};
kernel.set_data_from_host(weight);
VkDescriptorSetLayout descriptorSetLayout{};
VkDescriptorPool descriptorPool{};
VkDescriptorSet descriptorSet{};
@ -256,7 +251,7 @@ void conv2d_depthwise(
output.image()->bindStorageImage(descriptorSet, 0);
input.image()->bindShaderRead(descriptorSet, 1);
kernel.image()->bindShaderRead(descriptorSet, 2);
weight.image()->bindShaderRead(descriptorSet, 2);
biasBuffer.bind(descriptorSet, 3);
constBuffer.bind(descriptorSet, 4);
@ -269,7 +264,7 @@ void conv2d_depthwise(
auto commandBuffer = computeUnit.commandBuffer();
output.image()->addImageMemoryBarrierToGeneral(commandBuffer);
input.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
kernel.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
weight.image()->addImageMemoryBarrierToShaderRead(commandBuffer);
computeUnit.dispatchCommandBuffer(
params.OW, params.OH, params.OC_4, workGroupSize);
computeUnit.endCommandBuffer();
@ -279,6 +274,44 @@ void conv2d_depthwise(
vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
}
void conv2d_depthwise(
VulkanTensor& output,
const VulkanTensor& input,
const VulkanTensor& weight,
const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
conv2d_depthwise(
output,
input,
weight,
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
params,
output_min,
output_max);
}
void conv2d_depthwise(
VulkanTensor& output,
const VulkanTensor& input,
const float* weight,
const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
VulkanTensor weightTensor{{params.OC, params.KH, params.KW}};
weightTensor.set_data_from_host(weight);
conv2d_depthwise(
output,
input,
weightTensor,
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
params,
output_min,
output_max);
}
ImageSizes conv2d_prepack_weights_image_sizes(
int64_t OC,
int64_t C,
@ -463,7 +496,7 @@ void conv2d(
VulkanTensor& output,
const VulkanTensor& input,
const VImage& kernelImage,
const c10::optional<float*> bias,
const c10::optional<const float*> bias,
const Conv2DParams& params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
@ -483,10 +516,22 @@ void conv2d(
VulkanTensor& output,
const VulkanTensor& input,
const VulkanTensor& weight_prepacked,
c10::optional<float*> bias,
c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
if (params.G > 1) {
conv2d_depthwise(
output,
input,
weight_prepacked,
bufferFromOptionalHostData(bias, conv2d_biasBufferSize(params.OC)),
params,
output_min,
output_max);
return;
}
conv2d(
output,
input,
@ -505,6 +550,18 @@ void conv2d(
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {
if (params.G > 1) {
conv2d_depthwise(
output,
input,
weight_prepacked,
*(bias.buffer()),
params,
output_min,
output_max);
return;
}
conv2d(
output,
input,
@ -519,7 +576,7 @@ void conv2d(
VulkanTensor& output,
const VulkanTensor& input,
const float* weight,
const c10::optional<float*> bias,
const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min,
c10::optional<float> output_max) {

View file

@ -37,7 +37,7 @@ void conv2d(
VulkanTensor& output,
const VulkanTensor& input,
const float* weight,
const c10::optional<float*> bias,
const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min = c10::nullopt,
c10::optional<float> output_max = c10::nullopt);
@ -46,7 +46,7 @@ void conv2d(
VulkanTensor& output,
const VulkanTensor& input,
const VulkanTensor& weight_prepacked,
const c10::optional<float*> bias,
const c10::optional<const float*> bias,
const Conv2DParams params,
c10::optional<float> output_min = c10::nullopt,
c10::optional<float> output_max = c10::nullopt);

View file

@ -496,7 +496,7 @@ TEST(VulkanTest, conv2dPrepack) {
ASSERT_TRUE(no_prepack_check);
auto prepack = callOpByName(
"vulkan::conv2d_clamp_prepack",
"vulkan_prepack::conv2d_clamp_prepack",
"",
t_w,
t_b,
@ -507,7 +507,7 @@ TEST(VulkanTest, conv2dPrepack) {
output_min,
output_max);
auto tv_out_prepack_ivalues =
callOpByName("vulkan::conv2d_clamp_run", "", tv_in, prepack[0]);
callOpByName("vulkan_prepack::conv2d_clamp_run", "", tv_in, prepack[0]);
auto tv_out_prepack = tv_out_prepack_ivalues[0].toTensor();
auto t_out_prepack = tv_out_prepack.cpu();
const auto prepack_check = almostEqual(t_out_prepack, t_out_expected);

View file

@ -103,3 +103,4 @@ endif()
caffe2_binary_target("tutorial_blob.cc")
caffe2_binary_target("dump_operator_names.cc")
caffe2_binary_target("optimize_for_mobile.cc")

View file

@ -17,6 +17,7 @@
#include <string>
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
#include "torch/csrc/jit/serialization/import.h"
@ -29,6 +30,7 @@ C10_DEFINE_bool(
save_for_mobile,
false,
"Save the model with bytecode format compatible with lite inteprter.");
C10_DEFINE_bool(vulkan, false, "Vulkan optimize_for_mobile");
int main(int argc, char** argv) {
c10::SetUsageMessage(
@ -52,7 +54,10 @@ int main(int argc, char** argv) {
}
auto module = torch::jit::load(FLAGS_model);
auto optimized_module = torch::jit::optimizeForMobile(module);
auto optimized_module = FLAGS_vulkan
? torch::jit::vulkanOptimizeForMobile(module)
: torch::jit::optimizeForMobile(module);
if (FLAGS_save_for_mobile) {
optimized_module._save_for_mobile(output_model_name);

View file

@ -48,6 +48,7 @@ TESTS = [
'test_optim',
'test_mobile_optimizer',
'test_xnnpack_integration',
'test_vulkan',
'test_quantization',
'test_sparse',
'test_serialization',

162
test/test_vulkan.py Normal file
View file

@ -0,0 +1,162 @@
import unittest
import torch
from torch.nn import functional as F
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing import FileCheck
import io
@unittest.skipUnless(torch.is_vulkan_available(),
"Vulkan backend must be available for these tests.")
class TestVulkanRewritePass(TestCase):
@staticmethod
def validate_transformed_module(
# To please flake
self,
pattern_count_map,
data_shape,
prepack_removal=False,
fuse_clamping_ops=False):
module_instance = self
scripted_model = torch.jit.script(module_instance)
scripted_model.eval()
input_data = torch.normal(1, 20, size=data_shape)
ref_result = scripted_model(input_data)
torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
if fuse_clamping_ops or prepack_removal:
scripted_model._c = torch._C._freeze_module(scripted_model._c)
if fuse_clamping_ops:
torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
if prepack_removal:
torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
for pattern, v in pattern_count_map.items():
if (v == 0):
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif (v == -1):
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
else:
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
def test_conv(self):
# Conv params
batch_size = 2
input_channels_per_group = 6
height = 16
width = 16
output_channels_per_group = 6
groups = 4
kernel_h = kernel_w = 3
stride_h = stride_w = 1
pad_h = pad_w = 1
dilation = 1
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
conv_bias_shape = (output_channels)
class Conv2D(torch.nn.Module):
def __init__(self):
super(Conv2D, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
return F.conv2d(x, self.weight, self.bias,
self.strides, self.paddings, self.dilations, self.groups)
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"vulkan_prepack::conv2d_clamp_prepack": 1,
"vulkan_prepack::conv2d_clamp_run": 1}
TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
class Conv2DRelu(torch.nn.Module):
def __init__(self):
super(Conv2DRelu, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(x, self.weight, self.bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.relu(o)
return o
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"vulkan_prepack::conv2d_clamp_prepack": 1,
"vulkan_prepack::conv2d_clamp_run": 1}
TestVulkanRewritePass.validate_transformed_module(
Conv2DRelu(), pattern_count_map, data_shape)
pattern_count_map["aten::relu"] = 1
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
TestVulkanRewritePass.validate_transformed_module(
Conv2DRelu(),
pattern_count_map,
data_shape,
prepack_removal=True)
pattern_count_map["aten::relu"] = -1
TestVulkanRewritePass.validate_transformed_module(
Conv2DRelu(),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
class Conv2DHardtanh(torch.nn.Module):
def __init__(self):
super(Conv2DHardtanh, self).__init__()
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(x, self.weight, self.bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.hardtanh(o)
return o
data_shape = (batch_size, input_channels, height, width)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"vulkan_prepack::conv2d_clamp_prepack": 1,
"vulkan_prepack::conv2d_clamp_run": 1}
TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
pattern_count_map["aten::hardtanh"] = 1
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
TestVulkanRewritePass.validate_transformed_module(
Conv2DHardtanh(),
pattern_count_map,
data_shape,
prepack_removal=True)
pattern_count_map["aten::hardtanh"] = -1
TestVulkanRewritePass.validate_transformed_module(
Conv2DRelu(),
pattern_count_map,
data_shape,
prepack_removal=True,
fuse_clamping_ops=True)
if __name__ == "__main__":
run_tests()

View file

@ -177,6 +177,7 @@ libtorch_core_sources = [
"torch/csrc/jit/passes/utils/memory_dag.cpp",
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
"torch/csrc/jit/passes/xnnpack_rewrite.cpp",
"torch/csrc/jit/passes/vulkan_rewrite.cpp",
"torch/csrc/jit/passes/quantization/helper.cpp",
"torch/csrc/jit/passes/quantization/quantization_type.cpp",
"torch/csrc/jit/passes/quantization/insert_observers.cpp",

View file

@ -152,6 +152,41 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
}
bool isClampFusable(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
TORCH_CHECK(
vmap.find("dummy_min_max") != vmap.end(),
"Expected to find dummy_min_max Value in the subgraph to be replaced.");
auto dummy_min_max =
graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
// Also check if the output_min and output_max values are actually constant.
// If hardtanh's min/max Value's are not actually constants, we will end up
// rerouting those values to prepack op. And if they are not constants
// we will not be able to remove prepacking ops.
if (vmap.find("output_min") != vmap.end()) {
// aten::relu pattern does not have output_min/output_max.
// aten::hardtanh/_ does.
TORCH_CHECK(
vmap.find("output_max") != vmap.end(),
"Expected to find output_max as well given "
"output_min exist in pattern graph.");
// If output_min/max are not constant, we get c10::nullopt.
auto output_min =
graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
auto output_max =
graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
is_fusable =
is_fusable && (output_min.has_value() && output_max.has_value());
}
return is_fusable;
}
} // namespace graph_rewrite_helper
} // namespace jit
} // namespace torch

View file

@ -19,6 +19,10 @@ c10::optional<IValue> getIValue(
const std::unordered_map<std::string, Value*>& vmap);
void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph);
bool isClampFusable(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap);
using MatchFilter = std::function<
bool(const Match&, const std::unordered_map<std::string, Value*>&)>;

View file

@ -0,0 +1,205 @@
#include <ATen/core/jit_type.h>
#ifdef USE_VULKAN
#include <ATen/native/vulkan/VulkanOpContext.h>
#endif
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/fold_conv_bn.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
namespace torch {
namespace jit {
#ifdef USE_VULKAN
namespace {
void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
std::string conv_2d_pattern = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string prepacked_ops_conv2d_pattern = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
%output_min_max : None = prim::Constant()
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%output_min_max, %output_min_max)
%r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
return (%r) )";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(
conv_2d_pattern, prepacked_ops_conv2d_pattern);
rewriter.runOnGraph(graph);
}
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter;
std::string conv2d_prepack_run_hardtanh_fused = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
%packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%output_min, %output_max)
%r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
return (%r) )";
std::string conv2d_prepack_run_hardtanh = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%dummy_min_max, %dummy_min_max)
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
%r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
return (%r) )";
rewriter.RegisterRewritePattern(
conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
std::string conv2d_prepack_run_hardtanh_inplace = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%dummy_min_max, %dummy_min_max)
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
%r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
return (%r) )";
rewriter.RegisterRewritePattern(
conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
}
void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter;
std::string conv2d_prepack_run_relu_fused = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %dummy_min_max):
%output_min: float = prim::Constant[value=0.0]()
%output_max: None = prim::Constant()
%packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dOpContext = vulkan_prepack::conv2d_clamp_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%output_min, %output_max)
%r = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
return (%r) )";
std::string conv2d_prepack_run_relu = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %dummy_min_max):
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%dummy_min_max, %dummy_min_max)
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
%r = aten::relu(%conv2d_res)
return (%r) )";
rewriter.RegisterRewritePattern(
conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
std::string conv2d_prepack_run_relu_inplace = R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %dummy_min_max):
%packed_weight_bias = vulkan_prepack::conv2d_clamp_prepack(
%weight, %bias, %stride, %padding, %dilation, %groups,
%dummy_min_max, %dummy_min_max)
%conv2d_res = vulkan_prepack::conv2d_clamp_run(%input, %packed_weight_bias)
%r = aten::relu_(%conv2d_res)
return (%r) )";
rewriter.RegisterRewritePattern(
conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
}
} // namespace
void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
insertPrePackedConv2dOp(graph);
}
void vulkanInsertPrePackedOps(script::Module& module) {
for (auto& method : module.get_methods()) {
auto graph = method.graph();
vulkanInsertPrePackedOps(graph);
}
for (script::Module m : module.children()) {
vulkanInsertPrePackedOps(m);
}
}
void vulkanFusePrePackedConvWithClamp(script::Module& module) {
auto graph = module.get_method("forward").graph();
fuseReluWithPackedOps(graph);
fuseHardtanhWithPackedOps(graph);
}
void vulkanFoldPrePackingOps(script::Module& m) {
PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
return (
n->kind() ==
Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack"));
};
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
}
script::Module vulkanOptimizeForMobile(const script::Module& m) {
auto cloned_module = m.clone();
cloned_module.eval();
cloned_module = FoldConvBatchNorm2d(cloned_module);
vulkanInsertPrePackedOps(cloned_module);
cloned_module = freeze_module(cloned_module);
vulkanFusePrePackedConvWithClamp(cloned_module);
vulkanFoldPrePackingOps(cloned_module);
removeDropout(cloned_module);
return cloned_module;
}
#else
void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
TORCH_INTERNAL_ASSERT(
"Vulkan is not enabled. Please build with USE_VULKAN=1");
}
void vulkanInsertPrePackedOps(script::Module& module) {
TORCH_INTERNAL_ASSERT(
"Vulkan is not enabled. Please build with USE_VULKAN=1");
}
void vulkanFusePrePackedConvWithClamp(script::Module& module) {
TORCH_INTERNAL_ASSERT(
"Vulkan is not enabled. Please build with USE_VULKAN=1");
}
void vulkanFoldPrePackingOps(script::Module& m) {
TORCH_INTERNAL_ASSERT(
"Vulkan is not enabled. Please build with USE_VULKAN=1");
}
script::Module vulkanOptimizeForMobile(const script::Module& module) {
TORCH_INTERNAL_ASSERT(
"Mobile optimizaiton only available with Vulkan at the moment. "
"Vulkan is not enabled. Please build with USE_VULKAN=1");
return module;
}
#endif
} // namespace jit
} // namespace torch

View file

@ -0,0 +1,14 @@
#pragma once
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
TORCH_API void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph);
TORCH_API void vulkanInsertPrePackedOps(script::Module& module);
TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
TORCH_API script::Module vulkanOptimizeForMobile(const script::Module& module);
} // namespace jit
} // namespace torch

View file

@ -93,41 +93,6 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
rewriter.runOnGraph(graph);
}
bool isClampFusable(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
TORCH_CHECK(
vmap.find("dummy_min_max") != vmap.end(),
"Expected to find dummy_min_max Value in the subgraph to be replaced.");
auto dummy_min_max =
graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
// Also check if the output_min and output_max values are actually constant.
// If hardtanh's min/max Value's are not actually constants, we will end up
// rerouting those values to prepack op. And if they are not constants
// we will not be able to remove prepacking ops.
if (vmap.find("output_min") != vmap.end()) {
// aten::relu pattern does not have output_min/output_max.
// aten::hardtanh/_ does.
TORCH_CHECK(
vmap.find("output_max") != vmap.end(),
"Expected to find output_max as well given "
"output_min exist in pattern graph.");
// If output_min/max are not constant, we get c10::nullopt.
auto output_min =
graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
auto output_max =
graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
is_fusable =
is_fusable && (output_min.has_value() && output_max.has_value());
}
return is_fusable;
}
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter;
@ -194,7 +159,7 @@ void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
rewriter.RegisterRewritePattern(
conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
rewriter.runOnGraph(graph, isClampFusable);
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
}
void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
@ -266,7 +231,7 @@ void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
rewriter.RegisterRewritePattern(
conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
rewriter.runOnGraph(graph, isClampFusable);
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
}
void runCanonicalOptimizations(script::Module& module) {

View file

@ -54,6 +54,7 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/python/python_arg_flatten.h>
@ -541,6 +542,31 @@ void initJITBindings(PyObject* module) {
std::set<MobileOptimizerType>& optimization_blacklist) {
return optimizeForMobile(module, optimization_blacklist);
})
.def(
"_jit_pass_vulkan_insert_prepacked_ops",
[](std::shared_ptr<Graph>& graph) {
return vulkanInsertPrePackedOps(graph);
})
.def(
"_jit_pass_vulkan_insert_prepacked_ops",
[](script::Module& module) {
return vulkanInsertPrePackedOps(module);
})
.def(
"_jit_pass_vulkan_fuse_clamp_w_prepacked_conv",
[](script::Module& module) {
return vulkanFusePrePackedConvWithClamp(module);
})
.def(
"_jit_pass_vulkan_fold_prepacking_ops",
[](script::Module& module) {
return vulkanFoldPrePackingOps(module);
})
.def(
"_jit_pass_vulkan_optimize_for_mobile",
[](script::Module& module) {
return vulkanOptimizeForMobile(module);
})
.def(
"_jit_pass_onnx_unpack_quantized_weights",
[](std::shared_ptr<Graph>& graph,