pytorch/test/cpp/api/optim.cpp

430 lines
14 KiB
C++
Raw Normal View History

#include <gtest/gtest.h>
Re-organize C++ API `torch::nn` folder structure (#26262) Summary: This PR aims to re-organize C++ API `torch::nn` folder structure in the following way: - Every module in `torch/csrc/api/include/torch/nn/modules/` (except `any.h`, `named_any.h`, `modulelist.h`, `sequential.h`, `embedding.h`) has a strictly equivalent Python file in `torch/nn/modules/`. For example: `torch/csrc/api/include/torch/nn/modules/pooling.h` -> `torch/nn/modules/pooling.py` `torch/csrc/api/include/torch/nn/modules/conv.h` -> `torch/nn/modules/conv.py` `torch/csrc/api/include/torch/nn/modules/batchnorm.h` -> `torch/nn/modules/batchnorm.py` `torch/csrc/api/include/torch/nn/modules/sparse.h` -> `torch/nn/modules/sparse.py` - Containers such as `any.h`, `named_any.h`, `modulelist.h`, `sequential.h` are moved into `torch/csrc/api/include/torch/nn/modules/container/`, because their implementations are too long to be combined into one file (like `torch/nn/modules/container.py` in Python API) - `embedding.h` is not renamed to `sparse.h` yet, because we have another work stream that works on API parity for Embedding and EmbeddingBag, and renaming the file would cause conflict. After the embedding API parity work is done, we will rename `embedding.h` to `sparse.h` to match the Python file name, and move the embedding options out to options/ folder. - `torch/csrc/api/include/torch/nn/functional/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/functional/pooling.h` contains the functions for pooling, which are then used by the pooling modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`. - `torch/csrc/api/include/torch/nn/options/` is added, and the folder structure mirrors that of `torch/csrc/api/include/torch/nn/modules/`. For example, `torch/csrc/api/include/torch/nn/options/pooling.h` contains MaxPoolOptions, which is used by both MaxPool modules in `torch/csrc/api/include/torch/nn/modules/pooling.h`, and max_pool functions in `torch/csrc/api/include/torch/nn/functional/pooling.h`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26262 Differential Revision: D17422426 Pulled By: yf225 fbshipit-source-id: c413d2a374ba716dac81db31516619bbd879db7f
2019-09-17 17:05:11 +00:00
#include <torch/torch.h>
#include <test/cpp/api/optim_baseline.h>
#include <test/cpp/api/support.h>
#include <cmath>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
using namespace torch::nn;
using namespace torch::optim;
template <typename OptimizerClass, typename Options>
bool test_optimizer_xor(Options options) {
torch::manual_seed(0);
Sequential model(
Linear(2, 8),
Functional(torch::sigmoid),
Linear(8, 1),
Functional(torch::sigmoid));
const int64_t kBatchSize = 200;
const int64_t kMaximumNumberOfEpochs = 3000;
OptimizerClass optimizer(model->parameters(), options);
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
auto inputs = torch::empty({kBatchSize, 2});
auto labels = torch::empty({kBatchSize});
for (size_t i = 0; i < kBatchSize; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
}
inputs.set_requires_grad(true);
auto step = [&](OptimizerClass& optimizer, Sequential model, torch::Tensor inputs, torch::Tensor labels) {
auto closure = [&]() {
optimizer.zero_grad();
auto x = model->forward(inputs);
auto loss = torch::binary_cross_entropy(x, labels);
loss.backward();
return loss;
};
return optimizer.step(closure);
};
torch::Tensor loss = step(optimizer, model, inputs, labels);
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
if (epoch > kMaximumNumberOfEpochs) {
std::cout << "Loss is too high after epoch " << epoch << ": "
<< running_loss << std::endl;
return false;
}
epoch++;
}
return true;
}
template <typename Parameters>
void assign_parameter(
const Parameters& parameters,
const char* name,
torch::Tensor new_tensor) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
auto parameter = parameters[name];
parameter.set_requires_grad(false);
parameter.flatten().copy_(new_tensor);
parameter.set_requires_grad(true);
}
template <typename OptimizerClass, typename Options>
void check_exact_values(
Options options,
std::vector<std::vector<torch::Tensor>> expected_parameters) {
const size_t kIterations = 1001;
const size_t kSampleEvery = 100;
torch::manual_seed(0);
Sequential model(
Linear(2, 3),
Functional(torch::sigmoid),
Linear(3, 1),
Functional(torch::sigmoid));
Make Sequential ref-counted (#9151) Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
model->to(torch::kFloat64);
// Use exact input values because matching random values is hard.
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
auto parameters = model->named_parameters();
assign_parameter(
parameters,
"0.weight",
torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, torch::kFloat64));
assign_parameter(
parameters, "0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
assign_parameter(
parameters, "2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
assign_parameter(parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
auto optimizer = OptimizerClass(parameters.values(), options);
torch::Tensor input =
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64).reshape({3, 2});
for (size_t i = 0; i < kIterations; ++i) {
optimizer.zero_grad();
Make Sequential ref-counted (#9151) Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
auto output = model->forward(input);
auto loss = output.sum();
loss.backward();
auto closure = []() { return torch::tensor({10}); };
optimizer.step(closure);
if (i % kSampleEvery == 0) {
ASSERT_TRUE(
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
for (size_t p = 0; p < parameters.size(); ++p) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_TRUE(parameters[p]->defined());
// Always compare using double dtype, regardless of the original dtype of the tensors
auto computed = parameters[p]->flatten().to(torch::kFloat64);
auto expected = expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
std::cout << "Iteration " << i << ": " << computed
<< " != " << expected << " (parameter " << p << ")"
<< std::endl;
ASSERT_TRUE(false);
}
}
}
}
}
TEST(OptimTest, OptimizerAccessors) {
auto options = AdagradOptions(1.0);
std::vector<torch::Tensor> params;
for (size_t i = 0; i < 3; i++) {
params.push_back(torch::randn(10));
}
auto optimizer = Adagrad(params, options);
// test for defaults() method with non-const reference
auto& options_ = static_cast<AdagradOptions&>(optimizer.defaults());
ASSERT_TRUE(options == options_);
// test for param_groups() with non-const reference return
auto& params_groups = optimizer.param_groups();
params_groups.push_back(OptimizerParamGroup(params));
auto& params_1 = params_groups[1].params();
for (size_t i = 0; i < params_1.size(); i++) {
torch::equal(params[i], params_1[i]);
}
// test for add_param_group() when one or more params existing in another param_group
// are passed in the new param group to be added
ASSERT_THROWS_WITH(
optimizer.add_param_group(OptimizerParamGroup(params)), "some parameters appear in more than one parameter group");
// test for state() with non-const reference return
auto& state_ = static_cast<AdagradParamState&>(*(optimizer.state()[c10::guts::to_string(params_1[0].unsafeGetTensorImpl())]));
state_.step(state_.step()+1);
const auto& optimizer_ = Adagrad(params, options);
optimizer_.defaults();
// test for param_groups() with const reference return
const auto& params_2 = optimizer_.param_groups();
// test for state() with const reference return
optimizer_.state();
}
#define OLD_INTERFACE_WARNING_CHECK(func) \
{ \
torch::test::WarningCapture warnings; \
func; \
ASSERT_EQ( \
torch::test::count_substr_occurrences( \
warnings.str(), "will be removed"), \
1); \
}
struct MyOptimizerOptions : public OptimizerCloneableOptions<MyOptimizerOptions> {
MyOptimizerOptions(double lr = 1.0) : lr_(lr) {};
TORCH_ARG(double, lr) = 1.0;
};
TEST(OptimTest, OldInterface) {
struct MyOptimizer : Optimizer {
using Optimizer::Optimizer;
torch::Tensor step(LossClosure closure = nullptr) override { return {};}
explicit MyOptimizer(
std::vector<at::Tensor> params, MyOptimizerOptions defaults = {}) :
Optimizer({std::move(OptimizerParamGroup(params))}, std::make_unique<MyOptimizerOptions>(defaults)) {}
};
std::vector<torch::Tensor> parameters = {
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
{
MyOptimizer optimizer(parameters);
size_t size;
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
ASSERT_EQ(size, parameters.size());
}
{
std::vector<at::Tensor> params;
MyOptimizer optimizer(params);
size_t size;
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
ASSERT_EQ(size, 0);
OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
ASSERT_EQ(size, parameters.size());
std::vector<torch::Tensor> params_;
OLD_INTERFACE_WARNING_CHECK(params_ = optimizer.parameters());
for (size_t p = 0; p < size; ++p) {
ASSERT_TRUE(params_[p].allclose(parameters[p]));
}
}
{
Linear linear(3, 4);
MyOptimizer optimizer(linear->parameters());
size_t size;
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
ASSERT_EQ(size, linear->parameters().size());
}
}
TEST(OptimTest, XORConvergence_SGD) {
ASSERT_TRUE(test_optimizer_xor<SGD>(
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
}
TEST(OptimTest, XORConvergence_LBFGS) {
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0)));
ASSERT_TRUE(test_optimizer_xor<LBFGS>(LBFGSOptions(1.0).line_search_fn("strong_wolfe")));
}
TEST(OptimTest, XORConvergence_Adagrad) {
ASSERT_TRUE(test_optimizer_xor<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
}
TEST(OptimTest, XORConvergence_RMSprop) {
ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
}
TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
ASSERT_TRUE(test_optimizer_xor<RMSprop>(
RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
}
TEST(OptimTest, XORConvergence_Adam) {
ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
}
TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
ASSERT_TRUE(test_optimizer_xor<Adam>(
AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
}
TEST(OptimTest, ProducesPyTorchValues_Adam) {
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
check_exact_values<Adam>(
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
AdamOptions(1.0).weight_decay(1e-2),
expected_parameters::Adam_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<Adam>(
AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
expected_parameters::Adam_with_weight_decay_and_amsgrad());
}
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<Adagrad>(
AdagradOptions(1.0), expected_parameters::Adagrad());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-2),
expected_parameters::Adagrad_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
check_exact_values<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
}
TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
check_exact_values<RMSprop>(
RMSpropOptions(0.1), expected_parameters::RMSprop());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-2),
expected_parameters::RMSprop_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
expected_parameters::RMSprop_with_weight_decay_and_centered());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(
OptimTest,
ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
expected_parameters::
RMSprop_with_weight_decay_and_centered_and_momentum());
}
TEST(OptimTest, ProducesPyTorchValues_SGD) {
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<SGD>(
SGDOptions(0.1).weight_decay(1e-2),
expected_parameters::SGD_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<SGD>(
SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
expected_parameters::SGD_with_weight_decay_and_momentum());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
check_exact_values<SGD>(
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
}
TEST(OptimTest, ProducesPyTorchValues_LBFGS) {
check_exact_values<LBFGS>(
LBFGSOptions(1.0),
expected_parameters::LBFGS());
}
TEST(OptimTest, ProducesPyTorchValues_LBFGS_with_line_search) {
check_exact_values<LBFGS>(
LBFGSOptions(1.0).line_search_fn("strong_wolfe"),
expected_parameters::LBFGS_with_line_search());
}
TEST(OptimTest, ZeroGrad) {
torch::manual_seed(0);
Linear model(2, 8);
SGD optimizer(model->parameters(), 0.1);
for (const auto& parameter : model->parameters()) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_FALSE(parameter.grad().defined());
}
auto output = model->forward(torch::ones({5, 2}));
auto loss = output.sum();
loss.backward();
for (const auto& parameter : model->parameters()) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_TRUE(parameter.grad().defined());
ASSERT_GT(parameter.grad().sum().item<float>(), 0);
}
optimizer.zero_grad();
for (const auto& parameter : model->parameters()) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 18:53:07 +00:00
ASSERT_TRUE(parameter.grad().defined());
ASSERT_EQ(parameter.grad().sum().item<float>(), 0);
}
}
TEST(OptimTest, ExternalVectorOfParameters) {
torch::manual_seed(0);
std::vector<torch::Tensor> parameters = {
torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
std::vector<torch::Tensor> original_parameters = {
parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
// Set all gradients to one
for (auto& parameter : parameters) {
parameter.grad() = torch::ones_like(parameter);
}
SGD optimizer(parameters, 1.0);
optimizer.step();
ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
}
TEST(OptimTest, AddParameter_LBFGS) {
torch::manual_seed(0);
std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
// Set all gradients to one
for (auto& parameter : parameters) {
parameter.grad() = torch::ones_like(parameter);
}
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
OLD_INTERFACE_WARNING_CHECK(optimizer.add_parameters(parameters));
optimizer.step([]() { return torch::tensor(1); });
// REQUIRE this doesn't throw
}