mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9939 Pull Request resolved: https://github.com/facebookresearch/weakly-supervised-action-detection/pull/13 Pull Request resolved: https://github.com/pytorch/translate/pull/166 Pull Request resolved: https://github.com/pytorch/pytorch/pull/9125 Closes https://github.com/pytorch/pytorch/pull/9125 Use inheritance for polymorphism, and remove template parameter This is to change the templating in call sites, the core implementations will change later Before Caffe2 Tensor class was compile-time fixed to bind to a particular device/context. With this change, we're making it a runtime property (stored inside the tensor), but preserve the same semantics. For example, one has to specify device type in order to create a Tensor - there are no uninitialized tensors. More specifically the changes are: 1. We added an extra argument *DeviceType* to most of the constructors of the tensor, e.g. (Tensor(DeviceType type)), 2. Semantics of constructor Tensor(const Tensor<SrcContext>& src, ContextForCopy* context); is changed, in this constructor, the second context is passed in to enable us to call the templated Copy function, it could be in a different context as source and target previously, now we'll enforce that the context should have same device type as src, if it is provided. 3. To preserve 'get-or-construct' semantics of Blob, we added specialized getter Blob::GetMutableTensor that verifies both that Blob contains a Tensor and that it's of a correct type 4. Specifically, Tensor type is not default-constructible any more (as we don't have unknown device tensors) and thus some of the code handling STL containers needs to change Note: Some changes are postponed just to keep this diff a bit smaller. Please see `TODO`s. Reviewed By: ezyang, houseroad Differential Revision: D9024330 fbshipit-source-id: e0b8295d2dc6ebe2963383ded5af799ad17164ba
143 lines
4.4 KiB
C++
143 lines
4.4 KiB
C++
#include <mutex>
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
|
|
namespace caffe2 {
|
|
namespace fb {
|
|
namespace {
|
|
|
|
class CreateMutexOp final : public Operator<CPUContext> {
|
|
public:
|
|
CreateMutexOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<CPUContext>(operator_def, ws) {}
|
|
|
|
bool RunOnDevice() override {
|
|
*OperatorBase::Output<std::unique_ptr<std::mutex>>(0) =
|
|
std::unique_ptr<std::mutex>(new std::mutex);
|
|
return true;
|
|
}
|
|
};
|
|
|
|
class AtomicFetchAddOp final : public Operator<CPUContext> {
|
|
public:
|
|
AtomicFetchAddOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<CPUContext>(operator_def, ws) {}
|
|
|
|
bool RunOnDevice() override {
|
|
auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
|
|
auto& a = Input(1);
|
|
auto& b = Input(2);
|
|
auto* c = Output(0);
|
|
auto* d = Output(1);
|
|
c->Resize(std::vector<TIndex>());
|
|
d->Resize(std::vector<TIndex>());
|
|
auto* aPtr = a.data<int32_t>();
|
|
auto* bPtr = b.data<int32_t>();
|
|
auto* cPtr = c->template mutable_data<int32_t>();
|
|
auto* dPtr = d->template mutable_data<int32_t>();
|
|
std::lock_guard<std::mutex> lg(*mutex);
|
|
*dPtr = *aPtr;
|
|
*cPtr = *aPtr + *bPtr;
|
|
return true;
|
|
}
|
|
};
|
|
|
|
class CreateAtomicBoolOp final : public Operator<CPUContext> {
|
|
public:
|
|
using Operator::Operator;
|
|
|
|
bool RunOnDevice() override {
|
|
*OperatorBase::Output<std::unique_ptr<std::atomic<bool>>>(0) =
|
|
std::unique_ptr<std::atomic<bool>>(new std::atomic<bool>(false));
|
|
return true;
|
|
}
|
|
};
|
|
|
|
class ConditionalSetAtomicBoolOp final : public Operator<CPUContext> {
|
|
public:
|
|
using Operator::Operator;
|
|
|
|
bool RunOnDevice() override {
|
|
auto& ptr =
|
|
OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(ATOMIC_BOOL);
|
|
if (Input(CONDITION).data<bool>()[0]) {
|
|
ptr->store(true);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
INPUT_TAGS(ATOMIC_BOOL, CONDITION);
|
|
};
|
|
|
|
class CheckAtomicBoolOp final : public Operator<CPUContext> {
|
|
public:
|
|
using Operator::Operator;
|
|
|
|
bool RunOnDevice() override {
|
|
auto& ptr = OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(0);
|
|
Output(0)->Resize(1);
|
|
*Output(0)->template mutable_data<bool>() = ptr->load();
|
|
return true;
|
|
}
|
|
};
|
|
|
|
REGISTER_CPU_OPERATOR(CreateMutex, CreateMutexOp);
|
|
REGISTER_CPU_OPERATOR(AtomicFetchAdd, AtomicFetchAddOp);
|
|
|
|
REGISTER_CPU_OPERATOR(CreateAtomicBool, CreateAtomicBoolOp);
|
|
REGISTER_CPU_OPERATOR(ConditionalSetAtomicBool, ConditionalSetAtomicBoolOp);
|
|
REGISTER_CPU_OPERATOR(CheckAtomicBool, CheckAtomicBoolOp);
|
|
|
|
OPERATOR_SCHEMA(CreateMutex)
|
|
.NumInputs(0)
|
|
.NumOutputs(1)
|
|
.SetDoc("Creates an unlocked mutex and returns it in a unique_ptr blob.")
|
|
.Output(0, "mutex_ptr", "Blob containing a std::unique_ptr<mutex>.")
|
|
.ScalarType(TensorProto_DataType_UNDEFINED);
|
|
|
|
OPERATOR_SCHEMA(AtomicFetchAdd)
|
|
.NumInputs(3)
|
|
.NumOutputs(2)
|
|
.SetDoc(R"DOC(
|
|
Given a mutex and two int32 scalar tensors, performs an atomic fetch add
|
|
by mutating the first argument and adding it to the second input
|
|
argument. Returns the updated integer and the value prior to the update.
|
|
)DOC")
|
|
.Input(0, "mutex_ptr", "Blob containing to a unique_ptr<mutex>")
|
|
.Input(1, "mut_value", "Value to be mutated after the sum.")
|
|
.Input(2, "increment", "Value to add to the first operand.")
|
|
.Output(0, "mut_value", "Mutated value after sum. Usually same as input 1.")
|
|
.Output(1, "fetched_value", "Value of the first operand before sum.")
|
|
.AllowInplace({{1, 0}});
|
|
|
|
OPERATOR_SCHEMA(CreateAtomicBool)
|
|
.NumInputs(0)
|
|
.NumOutputs(1)
|
|
.SetDoc("Create an unique_ptr blob to hold an atomic<bool>")
|
|
.Output(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>");
|
|
|
|
OPERATOR_SCHEMA(ConditionalSetAtomicBool)
|
|
.NumInputs(2)
|
|
.NumOutputs(0)
|
|
.SetDoc(R"DOC(
|
|
Set an atomic<bool> to true if the given condition bool variable is true
|
|
)DOC")
|
|
.Input(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>")
|
|
.Input(1, "condition", "Blob containing a bool");
|
|
|
|
OPERATOR_SCHEMA(CheckAtomicBool)
|
|
.NumInputs(1)
|
|
.NumOutputs(1)
|
|
.SetDoc("Copy the value of an atomic<bool> to a bool")
|
|
.Input(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>")
|
|
.Output(0, "value", "Copy of the value for the atomic<bool>");
|
|
|
|
SHOULD_NOT_DO_GRADIENT(CreateMutex);
|
|
SHOULD_NOT_DO_GRADIENT(AtomicFetchAdd);
|
|
SHOULD_NOT_DO_GRADIENT(CreateAtomicBool);
|
|
SHOULD_NOT_DO_GRADIENT(ConditionalSetAtomicBool);
|
|
SHOULD_NOT_DO_GRADIENT(CheckAtomicBool);
|
|
}
|
|
}
|
|
}
|