pytorch/torch/csrc/jit/api/function_impl.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

173 lines
5.1 KiB
C++
Raw Normal View History

[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
#include <c10/util/Flags.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/peephole.h>
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
#ifndef C10_MOBILE
#include <ATen/autocast_mode.h>
#include <torch/csrc/jit/passes/autocast.h>
#endif
[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
C10_DEFINE_bool(
torch_jit_do_not_store_optimized_graph,
false,
"Do not store the optimized graph.");
namespace torch::jit {
namespace {
c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
std::vector<c10::Argument> args;
std::vector<c10::Argument> returns;
Graph& g = *function.graph();
size_t num_inputs = function.num_inputs();
for (const auto i : c10::irange(num_inputs)) {
const Value* v = g.inputs().at(i);
std::string name = v->hasDebugName() ? v->debugNameBase()
: ("argument_" + std::to_string(i));
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
}
for (const auto i : c10::irange(g.outputs().size())) {
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
}
return {function.name(), "", std::move(args), std::move(returns)};
}
template <typename T, typename F>
T* tryToGraphFunctionImpl(F& function) noexcept {
if (!function.isGraphFunction()) {
return nullptr;
}
return static_cast<T*>(&function);
}
template <typename T, typename F>
T& toGraphFunctionImpl(F& function) {
if (auto* g = tryToGraphFunctionImpl<T>(function)) {
return *g;
}
TORCH_INTERNAL_ASSERT(
false,
"Failed to downcast a Function to a GraphFunction. "
"This probably indicates that the JIT calling context needs a "
"special case on tryToGraphFunction() instead.");
}
} // namespace
static void placeholderCreator(GraphFunction&) {
throw RecursiveMethodCallError();
}
void GraphFunction::run(Stack& stack) {
Bytecode export flow (#25187) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25187 The bytecode export flow: dump the bytecode format for the light weighted interpreter. * The bytecode is generated without input spec optimization. It would be more generic (input independent) with no obvious performance degradation (to be tested). * Main API: torch::jit::script::Module::save(filename, extra_files, bool *bytecode_format* = false). * Both bytecode and module object are exported in pickle format. * The module object (in data.pkl) is the same as the original JIT model. * The serializer is dependent on pickle only (no protobuf or Json). * The major functionality is forked in ScriptModuleSerializer2::serialize(). * The test loader is test_bc_export.cpp. * Simple APIs are added in Code and its implementation to get necessary information (instructions, operators and constants). * Since there's no dependency on graph/node, GetAttr is promoted from an operator to first-class instruction (https://github.com/pytorch/pytorch/pull/25151) . * Some definitions (instructions, writeArchive, etc) that are shared by full JIT and bytecode are pulled out of the local namespace (https://github.com/pytorch/pytorch/pull/25148). The output layout looks like: * folders of methods. * In each method folder (for example, forward/): * bytecode.pkl: instructions and operators * constants{.pkl,/}: constant list in constants.pkl. If there are tensors in constants, the binary tensor files in constants/ folder. * data{.pkl,/}: the module object, with binary tensor files in data/ folder. The same as in torchscript. Test Plan: Imported from OSS Differential Revision: D17076411 fbshipit-source-id: 46eb298e7320d1e585b0101effc0fcfd09219046
2019-09-25 23:34:05 +00:00
get_executor().run(stack);
}
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
return get_executor().runAsync(stack, std::move(taskLauncher));
}
void GraphFunction::ensure_defined() {
if (function_creator_) {
auto creator = function_creator_;
function_creator_ = placeholderCreator;
creator(*this);
function_creator_ = nullptr;
}
check_single_output();
}
const c10::FunctionSchema& GraphFunction::getSchema() const {
if (schema_ == nullptr) {
schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this));
}
return *schema_;
}
[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
std::shared_ptr<Graph> GraphFunction::optimized_graph() const {
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
decltype(optimized_graphs_)::value_type graph;
auto& graph_ref = !FLAGS_torch_jit_do_not_store_optimized_graph
[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
? optimized_graphs_[currentSpecialization()]
: graph;
if (graph_ref) {
return graph_ref;
[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
}
graph_ref = graph_->copy();
[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
if (getGraphExecutorOptimize()) {
preoptimizeGraph(graph_ref, force_no_amp_);
[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
}
return graph_ref;
[raas][torch][jit] Allow not storing the optimized graph (#115381) Summary: GraphFunction internally stores the optimized graph after generating it and then it is passed into the executor which makes a copy of it. So we store the optimized graph effectively twice. This diff allows to set a flag to not store the optimized graph inside the GraphFunction. The code is NoP right now until the flag is enabled. Test Plan: I ran SL with this on raas with good memory saving on raas server. From command line: exmaple model run ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 I1207 11:04:58.657143 3556226 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 255646 Kb ``` then with flag enabled: ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true I1207 11:06:25.245779 3577383 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 165167 Kb ``` So collective with this flag and the flag from D51950418 ``` buck run mode/opt-clang sigrid/predictor/client/localnet:run_model -- --model_id_to_load=953556500 --model_snapshot_to_load=362 --torch_jit_do_not_store_optimized_graph=true --torch_jit_enable_profiling_graph_executor=false I1207 11:09:17.502743 3592345 SigridPredictorLocalModelFactory.cpp:32] Memory usage for 953556500_362 is 114848 Kb ``` Differential Revision: D51931895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115381 Approved by: https://github.com/malfet
2023-12-08 16:29:13 +00:00
}
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
if (force_no_amp_) {
return SpecializationKey::AutocastOff;
}
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
#ifdef C10_MOBILE
// disabling autodiff pass for mobile build since autocast APIs don't exist
return SpecializationKey::AutocastOff;
#else
Refactor autocast C++ APIs to be device-agnostic (#124359) # Motivation This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast **C++** APIs. In C++ side, - `is_enabled()` -> `is_enabled(device_type)`. - `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`. - `get_autocast_dtype()` -> `get_autocast_dtype(device_type)` - `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)` These following C++ APIs are deprecated and should be removed in PyTorch 2.5 - `is_cpu_enabled` - `set_cpu_enabled` - `get_autocast_cpu_dtype` - `set_autocast_cpu_dtype` - `is_xpu_enabled` - `set_xpu_enabled` - `get_autocast_xpu_dtype` - `set_autocast_xpu_dtype` - `is_ipu_enabled` - `set_ipu_enabled` - `get_autocast_ipu_dtype` - `set_autocast_ipu_dtype` - `is_hpu_enabled` - `set_hpu_enabled` - `get_autocast_hpu_dtype` - `set_autocast_hpu_dtype` - `is_xla_enabled` - `set_xla_enabled` - `get_autocast_xla_dtype` - `set_autocast_xla_dtype` - `is_privateuseone_enabled` - `set_privateuseone_enabled` - `get_autocast_privateuseone_dtype` - `set_autocast_privateuseone_dtype` In Python side, provide 4 generic autocast APIs: - `torch.is_autocast_enabled(device_type)` - `torch.set_autocast_enabled(device_type, new_enabled)` - `torch.get_autocast_dtype(device_type)` - `torch.set_autocast_dtype(device_type, dtype)` # Additional Context We will submit another PR to refactor autocast **Python** APIs based on this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359 Approved by: https://github.com/jgong5, https://github.com/albanD
2024-04-23 09:34:45 +00:00
bool cpu_enabled = at::autocast::is_autocast_enabled(at::kCPU);
bool gpu_enabled = at::autocast::is_autocast_enabled(at::kCUDA);
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
if (cpu_enabled && gpu_enabled) {
return SpecializationKey::CpuGpuAutocastOn;
} else if (!cpu_enabled && !gpu_enabled) {
return SpecializationKey::AutocastOff;
} else {
return gpu_enabled ? SpecializationKey::GpuAutocastOn
: SpecializationKey::CpuAutocastOn;
}
#endif
}
void preoptimizeGraph(std::shared_ptr<Graph>& graph, bool disable_autocast) {
Inline(*graph);
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
// Peephole Optimize cleans up many "is None" checks and creates constant prop
// opportunities
PeepholeOptimize(graph, true);
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
// AliasDb construction can be slow, so run it just on immutable types
// to clean up constant Ifs & other easy wins
ConstantPropagationImmutableTypes(graph);
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
#ifndef C10_MOBILE
// Inject casts for automatic mixed precision
//
// TODO: Ideally, this pass could run earlier, before inlining
// or any other optimizations. That setup is preferable because:
// 1. The AMP pass would be self-contained and function independently
// of the any optimizations
// 2. AMP transformations would benefit from followup passes's cleanup
//
if (!disable_autocast) {
Autocast(graph);
}
Add fp16/fp32 autocasting to JIT/TorchScript (#63939) Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 19:09:53 +00:00
#endif
ConstantPooling(graph);
}
GraphFunction* tryToGraphFunction(Function& function) noexcept {
return tryToGraphFunctionImpl<GraphFunction>(function);
}
GraphFunction& toGraphFunction(Function& function) {
return toGraphFunctionImpl<GraphFunction>(function);
}
const GraphFunction& toGraphFunction(const Function& function) {
return toGraphFunctionImpl<const GraphFunction>(function);
}
} // namespace torch::jit