diff --git a/torch/csrc/jit/passes/clear_profiling.cpp b/torch/csrc/jit/passes/clear_profiling.cpp index 405b110b2ef..ed797bdc023 100644 --- a/torch/csrc/jit/passes/clear_profiling.cpp +++ b/torch/csrc/jit/passes/clear_profiling.cpp @@ -1,4 +1,3 @@ - #include #include @@ -6,7 +5,7 @@ namespace torch { namespace jit { -static void unprofileGraphInputs(const std::shared_ptr& graph) { +void unprofileGraphInputs(const std::shared_ptr& graph) { for (auto i : graph->inputs()) { if (i->type()->isSubtypeOf(*TensorType::get())) { i->setType(unshapedType(i->type())); @@ -14,7 +13,7 @@ static void unprofileGraphInputs(const std::shared_ptr& graph) { } } -static void unprofileBlock(Block* start_block) { +void unprofileBlock(Block* start_block) { std::vector stack; stack.push_back(start_block); diff --git a/torch/csrc/jit/passes/clear_profiling.h b/torch/csrc/jit/passes/clear_profiling.h index 46915f5549f..b9ac975dee6 100644 --- a/torch/csrc/jit/passes/clear_profiling.h +++ b/torch/csrc/jit/passes/clear_profiling.h @@ -9,6 +9,10 @@ namespace torch { namespace jit { +TORCH_API void unprofileGraphInputs(const std::shared_ptr& graph); +TORCH_API void unprofileBlock(Block* start_block); +// Unprofiles all the node outputs in a block. + TORCH_API void ClearProfilingInformation(const std::shared_ptr& graph); } // namespace jit diff --git a/torch/csrc/jit/runtime/profiling_record.cpp b/torch/csrc/jit/runtime/profiling_record.cpp index e3d8ef653b7..42a74ffeb9b 100644 --- a/torch/csrc/jit/runtime/profiling_record.cpp +++ b/torch/csrc/jit/runtime/profiling_record.cpp @@ -113,33 +113,6 @@ ProfileIValueOp* ProfilingRecord::createProfileIValueNode( return pn; } -static void unprofileGraphInputs(const std::shared_ptr& graph) { - for (auto i : graph->inputs()) { - if (i->type()->isSubtypeOf(*TensorType::get())) { - i->setType(unshapedType(i->type())); - } - } -} - -static void unprofileBlock(Block* start_block) { - std::vector stack; - stack.push_back(start_block); - - while (!stack.empty()) { - Block* block = stack.back(); - stack.pop_back(); - - for (auto n : block->nodes()) { - for (auto o : n->outputs()) { - if (o->type()->isSubtypeOf(*TensorType::get())) { - o->setType(unshapedType(o->type())); - } - } - stack.insert(stack.end(), n->blocks().begin(), n->blocks().end()); - } - } -} - c10::SymbolicShape ProfilingRecord::mergeSymbolicShapes( const c10::SymbolicShape& new_sizes, const c10::SymbolicShape& sym_shapes,