Add fusion support for batchnorm and convolution without bias

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10595

Reviewed By: bwasti

Differential Revision: D9110099

fbshipit-source-id: e1ed66c7d82b2f9987b7eb9c7f98877a6dbeb902
This commit is contained in:
Keren Zhou 2018-08-17 11:59:00 -07:00 committed by Facebook Github Bot
parent d35f365ad5
commit f3ac619764
2 changed files with 156 additions and 8 deletions

View file

@ -11,12 +11,17 @@ using namespace nom;
// $$ X_{bn} = \frac{s(X - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
// $$ X_{conv} = X * W + b_{conv} $$
// thus, substituting $X$ with $X_{conv}$ in the BN equation we get:
// $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
// or
// $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} -
// m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ or
// $$ W' = W\frac{s}{\sqrt{\sigma + \epsilon}}$$
// $$ b' = (b_{conv} - m)\frac{s}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
for (auto convNode : repr::nn::nodeIterator<repr::Conv>(nn->dataFlow)) {
size_t convOrder = 0;
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
repr::NNGraph::NodeRef convNode;
repr::Conv* conv;
std::tie(conv, convNode) = node_pair;
auto output = repr::nn::getOutputs(convNode).front();
auto consumers = repr::nn::getConsumers(output);
NOM_REQUIRE_OR_CONT(consumers.size() == 1);
@ -31,9 +36,9 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
auto bnOutput = bnOutputs.front();
auto convInputs = repr::nn::getInputs(convNode);
CAFFE_ENFORCE(
convInputs.size() >= 3,
"Invalid convolution input size (TODO: optional bias)");
if (convInputs.size() < 2) {
continue;
}
auto bnInputs = repr::nn::getInputs(bnNode);
CAFFE_ENFORCE(
@ -46,13 +51,46 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
auto name##Data = name##Tensor->mutable_data<float>();
EXPOSE_TENSOR_DATA(filter, 1, convInputs);
EXPOSE_TENSOR_DATA(biasConv, 2, convInputs);
EXPOSE_TENSOR_DATA(scale, 1, bnInputs);
EXPOSE_TENSOR_DATA(biasBN, 2, bnInputs);
EXPOSE_TENSOR_DATA(mean, 3, bnInputs);
EXPOSE_TENSOR_DATA(variance, 4, bnInputs);
if (convInputs.size() == 2) {
NOM_REQUIRE_OR_CONT(conv->getMutableAnnotation() != nullptr);
auto annotation =
dyn_cast<caffe2::Caffe2Annotation>(conv->getMutableAnnotation());
NOM_REQUIRE_OR_CONT(annotation != nullptr);
auto op = annotation->getOperatorDef();
auto convName = op.name();
while (true) {
auto convBiasName = convName + "_bias" + to_string(convOrder);
if (!ws->HasBlob(convBiasName)) {
auto convBiasTensor = make_unique<repr::Tensor>(convBiasName);
convBiasTensor->setType(repr::Tensor::DataType::Float);
auto convBiasNode = nn->dataFlow.createNode(
unique_dyn_cast<repr::NeuralNetData>(convBiasTensor));
nn->inputs.insert(convBiasNode);
nn->dataFlow.createEdge(convBiasNode, convNode);
auto* blob = ws->CreateBlob(convBiasName);
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
CHECK_NOTNULL(tensor);
// Get output channel
size_t c = filterTensor->dim32(0);
tensor->Resize(c);
tensor->mutable_data<float>();
break;
}
convOrder++;
}
}
convInputs = repr::nn::getInputs(convNode);
EXPOSE_TENSOR_DATA(biasConv, 2, convInputs);
#undef EXPOSE_TENSOR_DATA
// Assume M{CHW,HWC}

View file

@ -221,6 +221,116 @@ class TestTransformations(test_util.TestCase):
assert np.allclose(
preTransformOutput,
postTransformOutput,
rtol=1e-02,
rtol=5e-02,
atol=1e-03
)
@given(
size=st.integers(7, 10),
input_channels=st.integers(1, 10),
seed=st.integers(0, 65535),
order=st.sampled_from(["NCHW", "NHWC"]),
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
)
def test_transformer_FuseConvBNNoConvBias(self, size, input_channels, seed, order, epsilon):
workspace.ResetWorkspace()
net = core.Net("net")
c = input_channels
h = size
w = size
k = 3
net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
net.SpatialBN(
["Y", "scale", "bias", "mean", "var"],
["Y2"],
is_test=True,
order=order,
epsilon=epsilon,
)
np.random.seed(seed)
if order == "NCHW":
workspace.FeedBlob("X", np.random.rand(1, c, h, w).astype(np.float32))
workspace.FeedBlob("w", np.random.rand(c, c, k, k).astype(np.float32))
else:
workspace.FeedBlob("X", np.random.rand(1, h, w, c).astype(np.float32))
workspace.FeedBlob("w", np.random.rand(c, k, k, c).astype(np.float32))
workspace.FeedBlob("scale", np.random.rand(c).astype(np.float32))
workspace.FeedBlob("bias", np.random.rand(c).astype(np.float32))
workspace.FeedBlob("mean", np.random.rand(c).astype(np.float32))
# This is necessary because 1/sqrt(var) is used and if var is too small
# we get floating point artifacts that cause test failures
workspace.FeedBlob("var", np.random.rand(c).astype(np.float32) + 0.5)
workspace.RunNetOnce(net)
preTransformOutput = workspace.FetchBlob("Y2").flatten()
workspace.FeedBlob("Y2", np.zeros((1, 1)))
transformer.FuseConvBN(net)
# Ensure fusion
assert len(net.Proto().op) == 1
workspace.RunNetOnce(net)
postTransformOutput = workspace.FetchBlob("Y2").flatten()
# Check that there is no numerical difference
assert np.allclose(
preTransformOutput,
postTransformOutput,
rtol=5e-02,
atol=1e-03
)
@given(
size=st.integers(7, 10),
input_channels=st.integers(1, 10),
seed=st.integers(0, 65535),
order=st.sampled_from(["NCHW", "NHWC"]),
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
)
def test_transformer_FuseConvBNNoConvBiasDuplicatedName(self, size, input_channels, seed, order, epsilon):
workspace.ResetWorkspace()
net = core.Net("net")
c = input_channels
h = size
w = size
k = 3
net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
net.SpatialBN(
["Y", "scale", "_bias0", "mean", "var"],
["Y2"],
is_test=True,
order=order,
epsilon=epsilon,
)
np.random.seed(seed)
if order == "NCHW":
workspace.FeedBlob("X", np.random.rand(1, c, h, w).astype(np.float32))
workspace.FeedBlob("w", np.random.rand(c, c, k, k).astype(np.float32))
else:
workspace.FeedBlob("X", np.random.rand(1, h, w, c).astype(np.float32))
workspace.FeedBlob("w", np.random.rand(c, k, k, c).astype(np.float32))
workspace.FeedBlob("scale", np.random.rand(c).astype(np.float32))
workspace.FeedBlob("_bias0", np.random.rand(c).astype(np.float32))
workspace.FeedBlob("mean", np.random.rand(c).astype(np.float32))
# This is necessary because 1/sqrt(var) is used and if var is too small
# we get floating point artifacts that cause test failures
workspace.FeedBlob("var", np.random.rand(c).astype(np.float32) + 0.5)
workspace.RunNetOnce(net)
preTransformOutput = workspace.FetchBlob("Y2").flatten()
workspace.FeedBlob("Y2", np.zeros((1, 1)))
transformer.FuseConvBN(net)
# Ensure fusion
assert len(net.Proto().op) == 1
workspace.RunNetOnce(net)
postTransformOutput = workspace.FetchBlob("Y2").flatten()
print("pre")
print(preTransformOutput)
print("after")
print(postTransformOutput)
# Check that there is no numerical difference
assert np.allclose(
preTransformOutput,
postTransformOutput,
rtol=5e-02,
atol=1e-03
)