#include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h> #include <torch/csrc/jit/jit_log.h> #include <torch/csrc/jit/passes/bailout_graph.h> #include <torch/csrc/jit/passes/batch_mm.h> #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h> #include <torch/csrc/jit/passes/clear_profiling.h> #include <torch/csrc/jit/passes/clear_undefinedness.h> #include <torch/csrc/jit/passes/common_subexpression_elimination.h> #include <torch/csrc/jit/passes/constant_pooling.h> #include <torch/csrc/jit/passes/constant_propagation.h> #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> #include <torch/csrc/jit/passes/cuda_graph_fuser.h> #include <torch/csrc/jit/passes/dead_code_elimination.h> #include <torch/csrc/jit/passes/decompose_ops.h> #include <torch/csrc/jit/passes/graph_fuser.h> #include <torch/csrc/jit/passes/guard_elimination.h> #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h> #include <torch/csrc/jit/passes/inliner.h> #include <torch/csrc/jit/passes/inplace_check.h> #include <torch/csrc/jit/passes/insert_guards.h> #include <torch/csrc/jit/passes/loop_unrolling.h> #include <torch/csrc/jit/passes/lower_grad_of.h> #include <torch/csrc/jit/passes/lower_tuples.h> #include <torch/csrc/jit/passes/pass_manager.h> #include <torch/csrc/jit/passes/peephole.h> #include <torch/csrc/jit/passes/remove_expands.h> #include <torch/csrc/jit/passes/remove_mutation.h> #include <torch/csrc/jit/passes/requires_grad_analysis.h> #include <torch/csrc/jit/passes/shape_analysis.h> #include <torch/csrc/jit/passes/specialize_autogradzero.h> #include <torch/csrc/jit/passes/tensorexpr_fuser.h> C10_DEFINE_bool( torch_jit_enable_new_executor, true, "If this flag is set to false TorchScript will be using the legacy/original executor"); C10_DEFINE_bool( torch_jit_disable_warning_prints, false, "Disables warning.warn prints in TorchScript graph"); constexpr size_t kDefaultNumProfiledRuns = 1; constexpr size_t kDefaultBailoutDepth = 20; C10_DEFINE_int64( torch_jit_num_profiled_runs, kDefaultNumProfiledRuns, "Number of profiling runs"); C10_DEFINE_int64( torch_jit_bailout_depth, kDefaultBailoutDepth, "Number of re-specializations"); namespace torch { namespace jit { #if defined(C10_MOBILE) static std::atomic<bool> executor_mode{true}; static std::atomic<bool> profiling_mode{false}; #else static std::atomic<bool> executor_mode{true}; static std::atomic<bool> profiling_mode{true}; #endif static std::atomic<size_t> num_profiled_runs{kDefaultNumProfiledRuns}; static std::atomic<size_t> bailout_depth{kDefaultBailoutDepth}; std::atomic<bool>& getProfilingMode() { return profiling_mode; } std::atomic<bool>& getExecutorMode() { return executor_mode; } std::atomic<size_t>& getNumProfiledRuns() { // Initialize num_profiled_runs from command-line flag. static const size_t init = []() { return num_profiled_runs = FLAGS_torch_jit_num_profiled_runs; }(); (void)init; // Silence clang-tidy. return num_profiled_runs; } std::atomic<size_t>& getBailoutDepth() { // Initialize bailout_depth from command-line flag. static const size_t init = []() { return bailout_depth = FLAGS_torch_jit_bailout_depth; }(); (void)init; // Silence clang-tidy. return bailout_depth; } static bool needsGradientInProfilingMode(Block* b) { for (auto n : b->nodes()) { if (n->kind() == prim::BailOut) { auto ptt = n->output()->type()->expect<TensorType>(); if (ptt->requiresGrad() && *ptt->requiresGrad()) { return true; } } if (n->kind() == prim::profile) { auto type = n->ty(attr::profiled_type)->expect<TensorType>(); if (type->requiresGrad() && *type->requiresGrad()) { return true; } } for (auto ib : n->blocks()) { if (needsGradientInProfilingMode(ib)) { return true; } } } return false; } void runNooptPassPipeline(std::shared_ptr<Graph>& graph) { GRAPH_DEBUG( "Before LowerGradOf (beginning of runNooptPassPipeline)\n", *graph); LowerGradOf(*graph); GRAPH_DEBUG("After LowerGradOf, before RemoveExpands\n", *graph); RemoveExpands(graph); GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph); CanonicalizeOps(graph); GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph); EliminateDeadCode(graph); GRAPH_DEBUG( "After EliminateDeadCode (end of runNooptPassPipeline)\n", *graph); } void runPreAutodiffPassPipeline(std::shared_ptr<Graph>& graph) { GRAPH_DEBUG( "Before InsertGuards (beginning of runPreAutodiffPassPipeline)\n", *graph); if (tensorExprFuserEnabled() || RegisterCudaFuseGraph::isRegistered()) { // With TE fuser or nvfuser, we don't generate bailouts LowerGradOf(*graph); GRAPH_DEBUG("After LowerGradOf, before specializeAutogradZero\n", *graph); } else { InsertGuards(graph); GRAPH_DEBUG("After InsertGuards, before LowerGradOf\n", *graph); LowerGradOf(*graph); GRAPH_DEBUG("After LowerGradOf, before EliminateRedundantGuards\n", *graph); EliminateRedundantGuards(graph); GRAPH_DEBUG( "After EliminateRedundantGuards, before InsertBailOuts\n", *graph); InsertBailOuts(graph); GRAPH_DEBUG( "After InsertBailOuts, before specializeAutogradZero\n", *graph); } specializeAutogradZero(graph); GRAPH_DEBUG("After specializeAutogradZero\n", *graph); // runRequiredPasses { RemoveExpands(graph); GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph); CanonicalizeOps(graph); GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph); EliminateDeadCode(graph); GRAPH_DEBUG("After EliminateDeadCode", *graph); } PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); ConstantPropagation(graph); // runOptimization: { EliminateDeadCode(graph); GRAPH_DEBUG( "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph); EliminateCommonSubexpression(graph); GRAPH_DEBUG( "After EliminateCommonSubexpression, before PeepholeOptimize\n", *graph); PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); ConstantPropagation(graph); GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph); ConstantPooling(graph); GRAPH_DEBUG("After ConstantPooling, before UnrollLoops\n", *graph); UnrollLoops(graph); GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph); // run again with unrolled loops RemoveListMutation(graph); GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph); PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); ConstantPropagation(graph); GRAPH_DEBUG( "After ConstantPropagation, before EliminateCommonSubexpression\n", *graph); EliminateCommonSubexpression(graph); GRAPH_DEBUG( "After EliminateCommonSubexpression, before CheckInplace\n", *graph); CheckInplace(graph); } GRAPH_DEBUG( "After CheckInplace (end of runPreAutodiffPassPipeline)\n", *graph); } void runDiffGraphPasses(std::shared_ptr<Graph>& graph) { GRAPH_DEBUG( "Before EliminateDeadCode (beginning of runDiffGraphPasses)\n", *graph); // runOptimization: { // Basic graph preprocessing to eliminate noise. EliminateDeadCode(graph); GRAPH_DEBUG( "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph); EliminateCommonSubexpression(graph); GRAPH_DEBUG( "After EliminateCommonSubexpression, before PeepholeOptimize\n", *graph); PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); ConstantPropagation(graph); GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph); ConstantPooling(graph); GRAPH_DEBUG("After ConstantPooling, before UnrollLoops\n", *graph); UnrollLoops(graph); GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph); // run again with unrolled loops RemoveListMutation(graph); GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph); PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); ConstantPropagation(graph); GRAPH_DEBUG( "After ConstantPropagation, before EliminateCommonSubexpression\n", *graph); EliminateCommonSubexpression(graph); GRAPH_DEBUG( "After EliminateCommonSubexpression, before CheckInplace\n", *graph); CheckInplace(graph); } GRAPH_DEBUG("After CheckInplace, before customPrePasses\n", *graph); // runNondiffOptimization { // Run custom passes that different backends can register. for (const auto& passPair : getCustomPrePasses()) { passPair.first(graph); } GRAPH_DEBUG("After customPrePasses, before LowerSimpleTuples\n", *graph); // TupleConstruct / TupleUnpack pairs can still be present at this point // and must be removed for fusion. LowerSimpleTuples(graph); GRAPH_DEBUG("After LowerSimpleTuples\n", *graph); if (tensorExprFuserEnabled()) { // Remove prim::profile nodes and embed the profile info directly in the // IR in value types. We're doing such transformation as optimizations // that try to merge/fuse nodes in the graph (e.g. BatchMM and GraphFuser) // work worse in the presence of intermittent prim::profile nodes. // Optimizations relying on the type info are also responsible for // inserting proper type checks. Once we're done with these optimizations // we will wipe the tensor type information from the IR, so that it's not // accidentally used by any other pass. RemoveProfileNodesAndSpecializeTypes(graph); GRAPH_DEBUG( "After RemoveProfileNodesAndSpecializeTypes, before BatchMM\n", *graph); // Rewrite subgraphs with many MMs into expressions that batch them. BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1); GRAPH_DEBUG( "After Fusion, before RemoveTensorTypeSpecializations\n", *graph); // Wipe tensor type info from the IR RemoveTensorTypeSpecializations(graph); GRAPH_DEBUG( "After RemoveTensorTypeSpecializations, before customPostPasses\n", *graph); } else { // Rewrite subgraphs with many MMs into expressions that batch them. BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); FuseGraph(graph, true); GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph); } // Run custom post-fusion passes for (const auto& passPair : getCustomPostPasses()) { passPair.first(graph); } } GRAPH_DEBUG("After customPostPasses (end of runDiffGraphPasses)\n", *graph); } void runNoGradOptimizations(std::shared_ptr<Graph>& graph) { GRAPH_DEBUG( "After customPostPasses (beginning of runNoGradOptimizations)\n", *graph); // runNondiffOptimization { // Run custom passes that different backends can register. for (const auto& passPair : getCustomPrePasses()) { passPair.first(graph); } GRAPH_DEBUG("After customPrePasses, before LowerSimpleTuples\n", *graph); // TupleConstruct / TupleUnpack pairs can still be present at this point // and must be removed for fusion. LowerSimpleTuples(graph); GRAPH_DEBUG("After LowerSimpleTuples\n", *graph); if (tensorExprFuserEnabled()) { // Remove prim::profile nodes and embed the profile info directly in the // IR in value types. We're doing such transformation as optimizations // that try to merge/fuse nodes in the graph (e.g. BatchMM and GraphFuser) // work worse in the presence of intermittent prim::profile nodes. // Optimizations relying on the type info are also responsible for // inserting proper type checks. Once we're done with these optimizations // we will wipe the tensor type information from the IR, so that it's not // accidentally used by any other pass. RemoveProfileNodesAndSpecializeTypes(graph); GRAPH_DEBUG( "After RemoveProfileNodesAndSpecializeTypes, before BatchMM\n", *graph); // Rewrite subgraphs with many MMs into expressions that batch them. BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1); GRAPH_DEBUG( "After Fusion, before RemoveTensorTypeSpecializations\n", *graph); // Wipe tensor type info from the IR RemoveTensorTypeSpecializations(graph); GRAPH_DEBUG( "After RemoveTensorTypeSpecializations, before customPostPasses\n", *graph); } else { // Rewrite subgraphs with many MMs into expressions that batch them. BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); FuseGraph(graph, true); GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph); } // Run custom post-fusion passes for (const auto& passPair : getCustomPostPasses()) { passPair.first(graph); } } GRAPH_DEBUG( "After customPostPasses (end of runNoGradOptimizations)\n", *graph); } void ProfilingGraphExecutorImpl::runProfilingOptimizations( std::shared_ptr<Graph>& copy) { GRAPH_DEBUG("Before runProfilingOptimizations:\n", *copy); if (!getGraphExecutorOptimize()) { runNooptPassPipeline(copy); return; } runPreAutodiffPassPipeline(copy); if (needsGradientInProfilingMode(copy->block())) { auto diff_nodes = CreateAutodiffSubgraphs( copy, getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1); GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *copy); size_t idx = 0; for (Node* dnode : diff_nodes) { GRAPH_DEBUG("Optimizing diff node ", idx); auto diff_graph = std::move(dnode->g(attr::Subgraph)); Gradient gradient = differentiate(diff_graph); GRAPH_DEBUG("Forward graph:\n", *(gradient.f)); GRAPH_DEBUG("Backward graph:\n", *(gradient.df)); runDiffGraphPasses(gradient.f); // replaces fallback graphs inserted by TE Fuser replaceFallbackGraphWithFallbackFunction(gradient.f->block()); packGradient(gradient, dnode); GRAPH_DEBUG("Finished optimizing diff node ", idx++); } InlineAutodiffSubgraphs( copy, getAutodiffSubgraphInlining() ? autodiffSubgraphInlineThreshold : 1); RemoveProfilingNodes(copy); GRAPH_DEBUG( "After InlineAutodiffSubgraphs and Removing Profiling Nodes\n", *copy); } else { runNoGradOptimizations(copy); } EliminateDeadCode(copy); GRAPH_DEBUG("After runProfilingOptimizations:\n", *copy); } void ProfilingGraphExecutorImpl::runProfilingInsensitiveOptimizations( std::shared_ptr<Graph>& graph) { GRAPH_DEBUG( "Before inlining (beginning of runProfilingInsensitiveOptimizations)\n", *graph); // TODO: maybe this can go later in pipeline / directly in autodiff forward // creation if (getGraphExecutorOptimize()) { Inline(*graph); } GRAPH_DEBUG("After inlining, before ClearProfilingInformation\n", *graph); ClearProfilingInformation(graph); GRAPH_DEBUG("After ClearProfilingInformation, before LowerGradOf\n", *graph); LowerGradOf(*graph); GRAPH_DEBUG("After LowerGradOf, before ClearUndefinedness\n", *graph); // clear any residual undefinedness // as double backward graph inputs' // may carry over undefinedness // from profiled backward graphs ClearUndefinedness(graph); // runRequiredPasses { GRAPH_DEBUG("After ClearUndefinedness, before RemoveExpands\n", *graph); RemoveExpands(graph); GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph); CanonicalizeOps(graph); GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph); EliminateDeadCode(graph); } if (!getGraphExecutorOptimize()) { GRAPH_DEBUG( "After EliminateDeadCode (end of runProfilingInsensitiveOptimizations)\n", *graph); return; } GRAPH_DEBUG("After EliminateDeadCode, before DecomposeOps\n", *graph); DecomposeOps(graph); GRAPH_DEBUG("After DecomposeOps, before ConstantPropagation\n", *graph); ConstantPropagation(graph); GRAPH_DEBUG("After ConstantPropagation, before EliminateDeadCode\n", *graph); EliminateDeadCode(graph); GRAPH_DEBUG( "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph); EliminateCommonSubexpression(graph); GRAPH_DEBUG( "After EliminateCommonSubexpression, before ConstantPooling\n", *graph); ConstantPooling(graph); GRAPH_DEBUG("After ConstantPooling, before PeepholeOptimize\n", *graph); PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before EliminateDeadCode\n", *graph); EliminateDeadCode(graph); GRAPH_DEBUG("After EliminateDeadCode, before LowerSimpleTuples\n", *graph); LowerSimpleTuples(graph); GRAPH_DEBUG("After LowerSimpleTuples, before CheckInplace\n", *graph); CheckInplace(graph); GRAPH_DEBUG( "After CheckInplace (end of runProfilingInsensitiveOptimizations)\n", *graph); } ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl( const std::shared_ptr<Graph>& graph, std::string function_name) : GraphExecutorImplBase(graph, std::move(function_name)) {} const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor( Stack& stack, size_t remaining_bailout_depth) { GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this); // no opt mode if (!getGraphExecutorOptimize()) { if (!fallback_plan_) { auto copy = graph->copy(); GRAPH_DEBUG( "Before LowerGradOf (beginning of runNooptPassPipeline)\n", *graph); LowerGradOf(*copy); GRAPH_DEBUG("After LowerGradOf, before RemoveExpands\n", *graph); RemoveExpands(copy); fallback_plan_ = ExecutionPlan(copy, function_name_); GRAPH_DUMP("NoOpt Graph: ", copy); } return *fallback_plan_; } // if tensorExprFuserEnabled() returns true we need to persist the very first // time ProfilingGraphExecutorImpl is called, so we can update it correctly // for fallback functions in ProfilingGraphExecutorImpl Else, // getPlanFor(remaining_bailout_depth) is corrected and persisted by the Code // object in interpreter. if (!remaining_bailout_depth_.has_value() || !tensorExprFuserEnabled()) { remaining_bailout_depth_ = remaining_bailout_depth; } // simple executor if (*remaining_bailout_depth_ == 0) { auto copy = graph->copy(); runProfilingInsensitiveOptimizations(copy); GRAPH_DUMP("Optimized SimpleExecutor Graph: ", copy); optimized_plan_ = ExecutionPlan(copy, function_name_); return *optimized_plan_; } // if a profiling graph hasn't been created yet if (!pr_) { auto copy = graph->copy(); runProfilingInsensitiveOptimizations(copy); pr_ = ProfilingRecord::instrumentGraph(copy); GRAPH_DUMP("Profiled Graph: ", pr_->graph()); profiling_plan_ = ExecutionPlan(pr_->graph(), function_name_); // fall-through } // profile until a graph is ready if (!pr_->ready()) { return *profiling_plan_; } auto copy = pr_->graph()->copy(); ProfilingRecord::removeProfileCounter(copy->block()); runProfilingOptimizations(copy); // replaces a fallback graph inserted by // specialize_autogradzero if one exists replaceFallbackGraphWithFallbackFunction(copy->block()); GRAPH_DUMP("Optimized Graph: ", copy); optimized_plan_ = ExecutionPlan(copy, function_name_, *remaining_bailout_depth_); return *optimized_plan_; } const ExecutionPlan& ProfilingGraphExecutorImpl::getPlanFor( Stack& stack, size_t remaining_bailout_depth) { std::lock_guard<std::mutex> lock(compile_mutex); // IMPORTANT: This is a hot path of calling a torchscript function. Try not to // add any code above this. if (optimized_plan_) { return *optimized_plan_; } return getOptimizedPlanFor(stack, remaining_bailout_depth); } GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() { GraphExecutorState state; TORCH_INTERNAL_ASSERT(optimized_plan_); auto opt_plan = *optimized_plan_; state.execution_plans.emplace(ArgumentSpec{0, 0}, opt_plan); return state; } Node* insertFallbackFunctionCall( Graph* graph, Function* func, ArrayRef<Value*> inputs) { auto tuple_type = func->graph()->return_node()->input(0)->type(); Value* fn_constant = graph->insertNode(graph->create(prim::Constant)) ->s_(attr::name, func->name()) ->i_(Symbol::attr("fallback"), 1) ->output() ->setType(FunctionType::create(func)); std::vector<Value*> func_call_inputs = {fn_constant}; func_call_inputs.insert(func_call_inputs.end(), inputs.begin(), inputs.end()); Value* result = graph->insertNode(graph->create(prim::CallFunction, func_call_inputs)) ->output() ->setType(tuple_type); auto fun_unpack_tuple = graph->insertNode(graph->createTupleUnpack(result)); return fun_unpack_tuple; } Function* createFallbackPathFunction( Block* b, const std::string& function_name) { auto value_map = [](Value* v) { return v; }; auto graph = std::make_shared<Graph>(); graph->block()->cloneFrom(b, value_map); auto otypes = c10::fmap( graph->return_node()->inputs(), [](Value* v) { return v->type(); }); // a GraphFunction call only have one output, so all the outputs // need to be packed into a tuple auto tuple_type = TupleType::create(otypes); auto return_tuple = graph->createTuple(graph->return_node()->inputs()); graph->appendNode(return_tuple); for (int i = static_cast<int>(graph->outputs().size()) - 1; i >= 0; i--) { graph->eraseOutput(i); } graph->registerOutput(return_tuple->output()); return new GraphFunction(function_name, graph, nullptr); } void ProfilingGraphExecutorImpl::replaceFallbackGraphWithFallbackFunction( Block* b) { Stack s; for (auto it = b->nodes().begin(); it != b->nodes().end();) { if (it->kind() == prim::FallbackGraph) { auto fallback_func = createFallbackPathFunction( it->g(attr::Subgraph)->block(), "fallback_function"); TORCH_INTERNAL_ASSERT(*remaining_bailout_depth_ > 0); GRAPH_DEBUG( "getPlanFor for", getHeader(*it), " ", *remaining_bailout_depth_); fallback_func->get_executor().getPlanFor( s, *remaining_bailout_depth_ - 1); fallback_functions_.emplace_back(fallback_func); WithInsertPoint wip{*it}; auto function_call = insertFallbackFunctionCall( b->owningGraph(), fallback_func, it->inputs()); for (size_t i = 0; i < function_call->outputs().size(); i++) { it->output(i)->replaceAllUsesWith(function_call->output(i)); } it.destroyCurrent(); } else { for (Block* ib : it->blocks()) { replaceFallbackGraphWithFallbackFunction(ib); } it++; } } } } // namespace jit } // namespace torch