mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Differential Revision: D50752096 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112675 Approved by: https://github.com/Skylion007
54 lines
1.9 KiB
Python
54 lines
1.9 KiB
Python
|
|
|
|
|
|
|
|
|
|
from caffe2.python import core, schema
|
|
from caffe2.python.modeling.net_modifier import NetModifier
|
|
|
|
import numpy as np
|
|
|
|
|
|
class ComputeStatisticsForBlobs(NetModifier):
|
|
"""
|
|
This class modifies the net passed in by adding ops to compute statistics
|
|
for certain blobs. For each blob in the list, its min, max, mean and standard
|
|
deviation will be computed.
|
|
|
|
Args:
|
|
blobs: list of blobs to compute norm for
|
|
logging_frequency: frequency for printing norms to logs
|
|
"""
|
|
|
|
def __init__(self, blobs, logging_frequency):
|
|
self._blobs = blobs
|
|
self._logging_frequency = logging_frequency
|
|
self._field_name_suffix = '_summary'
|
|
|
|
def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
|
|
modify_output_record=False):
|
|
|
|
for blob_name in self._blobs:
|
|
blob = core.BlobReference(blob_name)
|
|
assert net.BlobIsDefined(blob), 'blob {} is not defined in net {} whose proto is {}'.format(blob, net.Name(), net.Proto())
|
|
|
|
cast_blob = net.Cast(blob, to=core.DataType.FLOAT)
|
|
stats_name = net.NextScopedBlob(prefix=blob + self._field_name_suffix)
|
|
stats = net.Summarize(cast_blob, stats_name, to_file=0)
|
|
net.Print(stats, [], every_n=self._logging_frequency)
|
|
|
|
if modify_output_record:
|
|
output_field_name = str(blob) + self._field_name_suffix
|
|
output_scalar = schema.Scalar((np.float64, (1,)), stats)
|
|
|
|
if net.output_record() is None:
|
|
net.set_output_record(
|
|
schema.Struct((output_field_name, output_scalar))
|
|
)
|
|
else:
|
|
net.AppendOutputRecordField(
|
|
output_field_name,
|
|
output_scalar)
|
|
|
|
def field_name_suffix(self):
|
|
return self._field_name_suffix
|