pytorch/test/cpp/api/optim.cpp

341 lines
10 KiB
C++
Raw Normal View History

#include <gtest/gtest.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h>
#include <torch/optim.h>
#include <torch/types.h>
2018-05-30 15:55:34 +00:00
#include <torch/utils.h>
#include <test/cpp/api/optim_baseline.h>
#include <test/cpp/api/support.h>
#include <cmath>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <memory>
#include <random>
#include <vector>
using namespace torch::nn;
using namespace torch::optim;
template <typename OptimizerClass, typename Options>
bool test_optimizer_xor(Options options) {
torch::manual_seed(0);
Sequential model(
Linear(2, 8),
Functional(torch::sigmoid),
Linear(8, 1),
Functional(torch::sigmoid));
const int64_t kBatchSize = 4;
const int64_t kMaximumNumberOfEpochs = 3000;
OptimizerClass optimizer(model->parameters(), options);
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
auto inputs = torch::empty({kBatchSize, 2});
auto labels = torch::empty({kBatchSize});
for (size_t i = 0; i < kBatchSize; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
}
inputs.set_requires_grad(true);
optimizer.zero_grad();
Make Sequential ref-counted (#9151) Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
auto x = model->forward(inputs);
torch::Tensor loss = torch::binary_cross_entropy(x, labels);
loss.backward();
optimizer.step();
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
if (epoch > kMaximumNumberOfEpochs) {
std::cout << "Loss is too high after epoch " << epoch << ": "
<< running_loss << std::endl;
return false;
}
epoch++;
}
return true;
}
template <typename Parameters>
void assign_parameter(
const Parameters& parameters,
const char* name,
torch::Tensor new_tensor) {
auto parameter = parameters.at(name);
parameter.set_requires_grad(false);
parameter.flatten().copy_(new_tensor);
parameter.set_requires_grad(true);
}
template <typename OptimizerClass, typename Options>
void check_exact_values(
Options options,
std::vector<std::vector<torch::Tensor>> expected_parameters) {
const size_t kIterations = 1001;
const size_t kSampleEvery = 100;
torch::manual_seed(0);
Sequential model(
Linear(2, 3),
Functional(torch::sigmoid),
Linear(3, 1),
Functional(torch::sigmoid));
Make Sequential ref-counted (#9151) Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
model->to(torch::kFloat64);
// Use exact input values because matching random values is hard.
Make Sequential ref-counted (#9151) Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
auto parameters = model->parameters();
assign_parameter(
parameters,
"0.weight",
torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}));
assign_parameter(
parameters, "0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}));
assign_parameter(
parameters, "2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}));
assign_parameter(parameters, "2.bias", torch::tensor({-0.0711}));
auto optimizer = OptimizerClass(parameters, options);
torch::Tensor input =
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}).reshape({3, 2});
for (size_t i = 0; i < kIterations; ++i) {
optimizer.zero_grad();
Make Sequential ref-counted (#9151) Summary: In the C++ API, `Sequential` currently was not refcounted itself, but stored `shared_ptr<AnyModule>` to get the reference semantics. This is unfortunate because most modules in the API are accessed via `->`, e.g. `Linear l(1, 2); l->forward(...);`. `Sequential` was different in that it had value semantics itself, thus was accessed via `.`. This PR makes `Sequential` store `AnyModule` (without extra indirection), and uses the same pImpl mechanism we use for all other modules to make `Sequential` have reference semantics itself. This makes it consistent with the rest of the library. It also removes one level of indirection inside of `Sequential`, which is cool. One thing I had to change was that the `ModuleHolder` with which the whole pImpl thing is implemented previously did some tricks to make `Linear(3, 4)` actually construct `Linear(LinearOptions(3, 4))`. This doesn't work well with `Sequential` since it takes a variadic parameter pack. Instead, I made `ModuleHolder` forward all arguments to the underlying module, and then further pushed the trick to forward parameters to modules' options types into the actual Modules. This adds one constructor per Module in the library. This is not something user modules have to do (unless they want this nice forwarding themselves). It makes the code simpler overall. ezyang ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9151 Reviewed By: ezyang Differential Revision: D8809298 Pulled By: goldsborough fbshipit-source-id: da68452c3de912fbc67af330ba93b5220de6909f
2018-07-12 00:15:08 +00:00
auto output = model->forward(input);
auto loss = output.sum();
loss.backward();
optimizer.step();
if (i % kSampleEvery == 0) {
ASSERT_TRUE(
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
for (size_t p = 0; p < parameters.size(); ++p) {
ASSERT_TRUE(parameters.at(p)->defined());
auto computed = parameters.at(p)->flatten();
auto expected = expected_parameters.at(i / kSampleEvery).at(p);
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
std::cout << "Iteration " << i << ": " << computed
<< " != " << expected << " (parameter " << p << ")"
<< std::endl;
ASSERT_TRUE(false);
}
}
}
}
}
TEST(OptimTest, BasicInterface) {
struct MyOptimizer : Optimizer {
using Optimizer::Optimizer;
void step() override {}
};
std::vector<torch::Tensor> parameters = {
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
{
MyOptimizer optimizer(parameters);
ASSERT_EQ(optimizer.size(), parameters.size());
}
{
MyOptimizer optimizer;
ASSERT_EQ(optimizer.size(), 0);
optimizer.add_parameters(parameters);
ASSERT_EQ(optimizer.size(), parameters.size());
for (size_t p = 0; p < parameters.size(); ++p) {
ASSERT_TRUE(optimizer.parameters()[p].allclose(parameters[p]));
}
}
{
Linear linear(3, 4);
MyOptimizer optimizer(linear->parameters());
ASSERT_EQ(optimizer.size(), linear->parameters().size());
}
}
TEST(OptimTest, XORConvergence_SGD) {
ASSERT_TRUE(test_optimizer_xor<SGD>(
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
}
TEST(OptimTest, XORConvergence_Adagrad) {
ASSERT_TRUE(test_optimizer_xor<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
}
TEST(OptimTest, XORConvergence_RMSprop) {
ASSERT_TRUE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
}
TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
ASSERT_TRUE(test_optimizer_xor<RMSprop>(
RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
}
TEST(OptimTest, XORConvergence_Adam) {
ASSERT_TRUE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
}
TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
ASSERT_TRUE(test_optimizer_xor<Adam>(
AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
}
TEST(OptimTest, ProducesPyTorchValues_Adam) {
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
check_exact_values<Adam>(
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
AdamOptions(1.0).weight_decay(1e-2),
expected_parameters::Adam_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<Adam>(
AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
expected_parameters::Adam_with_weight_decay_and_amsgrad());
}
TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<Adagrad>(
AdagradOptions(1.0), expected_parameters::Adagrad());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-2),
expected_parameters::Adagrad_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
check_exact_values<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
}
TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
check_exact_values<RMSprop>(
RMSpropOptions(0.1), expected_parameters::RMSprop());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-2),
expected_parameters::RMSprop_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
expected_parameters::RMSprop_with_weight_decay_and_centered());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(
OptimTest,
ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
expected_parameters::
RMSprop_with_weight_decay_and_centered_and_momentum());
}
TEST(OptimTest, ProducesPyTorchValues_SGD) {
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<SGD>(
SGDOptions(0.1).weight_decay(1e-2),
expected_parameters::SGD_with_weight_decay());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
check_exact_values<SGD>(
SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
expected_parameters::SGD_with_weight_decay_and_momentum());
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
}
TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
check_exact_values<SGD>(
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
}
TEST(OptimTest, ZeroGrad) {
torch::manual_seed(0);
Linear model(2, 8);
SGD optimizer(model->parameters(), 0.1);
for (const auto& parameter : model->parameters()) {
ASSERT_FALSE(parameter->grad().defined());
}
auto output = model->forward(torch::ones({5, 2}));
auto loss = output.sum();
loss.backward();
for (const auto& parameter : model->parameters()) {
ASSERT_TRUE(parameter->grad().defined());
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
ASSERT_GT(parameter->grad().sum().item<float>(), 0);
}
optimizer.zero_grad();
for (const auto& parameter : model->parameters()) {
ASSERT_TRUE(parameter->grad().defined());
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 17:39:10 +00:00
ASSERT_EQ(parameter->grad().sum().item<float>(), 0);
}
}
TEST(OptimTest, ExternalVectorOfParameters) {
torch::manual_seed(0);
std::vector<torch::Tensor> parameters = {
torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
std::vector<torch::Tensor> original_parameters = {
parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
// Set all gradients to one
for (auto& parameter : parameters) {
parameter.grad() = torch::ones_like(parameter);
}
SGD optimizer(parameters, 1.0);
optimizer.step();
ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
}
TEST(OptimTest, AddParameter_LBFGS) {
torch::manual_seed(0);
std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
// Set all gradients to one
for (auto& parameter : parameters) {
parameter.grad() = torch::ones_like(parameter);
}
Remove use of data() in optimizers (#10490) Summary: After talking to users of the C++ API we found that having the tensor type be `autograd::Variable` causes more complications than having it be `at::Tensor`. It used to be a problem because `at::Tensor` didn't have the "autograd API" of variable (e.g. `detach()` or `grad()` methods), but those methods are now on `at::Tensor`. As such, we want to make a last big breaking change to have the tensor type be `at::Tensor`, while factory methods like `torch::ones` will return `Variable`s disguised as `at::Tensor`. This will make many things easier, like calling functions in ATen that take vectors of tensors. This PR makes a small step in this direction by updating the optimizer classes to not use `.data()` on `Variable` to access the underlying `at::Tensor`. Using `.data()` is effectively a hack to work around our modification rules for tensors that require grad. The proper way of doing things is to use `with torch.no_grad` or equivalently `NoGradGuard` in C++ to guard in-place operations. The next step can then simply redefine `torch::Tensor` to be `at::Tensor`. This transition should be smooth, since all methods available on `Variable` are at this point available on `at::Tensor`. For this PR I: 1. Modified the implementations of optimizers to not use `.data()`. This means the implementations are now different from PyTorch, which still uses the legacy method of using `.data`. 2. To properly verify (1), I added more fine-grained test cases to our optimizer tests, e.g. `SGD` with and without `weight_decay`, then with `nesterov` etc. Generally more tests = more happy! 3. Minor cleanup of the optimizer codebase ebetica apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/10490 Differential Revision: D9318229 Pulled By: goldsborough fbshipit-source-id: fb386700f37840542bc5d323f308ea88fe5ea5c5
2018-08-14 19:58:06 +00:00
LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
optimizer.add_parameters(parameters);
optimizer.step([]() { return torch::tensor(1); });
// REQUIRE this doesn't throw
}