pytorch/test/cpp/lazy/test_ir_util.cpp
Antonio Kim e1b4117e30 Move shape and operand definitions to base node (#75223)
Summary:
First stage of breaking up https://github.com/pytorch/pytorch/pull/74710

Moves the shape and operand definitions from `TsNode` to the base `Node`

CC: wconstab JackCaoG henrytwo

Partially Fixes https://github.com/pytorch/pytorch/issues/74628

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75223

Reviewed By: zou3519

Differential Revision: D35410285

Pulled By: wconstab

fbshipit-source-id: bb84d3fb636882cbe7e18af4b35ff2c0e22aaa58
(cherry picked from commit a4144c9a48379d8a9007cff845796608b597cce1)
2022-04-06 01:43:46 +00:00

71 lines
1.9 KiB
C++

#include <gtest/gtest.h>
#include <c10/util/Exception.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_metadata.h>
#include <torch/csrc/lazy/core/ir_util.h>
namespace torch {
namespace lazy {
class IrUtilNode : public Node {
public:
explicit IrUtilNode()
: Node(OpKind(), /* num_outputs */ 1, /* hash_func */ [&](bool /*bakeInSizes*/) -> hash_t { return Hash(0); }) {}
~IrUtilNode() override = default;
void AddOperand(Value v) {
if (!v.node) {
return;
}
operands_as_outputs_.emplace_back(v.node.get(), v.index);
operands_.push_back(std::move(v.node));
}
};
/* a
* / \
*b c
* \ /
* d
* Post-order: d c b a
*/
TEST(IrUtilTest, BasicTest) {
NodePtr a = MakeNode<IrUtilNode>();
NodePtr b = MakeNode<IrUtilNode>();
NodePtr c = MakeNode<IrUtilNode>();
NodePtr d = MakeNode<IrUtilNode>();
dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0));
dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(c, 1));
dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(d, 0));
dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(d, 0));
std::vector<Node*> postorder = Util::ComputePostOrder({a.get()});
EXPECT_EQ(postorder.size(), 4);
EXPECT_EQ(postorder.at(0), d.get());
EXPECT_EQ(postorder.at(1), c.get());
EXPECT_EQ(postorder.at(2), b.get());
EXPECT_EQ(postorder.at(3), a.get());
}
/* a
* / \
*b---c
* Post-order: not valid
*/
TEST(IrUtilTest, TestCircle) {
NodePtr a = MakeNode<IrUtilNode>();
NodePtr b = MakeNode<IrUtilNode>();
NodePtr c = MakeNode<IrUtilNode>();
dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0));
dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(c, 0));
dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(a, 0));
EXPECT_THROW(Util::ComputePostOrder({a.get()}), c10::Error);
}
} // namespace lazy
} // namespace torch