mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add interface to provide blob types to shape&type inference (#9643)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9643 Current map interface assumes float data type, which is not always correct. Reviewed By: kennyhorror Differential Revision: D8455784 fbshipit-source-id: b94a31267760f7f97c15aa4b03008affc347fd10
This commit is contained in:
parent
7af5883860
commit
2b134c72e6
8 changed files with 106 additions and 4 deletions
|
|
@ -571,6 +571,31 @@ TensorShapes InferBlobShapesAndTypesFromMap(
|
|||
return InferBlobShapesAndTypes(blob_desc, nets);
|
||||
}
|
||||
|
||||
TensorShapes InferBlobShapesAndTypesFromMap(
|
||||
const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
|
||||
const CaffeMap<std::string, TensorProto_DataType>& blob_types,
|
||||
const vector<NetDef*>& nets) {
|
||||
CaffeMap<string, TensorShape> blob_desc;
|
||||
// Populate shapes from known blobs
|
||||
for (const auto& blob : blob_dimensions) {
|
||||
TensorShape tp;
|
||||
for (auto d : blob.second) {
|
||||
CAFFE_ENFORCE_GE(d, 0, blob.first);
|
||||
tp.add_dims(d);
|
||||
}
|
||||
auto blob_type = blob_types.find(blob.first);
|
||||
if (blob_type == blob_types.end()) {
|
||||
LOG(WARNING) << "Missing type of " << blob.first
|
||||
<< "; assuming to be UNDEFINED";
|
||||
tp.set_data_type(TensorProto_DataType_UNDEFINED);
|
||||
} else {
|
||||
tp.set_data_type(blob_type->second);
|
||||
}
|
||||
blob_desc[blob.first] = tp;
|
||||
}
|
||||
return InferBlobShapesAndTypes(blob_desc, nets);
|
||||
}
|
||||
|
||||
std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
|
||||
OperatorBase& op,
|
||||
const OperatorDef& op_def) {
|
||||
|
|
|
|||
|
|
@ -968,6 +968,11 @@ TensorShapes InferBlobShapesAndTypesFromMap(
|
|||
const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
|
||||
const vector<NetDef*>& nets);
|
||||
|
||||
TensorShapes InferBlobShapesAndTypesFromMap(
|
||||
const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
|
||||
const CaffeMap<std::string, TensorProto_DataType>& blob_types,
|
||||
const vector<NetDef*>& nets);
|
||||
|
||||
std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
|
||||
OperatorBase& op,
|
||||
const OperatorDef& op_def);
|
||||
|
|
|
|||
|
|
@ -93,7 +93,8 @@ OPERATOR_SCHEMA(CreateMutex)
|
|||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc("Creates an unlocked mutex and returns it in a unique_ptr blob.")
|
||||
.Output(0, "mutex_ptr", "Blob containing a std::unique_ptr<mutex>.");
|
||||
.Output(0, "mutex_ptr", "Blob containing a std::unique_ptr<mutex>.")
|
||||
.ScalarType(TensorProto_DataType_UNDEFINED);
|
||||
|
||||
OPERATOR_SCHEMA(AtomicFetchAdd)
|
||||
.NumInputs(3)
|
||||
|
|
|
|||
|
|
@ -58,7 +58,8 @@ OPERATOR_SCHEMA(CreateMap)
|
|||
.SetDoc("Create an empty map blob")
|
||||
.Arg("key_dtype", "Key's TensorProto::DataType (default INT32)")
|
||||
.Arg("value_dtype", "Value's TensorProto::DataType (default INT32)")
|
||||
.Output(0, "map blob", "Blob reference to the map");
|
||||
.Output(0, "map blob", "Blob reference to the map")
|
||||
.ScalarType(TensorProto_DataType_UNDEFINED);
|
||||
|
||||
OPERATOR_SCHEMA(KeyValueToMap)
|
||||
.NumInputs(2)
|
||||
|
|
|
|||
|
|
@ -2007,6 +2007,13 @@ i.e. `len(LENGTHS)`. Other dimensions are inherited from the input tensor.
|
|||
"OUTPUT",
|
||||
"Aggregated output tensor. Has the first dimension of K "
|
||||
"(the number of segments).");
|
||||
schema.TensorInferenceFunction(
|
||||
[](const OperatorDef&, const std::vector<TensorShape>& input_types) {
|
||||
std::vector<TensorShape> out(1);
|
||||
out[0] = input_types[0];
|
||||
out[0].set_dims(0, input_types[Reducer::kInputCount + 1].dims(0));
|
||||
return out;
|
||||
});
|
||||
ReducerDef::PopulateSchema(schema);
|
||||
}
|
||||
using Reducer = typename ReducerDef::template Reducer<T, Context>;
|
||||
|
|
|
|||
|
|
@ -431,6 +431,36 @@ class TestShapeInference(test_util.TestCase):
|
|||
self.assertEqual(shapes['E'], [10, 23, 9, 10])
|
||||
self.assertEqual(shapes['G'], [10, 23, 9, 2, 10])
|
||||
|
||||
def testConcatInt32(self):
|
||||
net = core.Net("concat")
|
||||
|
||||
net.Concat(["A", "B"], ["C", "splits"], axis=1)
|
||||
net.Concat(["C", "D"], ["E"], order="NCHW")
|
||||
net.Concat(["E", "F"], ["G"], add_axis=1, order="NHWC")
|
||||
(shapes, types) = workspace.InferShapesAndTypes(
|
||||
[net],
|
||||
blob_dimensions={
|
||||
'A': [10, 12, 9, 10],
|
||||
'B': [10, 9, 9, 10],
|
||||
'D': [10, 2, 9, 10],
|
||||
'F': [10, 23, 9, 10]
|
||||
},
|
||||
blob_types={
|
||||
'A': core.DataType.INT32,
|
||||
'B': core.DataType.INT32,
|
||||
'D': core.DataType.INT32,
|
||||
'F': core.DataType.INT32,
|
||||
}
|
||||
)
|
||||
self.assertEqual(shapes['C'], [10, 21, 9, 10])
|
||||
self.assertEqual(shapes['splits'], [2])
|
||||
self.assertEqual(shapes['E'], [10, 23, 9, 10])
|
||||
self.assertEqual(shapes['G'], [10, 23, 9, 2, 10])
|
||||
self.assertEqual(types['C'], core.DataType.INT32)
|
||||
self.assertEqual(types['splits'], core.DataType.INT32)
|
||||
self.assertEqual(types['E'], core.DataType.INT32)
|
||||
self.assertEqual(types['G'], core.DataType.INT32)
|
||||
|
||||
def testSqueeze(self):
|
||||
net = core.Net("sq")
|
||||
net.Squeeze(["data"], ["data_squeezed"], dims=[3, 1])
|
||||
|
|
|
|||
|
|
@ -1354,6 +1354,33 @@ void addGlobalMethods(py::module& m) {
|
|||
|
||||
auto blob_info = InferBlobShapesAndTypesFromMap(blob_dimensions, nets_ptr);
|
||||
|
||||
std::string protob;
|
||||
CAFFE_ENFORCE(blob_info.SerializeToString(&protob));
|
||||
return py::bytes(protob);
|
||||
});
|
||||
m.def(
|
||||
"infer_shapes_and_types_from_map",
|
||||
[](const std::vector<py::bytes>& net_protos,
|
||||
const std::map<std::string, std::vector<TIndex>> blob_dimensions,
|
||||
const std::map<std::string, int> int_blob_types) {
|
||||
// Parse protobuffers to NetDefs
|
||||
std::vector<std::unique_ptr<caffe2::NetDef>> nets;
|
||||
std::vector<caffe2::NetDef*> nets_ptr;
|
||||
for (auto proto : net_protos) {
|
||||
std::unique_ptr<NetDef> def(new NetDef());
|
||||
CAFFE_ENFORCE(def->ParseFromString(proto));
|
||||
nets_ptr.push_back(def.get());
|
||||
nets.push_back(std::move(def));
|
||||
}
|
||||
std::map<std::string, TensorProto_DataType> blob_types;
|
||||
for (auto blob_type : int_blob_types) {
|
||||
blob_types[blob_type.first] =
|
||||
static_cast<TensorProto_DataType>(blob_type.second);
|
||||
}
|
||||
|
||||
auto blob_info = InferBlobShapesAndTypesFromMap(
|
||||
blob_dimensions, blob_types, nets_ptr);
|
||||
|
||||
std::string protob;
|
||||
CAFFE_ENFORCE(blob_info.SerializeToString(&protob));
|
||||
return py::bytes(protob);
|
||||
|
|
|
|||
|
|
@ -236,7 +236,8 @@ def RunPlanInBackground(plan_or_step):
|
|||
return C.run_plan_in_background(StringifyProto(plan_or_step))
|
||||
|
||||
|
||||
def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False):
|
||||
def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False,
|
||||
blob_types=None):
|
||||
"""Infers the shapes and types for the specified nets.
|
||||
|
||||
Inputs:
|
||||
|
|
@ -253,11 +254,16 @@ def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False):
|
|||
else:
|
||||
net_protos = [StringifyProto(n.Proto()) for n in nets]
|
||||
if blob_dimensions is None:
|
||||
assert blob_types is None
|
||||
blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
|
||||
else:
|
||||
elif blob_types is None:
|
||||
blobdesc_prototxt = C.infer_shapes_and_types_from_map(
|
||||
net_protos, blob_dimensions
|
||||
)
|
||||
else:
|
||||
blobdesc_prototxt = C.infer_shapes_and_types_from_map(
|
||||
net_protos, blob_dimensions, blob_types
|
||||
)
|
||||
blobdesc_proto = caffe2_pb2.TensorShapes()
|
||||
blobdesc_proto.ParseFromString(blobdesc_prototxt)
|
||||
shapes = {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue