#include <torch/csrc/jit/runtime/graph_executor.h> #include <ATen/core/ivalue.h> #include <c10/util/Exception.h> #include <torch/csrc/autograd/grad_mode.h> #include <torch/csrc/jit/frontend/tracer.h> #include <torch/csrc/jit/ir/ir.h> #include <torch/csrc/jit/jit_log.h> #include <torch/csrc/jit/passes/batch_mm.h> #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h> #include <torch/csrc/jit/passes/common_subexpression_elimination.h> #include <torch/csrc/jit/passes/constant_pooling.h> #include <torch/csrc/jit/passes/constant_propagation.h> #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> #include <torch/csrc/jit/passes/dead_code_elimination.h> #include <torch/csrc/jit/passes/decompose_ops.h> #include <torch/csrc/jit/passes/graph_fuser.h> #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h> #include <torch/csrc/jit/passes/inliner.h> #include <torch/csrc/jit/passes/inplace_check.h> #include <torch/csrc/jit/passes/loop_unrolling.h> #include <torch/csrc/jit/passes/lower_grad_of.h> #include <torch/csrc/jit/passes/lower_tuples.h> #include <torch/csrc/jit/passes/pass_manager.h> #include <torch/csrc/jit/passes/peephole.h> #include <torch/csrc/jit/passes/remove_expands.h> #include <torch/csrc/jit/passes/remove_mutation.h> #include <torch/csrc/jit/passes/requires_grad_analysis.h> #include <torch/csrc/jit/passes/shape_analysis.h> #include <torch/csrc/jit/passes/specialize_autogradzero.h> #include <torch/csrc/jit/passes/tensorexpr_fuser.h> #include <torch/csrc/jit/resource_guard.h> #include <torch/csrc/jit/runtime/argument_spec.h> #include <torch/csrc/jit/runtime/autodiff.h> #include <torch/csrc/jit/runtime/custom_operator.h> #include <torch/csrc/jit/runtime/graph_executor_impl.h> #include <torch/csrc/jit/runtime/interpreter.h> #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h> #include <torch/csrc/jit/runtime/profiling_record.h> #include <torch/csrc/autograd/edge.h> #include <torch/csrc/autograd/function.h> #include <torch/csrc/jit/runtime/logging.h> #include <cstdint> #include <iterator> #include <memory> #include <mutex> #include <unordered_map> #include <utility> #include <vector> namespace torch { namespace jit { EnableProfilingGuard::EnableProfilingGuard() { auto& profiling_mode = getProfilingMode(); old_profiling_mode = profiling_mode; profiling_mode = true; auto& executor_mode = getExecutorMode(); old_executor_mode = executor_mode; executor_mode = true; } EnableProfilingGuard::~EnableProfilingGuard() { getProfilingMode() = old_profiling_mode; getExecutorMode() = old_executor_mode; } namespace { c10::AliasAnalysisKind aliasAnalysisInternalSpecialCase() { return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } } // namespace // for debugging it is helpful to be able to force autodiff subgraphs // to be created, to check their correctness, even when the // size of the of the subgraph is too small to be profitable. thread_local bool autodiff_subgraph_inlining = true; void debugSetAutodiffSubgraphInlining(bool state) { autodiff_subgraph_inlining = state; } bool getAutodiffSubgraphInlining() { return autodiff_subgraph_inlining; } // for debugging it is helpful to be able to force fusion groups // to be created static std::atomic<bool> fusion_group_inlining(true); void debugSetFusionGroupInlining(bool state) { fusion_group_inlining = state; } bool getFusionGroupInlining() { return fusion_group_inlining; } thread_local std::weak_ptr<Graph> last_executed_optimized_graph; std::shared_ptr<Graph> lastExecutedOptimizedGraph() { return last_executed_optimized_graph.lock(); } namespace { using tensor_list = std::vector<at::Tensor>; using Variable = autograd::Variable; using autograd::variable_list; struct CaptureList { CaptureList(size_t capture_size) { capture_types_.reserve(capture_size); var_captures_.reserve(capture_size); // var_captures_.size() might be // greater than capture_size ivalue_captures_.reserve(capture_size); } void captureTensor(const at::Tensor& tensor, bool is_output) { var_captures_.emplace_back(Variable(tensor), is_output); } void capture(const IValue& val, bool is_output) { if (val.isTensor()) { capture_types_.emplace_back(CAPTURE_TENSOR); captureTensor(val.toTensor(), is_output); } else if (val.isTensorList()) { // For TensorList, we have to flatten it to Tensors during saving and // unflatten it back to TensorList when using it in backward apply(). // This is to avoid any implicit mutation to TensorList happened // between forward & backward. capture_types_.emplace_back(CAPTURE_LIST); auto tensors = val.toTensorList(); sizes_.push_back(tensors.size()); for (const at::Tensor tensor : tensors) { captureTensor(tensor, is_output); } } else { capture_types_.emplace_back(CAPTURE_IVALUE); ivalue_captures_.push_back(val); } } size_t size() const { return capture_types_.size(); } void unpack(Stack& stack, const std::shared_ptr<autograd::Node>& saved_for) { auto var_capture_it = var_captures_.begin(); auto ivalue_capture_it = ivalue_captures_.begin(); auto size_it = sizes_.begin(); for (Capture capture_type : capture_types_) { switch (capture_type) { case CAPTURE_TENSOR: { stack.emplace_back(var_capture_it->unpack(saved_for)); ++var_capture_it; } break; case CAPTURE_LIST: { c10::List<at::Tensor> lst; auto size = *size_it++; for (size_t i = 0; i < size; i++) { lst.emplace_back(var_capture_it->unpack(saved_for)); var_capture_it++; } stack.emplace_back(std::move(lst)); } break; case CAPTURE_IVALUE: { stack.push_back(*ivalue_capture_it++); } break; } } } void release_variables() { for (auto& var_capture_ : var_captures_) { var_capture_.reset_data(); } } private: enum Capture : uint8_t { CAPTURE_TENSOR, CAPTURE_LIST, CAPTURE_IVALUE, }; std::vector<Capture> capture_types_; std::vector<autograd::SavedVariable> var_captures_; std::vector<IValue> ivalue_captures_; std::vector<size_t> sizes_; }; // how do we turn a flattened list of tensors back into the ivalues that // the DifferentiableGraphBackward expects struct UnpackInstructions { UnpackInstructions(size_t num_inputs) { insts_.reserve(num_inputs); } void pushTensor() { insts_.emplace_back(PUSH_TENSOR); } void pushTensorList(size_t size) { insts_.emplace_back(PUSH_LIST); sizes_.push_back(size); } void unpack(variable_list&& inputs, Stack& stack) { auto input_it = std::make_move_iterator(inputs.begin()); auto sizes_it = sizes_.begin(); for (Inst inst : insts_) { switch (inst) { case PUSH_TENSOR: { at::Tensor t = *input_it++; stack.emplace_back(std::move(t)); } break; case PUSH_LIST: { std::vector<at::Tensor> lst(input_it, input_it + *sizes_it++); stack.emplace_back(lst); } break; } } } private: enum Inst : uint8_t { PUSH_TENSOR, PUSH_LIST, // consumes one size }; std::vector<Inst> insts_; std::vector<size_t> sizes_; }; // unpack values packed by `packReturnValuesIntoTuple` static void unpackReturnTuple(Stack& stack) { auto tuple = pop(stack).toTuple(); stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end()); } struct DifferentiableGraphBackward : public autograd::Node { DifferentiableGraphBackward( GraphExecutor executor, size_t input_size, size_t capture_size) : executor(std::move(executor)), captures_(capture_size), input_instructions_(input_size) {} variable_list apply(variable_list&& inputs) override { Stack stack; stack.reserve(captures_.size() + inputs.size()); input_instructions_.unpack(std::move(inputs), stack); captures_.unpack(stack, shared_from_this()); GRAPH_DEBUG("Running DifferentiableGraphBackward for ", &executor); executor.run(stack); unpackReturnTuple(stack); // NB: stack.size() == num_outputs() is not always true // after we added TensorList support. // Example: aten::stack(Tensor[] tensors, int) where // tensors = [x, x] // Here stack.size()[=1] with a TensorList IValue of // backward graph output. // num_outputs()[=2], however, is the number of outputs of // grad_fn (an autograd::Node). grad_fn's outputs are // grads with regard to Tensor/Variables `x`, but not // graph input TensorList [x, x]. These two grads will // be accumulated to x.grad later using autograd::InputBuffer. variable_list outputs; outputs.reserve(num_outputs()); size_t output_index = 0; for (IValue& v : stack) { if (v.isTensorList()) { for (at::Tensor tensor : v.toTensorList()) { produceOutput(output_index++, std::move(tensor), outputs); } } else if (v.isTensor()) { produceOutput(output_index++, std::move(v).toTensor(), outputs); } else { // Input grad can also be None even if it requires grad // Example: `other` in expand_as(self, other) outputs.emplace_back(); } } return outputs; } void capture(const IValue& val, bool is_output) { captures_.capture(val, is_output); } void addOutputForTensor(const at::Tensor& tensor) { auto v = Variable(tensor); add_next_edge( v.defined() ? torch::autograd::impl::gradient_edge(v) : autograd::Edge{}); } void addOutputForIValue(const IValue& value) { if (value.isTensorList()) { for (const at::Tensor tensor : value.toTensorList()) { addOutputForTensor(tensor); } } else { addOutputForTensor(value.toTensor()); } } void addInputVariable(Variable output) { // NB: since our requires_grad setting is only a heuristic we might end // up wanting to differentiate through integral tensors, which is // generally a hard error in autograd. if (at::isFloatingType(output.scalar_type())) { autograd::create_gradient_edge(output, shared_from_this()); output.set_requires_grad(true); } else { add_input_metadata(autograd::Node::undefined_input{}); } } void addInputIValue(const IValue& v) { if (v.isTensorList()) { auto tensors = v.toTensorList(); input_instructions_.pushTensorList(tensors.size()); for (const at::Tensor tensor : tensors) { addInputVariable(tensor); } } else if (v.isTensor()) { input_instructions_.pushTensor(); addInputVariable(v.toTensor()); } } void release_variables() override { captures_.release_variables(); } private: void produceOutput(size_t i, at::Tensor output, variable_list& outputs) { if (should_compute_output(i)) { const auto& edge = next_edge(i); if (output.defined()) { outputs.emplace_back(std::move(output)); } else if (edge.is_valid()) { outputs.emplace_back( edge.function->input_metadata(edge.input_nr).zeros_like()); } else { outputs.emplace_back(); } } else { outputs.emplace_back(); } } friend struct ExecutionPlan; GraphExecutor executor; CaptureList captures_; UnpackInstructions input_instructions_; }; // an optimized way of executing the subgraph computed directly on // tensors rather than Variables. // This will unwrap Variables, run the plan, and re-wrap them. // It can optionally also have a gradient which is hooked up // to the output Variables if present. struct DifferentiableGraphOp { DifferentiableGraphOp(Gradient grad) : f(grad.f, "<foward op>"), grad(std::move(grad)), grad_executor(this->grad.df, "<backward op>"), num_inputs(this->grad.f->inputs().size()), num_outputs(this->grad.f->outputs().size()) {} // XXX: keep in mind that stack can be larger than the inputs we need! void operator()(Stack* stack) const { auto grad_fn = std::make_shared<DifferentiableGraphBackward>( grad_executor, grad.df_input_vjps.size(), grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size()); { auto inputs = last(stack, num_inputs); // hook up the outputs of df to the gradient functions of the inputs that // require gradients for (auto idx : grad.df_output_vjps) { grad_fn->addOutputForIValue(inputs[idx]); } captureInputs(*grad_fn, inputs); } detachVariables(*stack); InterpreterState(f).run(*stack); { auto outputs = last(stack, num_outputs); // hookup the gradients for the output tensors that require gradients // to the inputs to our gradient function df // TODO - XXX - if any output is the same tensor multiple times, views // have to be setup here. We need to refactor autograd until it is safe // for tensors to be constructed without all the viewing infrastructure. // this is currently intentionally not done here so we can get an idea of // our perf before introducing overhead for correctness for (auto idx : grad.df_input_vjps) { grad_fn->addInputIValue(outputs[idx]); } captureOutputs(*grad_fn, outputs); // drop the temporary outputs so that we return the same number of // outputs as if we were not also calculating gradient const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs; stack->erase(stack->end() - num_temporary_outputs, stack->end()); } } private: friend GraphExecutor* detail::getGradExecutor(Operation& op); at::Tensor detach(at::Tensor t) const { if (!t.defined()) { return t; } return t.detach(); } void detach(IValue& v) const { if (v.isTensor()) { v = IValue(detach(std::move(v).toTensor())); } else if (v.isTensorList()) { c10::List<at::Tensor> lst = std::move(v).toTensorList(); for (size_t i = 0; i < lst.size(); ++i) { lst.set(i, detach(lst.extract(i))); } v = std::move(lst); } } void detachVariables(Stack& stack) const { // It would be nice to use an ArrayRef here, but unfortunately those can // only return const references, so we need to do a bunch of indexing // ourselves. const int64_t stack_size = stack.size(); const int64_t stack_offset = stack_size - num_inputs; for (int64_t i = stack_offset; i < stack_size; ++i) { detach(stack[i]); } } // Capture (save) inputs that would be required to subsequently run backwards void captureInputs( DifferentiableGraphBackward& grad_fn, at::ArrayRef<IValue> inputs) const { for (size_t offset : grad.df_input_captured_inputs) { grad_fn.capture(inputs[offset], /*is_output*/ false); } } void captureOutputs( DifferentiableGraphBackward& grad_fn, at::ArrayRef<IValue> outputs) const { for (size_t offset : grad.df_input_captured_outputs) { grad_fn.capture(outputs[offset], /*is_output*/ true); } } Code f; Gradient grad; GraphExecutor grad_executor; const size_t num_inputs; const size_t num_outputs; }; Gradient getGradient(const Node* n) { AT_ASSERT(n->kind() == prim::DifferentiableGraph); Gradient grad; grad.f = n->g(attr::Subgraph); grad.df = n->g(attr::ReverseSubgraph); grad.f_real_outputs = n->i(attr::f_real_outputs); grad.df_input_vjps = fmap<size_t>(n->is(attr::df_input_vjps)); grad.df_input_captured_inputs = fmap<size_t>(n->is(attr::df_input_captured_inputs)); grad.df_input_captured_outputs = fmap<size_t>(n->is(attr::df_input_captured_outputs)); grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps)); return grad; } } // anonymous namespace RegisterOperators reg_graph_executor_ops({Operator( prim::DifferentiableGraph, [](const Node* n) -> Operation { return DifferentiableGraphOp(getGradient(n)); }, aliasAnalysisInternalSpecialCase())}); namespace detail { GraphExecutor* getGradExecutor(Operation& op) { if (auto diff_op = op.target<DifferentiableGraphOp>()) { return &diff_op->grad_executor; } return nullptr; } } // namespace detail void GraphExecutorImplBase::run(Stack& stack) { TORCH_CHECK( stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size()); C10_LOG_API_USAGE_ONCE("torch.graph_executor.run"); logging::getLogger()->addStatValue( logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); const ExecutionPlan& plan = getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()); InterpreterState(plan.code).run(stack); last_executed_optimized_graph = plan.graph; } c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync( Stack& stack, TaskLauncher taskLauncher) { TORCH_CHECK( stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size()); C10_LOG_API_USAGE_ONCE("torch.graph_executor.runAsync"); logging::getLogger()->addStatValue( logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); struct Frame { explicit Frame(ExecutionPlan eplan, TaskLauncher taskLauncher) : plan(std::move(eplan)), state(plan.code, std::move(taskLauncher)) {} ExecutionPlan plan; InterpreterState state; }; auto frame = std::make_shared<Frame>( getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()), std::move(taskLauncher)); auto res = frame->state.runAsync(stack); last_executed_optimized_graph = frame->plan.graph; if (!res->completed()) { // If not completed, persist the Frame until complete. res->addCallback([frame] {}); } return res; } // a Graph can be created via tracing, or via a language-based frontend // GraphExecutor runs it. It can run the same graph on many different sizes // and different requires_grad states, and handles specializations for each // situation. GraphExecutor is completely unaware of tracing or module // parameters to keep the tracing concerns separated. struct GraphExecutorImpl : public GraphExecutorImplBase { GraphExecutorImpl( const std::shared_ptr<Graph>& graph, std::string function_name) : GraphExecutorImplBase(graph, std::move(function_name)), arg_spec_creator_(*graph) { logging::getLogger()->addStatValue( logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); } const ExecutionPlan& getPlanFor(Stack& stack, size_t remaining_bailout_depth) override { return getGraphExecutorOptimize() ? getOrCompile(stack) : getOrCompileFallback(); } GraphExecutorState getDebugState() override { GraphExecutorState state; state.graph = graph.get(); if (fallback) { state.fallback = fallback; } for (auto& entry : plan_cache) { state.execution_plans.emplace(entry.first, entry.second); } return state; } protected: friend struct GraphExecutor; const ExecutionPlan& getOrCompileFallback() { std::lock_guard<std::mutex> lock(compile_mutex); if (!fallback) { auto graph_ = graph->copy(); runRequiredPasses(graph_); fallback = ExecutionPlan(graph_, function_name_); } return fallback; } const ExecutionPlan& getOrCompile(const Stack& stack) { // outside lock guard, to minimize the time holding the lock on the fast // path ArgumentSpec even computes its hashCode here. ArgumentSpec spec = arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack); { std::lock_guard<std::mutex> lock(compile_mutex); auto it = plan_cache.find(spec); if (it != plan_cache.end()) { logging::getLogger()->addStatValue( logging::runtime_counters::EXECUTION_PLAN_CACHE_HIT, 1.0); return it->second; } auto plan = compileSpec(spec); auto r = plan_cache.emplace(std::move(spec), std::move(plan)); logging::getLogger()->addStatValue( logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0); return r.first->second; } } ExecutionPlan compileSpec(const ArgumentSpec& spec) { auto opt_graph = graph->copy(); GRAPH_DUMP("Optimizing the following function:", opt_graph); arg_spec_creator_.specializeTypes(*opt_graph, spec); // Phase 0. Inline functions, then clean up any artifacts that the inliner // left in that may inhibit optimization Inline(*opt_graph); GRAPH_DEBUG("After Inline, before LowerGradOf\n", *opt_graph); LowerGradOf(*opt_graph); GRAPH_DEBUG( "After LowerGradOf, before specializeAutogradZero\n", *opt_graph); specializeAutogradZero(opt_graph); GRAPH_DEBUG( "After specializeAutogradZero, before LowerSimpleTuples\n", *opt_graph); LowerSimpleTuples(opt_graph); GRAPH_DEBUG( "After LowerSimpleTuples, before ConstantPooling\n", *opt_graph); ConstantPooling(opt_graph); GRAPH_DEBUG( "After ConstantPooling, before runRequiredPasses\n", *opt_graph); // Phase 1. Specialize to input definedness (this is very important for // gradient graphs), and run required passes to bring the graph // to an executable form. runRequiredPasses(opt_graph); GRAPH_DEBUG( "After runRequiredPasses, before ConstantPropagation\n", *opt_graph); // Phase 2. Propagate detailed information about the spec through the // graph (enabled more specializations in later passes). // Shape propagation sometimes depends on certain arguments being // constants, and constant propagation doesn't need shape // information anyway, so it's better to run it first. ConstantPropagation(opt_graph); GRAPH_DEBUG( "After ConstantPropagation, before PropagateInputShapes\n", *opt_graph); PropagateInputShapes(opt_graph); GRAPH_DEBUG( "After PropagateInputShapes, before PropagateRequiresGrad\n", *opt_graph); PropagateRequiresGrad(opt_graph); GRAPH_DEBUG( "After PropagateRequiresGrad, before runOptimization\n", *opt_graph); // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites // that we can still execute using autograd). runOptimization(opt_graph); // Phase 4. If this graph will be differentiated, we need to slice out the // symbolically differentiable subgraphs for further optimizations. // Phase 5. Apply non-differentiable optimizations to the graphs we've found // (or the whole graph if we know we won't need its derivative). if (needsGradient(opt_graph)) { auto diff_nodes = CreateAutodiffSubgraphs( opt_graph, autodiff_subgraph_inlining ? autodiffSubgraphNodeThreshold : 1); GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *opt_graph); size_t idx = 0; for (Node* dnode : diff_nodes) { GRAPH_DEBUG("Optimizing diff node ", idx); auto diff_graph = std::move(dnode->g(attr::Subgraph)); Gradient gradient = differentiate(diff_graph); GRAPH_DEBUG("Forward graph:\n", *(gradient.f)); GRAPH_DEBUG("Backward graph:\n", *(gradient.df)); // Run post differentiation optimizations, Autodiff will replace some // parts of graph with new graph, these new graphs usually consists of // control flows and miss shape information on nodes, so we run shape // prop and differentiable optimizations to ensure the graph is // optimized PropagateInputShapes(gradient.f); GRAPH_DEBUG("After PropagateInputShapes\n", *(gradient.f)); runOptimization(gradient.f); // run non diff optimization on the forward graph runNondiffOptimization(gradient.f); packGradient(gradient, dnode); GRAPH_DEBUG("Finished optimizing diff node ", idx++); } InlineAutodiffSubgraphs( opt_graph, autodiff_subgraph_inlining ? autodiffSubgraphInlineThreshold : 1); GRAPH_DEBUG("After InlineAutodiffSubgraphs\n", *opt_graph); } else { runNondiffOptimization(opt_graph); } // Make sure there are no leftovers from any passes. EliminateDeadCode(opt_graph); GRAPH_DUMP("After compileSpec optimizations:", opt_graph); return ExecutionPlan(opt_graph, function_name_); } ~GraphExecutorImpl() override = default; ArgumentSpecCreator arg_spec_creator_; // Populated only when optimize is false (and in that case plan_cache will be // unused). The compiled version of graph. ExecutionPlan fallback; // Mapping from argument configurations to optimized versions of the graph // that are specialized to the spec. std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache; }; GraphExecutor::GraphExecutor( const std::shared_ptr<Graph>& graph, std::string function_name) : pImpl( IsNewExecutorEnabled() ? dynamic_cast<GraphExecutorImplBase*>( new ProfilingGraphExecutorImpl( graph, std::move(function_name))) : dynamic_cast<GraphExecutorImplBase*>( new GraphExecutorImpl(graph, std::move(function_name)))) {} void GraphExecutor::run(Stack& inputs) { return pImpl->run(inputs); } c10::intrusive_ptr<Future> GraphExecutor::runAsync( Stack& stack, TaskLauncher taskLauncher) { return pImpl->runAsync(stack, std::move(taskLauncher)); } size_t GraphExecutor::getDefaultNumBailOuts() { return getProfilingMode() ? getBailoutDepth().load() : 0; } const ExecutionPlan& GraphExecutor::getPlanFor( Stack& inputs, size_t remaining_bailout_depth) { return pImpl->getPlanFor(inputs, remaining_bailout_depth); } std::shared_ptr<Graph> GraphExecutor::graph() const { return pImpl->graph; } GraphExecutorState GraphExecutor::getDebugState() { return pImpl->getDebugState(); } TORCH_API bool IsNewExecutorEnabled() { static const auto disable_new_executor = std::getenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"); return getExecutorMode() && FLAGS_torch_jit_enable_new_executor && !disable_new_executor; } void runRequiredPasses(const std::shared_ptr<Graph>& g) { // implicit inserted expand nodes are not necessarily always valid // when used inside script methods that might have unstable shapes // we remove the implicitly created ones, and have shape analysis // add valid expand nodes when the shapes are stable RemoveExpands(g); CanonicalizeOps(g); EliminateDeadCode(g); } void packGradient(const Gradient& gradient, Node* dnode) { AT_ASSERT(dnode->kind() == prim::DifferentiableGraph); dnode->g_(attr::Subgraph, gradient.f) ->g_(attr::ReverseSubgraph, gradient.df) ->i_(attr::f_real_outputs, gradient.f_real_outputs) ->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps)) ->is_( attr::df_input_captured_inputs, fmap<int64_t>(gradient.df_input_captured_inputs)) ->is_( attr::df_input_captured_outputs, fmap<int64_t>(gradient.df_input_captured_outputs)) ->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps)); } static bool mayIntroduceGradient(const Block* b) { for (const Node* n : b->nodes()) { if (n->kind() == prim::PythonOp) return true; for (const Block* bb : n->blocks()) { if (mayIntroduceGradient(bb)) return true; } } return false; } bool needsGradient(const std::shared_ptr<const Graph>& graph) { if (!autograd::GradMode::is_enabled()) { return false; } if (mayIntroduceGradient(graph->block())) { return true; } for (const Value* input : graph->inputs()) { if (input->type()->requires_grad()) { return true; } } return false; } void runNondiffOptimization( std::shared_ptr<Graph>& graph, bool strict_fuser_check) { GRAPH_DEBUG( "Before customPrePassses (beginning of runNondiffOptimization)\n", *graph); // Run custom passes that different backends can register. for (const auto& passPair : getCustomPrePasses()) { passPair.first(graph); } GRAPH_DEBUG("After customPrePassses\n", *graph); // decomposition pass, decompose certain ops that will be used in the // following passes (like batchmm and jit fusion) if (!getProfilingMode()) { DecomposeOps(graph); GRAPH_DEBUG("After DecomposeOps\n", *graph); } // TupleConstruct / TupleUnpack pairs can still be present at this point // and must be removed for fusion. LowerSimpleTuples(graph); GRAPH_DEBUG("After LowerSimpleTuples, before BatchMM\n", *graph); // Rewrite subgraphs with many MMs into expressions that batch them. BatchMM(graph); GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph); if (getProfilingMode()) { if (tensorExprFuserEnabled()) { FuseTensorExprs(graph); } } else { FuseGraph(graph, strict_fuser_check); } GRAPH_DEBUG("After Fusion\n", *graph); // Run custom post-fusion passes for (const auto& passPair : getCustomPostPasses()) { passPair.first(graph); } GRAPH_DEBUG( "After customPostPassses (end of runNondiffOptimization)\n", *graph); } void runOptimization( std::shared_ptr<Graph>& graph, bool unroll, bool const_prop_user_classes) { // Basic graph preprocessing to eliminate noise. GRAPH_DEBUG( "Before EliminateDeadCode (beginning of runOptimization)\n", *graph); EliminateDeadCode(graph); GRAPH_DEBUG( "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph); EliminateCommonSubexpression(graph); GRAPH_DEBUG( "After EliminateCommonSubexpression, before PeepholeOptimize\n", *graph); PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); if (const_prop_user_classes) { ConstantPropagation(graph); } else { ConstantPropagation(graph, true); } GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph); ConstantPooling(graph); GRAPH_DEBUG("After ConstantPooling\n", *graph); // Unroll small loops, and eliminate expressions that are the same at every // iteration. if (unroll) { UnrollLoops(graph); GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph); // run again with unrolled loops RemoveListMutation(graph); GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph); PeepholeOptimize(graph); GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph); ConstantPropagation(graph); GRAPH_DEBUG("After ConstantPropagation\n", *graph); } EliminateCommonSubexpression(graph); GRAPH_DEBUG( "After EliminateCommonSubexpression, before CheckInplace\n", *graph); CheckInplace(graph); GRAPH_DEBUG("After CheckInplace (end of runOptimization)", *graph); } Node* replaceBlockWithFallbackGraph(Block* b, ArrayRef<Value*> inputs) { auto graph = std::make_shared<Graph>(); // we are copying the block inside If or prim::Loop otherwise we are copying // the whole graph we need to differentiate the two cases because cloneFrom // automatically adds inputs if we are copying graph's block and we will // need the inputs from a user otherwise if (b->owningNode() != nullptr) { std::unordered_map<Value*, Value*> input_mapping; auto value_map = [&input_mapping](Value* v) { return input_mapping[v]; }; for (auto inp : inputs) { input_mapping[inp] = graph->block()->addInput(); } graph->block()->cloneFrom(b, value_map); } else { auto value_map = [](Value* v) { return v; }; graph->block()->cloneFrom(b, value_map); } auto fallback = b->owningGraph()->create( prim::FallbackGraph, inputs, b->outputs().size()); fallback->g_(attr::Subgraph, graph); b->prependNode(fallback); for (size_t i = 0; i < inputs.size(); i++) { graph->inputs()[i]->setType(inputs[i]->type()); graph->inputs()[i]->copyMetadata(inputs[i]); } for (size_t i = 0; i < b->outputs().size(); i++) { fallback->output(i)->setType(b->outputs()[i]->type()); fallback->output(i)->copyMetadata(b->outputs()[i]); b->replaceOutput(i, fallback->output(i)); } ProfilingRecord::removeProfilingNodes(graph->block()); for (auto it = b->nodes().rbegin(); it != fallback->iterator(); it++) { it.destroyCurrent(); } return fallback; } } // namespace jit } // namespace torch