pytorch/caffe2/python/modeling/compute_statistics_for_blobs.py
Paul Jesse Hellemn 771fcb3455 [caffe2] Fbcode to GitHub sync (#6208)
* [easy] allow empty tensor in cuda relu op

The diff has not enabled unit test of empty tensor, because MLKVersion of ReluOp need extra work to support

* Make blob norm plotting work with distributed trainer when the old framework is used
2018-04-02 16:35:27 -07:00

56 lines
2 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, schema
from caffe2.python.modeling.net_modifier import NetModifier
import numpy as np
class ComputeStatisticsForBlobs(NetModifier):
"""
This class modifies the net passed in by adding ops to compute statistics
for certain blobs. For each blob in the list, its min, max, mean and standard
deviation will be computed.
Args:
blobs: list of blobs to compute norm for
logging_frequency: frequency for printing norms to logs
"""
def __init__(self, blobs, logging_frequency):
self._blobs = blobs
self._logging_frequency = logging_frequency
self._field_name_suffix = '_summary'
def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
modify_output_record=False):
for blob_name in self._blobs:
blob = core.BlobReference(blob_name)
if not net.BlobIsDefined(blob):
raise Exception('blob {0} is not defined in net {1}'.format(
blob, net.Name()))
cast_blob = net.Cast(blob, to=core.DataType.FLOAT)
stats_name = net.NextScopedBlob(prefix=blob + self._field_name_suffix)
stats = net.Summarize(cast_blob, stats_name, to_file=0)
net.Print(stats, [], every_n=self._logging_frequency)
if modify_output_record:
output_field_name = str(blob) + self._field_name_suffix
output_scalar = schema.Scalar((np.float, (1,)), stats)
if net.output_record() is None:
net.set_output_record(
schema.Struct((output_field_name, output_scalar))
)
else:
net.AppendOutputRecordField(
output_field_name,
output_scalar)
def field_name_suffix(self):
return self._field_name_suffix