mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary:
Added a way to `dynamic_cast` an `nn::Module` and get a pointer to it. `nn::Module::is<T>` just checked if the return value of the `dynamic_cast` was nullptr, so I got rid of `is<T>` since it's equivalent to `as<T> != nullptr`(or just `as<T>` due to boolean conversion).
We're now at
```
if (auto* conv = module.as<nn::Conv2d>()) {
conv->weight.data().normal_(0.0, 0.02);
} else if (auto* bn = module.as<nn::BatchNorm>()) {
bn->weight.data().normal_(1.0, 0.02);
bn->bias.data().fill_(0);
}
```
ezyang apaszke ebetica
Closes https://github.com/pytorch/pytorch/pull/9149
Differential Revision: D8735954
Pulled By: goldsborough
fbshipit-source-id: e2b8f6f0cea16a621f8bc0807a33cc7651d25154
|
||
|---|---|---|
| .. | ||
| api | ||