The Problem
You've fine-tuned a BERT model that achieves great accuracy on your task. But when you deploy it, inference takes 200ms per request — way too slow for a real-time API with a 50ms SLA. Sound familiar?
This post walks through the exact optimization pipeline I used to reduce transformer inference latency by 92% while retaining 98% of the original accuracy.
Baseline Measurement
Always start by profiling your baseline. Here's what our starting point looked like:
import torch
import time
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Benchmark
inputs = tokenizer("Sample input text for benchmarking", return_tensors="pt")
times = []
for _ in range(100):
start = time.perf_counter()
with torch.no_grad():
outputs = model(**inputs)
times.append(time.perf_counter() - start)
print(f"Mean latency: {sum(times)/len(times)*1000:.1f}ms")
# Output: Mean latency: 198.3ms (CPU)
Step 1: ONNX Export
Converting to ONNX format alone provides significant speedup through graph optimizations:
from optimum.exporters.onnx import main_export
main_export(
model_name_or_path="./fine-tuned-bert",
output="./onnx-model",
task="text-classification",
)
Result: 198ms → 85ms (57% reduction)
Step 2: Quantization
Dynamic INT8 quantization reduces model size and leverages integer arithmetic:
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="./onnx-model/model.onnx",
model_output="./onnx-model/model_quantized.onnx",
weight_type=QuantType.QInt8,
)
Result: 85ms → 32ms (62% further reduction)
Step 3: Knowledge Distillation
For the final push, we trained a 4-layer DistilBERT student model:
| Model | Params | Accuracy | Latency |
|---|---|---|---|
| BERT-base | 110M | 92.3% | 198ms |
| ONNX + Quantized BERT | 110M | 91.8% | 32ms |
| Distilled + Quantized | 28M | 90.5% | 15ms |
Key Takeaways
- Profile before optimizing — know where your bottleneck actually is
- ONNX export is low-hanging fruit — always do this first
- Quantization is nearly free — minimal accuracy loss for big speedups
- Distillation is worth it if you need sub-20ms latency
- Batch your requests when possible — GPU utilization matters