2017-09-14 05:41:54 +00:00
|
|
|
#pragma once
|
|
|
|
|
|
Canonicalize all includes in PyTorch. (#14849)
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
2018-12-09 03:32:01 +00:00
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
|
|
2020-03-26 18:15:49 +00:00
|
|
|
#include <ATen/core/interned_strings.h>
|
|
|
|
|
#include <ATen/core/ivalue.h>
|
2021-07-28 20:28:39 +00:00
|
|
|
#include <c10/util/irange.h>
|
Canonicalize all includes in PyTorch. (#14849)
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
2018-12-09 03:32:01 +00:00
|
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
|
|
|
#include <torch/csrc/THP.h>
|
|
|
|
|
#include <torch/csrc/autograd/variable.h>
|
2020-02-27 20:18:24 +00:00
|
|
|
#include <torch/csrc/jit/frontend/tracer.h>
|
2020-03-26 18:15:49 +00:00
|
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
2018-12-26 14:52:25 +00:00
|
|
|
#include <torch/csrc/utils/pybind.h>
|
2018-02-13 04:26:26 +00:00
|
|
|
|
2018-03-01 03:45:04 +00:00
|
|
|
#include <pybind11/functional.h>
|
2018-02-13 04:26:26 +00:00
|
|
|
#include <pybind11/pybind11.h>
|
|
|
|
|
#include <pybind11/stl.h>
|
2017-09-14 05:41:54 +00:00
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
|
|
2021-02-04 09:32:42 +00:00
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
|
|
|
|
|
|
|
|
|
// This is a variant of shared_ptr that "sees through" a wrapper.
|
|
|
|
|
// We use it to convert Value, Node, Block and node to "wrapped" Python
|
|
|
|
|
// values. When we destruct the C++ object, the wrapper's pointer will
|
|
|
|
|
// be set to 0 and any future dereferencing will throw. We need this
|
|
|
|
|
// because the Python objects may hang around after the C++ object
|
|
|
|
|
// has already been destroyed.
|
|
|
|
|
// This also needs the magic type_caster below, which is from the
|
|
|
|
|
// workaround offered in https://github.com/pybind/pybind11/issues/2751
|
|
|
|
|
template <typename T>
|
|
|
|
|
class unwrapping_shared_ptr {
|
|
|
|
|
static_assert(
|
|
|
|
|
std::is_same<T, torch::jit::Value>::value ||
|
|
|
|
|
std::is_same<T, torch::jit::Node>::value ||
|
|
|
|
|
std::is_same<T, torch::jit::Block>::value,
|
|
|
|
|
"unwrapping type only defined for Graph object types");
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<torch::jit::Wrap<T>> impl;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
unwrapping_shared_ptr() : impl({}) {}
|
|
|
|
|
explicit unwrapping_shared_ptr(T* p) : impl(p->wrap()) {
|
|
|
|
|
impl->clear_cb = &clear_registered_instances;
|
|
|
|
|
}
|
|
|
|
|
T* get() const {
|
|
|
|
|
if (!impl->elem) {
|
|
|
|
|
throw std::logic_error("has been invalidated");
|
|
|
|
|
}
|
|
|
|
|
return impl->elem;
|
|
|
|
|
}
|
|
|
|
|
// we need to disable the overloaded & for PyBind11 < 2.3 due.
|
|
|
|
|
// see https://github.com/pybind/pybind11/pull/1435
|
|
|
|
|
#if (PYBIND11_VERSION_MAJOR > 2) || \
|
|
|
|
|
((PYBIND11_VERSION_MAJOR == 2) && (PYBIND11_VERSION_MINOR >= 3))
|
|
|
|
|
T** operator&() {
|
|
|
|
|
if (!impl->elem) {
|
|
|
|
|
throw std::logic_error("has been invalidated");
|
|
|
|
|
}
|
|
|
|
|
return &(impl->elem);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|
|
|
|
|
|
|
|
|
|
PYBIND11_DECLARE_HOLDER_TYPE(T, torch::jit::unwrapping_shared_ptr<T>, true);
|
|
|
|
|
|
|
|
|
|
namespace pybind11 {
|
|
|
|
|
namespace detail {
|
|
|
|
|
|
|
|
|
|
#define CREATE_UNWRAPPING_CASTER(Class) \
|
|
|
|
|
template <> \
|
|
|
|
|
struct type_caster<Class> : public type_caster_base<Class> { \
|
|
|
|
|
public: \
|
|
|
|
|
using type = Class; \
|
|
|
|
|
using holder_type = torch::jit::unwrapping_shared_ptr<Class>; \
|
|
|
|
|
\
|
|
|
|
|
bool load(handle src, bool convert) { \
|
|
|
|
|
return load_impl<type_caster<Class>>(src, convert); \
|
|
|
|
|
} \
|
|
|
|
|
\
|
|
|
|
|
explicit operator type*() { \
|
|
|
|
|
return static_cast<type*>(value); \
|
|
|
|
|
} \
|
|
|
|
|
explicit operator type&() { \
|
|
|
|
|
return *static_cast<type*>(value); \
|
|
|
|
|
} \
|
|
|
|
|
\
|
|
|
|
|
protected: \
|
|
|
|
|
friend class type_caster_generic; \
|
|
|
|
|
\
|
|
|
|
|
bool load_value(value_and_holder&& v_h) { \
|
|
|
|
|
if (v_h.holder_constructed()) { \
|
|
|
|
|
value = v_h.template holder<holder_type>().get(); \
|
|
|
|
|
return true; \
|
|
|
|
|
} else { \
|
|
|
|
|
throw cast_error( \
|
|
|
|
|
"Unable to cast from non-held to held instance (#Class& to Holder<#Class>)"); \
|
|
|
|
|
} \
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CREATE_UNWRAPPING_CASTER(torch::jit::Node);
|
|
|
|
|
CREATE_UNWRAPPING_CASTER(torch::jit::Value);
|
|
|
|
|
CREATE_UNWRAPPING_CASTER(torch::jit::Block);
|
|
|
|
|
|
|
|
|
|
#undef CREATE_UNWRAPPING_CASTER
|
|
|
|
|
|
|
|
|
|
} // namespace detail
|
|
|
|
|
} // namespace pybind11
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
namespace pybind11 {
|
|
|
|
|
namespace detail {
|
2017-09-14 05:41:54 +00:00
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
template <>
|
|
|
|
|
struct type_caster<torch::jit::IValue> {
|
|
|
|
|
public:
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
2018-08-22 22:21:04 +00:00
|
|
|
PYBIND11_TYPE_CASTER(torch::jit::IValue, _("IValue"));
|
|
|
|
|
|
|
|
|
|
bool load(handle src, bool) {
|
2018-09-11 12:56:17 +00:00
|
|
|
try {
|
2019-10-16 18:05:32 +00:00
|
|
|
value = torch::jit::toTypeInferredIValue(src);
|
2018-09-11 12:56:17 +00:00
|
|
|
return true;
|
|
|
|
|
} catch (std::exception& e) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
2018-08-22 22:21:04 +00:00
|
|
|
}
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
static handle cast(
|
|
|
|
|
torch::jit::IValue src,
|
|
|
|
|
return_value_policy /* policy */,
|
|
|
|
|
handle /* parent */) {
|
2018-10-16 08:23:08 +00:00
|
|
|
return torch::jit::toPyObject(std::move(src)).release();
|
2018-08-22 22:21:04 +00:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
template <>
|
|
|
|
|
struct type_caster<torch::jit::Symbol> {
|
|
|
|
|
public:
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
2017-09-14 05:41:54 +00:00
|
|
|
PYBIND11_TYPE_CASTER(torch::jit::Symbol, _("Symbol"));
|
|
|
|
|
|
|
|
|
|
bool load(handle src, bool) {
|
Namespaced symbols (#5820)
* Namespaced symbols
- Our interned strings now have structure, "ns::symname" rather than just
"symname" before. We support efficient namespace testing for uniques
by encoding the namespace in one byte in the Symbol internal representation.
See torch/csrc/jit/interned_strings.h for a more in-depth implementation
discussion.
- All uses of ksymbol are now attr::symbol (or some appropriate namespace).
The valid namespaces are prim, attr, onnx and aten.
- Symbol is bound in Python as a qualified string "attr::symbol", EXCEPT for the
attribute setting/getting API, whose symbols must always be attr
symbols; they get special cased to assume strings are passed.
There's a little bit of naughtiness in the implementation, maybe you know
how to solve it.
- However, the g.op() convenience function assumes that you're generating
ONNX operators, unless you explicitly qualify.
- All ATen operators and nodes have built-in interned strings generated
for them, so you should never have to write a string literal ever again.
The tracing code is adjusted to use it.
- ONNX exporter now properly tests to see that all operators are in
onnx namespace before accepting the export. This is way more
robust than the previous exporter, which would be willing to
export capitalized operators which were not actually ONNX operators.
- A slight organizational change for symbolic.py; this module now ONLY
contains aten operators. In particular, the exporter for Constant
has moved into utils.py (along with Undefined, from the C++ side),
since primitive ops get "special treatment."
- The un-inplacing logic in recording is more robust, so that we don't
delete a trailing underscore from __and__. This never affected us
before because we didn't have any tests for it.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-03-16 17:36:11 +00:00
|
|
|
// TODO: Is there a way to py::cast that doesn't raise an exception on
|
|
|
|
|
// failure? Can we catch pybind11::cast_error here instead?
|
|
|
|
|
std::string src_str;
|
2017-09-14 05:41:54 +00:00
|
|
|
try {
|
Namespaced symbols (#5820)
* Namespaced symbols
- Our interned strings now have structure, "ns::symname" rather than just
"symname" before. We support efficient namespace testing for uniques
by encoding the namespace in one byte in the Symbol internal representation.
See torch/csrc/jit/interned_strings.h for a more in-depth implementation
discussion.
- All uses of ksymbol are now attr::symbol (or some appropriate namespace).
The valid namespaces are prim, attr, onnx and aten.
- Symbol is bound in Python as a qualified string "attr::symbol", EXCEPT for the
attribute setting/getting API, whose symbols must always be attr
symbols; they get special cased to assume strings are passed.
There's a little bit of naughtiness in the implementation, maybe you know
how to solve it.
- However, the g.op() convenience function assumes that you're generating
ONNX operators, unless you explicitly qualify.
- All ATen operators and nodes have built-in interned strings generated
for them, so you should never have to write a string literal ever again.
The tracing code is adjusted to use it.
- ONNX exporter now properly tests to see that all operators are in
onnx namespace before accepting the export. This is way more
robust than the previous exporter, which would be willing to
export capitalized operators which were not actually ONNX operators.
- A slight organizational change for symbolic.py; this module now ONLY
contains aten operators. In particular, the exporter for Constant
has moved into utils.py (along with Undefined, from the C++ side),
since primitive ops get "special treatment."
- The un-inplacing logic in recording is more robust, so that we don't
delete a trailing underscore from __and__. This never affected us
before because we didn't have any tests for it.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-03-16 17:36:11 +00:00
|
|
|
src_str = py::cast<std::string>(src);
|
2017-09-14 05:41:54 +00:00
|
|
|
} catch (std::exception& e) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
Namespaced symbols (#5820)
* Namespaced symbols
- Our interned strings now have structure, "ns::symname" rather than just
"symname" before. We support efficient namespace testing for uniques
by encoding the namespace in one byte in the Symbol internal representation.
See torch/csrc/jit/interned_strings.h for a more in-depth implementation
discussion.
- All uses of ksymbol are now attr::symbol (or some appropriate namespace).
The valid namespaces are prim, attr, onnx and aten.
- Symbol is bound in Python as a qualified string "attr::symbol", EXCEPT for the
attribute setting/getting API, whose symbols must always be attr
symbols; they get special cased to assume strings are passed.
There's a little bit of naughtiness in the implementation, maybe you know
how to solve it.
- However, the g.op() convenience function assumes that you're generating
ONNX operators, unless you explicitly qualify.
- All ATen operators and nodes have built-in interned strings generated
for them, so you should never have to write a string literal ever again.
The tracing code is adjusted to use it.
- ONNX exporter now properly tests to see that all operators are in
onnx namespace before accepting the export. This is way more
robust than the previous exporter, which would be willing to
export capitalized operators which were not actually ONNX operators.
- A slight organizational change for symbolic.py; this module now ONLY
contains aten operators. In particular, the exporter for Constant
has moved into utils.py (along with Undefined, from the C++ side),
since primitive ops get "special treatment."
- The un-inplacing logic in recording is more robust, so that we don't
delete a trailing underscore from __and__. This never affected us
before because we didn't have any tests for it.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-03-16 17:36:11 +00:00
|
|
|
value = torch::jit::Symbol::fromQualString(src_str);
|
2017-09-14 05:41:54 +00:00
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
static handle cast(
|
|
|
|
|
torch::jit::Symbol src,
|
|
|
|
|
return_value_policy /* policy */,
|
|
|
|
|
handle /* parent */) {
|
|
|
|
|
return py::cast(std::string(src.toQualString()), return_value_policy::copy)
|
|
|
|
|
.release();
|
2017-09-14 05:41:54 +00:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
template <>
|
|
|
|
|
struct type_caster<torch::jit::AttributeKind> {
|
|
|
|
|
public:
|
Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os
def get_compiled_files_list():
import json
with open("build/compile_commands.json") as f:
data = json.load(f)
files = [os.path.relpath(node['file']) for node in data]
for idx, fname in enumerate(files):
if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
return files
def run_clang_tidy(fname):
check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
changes = check_output(["git", "ls-files", "-m"])
if len(changes) == 0:
return
check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])
def main():
git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
compiled_files = get_compiled_files_list()
for idx, fname in enumerate(git_files):
if fname not in compiled_files:
continue
if fname.startswith("caffe2/contrib/aten/"):
continue
print(f"[{idx}/{len(git_files)}] Processing {fname}")
run_clang_tidy(fname)
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892
Reviewed By: H-Huang
Differential Revision: D27991944
Pulled By: malfet
fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 21:09:06 +00:00
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
2017-09-14 05:41:54 +00:00
|
|
|
PYBIND11_TYPE_CASTER(torch::jit::AttributeKind, _("AttributeKind"));
|
|
|
|
|
|
|
|
|
|
bool load(handle src, bool) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
static handle cast(
|
|
|
|
|
torch::jit::AttributeKind src,
|
|
|
|
|
return_value_policy /* policy */,
|
|
|
|
|
handle /* parent */) {
|
|
|
|
|
return py::cast(
|
|
|
|
|
std::string(torch::jit::toString(src)),
|
|
|
|
|
return_value_policy::copy)
|
|
|
|
|
.release();
|
2017-09-14 05:41:54 +00:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
Improve const-correctness of JIT.
This started off as a minor fix based on Adam's question, "why is printing
a graph not const" and snowballed into a giant yak shaving exercise.
- The Graph and Node APIs now uniformly enforce deep constness; e.g., if you
get a const Node* or const Graph*, it is not possible to get a non-const
Node*/Graph* somewhere else in the graph (even though the member variables
of these are non-const. Hooray for private access specifier.)
- A big pile of functions got const versions, most notably the printing
functions, and functions for accessing inputs().
- REALLY IMPORTANT, BC-BREAKING CHANGE: inputs() now returns a COPY of the
inputs, rather than a reference to the underlying. I was forced to do this
because there is no way to portably turn a std::vector<Node*> into a
std::vector<const Node*>, which is necessary to provide a const-correct
version of inputs() that enforces deep const-correctness. I then justified
this choice to myself with the observation that outputs() returned a
copy (by necessity), so this makes the API more uniform.
But making this change uncovered two very subtle bugs:
1. If you change functions from returning a reference to returning a copy,
the idiom node->inputs().begin() is no longer valid, because the memory
the iterator points to immediately becomes invalid. THIS SUCKS.
Honestly, we should add a lint rule rejecting calling begin()/end() on
temporaries because this is very dangerous. To excise this pattern from
the codebase, I added begin() and end() methods to Graph, so that we got
rid of the graph->nodes().begin() idiom, which happens to be sound,
despite not returning a reference, because graph_node_list is a
non-owning reference.
2. pybind11 doesn't handle std::vector<Node*> cast out of the box.
Fortunately, I found a simple fix in the GitHub issues tracker
that involved adding an extra type converter. And yes, this
does mean that outputs() in Python never worked correctly.
- New const_graph_node_list, which is a graph_node_list that gives you const
Node*
There are some more miscellaneous improvements:
- Applied CR comment fixes on export.cpp; using replaceInput, and renaming
variables for clarity.
- assertValidInput helper method added, and applied to replaceInput
- Use an explicit function to print THPObjectPtr, otherwise we get
the wrong overload.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-10-30 03:24:58 +00:00
|
|
|
// See https://github.com/pybind/pybind11/issues/637
|
2018-12-26 14:52:25 +00:00
|
|
|
using ListCasterBase = pybind11::detail::
|
|
|
|
|
list_caster<std::vector<torch::jit::Node*>, torch::jit::Node*>;
|
|
|
|
|
template <>
|
|
|
|
|
struct type_caster<std::vector<torch::jit::Node*>> : ListCasterBase {
|
|
|
|
|
static handle cast(
|
|
|
|
|
const std::vector<torch::jit::Node*>& src,
|
|
|
|
|
return_value_policy,
|
|
|
|
|
handle parent) {
|
|
|
|
|
return ListCasterBase::cast(src, return_value_policy::reference, parent);
|
|
|
|
|
}
|
|
|
|
|
static handle cast(
|
|
|
|
|
const std::vector<torch::jit::Node*>* src,
|
|
|
|
|
return_value_policy pol,
|
|
|
|
|
handle parent) {
|
|
|
|
|
return cast(*src, pol, parent);
|
|
|
|
|
}
|
Improve const-correctness of JIT.
This started off as a minor fix based on Adam's question, "why is printing
a graph not const" and snowballed into a giant yak shaving exercise.
- The Graph and Node APIs now uniformly enforce deep constness; e.g., if you
get a const Node* or const Graph*, it is not possible to get a non-const
Node*/Graph* somewhere else in the graph (even though the member variables
of these are non-const. Hooray for private access specifier.)
- A big pile of functions got const versions, most notably the printing
functions, and functions for accessing inputs().
- REALLY IMPORTANT, BC-BREAKING CHANGE: inputs() now returns a COPY of the
inputs, rather than a reference to the underlying. I was forced to do this
because there is no way to portably turn a std::vector<Node*> into a
std::vector<const Node*>, which is necessary to provide a const-correct
version of inputs() that enforces deep const-correctness. I then justified
this choice to myself with the observation that outputs() returned a
copy (by necessity), so this makes the API more uniform.
But making this change uncovered two very subtle bugs:
1. If you change functions from returning a reference to returning a copy,
the idiom node->inputs().begin() is no longer valid, because the memory
the iterator points to immediately becomes invalid. THIS SUCKS.
Honestly, we should add a lint rule rejecting calling begin()/end() on
temporaries because this is very dangerous. To excise this pattern from
the codebase, I added begin() and end() methods to Graph, so that we got
rid of the graph->nodes().begin() idiom, which happens to be sound,
despite not returning a reference, because graph_node_list is a
non-owning reference.
2. pybind11 doesn't handle std::vector<Node*> cast out of the box.
Fortunately, I found a simple fix in the GitHub issues tracker
that involved adding an extra type converter. And yes, this
does mean that outputs() in Python never worked correctly.
- New const_graph_node_list, which is a graph_node_list that gives you const
Node*
There are some more miscellaneous improvements:
- Applied CR comment fixes on export.cpp; using replaceInput, and renaming
variables for clarity.
- assertValidInput helper method added, and applied to replaceInput
- Use an explicit function to print THPObjectPtr, otherwise we get
the wrong overload.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-10-30 03:24:58 +00:00
|
|
|
};
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
} // namespace detail
|
|
|
|
|
} // namespace pybind11
|
2018-02-03 01:45:59 +00:00
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
2018-02-03 01:45:59 +00:00
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
static inline py::tuple tuple_tail(const py::tuple& tup) {
|
2018-02-03 01:45:59 +00:00
|
|
|
py::tuple r(tup.size() - 1);
|
2021-07-28 20:28:39 +00:00
|
|
|
for (const auto i : c10::irange(1, tup.size())) {
|
2018-12-26 14:52:25 +00:00
|
|
|
r[i - 1] = tup[i];
|
2018-02-03 01:45:59 +00:00
|
|
|
}
|
|
|
|
|
return r;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|