Exploring Python fallback for the JIT
One of the key difficulty of the almost everything can be scripted promises is what to do with functions the JIT doesn't understand. In lieu of re-implementing all of Python we need to fall back to the Python we have selectively. Join me today in looking how that can be done.
The JIT has awesome optimizations, but they only work when our models, or at least the parts that are to be optimized, are JITed. This is easy enough if our models trace well, but scripting often gets messy because we need to make our entire model JIT-compatible before we have something we can run. This is what a fallback is designed to make easier.
Acknowledgement: This work was made possible by grid.ai, the people known for PyTorch Lightning. Thank you!
Fallbacks
Obviously, the JIT cannot implement everything - all of Python's standard library and all of all other packages. But then what if my model needs to use something the JIT cannot?
Because I lack imagination and I like to work from examples, I took a very simple example with a built-in function. It wasn't all that easy to come up with one, because many useful functions are already implemented.So I don't necessarily want to imply that every model should need to read random files from disk.
import torch @torch.jit.script def fn(x : str): return open(x).read() print(fn.graph) print(fn(__file__))
But this program doesn't work on today's PyTorch:
Traceback (most recent call last): File "/home/tv/pytorch/pytorch/../scripts/fallback.py", line 4, in <module> def fn(x : str) -> str: File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 939, in script fn = torch._C._jit_script_compile( RuntimeError: Python builtin <built-in function open> is currently not supported in Torchscript: File "/home/tv/pytorch/pytorch/../scripts/fallback.py", line 5 @torch.jit.script def fn(x : str) -> str: return open(x).read() ~~~~ <--- HERE
Depending on which slide you take, you get the TorchScript interpreter or the Python fallback. Kids can take any slide they want! It's their playground I borrowed here.
Before we can fix it, we need to know a bit how functions are getting into TorchScript. I touched briefly on this in my post on fusers, but now we look at it in (selectively) more detail.
How the JIT makes a graph from source code
When scripting a function from Python, the JIT grabs the Python source code (via the inspect
module of the standard Python library) and then runs the Python parser from ast
(for Abstract Syntax Tree) module. It then transforms the Python AST into TorchScript AST (implemented in C++). The TorchScript AST is very similar to what you would expect, it is defined in torch/jit/frontend/treeview.h
. Every node type has a class. Thankfully it also has a dump()
function that shows a lisp-like representation of the tree - every opening paranthesis followed by a name is a node of that class and the further elements in the list until the closing paramthesis are the children.
For our little function, the tree looks like this:
(def (ident fn) (decl (list (param (ident x) (option (variable (ident str))) (option) (False))) (option (variable (ident str)))) (list (return (apply (. (apply (variable (ident open)) (list (variable (ident x))) (list)) (ident read)) (list) (list)))))
This is what gets handed to CompilationUnit::define
Avid readers of this blog may recall that we met compilation units in two previous posts. In the JIT runtime overview we saw that they were holding script functions. In the exploration of graph manipulation in Python, we created functions from graphs with their create_function
instance, though I didn't talk about it much. in torch/csrc/jit/frontend/ir_emitter.cpp
So define
calls (well, instantiates, but the gist is a call) to_ir
to get a graph from this tree. This conversion is done in three steps:
-
With a fresh graph instance, a set of node visitors is called along the tree structure, starting with
emitDef
. It produces an initial graph that "looks like Python" and is in non-SSAThe defining feature of static single assignment (SSA) form is that eliminates conventional variables, but instead values are only set once. This makes it easier for optimization passes to reason about them and to generate code. SSA has other aspects, like how loops are handled, but the variable bit will be the most important part for us. form. -
This is passed to a
ConvertToSSA
pass, which does what the name insinuates. -
Then some normalization is carried out in
CanonicalizeModifiedLoops
andNormalizeOps
followed by some initial cleanup passes (e.g. in simple cases split tuples into separate values).
One thing to know about the graph visitation in the first step is that it has to deal with "external" (to our function) references. It uses a resolver passed to define to find the matching Python objects or - for particular functions - the TorchScript overrides. Now these things can be vary strongly in their nature - from Tensor variables (or, to us, constants) to other functions to classes - and we do not want to have all of that show up in our TorchScript graphs.
To deal with the discrepancy between all things that might be and the things that can actually be values in a graph, PyTorch defines a data structure Environment
(in ir_emitter.cpp
) which captures the lexical scopes and local variables and all. Things are stored in the environment in SugaredValues
, which we look at in detail below. Then to process any "top-level" identifier (as opposed to an attribute lookup), the graph visitors call Environment::getSugaredVar
. This method
- first checks all local scopes using
Environment::findInAnyFrame
, - if that didn't find anything, it checks a table of magic global sugared values (defined in static table
global
in thegetSugaredVar
method), - if these didn't return anything, it calls a resolver that looks up tings in the Python environment outside TorchScript and tries to convert this into a
SugaredValue
. More precisely it calls the resolver three times, first trying to find NamedTuple types, then arbitrary values, and then classes. TheresolveValue
call will return aSugaredValue
, while in the other two cases, we get a type and instantiate specializedSugaredValue
subclasses.
The PythonResolver
is defined in python/script_init.cpp
. It's resolveValue
method calls into a python-defined resolution callback to get the desired Python object and calls toSugaredValue
(python/python_sugared_value.cpp
) to convert the Python object it found to a SugaredValue
.
But so what is a SugaredValue
?
Sugar, Value!
As mentioned above SugaredValue
s bridge the gap between anything that can be referenced by name in our programs and what is a Value
in the narrow sense of TorchScript graphs.
At the sugared value level, things you can do with the variable call into the SugaredValue
object to accomplish things:
-
Sometimes you just want to get a (JIT)
Value
for the sugared value, this can be obtained by theasValue
method. Then you can use the value as arguments in functions calls etc. -
For things you can do with references, but not generally with values, the compiler can call into methods on the
SugaredValue
. Notable examples includecall
(for calling functions, constructores, etc.) andattr
for attribute lookup (à lagetattr
). -
SugaredValue
s are subclassed to define the different effects of various methods and help the compiler distinguish valid uses of values (e.g. call a function, add something to a simple value) from invalid ones (like calling a string literal, adding something to a function,...). These things potentially insert things into the graph we're building (e.g. for calling functions) and return aSugaredValue
representing the result.
To give you a taste of the richness of this "type system", these are the subclasses of SugaredValue
in the JIT (from frontend/sugared_value.h
, python/python_sugared_value.h
, and - a bit special - three from serialization/import_source.cpp
): BooleanDispatchValue
, BuiltinFunction
, BuiltinModule
, ClassNamespaceValue
, ClassNamespaceValue
, ClassValue
, ClosureValue
, ConstantParameterList
, ConstantTableValue
, ExceptionMessageValue
, ExceptionValue
, FunctionValue
, IterableTree
, MagicMethod
, MethodValue
, ModuleDictMethod
, ModuleValue
, NamedTupleConstructor
, NoneValue
, OpsValue
, PrintValue
, PythonSliceClass
, PythonValue
, RangeValue
, SimpleValue
, SliceValue
, SpecialFormValue
, SugaredDict
, SugaredEnumClass
, SugaredTupleValue
, TensorCastValue
.
As you imagine, the toSugaredValue
function mentioned above uses many of these to represent the various things, notable constants (represented as SimpleValue
, functions from Python (e.g. marked torch.jit.ignore
represented as PythonValue
), and all sorts of special things things the JIT knows more about.
In the end, all things we do on the sugared values will be translated into graph operations and values, so we end with a graph that only contains the runtime universe of JIT types. It still contains loads and stores to variables (prim::Load
and prim::Store
) and similar things that the ConvertToSSA
pass then eliminates (by again building up an environment, but this time either of only types or of only values, as those are separated in TorchScript). This conversion is done in frontend/convert_to_ssa.cpp
, for loads and stores in a block (so after loops and so have been dealt with), it is done in EraseLoadStores::eraseBlockLoadStores
. This will be important to us later.
Our fallback in the compiler frontend
So what does all that mean for out function above?
It shouldn't be much of a surprise - but we can also fire up the debugger so see it - that the error message above stems from toSugaredValue
noticing that it does not want to deal with our case. We can fix that.
--- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1018,8 +1072,7 @@ std::shared_ptr<SugaredValue> toSugaredValue( if (py::isinstance<py::function>(obj)) { if (typeString(obj) == "builtin_function_or_method") { - throw ErrorReport(loc) << "Python builtin " << py::str(obj) - << " is currently not supported in Torchscript"; + return bindPythonObjectValue(obj, m, loc); } }
We now have two options: We could try to make an existing SugaredValue
class do what we want (and likely, this would be PythonValue
. Or, and this is the route we take today (a decision to be revisited later, but to me it looks like the semantics of the result are different enough to want a separate type), we make a new PythonObjectValue
, which, as we want it to look like what becomes a JIT Value
we make a subclass of SimpleValue
.
We define this new PythonObjectValue
sugared value class in python_sugared_value.h
(the signatures are given, as they are overrides).
--- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -345,5 +345,25 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { size_t n_binders) override; }; +struct VISIBILITY_HIDDEN PythonObjectValue : public SimpleValue { + PythonObjectValue(Value* v) : SimpleValue(v) {} + + std::shared_ptr<SugaredValue> attr( + const SourceRange& loc, + Function& m, + const std::string& field) override; + + std::string kind() const override { + return "computed Python value"; + } + + std::shared_ptr<SugaredValue> call( + const SourceRange& loc, + Function& caller, + at::ArrayRef<NamedValue> args, + at::ArrayRef<NamedValue> kwargs, + size_t n_binders) override; +}; + } // namespace jit } // namespace torch
We introduce a helper function bindPythonObjectValue
in python_sugared_value.cpp
that inserts a node of the new kind prim::PyConstant
that binds a Python value.Actually it isn't quite as constant, as changes in mutuable Python objects will be reflected. and returns the new PythonObjectValue
sugared value. prim::PyConstant
mimics prim::Constant
, but it as the separation of Python and Python-less parts is crucial to PyTorch (this also gives us the sugared_values
and python_sugared_values
distinction). Happily, there already is a PyObject
JIT type and IValues can be made from Python objects through toIValue
(as we met in our JIT runtime overview).
--- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -905,6 +905,60 @@ return std::make_shared<SliceValue>(start, stop, step); } +std::shared_ptr<PythonObjectValue> bindPythonObjectValue( + py::object obj, + Function& m, + SourceRange loc) { + Node* n = m.graph()->insertNode(m.graph()->create(prim::PyConstant)); + n->ival_(attr::value, toIValue(obj, PyObjectType::get())); + n->setSourceRange(loc); + n->output()->setType(PyObjectType::get()); + return std::make_shared<PythonObjectValue>(n->output()); +} + ....
In contrast to SimpleValue
we want to enable calls and attribute lookups (and more might be added in the future), so we implement call
and attr
: call
in the python_sugared_value.cpp
. As there already is a prim::PythonOp
to launch any Python function call, we hook into this, but in contrast to the conventiontional use, our call
will use the first input as the function to call, so we remember that we want to change the implementation of prim::PythonOp
later.To see prim::PythonOp
in action, create a simple method (e.g. multiplying a Tensor
by 2), fully type annotate it, and decorate the definition with @torch.jit.ignore
. We said that attr
is similar to Python's getattr
and indeed, we'll just use a prim::PythonOp
calling getattr
as the graph created by attr
. Both these methods again return PythonObjectValue
s as sugared values and the graph Value
s they create are of type PyObject
.
... +std::shared_ptr<SugaredValue> PythonObjectValue::attr( + const SourceRange& loc, + Function& m, + const std::string& field) { + // using prim::GetAttr would look nicer in the graph, but we would need + // to implement it in the interpreter or move to replacing prim::GetAttr + // on PythonObjects + // later as a pass + std::string cconv(2, 'd'); + Value* v_field = insertConstant(*m.graph(), field, loc); + py::object getattr = py::module::import("builtins").attr("getattr"); + Node* n = m.graph()->insertNode(m.graph()->createPythonOp( + THPObjectPtr(getattr.release().ptr()), cconv, {})); + n->setSourceRange(loc); + n->addInput(getValue()); + n->addInput(v_field); + n->addOutput()->setType(PyObjectType::get()); + return std::make_shared<PythonObjectValue>(n->output()); +} + +std::shared_ptr<SugaredValue> PythonObjectValue::call( + const SourceRange& loc, + Function& m, + at::ArrayRef<NamedValue> args, + at::ArrayRef<NamedValue> kwargs, + size_t /*n_binders*/) { + auto inputs = toValues(*m.graph(), args); + std::string cconv(inputs.size(), 'd'); + if (!kwargs.empty()) { + throw ErrorReport(loc) << "KWARGS currently not supported"; + } + Node* new_node = + m.graph()->insertNode(m.graph()->createPythonOp({}, cconv, {})); + + new_node->setSourceRange(loc); + new_node->addInput(getValue()); + for (auto& i : inputs) + new_node->addInput(i); + + Value* output = new_node->addOutput()->setType(PyObjectType::get()); + return std::make_shared<PythonObjectValue>(output); +} + std::shared_ptr<SugaredValue> toSugaredValue( py::object obj, Function& m,
In order to print graphs with prim::PyConstant
nodes, which have a PyObject IValue
as the value
attribute (compare to prim::Constant
in graphs), we need to add a case to IValue::repr
in aten/src/ATen/core/ivalue.cpp
. The prim::PyConstant
itself is defined - like a prim::CastFromPython
that we need in a bit - in aten/src/ATen/core/interned_strings.h
.
--- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -66,6 +66,8 @@ namespace c10 { _(prim, AutogradAllNonZero) \ _(prim, AutogradAllZero) \ _(prim, Starred) \ + _(prim, PyConstant) \ + _(prim, CastFromPython) \ _(prim, TupleConstruct) \ _(prim, TupleUnpack) \ _(prim, TupleIndex) \ --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -532,6 +532,9 @@ std::ostream& IValue::repr( return out << enum_holder->qualifiedClassName() << "." << enum_holder->name(); } + case IValue::Tag::PyObject: { + return out << "<python object>"; + } case IValue::Tag::Object: { TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?"); }
Rising forward?
It seems dubious that rising forward might be an antonym of falling back, but we need a reversal.
Above we said that operations on PyObject
inputs yield PyObject
results. If we do not just want basically use the JIT as a particularly indirect line-by-line Python interpreter, we also need a plan to get our beloved JIT typed objects - such as Tensor
s, int
s, float
, and str
s from these PyObject
s.
At the technical level, we introduce a second new operator, the prim::CastFromPython
mentioned above. This takes a (slightly misnamed because it only is one) types
argument with a JIT Type
attribute.
When executed, it will cast the Python object to the desired type just as passing it to a JIT function called from Python would. On error it raises a runtime ValueError
.
But how would we do this in our program? We use type annotations on the variable assignment to express our desire to have a certain type. In our example we might say that the result is a str
:
import torch @torch.jit.script def fn(x : str): res: str = open(x).read() return res
If we do this, we get an error (the program before would run, and PyObject would show in the schema
(the PyTorch JIT signature for a function) as the return value) because the tree visiting in ir_emitter.cpp
dispatches the assignment to Environment::setSugaredVar
which checks if the annotated type matches the assigned value. We tell it to not throw an exception when the assigned type is PyObject
and add a check_type
argument to Environment::insertStore
which, if true
, causes insertStore
to add a types
attribute with the desired type to the prim::Store
node it creates and sets the input Value
's type to PyObject
.
--- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -280,9 +280,14 @@ struct Environment { const std::string& name, const SourceRange& loc, Value* v, - TypePtr type) { + TypePtr type, + bool check_type) { auto g = b->owningGraph(); - g->insertNode(g->createStore(name, v))->setSourceRange(loc); + auto n = g->insertNode(g->createStore(name, v))->setSourceRange(loc); + if (check_type) { + v->setType(PyObjectType::get()); + n->ty_(attr::types, type); + } type_table[name] = std::move(type); } @@ -399,14 +404,20 @@ struct Environment { if (!annotated_type) { annotated_type = as_simple_value->type(); } - if (!as_simple_value->type()->isSubtypeOf(annotated_type)) { + if (as_simple_value->type() == PyObjectType::get()) { + } else if (!as_simple_value->type()->isSubtypeOf(annotated_type)) { throw ErrorReport(loc) << "Variable '" << name << "' is annotated with type " << annotated_type->repr_str() << " but is being assigned to a value of type " << as_simple_value->type()->repr_str(); } - insertStore(name, loc, as_simple_value, annotated_type); + insertStore( + name, + loc, + as_simple_value, + annotated_type, + as_simple_value->type() == PyObjectType::get()); } else { value_table[name] = std::move(value); }
Then in EraseLoadStore::eraseBlockLoadStores
in convert_to_ssa.cpp
we insert the prim::CastFromPython
node if the prim::Store
we process has the types
attribute set.
--- a/torch/csrc/jit/frontend/convert_to_ssa.cpp +++ b/torch/csrc/jit/frontend/convert_to_ssa.cpp @@ -194,7 +194,18 @@ struct EraseLoadStores { switch (n->kind()) { case prim::Store: { - environment_stack->setVar(n->s(attr::name), n->input()); + auto v = n->input(); + if (n->hasAttribute(attr::types)) { + auto ty = n->ty(attr::types); + auto ta = n->owningGraph() + ->create(prim::CastFromPython) + ->setSourceRange(n->sourceRange()) + ->ty_(attr::types, ty) + ->insertAfter(n); + ta->addInput(v); + v = ta->output()->setType(ty); + } + environment_stack->setVar(n->s(attr::name), v); n->destroy(); } break; case prim::Load: {
We are not quite done here, because we also want to enable casting by return type annotations, so
import torch @torch.jit.script def fn(x : str) -> str: return open(x).read()
does the right thing. For this, we go back to ir_emitter.cpp
find emitReturn
and insert a prim::CastFromPython
node if we find a type annotation and a return of type PyObject
.
--- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1007,7 +1018,16 @@ struct to_ir::emitReturn /*allow_conversions=*/true); } - if (!result->type()->isSubtypeOf(result_type)) { + if ((result->type() == PyObjectType::get()) && + (result_type != AnyType::get())) { + auto n = graph->insertNode(graph->create(prim::CastFromPython) + ->setSourceRange(stmt.range()) + ->ty_(attr::types, result_type)); + n->addInput(result); + result = n->output()->setType(result_type); + } + + if (!(result->type()->isSubtypeOf(result_type))) { throw ErrorReport(stmt.range()) << "Return value was annotated as having type " << result_type->repr_str() << " but is actually of type "
There might be cases missing yet when you want to work with tuples, but that is how it is.
So yay, the JIT frontend can deal with out fallback and we can get back.
The middle ages, er, layers
As will be no surprise to readers of this blog, there are a number of analysis and optimization passes the JIT runs before actually executing code. Most of them are happy to ignore our new two node kinds prim::PyConstant
and prim::CastFromPython
, but there are two exceptions: The alias analysis and the printing mechanism want to know more. I rand gdb
again to find the stack trace for the exceptions being raised, and so we need to modify two and a half places:
- In
runtime/operator.cpp
, we need add our new operators to the lists inprinterHasSpecialCaseFor
andaliasAnalysisHasSpecialCaseFor
.
--- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -214,7 +214,7 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::CreateObject, prim::GetAttr, prim::SetAttr, prim::CallFunction, prim::isinstance, prim::unchecked_cast, prim::tolist, prim::rpc_async, prim::rpc_sync, - prim::rpc_remote}; + prim::rpc_remote, prim::PyConstant, prim::CastFromPython}; // WARNING: by adding a value to this set, you are asserting that your // primitive is only ever added during optimization and does not need @@ -324,6 +324,8 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::Enter, prim::Exit, prim::FallbackGraph, + prim::CastFromPython, + prim::PyConstant, }; // Operators that should not be used by alias analysis
- In for the actual alias analysis we need to specify the aliasing relations in
AliasDb::analyzeImpl
. I chose to addprim::PyConstant
to the creator case andprim::CastFromPython
to link input and output (though I don't know how well that works with wherever thePyObject
might have come from,wildcard
might be an alternative).
--- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -498,6 +498,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::Closure: case prim::CreateObject: case prim::tolist: + case prim::PyConstant: return analyzeCreator(node); case prim::TupleConstruct: case prim::DictConstruct: @@ -516,6 +517,7 @@ void AliasDb::analyzeImpl(Node* node) { } } return analyzeExtractor(node); + case prim::CastFromPython: case prim::unchecked_cast: return makePointerTo(node->output(), node->input()); case prim::ConstantChunk:
And this is all the purely administrative things we needed. I should say that only when I started investigating fallbacks, I started to notice that most infrastructure (Type
s. Value
s, IValue
s,...) is already in place.
Extending and implementing the operators
Actually, there is a small bit of administrative stuff left. The existingprim::PythonOp
that we wanted to extend to use the first parameter as the function instead of a fixed one bound directly from Python is implemented via a ConcretePythonOp
class in python/python_ir.cpp
.
This has two places where we need to deal with the possibility that the Python object that would be a function is "empty" (i.e. a null object), so we add two quick if cases in name
(for printing, just emitting a default <PyObjectCall>
for now) and for copying (setting the target Python object to null when the source is).
--- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -129,7 +129,9 @@ Node* findNode(Block* block, Symbol kind, bool recurse = true) { std::string ConcretePythonOp::name() const { pybind11::gil_scoped_acquire gil; - if (auto autograd = autogradFunction()) { + if (!pyobj) { + return "<PyObjectCall>"; + } else if (auto autograd = autogradFunction()) { return getPythonName(autograd->get()); } else { return getPythonName(pyobj.get()); @@ -140,8 +142,12 @@ void ConcretePythonOp::cloneFrom(Node* other_) { Node::cloneFrom(other_); auto other = other_->cast<ConcretePythonOp>(); this->cconv = other->cconv; - Py_INCREF(other->pyobj.get()); - this->pyobj = THPObjectPtr(other->pyobj.get()); + if (other->pyobj) { + Py_INCREF(other->pyobj.get()); + this->pyobj = THPObjectPtr(other->pyobj.get()); + } else { + this->pyobj = {}; + } for (auto& sa : other->scalar_args) { Py_INCREF(sa.get()); this->scalar_args.emplace_back(sa.get());
But now, all that is left is to extend the implementation of prim::PythonOp
and write the new prim::PyConstant
and prim::CastFromPython
.
The functions implementing them are defined and registered as JIT operators in python/python_interpreter.cpp
. They take a Node
as an argument and return an Operation
, a lambda taking a Stack
of IValue
s as input (actually there might be additional values below the ones for the operator) and return the stack with the inputs replaced by the outputs.
Adapting prim::PythonOp
is straightforward, thanks to the excellent PyBind11 tooling. On creation we just check the truth value of the pyobj
member of the ConcretePythonOperator
instance behind the node to decide if the local func
object should be initialized from it or left empty.
On execution we check if func
is empty and if so pop the top input from the stack and cast it to a Python object to get a function to call.
--- a/torch/csrc/jit/python/python_interpreter.cpp +++ b/torch/csrc/jit/python/python_interpreter.cpp @@ -12,6 +12,7 @@ #include <torch/csrc/jit/runtime/graph_executor.h> #include <torch/csrc/jit/runtime/operator.h> +#include <sstream> #include <typeinfo> #include <pybind11/pybind11.h> @@ -31,11 +32,14 @@ namespace { Operation createPythonOperation(const Node* op_) { pybind11::gil_scoped_acquire gil; const ConcretePythonOp* op = static_cast<const ConcretePythonOp*>(op_); - const py::function func = py::reinterpret_borrow<const py::function>( - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - py::handle(const_cast<ConcretePythonOp*>(op)->pyobj.get())); + const py::function func = + (op->pyobj + ? py::reinterpret_borrow<const py::function>( + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + py::handle(const_cast<ConcretePythonOp*>(op)->pyobj.get())) + : py::function{}); - size_t num_inputs = 0; + size_t num_inputs = op->pyobj ? 0 : 1; for (auto arg_type : op->cconv) { if (arg_type == 'd') num_inputs++; @@ -49,6 +53,12 @@ Operation createPythonOperation(const Node* op_) { size_t i = 0; size_t next_scalar = 0; size_t next_tensor = 0; + py::function func_from_val; + if (!func) { + func_from_val = + toPyObject(std::move(peek(stack, next_tensor, num_inputs))); + next_tensor++; + } for (auto arg_type : op->cconv) { if (arg_type == 'c') { py_inputs[i] = py::reinterpret_borrow<const py::object>( @@ -65,7 +75,7 @@ Operation createPythonOperation(const Node* op_) { } drop(stack, num_inputs); try { - py::object py_output(func(*py_inputs)); + py::object py_output((func ? func : func_from_val)(*py_inputs)); stack->push_back(returnToIValue(op->output()->type(), py_output)); } catch (py::error_already_set& e) { throw std::runtime_error(e.what());
The PyConstant
Operation is even simpler, jut put the saved IValue
in the value
member on the stack. We need to take the GIL because copying the IValue
here likely needs to increase the Python objects reference counter. I think. I might want to double check.
Similarly, the cast operation implementing CastFromPython
takes a Python IValue from the top of the stack, casts with the utility functions also used when invoking ScriptFunction
s from Python and returns the result. On error, the cast will throw an exception, but we catch it to translate it to a more Pythonic ValueError
.
--- a/torch/csrc/jit/python/python_interpreter.cpp +++ b/torch/csrc/jit/python/python_interpreter.cpp @@ -77,10 +87,47 @@ return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } -RegisterOperators reg({Operator( - prim::PythonOp, - createPythonOperation, - aliasAnalysisIsSpecialCase())}); +Operation createPyConstantOperation(const Node* node) { + pybind11::gil_scoped_acquire gil; + auto val = node->ival(attr::value); + return [=](Stack* stack) { + pybind11::gil_scoped_acquire gil; + stack->push_back(val); + }; +} + +Operation createCastFromPythonOperation(const Node* node) { + TypePtr typ = node->ty(attr::types); + return [=](Stack* stack) { + pybind11::gil_scoped_acquire gil; + + py::object pyobj = toPyObject(std::move(pop(stack))); + try { + stack->push_back(toIValue(pyobj, typ)); + } catch (py::cast_error& e) { + std::stringstream msg; + py::object pytype = + py::module::import("builtins").attr("type")(pyobj).attr("__name__"); + msg << "ValueError: cannot cast Python object of type " << pytype + << " to TorchScript type " << *typ; + throw std::runtime_error(msg.str()); + } + }; +} + +RegisterOperators reg( + {Operator( + prim::PythonOp, + createPythonOperation, + aliasAnalysisIsSpecialCase()), + Operator( + prim::PyConstant, + createPyConstantOperation, + aliasAnalysisIsSpecialCase()), + Operator( + prim::CastFromPython, + createCastFromPythonOperation, + aliasAnalysisIsSpecialCase())}); } // namespace } // namespace jit
And that is it!
Now our little TorchScript function runs:
$ PYTHONPATH=build/lib.linux-x86_64-3.8/ python3.8 ../scripts/fallback.py graph(%x.1 : str): %4 : str = prim::Constant[value="read"]() # ../scripts/fallback.py:5:11 %1 : PyObject = prim::PyConstant[value=<python object>]() # ../scripts/fallback.py:5:11 %3 : PyObject = ^<PyObjectCall>()(%1, %x.1) # ../scripts/fallback.py:5:11 %5 : PyObject = ^getattr()(%3, %4) # ../scripts/fallback.py:5:11 %6 : PyObject = ^<PyObjectCall>()(%5) # ../scripts/fallback.py:5:11 %7 : str = prim::CastFromPython[types=str](%6) # ../scripts/fallback.py:5:4 return (%7) #!/usr/bin/python import torch @torch.jit.script def fn(x : str) -> str: return open(x).read() print(fn.graph) print(fn(__file__))
Our patches clock in at a little under 200 lines. It appears that at least in this case, a relatively simple fallback is feasible. We should not kid ourselves: There will be many more cases to handle in the frontend around the SugaredValue
and we need tests, too!
Conclusion
Our exploration into a SugaredValue
based fallback mechanism worked surprisingly (or emberassingly, I might have forgotten something) well. It will have to be seen how much of the problem it solves. If you want to play with it without copypasting the diff snippets, you can also use my ScriptTorch git branch.
-
We might have cases where we discover during partial de-sugaring (i.e. after attribute lookups) that we cannot complete it and might have needed to invoke the fallback mechanism earlier (i.e. before we knew). This will likely be a tough one to analyse.
-
We did not look at syntax constructs currently not handled by the JIT.
-
We need to work on "standard functions" (the operators and magic functions) working on
PyObject
s. These would need to be re-routed toPythonOp
s.
I hope you enjoyed this little expedition into the internals of the JIT, with a view towards implementing a fallback. The JIT already provided us with most of the infrastructure, so this was easy.
But here is your part: In addition to helping out with the code, what is your favourite model bit that you cannot yet script? I look forward to hear from you at tv@lernapparat.de.