mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Fix build
Summary: (1) BlobsQueue is causing a gcc error (google search suggeste it was a bug, but we'll put the implementation in a separate cc file). (2) Preparing for cuda 9: update cub. (3) Prepare for cudnn 7: update cudnn rnn op. (4) Fix an MSVC issue Reviewed By: sf-wind, jerryzh168 Differential Revision: D5574352 fbshipit-source-id: 230820ce3ceaa32bee8323bdc509de352c93fcf2
This commit is contained in:
parent
a11aa0ab35
commit
5ae3865112
7 changed files with 167 additions and 118 deletions
|
|
@ -63,10 +63,12 @@ using std::vector;
|
|||
#define CAFFE_NOT_IMPLEMENTED CAFFE_THROW("Not Implemented.")
|
||||
|
||||
// suppress an unused variable.
|
||||
#ifndef _MSC_VER
|
||||
#define CAFFE2_UNUSED __attribute__((__unused__))
|
||||
#else
|
||||
#ifdef _MSC_VER
|
||||
#define CAFFE2_UNUSED
|
||||
#define CAFFE2_USED
|
||||
#else
|
||||
#define CAFFE2_UNUSED __attribute__((__unused__))
|
||||
#define CAFFE2_USED __attribute__((__used__))
|
||||
#endif //_MSC_VER
|
||||
|
||||
// Disable the copy and assignment operator for a class. Note that this will
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@
|
|||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include <direct.h> // for _mkdir
|
||||
#endif
|
||||
|
||||
#include "caffe2/utils/murmur_hash3.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
|
@ -37,7 +41,11 @@ FileStoreHandler::FileStoreHandler(
|
|||
if (!prefix.empty()) {
|
||||
basePath_ = basePath_ + "/" + encodeName(prefix);
|
||||
}
|
||||
#if defined(_MSC_VER)
|
||||
auto ret = _mkdir(basePath_.c_str());
|
||||
#else
|
||||
auto ret = mkdir(basePath_.c_str(), 0777);
|
||||
#endif // defined(_MSC_VER)
|
||||
if (ret == -1) {
|
||||
CHECK_EQ(errno, EEXIST) << "mkdir: " << strerror(errno);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -116,6 +116,9 @@ void RecurrentBaseOp<T>::initialize(
|
|||
// RNN setup
|
||||
{
|
||||
CUDNN_ENFORCE(cudnnSetRNNDescriptor(
|
||||
#if CUDNN_MAJOR >= 7
|
||||
cudnn_wrapper_.inline_cudnn_handle(),
|
||||
#endif
|
||||
rnnDesc_,
|
||||
hiddenSize,
|
||||
numLayers,
|
||||
|
|
@ -123,6 +126,9 @@ void RecurrentBaseOp<T>::initialize(
|
|||
rnnInput,
|
||||
rnnDirection,
|
||||
rnnMode,
|
||||
#if CUDNN_MAJOR >= 7
|
||||
CUDNN_RNN_ALGO_STANDARD, // TODO: verify correctness / efficiency.
|
||||
#endif
|
||||
cudnnTypeWrapper<T>::type));
|
||||
}
|
||||
// X setup
|
||||
|
|
|
|||
138
caffe2/queue/blobs_queue.cc
Normal file
138
caffe2/queue/blobs_queue.cc
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
#include "caffe2/queue/blobs_queue.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
|
||||
#include "caffe2/core/blob_stats.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/stats.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
#include "caffe2/core/workspace.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
BlobsQueue::BlobsQueue(
|
||||
Workspace* ws,
|
||||
const std::string& queueName,
|
||||
size_t capacity,
|
||||
size_t numBlobs,
|
||||
bool enforceUniqueName,
|
||||
const std::vector<std::string>& fieldNames)
|
||||
: numBlobs_(numBlobs), stats_(queueName) {
|
||||
if (!fieldNames.empty()) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
fieldNames.size(), numBlobs, "Wrong number of fieldNames provided.");
|
||||
stats_.queue_dequeued_bytes.setDetails(fieldNames);
|
||||
}
|
||||
queue_.reserve(capacity);
|
||||
for (auto i = 0; i < capacity; ++i) {
|
||||
std::vector<Blob*> blobs;
|
||||
blobs.reserve(numBlobs);
|
||||
for (auto j = 0; j < numBlobs; ++j) {
|
||||
const auto blobName = queueName + "_" + to_string(i) + "_" + to_string(j);
|
||||
if (enforceUniqueName) {
|
||||
CAFFE_ENFORCE(
|
||||
!ws->GetBlob(blobName),
|
||||
"Queue internal blob already exists: ",
|
||||
blobName);
|
||||
}
|
||||
blobs.push_back(ws->CreateBlob(blobName));
|
||||
}
|
||||
queue_.push_back(blobs);
|
||||
}
|
||||
DCHECK_EQ(queue_.size(), capacity);
|
||||
}
|
||||
|
||||
bool BlobsQueue::blockingRead(
|
||||
const std::vector<Blob*>& inputs,
|
||||
float timeout_secs) {
|
||||
auto keeper = this->shared_from_this();
|
||||
std::unique_lock<std::mutex> g(mutex_);
|
||||
auto canRead = [this]() {
|
||||
CAFFE_ENFORCE_LE(reader_, writer_);
|
||||
return reader_ != writer_;
|
||||
};
|
||||
CAFFE_EVENT(stats_, queue_balance, -1);
|
||||
if (timeout_secs > 0) {
|
||||
std::chrono::milliseconds timeout_ms(int(timeout_secs * 1000));
|
||||
cv_.wait_for(
|
||||
g, timeout_ms, [this, canRead]() { return closing_ || canRead(); });
|
||||
} else {
|
||||
cv_.wait(g, [this, canRead]() { return closing_ || canRead(); });
|
||||
}
|
||||
if (!canRead()) {
|
||||
if (timeout_secs > 0 && !closing_) {
|
||||
LOG(ERROR) << "DequeueBlobs timed out in " << timeout_secs << " secs";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
DCHECK(canRead());
|
||||
auto& result = queue_[reader_ % queue_.size()];
|
||||
CAFFE_ENFORCE(inputs.size() >= result.size());
|
||||
for (auto i = 0; i < result.size(); ++i) {
|
||||
auto bytes = BlobStat::sizeBytes(*result[i]);
|
||||
CAFFE_EVENT(stats_, queue_dequeued_bytes, bytes, i);
|
||||
using std::swap;
|
||||
swap(*(inputs[i]), *(result[i]));
|
||||
}
|
||||
CAFFE_EVENT(stats_, queue_dequeued_records);
|
||||
++reader_;
|
||||
cv_.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BlobsQueue::tryWrite(const std::vector<Blob*>& inputs) {
|
||||
auto keeper = this->shared_from_this();
|
||||
std::unique_lock<std::mutex> g(mutex_);
|
||||
if (!canWrite()) {
|
||||
return false;
|
||||
}
|
||||
CAFFE_EVENT(stats_, queue_balance, 1);
|
||||
DCHECK(canWrite());
|
||||
doWrite(inputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BlobsQueue::blockingWrite(const std::vector<Blob*>& inputs) {
|
||||
auto keeper = this->shared_from_this();
|
||||
std::unique_lock<std::mutex> g(mutex_);
|
||||
CAFFE_EVENT(stats_, queue_balance, 1);
|
||||
cv_.wait(g, [this]() { return closing_ || canWrite(); });
|
||||
if (!canWrite()) {
|
||||
return false;
|
||||
}
|
||||
DCHECK(canWrite());
|
||||
doWrite(inputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
void BlobsQueue::close() {
|
||||
closing_ = true;
|
||||
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
bool BlobsQueue::canWrite() {
|
||||
// writer is always within [reader, reader + size)
|
||||
// we can write if reader is within [reader, reader + size)
|
||||
CAFFE_ENFORCE_LE(reader_, writer_);
|
||||
CAFFE_ENFORCE_LE(writer_, reader_ + queue_.size());
|
||||
return writer_ != reader_ + queue_.size();
|
||||
}
|
||||
|
||||
void BlobsQueue::doWrite(const std::vector<Blob*>& inputs) {
|
||||
auto& result = queue_[writer_ % queue_.size()];
|
||||
CAFFE_ENFORCE(inputs.size() >= result.size());
|
||||
for (auto i = 0; i < result.size(); ++i) {
|
||||
using std::swap;
|
||||
swap(*(inputs[i]), *(result[i]));
|
||||
}
|
||||
++writer_;
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
@ -28,32 +28,7 @@ class BlobsQueue : public std::enable_shared_from_this<BlobsQueue> {
|
|||
size_t capacity,
|
||||
size_t numBlobs,
|
||||
bool enforceUniqueName,
|
||||
const std::vector<std::string>& fieldNames = {})
|
||||
: numBlobs_(numBlobs), stats_(queueName) {
|
||||
if (!fieldNames.empty()) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
fieldNames.size(), numBlobs, "Wrong number of fieldNames provided.");
|
||||
stats_.queue_dequeued_bytes.setDetails(fieldNames);
|
||||
}
|
||||
queue_.reserve(capacity);
|
||||
for (auto i = 0; i < capacity; ++i) {
|
||||
std::vector<Blob*> blobs;
|
||||
blobs.reserve(numBlobs);
|
||||
for (auto j = 0; j < numBlobs; ++j) {
|
||||
const auto blobName =
|
||||
queueName + "_" + to_string(i) + "_" + to_string(j);
|
||||
if (enforceUniqueName) {
|
||||
CAFFE_ENFORCE(
|
||||
!ws->GetBlob(blobName),
|
||||
"Queue internal blob already exists: ",
|
||||
blobName);
|
||||
}
|
||||
blobs.push_back(ws->CreateBlob(blobName));
|
||||
}
|
||||
queue_.push_back(blobs);
|
||||
}
|
||||
DCHECK_EQ(queue_.size(), capacity);
|
||||
}
|
||||
const std::vector<std::string>& fieldNames = {});
|
||||
|
||||
~BlobsQueue() {
|
||||
close();
|
||||
|
|
@ -61,97 +36,17 @@ class BlobsQueue : public std::enable_shared_from_this<BlobsQueue> {
|
|||
|
||||
bool blockingRead(
|
||||
const std::vector<Blob*>& inputs,
|
||||
float timeout_secs = 0.0f) {
|
||||
auto keeper = this->shared_from_this();
|
||||
std::unique_lock<std::mutex> g(mutex_);
|
||||
auto canRead = [this]() {
|
||||
CAFFE_ENFORCE_LE(reader_, writer_);
|
||||
return reader_ != writer_;
|
||||
};
|
||||
CAFFE_EVENT(stats_, queue_balance, -1);
|
||||
if (timeout_secs > 0) {
|
||||
std::chrono::milliseconds timeout_ms(int(timeout_secs * 1000));
|
||||
cv_.wait_for(
|
||||
g, timeout_ms, [this, canRead]() { return closing_ || canRead(); });
|
||||
} else {
|
||||
cv_.wait(g, [this, canRead]() { return closing_ || canRead(); });
|
||||
}
|
||||
if (!canRead()) {
|
||||
if (timeout_secs > 0 && !closing_) {
|
||||
LOG(ERROR) << "DequeueBlobs timed out in " << timeout_secs << " secs";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
DCHECK(canRead());
|
||||
auto& result = queue_[reader_ % queue_.size()];
|
||||
CAFFE_ENFORCE(inputs.size() >= result.size());
|
||||
for (auto i = 0; i < result.size(); ++i) {
|
||||
auto bytes = BlobStat::sizeBytes(*result[i]);
|
||||
CAFFE_EVENT(stats_, queue_dequeued_bytes, bytes, i);
|
||||
using std::swap;
|
||||
swap(*(inputs[i]), *(result[i]));
|
||||
}
|
||||
CAFFE_EVENT(stats_, queue_dequeued_records);
|
||||
++reader_;
|
||||
cv_.notify_all();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool tryWrite(const std::vector<Blob*>& inputs) {
|
||||
auto keeper = this->shared_from_this();
|
||||
std::unique_lock<std::mutex> g(mutex_);
|
||||
if (!canWrite()) {
|
||||
return false;
|
||||
}
|
||||
CAFFE_EVENT(stats_, queue_balance, 1);
|
||||
DCHECK(canWrite());
|
||||
doWrite(inputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool blockingWrite(const std::vector<Blob*>& inputs) {
|
||||
auto keeper = this->shared_from_this();
|
||||
std::unique_lock<std::mutex> g(mutex_);
|
||||
CAFFE_EVENT(stats_, queue_balance, 1);
|
||||
cv_.wait(g, [this]() { return closing_ || canWrite(); });
|
||||
if (!canWrite()) {
|
||||
return false;
|
||||
}
|
||||
DCHECK(canWrite());
|
||||
doWrite(inputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
void close() {
|
||||
closing_ = true;
|
||||
|
||||
std::lock_guard<std::mutex> g(mutex_);
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
float timeout_secs = 0.0f);
|
||||
bool tryWrite(const std::vector<Blob*>& inputs);
|
||||
bool blockingWrite(const std::vector<Blob*>& inputs);
|
||||
void close();
|
||||
size_t getNumBlobs() const {
|
||||
return numBlobs_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool canWrite() {
|
||||
// writer is always within [reader, reader + size)
|
||||
// we can write if reader is within [reader, reader + size)
|
||||
CAFFE_ENFORCE_LE(reader_, writer_);
|
||||
CAFFE_ENFORCE_LE(writer_, reader_ + queue_.size());
|
||||
return writer_ != reader_ + queue_.size();
|
||||
}
|
||||
|
||||
void doWrite(const std::vector<Blob*>& inputs) {
|
||||
auto& result = queue_[writer_ % queue_.size()];
|
||||
CAFFE_ENFORCE(inputs.size() >= result.size());
|
||||
for (auto i = 0; i < result.size(); ++i) {
|
||||
using std::swap;
|
||||
swap(*(inputs[i]), *(result[i]));
|
||||
}
|
||||
++writer_;
|
||||
cv_.notify_all();
|
||||
}
|
||||
bool canWrite();
|
||||
void doWrite(const std::vector<Blob*>& inputs);
|
||||
|
||||
std::atomic<bool> closing_{false};
|
||||
|
||||
|
|
@ -169,4 +64,4 @@ class BlobsQueue : public std::enable_shared_from_this<BlobsQueue> {
|
|||
CAFFE_DETAILED_EXPORTED_STAT(queue_dequeued_bytes);
|
||||
} stats_;
|
||||
};
|
||||
}
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include "blobs_queue_db.h"
|
||||
#include "caffe2/queue/blobs_queue_db.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
|
|
|
|||
2
third_party/cub
vendored
2
third_party/cub
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit 01347a797c620618d09e7d2d90bce4be4c42513e
|
||||
Subproject commit b1370adb972a8345de92e2d69e61daf7f9bce43e
|
||||
Loading…
Reference in a new issue