“AI” models tend to be slow. This is generally inevitable to a large degree, lots of calculations to do means more time to run. Simply throwing more compute at the problem is one solution, but limits the ability to run models “at the edge” (ie on a phone) and raises the cost.
Another solution is to come up with “clever” implementations of models that are naturally more efficient. For a basic example, division operations tend to slower than other calculations due to the nature of hardware implementations, so planning around avoiding division is one way to speed up math based code. In this case, for an image classifier neural network, people have already done the hard work in that regard. Here I am using the MobileNet family of models, from Google, optimized to be efficient models by using operations that can do more with less, such as here using depthwise convolutions rather than standard convolutions.
Here, for my goal of the fastest possible image inference, I took one more step, model quantization. Most neural networks are trained in FP32 precision, which means a number with a whole bunch of decimal points after. The idea of quantization is that if we cut down much smaller numbers, usually INT8, we have ‘smaller’ numbers to process. Smaller is faster, and can often use different, smaller, more efficient computing units on a CPU, or dedicated NPUs (many NPUs accept only INT8 models). An FP32 level of precision is necessary for training the gradients of a model through backpropagation (FP16 is also used sometimes), but once the model is trained, a model can often be shifted to INT8 without much precision loss. In static quantization, nothing about the model’s training is changed, simply the final model has weights processed down to INT8.
However, as one can imagine, slicing off all those decimal points can sometimes mess up the models accuracy and performance, which is why there is a technique, quantization aware training (QAT) which is just training the FP32 model for a short period time as if it were an INT8 model, then exporting to INT8, essentially allowing the model to learn to handle the precision loss.
There isn’t much clear data out there on quantization and especially QAT, so I set out to test to see what could get me the fastest results, testing across different mobilenet architectures. In this, I trained models on ImageNet (a well-known image dataset) with a subset down to 42 hand-picked classes. I did the subset because it is my belief that a lot of the ImageNet classes overlap too much and can be only trivially different, and generally the expectation is that smaller models can’t handle tiny finicky differences as well, and nor do they often need to. I wanted to test the model’s ability to handle a useful level of data, not an arbitrarily complex one.
This was tested on a laptop with RTX4060 GPU (for training) and AMD 7840HS CPU (used for training and INT8 inference).
Before getting into too much detail of the results, let’s first take a brief look at the overall conclusions.
Conclusions:
- Architecture matters the most for driving speed (and accuracy). As expected, MobileNet v4 was the fastest, then v3 and then v2.
- A 40% or so speedup is reasonable for INT8 quantization, but it varies per architecture and inference platform, some benefit much more than others.
- Static quantization sometimes works very well (delivering nearly identical accuracy to the FP32 model and much faster inference time) but sometimes completely fails and delivers much slower and very inaccurate results.
- QAT training is more reliable than static quantization, but isn’t always much faster.
- QAT training requires a lot more code and full training data, whereas static quantization is simpler.
- Surprisingly, number of nodes in the ONNX graph is not a good indicator of runtime. Going from 1014 to 69 nodes in the optimized graph should seem to be a huge speedup but is actually about the same runtime.
One big caveat, I noticed around +/- 0.05 ms or around 5% performance variation between different runs of the same onnx model. This is due to things like CPU temp (already hot from previous tests), background processes, and other little things. Accordingly, factor that in when considering results.
From going through this process, my conclusion is that static quantization should be used first, making sure to validate the results. If the accuracy is good, and the speedup is all that is hoped for, then there is no need to bother with slower, harder QAT. However if static quantization is disappointing, QAT does deliver more consistent improvements.
I should also add that quantization code is a bit of a mess right now. Pytorch (2.7 here) has at least 3 different partial APIs for using it, and it isn’t clear what is best. Even messaging some of the team at Meta and they didn’t quite all seem to know what the official way forward is. I think torch.ao.quantization, which I used here, is the best path for now. You can see the training and testing code I used for this here: https://github.com/winedarksea/qat_onnx_models
Accuracy results in this report should be taken with a degree of skepticism. My primary task here was evaluating inference time. For consistency they generally used the same training parameters but for runtime I sometimes cut the number of epochs. This isn’t particularly tuned, it is probably not truly the “best” these architectures could do. The two highest accuracy models here, MNv3 Large and MNv2 both had pretrained weights available as a starting point, which likely contributed to that performance.
One architecture here, MobileNet v2, actually was slower in all quantized forms, so it does seem quantization isn’t always going to work well. Perhaps that could be solved with custom graph fusing or a custom MobileNet v2 built with quantization in mind, but those are rather advanced, time-consuming actions to take.
Model | Training | Save Type | Runtime | Accuracy Top 1 | Accuracy Top 5 | F1 Score | Nodes |
Mnv3 | Epochs 500, pretrained | Fp32 | 1.8584 ms | 89.19% | 97.57% | 89.08% | |
INT8 Static | 2.8002 ms | 37.62% | 66.77% | 40.15% | |||
INT8 QAT | 1.6681 ms | 86.23% | 97.21% | 86.18% | |||
INT8 QAT Opt | 1.6605 ms | 86.23% | 97.21% | 86.18% | |||
Mnv4c | Epochs 550 | Fp32 | 1.4384 ms | 89.92% | 98.03% | 89.99% | |
INT8 Static | 1.8281 ms | 89.43% | 97.99% | 89.54% | |||
INT8 QAT | 0.8578 ms | 89.62% | 97.92% | 89.66% | |||
INT8 QAT Opt | 0.8472 ms | 89.62% | 97.92% | 89.66% | |||
Mnv4c | Epochs 550 | Fp32 | 2.0254 ms | 90.26% | 98.13% | 90.35% | 98 |
1.2 width | INT8 Static | 2.4744 ms | 89.86% | 98.08% | 89.85% | 280 | |
INT8 QAT | 1.1395 ms | 89.99% | 98.20% | 90.07% | 1014 | ||
INT8 QAT Opt | 1.0963 ms | 89.99% | 98.20% | 90.07% | 69 | ||
FP32 Dynamo | 2.1441 ms | 90.26% | 98.13% | 90.35% | 158 | ||
INT8 Static Torch | 1.1018 ms | 90.20% | 98.09% | 90.35% | 1014 | ||
Mnv4c | Epochs 550 | FP32 | 1.1883 ms | 89.16% | 97.97% | 89.15% | 98 |
0.8 width | |||||||
INT8 Static | 1.7362 ms | 87.00% | 97.38% | 86.88% | 280 | ||
INT8 QAT | 0.8055 ms | 88.94% | 98.09% | 88.96% | 1014 | ||
INT8 QAT Opt | 0.8236 ms | 88.94% | 98.09% | 88.96% | 69 | ||
FP32 Dynamo | 1.3882 | 89.16% | 97.97% | 89.15% | 158 | ||
INT8 Static Torch | 0.8576 ms | 88.44% | 97.74% | 88.52% | 1014 | ||
Mnv4s | Epochs 105 | Fp32 | 1.1893 ms | 88.34% | 97.75% | 88.41% | |
INT8 Static | 1.8135 ms | 87.74% | 97.65% | 87.67% | |||
INT8 QAT | 0.9114 ms | 88.02% | 97.77% | 87.94% | 1012 | ||
0.8211 ms | |||||||
INT8 QAT Opt | 0.8571 ms | 88.02% | 97.77% | 87.94% | |||
Mnv4s | Epochs 750 | Fp32 | 1.2013 ms | 89.99% | 97.92% | 89.96% | 96 |
INT8 Static | 1.8782 ms | 89.96% | 97.87% | 90.01% | 280 | ||
INT8 QAT | 0.8860 ms | 89.60% | 97.80% | 89.60% | 1012 | ||
INT8 QAT Opt | 0.8290 ms | 89.60% | 97.80% | 89.60% | 69 | ||
FP32 Dynamo | 1.3653 ms | 89.99% | 97.92% | 89.96% | 156 | ||
INT8 Static Torch | 0.8990 ms | 89.29% | 97.28% | 89.30% | 1012 | ||
Mnv2 | Epochs 105, pretrained | Fp32 | 2.0646 ms | 91.49% | 98.59% | 91.53% | |
INT8 Static | 3.2227 ms | 91.34% | 98.50% | 91.28% | |||
INT8 QAT | 4.0154 ms | 90.28% | 98.31% | 90.23% | |||
INT8 QAT Opt | 4.0773 ms | 90.28% | 98.31% | 90.23% | |||
FP32 Dynamo | 2.1608 ms | 91.49% | 98.59% | 91.53% | |||
Mnv4m | Epochs 105 | Fp32 | 4.2657 ms | 88.61% | 97.87% | 88.73% | 143 |
INT8 Static | 5.6956 ms | 88.56% | 97.62% | 88.58% | 459 | ||
INT8 QAT | 1.8904 ms | 88.36% | 97.62% | 88.40% | 1674 | ||
INT8 QAT Opt | 1.9162 ms | 88.36% | 97.62% | 88.40% | 108 | ||
FP32 Dynamo | 4.1543 ms | 88.61% | 97.87% | 88.73% | 234 | ||
Mnv3l | Epochs 75, pretrained | Fp32 | 3.1451 ms | 92.05% | 98.35% | 92.12% | 147 |
INT8 Static | 7.8488 ms | 18.18% | 36.36% | 20.08% | 500 | ||
INT8 QAT | 2.6631 ms | 90.77% | 97.87% | 90.72% | 1686 | ||
INT8 QAT Opt | 3.0029 ms | 90.77% | 97.87% | 90.72% | 208 | ||
FP32 Dynamo | 4.2040 ms | 92.05% | 98.35% | 92.12% | 319 | ||
Colin Catlin, 2025