mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16751 This was made more complicated by the fact that ivalue::IntList is a thing. So I had to fix all of the sites where we referring to IValue post facto. The following codemods were run, in this order: ``` codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in IntList IntArrayRef codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in IntArrayRef::create IntList::create codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in ivalue::IntArrayRef ivalue::IntList codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in Tag::IntArrayRef Tag::IntList codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in isIntArrayRef isIntList codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in toIntArrayRef toIntList codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in 'Shared<IntArrayRef>' 'Shared<IntList>' codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in 'intrusive_ptr<IntArrayRef>' 'intrusive_ptr<IntList>' ``` Some manual fixups were done afterwards; they can be reviewed separately at https://github.com/pytorch/pytorch/pull/16752 Reviewed By: dzhulgakov Differential Revision: D13954363 fbshipit-source-id: b5c40aacba042402155a2f5a229fa6db7992ac64
85 lines
2.2 KiB
C++
85 lines
2.2 KiB
C++
#include <torch/extension.h>
|
|
|
|
#include <ATen/CPUFloatType.h>
|
|
#include <ATen/Type.h>
|
|
#include <ATen/core/VariableHooksInterface.h>
|
|
#include <ATen/detail/ComplexHooksInterface.h>
|
|
|
|
#include <c10/core/Allocator.h>
|
|
#include <ATen/CPUGenerator.h>
|
|
#include <ATen/DeviceGuard.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#include <ATen/Utils.h>
|
|
#include <ATen/WrapDimUtils.h>
|
|
#include <c10/util/Half.h>
|
|
#include <c10/core/TensorImpl.h>
|
|
#include <c10/core/UndefinedTensorImpl.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
#include <cstddef>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <utility>
|
|
|
|
#include <ATen/Config.h>
|
|
|
|
namespace at {
|
|
|
|
struct CPUComplexFloatType : public at::CPUTypeDefault {
|
|
CPUComplexFloatType()
|
|
: CPUTypeDefault(
|
|
CPUTensorId(),
|
|
/*is_variable=*/false,
|
|
/*is_undefined=*/false) {}
|
|
|
|
ScalarType scalarType() const override;
|
|
caffe2::TypeMeta typeMeta() const override;
|
|
Backend backend() const override;
|
|
const char* toString() const override;
|
|
size_t elementSizeInBytes() const override;
|
|
TypeID ID() const override;
|
|
|
|
Tensor empty(IntArrayRef size, const TensorOptions & options) const override {
|
|
// Delegate to the appropriate cpu tensor factory
|
|
const DeviceGuard device_guard(options.device());
|
|
return at::native::empty_cpu(/* actuals */ size, options);
|
|
}
|
|
};
|
|
|
|
struct ComplexHooks : public at::ComplexHooksInterface {
|
|
ComplexHooks(ComplexHooksArgs) {}
|
|
void registerComplexTypes(Context* context) const override {
|
|
context->registerType(
|
|
Backend::CPU, ScalarType::ComplexFloat, new CPUComplexFloatType());
|
|
}
|
|
};
|
|
|
|
ScalarType CPUComplexFloatType::scalarType() const {
|
|
return ScalarType::ComplexFloat;
|
|
}
|
|
|
|
caffe2::TypeMeta CPUComplexFloatType::typeMeta() const {
|
|
return scalarTypeToTypeMeta(ScalarType::ComplexFloat);
|
|
}
|
|
|
|
Backend CPUComplexFloatType::backend() const {
|
|
return Backend::CPU;
|
|
}
|
|
|
|
const char* CPUComplexFloatType::toString() const {
|
|
return "CPUComplexFloatType";
|
|
}
|
|
|
|
TypeID CPUComplexFloatType::ID() const {
|
|
return TypeID::CPUComplexFloat;
|
|
}
|
|
|
|
size_t CPUComplexFloatType::elementSizeInBytes() const {
|
|
return sizeof(float);
|
|
}
|
|
|
|
REGISTER_COMPLEX_HOOKS(ComplexHooks);
|
|
|
|
} // namespace at
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { }
|