mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26572 Combined with isinstance specialization this allows a degree of polymorphic functions to work without needing to use our weirder overload hacks. We do not define any operators on Any, so the only thing you can do with it is to put it in containers or type refine it using an isinstance check. Any is restricted from appearing in non-argument position because we cannot restore type tags if it ends up as a field in a class. Test Plan: Imported from OSS Differential Revision: D17530643 Pulled By: zdevito fbshipit-source-id: f06f78ce84819f7773953a492f3d4c49219ee94c
126 lines
3 KiB
C++
126 lines
3 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/THP.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <ATen/core/interned_strings.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <torch/csrc/jit/pybind_utils.h>
|
|
#include <torch/csrc/jit/tracer.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
#include <pybind11/functional.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/stl.h>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace pybind11 {
|
|
namespace detail {
|
|
|
|
template <>
|
|
struct type_caster<torch::jit::IValue> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(torch::jit::IValue, _("IValue"));
|
|
|
|
bool load(handle src, bool) {
|
|
try {
|
|
value = torch::jit::toTypeInferredIValue(src);
|
|
return true;
|
|
} catch (std::exception& e) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
static handle cast(
|
|
torch::jit::IValue src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return torch::jit::toPyObject(std::move(src)).release();
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct type_caster<torch::jit::Symbol> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(torch::jit::Symbol, _("Symbol"));
|
|
|
|
bool load(handle src, bool) {
|
|
// TODO: Is there a way to py::cast that doesn't raise an exception on
|
|
// failure? Can we catch pybind11::cast_error here instead?
|
|
std::string src_str;
|
|
try {
|
|
src_str = py::cast<std::string>(src);
|
|
} catch (std::exception& e) {
|
|
return false;
|
|
}
|
|
value = torch::jit::Symbol::fromQualString(src_str);
|
|
return true;
|
|
}
|
|
|
|
static handle cast(
|
|
torch::jit::Symbol src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return py::cast(std::string(src.toQualString()), return_value_policy::copy)
|
|
.release();
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct type_caster<torch::jit::AttributeKind> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(torch::jit::AttributeKind, _("AttributeKind"));
|
|
|
|
bool load(handle src, bool) {
|
|
return false;
|
|
}
|
|
|
|
static handle cast(
|
|
torch::jit::AttributeKind src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return py::cast(
|
|
std::string(torch::jit::toString(src)),
|
|
return_value_policy::copy)
|
|
.release();
|
|
}
|
|
};
|
|
|
|
// See https://github.com/pybind/pybind11/issues/637
|
|
using ListCasterBase = pybind11::detail::
|
|
list_caster<std::vector<torch::jit::Node*>, torch::jit::Node*>;
|
|
template <>
|
|
struct type_caster<std::vector<torch::jit::Node*>> : ListCasterBase {
|
|
static handle cast(
|
|
const std::vector<torch::jit::Node*>& src,
|
|
return_value_policy,
|
|
handle parent) {
|
|
return ListCasterBase::cast(src, return_value_policy::reference, parent);
|
|
}
|
|
static handle cast(
|
|
const std::vector<torch::jit::Node*>* src,
|
|
return_value_policy pol,
|
|
handle parent) {
|
|
return cast(*src, pol, parent);
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
} // namespace pybind11
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
static inline py::tuple tuple_tail(const py::tuple& tup) {
|
|
py::tuple r(tup.size() - 1);
|
|
for (size_t i = 1; i < tup.size(); i++) {
|
|
r[i - 1] = tup[i];
|
|
}
|
|
return r;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|