mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13004 Implement BucketWeighted model layer, which learns a weight for each possible score in an IdScoreList. Here, we assume that the scores in the IdScoreList have already been converted into the appropriate 'buckets'. If this is not done, then essentially each score represents its own bucket. We assume that the scores/buckets are integers, and if max_score is not set, we assume that the maximum cardinality of the score is less than or equal to the cardinality of the ids. Reviewed By: chonglinsun Differential Revision: D10413186 fbshipit-source-id: 743e643a1b36adf124502a8b6b29976158cdb130
68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
## @package bucket_weighted
|
|
# Module caffe2.python.layers.bucket_weighted
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import logging
|
|
import numpy as np
|
|
|
|
from caffe2.python import core, schema
|
|
from caffe2.python.layers.layers import (
|
|
get_categorical_limit,
|
|
ModelLayer,
|
|
)
|
|
|
|
from caffe2.python.layers.tags import Tags
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BucketWeighted(ModelLayer):
|
|
def __init__(self, model, input_record, max_score=0, bucket_boundaries=None,
|
|
weight_optim=None, name="bucket_weighted"):
|
|
super(BucketWeighted, self).__init__(model, name, input_record)
|
|
|
|
assert isinstance(input_record, schema.List), "Incorrect input type"
|
|
self.bucket_boundaries = bucket_boundaries
|
|
if bucket_boundaries is not None:
|
|
self.shape = len(bucket_boundaries) + 1
|
|
elif max_score > 0:
|
|
self.shape = max_score
|
|
else:
|
|
self.shape = get_categorical_limit(input_record)
|
|
|
|
self.bucket_w = self.create_param(param_name='bucket_w',
|
|
shape=[self.shape, ],
|
|
initializer=('ConstantFill', {'value': 1.0}),
|
|
optimizer=weight_optim)
|
|
|
|
self.output_schema = schema.Struct(
|
|
('bucket_weights',
|
|
schema.Scalar((np.float32, self.shape),
|
|
self.get_next_blob_reference("bucket_w_gather")))
|
|
)
|
|
|
|
self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
|
|
|
|
def get_memory_usage(self):
|
|
return self.shape
|
|
|
|
def add_ops(self, net):
|
|
if self.bucket_boundaries is not None:
|
|
buckets = net.Bucketize(
|
|
self.input_record.values(),
|
|
"buckets",
|
|
boundaries=self.bucket_boundaries
|
|
)
|
|
else:
|
|
buckets = self.input_record.values()
|
|
buckets_int = net.Cast(
|
|
buckets,
|
|
"buckets_int",
|
|
to=core.DataType.INT32
|
|
)
|
|
net.Gather(
|
|
[self.bucket_w, buckets_int],
|
|
self.output_schema.bucket_weights.field_blobs())
|