mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
(1) Loss: do not coerce a gradient output. Although it may be numerically more efficient to do so, it makes the definition of a loss kind of funny if one does not really want to run backward pass. (2) Autodifferentiation: allow more explicit in-place check, in-place is now opt-in, and implemented a simple SSA/IR gradient generation scheme. Also added some core gradient tests. Misc bugfixes as well.
109 lines
4.4 KiB
Python
109 lines
4.4 KiB
Python
from pycaffe2 import core
|
|
|
|
class CNNModelHelper(object):
|
|
"""A helper model so we can write CNN models more easily, without having to
|
|
manually define parameter initializations and operators separately.
|
|
"""
|
|
def __init__(self, order, name=None):
|
|
if name is None:
|
|
name = "CNN"
|
|
self.net = core.Net(name)
|
|
self.param_init_net = core.Net(name + '_init')
|
|
self.params = []
|
|
self.weights = []
|
|
self.biases = []
|
|
self.order = order
|
|
if self.order != "NHWC" and self.order != "NCHW":
|
|
raise ValueError("Cannot understand the CNN storage order %s."
|
|
% self.order)
|
|
|
|
def Proto(self):
|
|
return self.net.Proto()
|
|
|
|
def Conv(self, blob_in, blob_out, dim_in, dim_out, kernel,
|
|
weight_init, bias_init, **kwargs):
|
|
"""Convolution. We intentionally do not provide odd kernel/stride/pad
|
|
settings in order to discourage the use of odd cases.
|
|
"""
|
|
weight_shape = ([dim_out, dim_in, kernel, kernel] if self.order == "NCHW"
|
|
else [dim_out, kernel, kernel, dim_in])
|
|
weight = self.param_init_net.__getattr__(weight_init[0])(
|
|
[], blob_out + '_w', shape=weight_shape, **weight_init[1])
|
|
bias = self.param_init_net.__getattr__(bias_init[0])(
|
|
[], blob_out + '_b', shape=[dim_out,], **bias_init[1])
|
|
self.params.extend([weight, bias])
|
|
self.weights.append(weight)
|
|
self.biases.append(bias)
|
|
return self.net.Conv([blob_in, weight, bias], blob_out, kernel=kernel,
|
|
order=self.order, **kwargs)
|
|
|
|
def GroupConv(self, blob_in, blob_out, dim_in, dim_out, kernel,
|
|
weight_init, bias_init, group=1, **kwargs):
|
|
"""Convolution. We intentionally do not provide odd kernel/stride/pad
|
|
settings in order to discourage the use of odd cases.
|
|
"""
|
|
if dim_in % group:
|
|
raise ValueError("dim_in should be divisible by group.")
|
|
splitted_blobs = self.net.DepthSplit(
|
|
blob_in,
|
|
['_' + blob_out + '_gconv_split_' + str(i) for i in range(group)],
|
|
dimensions=[dim_in / group for i in range(group)],
|
|
order=self.order)
|
|
weight_shape = ([dim_out / group, dim_in / group, kernel, kernel]
|
|
if self.order == "NCHW"
|
|
else [dim_out / group, kernel, kernel, dim_in / group])
|
|
conv_blobs = []
|
|
for i in range(group):
|
|
weight = self.param_init_net.__getattr__(weight_init[0])(
|
|
[], blob_out + '_gconv_%d_w' % i, shape=weight_shape,
|
|
**weight_init[1])
|
|
bias = self.param_init_net.__getattr__(bias_init[0])(
|
|
[], blob_out + '_gconv_%d_b' % i, shape=[dim_out / group],
|
|
**bias_init[1])
|
|
self.params.extend([weight, bias])
|
|
self.weights.append(weight)
|
|
self.biases.append(bias)
|
|
conv_blobs.append(
|
|
splitted_blobs[i].Conv([weight, bias], blob_out + '_gconv_%d' % i,
|
|
kernel=kernel, order=self.order, **kwargs))
|
|
concat, concat_dims = self.net.DepthConcat(
|
|
conv_blobs, [blob_out, "_" + blob_out + "_concat_dims"],
|
|
order=self.order)
|
|
return concat
|
|
|
|
def FC(self, blob_in, blob_out, dim_in, dim_out, weight_init, bias_init,
|
|
**kwargs):
|
|
"""FC"""
|
|
weight = self.param_init_net.__getattr__(weight_init[0])(
|
|
[], blob_out + '_w', shape=[dim_out, dim_in], **weight_init[1])
|
|
bias = self.param_init_net.__getattr__(bias_init[0])(
|
|
[], blob_out + '_b', shape=[dim_out,], **bias_init[1])
|
|
self.params.extend([weight, bias])
|
|
return self.net.FC([blob_in, weight, bias], blob_out, **kwargs)
|
|
|
|
def LRN(self, blob_in, blob_out, **kwargs):
|
|
"""LRN"""
|
|
return self.net.LRN(blob_in, [blob_out, "_" + blob_out + "_scale"],
|
|
order=self.order, **kwargs)[0]
|
|
|
|
def Dropout(self, blob_in, blob_out, **kwargs):
|
|
"""Dropout"""
|
|
return self.net.Dropout(blob_in, [blob_out, "_" + blob_out + "_mask"],
|
|
**kwargs)[0]
|
|
|
|
def MaxPool(self, blob_in, blob_out, **kwargs):
|
|
"""Max pooling"""
|
|
return self.net.MaxPool(blob_in, [blob_out, "_" + blob_out + "_idx"],
|
|
order=self.order, **kwargs)[0]
|
|
|
|
def AddGradientOperators(self):
|
|
self.net.AddGradientOperators()
|
|
|
|
def __getattr__(self, op_type):
|
|
"""Catch-all for all other operators, mostly those without params."""
|
|
if not core.IsOperator(op_type):
|
|
raise RuntimeError(
|
|
'Method ' + op_type + ' is not a registered operator.')
|
|
return self.net.__getattr__(op_type)
|
|
|
|
|