Merging the implementations of ClearProfiling (#67575)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67575

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D32175959

Pulled By: Gamrix

fbshipit-source-id: b335dacce709a64e3d5779f9c6e9569f86e10748
This commit is contained in:
John Clow 2021-11-04 18:57:19 -07:00 committed by Facebook GitHub Bot
parent b8e165e841
commit f1754319e3
3 changed files with 6 additions and 30 deletions

View file

@ -1,4 +1,3 @@
#include <torch/csrc/jit/passes/clear_profiling.h>
#include <torch/csrc/jit/jit_log.h>
@ -6,7 +5,7 @@
namespace torch {
namespace jit {
static void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
for (auto i : graph->inputs()) {
if (i->type()->isSubtypeOf(*TensorType::get())) {
i->setType(unshapedType(i->type()));
@ -14,7 +13,7 @@ static void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
}
}
static void unprofileBlock(Block* start_block) {
void unprofileBlock(Block* start_block) {
std::vector<Block*> stack;
stack.push_back(start_block);

View file

@ -9,6 +9,10 @@
namespace torch {
namespace jit {
TORCH_API void unprofileGraphInputs(const std::shared_ptr<Graph>& graph);
TORCH_API void unprofileBlock(Block* start_block);
// Unprofiles all the node outputs in a block.
TORCH_API void ClearProfilingInformation(const std::shared_ptr<Graph>& graph);
} // namespace jit

View file

@ -113,33 +113,6 @@ ProfileIValueOp* ProfilingRecord::createProfileIValueNode(
return pn;
}
static void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
for (auto i : graph->inputs()) {
if (i->type()->isSubtypeOf(*TensorType::get())) {
i->setType(unshapedType(i->type()));
}
}
}
static void unprofileBlock(Block* start_block) {
std::vector<Block*> stack;
stack.push_back(start_block);
while (!stack.empty()) {
Block* block = stack.back();
stack.pop_back();
for (auto n : block->nodes()) {
for (auto o : n->outputs()) {
if (o->type()->isSubtypeOf(*TensorType::get())) {
o->setType(unshapedType(o->type()));
}
}
stack.insert(stack.end(), n->blocks().begin(), n->blocks().end());
}
}
}
c10::SymbolicShape ProfilingRecord::mergeSymbolicShapes(
const c10::SymbolicShape& new_sizes,
const c10::SymbolicShape& sym_shapes,