#pragma once #include <c10/util/Optional.h> #include <memory> #include <vector> #include <ATen/ThreadLocalState.h> #include <ATen/core/ivalue.h> #include <torch/csrc/WindowsTorchApiMacro.h> #include <torch/csrc/jit/frontend/source_range.h> C10_DECLARE_bool(torch_jit_disable_warning_prints); namespace at { class Tensor; TORCH_API void launch(std::function<void()> func); } // namespace at namespace c10 { struct IValue; struct OperatorName; } // namespace c10 namespace torch { namespace jit { // The interpreter run Graphs with Tensor inputs and Tensor outputs // a separate component in the autograd handles unwrapping and wrapping // variable objects for use in the interpreter. struct Node; struct GraphExecutor; struct CodeImpl; struct InterpreterStateImpl; struct Graph; struct Node; struct Instruction; using Stack = std::vector<c10::IValue>; using c10::ivalue::Future; using TaskLauncher = std::function<void(std::function<void()>)>; struct TORCH_API Code { Code() : pImpl(nullptr) {} // remaining_bailout_depth is irrelevant in a `Code` object unless the `Code` // is directly created by `GraphExecutor` in which case it's likely to contain // `prim::BailOut`s to control the maximum depth of bailout chains explicit Code( const std::shared_ptr<Graph>& graph, std::string function_name, size_t remaining_bailout_depth = 0); ~Code(); const std::vector<GraphExecutor*>& grad_executors(); explicit operator bool() const { return pImpl != nullptr; } size_t num_inputs() const; size_t num_outputs() const; size_t num_bailouts() const; const std::vector<c10::IValue>& constant_table() const; const std::vector<c10::TypePtr>& type_table() const; const std::vector<Instruction>& instructions() const; const std::vector<Node*>& instructions_source() const; void request_bailout(size_t index); size_t register_size() const; private: std::shared_ptr<CodeImpl> pImpl; friend struct InterpreterStateImpl; friend std::ostream& operator<<(std::ostream& out, const Code& code); }; struct InterpreterState { TORCH_API InterpreterState( const Code& code, TaskLauncher taskLauncher = at::launch); TORCH_API void run(Stack& stack); TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack); c10::intrusive_ptr<Future> getFuture(); TORCH_API ~InterpreterState(); private: InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl); // Ideally we should use c10::intrusive_ptr<InterpreterStateImpl> for pImpl; // but intrusive_ptr requires full definition of InterpreterStateImpl, // which we need to hide in the header. c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl; friend struct InterpreterStateImpl; }; // Created by wait() struct Suspend : public std::exception { const char* what() const noexcept override { return "Suspend"; } explicit Suspend(c10::intrusive_ptr<Future> future_) : future(std::move(future_)) {} c10::intrusive_ptr<Future> future; }; // InterpreterContinuation propagates dist_autograd_context_id // through (and only through) the forward pass manually, other // thread local settings are propagated with ThreadLocalState struct InterpreterContinuation { InterpreterContinuation( const InterpreterState& state_, Stack stack_, int64_t dist_autograd_context_id = 0, c10::optional<at::ThreadLocalState> tls_state = c10::nullopt) : state(state_), stack(std::move(stack_)), tls_state_(std::move(tls_state)) { #ifdef USE_DISTRIBUTED dist_autograd_context_id_ = dist_autograd_context_id; #endif } void operator()(); private: InterpreterState state; Stack stack; c10::optional<at::ThreadLocalState> tls_state_ = c10::nullopt; #ifdef USE_DISTRIBUTED int64_t dist_autograd_context_id_; #endif }; // what is the tensors type, including state from the current execution context // that modifies how the tensor behaves. For instance if no_grad is enabled // this will cause the TensorType to have requires_grad=False. TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext( const at::Tensor& t); // current (TLS) TorchScript interpreter callstack TORCH_API std::vector<StackEntry> currentCallstack(); } // namespace jit } // namespace torch