Traceable and Differentiable Extensions with PyTorch
Editorial note: This post - from the end of June 2019 - described how to do manually what is now, three months later, offered painlessly by the torch::autograd::Function
mechanism. In fact, you can see how we used autograd::Function
for the function discussed below in the TorchVision source code. Of course, what autograd::Function
does, is providing a nice wrapper for the things I described below.
Three of the most liked features of PyTorch are the extensible autograd mechanism, the ability to extend PyTorch with C++ efficiently, and the tracing/scripting mechanism, the PyTorch JIT. Which leads to the natural question - can we have all at the same time?
In this post, we dive into the autograd internals and come out with a solutionThis is not the @torch.jit.script
decorated autograd.Function
which I'm hacking on, but something that is all available today and perhaps not even breaking news - but by my estimate, there are only 3-5 people aware of it working now. We will have an exciting followup..
Let us see. While C++ extensions are arguably the most popular way of extending PyTorch, there also are C++ custom ops that extend TorchScript, i.e. are traceable / scriptable. The price you pay is that you are limited to the data types that TorchScript provides, but with Tensor
, float
, int
, string
and even lists and dicts of them, you are pretty much set. All you need to do is to use RegisterOperators
in your module, i.e.
static auto registry = torch::RegisterOperators().op("mylib::something", &do_something);
You can do this in addition to the C++-Extension (PyBind11-) bindings for your functions or instead of them. If you remove the extension bindings, you become independent of Python and can load your library in C++-Programs, too, but you might have to think about how to load it from Python.
We have C++ and scripting. But now comes the difficulty, how do we get differentiability?
The standard way of doing this - e.g. suggested in the C++ extension tutorial - is to implement forward and backward and then wrap them using a torch.autograd.Function
in Python. But here is the problem: those are not scriptable.
So if a C++-implemented operator wrapped in an autograd.Function
is not scripteable, maybe we need an "autograd.Function" wrapped in a C++ operator. The only problem is, we do not have (or want) Python inside C++.
So how do Function
s work at the C++ level? If you've read my selective excursion into PyTorch externals, I have not gone into any detail there. Edward Yang's great blog post and slides do have it, but alas he skipped the following seven slides [in his NY meetup talk and is] also going to delay writeup for them; you'll have to wait for the sequel for some text. But we had a glimpse at the Function slide and that's the thing we need. I cannot recommend Edward's blog post highly enough, it is a very gentle introduction and gets you up to speed to the point where I learn a lot from reading it. Thank you, Edward!
If you have not worked with Python's torch.autograd.Function
or want to refresh your memory, I recommend checking out the documentation as well as the Autograd mechanics chapter of the PyTorch documentation. As this Python side of autograd is well known, we will use that as a reference when diving into the internals.
So I should add a rather large caveat here: Just like in PyTorch 0.4, we had the great Variable
/Tensor
merge on the Python side, PyTorch is currently (June 2019 / master showing 1.2 as the next version), Will Feng is currently working on merging them in C++, too. This has lead to some breaking changes (for me it was .data()
going away in the code we are going to develop - just as you should not use .data
in Python these days.
Dissecting how PyTorch builds the computational graph
So let us look at how PyTorch's own functions manage to work with autograd. We pick the simplest possible function - a (pointwise, unary) function that takes one tensor and produces one of the same shape - say atan
. After you compile PyTorch, there is a file torch/csrc/autograd/generated/VariableTypeEverything.cpp
I use rgrep
a lot to find my way through the PyTorch code. After forming some basic opinion about where what I am looking for might be - say torch/csrc/autograd
- I use rgrep '\batan\b' torch/csrc/autograd
. So VariableTypeEverything.cpp
has been split up into pieces and we also get hits in the pieces.. Note that we want atan
, not the inplace atan_
nor the out of place atan_out
function.
Here is how it looks like for me:
Tensor VariableType::atan(const Tensor & self) const { RECORD_FUNCTION("atan", std::vector<c10::IValue>({self}), Function::peek_at_next_sequence_nr()); auto& self_ = unpack(self, "self", 0); std::shared_ptr<AtanBackward> grad_fn; if (compute_requires_grad( self )) { grad_fn = std::shared_ptr<AtanBackward>(new AtanBackward(), deleteFunction); grad_fn->set_next_edges(collect_next_edges( self )); grad_fn->self_ = SavedVariable(self, false); } torch::jit::Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; if (jit::tracer::isTracing()) { tracer_state = jit::tracer::getTracingState(); at::Symbol op_name; op_name = jit::Symbol::fromQualString("aten::atan"); node = tracer_state->graph->create(op_name, /*num_outputs=*/0); jit::tracer::recordSourceLocation(node); jit::tracer::addInputs(node, "self", self); tracer_state->graph->insertNode(node); jit::tracer::setTracingState(nullptr); } #ifndef NDEBUG c10::optional<Storage> self__storage_saved = self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt; c10::intrusive_ptr<TensorImpl> self__impl_saved; if (self_.defined()) self__impl_saved = self_.getIntrusivePtr(); #endif auto tmp = ([&]() { at::AutoNonVariableTypeMode non_var_type_mode(true); return baseType->atan(self_); })(); auto result = as_variable(tmp); #ifndef NDEBUG if (self__storage_saved.has_value()) AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage())); if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr()); #endif if (grad_fn) { set_history(flatten_tensor_args( result ), grad_fn); } if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); jit::tracer::addOutput(node, result); } return result; }
Uh. That is a lot to digest! We need to focus, so we will first ignore everything that is in those #ifndef NDEBUG
or looks jit::tracer
-related (we want to use this in a C++ custom op, so we already have tracing from that). Let us delete those lines. We are left with
Tensor VariableType::atan(const Tensor & self) const { RECORD_FUNCTION("atan", std::vector<c10::IValue>({self}), Function::peek_at_next_sequence_nr()); auto& self_ = unpack(self, "self", 0); std::shared_ptr<AtanBackward> grad_fn; if (compute_requires_grad( self )) { grad_fn = std::shared_ptr<AtanBackward>(new AtanBackward(), deleteFunction); grad_fn->set_next_edges(collect_next_edges( self )); grad_fn->self_ = SavedVariable(self, false); } auto tmp = ([&]() { at::AutoNonVariableTypeMode non_var_type_mode(true); return baseType->atan(self_); })(); auto result = as_variable(tmp); if (grad_fn) { set_history(flatten_tensor_args( result ), grad_fn); } return result; }
Well, still large, but much more manageable! Let us go through it in pieces and see what we find. There is some very important great news. Looking at the signature
Tensor VariableType::atan(const Tensor & self) const {
we see that while we won't be part of a class like VariableType
, we can perfectly relate to
Tensor atan(const Tensor & self)
- we would use something like that in our extension function or custom op implementation, too, with Tensor
being torch::Tensor
. Awesome!
The next line
RECORD_FUNCTION("atan", std::vector<c10::IValue>({self}), Function::peek_at_next_sequence_nr());
is for the purposes of the PyTorch profiler (did you know that exists? It is great!)You might have guessed, I used rgrep RECORD_FUNCTION torch/csrc/
to find the definition..
The next line is unpack
, and we look at it together with as_variable
near the end:
auto& self_ = unpack(self, "self", 0); ... auto result = as_variable(tmp);
So these bits are the effect of still having Tensor
and Variable
- the incoming Tensor self
is a variable in disguise and we unpack that to a pure (non-Variable
) Tensor
and then we wrap the result tmp
into a variable-in-disguise Tensor result
. Unsurprisingly, unpack
is not a very unique name in the PyTorch codebase, here it is VariableType::unpack
. This is a bit that is not too unlikely to change in the near future with Will's work mentioned above. For now we need them and will need to come back to them later.
In the middle we have the actual calculation:
auto tmp = ([&]() { at::AutoNonVariableTypeMode non_var_type_mode(true); return baseType->atan(self_); })();
This is using lambdas in a creative way in order to not spell out the type of tmp
and still get scoping, but what happens is that at::AutoNonVariableTypeMode
is a guard variable that is similar to making the remainder of the scope behave like with torch.no_grad():
. Indeed, it is a functional equivalent of the more familiar torch::NoGradGuard
and will be merged into that in due courseJust like with
statements are often handy in PyTorch for backend flags, default types, autograd mode, PyTorch C++ uses those guards for both user-facing things and internally.. The line return baseType->aten(self_)
is the actual calculation. Of course, we would do our own instead.
The remaining bits of code now are the actual graph recording bits:
std::shared_ptr<AtanBackward> grad_fn; if (compute_requires_grad( self )) { grad_fn = std::shared_ptr<AtanBackward>(new AtanBackward(), deleteFunction); grad_fn->set_next_edges(collect_next_edges( self )); grad_fn->self_ = SavedVariable(self, false); } .... if (grad_fn) { set_history(flatten_tensor_args( result ), grad_fn); }
Note carefully that while the actual calculation worked with the unpacked object self_
and tmp
, these bits only work on the wrapped objects self
and result
.Some of you asked for a talk or video on this topic, we will see how well I am at keeping self_
and self
apart when speaking.
By checking compute_requires_grad( self )
we only record a graph when something needs a gradient (i.e. requires_grad
is set and we're not in a with no_grad()
block). The function takes any number of tensors as arguments. The second if
will be true precisely when the first was, because we initialize the shared pointer there.
The AtanBackward
object is declared in the first line (using a shared pointer - thanks for not needing raw pointers!) is what you also see in Python when you check out x.grad_fn
for some calculation result. It is a node in the graph recording the calculation. ATanBackward
is a subclass of torch::autograd::Function
which does the required bookkeeping for calling backward
later.It may be surprising that the superclass of all Backward
classes is called Function
. In a way, this is the PyTorch 0.1.2 way before Python-level new-style torch.autograd.Function
s were introduced and the information that we now save in the ctx
Contexts has was stored in instances of the torch.autograd.Function
object. After instantiating itThe deleteFunction
deleter was not declared as part of the public API, which we fixed while writing this, so you need a very recent nightly/master for things to work., we hook it up to the existing graph by calling set_next_edges
with the edges from collect_next_edges
. collect_next_edges
again takes a variable number of arguments so you would hook up all your inputs' sub graphs.Alas, I think collect_next_edges
might not be public API, but I missed this in the first version of the draft because I was working off a feature branch for other work. Hopefully, we'll get an official API soon.
Then graph_fn->self_
graph_fn->self_
has nothing to do with self_
other than that they are both related to self
, the former is more wrapped, the other is unwrapped... is set to the SavedVariable
-wrapped self
. This is similar to Python torch.autograd.Function
s calling ctx.save_for_backward
, which will wrap all of the arguments into SavedVariable
s. The purpose of SavedVariable
is to sanity check that nothing bad happened to our tensor in between (the most (in)famous non-sane thing being caught and raised by this is the one of the variables needed for gradient computation has been modified by an inplace operation exception)This is also why you should save inputs and outputs using ctx.save_for_backward
but do not necessarily need to do so for intermediate results that have no references outside your torch.autograd.Function
- if noone knows your intermediates, noone can drive you crazy by modifying them without you noticing.. self_
is just an arbitrary member of AtanBackward
, we will define our own subclass of torch::autograd::Function
later.
Finally, we need to attach our graph to our results, which is done in the second if
block by
set_history(flatten_tensor_args( result ), grad_fn);
As autograd needs to know how many and which outputs we have, we pass them all in one go, flatten_tensor_args
takes the results (varargs again) and hands a list to set_history
, which then connects grad_fn
to each of them.
Phew. So this is what needs to happen in our forward, but how does AtanBackward
work?
The backward
So the AtanBackward
is defined in torch/csrc/autograd/generated/Functions.h
, with details in the Functions.cpp
.
Let us first look at the declaration in the Functions.h
struct TORCH_API AtanBackward : public TraceableFunction { using TraceableFunction::TraceableFunction; variable_list apply(variable_list&& grads) override; std::string name() const override { return "AtanBackward"; } void release_variables() override { self_.reset_data(); self_.reset_grad_function(); } SavedVariable self_; };
We will conveniently ignore that it is a TraceableFunction
subclass and pretend that it is only a Function
subclass. It defined overrides for three methods: apply
which is the actual backward computation, name
returning the name, and release_variables
which cleans up each after the SavedVariable
members by calling their reset_data
and reset_grad
functions. The latter two are straightforward to adapt when we define our own: Return our own name from name
and just do the right thing for all our SavedVariable
members (we see that we only have self_
here).
This leaves the apply
method defined in the Function.cpp
, which needs a closer look:
variable_list AtanBackward::apply(variable_list&& grads) { IndexRangeGenerator gen; auto self_ix = gen.range(1); variable_list grad_inputs(gen.size()); auto& grad = grads[0]; auto self = self_.unpack(); if (should_compute_output({ self_ix })) { auto grad_result = grad / (self * self + 1); copy_range(grad_inputs, self_ix, grad_result); } return grad_inputs; }
Looking at the signature
variable_list AtanBackward::apply(variable_list&& grads) {
we get a variable list of gradients (grad_out
if you wish) and need to produce a variable list of input gradients (grad_in
s). So far so good. A variable_list
is just a std::vector<Variable>
and we can ignore the Variable
vs. Tensor
bit for a moment.
The IndexRangeGenerator
business looks taunting. Remember that we called collect_next_edges
with our inputs? This now helps us map back the list of inbound edges (and grad_in
s in the backward results) to the various arguments. For Tensor
arguments this is fairly boring, it would get more interesting for lists of tensors and such. After we got all those index range
s, we know how many inputs we had in total and can allocate grad_inputs
to the right size (note that by default, the gradients are set to undefined Tensors
here, these map to None
in Python).
Similar to the inputs, we used flatten_tensors
on the outputs, and we get to collect the grads
. In
auto& grad = grads[0];
grad
is the gradient of our only output. We only need a reference, not a copy (for efficiency reasons).
Next we unpack our SavedVariable
s:
auto self = self_.unpack();
Note that this SavedVariable::unpack
is strictly unrelated to the unpack function we used in the forward. It does the aforementioned sanity checks.
Next we go through our inputs (but we only have one) and if should_compute_output
says we should, we compute the gradient of the input and copy it into the grads
vector. If you look above the function, you see that there is another variant of should_compute_output
that foregoes the IndexRange
business and takes just a simple integer index.
The actual gradient computation is
auto grad_result = grad / (self * self + 1);
Now this is generated from a template by clever tools from the PyTorch tools/autograd/
directory during compilation. If we did it manually, we could rip out the IndexRangeGenerator
and write our AtanBackard::apply
as
variable_list AtanBackward::apply(variable_list&& grads) { variable_list grad_inputs(1); auto& grad = grads[0]; auto self = self_.unpack(); if (should_compute_output(0)) { grad_inputs[0] = grad / (self * self + 1); } return grad_inputs; }
which has a bit less baggage.
Rolling our own
So is our new knowledge sufficient to conquer the quest of differentiable, traceable C++ operators?
Let us try our hand at a practical example. My fellow PyTorch developer and online friend Francisco Massa introduced a C++ extension for MaskRCNN functions in TorchVision 0.3 and my extensive discussions around the topic with him have been part of the inspiration for this blog post. We pick one, say, roi_align
and try.
As is recommended by the official extension tutorial, TorchVision wraps the forward and backward C++ functions into a torch.autograd.Function
function, in torchvision/ops/roi_align.py
. I will not copy the full thing here, but it takes twoTensor
inputs inputs
and rois
, and parameters output_size
(two ints pool_h
, pool_w
), a float spatial_scale
and another int sampling_ratio
.
It computes a gradient only for input
(rois
might be integral, but I have not checked). For this it needs the parameters, the rois
input Tensor, but only the shape of the input tensor, not the entire input tensor.
It saves those in the context. The backward does not need the output.
Other than those administrative things the function merely calls roi_align_forward
and roi_align_backward
from the C++ extension.
So let us create a C++ custom operator equivalent of the functionRecall the we prefer the custom op over an extension because it is traceable. We can also just wrap the function for the op in the extension interface and get a differentiable but not traceable extension function..
For simplicity, we will put our changes directly into torchvision, in the main module file vision.cpp
I would not expect that dumping this is in there is up to TorchVision standards, but we want to get a prototype fast....
So first we need all the utility functions mentioned above, so we include the worldMaybe it would be beneficial to streamline this and bless some more official way of defining this. I hope this blog post can facilitate a discussion here..
#include <torch/csrc/autograd/VariableTypeUtils.h> #include <torch/csrc/autograd/function.h> #include <torch/csrc/autograd/functions/utils.h> #include <torch/csrc/autograd/saved_variable.h> #include <torch/csrc/autograd/variable.h> #include "torch/script.h"
we also want these fancy shortcuts Tensor
and variable_list
, so
using torch::Tensor; using torch::autograd::variable_list;
So we declare our Backward
object:
struct ROIAlignBackward : public torch::autograd::TraceableFunction { variable_list apply(variable_list&& grads) override; std::string name() const override { return "ROIAlignBackward"; } void release_variables() override { rois_.reset_data(); rois_.reset_grad_function(); } torch::autograd::SavedVariable rois_; double spatial_scale; int64_t pooled_height; int64_t pooled_width; int64_t sampling_ratio; int64_t batch_size, channels, height, width; };
Nothing surprising here, what the Python function saved into the ctx
Context is now declared as fields. The rois_
are a SavedVariable. Note that Python floats and ints are mapped to double
and int64_t
. While the extension mechanism (via PyBind11) is lenient here (maybe it should not), this is a necessity for successfully using custom ops and dealing with TorchScript in general.
Our ROIAlignBackward::apply
method also mimics the simplified version of the ATanBackward
:
variable_list ROIAlignBackward::apply(variable_list&& grads) { variable_list grad_inputs(1); auto& grad = grads[0]; auto rois = rois_.unpack(); if (should_compute_output(0)) { grad_inputs[0] = ROIAlign_backward( grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio); } return grad_inputs; }
Even though the forward has two tensor inputs (and a bunch of others), we only return a one-element gradient_inptus
(for inputs
). This is different to how things work in torch.autograd.Function
s but this means that we will only pass inputs
to collect_edges
in the forward. Just like the Python function, we simply hand off the calculation to the ROIAlign_backward
method.
That was not too bad! Let us do the forward. Here, we hit a small stumbling block: VariableType::unpack
is a private static method! But looking at it, it only calls VariableType::checked_cast_variable
, we might just substitute that. But alas, that, too, is a private static method. So what does it do?
It checks t.defined()
and t.is_variable()
, raising an exception if they do not return true
, and then returns as_variable_ref(t);
. Happily that is a function and available to us. As we are not in the torch::autograd
namespace, we need to prefix the functions with the namespace.
With this in mind, we write our function roi_align
. It turns out a bit long because of the many inputs we need to add to grad_fn
, but is very straightforward otherwise.
Tensor roi_align( const Tensor& input, const Tensor& rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, const int64_t sampling_ratio) { // checks from VariableType::unpack TORCH_CHECK(input.defined() && input.is_variable(), "invalid argument input"); TORCH_CHECK(rois.defined() && rois.is_variable(), "invalid argument rois"); // we might error if rois requires grad... auto& input_ = torch::autograd::as_variable_ref(input); auto& rois_ = torch::autograd::as_variable_ref(rois); std::shared_ptr<ROIAlignBackward> grad_fn; if (torch::autograd::compute_requires_grad(input, rois)) { grad_fn = std::shared_ptr<ROIAlignBackward>( new ROIAlignBackward(), torch::autograd::deleteFunction); grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); // note, only input! grad_fn->rois_ = torch::autograd::SavedVariable(rois, false); // extra bookkeeping grad_fn->spatial_scale = spatial_scale; grad_fn->pooled_height = pooled_height; grad_fn->pooled_width = pooled_width; grad_fn->sampling_ratio = sampling_ratio; grad_fn->batch_size = input.size(0); grad_fn->channels = input.size(1); grad_fn->height = input.size(2); grad_fn->width = input.size(3); } auto tmp = ([&]() { at::AutoNonVariableTypeMode non_var_type_mode(true); return ROIAlign_forward( input_, rois_, spatial_scale, pooled_height, pooled_width, sampling_ratio); })(); auto result = torch::autograd::as_variable(tmp); if (grad_fn) { set_history(torch::autograd::flatten_tensor_args(result), grad_fn); } return result; }
Note how we pass only input
to collect edges and how which inputs we calculate gradient for is separate from which ones we store for the backward.
With that we are all set. All that is left to do is export our operatorWe could just do static auto registry = torch::RegisterOperators().op("torchvision::roi_align", &roi_align);
similar to the tutorial examples, but for a widely used library like torchvision, it might be good to have argument names show up in error messages etc..
static auto registry = torch::RegisterOperators() .op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", &roi_align);
Done, here we have our own differentiable, traceable/scriptable custom operator torch.ops.torchvision.roi_align
!
Even for this little excercise, we include one test case - proper testing would cover much more. In test/test_ops.py
we find a test test_roi_align_gradient_cpu
testing the gradient. We add a test for our shiny new op
x2 = x.detach().requires_grad_() y2 = torch.ops.torchvision.roi_align(x2, rois, roi_align.spatial_scale, roi_align.output_size[0], roi_align.output_size[1], roi_align.sampling_ratio) y2.sum().backward() assert torch.allclose(x2.grad, gt_grad), 'gradient incorrect for RoIAlign CPU in custom op'
This can be run using python3 test/test_ops.py RoIAlignTester.test_roi_align_gradient_cpu
after installing torchvision
. It works!
Conclusion and Outlook
We inspected PyTorch's autograd mechanism in great detail to uncover how it works in C++. We then used this pattern to add ready differentiability into a C++ custom op. (Note that the backward will not be differentiable unless you use similar methods as we did here.)
As we saw, getting there had some rough edges and the Tensor/Variable merge might bring changes that break this (but probably also making it simpler, so it is good). In order to make this ready for mainstream consumption, we should bless some variant as an official way to implement Function
s in C++, so this blog post could be a tutorial on the PyTorch web site.
My code is here, but it is more an example than PR material.
Is this the best way to do differentiability in a JIT-compatible way? It is the one working. But would it not be neat to define forward and backward custom ops and just @torch.jit.script
our autograd.Function
? That would enable to tie into the JITs source-to-source differentiation capabilities and be much easier for implementors. But it needs quite a bit of hacking in PyTorch. I am very proud that I have made a prototype of that, and we will discuss that next time.
PyTorch Training
As you can tell, I like PyTorch internals. I also like to talk and write about them and about how to use PyTorch efficiently. I offer inhouse and public workshops for beginner, intermediate and PyTorch expert levels. If you are in near Munich (say, in Europe) and need PyTorch training, I love to hear from you! I also do bespoke development.
I hope this blog post is useful to you, I appreciate and read every mail you send to tv@lernapparat.de.