PyTorch best practices
We look at the some best practices but also try to shed some light at the rationale behind it. Whether this becomes a series or an updated blog post, we will see.
Are you an intermediate PyTorch programmer? Are you following documented best practices? Do you form opinions on which ones to adhere to and which one can be foregone without trouble?
I will liberally admit that I sometimes have a hard time following the best practices when the thing they are advising against seems to work and I don't fully understand their rationale. So here is a little thing that happened to me.
A story of me doing stuff (Quantization)
After building PyTorch for the Raspberry Pi, I've been looking to do some fun projects with it. Sure enough, I found a model that I wanted to adapt to running on the Pi. I had the thing running soon enough, but it was not as fast as I would like it to be. So I started looking at quantizing it.
Quantization makes any operation stateful (temporarily)
Now if you think of a PyTorch computation as a set of (Tensor) values linked by operations, quantization consists of taking each operation and forming an opinion what range of values output Tensors would take in order to approximate numbers in that range by integers from the quantized element type via an affine transformation. Don't worry if that sounds complicated, the main thing is that now each operation needs to be associated "with an opinion", or more precisely an observer that records the minimum and maximum value that has been seen around over some exemplary use of the model. But this now means that during quantization, all operations become stateful.More precisely, they become stateful when preparing for quantization and until doing the quantization.
I often mention this when I advocate to not declare the activation function once and re-use it several times. This is because at the various points in the computation where the functions are used, the observer would, in general, see different values, so now they work differently.
This new stateful nature also applies to simple things like adding tensors, usually just expressed as a + b
. For this, PyTorch provides torch.nn.quantized.FloatFunctional
module. It is a usual Module
with the twist that instead of using forward
in the computation, there are several methods corresponding to basic operations, in our case .add
.
So I took the residual module, which looked roughly like thisNote how it declares activations separately, which is a good thing!:
class ResBlock(torch.nn.Module): def __init__(self, ...): self.conv1 = ... self.act1 = ... self.conv2 = ... self.act2 = ... def forward(self, x): return self.act2(x + self.conv2(self.act1(self.conv1(x))))
And I added self.add = torch.nn.quantized.FloatFunctional()
to __init__
and replaced the x + ...
with self.add.add(x, ...)
. Done!
With the model thus prepared, I could add the quantization itself, which is simple enough following the PyTorch tutorial. At the bottom of the evaluation script, with the model all loaded, set to eval etc., I added the following and restarted the notebook kernel I was working with and ran all this.
#config model.qconfig = torch.quantization.get_default_qconfig('qnnpack') torch.backends.quantized.engine = 'qnnpack' # wrap in quantization of inputs / de-quantization of output) model = torch.quantization.QuantWrapper(model) # insert observers torch.quantization.prepare(model, inplace=True)
and so later (after running the model a bit to get observations), I would call
torch.quantization.convert(model, inplace=True)
to get a model. Easy!
An unexpected error
And now I just had to run through a few batches of input.
preds = model(inp)
But what happened, was
ModuleAttributeError: 'ResBlock' object has no attribute 'add'
Bad!
What went wrong? Maybe I had a typo in ResBlock?
In Jupyter, you can check very easily using ?? model.resblock1
. But this was all right, no typos.
So this is where the PyTorch best practices comes in.
Serialization best practices
The PyTorch documentation has a note on serialization that contains - or consists of - a best practices section. It starts with
There are two main approaches for serializing and restoring a model. The first (recommended) saves and loads only the model parameters:
and then shows how this works using the state_dict()
and load_state_dict()
methods. The second method is to save and load the model.
The note provides the following rationale for preferring serializing parameters only:
However in [the case of saving the model], the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
It turns out that this is quite an understatement and even in our very modest modification - hardly a serious refactor - we ran into the problem it alludes to.
What went wrong?
To get to the core of what went wrong, we have to think about what object is in Python. In a gross oversimplification, it is completely defined by its __dict__
attribute holding all the ("data") members and its __class__
attribute pointing to its type (so e.g. for Module
instances, this will be Module
, and for Module
itself (a class) it will be type
) .
When we call a method, it typically is not in the __dict__
(it could be if we tried hard) but Python will automatically consult the __class__
to find the methods (or other things it could not find in the __dict__
).
When deserializing the model (and the author of the model I used didn't follow the best practice advice) Python will construct an object by looking up the type for the __class__
and combining it with the deserialized __dict__
. But the thing it (rightfully) does not do, is to call __init__
to set up the class (it should not, not lest because things might have been modified between __init__
and serialization or it might have side-effects we do not want). This means, that we get when we call the module, we are using the new forward
but get the __dict__
as prepared by the original author's __init__
and subsequent training, without the new attribute add
our modified __init__
added.
So this in a nutshell, this is why serializing PyTorch modules or generally objects in Python is dangerous: You very easily end up with something where the data attributes and the code are out of sync.
Maintaining compatibility
An obvious thing here - a drawback if you wish - is that we need to keep track of the configuration for setup in addition to the state dictionary. But you can easily serialize all the parameters along with the state dict if you want - just stick them into a joint dictionary.
But there are other advantages to not serializing the modules itself:
The obvious thing is that we can work with the state dictionary. We can load the state dictionary without having the Modules and we can inspect and modify the state dictionary if we changed something important.
The not quite as obvious thing is that the implementor or user can customize how modules process the state dict. This is in two ways:
-
For the users, there are hooks. Well, they're not very official, but so there is
_register_load_state_dict_pre_hook
which you can use to register hooks that process the state dict before it is used to update the model, and there is_register_state_dict_hook
to register hooks that are called after the state dict has been collected and before it is returned fromstate_dict()
. -
More importantly, though, implementors can override
_load_from_state_dict
. When the class has an attribute_version
, this is saved as theversion
metadata in the state dict. With this, you can add conversions from older state dictionaries BatchNorm provides an example of how to do this, it roughly looks like this:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if (version is None or version < 2) and self.have_new_thing: new_key = prefix + 'new_thing_param' if new_key not in state_dict: state_dict[new_key] = ... # some default here super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
So here, we check if the version is old and we need a new key, we add it before handing to the superclass (typically torch.nn.Module
) usual processing.
Summary
So here we saw in great detail what went wrong when we saved the model rather than following the best practice to save just the parameters. My personal takeaway is that the pitfall that saving models offers is rather large and easy to fall into, and so we should really care to save models only as parameters and not Module
classes.
I hope you enjoyed this little deep dive into a PyTorch best practice. More of this can be found in Piotr's and my imaginary PyTorch book and until it materializes in my no-nonsense PyTorch workshops. A special shout of thank you to Piotr, I couldn't do half the PyTorch things I do without him! Do send me an E-Mail tv@lernapparat.de if you want to become a better PyTorch programmer or I can help you with PyTorch and ML consulting.
I appreciate your comments and feedback at tv@lernapparat.de.