diff --git a/test/cpp/lazy/test_ir.cpp b/test/cpp/lazy/test_ir.cpp index 46c7d28ee07..3e775dce6e2 100644 --- a/test/cpp/lazy/test_ir.cpp +++ b/test/cpp/lazy/test_ir.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include diff --git a/test/cpp/lazy/test_ir_util.cpp b/test/cpp/lazy/test_ir_util.cpp index 48573ed763c..ad951956db7 100644 --- a/test/cpp/lazy/test_ir_util.cpp +++ b/test/cpp/lazy/test_ir_util.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp index 318d1ecbd48..6ac65c90105 100644 --- a/test/cpp/lazy/test_lazy_ops.cpp +++ b/test/cpp/lazy/test_lazy_ops.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include diff --git a/test/cpp/lazy/test_symbolic_shape.cpp b/test/cpp/lazy/test_symbolic_shape.cpp index 7243530e7ff..f0ce5a3083e 100644 --- a/test/cpp/lazy/test_symbolic_shape.cpp +++ b/test/cpp/lazy/test_symbolic_shape.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/lazy/core/ir.h b/torch/csrc/lazy/core/ir.h index a9b55fff1b7..33e33d3045b 100644 --- a/torch/csrc/lazy/core/ir.h +++ b/torch/csrc/lazy/core/ir.h @@ -175,12 +175,6 @@ inline std::ostream& operator<<(std::ostream& stream, const Node& node) { return stream; } -// TODO(alanwaketan): Support r-value reference argument type. -template -NodePtr MakeNode(Args&&... args) { - return std::make_shared(std::forward(args)...); -} - template const T* NodeCast(const Node* node, OpKind op) { if (op != node->op()) { diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 93691b92783..ca34df2af83 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -12,6 +12,12 @@ namespace torch { namespace lazy { +// TODO(alanwaketan): Support r-value reference argument type. +template +NodePtr MakeNode(Args&&... args) { + return std::make_shared(std::forward(args)...); +} + struct IrBuilder { virtual NodePtr MakeDeviceData(const std::shared_ptr& data) const = 0; virtual NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) const = 0; diff --git a/torch/csrc/lazy/ts_backend/ops/generic.h b/torch/csrc/lazy/ts_backend/ops/generic.h index 2f9a837cd78..c605aaa437c 100644 --- a/torch/csrc/lazy/ts_backend/ops/generic.h +++ b/torch/csrc/lazy/ts_backend/ops/generic.h @@ -2,6 +2,8 @@ #include +#include + namespace torch { namespace lazy { diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index 14fe5ef378e..a8d54b5f001 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -451,6 +451,7 @@ def run_gen_lazy_tensor( "ATen/MetaFunctions.h", "ATen/Operators.h", "ATen/native/CPUFallback.h", + "torch/csrc/lazy/core/ir_builder.h", "torch/csrc/lazy/core/lazy_graph_executor.h", "torch/csrc/lazy/core/metrics.h", "torch/csrc/lazy/core/shape.h",