diff --git a/caffe2/core/net_async_task.cc b/caffe2/core/net_async_task.cc new file mode 100644 index 00000000000..1496779a0b2 --- /dev/null +++ b/caffe2/core/net_async_task.cc @@ -0,0 +1,107 @@ +#include "caffe2/core/net_async_task.h" + +#include "caffe2/core/net_async_task_graph.h" + +namespace caffe2 { + +AsyncTask::AsyncTask(const std::vector& ops) : ops_(ops) { + CAFFE_ENFORCE(!ops_.empty()); + device_option_ = ops_.front()->device_option(); + for (auto& op : ops_) { + CAFFE_ENFORCE(IsSameDevice(device_option_, op->device_option())); + } + Reset(); +} + +void AsyncTask::handleChainError( + OperatorBase* op, + const char* err_str, + bool save_exception) { + std::string err_msg = err_str; + if (op) { + err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown"); + } + LOG(ERROR) << err_msg; + + // save error message and exception in chain's Event + auto last_op = ops_.back(); + if (save_exception) { + last_op->event().SetFinishedWithException(err_msg.c_str()); + } else { + last_op->event().SetFinished(err_msg.c_str()); + } + + // set future as completed with an error + // TODO: exceptions in future + future_.SetCompleted(err_msg.c_str()); +} + +bool AsyncTask::Run(const ExecutionOptions& options) { + // TODO: insert CUDA's async stream waits; tracing and counters + OperatorBase* op = nullptr; + try { + for (auto op_idx = 0; op_idx < ops_.size(); ++op_idx) { + op = ops_[op_idx]; + int stream_id = 0; // TODO: thread local stream id + if (!op->RunAsync(stream_id)) { + handleChainError(op, "Failed to execute an op"); + return false; + } + } + + if (options.finish_chain_) { + op = ops_.back(); + op->Finish(); + } + + // set the future as successfully completed or, in case of async CPU, + // use op's callback + if (IsCPUDeviceType(device_option_.device_type()) && + ops_.back()->HasAsyncPart()) { + auto& event = ops_.back()->event(); + event.SetCallback([this, &event]() { + CAFFE_ENFORCE(event.IsFinished()); + if (event.Query() == EventStatus::EVENT_SUCCESS) { + future_.SetCompleted(); + } else { + // TODO: support for exceptions + future_.SetCompleted(event.ErrorMessage().c_str()); + } + }); + } else { + future_.SetCompleted(); + } + } catch (const std::exception& e) { + handleChainError(op, e.what(), /* save_exception */ true); + return false; + } catch (...) { + handleChainError( + op, + "Failed to execute task: unknown error", + /* save_exception */ true); + return false; + } + + return true; +} + +void AsyncTask::Reset() { + for (auto& op : ops_) { + op->ResetEvent(); + } + future_.ResetState(); +} + +DeviceOption AsyncTask::GetDeviceOption() const { + return device_option_; +} + +AsyncTaskFuture& AsyncTask::GetFuture() { + return future_; +} + +const AsyncTaskFuture& AsyncTask::GetFuture() const { + return future_; +} + +}; // namespace caffe2 diff --git a/caffe2/core/net_async_task.h b/caffe2/core/net_async_task.h new file mode 100644 index 00000000000..e9aa699854e --- /dev/null +++ b/caffe2/core/net_async_task.h @@ -0,0 +1,39 @@ +#ifndef CAFFE2_NET_ASYNC_TASK_H +#define CAFFE2_NET_ASYNC_TASK_H + +#include "caffe2/core/net_async_base.h" +#include "caffe2/core/net_async_task_future.h" +#include "caffe2/core/operator.h" + +#include + +namespace caffe2 { + +// AsyncTask represents an asynchronous execution of a chain of ops. +class AsyncTask { + public: + AsyncTask(const std::vector& ops); + + bool Run(const ExecutionOptions& options); + + void Reset(); + + DeviceOption GetDeviceOption() const; + + AsyncTaskFuture& GetFuture(); + const AsyncTaskFuture& GetFuture() const; + + private: + void handleChainError( + OperatorBase* op, + const char* err_msg, + bool save_exception = false); + + std::vector ops_; + DeviceOption device_option_; + AsyncTaskFuture future_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_NET_ASYNC_TASK_H diff --git a/caffe2/core/net_async_task_future.cc b/caffe2/core/net_async_task_future.cc new file mode 100644 index 00000000000..da33d21fb1e --- /dev/null +++ b/caffe2/core/net_async_task_future.cc @@ -0,0 +1,110 @@ +#include "caffe2/core/net_async_task_future.h" + +#include "c10/util/Logging.h" +#include "caffe2/core/common.h" + +namespace caffe2 { + +AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {} + +AsyncTaskFuture::AsyncTaskFuture(const std::vector& futures) + : completed_(false), failed_(false) { + if (futures.size() > 1) { + parent_counter_ = caffe2::make_unique(futures.size()); + for (auto future : futures) { + future->SetCallback([this](const AsyncTaskFuture* f) { + if (f->IsFailed()) { + std::unique_lock lock(parent_counter_->err_mutex); + if (parent_counter_->parent_failed) { + parent_counter_->err_msg += ", " + f->ErrorMessage(); + } else { + parent_counter_->parent_failed = true; + parent_counter_->err_msg = f->ErrorMessage(); + } + } + int count = --parent_counter_->parent_count; + if (count == 0) { + // thread safe to use parent_counter here + if (!parent_counter_->parent_failed) { + SetCompleted(); + } else { + SetCompleted(parent_counter_->err_msg.c_str()); + } + } + }); + } + } else { + CAFFE_ENFORCE_EQ(futures.size(), 1); + auto future = futures.back(); + future->SetCallback([this](const AsyncTaskFuture* f) { + if (!f->IsFailed()) { + SetCompleted(); + } else { + SetCompleted(f->ErrorMessage().c_str()); + } + }); + } +} + +bool AsyncTaskFuture::IsCompleted() const { + return completed_; +} + +bool AsyncTaskFuture::IsFailed() const { + return failed_; +} + +std::string AsyncTaskFuture::ErrorMessage() const { + return err_msg_; +} + +void AsyncTaskFuture::Wait() const { + std::unique_lock lock(mutex_); + while (!completed_) { + cv_completed_.wait(lock); + } +} + +void AsyncTaskFuture::SetCallback( + std::function callback) { + std::unique_lock lock(mutex_); + + callbacks_.push_back(callback); + if (completed_) { + callback(this); + } +} + +void AsyncTaskFuture::SetCompleted(const char* err_msg) { + std::unique_lock lock(mutex_); + + CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future"); + completed_ = true; + + if (err_msg) { + failed_ = true; + err_msg_ = err_msg; + } + + for (auto& callback : callbacks_) { + callback(this); + } + + cv_completed_.notify_all(); +} + +// ResetState is called on a completed future, +// does not reset callbacks to keep task graph structure +void AsyncTaskFuture::ResetState() { + std::unique_lock lock(mutex_); + if (parent_counter_) { + parent_counter_->Reset(); + } + completed_ = false; + failed_ = false; + err_msg_ = ""; +} + +AsyncTaskFuture::~AsyncTaskFuture() {} + +} // namespace caffe2 diff --git a/caffe2/core/net_async_task_future.h b/caffe2/core/net_async_task_future.h new file mode 100644 index 00000000000..167b4b29c52 --- /dev/null +++ b/caffe2/core/net_async_task_future.h @@ -0,0 +1,76 @@ +#ifndef CAFFE2_NET_ASYNC_TASK_FUTURE_H +#define CAFFE2_NET_ASYNC_TASK_FUTURE_H + +#include +#include +#include +#include +#include +#include +#include + +namespace caffe2 { + +// Represents the state of AsyncTask execution, that can be queried with +// IsCompleted/IsFailed. Callbacks are supported through SetCallback and +// are called upon future's completion. + +class AsyncTaskFuture { + public: + AsyncTaskFuture(); + // Creates a future completed when all given futures are completed + explicit AsyncTaskFuture(const std::vector& futures); + ~AsyncTaskFuture(); + + AsyncTaskFuture(const AsyncTaskFuture&) = delete; + + AsyncTaskFuture& operator=(const AsyncTaskFuture&) = delete; + + bool IsCompleted() const; + + bool IsFailed() const; + + std::string ErrorMessage() const; + + void Wait() const; + + void SetCallback(std::function callback); + + void SetCompleted(const char* err_msg = nullptr); + + void ResetState(); + + private: + mutable std::mutex mutex_; + mutable std::condition_variable cv_completed_; + std::atomic completed_; + std::atomic failed_; + std::string err_msg_; + std::vector> callbacks_; + + struct ParentCounter { + explicit ParentCounter(int init_parent_count) + : init_parent_count_(init_parent_count), + parent_count(init_parent_count), + parent_failed(false) {} + + void Reset() { + std::unique_lock lock(err_mutex); + parent_count = init_parent_count_; + parent_failed = false; + err_msg = ""; + } + + const int init_parent_count_; + std::atomic parent_count; + std::mutex err_mutex; + std::atomic parent_failed; + std::string err_msg; + }; + + std::unique_ptr parent_counter_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_NET_ASYNC_TASK_FUTURE_H diff --git a/caffe2/core/net_async_task_graph.cc b/caffe2/core/net_async_task_graph.cc new file mode 100644 index 00000000000..c2732ee97c0 --- /dev/null +++ b/caffe2/core/net_async_task_graph.cc @@ -0,0 +1,139 @@ +#include "caffe2/core/net_async_task_graph.h" + +#include "caffe2/core/net_parallel.h" + +namespace caffe2 { + +AsyncTaskGraph::AsyncTaskGraph( + ExecutorHelper* helper, + const ExecutionOptions& options) + : helper_(helper), options_(options), frozen_(false) {} + +bool AsyncTaskGraph::CreateNode( + int node_id, + const std::vector& ops) { + CAFFE_ENFORCE(!frozen_); + if (!nodes_.count(node_id)) { + nodes_[node_id] = caffe2::make_unique(ops); + return true; + } else { + return false; + } +} + +bool AsyncTaskGraph::AddDependency( + int child_node_id, + const std::vector& parent_node_ids) { + CAFFE_ENFORCE(!frozen_); + CAFFE_ENFORCE(!parent_node_ids.empty()); + CAFFE_ENFORCE(nodes_.count(child_node_id)); + for (auto node_id : parent_node_ids) { + CAFFE_ENFORCE(nodes_.count(node_id)); + } + CAFFE_ENFORCE(!parents_.count(child_node_id)); + + auto* child_task = nodes_[child_node_id].get(); + auto child_device = child_task->GetDeviceOption(); + + std::vector parent_futures; + for (auto node_id : parent_node_ids) { + parents_[child_node_id].insert(node_id); + children_[node_id].insert(child_node_id); + parent_futures.push_back(&nodes_[node_id]->GetFuture()); + } + + AsyncTaskFuture* parents_future = nullptr; + if (parent_futures.size() > 1) { + edge_futures_.push_back( + caffe2::make_unique(parent_futures)); + parents_future = edge_futures_.back().get(); + } else { + CAFFE_ENFORCE_EQ(parent_futures.size(), 1); + parents_future = parent_futures.back(); + } + + // TODO: CUDA polling + parents_future->SetCallback( + [this, child_task, child_device](const AsyncTaskFuture* f) { + CAFFE_ENFORCE(f->IsCompleted()); + if (!f->IsFailed()) { + // if we're in the correct thread pool and DFS scheduling is enabled, + // immediately call task inline, otherwise send task into thread pool + auto* pool = helper_->GetPool(child_device); + if (pool->inThreadPool() && options_.use_dfs_scheduling_) { + child_task->Run(options_); + } else { + pool->run([this, child_task]() { child_task->Run(options_); }); + } + } else { + // skip task execution and propagate error further + child_task->GetFuture().SetCompleted(f->ErrorMessage().c_str()); + } + }); + + return true; +} + +void AsyncTaskGraph::FreezeGraph() { + if (frozen_) { + return; + } + + CAFFE_ENFORCE(!run_future_); + CAFFE_ENFORCE(root_tasks_.empty()); + + std::vector final_futures; + for (auto& kv : nodes_) { + auto task_id = kv.first; + auto* task = kv.second.get(); + + if (parents_[task_id].empty()) { + root_tasks_.push_back(task); + } + + if (children_[task_id].empty()) { + auto& future = task->GetFuture(); + final_futures.push_back(&future); + } + } + + CAFFE_ENFORCE(!root_tasks_.empty()); + CAFFE_ENFORCE(!final_futures.empty()); + + run_future_ = caffe2::make_unique(final_futures); + + frozen_ = true; +} + +AsyncTaskFuture* AsyncTaskGraph::ExecuteGraph() { + CAFFE_ENFORCE(frozen_); + CAFFE_ENFORCE(run_future_ && !run_future_->IsCompleted()); + + // TODO: run root tasks inline in inference mode + for (auto* task : root_tasks_) { + auto task_device = task->GetDeviceOption(); + helper_->GetPool(task_device)->run([this, task]() { task->Run(options_); }); + } + + return run_future_.get(); +} + +AsyncTaskFuture* AsyncTaskGraph::GetFuture() { + CAFFE_ENFORCE(frozen_); + return run_future_.get(); +} + +void AsyncTaskGraph::Reset() { + CAFFE_ENFORCE(frozen_); + for (auto& kv : nodes_) { + kv.second->Reset(); + } + for (auto& future : edge_futures_) { + future->ResetState(); + } + if (run_future_) { + run_future_->ResetState(); + } +} + +}; // namespace caffe2 diff --git a/caffe2/core/net_async_task_graph.h b/caffe2/core/net_async_task_graph.h new file mode 100644 index 00000000000..8ccec923ea6 --- /dev/null +++ b/caffe2/core/net_async_task_graph.h @@ -0,0 +1,78 @@ +#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H +#define CAFFE2_NET_ASYNC_TASK_GRAPH_H + +#include "caffe2/core/net_async_base.h" +#include "caffe2/core/net_async_task.h" +#include "caffe2/core/net_async_task_future.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +// AsyncTaskGraph represents an execution of a net, it owns the tasks and +// associated futures, sets up future callbacks and propagates errors. +// Usage steps: +// - Adding graph nodes and edges through CreateNode/AddDependency; +// - Freezing the graph (FreezeGraph), after the freezing a future +// can be obtained using GetFuture; +// - Execution of the graph is scheduled through ExecuteGraph, after each +// execution Reset must be called to prepare the graph for the next run + +class AsyncTaskGraphBase { + public: + virtual bool CreateNode( + int node_id, + const std::vector& ops) = 0; + + virtual bool AddDependency( + int child_node_id, + const std::vector& parent_node_ids) = 0; + + virtual void FreezeGraph() = 0; + + virtual AsyncTaskFuture* ExecuteGraph() = 0; + + virtual AsyncTaskFuture* GetFuture() = 0; + + virtual void Reset() = 0; + + virtual ~AsyncTaskGraphBase() noexcept {} +}; + +class AsyncTaskGraph : public AsyncTaskGraphBase { + public: + AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options); + + bool CreateNode(int node_id, const std::vector& ops) override; + + bool AddDependency(int child_node_id, const std::vector& parent_node_ids) + override; + + void FreezeGraph() override; + + AsyncTaskFuture* ExecuteGraph() override; + + AsyncTaskFuture* GetFuture() override; + + void Reset() override; + + private: + // used to, e.g., get access to executor's thread pools + // TODO: pass tracer and counters through ExecutorHelper + ExecutorHelper* helper_; + ExecutionOptions options_; + + bool frozen_; + + std::unordered_map> nodes_; + std::unordered_map> parents_; + std::unordered_map> children_; + std::vector> edge_futures_; + + std::vector root_tasks_; + + std::unique_ptr run_future_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H diff --git a/caffe2/core/net_parallel.cc b/caffe2/core/net_parallel.cc new file mode 100644 index 00000000000..b9a9f0869f7 --- /dev/null +++ b/caffe2/core/net_parallel.cc @@ -0,0 +1,197 @@ +#include "caffe2/core/net_parallel.h" + +#include "caffe2/core/operator.h" + +#include + +C10_DEFINE_string( + caffe2_task_graph_engine, + "futures", + "Task graph engine type used by net executor"); + +namespace caffe2 { + +ParallelNet::ParallelNet( + const std::shared_ptr& net_def, + Workspace* ws) + : NetBase(net_def, ws), options_(net_def), run_future_(nullptr) { + num_workers_ = net_def->num_workers(); + CAFFE_ENFORCE_GT( + num_workers_, 0, "Expected positive number of worker threads"); + + helper_ = caffe2::make_unique(this); + task_graph_ = TaskGraphRegistry()->Create( + FLAGS_caffe2_task_graph_engine, helper_.get(), options_); + + // initialize operators + operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws); + operators_.reserve(operator_nodes_.size()); + for (const auto& node : operator_nodes_) { + auto op = node.operator_.get(); + op->SetExecutorHelper(helper_.get()); + operators_.push_back(op); + } + + // compute chains + // TODO: inference mode for chaining + auto execution_chains = dag_utils::computeChains(operator_nodes_); + std::vector> chains; + chains.reserve(execution_chains.size()); + for (const auto& kv : execution_chains) { + chains.push_back(kv.second); + } + auto chain_nodes = dag_utils::prepareChainGraphNodes(operator_nodes_, chains); + CAFFE_ENFORCE_EQ(chains.size(), chain_nodes.size()); + + // disable unused events + for (const auto& chain : chains) { + for (const auto& op_id : chain) { + if (op_id == chain.back() || op_id == chain.front()) { + continue; + } + auto op = operators_[op_id]; + if (IsCPUDeviceType(op->device_option().device_type()) && + op->HasAsyncPart()) { + continue; + } + op->DisableEvent(); + } + } + + // initialize task graph + for (auto chain_id = 0; chain_id < chains.size(); ++chain_id) { + std::vector ops; + ops.reserve(chains[chain_id].size()); + for (auto op_id : chains[chain_id]) { + ops.push_back(operators_[op_id]); + } + CAFFE_ENFORCE(task_graph_->CreateNode(chain_id, ops)); + } + for (auto chain_id = 0; chain_id < chain_nodes.size(); ++chain_id) { + if (!chain_nodes[chain_id].parents_.empty()) { + CAFFE_ENFORCE( + task_graph_->AddDependency(chain_id, chain_nodes[chain_id].parents_)); + } + } + + // Freeze graph and initialize graph execution future + task_graph_->FreezeGraph(); + run_future_ = task_graph_->GetFuture(); + run_future_->SetCallback([this](const AsyncTaskFuture* /* unused */) { + StopAllObservers(); + finishRun(); + }); + + LOG(INFO) << "Initialized parallel net: '" << Name() + << "', #ops: " << net_def->op_size() + << ", #chains: " << chains.size() << ", #workers: " << num_workers_ + << ", dfs scheduling: " << options_.use_dfs_scheduling_ + << ", task graph engine: " << FLAGS_caffe2_task_graph_engine; +} + +bool ParallelNet::RunAsync() { + reset(); + StartAllObservers(); + + try { + task_graph_->ExecuteGraph(); + } catch (const std::exception&) { + StopAllObservers(); + return false; + } + + return true; +} + +void ParallelNet::Wait() { + CAFFE_ENFORCE(run_future_); + run_future_->Wait(); +} + +void ParallelNet::reset() { + task_graph_->Reset(); +} + +bool ParallelNet::handleRunError() { + CAFFE_ENFORCE(run_future_ && run_future_->IsCompleted()); + // TODO: throw saved exceptions + if (run_future_->IsFailed()) { + LOG(ERROR) << "Failed parallel run (" << Name() + << "): " << run_future_->ErrorMessage(); + } + return !run_future_->IsFailed(); +} + +TaskThreadPoolBase* ParallelNet::poolGetter( + PoolsMap& pools, + int device_type, + int device_id, + int pool_size) { + std::unique_lock pools_lock(pools_mutex_); + auto pool = pools[device_id][pool_size]; + if (!pool) { + pool = ThreadPoolRegistry()->Create( + DeviceTypeName(device_type), + device_id, + pool_size, + options_.use_per_net_pools_); + pools[device_id][pool_size] = pool; + } + return pool.get(); +} + +TaskThreadPoolBase* ParallelNet::Pool(const DeviceOption& device_option) { + if (options_.use_single_pool_) { + return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_); + } + const auto device_type = device_option.device_type(); + if (IsCPUDeviceType(device_type)) { + auto numa_node_id = -1; + if (device_option.has_numa_node_id()) { + numa_node_id = device_option.numa_node_id(); + CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id); + } + CAFFE_ENFORCE_LT( + numa_node_id, + FLAGS_caffe2_net_async_max_numa_nodes, + "Invalid NUMA node id: ", + numa_node_id); + return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_); + } else if (IsGPUDeviceType(device_type)) { + auto gpu_id = device_option.device_id(); + CAFFE_ENFORCE( + gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus, + "Invalid GPU id: " + caffe2::to_string(gpu_id)); + return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_); + } else { + CAFFE_THROW("Unsupported device type " + caffe2::to_string(device_type)); + } +} + +bool ParallelNet::SupportsAsync() { + return true; +} + +void ParallelNet::finishRun() {} + +std::vector ParallelNet::GetOperators() const { + return operators_; +} + +std::shared_ptr GetAsyncTaskGraph( + ExecutorHelper* helper, + const ExecutionOptions& options) { + return std::make_shared(helper, options); +} + +C10_DEFINE_SHARED_REGISTRY( + TaskGraphRegistry, + AsyncTaskGraphBase, + ExecutorHelper*, + const ExecutionOptions&); + +C10_REGISTER_CREATOR(TaskGraphRegistry, futures, GetAsyncTaskGraph); + +REGISTER_NET(parallel, ParallelNet); + +} // namespace caffe2 diff --git a/caffe2/core/net_parallel.h b/caffe2/core/net_parallel.h new file mode 100644 index 00000000000..df5b6d8a4ef --- /dev/null +++ b/caffe2/core/net_parallel.h @@ -0,0 +1,77 @@ +#ifndef CAFFE2_CORE_NET_PARALLEL_H +#define CAFFE2_CORE_NET_PARALLEL_H + +#include "caffe2/core/net_async_base.h" +#include "caffe2/core/net_async_task_graph.h" + +C10_DECLARE_string(caffe2_task_graph_engine); + +namespace caffe2 { + +class ParallelNetExecutorHelper; + +class CAFFE2_API ParallelNet : public NetBase { + public: + ParallelNet(const std::shared_ptr& net_def, Workspace* ws); + + bool RunAsync() override; + void Wait() override; + + bool SupportsAsync() override; + std::vector GetOperators() const override; + + TaskThreadPoolBase* Pool(const DeviceOption& device_option); + + protected: + bool handleRunError() override; + virtual void finishRun(); + virtual void reset(); + + ExecutionOptions options_; + int num_workers_; + + std::unique_ptr helper_; + std::shared_ptr task_graph_; + AsyncTaskFuture* run_future_; + + std::vector operator_nodes_; + std::vector operators_; + + std::mutex pools_mutex_; + typedef std::unordered_map< + int, + std::unordered_map>> + PoolsMap; + PoolsMap cpu_pools_; + PoolsMap gpu_pools_; + TaskThreadPoolBase* + poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size); + + friend class ParallelNetExecutorHelper; + C10_DISABLE_COPY_AND_ASSIGN(ParallelNet); +}; + +C10_DECLARE_SHARED_REGISTRY( + TaskGraphRegistry, + AsyncTaskGraphBase, + ExecutorHelper*, + const ExecutionOptions&); + +std::shared_ptr GetAsyncTaskGraph( + ExecutorHelper* helper, + const ExecutionOptions& options); + +class ParallelNetExecutorHelper : public ExecutorHelper { + public: + explicit ParallelNetExecutorHelper(ParallelNet* net) : net_(net) {} + TaskThreadPoolBase* GetPool(const DeviceOption& option) const override { + return net_->Pool(option); + } + + private: + ParallelNet* net_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_CORE_NET_PARALLEL_H diff --git a/caffe2/python/test/executor_test.py b/caffe2/python/test/executor_test.py index d4ff0c328cb..ee52717c0f4 100644 --- a/caffe2/python/test/executor_test.py +++ b/caffe2/python/test/executor_test.py @@ -19,7 +19,7 @@ import hypothesis.strategies as st import unittest -EXECUTORS = ["async_scheduling", "dag", "async_dag"] +EXECUTORS = ["parallel", "async_scheduling"] ITERATIONS = 1