pytorch/caffe2/python/layers/bucket_weighted.py
Frank Jiang b827a40880 Implement bucket-based attention pooling for IdScoreList features (#13004)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13004

Implement BucketWeighted model layer, which learns a weight for each possible score in an IdScoreList. Here, we assume that the scores in the IdScoreList have already been converted into the appropriate 'buckets'. If this is not done, then essentially each score represents its own bucket.

We assume that the scores/buckets are integers, and if max_score is not set, we assume that the maximum cardinality of the score is less than or equal to the cardinality of the ids.

Reviewed By: chonglinsun

Differential Revision: D10413186

fbshipit-source-id: 743e643a1b36adf124502a8b6b29976158cdb130
2018-10-25 18:04:08 -07:00

68 lines
2.2 KiB
Python

## @package bucket_weighted
# Module caffe2.python.layers.bucket_weighted
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import numpy as np
from caffe2.python import core, schema
from caffe2.python.layers.layers import (
get_categorical_limit,
ModelLayer,
)
from caffe2.python.layers.tags import Tags
logger = logging.getLogger(__name__)
class BucketWeighted(ModelLayer):
def __init__(self, model, input_record, max_score=0, bucket_boundaries=None,
weight_optim=None, name="bucket_weighted"):
super(BucketWeighted, self).__init__(model, name, input_record)
assert isinstance(input_record, schema.List), "Incorrect input type"
self.bucket_boundaries = bucket_boundaries
if bucket_boundaries is not None:
self.shape = len(bucket_boundaries) + 1
elif max_score > 0:
self.shape = max_score
else:
self.shape = get_categorical_limit(input_record)
self.bucket_w = self.create_param(param_name='bucket_w',
shape=[self.shape, ],
initializer=('ConstantFill', {'value': 1.0}),
optimizer=weight_optim)
self.output_schema = schema.Struct(
('bucket_weights',
schema.Scalar((np.float32, self.shape),
self.get_next_blob_reference("bucket_w_gather")))
)
self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
def get_memory_usage(self):
return self.shape
def add_ops(self, net):
if self.bucket_boundaries is not None:
buckets = net.Bucketize(
self.input_record.values(),
"buckets",
boundaries=self.bucket_boundaries
)
else:
buckets = self.input_record.values()
buckets_int = net.Cast(
buckets,
"buckets_int",
to=core.DataType.INT32
)
net.Gather(
[self.bucket_w, buckets_int],
self.output_schema.bucket_weights.field_blobs())