diff --git a/caffe2/opt/fusion.cc b/caffe2/opt/fusion.cc index f5ea0f678ed..61d5301adb7 100644 --- a/caffe2/opt/fusion.cc +++ b/caffe2/opt/fusion.cc @@ -11,12 +11,17 @@ using namespace nom; // $$ X_{bn} = \frac{s(X - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ // $$ X_{conv} = X * W + b_{conv} $$ // thus, substituting $X$ with $X_{conv}$ in the BN equation we get: -// $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ -// or +// $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} - +// m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ or // $$ W' = W\frac{s}{\sqrt{\sigma + \epsilon}}$$ // $$ b' = (b_{conv} - m)\frac{s}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) { - for (auto convNode : repr::nn::nodeIterator(nn->dataFlow)) { + size_t convOrder = 0; + for (auto node_pair : repr::nn::dataIterator(nn->dataFlow)) { + repr::NNGraph::NodeRef convNode; + repr::Conv* conv; + std::tie(conv, convNode) = node_pair; + auto output = repr::nn::getOutputs(convNode).front(); auto consumers = repr::nn::getConsumers(output); NOM_REQUIRE_OR_CONT(consumers.size() == 1); @@ -31,9 +36,9 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) { auto bnOutput = bnOutputs.front(); auto convInputs = repr::nn::getInputs(convNode); - CAFFE_ENFORCE( - convInputs.size() >= 3, - "Invalid convolution input size (TODO: optional bias)"); + if (convInputs.size() < 2) { + continue; + } auto bnInputs = repr::nn::getInputs(bnNode); CAFFE_ENFORCE( @@ -46,13 +51,46 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) { auto name##Data = name##Tensor->mutable_data(); EXPOSE_TENSOR_DATA(filter, 1, convInputs); - EXPOSE_TENSOR_DATA(biasConv, 2, convInputs); EXPOSE_TENSOR_DATA(scale, 1, bnInputs); EXPOSE_TENSOR_DATA(biasBN, 2, bnInputs); EXPOSE_TENSOR_DATA(mean, 3, bnInputs); EXPOSE_TENSOR_DATA(variance, 4, bnInputs); + if (convInputs.size() == 2) { + NOM_REQUIRE_OR_CONT(conv->getMutableAnnotation() != nullptr); + auto annotation = + dyn_cast(conv->getMutableAnnotation()); + NOM_REQUIRE_OR_CONT(annotation != nullptr); + auto op = annotation->getOperatorDef(); + auto convName = op.name(); + + while (true) { + auto convBiasName = convName + "_bias" + to_string(convOrder); + if (!ws->HasBlob(convBiasName)) { + auto convBiasTensor = make_unique(convBiasName); + convBiasTensor->setType(repr::Tensor::DataType::Float); + auto convBiasNode = nn->dataFlow.createNode( + unique_dyn_cast(convBiasTensor)); + nn->inputs.insert(convBiasNode); + nn->dataFlow.createEdge(convBiasNode, convNode); + + auto* blob = ws->CreateBlob(convBiasName); + caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU); + CHECK_NOTNULL(tensor); + // Get output channel + size_t c = filterTensor->dim32(0); + tensor->Resize(c); + tensor->mutable_data(); + break; + } + convOrder++; + } + } + + convInputs = repr::nn::getInputs(convNode); + EXPOSE_TENSOR_DATA(biasConv, 2, convInputs); + #undef EXPOSE_TENSOR_DATA // Assume M{CHW,HWC} diff --git a/caffe2/python/transformations_test.py b/caffe2/python/transformations_test.py index 4e215b586e5..d9992116a69 100644 --- a/caffe2/python/transformations_test.py +++ b/caffe2/python/transformations_test.py @@ -221,6 +221,116 @@ class TestTransformations(test_util.TestCase): assert np.allclose( preTransformOutput, postTransformOutput, - rtol=1e-02, + rtol=5e-02, + atol=1e-03 + ) + + @given( + size=st.integers(7, 10), + input_channels=st.integers(1, 10), + seed=st.integers(0, 65535), + order=st.sampled_from(["NCHW", "NHWC"]), + epsilon=st.floats(min_value=1e-5, max_value=1e-2), + ) + def test_transformer_FuseConvBNNoConvBias(self, size, input_channels, seed, order, epsilon): + workspace.ResetWorkspace() + net = core.Net("net") + c = input_channels + h = size + w = size + k = 3 + net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order) + net.SpatialBN( + ["Y", "scale", "bias", "mean", "var"], + ["Y2"], + is_test=True, + order=order, + epsilon=epsilon, + ) + + np.random.seed(seed) + if order == "NCHW": + workspace.FeedBlob("X", np.random.rand(1, c, h, w).astype(np.float32)) + workspace.FeedBlob("w", np.random.rand(c, c, k, k).astype(np.float32)) + else: + workspace.FeedBlob("X", np.random.rand(1, h, w, c).astype(np.float32)) + workspace.FeedBlob("w", np.random.rand(c, k, k, c).astype(np.float32)) + workspace.FeedBlob("scale", np.random.rand(c).astype(np.float32)) + workspace.FeedBlob("bias", np.random.rand(c).astype(np.float32)) + workspace.FeedBlob("mean", np.random.rand(c).astype(np.float32)) + # This is necessary because 1/sqrt(var) is used and if var is too small + # we get floating point artifacts that cause test failures + workspace.FeedBlob("var", np.random.rand(c).astype(np.float32) + 0.5) + workspace.RunNetOnce(net) + preTransformOutput = workspace.FetchBlob("Y2").flatten() + workspace.FeedBlob("Y2", np.zeros((1, 1))) + transformer.FuseConvBN(net) + + # Ensure fusion + assert len(net.Proto().op) == 1 + workspace.RunNetOnce(net) + postTransformOutput = workspace.FetchBlob("Y2").flatten() + # Check that there is no numerical difference + assert np.allclose( + preTransformOutput, + postTransformOutput, + rtol=5e-02, + atol=1e-03 + ) + + @given( + size=st.integers(7, 10), + input_channels=st.integers(1, 10), + seed=st.integers(0, 65535), + order=st.sampled_from(["NCHW", "NHWC"]), + epsilon=st.floats(min_value=1e-5, max_value=1e-2), + ) + def test_transformer_FuseConvBNNoConvBiasDuplicatedName(self, size, input_channels, seed, order, epsilon): + workspace.ResetWorkspace() + net = core.Net("net") + c = input_channels + h = size + w = size + k = 3 + net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order) + net.SpatialBN( + ["Y", "scale", "_bias0", "mean", "var"], + ["Y2"], + is_test=True, + order=order, + epsilon=epsilon, + ) + + np.random.seed(seed) + if order == "NCHW": + workspace.FeedBlob("X", np.random.rand(1, c, h, w).astype(np.float32)) + workspace.FeedBlob("w", np.random.rand(c, c, k, k).astype(np.float32)) + else: + workspace.FeedBlob("X", np.random.rand(1, h, w, c).astype(np.float32)) + workspace.FeedBlob("w", np.random.rand(c, k, k, c).astype(np.float32)) + workspace.FeedBlob("scale", np.random.rand(c).astype(np.float32)) + workspace.FeedBlob("_bias0", np.random.rand(c).astype(np.float32)) + workspace.FeedBlob("mean", np.random.rand(c).astype(np.float32)) + # This is necessary because 1/sqrt(var) is used and if var is too small + # we get floating point artifacts that cause test failures + workspace.FeedBlob("var", np.random.rand(c).astype(np.float32) + 0.5) + workspace.RunNetOnce(net) + preTransformOutput = workspace.FetchBlob("Y2").flatten() + workspace.FeedBlob("Y2", np.zeros((1, 1))) + transformer.FuseConvBN(net) + + # Ensure fusion + assert len(net.Proto().op) == 1 + workspace.RunNetOnce(net) + postTransformOutput = workspace.FetchBlob("Y2").flatten() + print("pre") + print(preTransformOutput) + print("after") + print(postTransformOutput) + # Check that there is no numerical difference + assert np.allclose( + preTransformOutput, + postTransformOutput, + rtol=5e-02, atol=1e-03 )