2017-09-28 23:00:15 +00:00
|
|
|
# Copyright (c) 2016-present, Facebook, Inc.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
##############################################################################
|
|
|
|
|
|
2017-03-29 13:44:02 +00:00
|
|
|
## @package cnn
|
|
|
|
|
# Module caffe2.python.cnn
|
2016-11-14 22:58:04 +00:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
from __future__ import division
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
2017-04-25 22:59:13 +00:00
|
|
|
from caffe2.python import brew
|
2017-04-23 10:23:37 +00:00
|
|
|
from caffe2.python.model_helper import ModelHelper
|
2016-05-13 21:43:48 +00:00
|
|
|
from caffe2.proto import caffe2_pb2
|
2017-05-19 06:25:14 +00:00
|
|
|
import logging
|
2016-05-13 21:43:48 +00:00
|
|
|
|
2016-07-21 17:16:42 +00:00
|
|
|
|
2017-04-23 10:23:37 +00:00
|
|
|
class CNNModelHelper(ModelHelper):
|
2016-05-13 21:43:48 +00:00
|
|
|
"""A helper model so we can write CNN models more easily, without having to
|
|
|
|
|
manually define parameter initializations and operators separately.
|
2015-09-10 03:33:34 +00:00
|
|
|
"""
|
2016-05-13 21:43:48 +00:00
|
|
|
|
|
|
|
|
def __init__(self, order="NCHW", name=None,
|
2016-12-22 20:43:52 +00:00
|
|
|
use_cudnn=True, cudnn_exhaustive_search=False,
|
disable local update for sparse features
Summary:
With parameter server, sparse features are updated on the parameter server.
Local update for sparse features are disabled. But that logic is removed in
D4144922. This diff is to add this logic back in a slightly different way.
Previously, in trainer_example, I did that in a hacky way just avoid adding
sparse weight to model.params. It will still generate grad, but will not add
optimization operators. At the same time, it is always registered directly in
the sparse_mapping, so the parameter server is aware of this parameter.
But with the new change for ParameterInfo. I can not do it in that way anymore.
Because the param registry and params are bind together in ParameterInfo.
For dper, there is a option in dper model helper to disable all of the sparse
parameter optimizer.
To combine these two together, I directly changed the ModelHelperBase in this
diff. It is not quite ideal. It is better to do it in Layer. But to fix the old
one, this seems to be more reasonable place to cover both cases.
With this diff, there is no spike anymore. So probably this is the root cause
for the convergence issue we have seen in D4144922. It explains that why the
model can recover, which is because adagrad decays local learning rate and
local updates cause less change.
Reviewed By: dzhulgakov
Differential Revision: D4229684
fbshipit-source-id: da1241d43d7c52cbf13560f9bb83e09897d8d56f
2016-11-28 22:17:04 +00:00
|
|
|
ws_nbytes_limit=None, init_params=True,
|
2016-12-21 04:43:37 +00:00
|
|
|
skip_sparse_optim=False,
|
|
|
|
|
param_model=None):
|
2017-05-19 06:25:14 +00:00
|
|
|
logging.warning(
|
|
|
|
|
"[====DEPRECATE WARNING====]: you are creating an "
|
|
|
|
|
"object from CNNModelHelper class which will be deprecated soon. "
|
|
|
|
|
"Please use ModelHelper object with brew module. For more "
|
|
|
|
|
"information, please refer to caffe2.ai and python/brew.py, "
|
|
|
|
|
"python/brew_test.py for more information."
|
|
|
|
|
)
|
2016-12-21 04:43:37 +00:00
|
|
|
|
2017-05-12 18:04:01 +00:00
|
|
|
cnn_arg_scope = {
|
|
|
|
|
'order': order,
|
|
|
|
|
'use_cudnn': use_cudnn,
|
|
|
|
|
'cudnn_exhaustive_search': cudnn_exhaustive_search,
|
|
|
|
|
}
|
|
|
|
|
if ws_nbytes_limit:
|
|
|
|
|
cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
|
2016-10-07 20:08:53 +00:00
|
|
|
super(CNNModelHelper, self).__init__(
|
2016-12-21 04:43:37 +00:00
|
|
|
skip_sparse_optim=skip_sparse_optim,
|
|
|
|
|
name="CNN" if name is None else name,
|
|
|
|
|
init_params=init_params,
|
|
|
|
|
param_model=param_model,
|
2017-05-12 18:04:01 +00:00
|
|
|
arg_scope=cnn_arg_scope,
|
2016-12-21 04:43:37 +00:00
|
|
|
)
|
2016-10-07 20:08:53 +00:00
|
|
|
|
2016-05-13 21:43:48 +00:00
|
|
|
self.order = order
|
|
|
|
|
self.use_cudnn = use_cudnn
|
|
|
|
|
self.cudnn_exhaustive_search = cudnn_exhaustive_search
|
|
|
|
|
self.ws_nbytes_limit = ws_nbytes_limit
|
|
|
|
|
if self.order != "NHWC" and self.order != "NCHW":
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Cannot understand the CNN storage order %s." % self.order
|
|
|
|
|
)
|
|
|
|
|
|
2017-04-25 22:59:13 +00:00
|
|
|
def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
|
|
|
|
|
return brew.image_input(
|
|
|
|
|
self,
|
|
|
|
|
blob_in,
|
|
|
|
|
blob_out,
|
|
|
|
|
order=self.order,
|
|
|
|
|
use_gpu_transform=use_gpu_transform,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
2016-05-13 21:43:48 +00:00
|
|
|
|
2017-05-05 21:10:07 +00:00
|
|
|
def VideoInput(self, blob_in, blob_out, **kwargs):
|
|
|
|
|
return brew.video_input(
|
|
|
|
|
self,
|
|
|
|
|
blob_in,
|
|
|
|
|
blob_out,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
|
2017-04-17 21:58:38 +00:00
|
|
|
def PadImage(self, blob_in, blob_out, **kwargs):
|
2017-04-25 22:59:13 +00:00
|
|
|
# TODO(wyiming): remove this dummy helper later
|
2017-04-17 21:58:38 +00:00
|
|
|
self.net.PadImage(blob_in, blob_out, **kwargs)
|
2016-12-19 22:49:46 +00:00
|
|
|
|
2017-04-17 21:58:38 +00:00
|
|
|
def ConvNd(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.conv_nd(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
use_cudnn=self.use_cudnn,
|
|
|
|
|
order=self.order,
|
|
|
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
|
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
2017-04-17 21:58:38 +00:00
|
|
|
|
|
|
|
|
def Conv(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.conv(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
use_cudnn=self.use_cudnn,
|
|
|
|
|
order=self.order,
|
|
|
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
|
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
2017-04-17 21:58:38 +00:00
|
|
|
|
|
|
|
|
def ConvTranspose(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.conv_transpose(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
use_cudnn=self.use_cudnn,
|
|
|
|
|
order=self.order,
|
|
|
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
|
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
2017-04-17 21:58:38 +00:00
|
|
|
|
|
|
|
|
def GroupConv(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.group_conv(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
use_cudnn=self.use_cudnn,
|
|
|
|
|
order=self.order,
|
|
|
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
|
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
2016-08-10 18:02:15 +00:00
|
|
|
|
2017-04-17 21:58:38 +00:00
|
|
|
def GroupConv_Deprecated(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.group_conv_deprecated(
|
|
|
|
|
self,
|
|
|
|
|
*args,
|
|
|
|
|
use_cudnn=self.use_cudnn,
|
|
|
|
|
order=self.order,
|
|
|
|
|
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
|
|
|
|
ws_nbytes_limit=self.ws_nbytes_limit,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
2016-05-13 21:43:48 +00:00
|
|
|
|
2016-10-07 20:08:53 +00:00
|
|
|
def FC(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.fc(self, *args, **kwargs)
|
2016-10-07 20:08:53 +00:00
|
|
|
|
|
|
|
|
def PackedFC(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.packed_fc(self, *args, **kwargs)
|
2016-07-21 17:16:42 +00:00
|
|
|
|
2017-04-08 00:29:17 +00:00
|
|
|
def FC_Prune(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.fc_prune(self, *args, **kwargs)
|
2016-07-21 17:16:42 +00:00
|
|
|
|
2017-04-08 00:29:17 +00:00
|
|
|
def FC_Decomp(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.fc_decomp(self, *args, **kwargs)
|
2016-07-21 17:16:42 +00:00
|
|
|
|
2017-04-08 00:29:17 +00:00
|
|
|
def FC_Sparse(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.fc_sparse(self, *args, **kwargs)
|
2017-04-08 00:29:17 +00:00
|
|
|
|
|
|
|
|
def Dropout(self, *args, **kwargs):
|
2017-05-08 23:19:56 +00:00
|
|
|
return brew.dropout(
|
|
|
|
|
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
|
|
|
|
)
|
2016-05-13 21:43:48 +00:00
|
|
|
|
2017-04-17 21:58:33 +00:00
|
|
|
def LRN(self, *args, **kwargs):
|
2017-05-08 23:19:56 +00:00
|
|
|
return brew.lrn(
|
|
|
|
|
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
|
|
|
|
)
|
2017-04-17 21:58:33 +00:00
|
|
|
|
|
|
|
|
def Softmax(self, *args, **kwargs):
|
2017-04-25 22:59:13 +00:00
|
|
|
return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
|
2017-04-17 21:58:33 +00:00
|
|
|
|
|
|
|
|
def SpatialBN(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.spatial_bn(self, *args, order=self.order, **kwargs)
|
2017-04-17 21:58:33 +00:00
|
|
|
|
|
|
|
|
def InstanceNorm(self, *args, **kwargs):
|
2017-04-25 22:59:13 +00:00
|
|
|
return brew.instance_norm(self, *args, order=self.order, **kwargs)
|
2017-04-17 21:58:33 +00:00
|
|
|
|
2017-04-17 21:58:34 +00:00
|
|
|
def Relu(self, *args, **kwargs):
|
2017-04-25 22:59:13 +00:00
|
|
|
return brew.relu(
|
|
|
|
|
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
|
|
|
|
)
|
2017-04-17 21:58:34 +00:00
|
|
|
|
|
|
|
|
def PRelu(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.prelu(self, *args, **kwargs)
|
2017-04-17 21:58:34 +00:00
|
|
|
|
2017-04-17 21:58:36 +00:00
|
|
|
def Concat(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.concat(self, *args, order=self.order, **kwargs)
|
2017-04-17 21:58:36 +00:00
|
|
|
|
|
|
|
|
def DepthConcat(self, *args, **kwargs):
|
|
|
|
|
"""The old depth concat function - we should move to use concat."""
|
|
|
|
|
print("DepthConcat is deprecated. use Concat instead.")
|
|
|
|
|
return self.Concat(*args, **kwargs)
|
|
|
|
|
|
2017-04-17 21:58:37 +00:00
|
|
|
def Sum(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.sum(self, *args, **kwargs)
|
2017-04-17 21:58:37 +00:00
|
|
|
|
|
|
|
|
def Transpose(self, *args, **kwargs):
|
2017-04-25 22:59:13 +00:00
|
|
|
return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
|
2017-04-17 21:58:37 +00:00
|
|
|
|
|
|
|
|
def Iter(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.iter(self, *args, **kwargs)
|
2017-04-17 21:58:37 +00:00
|
|
|
|
|
|
|
|
def Accuracy(self, *args, **kwargs):
|
2017-04-23 10:23:37 +00:00
|
|
|
return brew.accuracy(self, *args, **kwargs)
|
2016-05-13 21:43:48 +00:00
|
|
|
|
2017-04-12 06:00:20 +00:00
|
|
|
def MaxPool(self, *args, **kwargs):
|
2017-04-25 22:59:13 +00:00
|
|
|
return brew.max_pool(
|
|
|
|
|
self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
|
|
|
|
|
)
|
2016-05-13 21:43:48 +00:00
|
|
|
|
2017-05-08 23:19:56 +00:00
|
|
|
def MaxPoolWithIndex(self, *args, **kwargs):
|
|
|
|
|
return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
|
|
|
|
|
|
2017-04-12 06:00:20 +00:00
|
|
|
def AveragePool(self, *args, **kwargs):
|
2017-04-25 22:59:13 +00:00
|
|
|
return brew.average_pool(
|
|
|
|
|
self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
|
|
|
|
|
)
|
2016-05-13 21:43:48 +00:00
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def XavierInit(self):
|
|
|
|
|
return ('XavierFill', {})
|
|
|
|
|
|
2016-07-21 17:16:42 +00:00
|
|
|
def ConstantInit(self, value):
|
|
|
|
|
return ('ConstantFill', dict(value=value))
|
|
|
|
|
|
2016-05-13 21:43:48 +00:00
|
|
|
@property
|
|
|
|
|
def MSRAInit(self):
|
|
|
|
|
return ('MSRAFill', {})
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def ZeroInit(self):
|
|
|
|
|
return ('ConstantFill', {})
|
|
|
|
|
|
2016-10-07 20:08:53 +00:00
|
|
|
def AddWeightDecay(self, weight_decay):
|
2017-04-25 22:59:13 +00:00
|
|
|
return brew.add_weight_decay(self, weight_decay)
|
2016-05-13 21:43:48 +00:00
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def CPU(self):
|
|
|
|
|
device_option = caffe2_pb2.DeviceOption()
|
|
|
|
|
device_option.device_type = caffe2_pb2.CPU
|
|
|
|
|
return device_option
|
2015-10-29 06:15:17 +00:00
|
|
|
|
2016-05-13 21:43:48 +00:00
|
|
|
@property
|
|
|
|
|
def GPU(self, gpu_id=0):
|
|
|
|
|
device_option = caffe2_pb2.DeviceOption()
|
|
|
|
|
device_option.device_type = caffe2_pb2.CUDA
|
|
|
|
|
device_option.cuda_gpu_id = gpu_id
|
|
|
|
|
return device_option
|