mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Forward blobs into workspace
Summary: Better isolation for workspaces to allow forwarding selected blobs from parent to child workspace, possibly under new names. Used for proper isolation of subnets (loops, then/else branhes, etc) from outer workspace. Reviewed By: azzolini Differential Revision: D5681667 fbshipit-source-id: e61a2c7c98ee2abf1f0761905f4bfae47c201c32
This commit is contained in:
parent
502b43641f
commit
67a55b81e3
2 changed files with 71 additions and 30 deletions
|
|
@ -74,6 +74,7 @@ void Workspace::PrintBlobSizes() {
|
|||
|
||||
vector<string> Workspace::LocalBlobs() const {
|
||||
vector<string> names;
|
||||
names.reserve(blob_map_.size());
|
||||
for (auto& entry : blob_map_) {
|
||||
names.push_back(entry.first);
|
||||
}
|
||||
|
|
@ -82,12 +83,20 @@ vector<string> Workspace::LocalBlobs() const {
|
|||
|
||||
vector<string> Workspace::Blobs() const {
|
||||
vector<string> names;
|
||||
names.reserve(blob_map_.size());
|
||||
for (auto& entry : blob_map_) {
|
||||
names.push_back(entry.first);
|
||||
}
|
||||
if (shared_) {
|
||||
vector<string> shared_blobs = shared_->Blobs();
|
||||
names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
|
||||
for (const auto& forwarded : forwarded_blobs_) {
|
||||
if (shared_->HasBlob(forwarded.second)) {
|
||||
names.push_back(forwarded.first);
|
||||
}
|
||||
}
|
||||
if (blob_inheritance_) {
|
||||
const auto& shared_blobs = shared_->Blobs();
|
||||
names.insert(names.end(), shared_blobs.begin(), shared_blobs.end());
|
||||
}
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
|
@ -95,6 +104,10 @@ vector<string> Workspace::Blobs() const {
|
|||
Blob* Workspace::CreateBlob(const string& name) {
|
||||
if (HasBlob(name)) {
|
||||
VLOG(1) << "Blob " << name << " already exists. Skipping.";
|
||||
} else if (forwarded_blobs_.count(name)) {
|
||||
// possible if parent workspace deletes forwarded blob
|
||||
VLOG(1) << "Blob " << name << " is already forwarded from parent workspace "
|
||||
<< "(blob " << forwarded_blobs_[name] << "). Skipping.";
|
||||
} else {
|
||||
VLOG(1) << "Creating blob " << name;
|
||||
blob_map_[name] = unique_ptr<Blob>(new Blob());
|
||||
|
|
@ -110,7 +123,7 @@ bool Workspace::RemoveBlob(const string& name) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// won't go into share_ here
|
||||
// won't go into shared_ here
|
||||
VLOG(1) << "Blob " << name << " not exists. Skipping.";
|
||||
return false;
|
||||
}
|
||||
|
|
@ -118,17 +131,21 @@ bool Workspace::RemoveBlob(const string& name) {
|
|||
const Blob* Workspace::GetBlob(const string& name) const {
|
||||
if (blob_map_.count(name)) {
|
||||
return blob_map_.at(name).get();
|
||||
} else if (shared_ && shared_->HasBlob(name)) {
|
||||
return shared_->GetBlob(name);
|
||||
} else {
|
||||
LOG(WARNING) << "Blob " << name << " not in the workspace.";
|
||||
// TODO(Yangqing): do we want to always print out the list of blobs here?
|
||||
// LOG(WARNING) << "Current blobs:";
|
||||
// for (const auto& entry : blob_map_) {
|
||||
// LOG(WARNING) << entry.first;
|
||||
// }
|
||||
return nullptr;
|
||||
} else if (shared_) {
|
||||
if (forwarded_blobs_.count(name)) {
|
||||
return shared_->GetBlob(forwarded_blobs_.at(name));
|
||||
}
|
||||
if (blob_inheritance_ && shared_->HasBlob(name)) {
|
||||
return shared_->GetBlob(name);
|
||||
}
|
||||
}
|
||||
LOG(WARNING) << "Blob " << name << " not in the workspace.";
|
||||
// TODO(Yangqing): do we want to always print out the list of blobs here?
|
||||
// LOG(WARNING) << "Current blobs:";
|
||||
// for (const auto& entry : blob_map_) {
|
||||
// LOG(WARNING) << entry.first;
|
||||
// }
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Blob* Workspace::GetBlob(const string& name) {
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ struct StopOnSignal {
|
|||
std::shared_ptr<SignalHandler> handler_;
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Workspace is a class that holds all the related objects created during
|
||||
* runtime: (1) all blobs, and (2) all instantiated networks. It is the owner of
|
||||
|
|
@ -58,8 +57,8 @@ class Workspace {
|
|||
/**
|
||||
* Initializes an empty workspace.
|
||||
*/
|
||||
Workspace() {
|
||||
}
|
||||
Workspace() : root_folder_("."), shared_(nullptr), blob_inheritance_(false) {}
|
||||
|
||||
/**
|
||||
* Initializes an empty workspace with the given root folder.
|
||||
*
|
||||
|
|
@ -68,7 +67,8 @@ class Workspace {
|
|||
* by the workspace.
|
||||
*/
|
||||
explicit Workspace(const string& root_folder)
|
||||
: root_folder_(root_folder) {}
|
||||
: root_folder_(root_folder), shared_(nullptr), blob_inheritance_(false) {}
|
||||
|
||||
/**
|
||||
* Initializes a workspace with a shared workspace.
|
||||
*
|
||||
|
|
@ -79,26 +79,37 @@ class Workspace {
|
|||
* created workspace.
|
||||
*/
|
||||
explicit Workspace(Workspace* const shared)
|
||||
: shared_(shared) {}
|
||||
: root_folder_("."), shared_(shared), blob_inheritance_(true) {}
|
||||
|
||||
/**
|
||||
* Initializes workspace with parent workspace, blob name remapping
|
||||
* (new name -> parent blob name), no other blobs are inherited from
|
||||
* parent workspace
|
||||
*/
|
||||
Workspace(
|
||||
Workspace* const shared,
|
||||
const std::unordered_map<string, string>& forwarded_blobs)
|
||||
: root_folder_("."), shared_(shared), blob_inheritance_(false) {
|
||||
CAFFE_ENFORCE(shared_, "Parent workspace must be specified");
|
||||
for (const auto& forwarded : forwarded_blobs) {
|
||||
CAFFE_ENFORCE(
|
||||
shared_->HasBlob(forwarded.second), "Invalid parent workspace blob");
|
||||
}
|
||||
forwarded_blobs_ = forwarded_blobs; // copy
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a workspace with a root folder and a shared workspace.
|
||||
*/
|
||||
Workspace(const string& root_folder, Workspace* shared)
|
||||
: root_folder_(root_folder), shared_(shared) {}
|
||||
: root_folder_(root_folder), shared_(shared), blob_inheritance_(true) {}
|
||||
|
||||
~Workspace() {
|
||||
if (FLAGS_caffe2_print_blob_sizes_at_exit) {
|
||||
PrintBlobSizes();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Allows to add a parent workspace post factum after the object
|
||||
* was already constructed.
|
||||
*/
|
||||
void SetParentWorkspace(Workspace* shared) {
|
||||
shared_ = shared;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return list of blobs owned by this Workspace, not including blobs
|
||||
* shared from parent workspace.
|
||||
|
|
@ -120,7 +131,18 @@ class Workspace {
|
|||
* Checks if a blob with the given name is present in the current workspace.
|
||||
*/
|
||||
inline bool HasBlob(const string& name) const {
|
||||
return (blob_map_.count(name) || (shared_ && shared_->HasBlob(name)));
|
||||
// First, check the local workspace,
|
||||
// Then, check the forwarding map, then the parent workspace
|
||||
if (blob_map_.count(name)) {
|
||||
return true;
|
||||
}
|
||||
if (shared_) {
|
||||
if (forwarded_blobs_.count(name)) {
|
||||
return shared_->HasBlob(forwarded_blobs_.at(name));
|
||||
}
|
||||
return blob_inheritance_ && shared_->HasBlob(name);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void PrintBlobSizes();
|
||||
|
|
@ -217,8 +239,10 @@ class Workspace {
|
|||
private:
|
||||
BlobMap blob_map_;
|
||||
NetMap net_map_;
|
||||
string root_folder_ = ".";
|
||||
Workspace* shared_ = nullptr;
|
||||
const string root_folder_;
|
||||
const Workspace* shared_;
|
||||
std::unordered_map<string, string> forwarded_blobs_;
|
||||
const bool blob_inheritance_;
|
||||
#if CAFFE2_MOBILE
|
||||
std::unique_ptr<ThreadPool> thread_pool_;
|
||||
std::mutex thread_pool_creation_mutex_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue