The JIT runtime - Calling a ScriptFunction
In the first post of our series of PyTorch JIT blog posts, we had a good overview of how the PyTorch JIT works when we wanted to look at how it optimizes models. Today we take a close look at what happens behind the scenes when we call a TorchScript function from Python.
What happens when you call a TorchScript function?
This is structured a bit differently from the last post in that I will take you by the hand and we hop into the rabbit hole of the PyTorch source code. To do this, I have added links to a copy of the source code below and they will open the source code at the corresponding line on the right when you move the mose over them or click them (and sorry, if you read this on mobile, you're probably out of luck). As I want you to click on all the internal links I have marked the very few external links with .
Let us try:
@torch.jit.script
def fn(x):
return x * 2 + x
fn.__class__
this gives
torch.jit.ScriptFunction
so we have to look at what happens when a ScriptFunction
is called.
This will be quite a journey, so here is a map:
It will be quite a journey, so off we go!
Functions from Python to C++
ScriptFunction
is a PyBind-defined class defined in torch/csrc/jit/python/script_init.cpp
. It wraps a StrongFunctionPtr
. It defines a __call__
method.
The StrongFunctionPtr
has a shared pointer to a CompilationUnit
and a torch::jit::Function*
.
The CompilationUnit
owns the function (put there e.g. by scripting our function) and potentially multiple of them but to adapt to Python, we want something that is specific to this function, the StrongFunctionPtr
.
So for calling, we need the Function
. This is passed with the (Python) args to invokeScriptFunctionFromPython
, defined in pybind_utils.h
, which passes to runAndInsertCall
in the same file.
There we use the function createStackForSchema
to create a stack of IValue
s from the Python arguments using the function's Schema
- TorchScript's way of keeping track of the signature of a function. We skip the details of that. Then the Function
's run
method is called to execute the function. Afterwards, the stack contains the results in IValue
s, these are converted to Python values (toPyObject
, same file) and returned. If we were in tracing mode (i.e. runningtorch.jit.trace
calling our script function), we would record the call in the tracing graph.
So we are looking for Function::run
. A Function
is an abstract class defined in aten/src/ATen/core/function.h
. The interesting specialization is the GraphFunction
from torch/csrc/jit/api/function_impl.h
and .cpp
. It has a GraphExecutor
attribute executor_
that can be obtained though (and is instantiated in) its get_executor()
(from the .h
). With that, all GraphFunction::run
does is call the run
method of the GraphExecutor
with the IValue
-stack as the argument.
The GraphExecutor
is instantiated with a TorchScript Graph
from the GraphFunction
's optimized_graph()
method and the function's name. The optimized_graph
is basically the graph with some initial optimizations (applied by GraphFunction::preoptimizeGraph
in the .cpp
): PeepholeOptimize
, ConstantPropagationImmutableTypes
and ConstantPooling
. I will have to do a separate blog post on all the optimizations and passes let's skip details for now.
Graph Executors
Now we are in the depths of the JIT runtime and in torch/csrc/jit/runtime
! The GraphExecutor
is a wrapper for GraphExecutorImpl
(from graph_executor_impl.h
and graph_executor.cpp
) or the newer ProfilingGraphExecutorImpl
(from profiling_graph_executor_impl.h
and .cpp
) that both extend the GraphExecutorImplBase
(also graph_executor_impl.h
and graph_executor.cpp
). The wrapper mostly handles instantiation and forwards a few method calls, including, of course, run
. Which executor implementation gets used is decided in the GraphExecutor
constructor based on three things: the environment variable TORCH_JIT_DISABLE_NEW_EXECUTOR
, the C++ command line flag torch_jit_enable_new_executor
, and - most importantly for us - a flag obtained through getExecutorMode()
, which is exposed to Python via torch._C._jit_set_profiling_executor
and set via the torch.jit.fuser
context. The default, nowadays is the profiling one.
The graph executor is the thing that handles the optimization and running of things in TorchScript. It works at the level of TorchScript IR graphs and will itself call into the bytecode interpreter for the actual execution, as we will see.
Functions (torch::jit::Function
) are not the only thing that get graph executors, sometimes we also want to have executors for other graphs. This will be important to us below, when we use a new executor for some things in order to have specialized optimizations.
(We'll drop the Impl from the graph executors below when referring to them below).
Execution Plans
The run
(GraphExecutorImpleBase::run
in graph_executor.cpp
) method is relatively simple: It gets an ExecutionPlan
(defined in graph_executor.h
) using the getPlanFor
method (which takes our stack - for the types - and the remaining bailout depth as arguments. More on the latter below.). This execution plan has a member code
holding the bytecode in a Code
object. The GraphExecutor
's run
method then instantiates an InterpreterState
with this code
and calls its run
method with the stack.
But this means the optimization magic is in getPlanFor
, which is specific to the two executors.
The Profiling Graph Executor
The Profiling Graph Executor's getPlanFor
is again very easy: If the optimized_plan_
is initialized, that is what it returns (so we only have one optimized execution plan which is used regardless of the input types). If not, it calls getOptimizedPlanFor
to make one.
When not disabled through a flag that can be queried and set in Python through torch._C._get/set_graph_executor_optimize
, our getOptimizedPlanFor
goes through the following:
-
If the bailout depth is $0$, it uses the graph after only the profiling insensitive optimizations (quite a list called from the function
runProfilingInsensitiveoptimizations
). In particular, no profiling or specialization takes place. -
If not it creates a profiling plan (if it has not done so yet) that is used to record the shape information in a
ProfilingRecord
(assigned to class memberpr_
,ProfilingRecord
is defined inprofiling_record.h/cpp
). The profiling graph is created from the graph we got by running the profiling insinsitive optimizations and then instrumenting the graph (through theProfilingRecord::instrumentGraph
factory function). This graph then contains lots ofprim::profile
nodes. TheProfilingRecord
has a counter for the number of remaining profiling runs until it is ready (ready()
), it starts fromgetNumProfiledRuns()
, which can controlled from Python bytorch.jit._jit_get/set_num_profiled_runs
, the default value is $1$. If the profiling graph is there, but the profiling record is not ready yet, this profiling graph is run (through anExecutionPlan
created from it). -
Once ready, it creates the optimized graph by running the profiling optimizations (
runProfilingOptimizations
) on a copy of the profiling graph. This is what we instantiate theExecutionPlan
that is ouroptimized_plan_
with. When we callfn.get_debug_state()
from our Python script function, the debug state'sexecution_plans
member is a dictionary will have this execution plan as the only member. Currently, we get an internal assert error when it doesn't exist (e.g. because we're still in the profiling phase).
And this is really all there is to the Profiling Graph Executor - because we left all the details about the two important bits that are at the core of the executor: the profiling mechanism and the optimizations, in particular after profiling.
We can look at the profiling executor in action:
@torch.jit.script
def fn(x):
return torch.sin(x) * 2 + x
x = torch.randn(5, device="cuda")
fn(x) # call the function once
print(torch.jit.last_executed_optimized_graph())
This gives us the profiled graph (see the prim::profile
nodes):
graph(%x.1 : Tensor):
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]() # <ipython-input-11-43e780f89194>:3:26
%6 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%x.1)
%3 : Tensor = aten::sin(%6) # <ipython-input-11-43e780f89194>:3:11
%7 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%3)
%4 : Tensor = aten::mul(%7, %2) # <ipython-input-11-43e780f89194>:3:11
%8 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%4)
%9 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%x.1)
%5 : Tensor = aten::add(%8, %9, %1) # <ipython-input-11-43e780f89194>:3:11
%10 : Tensor = prim::profile[profiled_type=Float(5, strides=[1], requires_grad=0, device=cuda:0)](%5)
= prim::profile()
return (%10)
As there isn't an optimized graph yet, getting the debug state throws an exception:
print(fn.get_debug_state().execution_plans)
gives
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-12-22ea0456b2eb> in <module>
1 # as there isn't an optimized graph yet, this throws an exception
----> 2 print(fn.get_debug_state().execution_plans)
RuntimeError: optimized_plan_ INTERNAL ASSERT FAILED at "../torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp":556, please report a bug to PyTorch.
because we hit an internal assert because there is no optimized plan yet.
But when we run the function another time (and because the number of profiling runs is 1 by default) we get the optimized graph with TensorExpr fusion group:
fn(x) # run a second time
print(torch.jit.last_executed_optimized_graph())
Note the TypeCheck and the bailout function we met in the introduction to how the jit optimizes functions. There we also saw a trick to get at the fallback function's graph.
With the optimized plan defined, we can now also get the debug state.
print(fn.get_debug_state().execution_plans)
Regardless of how often we called the function and with what arguments, we only ever have one execution plan in the dictionary, as this is just a dummy mapping to the debug state designed for the traditional Graph Executor:
{<torch._C.ArgumentSpec object at 0x7f92021cb470>: <torch._C.ExecutionPlan object at 0x7f92021cbd30>}
Having seen the profiling graph executor in code and in action, let us now look briefly at the traditional one.
The traditional Graph Executor
The traditional GraphExecutor isn't as interesting to us because it's on its way to retirement (well, maybe slowly) but so here is an equally brief overview, and maybe some details that are interesting ideas to contrast with the profiling executor:
The difference starts when the getPlanFor
method is called. Depending on whether optimizations are enabled, it calls getOrCompile
or getOrCompileFallback
, the interesting bit being the getOrCompile
. (The non-optimized parts still have some passes they need to apply, but they basically leave out most things that are optional.)
The traditional executor distinguishes between different input configurations (shapes, requires grad, whether optionals are None
etc.) and creates distinct optimized graphs for them. This was a key ingredient to optimizing LSTM backwards because it allowed to make the information whether broadcasting has happened int the forward "static" for the backward. It does so by having a minified version of the information listed above and a hash table (the execution_plans
directory in the debug state). So in getOrCompile
the JIT creates an ArgumentSpec
that is the key to the plan cache. If it finds a plan, it returns that, else it compiles a new one for this spec in compileSpec
.
As the traditional executor relies on shape propagation to apply optimizations, it seeds the input's shape information. Then it applies the optimizations (inlining functions etc.) that always work (similar to the profiling insensitive ones in the profiling executor), followed by the differentiation mechanism and optimizations that can only be executed when things do not require gradients (namely fusion with the traditional fuser) either inside the differentiable graph's forward or for graphs that don't need gradients.
And this is really all we need to know about the traditional executor. To see it in action, we can switch to it and run script functions as in the following (note that you want to re-define and re-script the function to not get cached results):
@torch.jit.script
def fn(x):
return torch.sin(x) * 2 + x
with torch.jit.fuser("fuser0"):
old_pe = torch._C._jit_set_profiling_executor(False)
gr1 = fn.graph_for(torch.randn(5, device="cuda", requires_grad=False))
gr2 = fn.graph_for(torch.randn(5, device="cuda", requires_grad=True))
torch._C._jit_set_profiling_executor(old_pe)
# we find two execution plans, but there isn't a way to see the argspec in Python
print(fn.get_debug_state().execution_plans)
print(gr1) # it seems the fuser0 is already gone here...
We see that there are two execution plans:
{<torch._C.ArgumentSpec object at 0x7f402564a370>: <torch._C.ExecutionPlan object at 0x7f4023119eb0>, <torch._C.ArgumentSpec object at 0x7f40257bc530>: <torch._C.ExecutionPlan object at 0x7f40257bc2b0>}
But no sign of fusion:
graph(%x.1 : Float(*, requires_grad=0, device=cuda:0)):
%1 : int = prim::Constant[value=2]() # <ipython-input-67-7eb141224258>:3:26
%2 : int = prim::Constant[value=1]()
%3 : Float(*, requires_grad=0, device=cuda:0) = aten::sin(%x.1) # <ipython-input-67-7eb141224258>:3:11
%4 : Float(*, requires_grad=0, device=cuda:0) = aten::mul(%3, %1) # <ipython-input-67-7eb141224258>:3:11
%5 : Float(*, requires_grad=0, device=cuda:0) = aten::add(%4, %x.1, %2) # <ipython-input-67-7eb141224258>:3:11
return (%5)
The interpreter
The main part of the interpreter is in the Code
class, or rather the CodeImpl
one in torch/csrc/jit/runtime/interpreter.cpp
.
This has two main parts:
- Translating graphs to "bytecode" sequences, this is done on instantiation.
- Running the bytecode.
For the translation, the constructor of CodeImpl
calls emitCodeForBlock
on the input graph's main block.
emitCodeForBlock
then has a typical recursive visitor pattern that produces bytecode for the various things.
The regular bits like calls to PyTorch functions are done by emitOperator
that calls the JIT ops (aten::...
and custom ops) but anything control flow is handled on its own code generator function, dispatched from emitNode
in a large switch statement on the various special node types, with some amendments valid inside blocks in emitNodeAtBlockLevel
. Note that some interesting things, like differentiable graphs from Autodiff are also implemented via "regular" operators, so they don't show up here.
As is a familiar pattern now, the InterpreterState
is a forwarding wrapper for InterpreterStateImpl
which holds the TorchScript VM execution state (like call stack of frames, registers etc.) and does the actual execution.
The execution goes from InterpreterStateImpl::run
(or runAsync
, as the graph executors have, too) to runImpl
which is a state machine with a large switch statement for the various instructions. Of special note are CALL
and INTERFACE_CALL
as well as (context-manager) EXIT
which run functions (builtin or graph ones). Graph functions are run in runGraphFunction
just like they are when called from Python above: We call the functions get_executor()
method and the executor's getPlanFor
and take its code
member. In a new interpreter frame, we run that. Note how this way, the function's graph will go through the optimization as needed.
Conclusion
So that is what happens when you run a TorchScript function. I hope you enjoyed the technical dive into the parts of the JIT runtime that execute bits and the new (to me) form of walking through the code. So far we left out the compilation parts, these are for the another time, as a detailed look at the frontend, i.e. how we get script functions in the first place. This will be very important to us at a later point, too.
I appreciate your feedback and comments - mail me at tv@lernapparat.de.