mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
32 lines
1 KiB
C++
32 lines
1 KiB
C++
|
|
#include "test/cpp/jit/test_base.h"
|
||
|
|
#include "test/cpp/jit/test_utils.h"
|
||
|
|
|
||
|
|
namespace torch {
|
||
|
|
namespace jit {
|
||
|
|
|
||
|
|
void testInterp() {
|
||
|
|
constexpr int batch_size = 4;
|
||
|
|
constexpr int input_size = 256;
|
||
|
|
constexpr int seq_len = 32;
|
||
|
|
|
||
|
|
int hidden_size = 2 * input_size;
|
||
|
|
|
||
|
|
auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
|
||
|
|
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
||
|
|
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
||
|
|
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
|
||
|
|
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
|
||
|
|
|
||
|
|
auto lstm_g = build_lstm();
|
||
|
|
Code lstm_function(lstm_g);
|
||
|
|
InterpreterState lstm_interp(lstm_function);
|
||
|
|
auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
|
||
|
|
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
|
||
|
|
|
||
|
|
// std::cout << almostEqual(outputs[0],hx) << "\n";
|
||
|
|
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
|
||
|
|
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
|
||
|
|
}
|
||
|
|
} // namespace jit
|
||
|
|
} // namespace torch
|