mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: It appears that my initial implementation was not really working when one starts doing nesting. This diff is fixing this by replacing itertools with something that is really easy to reason about. Reviewed By: idning Differential Revision: D6933763 fbshipit-source-id: f7a1de996d878a41bac2b2acd9d87a7c4b416778
170 lines
7.9 KiB
Python
170 lines
7.9 KiB
Python
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import brew, model_helper, scope
|
|
from caffe2.python.modeling.parameter_sharing import (
|
|
ParameterSharing,
|
|
parameter_sharing_context,
|
|
)
|
|
from caffe2.python.modeling.initializers import (
|
|
Initializer
|
|
)
|
|
import unittest
|
|
|
|
|
|
class ParameterSharingTest(unittest.TestCase):
|
|
|
|
def test_parameter_sharing_default_scopes(self):
|
|
# Test no sharing default scopes
|
|
param_1 = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_1, 'w')
|
|
with scope.NameScope('scope'):
|
|
param_2 = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_2, 'scope/w')
|
|
with scope.NameScope('scope_2'):
|
|
param_3 = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_3, 'scope/scope_2/w')
|
|
|
|
def test_parameter_sharing_nested_scopes(self):
|
|
# Test parameter sharing
|
|
with scope.NameScope('global_scope'):
|
|
with ParameterSharing({'model_b': 'model_a'}):
|
|
param_global = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_global, 'global_scope/w')
|
|
# This scope is overridden to match 'model_a'
|
|
with scope.NameScope('model_b'):
|
|
with ParameterSharing({'shared_scope': ''}):
|
|
param_4 = parameter_sharing_context.get_parameter_name(
|
|
'w')
|
|
self.assertEquals(param_4, 'global_scope/model_a/w')
|
|
with scope.NameScope('shared_scope'):
|
|
param_5 = parameter_sharing_context.\
|
|
get_parameter_name('w')
|
|
self.assertEquals(param_5, 'global_scope/model_a/w')
|
|
# This scope is supposed to have not sharing
|
|
with scope.NameScope('model_c'):
|
|
with ParameterSharing({'shared_scope': ''}):
|
|
param_4 = parameter_sharing_context.get_parameter_name(
|
|
'w')
|
|
self.assertEquals(param_4, 'global_scope/model_c/w')
|
|
with scope.NameScope('shared_scope'):
|
|
param_5 = parameter_sharing_context.\
|
|
get_parameter_name('w')
|
|
self.assertEquals(param_5, 'global_scope/model_c/w')
|
|
|
|
def test_parameter_sharing_subscopes(self):
|
|
# Sharing only one of the subscopes
|
|
with ParameterSharing({'global_scope/b': 'global_scope/a'}):
|
|
with scope.NameScope('global_scope'):
|
|
param_6 = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_6, 'global_scope/w')
|
|
with scope.NameScope('a'):
|
|
param_7 = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_7, 'global_scope/a/w')
|
|
with scope.NameScope('b'):
|
|
param_8 = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_8, 'global_scope/a/w')
|
|
with scope.NameScope('c'):
|
|
param_9 = parameter_sharing_context.get_parameter_name('w')
|
|
self.assertEquals(param_9, 'global_scope/c/w')
|
|
|
|
def test_create_param(self):
|
|
model = model_helper.ModelHelper(name="test")
|
|
# Test no sharing default scopes
|
|
p1 = model.create_param(
|
|
'w',
|
|
shape=[2],
|
|
initializer=Initializer("ConstantFill")
|
|
)
|
|
with scope.NameScope('some_global_scope'):
|
|
p2 = model.create_param(
|
|
'w',
|
|
shape=[2],
|
|
initializer=Initializer("ConstantFill")
|
|
)
|
|
self.assertNotEqual(model.get_param_info(p1), None)
|
|
self.assertNotEqual(model.get_param_info(p2), None)
|
|
self.assertNotEqual(model.get_param_info(p1), model.get_param_info(p2))
|
|
model.Validate()
|
|
|
|
def test_deep_hierarchy(self):
|
|
model = model_helper.ModelHelper(name="test")
|
|
with ParameterSharing({'a': 'b'}):
|
|
with scope.NameScope('a'):
|
|
with ParameterSharing({'c': 'd'}):
|
|
with scope.NameScope('c'):
|
|
with ParameterSharing({'e': 'f'}):
|
|
with scope.NameScope('e'):
|
|
p = model.create_param(
|
|
'w',
|
|
shape=[2],
|
|
initializer=Initializer("ConstantFill")
|
|
)
|
|
self.assertNotEqual(model.get_param_info(p), None)
|
|
|
|
|
|
def test_parameter_sharing_brew(self):
|
|
# Test no sharing default scopes
|
|
model = model_helper.ModelHelper(name="test")
|
|
data = model.net.AddExternalInput("data")
|
|
fc1 = brew.fc(model, data, "fc1", dim_in=16, dim_out=16)
|
|
# Shared params are expected to share the same shape and fail if it's
|
|
# not true
|
|
with self.assertRaises(AssertionError):
|
|
_ = brew.fc(model, data, "fc1", dim_in=2, dim_out=2) # noqa
|
|
|
|
output_blobs = set()
|
|
with scope.NameScope('some_global_scope'):
|
|
with scope.NameScope('model_a'):
|
|
output_blobs.add(str(brew.fc(model, fc1, 'output', 16, 16)))
|
|
with ParameterSharing({'model_b': 'model_a'}),\
|
|
scope.NameScope('model_b'):
|
|
with ParameterSharing({'shared_1': '', 'shared_2': ''}):
|
|
# All params in DenseLayers from shared_1, shared_2 and
|
|
# model_a are shared and will be pointing to:
|
|
# [some_global_scope/model_a/output_W,
|
|
# some_global_scope/model_a/output_b]
|
|
with scope.NameScope('shared_1'):
|
|
output_blobs.add(
|
|
str(brew.fc(model, fc1, 'output', 16, 16)))
|
|
with scope.NameScope('shared_2'):
|
|
output_blobs.add(
|
|
str(brew.fc(model, fc1, 'output', 16, 16)))
|
|
# Params of this layer are not shared with anyone unless
|
|
# there is some explicit sharing with model_a/unshared (not
|
|
# in this example).
|
|
# Names of the blobs are
|
|
# [some_global_scope/model_a/unshared/output_W,
|
|
# some_global_scope/model_a/unshared/output_b]
|
|
with scope.NameScope('unshared'):
|
|
output_blobs.add(
|
|
str(brew.fc(model, fc1, 'output', 16, 16)))
|
|
|
|
self.assertEqual(len(model._parameters_info), 6)
|
|
self.assertEqual(len(output_blobs), 4)
|
|
self.assertEqual(sorted(model._parameters_info.keys()), [
|
|
'fc1_b',
|
|
'fc1_w',
|
|
'some_global_scope/model_a/output_b',
|
|
'some_global_scope/model_a/output_w',
|
|
'some_global_scope/model_a/unshared/output_b',
|
|
'some_global_scope/model_a/unshared/output_w',
|
|
])
|
|
model.Validate()
|