Making Sense of PyTorch
I’m a full-stack software engineer who’s been writing Python since One Laptop per Child around 13 years ago. I see that there are plenty of deep learning libraries in Python: PyTorch, TensorFlow, Keras, JAX. But when I peek inside example training notebooks and their objects, I can’t figure out what’s going on:
Sure, any new codebase is difficult to parse, but this was all new and math-y.
All I want to do is learn to label a folder of images, a really common and straightforward task.
This is a first draft of ‘explaining PyTorch to my previous self’, I’m not sure what the title should be exactly.
What I ended up discovering:
First impression: PyTorch code examples require a lot of code to accomplish a simple task (labeling images).
Reality: Most of these lines of code are knobs to control and tune and provide options for training.
First impression: I know what a neural network looks like, but I don’t see that in the way this code is written.
Reality: PyTorch models are written as nested modules.
What I thought neural networks were like: Round #1
What I thought neural networks were like: Round #2
What I think neural networks are like today: Round #3
- A great deal of work can be encapsulated in pre-trained networks and weights (such as ResNet). For a basic image labeler around that module I need to know only a linear transform and something like softmax or argsort.
Understanding inputs as tensors
There are guides to understanding what a tensor is, but it’s essentially an n-dimensional array with fixed dimensions and a uniform type.
Any type of text or image gets consumed by the neural network as a tensor. In practice, for images I’d start processing them by making them the same width x height x rgb dimensions. With Tweets I’d use word2vec or tokenizers to convert it into a 200+ dimensional vector and even that needs some padding tokens to make texts the same length.
Then to avoid fitting to one extreme or another, inputs typically are sent in batches of size 2^n (i.e. 8 photos x width x height x rgb). Then BatchNorm is commonly used between layers. When I have limited resources on Google CoLab, sometimes I reduce the batch size to 4 or 2.
Tensors also make it possible to process data quickly on GPUs or other specialized hardware.
Using a GPU
Training on GPU hardware saves time. You absolutely will want this to train image models. If you aren’t in a pre-installed ML notebook environment like CoLab, you likely need to install CUDA first.
This will be more complicated if you have too much data for one GPU, or want to train on a TPU.
AutoML and Pre-Trained Models vs. From Scratch
For a more hands-off approach, there are AutoML libraries (such as AutoKeras) which take in images and labels and produce models. I don’t want to knock this, especially if you just want a baseline accuracy score or peek at what pieces go into their model.
You can insert a pre-trained model which has been trained up on a general dataset (ImageNet). Copying the network’s structure and/or weights saves us a lot of time and resources.
In this case, I’ve used Ross Wightman’s PyTorch Image Models library (
pip install timm) which can bring in over 300 models, and decided on ResNeXt for my project.
Which config and options should I choose?
This is a question of research, scale, and specialization.
In GPT-type models, you use cross-entropy loss. For vector search, you use cosine similarity. For a lot of scenarios there is a preferred loss function, activation function, learning rate / schedule, etc. which you could use as a jumping-off point, or to bring into your project and focus on some other aspect of the project (a new dataset, pre-processing, etc).
For experimental projects, it’s totally OK to start with one of these preferred configs. You might change only one aspect of the model at a time to get a sense of whether that change had a positive effect.
For larger-scale and higher-stakes projects, there’s interest in hyperparameter search and/or training test models on a data sample before trying the best parameters on the full-scale model.
There’s also interesting research in having a preferred config/model and replacing one of its components with something new. Two examples in this category would be…
- Including the optimizer in the learned parameters:
GitHub - google/learned_optimization
learned_optimization is a research codebase for training learned optimizers. It implements hand designed and learned…
- Removing BatchNorm without sacrificing quality:
Training Deep Neural Networks Without Batch Normalization
Training neural networks is an optimization problem, and finding a decent set of parameters through gradient descent…
One of the surprisingly difficult parts of neural networks is knowing if they are working well.
For a labeling / classification problem with a manageable number of labels, you probably want a confusion matrix. If you got everything right, you would have a stripe of 1’s along the diagonal. The other squares show which is labeled as what.
Here’s some SciKit-Learn code for a labeled confusion matrix:
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplayConfusionMatrixDisplay.from_predictions(
Stuff that still is kinda sketchy (for researchers, not just newbies)
Transformers vs. Doing ResNet Better
After success in masked language models, transformer models with image-masking are becoming more popular. It seemed like a natural evolution to a new architecture, until ‘ResNet strikes back’ revisited the old architecture with new training settings. It could be that an even better ResNet exists with the right components.
ResNet strikes back: An improved training procedure in timm
The influential Residual Networks designed by He et al. remain the gold-standard architecture in numerous scientific…
Normalizing and Augmenting
When I said earlier that images were width x height x rgb, you might want to consider instead normalizing the pixel values or converting to grayscale to make the images more generalized or learn-able before feeding them into the neural network.
These and other transform operations are available in the
Data augmentation uses transforms (including some more radical changes) to images to increase the variety in training data. Ideally this makes it easier for the model to find true patterns which are more consistent (i.e. a car is detected whether it is upright or on its side, in light or dark images).
Sometimes these processes help even if the augmented images are unrecognizable by humans.
‘Adversarial’ covers many different practices now, but essentially: fooling a model is surprisingly easy. Especially when you have direct access to the model, you can iterate and find small changes which look similar to humans but strongly bias the model in a different direction.
In 2021, researchers from the Weizmann Institute of Science proposed The Dimpled Manifold Model to argue that most mental models of decision boundaries and training were clouded by the hundreds of dimensions of data used in models’ vector space. In this view (as I understand it), decision boundaries are a small ‘dimple’ in the space of typical data, and adversarial examples nudge an image outside of that dimple in some dimensional direction. Data augmentation then would be about expanding the dimple to cover more possible situations.
This paper is still controversial so I think the safest thing to say is… we are not 200-dimensional beings so we do not fully visualize and anticipate the flaws of 200-dimensional models.
Robustness and drift
Neural network research takes English words (attention, sparse, robustness) and packs in new meanings. I think the best way to describe ‘robustness’ would be whether a model can be consistently good at its job given real-world variation of a task. If an X-ray model is trained and tested on images from one machine or clinic, or if a face-recognition model was trained and tested on people in a particular age/race/gender group, we will likely find it is not robust enough to apply to others.
This is a separate term-of-art but a neighboring problem to drift (problems like changes over time, changes in human perception of labels, changes in distribution, etc).