2020-02-27 20:18:24 +00:00
|
|
|
#include <torch/csrc/jit/api/module.h>
|
2020-05-07 21:46:41 +00:00
|
|
|
#include <ATen/record_function.h>
|
2019-01-24 19:05:07 +00:00
|
|
|
#include <c10/util/Exception.h>
|
2019-05-10 20:01:15 +00:00
|
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
2020-02-27 20:18:24 +00:00
|
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
|
|
|
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
|
|
|
|
#include <torch/csrc/jit/frontend/schema_matching.h>
|
2020-03-26 18:15:49 +00:00
|
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
|
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
|
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
|
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
2018-03-12 13:52:40 +00:00
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
2018-04-16 22:19:05 +00:00
|
|
|
|
2019-11-18 06:56:49 +00:00
|
|
|
static ObjectPtr create_module_object(
|
2019-07-17 03:04:24 +00:00
|
|
|
c10::QualifiedName class_name,
|
2019-07-29 23:26:02 +00:00
|
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
|
|
|
bool shouldMangle = false) {
|
2019-08-11 22:43:28 +00:00
|
|
|
// If the name is unqualified, prepend a `__torch__`, similar to what Python
|
|
|
|
|
// does with `__main__` for top-level code.
|
|
|
|
|
if (class_name.prefix().empty()) {
|
|
|
|
|
class_name = c10::QualifiedName("__torch__", class_name.name());
|
|
|
|
|
}
|
2019-07-29 23:26:02 +00:00
|
|
|
if (shouldMangle && cu->get_class(class_name) != nullptr) {
|
|
|
|
|
class_name = cu->mangle(class_name);
|
|
|
|
|
}
|
2019-07-17 03:04:24 +00:00
|
|
|
auto cls = ClassType::create(std::move(class_name), cu, /*is_module=*/true);
|
2019-08-09 07:42:40 +00:00
|
|
|
cu->register_type(cls);
|
2019-07-17 03:04:24 +00:00
|
|
|
return c10::ivalue::Object::create(
|
|
|
|
|
c10::StrongTypePtr(std::move(cu), std::move(cls)), 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Module::Module(c10::QualifiedName class_name)
|
2019-11-18 06:56:49 +00:00
|
|
|
: Object(create_module_object(
|
2019-07-29 23:26:02 +00:00
|
|
|
std::move(class_name),
|
|
|
|
|
std::make_shared<CompilationUnit>())) {}
|
2019-07-17 03:04:24 +00:00
|
|
|
|
module dedupe (#26666)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26666
Changes:
- Introduce a `ConcreteModuleType` concept. This acts both as the key into the type
cache, and as the source of truth for `ModuleValue::attr` queries. It needs
to do both jobs because that's how we ensure correctness (if the types are
different, it's because `ModuleValue::attr` would return different things).
- Now `recursive_script` will first construct a `ConcreteModuleType` and search for a
pre-existing type before starting compilation.
- All previous paths to creating a `ScriptModule` (including inheriting from
`ScriptModule`) are now rewritten to go through `create_script_module`, so
that we have only a single place where construction happens.
Behavioral changes:
- Big change to `torch.jit.ScriptModule` inheritance: all attributes are now
recursively scripted if possible, matching recursive scripting semantics.
This makes it hard to keep something from being scripted (for example, a
Python submodule). Possibly we'll need an `ignore()` type thing for
attributes. In particular, this adds `self.training` to *every* ScriptModule, since
it's present on every `nn.Module`.
- I believe this change to be transparent to existing users of the inheritance API, since if you had an attribute that is unscriptable that you never used, there is no error. In some cases, we will create new attributes (even if they are unused), which will increase serialized model size from before.
Test Plan: Imported from OSS
Differential Revision: D17551196
Pulled By: suo
fbshipit-source-id: b476d1c9feb3ddfd63406d90989aaf9dfe890591
2019-10-12 16:49:56 +00:00
|
|
|
Module::Module(
|
|
|
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
|
|
|
const c10::ClassTypePtr& type)
|
2019-11-18 06:56:49 +00:00
|
|
|
: Object(c10::ivalue::Object::create(
|
module dedupe (#26666)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26666
Changes:
- Introduce a `ConcreteModuleType` concept. This acts both as the key into the type
cache, and as the source of truth for `ModuleValue::attr` queries. It needs
to do both jobs because that's how we ensure correctness (if the types are
different, it's because `ModuleValue::attr` would return different things).
- Now `recursive_script` will first construct a `ConcreteModuleType` and search for a
pre-existing type before starting compilation.
- All previous paths to creating a `ScriptModule` (including inheriting from
`ScriptModule`) are now rewritten to go through `create_script_module`, so
that we have only a single place where construction happens.
Behavioral changes:
- Big change to `torch.jit.ScriptModule` inheritance: all attributes are now
recursively scripted if possible, matching recursive scripting semantics.
This makes it hard to keep something from being scripted (for example, a
Python submodule). Possibly we'll need an `ignore()` type thing for
attributes. In particular, this adds `self.training` to *every* ScriptModule, since
it's present on every `nn.Module`.
- I believe this change to be transparent to existing users of the inheritance API, since if you had an attribute that is unscriptable that you never used, there is no error. In some cases, we will create new attributes (even if they are unused), which will increase serialized model size from before.
Test Plan: Imported from OSS
Differential Revision: D17551196
Pulled By: suo
fbshipit-source-id: b476d1c9feb3ddfd63406d90989aaf9dfe890591
2019-10-12 16:49:56 +00:00
|
|
|
c10::StrongTypePtr(std::move(cu), type),
|
|
|
|
|
type->numAttributes())) {}
|
|
|
|
|
|
2019-07-17 03:04:24 +00:00
|
|
|
Module::Module(
|
|
|
|
|
c10::QualifiedName class_name,
|
2019-07-29 23:26:02 +00:00
|
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
|
|
|
bool shouldMangle)
|
2019-11-18 06:56:49 +00:00
|
|
|
: Object(create_module_object(
|
2019-07-29 23:26:02 +00:00
|
|
|
std::move(class_name),
|
|
|
|
|
std::move(cu),
|
|
|
|
|
shouldMangle)) {}
|
2019-07-17 03:04:24 +00:00
|
|
|
|
2019-06-09 03:54:17 +00:00
|
|
|
// first class mode runs models as first class objects,
|
|
|
|
|
// and does not force inlining everywhere. This is experimental
|
|
|
|
|
// as we bring up the system since it will degrade performance
|
|
|
|
|
// and may introduce bugs. test_jit.py provides context managers
|
|
|
|
|
// that enable it for specific tests.
|
2019-10-31 20:00:33 +00:00
|
|
|
thread_local bool inline_everything = false;
|
2019-06-16 21:24:02 +00:00
|
|
|
bool& getInlineEverythingMode() {
|
|
|
|
|
return inline_everything;
|
2019-06-09 03:54:17 +00:00
|
|
|
}
|
|
|
|
|
|
2018-10-18 14:47:11 +00:00
|
|
|
void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {
|
|
|
|
|
to_impl(device, dtype, non_blocking);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Module::to(at::ScalarType dtype, bool non_blocking) {
|
|
|
|
|
to_impl(/*device=*/c10::nullopt, dtype, non_blocking);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Module::to(at::Device device, bool non_blocking) {
|
|
|
|
|
to_impl(device, /*dtype=*/c10::nullopt, non_blocking);
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-26 20:03:50 +00:00
|
|
|
void module_state_to(
|
2019-11-07 06:56:46 +00:00
|
|
|
autograd::Variable variable,
|
2018-12-14 00:09:08 +00:00
|
|
|
const c10::optional<at::Device>& device,
|
|
|
|
|
const c10::optional<at::ScalarType>& dtype,
|
2018-10-18 14:47:11 +00:00
|
|
|
bool non_blocking) {
|
2019-05-10 20:01:15 +00:00
|
|
|
// Need to access the `at::Tensor` as a `Variable` here.
|
|
|
|
|
// Use the data's original device or dtype if not supplied here.
|
Remove Variable::Impl and DifferentiableViewImpl (#17072)
Summary:
As part of the Variable/Tensor merge work: https://github.com/pytorch/pytorch/issues/13638, we make the following changes in this PR:
1. Remove the `Variable::Impl` class and the `DifferentiableViewImpl` class
2. Change all `Variable.data()` call sites to either use `Variable` directly, or use `Variable.tensor_data()`
3. Remove `Variable.data()` API
3. Add `Variable.variable_data()` that matches `tensor.data` in Python API, which creates a new `Variable` that shares the same storage and tensor metadata with the original `Variable`, but with a completely new autograd history.
After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its `impl_`. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't.
**Note that this PR is BC-breaking in the following use cases:**
**Use Case 1:**
Previously, `x.data = y` works even if `x` and `y` are of different TensorImpl type (e.g. `x` is a CPU dense tensor whose impl is of type TensorImpl, while `y` is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, `x.data = y` doesn't work anymore if `x` and `y` are of different TensorImpl type, because the underlying implementation `variable.set_data(tensor)` no longer works if `variable` and `tensor` have different TensorImpl type.
**Use Case 2:**
If a tensor `x`'s `grad` is sparse, accumulating dense gradients to `x` will change the tensor that `x.grad` is pointing to. This is better illustrated with the following example:
```python
params = torch.tensor([1.5, 1.5]).requires_grad_()
with torch.no_grad():
# Change gradient to a sparse tensor
params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]))
grad_saved = params.grad
params.backward(torch.tensor([1.5, 1.5]))
assert id(grad_saved) == id(params.grad) # This will fail after this PR
```
The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the `params.grad` tensor reference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17072
Differential Revision: D14075257
Pulled By: yf225
fbshipit-source-id: 0e681df641270dea586042dd26db59f2e76b5957
2019-05-24 04:03:29 +00:00
|
|
|
auto new_data = variable.to(
|
|
|
|
|
device.value_or(variable.device()),
|
|
|
|
|
dtype.value_or(variable.scalar_type()),
|
2019-05-10 20:01:15 +00:00
|
|
|
non_blocking);
|
|
|
|
|
variable.set_data(new_data);
|
2019-04-26 20:03:50 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Module::to_impl(
|
|
|
|
|
const c10::optional<at::Device>& device,
|
|
|
|
|
const c10::optional<at::ScalarType>& dtype,
|
|
|
|
|
bool non_blocking) {
|
2019-11-07 06:56:46 +00:00
|
|
|
for (at::Tensor e : parameters()) {
|
|
|
|
|
module_state_to(e, device, dtype, non_blocking);
|
2019-04-26 20:03:50 +00:00
|
|
|
}
|
2019-11-07 06:56:46 +00:00
|
|
|
for (at::Tensor e : buffers()) {
|
|
|
|
|
module_state_to(e, device, dtype, non_blocking);
|
2018-10-18 14:47:11 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-07-05 00:07:52 +00:00
|
|
|
Method::Method(ModulePtr owner, Function* function)
|
|
|
|
|
: owner_(std::move(owner)), function_(function) {}
|
2019-06-09 03:54:17 +00:00
|
|
|
|
2019-06-18 20:55:15 +00:00
|
|
|
Module Method::owner() const {
|
|
|
|
|
return Module(owner_);
|
|
|
|
|
}
|
2019-06-09 03:54:17 +00:00
|
|
|
void Method::run(Stack& stack) {
|
2019-11-18 06:56:49 +00:00
|
|
|
stack.insert(stack.begin(), owner()._ivalue());
|
2020-03-31 07:31:06 +00:00
|
|
|
RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
|
2019-06-09 03:54:17 +00:00
|
|
|
function_->run(stack);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
|
2019-11-18 06:56:49 +00:00
|
|
|
stack.insert(stack.begin(), owner()._ivalue());
|
2020-03-31 07:31:06 +00:00
|
|
|
RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
|
2019-06-16 21:24:02 +00:00
|
|
|
return (*function_)(std::move(stack), kwargs);
|
2019-04-25 22:43:53 +00:00
|
|
|
}
|
2019-04-11 20:30:42 +00:00
|
|
|
|
2019-04-25 22:43:53 +00:00
|
|
|
void Module::clone_method(
|
|
|
|
|
const Module& orig,
|
2019-07-29 23:26:02 +00:00
|
|
|
const Function& method,
|
2019-04-25 22:43:53 +00:00
|
|
|
const std::unordered_map<TypePtr, TypePtr>& type_remap) {
|
|
|
|
|
// type remapping - when we copy method implementations from one module
|
|
|
|
|
// singleton to another, we need to update the types of the self arguments
|
|
|
|
|
// to match the new module.
|
|
|
|
|
// XXX - this only handles modules that occur as variables, not modules
|
|
|
|
|
// that appear in aggregate types. Currently this works fine because
|
|
|
|
|
// we restrict how modules can be used during the lowering step. Eventually,
|
|
|
|
|
// we will need to decide what it means for us to 'copy' a module.
|
|
|
|
|
// For instance, we can copy just the state (parameters, attributes),
|
|
|
|
|
// but share the code. Or we can copy the code. If we choose to copy the
|
|
|
|
|
// code, what should we do about aggregate types that contain a module?
|
|
|
|
|
auto type_remap_fn = [&](TypePtr in) {
|
|
|
|
|
auto it = type_remap.find(in);
|
|
|
|
|
if (it == type_remap.end())
|
|
|
|
|
return in;
|
|
|
|
|
return it->second;
|
|
|
|
|
};
|
2019-07-29 23:26:02 +00:00
|
|
|
auto graph = method.graph()->copy();
|
2019-04-25 22:43:53 +00:00
|
|
|
graph->remapTypes(type_remap_fn);
|
2019-07-29 23:26:02 +00:00
|
|
|
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
|
|
|
|
|
const auto this_method_name = getNameForMethod(method.name());
|
2019-07-11 21:41:53 +00:00
|
|
|
auto copied =
|
2019-11-18 06:56:49 +00:00
|
|
|
_ivalue()->compilation_unit()->create_function(this_method_name, graph);
|
2019-07-29 23:26:02 +00:00
|
|
|
type()->addMethod(copied);
|
2019-04-25 22:43:53 +00:00
|
|
|
copied->setSchema(std::move(schema));
|
2019-04-11 20:30:42 +00:00
|
|
|
}
|
|
|
|
|
|
2019-04-25 22:43:53 +00:00
|
|
|
void Module::clone_method(const Module& orig, const std::string& name) {
|
|
|
|
|
std::unordered_map<TypePtr, TypePtr> type_remap;
|
2019-06-18 20:55:15 +00:00
|
|
|
std::vector<std::pair<Module, Module>> to_scan = {{orig, *this}};
|
2019-04-25 22:43:53 +00:00
|
|
|
while (!to_scan.empty()) {
|
|
|
|
|
auto entry = to_scan.back();
|
|
|
|
|
to_scan.pop_back();
|
2019-11-18 06:56:49 +00:00
|
|
|
type_remap[entry.first._ivalue()->type()] = entry.second._ivalue()->type();
|
2019-11-07 06:56:46 +00:00
|
|
|
for (const NameModule& s : entry.first.named_children()) {
|
|
|
|
|
to_scan.emplace_back(
|
|
|
|
|
s.value, Module(entry.second.attr(s.name).toObject()));
|
2019-04-25 22:43:53 +00:00
|
|
|
}
|
|
|
|
|
}
|
2019-07-29 23:26:02 +00:00
|
|
|
return clone_method(orig, orig.get_method(name).function(), type_remap);
|
2019-04-11 20:30:42 +00:00
|
|
|
}
|
|
|
|
|
|
2020-05-06 20:49:40 +00:00
|
|
|
Module Module::copy() const {
|
|
|
|
|
return Module(_ivalue()->copy());
|
|
|
|
|
}
|
|
|
|
|
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-29 01:44:29 +00:00
|
|
|
Module Module::deepcopy() const {
|
|
|
|
|
return Module(_ivalue()->deepcopy());
|
|
|
|
|
}
|
|
|
|
|
|
2020-05-29 18:11:08 +00:00
|
|
|
Module Module::clone(bool inplace) const {
|
2019-08-10 01:18:25 +00:00
|
|
|
std::unordered_map<TypePtr, TypePtr> type_remap;
|
2020-05-29 18:11:08 +00:00
|
|
|
IValue::HashAliasedIValueMap memo;
|
|
|
|
|
return clone_impl(type_remap, inplace, memo);
|
2019-08-10 01:18:25 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Module Module::clone_impl(
|
2020-05-29 18:11:08 +00:00
|
|
|
std::unordered_map<TypePtr, TypePtr>& type_remap,
|
|
|
|
|
bool inplace,
|
|
|
|
|
IValue::HashAliasedIValueMap memo) const {
|
2019-11-18 06:56:49 +00:00
|
|
|
// Create a new _ivalue in the same compilation unit.
|
2020-01-15 19:15:21 +00:00
|
|
|
// Since now we have shared ClassType, we need to preserve the shared
|
|
|
|
|
// ClassType during cloning, so we first need to check if the type
|
|
|
|
|
// is already cloned, if so, we'll create a new module with the cloned
|
|
|
|
|
// ClassType, if not, we'll create a new module and a new ClassType.
|
|
|
|
|
bool type_already_cloned = type_remap.find(type()) != type_remap.end();
|
|
|
|
|
Module r;
|
|
|
|
|
if (type_already_cloned) {
|
|
|
|
|
// if we cloned the class type before, we'll reuse it
|
2020-03-26 18:15:49 +00:00
|
|
|
Module new_module(
|
|
|
|
|
_ivalue()->compilation_unit(), type_remap[type()]->cast<ClassType>());
|
2020-01-15 19:15:21 +00:00
|
|
|
r = new_module;
|
|
|
|
|
} else {
|
|
|
|
|
Module new_module(*type()->name(), _ivalue()->compilation_unit(), true);
|
|
|
|
|
r = new_module;
|
|
|
|
|
type_remap[type()] = r.type();
|
|
|
|
|
}
|
2019-08-10 01:18:25 +00:00
|
|
|
|
|
|
|
|
// Copy slots. If a slot is a module - recursively clone it.
|
2019-11-07 06:56:46 +00:00
|
|
|
size_t N = type()->numAttributes();
|
|
|
|
|
for (size_t i = 0; i < N; ++i) {
|
2019-11-18 06:56:49 +00:00
|
|
|
IValue s = _ivalue()->getSlot(i);
|
2020-07-31 17:21:43 +00:00
|
|
|
std::string attr_name = type()->getAttributeName(i);
|
|
|
|
|
TypePtr attr_type = type()->getAttribute(i);
|
|
|
|
|
if (attr_type->is_module()) {
|
2019-11-07 06:56:46 +00:00
|
|
|
const Module& orig = Module(s.toObject());
|
2020-05-29 18:11:08 +00:00
|
|
|
Module cloned = orig.clone_impl(type_remap, inplace, memo);
|
2019-08-10 01:18:25 +00:00
|
|
|
type_remap[orig.type()] = cloned.type();
|
2020-07-31 17:21:43 +00:00
|
|
|
// NOTE: why do we need to manually setattr on object instead of using
|
|
|
|
|
// register_module here? because the attr can be a module interface
|
|
|
|
|
// type and hold a Module object still. register_module will not let us
|
|
|
|
|
// correctly set up the type for this attr, so we had to do this manually.
|
|
|
|
|
// In the case it's an interface type, the type will be shared by the new
|
|
|
|
|
// cloned instance in the same compilation unit bc it only contains a list
|
|
|
|
|
// of functionSchema
|
|
|
|
|
r.type()->addOrCheckAttribute(
|
|
|
|
|
attr_name, attr_type->cast<ClassType>() ? cloned.type() : attr_type);
|
|
|
|
|
r._ivalue()->setAttr(attr_name, cloned._ivalue());
|
2019-08-10 01:18:25 +00:00
|
|
|
} else {
|
2020-03-26 18:15:49 +00:00
|
|
|
// this adds new slot and creates a new attribute for the underlying type
|
|
|
|
|
// if the type is not already cloned, otherwise it will only add a new
|
|
|
|
|
// slot and typecheck
|
2019-10-17 04:30:44 +00:00
|
|
|
r.register_attribute(
|
2019-11-07 23:23:37 +00:00
|
|
|
type()->getAttributeName(i),
|
2020-07-31 17:21:43 +00:00
|
|
|
attr_type,
|
2020-05-29 18:11:08 +00:00
|
|
|
// we'll deepcopy the IValue in non inplace option
|
|
|
|
|
inplace ? s : s.deepcopy(memo),
|
2020-05-19 06:21:27 +00:00
|
|
|
type()->is_parameter(i),
|
|
|
|
|
type()->is_buffer(i));
|
2019-08-10 01:18:25 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-01-15 19:15:21 +00:00
|
|
|
// only clone the methods if the ClassType is not cloned before
|
|
|
|
|
if (!type_already_cloned) {
|
2020-01-25 00:47:05 +00:00
|
|
|
// clone constants
|
|
|
|
|
for (size_t i = 0; i < type()->numConstants(); ++i) {
|
|
|
|
|
r.type()->addConstant(type()->getConstantName(i), type()->getConstant(i));
|
|
|
|
|
}
|
|
|
|
|
// clone methods, remapping the types to the cloned ones.
|
2020-01-15 19:15:21 +00:00
|
|
|
for (auto& fn : type()->methods()) {
|
|
|
|
|
r.clone_method(*this, *fn, type_remap);
|
|
|
|
|
}
|
2019-08-10 01:18:25 +00:00
|
|
|
}
|
|
|
|
|
return r;
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-24 18:11:34 +00:00
|
|
|
void Module::train(bool on) {
|
2019-11-07 06:56:46 +00:00
|
|
|
for (Module m : modules()) {
|
2019-11-18 06:56:49 +00:00
|
|
|
if (auto slot = m._ivalue()->type()->findAttributeSlot("training")) {
|
|
|
|
|
m._ivalue()->setSlot(*slot, on);
|
2019-11-07 06:56:46 +00:00
|
|
|
} else {
|
|
|
|
|
TORCH_INTERNAL_ASSERT("'training' attribute not found");
|
|
|
|
|
}
|
2019-06-06 18:55:44 +00:00
|
|
|
}
|
2019-05-10 20:01:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
|
|
|
|
// Look up the class
|
|
|
|
|
const auto classType =
|
2019-11-18 06:56:49 +00:00
|
|
|
_ivalue()->compilation_unit()->get_class(c10::QualifiedName(name));
|
2019-05-10 20:01:15 +00:00
|
|
|
if (!classType) {
|
|
|
|
|
AT_ERROR(
|
|
|
|
|
"Could not find class with name: '",
|
|
|
|
|
name.qualifiedName(),
|
|
|
|
|
"' in module.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create a bare object with correct number of slots
|
|
|
|
|
const size_t numAttrs = classType->numAttributes();
|
2019-07-05 00:07:52 +00:00
|
|
|
auto obj = c10::ivalue::Object::create(
|
2019-11-18 06:56:49 +00:00
|
|
|
c10::StrongTypePtr(_ivalue()->compilation_unit(), classType), numAttrs);
|
2019-05-10 20:01:15 +00:00
|
|
|
|
|
|
|
|
// Invoke the `__init__()` of the class with the arguments provided.
|
|
|
|
|
Stack stackWithSelf = {obj};
|
|
|
|
|
for (auto& arg : stack) {
|
|
|
|
|
stackWithSelf.push_back(std::move(arg));
|
|
|
|
|
}
|
|
|
|
|
// Note: following Python, `__init__()` modifies its first parameter in-place
|
|
|
|
|
// and returns nothing.
|
2020-05-06 22:20:31 +00:00
|
|
|
classType->getMethod("__init__").operator()(std::move(stackWithSelf));
|
2019-05-10 20:01:15 +00:00
|
|
|
|
|
|
|
|
return obj;
|
2019-04-24 18:11:34 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-07 06:56:46 +00:00
|
|
|
buffer_list Module::buffers(bool recurse) const {
|
|
|
|
|
return buffer_list(*this, recurse, /*return_module=*/false);
|
2019-06-18 20:55:15 +00:00
|
|
|
}
|
2019-11-07 06:56:46 +00:00
|
|
|
named_buffer_list Module::named_buffers(bool recurse) const {
|
|
|
|
|
return named_buffer_list(*this, recurse, /*return_module=*/false);
|
2019-06-18 20:55:15 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-07 06:56:46 +00:00
|
|
|
module_list Module::children() const {
|
|
|
|
|
return module_list(*this, /*recurse=*/false, /*return_module=*/false);
|
2019-06-18 20:55:15 +00:00
|
|
|
}
|
2019-11-07 06:56:46 +00:00
|
|
|
named_module_list Module::named_children() const {
|
|
|
|
|
return named_module_list(*this, /*recurse=*/false, /*return_module=*/false);
|
2019-06-18 20:55:15 +00:00
|
|
|
}
|
2019-11-07 06:56:46 +00:00
|
|
|
module_list Module::modules() const {
|
|
|
|
|
return module_list(*this, /*recurse=*/true, /*return_module=*/true);
|
2019-06-18 20:55:15 +00:00
|
|
|
}
|
2019-11-07 06:56:46 +00:00
|
|
|
named_module_list Module::named_modules() const {
|
|
|
|
|
return named_module_list(*this, /*recurse=*/true, /*return_module=*/true);
|
2019-10-17 04:30:44 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-07 06:56:46 +00:00
|
|
|
parameter_list Module::parameters(bool recurse) const {
|
|
|
|
|
return parameter_list(*this, recurse, /*return_module=*/false);
|
2019-10-17 04:30:44 +00:00
|
|
|
}
|
2019-11-07 06:56:46 +00:00
|
|
|
named_parameter_list Module::named_parameters(bool recurse) const {
|
|
|
|
|
return named_parameter_list(*this, recurse, /*return_module=*/false);
|
2019-10-17 04:30:44 +00:00
|
|
|
}
|
|
|
|
|
|
2019-11-07 06:56:46 +00:00
|
|
|
attribute_list Module::attributes(bool recurse) const {
|
|
|
|
|
return attribute_list(*this, recurse, /*return_module=*/false);
|
|
|
|
|
}
|
|
|
|
|
named_attribute_list Module::named_attributes(bool recurse) const {
|
|
|
|
|
return named_attribute_list(*this, recurse, /*return_module=*/false);
|
|
|
|
|
}
|
|
|
|
|
|
2019-06-25 20:20:43 +00:00
|
|
|
void Module::apply(const std::function<void(Module&)>& fn) {
|
2019-11-07 06:56:46 +00:00
|
|
|
for (Module s : modules()) {
|
|
|
|
|
fn(s);
|
2019-06-18 20:55:15 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-10-08 23:30:30 +00:00
|
|
|
std::string Module::dump_to_str(
|
2019-09-10 18:29:14 +00:00
|
|
|
bool print_method_bodies,
|
|
|
|
|
bool print_attr_values,
|
|
|
|
|
bool print_param_values,
|
2019-10-08 23:30:30 +00:00
|
|
|
int level = 0) const {
|
2019-08-16 22:11:07 +00:00
|
|
|
std::stringstream ss;
|
|
|
|
|
std::stringstream parameters_ss;
|
|
|
|
|
std::stringstream attributes_ss;
|
|
|
|
|
std::stringstream methods_ss;
|
|
|
|
|
std::stringstream submodules_ss;
|
|
|
|
|
|
2019-11-07 06:56:46 +00:00
|
|
|
for (const NameTensor& p : named_parameters(/*recurse=*/false)) {
|
2019-10-17 04:30:44 +00:00
|
|
|
parameters_ss << p.name << " = ";
|
2019-09-10 18:29:14 +00:00
|
|
|
if (print_param_values) {
|
2019-11-07 06:56:46 +00:00
|
|
|
parameters_ss << p.value << std::endl;
|
2019-08-16 22:11:07 +00:00
|
|
|
} else {
|
|
|
|
|
parameters_ss << "..." << std::endl;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-07 06:56:46 +00:00
|
|
|
for (const NameValue& p : named_attributes(/*recurse=*/false)) {
|
2019-10-17 04:30:44 +00:00
|
|
|
attributes_ss << p.name << " = ";
|
|
|
|
|
if (!p.value.isTensor() || print_attr_values) {
|
|
|
|
|
attributes_ss << p.value << std::endl;
|
2019-08-16 22:11:07 +00:00
|
|
|
} else {
|
|
|
|
|
attributes_ss << "..." << std::endl;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const Method& method : get_methods()) {
|
|
|
|
|
methods_ss << " method " << method.name() << " {" << std::endl;
|
2019-09-10 18:29:14 +00:00
|
|
|
if (print_method_bodies) {
|
2019-08-16 22:11:07 +00:00
|
|
|
methods_ss << torch::jit::jit_log_prefix(
|
|
|
|
|
" ", method.graph()->toString())
|
|
|
|
|
<< std::endl;
|
|
|
|
|
}
|
|
|
|
|
methods_ss << " }" << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-18 06:56:49 +00:00
|
|
|
ss << "module " << type()->name()->qualifiedName() << " {" << std::endl;
|
2019-08-16 22:11:07 +00:00
|
|
|
ss << " parameters {" << std::endl;
|
|
|
|
|
ss << torch::jit::jit_log_prefix(" ", parameters_ss.str());
|
|
|
|
|
ss << " }" << std::endl;
|
|
|
|
|
ss << " attributes {" << std::endl;
|
|
|
|
|
ss << torch::jit::jit_log_prefix(" ", attributes_ss.str());
|
|
|
|
|
ss << " }" << std::endl;
|
|
|
|
|
ss << " methods {" << std::endl;
|
|
|
|
|
ss << torch::jit::jit_log_prefix(" ", methods_ss.str());
|
|
|
|
|
ss << " }" << std::endl;
|
|
|
|
|
ss << " submodules {" << std::endl;
|
2019-11-07 06:56:46 +00:00
|
|
|
for (const NameModule& s : named_children()) {
|
2019-08-16 22:11:07 +00:00
|
|
|
// We do level + 2, because one level of indentation comes from 'submodules'
|
|
|
|
|
// scope and the other one goes from a specific submodule we're printing.
|
2019-11-07 06:56:46 +00:00
|
|
|
ss << s.value.dump_to_str(
|
2019-09-10 18:29:14 +00:00
|
|
|
print_method_bodies, print_attr_values, print_param_values, level + 2);
|
2019-08-16 22:11:07 +00:00
|
|
|
}
|
|
|
|
|
ss << " }" << std::endl;
|
|
|
|
|
ss << "}" << std::endl;
|
|
|
|
|
|
|
|
|
|
std::string indent(2 * level, ' ');
|
|
|
|
|
return torch::jit::jit_log_prefix(indent, ss.str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Module::dump(
|
2019-09-10 18:29:14 +00:00
|
|
|
bool print_method_bodies = true,
|
|
|
|
|
bool print_attr_values = true,
|
|
|
|
|
bool print_param_values = true) const {
|
2019-10-08 23:30:30 +00:00
|
|
|
std::cout << dump_to_str(
|
2019-11-07 23:23:37 +00:00
|
|
|
print_method_bodies, print_attr_values, print_param_values)
|
2019-08-16 22:11:07 +00:00
|
|
|
<< std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
2018-12-26 14:52:25 +00:00
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|
2019-11-07 06:56:46 +00:00
|
|
|
|
|
|
|
|
namespace c10 {
|
|
|
|
|
|
2020-03-12 06:29:34 +00:00
|
|
|
torch::jit::Module IValue::toModule() const {
|
|
|
|
|
return torch::jit::Module(toObject());
|
2019-11-07 06:56:46 +00:00
|
|
|
}
|
|
|
|
|
bool IValue::isModule() const {
|
|
|
|
|
return isObject() && toObjectRef().type()->is_module();
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-07 23:23:37 +00:00
|
|
|
} // namespace c10
|