Forward blobs into workspace

Summary:
Better isolation for workspaces to allow forwarding selected blobs
from parent to child workspace, possibly under new names. Used for proper
isolation of subnets (loops, then/else branhes, etc) from outer workspace.

Reviewed By: azzolini

Differential Revision: D5681667

fbshipit-source-id: e61a2c7c98ee2abf1f0761905f4bfae47c201c32
This commit is contained in:
Ilia Cherniavskii 2017-08-22 18:39:36 -07:00 committed by Facebook Github Bot
parent 502b43641f
commit 67a55b81e3
2 changed files with 71 additions and 30 deletions

View file

@ -74,6 +74,7 @@ void Workspace::PrintBlobSizes() {
vector<string> Workspace::LocalBlobs() const {
vector<string> names;
names.reserve(blob_map_.size());
for (auto& entry : blob_map_) {
names.push_back(entry.first);
}
@ -82,12 +83,20 @@ vector<string> Workspace::LocalBlobs() const {
vector<string> Workspace::Blobs() const {
vector<string> names;
names.reserve(blob_map_.size());
for (auto& entry : blob_map_) {
names.push_back(entry.first);
}
if (shared_) {
vector<string> shared_blobs = shared_->Blobs();
names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
for (const auto& forwarded : forwarded_blobs_) {
if (shared_->HasBlob(forwarded.second)) {
names.push_back(forwarded.first);
}
}
if (blob_inheritance_) {
const auto& shared_blobs = shared_->Blobs();
names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
}
}
return names;
}
@ -95,6 +104,10 @@ vector<string> Workspace::Blobs() const {
Blob* Workspace::CreateBlob(const string& name) {
if (HasBlob(name)) {
VLOG(1) << "Blob " << name << " already exists. Skipping.";
} else if (forwarded_blobs_.count(name)) {
// possible if parent workspace deletes forwarded blob
VLOG(1) << "Blob " << name << " is already forwarded from parent workspace "
<< "(blob " << forwarded_blobs_[name] << "). Skipping.";
} else {
VLOG(1) << "Creating blob " << name;
blob_map_[name] = unique_ptr<Blob>(new Blob());
@ -110,7 +123,7 @@ bool Workspace::RemoveBlob(const string& name) {
return true;
}
// won't go into share_ here
// won't go into shared_ here
VLOG(1) << "Blob " << name << " not exists. Skipping.";
return false;
}
@ -118,17 +131,21 @@ bool Workspace::RemoveBlob(const string& name) {
const Blob* Workspace::GetBlob(const string& name) const {
if (blob_map_.count(name)) {
return blob_map_.at(name).get();
} else if (shared_ && shared_->HasBlob(name)) {
return shared_->GetBlob(name);
} else {
LOG(WARNING) << "Blob " << name << " not in the workspace.";
// TODO(Yangqing): do we want to always print out the list of blobs here?
// LOG(WARNING) << "Current blobs:";
// for (const auto& entry : blob_map_) {
// LOG(WARNING) << entry.first;
// }
return nullptr;
} else if (shared_) {
if (forwarded_blobs_.count(name)) {
return shared_->GetBlob(forwarded_blobs_.at(name));
}
if (blob_inheritance_ && shared_->HasBlob(name)) {
return shared_->GetBlob(name);
}
}
LOG(WARNING) << "Blob " << name << " not in the workspace.";
// TODO(Yangqing): do we want to always print out the list of blobs here?
// LOG(WARNING) << "Current blobs:";
// for (const auto& entry : blob_map_) {
// LOG(WARNING) << entry.first;
// }
return nullptr;
}
Blob* Workspace::GetBlob(const string& name) {

View file

@ -44,7 +44,6 @@ struct StopOnSignal {
std::shared_ptr<SignalHandler> handler_;
};
/**
* Workspace is a class that holds all the related objects created during
* runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of
@ -58,8 +57,8 @@ class Workspace {
/**
* Initializes an empty workspace.
*/
Workspace() {
}
Workspace() : root_folder_("."), shared_(nullptr), blob_inheritance_(false) {}
/**
* Initializes an empty workspace with the given root folder.
*
@ -68,7 +67,8 @@ class Workspace {
* by the workspace.
*/
explicit Workspace(const string& root_folder)
: root_folder_(root_folder) {}
: root_folder_(root_folder), shared_(nullptr), blob_inheritance_(false) {}
/**
* Initializes a workspace with a shared workspace.
*
@ -79,26 +79,37 @@ class Workspace {
* created workspace.
*/
explicit Workspace(Workspace* const shared)
: shared_(shared) {}
: root_folder_("."), shared_(shared), blob_inheritance_(true) {}
/**
* Initializes workspace with parent workspace, blob name remapping
* (new name -> parent blob name), no other blobs are inherited from
* parent workspace
*/
Workspace(
Workspace* const shared,
const std::unordered_map<string, string>& forwarded_blobs)
: root_folder_("."), shared_(shared), blob_inheritance_(false) {
CAFFE_ENFORCE(shared_, "Parent workspace must be specified");
for (const auto& forwarded : forwarded_blobs) {
CAFFE_ENFORCE(
shared_->HasBlob(forwarded.second), "Invalid parent workspace blob");
}
forwarded_blobs_ = forwarded_blobs; // copy
}
/**
* Initializes a workspace with a root folder and a shared workspace.
*/
Workspace(const string& root_folder, Workspace* shared)
: root_folder_(root_folder), shared_(shared) {}
: root_folder_(root_folder), shared_(shared), blob_inheritance_(true) {}
~Workspace() {
if (FLAGS_caffe2_print_blob_sizes_at_exit) {
PrintBlobSizes();
}
}
/**
* Allows to add a parent workspace post factum after the object
* was already constructed.
*/
void SetParentWorkspace(Workspace* shared) {
shared_ = shared;
}
/**
* Return list of blobs owned by this Workspace, not including blobs
* shared from parent workspace.
@ -120,7 +131,18 @@ class Workspace {
* Checks if a blob with the given name is present in the current workspace.
*/
inline bool HasBlob(const string& name) const {
return (blob_map_.count(name) || (shared_ && shared_->HasBlob(name)));
// First, check the local workspace,
// Then, check the forwarding map, then the parent workspace
if (blob_map_.count(name)) {
return true;
}
if (shared_) {
if (forwarded_blobs_.count(name)) {
return shared_->HasBlob(forwarded_blobs_.at(name));
}
return blob_inheritance_ && shared_->HasBlob(name);
}
return false;
}
void PrintBlobSizes();
@ -217,8 +239,10 @@ class Workspace {
private:
BlobMap blob_map_;
NetMap net_map_;
string root_folder_ = ".";
Workspace* shared_ = nullptr;
const string root_folder_;
const Workspace* shared_;
std::unordered_map<string, string> forwarded_blobs_;
const bool blob_inheritance_;
#if CAFFE2_MOBILE
std::unique_ptr<ThreadPool> thread_pool_;
std::mutex thread_pool_creation_mutex_;