mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add fusion support for batchnorm and convolution without bias
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10595 Reviewed By: bwasti Differential Revision: D9110099 fbshipit-source-id: e1ed66c7d82b2f9987b7eb9c7f98877a6dbeb902
This commit is contained in:
parent
d35f365ad5
commit
f3ac619764
2 changed files with 156 additions and 8 deletions
|
|
@ -11,12 +11,17 @@ using namespace nom;
|
|||
// $$ X_{bn} = \frac{s(X - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
|
||||
// $$ X_{conv} = X * W + b_{conv} $$
|
||||
// thus, substituting $X$ with $X_{conv}$ in the BN equation we get:
|
||||
// $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
|
||||
// or
|
||||
// $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} -
|
||||
// m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ or
|
||||
// $$ W' = W\frac{s}{\sqrt{\sigma + \epsilon}}$$
|
||||
// $$ b' = (b_{conv} - m)\frac{s}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
|
||||
bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
for (auto convNode : repr::nn::nodeIterator<repr::Conv>(nn->dataFlow)) {
|
||||
size_t convOrder = 0;
|
||||
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
|
||||
repr::NNGraph::NodeRef convNode;
|
||||
repr::Conv* conv;
|
||||
std::tie(conv, convNode) = node_pair;
|
||||
|
||||
auto output = repr::nn::getOutputs(convNode).front();
|
||||
auto consumers = repr::nn::getConsumers(output);
|
||||
NOM_REQUIRE_OR_CONT(consumers.size() == 1);
|
||||
|
|
@ -31,9 +36,9 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|||
auto bnOutput = bnOutputs.front();
|
||||
|
||||
auto convInputs = repr::nn::getInputs(convNode);
|
||||
CAFFE_ENFORCE(
|
||||
convInputs.size() >= 3,
|
||||
"Invalid convolution input size (TODO: optional bias)");
|
||||
if (convInputs.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto bnInputs = repr::nn::getInputs(bnNode);
|
||||
CAFFE_ENFORCE(
|
||||
|
|
@ -46,13 +51,46 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
|
|||
auto name##Data = name##Tensor->mutable_data<float>();
|
||||
|
||||
EXPOSE_TENSOR_DATA(filter, 1, convInputs);
|
||||
EXPOSE_TENSOR_DATA(biasConv, 2, convInputs);
|
||||
|
||||
EXPOSE_TENSOR_DATA(scale, 1, bnInputs);
|
||||
EXPOSE_TENSOR_DATA(biasBN, 2, bnInputs);
|
||||
EXPOSE_TENSOR_DATA(mean, 3, bnInputs);
|
||||
EXPOSE_TENSOR_DATA(variance, 4, bnInputs);
|
||||
|
||||
if (convInputs.size() == 2) {
|
||||
NOM_REQUIRE_OR_CONT(conv->getMutableAnnotation() != nullptr);
|
||||
auto annotation =
|
||||
dyn_cast<caffe2::Caffe2Annotation>(conv->getMutableAnnotation());
|
||||
NOM_REQUIRE_OR_CONT(annotation != nullptr);
|
||||
auto op = annotation->getOperatorDef();
|
||||
auto convName = op.name();
|
||||
|
||||
while (true) {
|
||||
auto convBiasName = convName + "_bias" + to_string(convOrder);
|
||||
if (!ws->HasBlob(convBiasName)) {
|
||||
auto convBiasTensor = make_unique<repr::Tensor>(convBiasName);
|
||||
convBiasTensor->setType(repr::Tensor::DataType::Float);
|
||||
auto convBiasNode = nn->dataFlow.createNode(
|
||||
unique_dyn_cast<repr::NeuralNetData>(convBiasTensor));
|
||||
nn->inputs.insert(convBiasNode);
|
||||
nn->dataFlow.createEdge(convBiasNode, convNode);
|
||||
|
||||
auto* blob = ws->CreateBlob(convBiasName);
|
||||
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
// Get output channel
|
||||
size_t c = filterTensor->dim32(0);
|
||||
tensor->Resize(c);
|
||||
tensor->mutable_data<float>();
|
||||
break;
|
||||
}
|
||||
convOrder++;
|
||||
}
|
||||
}
|
||||
|
||||
convInputs = repr::nn::getInputs(convNode);
|
||||
EXPOSE_TENSOR_DATA(biasConv, 2, convInputs);
|
||||
|
||||
#undef EXPOSE_TENSOR_DATA
|
||||
|
||||
// Assume M{CHW,HWC}
|
||||
|
|
|
|||
|
|
@ -221,6 +221,116 @@ class TestTransformations(test_util.TestCase):
|
|||
assert np.allclose(
|
||||
preTransformOutput,
|
||||
postTransformOutput,
|
||||
rtol=1e-02,
|
||||
rtol=5e-02,
|
||||
atol=1e-03
|
||||
)
|
||||
|
||||
@given(
|
||||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 10),
|
||||
seed=st.integers(0, 65535),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]),
|
||||
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
|
||||
)
|
||||
def test_transformer_FuseConvBNNoConvBias(self, size, input_channels, seed, order, epsilon):
|
||||
workspace.ResetWorkspace()
|
||||
net = core.Net("net")
|
||||
c = input_channels
|
||||
h = size
|
||||
w = size
|
||||
k = 3
|
||||
net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
|
||||
net.SpatialBN(
|
||||
["Y", "scale", "bias", "mean", "var"],
|
||||
["Y2"],
|
||||
is_test=True,
|
||||
order=order,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
|
||||
np.random.seed(seed)
|
||||
if order == "NCHW":
|
||||
workspace.FeedBlob("X", np.random.rand(1, c, h, w).astype(np.float32))
|
||||
workspace.FeedBlob("w", np.random.rand(c, c, k, k).astype(np.float32))
|
||||
else:
|
||||
workspace.FeedBlob("X", np.random.rand(1, h, w, c).astype(np.float32))
|
||||
workspace.FeedBlob("w", np.random.rand(c, k, k, c).astype(np.float32))
|
||||
workspace.FeedBlob("scale", np.random.rand(c).astype(np.float32))
|
||||
workspace.FeedBlob("bias", np.random.rand(c).astype(np.float32))
|
||||
workspace.FeedBlob("mean", np.random.rand(c).astype(np.float32))
|
||||
# This is necessary because 1/sqrt(var) is used and if var is too small
|
||||
# we get floating point artifacts that cause test failures
|
||||
workspace.FeedBlob("var", np.random.rand(c).astype(np.float32) + 0.5)
|
||||
workspace.RunNetOnce(net)
|
||||
preTransformOutput = workspace.FetchBlob("Y2").flatten()
|
||||
workspace.FeedBlob("Y2", np.zeros((1, 1)))
|
||||
transformer.FuseConvBN(net)
|
||||
|
||||
# Ensure fusion
|
||||
assert len(net.Proto().op) == 1
|
||||
workspace.RunNetOnce(net)
|
||||
postTransformOutput = workspace.FetchBlob("Y2").flatten()
|
||||
# Check that there is no numerical difference
|
||||
assert np.allclose(
|
||||
preTransformOutput,
|
||||
postTransformOutput,
|
||||
rtol=5e-02,
|
||||
atol=1e-03
|
||||
)
|
||||
|
||||
@given(
|
||||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 10),
|
||||
seed=st.integers(0, 65535),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]),
|
||||
epsilon=st.floats(min_value=1e-5, max_value=1e-2),
|
||||
)
|
||||
def test_transformer_FuseConvBNNoConvBiasDuplicatedName(self, size, input_channels, seed, order, epsilon):
|
||||
workspace.ResetWorkspace()
|
||||
net = core.Net("net")
|
||||
c = input_channels
|
||||
h = size
|
||||
w = size
|
||||
k = 3
|
||||
net.Conv(["X", "w"], ["Y"], stride=1, pad=0, kernel=k, order=order)
|
||||
net.SpatialBN(
|
||||
["Y", "scale", "_bias0", "mean", "var"],
|
||||
["Y2"],
|
||||
is_test=True,
|
||||
order=order,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
|
||||
np.random.seed(seed)
|
||||
if order == "NCHW":
|
||||
workspace.FeedBlob("X", np.random.rand(1, c, h, w).astype(np.float32))
|
||||
workspace.FeedBlob("w", np.random.rand(c, c, k, k).astype(np.float32))
|
||||
else:
|
||||
workspace.FeedBlob("X", np.random.rand(1, h, w, c).astype(np.float32))
|
||||
workspace.FeedBlob("w", np.random.rand(c, k, k, c).astype(np.float32))
|
||||
workspace.FeedBlob("scale", np.random.rand(c).astype(np.float32))
|
||||
workspace.FeedBlob("_bias0", np.random.rand(c).astype(np.float32))
|
||||
workspace.FeedBlob("mean", np.random.rand(c).astype(np.float32))
|
||||
# This is necessary because 1/sqrt(var) is used and if var is too small
|
||||
# we get floating point artifacts that cause test failures
|
||||
workspace.FeedBlob("var", np.random.rand(c).astype(np.float32) + 0.5)
|
||||
workspace.RunNetOnce(net)
|
||||
preTransformOutput = workspace.FetchBlob("Y2").flatten()
|
||||
workspace.FeedBlob("Y2", np.zeros((1, 1)))
|
||||
transformer.FuseConvBN(net)
|
||||
|
||||
# Ensure fusion
|
||||
assert len(net.Proto().op) == 1
|
||||
workspace.RunNetOnce(net)
|
||||
postTransformOutput = workspace.FetchBlob("Y2").flatten()
|
||||
print("pre")
|
||||
print(preTransformOutput)
|
||||
print("after")
|
||||
print(postTransformOutput)
|
||||
# Check that there is no numerical difference
|
||||
assert np.allclose(
|
||||
preTransformOutput,
|
||||
postTransformOutput,
|
||||
rtol=5e-02,
|
||||
atol=1e-03
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue