How to Convert PyTorch Models to ONNX Format for Faster Inference
15 mins read

How to Convert PyTorch Models to ONNX Format for Faster Inference

I remember the first time I deployed a PyTorch model to production. I wrapped a beautifully trained ResNet model in a Flask API, spun up a Docker container, and watched in horror as the latency spiked to 800 milliseconds per request under moderate load. PyTorch is an absolute joy for research and training, but its native eager execution mode and reliance on the Python GIL make it a massive bottleneck for high-throughput production environments.

If you want to serve models at scale, you need to step outside the native PyTorch ecosystem. You need a format that C++ backends, hardware accelerators, and specialized inference servers can understand. This is where ONNX (Open Neural Network Exchange) comes in. When you convert PyTorch model to ONNX format, you decouple your model’s architecture and weights from the Python runtime, opening the door to massive latency reductions and throughput gains.

In my experience building machine learning infrastructure, mastering this conversion process is the single highest-ROI skill for an ML engineer transitioning from research to deployment. Whether you plan to compile your model with TensorRT for NVIDIA GPUs, run it through OpenVINO for Intel hardware, or serve it globally via Triton Inference Server, ONNX is the required bridge.

The Toolchain: What You Need to Get Started

Before we touch any code, let’s get our environment strictly defined. Mismatched versions between PyTorch and ONNX are the root cause of 90% of export failures. If you’ve been following recent PyTorch News or ONNX News, you know that the ecosystem moves fast, and backwards compatibility isn’t always guaranteed.

For this tutorial, I am using the following stack. I highly recommend pinning these versions in your requirements.txt or pyproject.toml:

  • torch==2.2.0 (The core framework)
  • onnx==1.15.0 (The ONNX specification and checker)
  • onnxruntime==1.17.0 (The inference engine to test our exported model)
  • onnxsim==0.4.33 (A crucial tool we’ll use to optimize the exported graph)

Install them via pip:

pip install torch onnx onnxruntime onnxsim torchvision

The Golden Rule of ONNX Export: Tracing vs. Scripting

Before we write the export script, you need to understand exactly what PyTorch is doing under the hood. When you call torch.onnx.export, PyTorch does not parse your Python code. Instead, it uses a technique called tracing.

PyTorch pushes a “dummy” tensor through your model. As the tensor flows through your forward pass, PyTorch records every single operation (convolutions, matrix multiplications, activations) into a static graph. This has a massive implication: data-dependent control flow will be silently hardcoded.

If your forward() method contains an if x.sum() > 0: statement, the ONNX graph will only contain the branch that the dummy tensor triggered during the export. If you need dynamic control flow, you have to use TorchScript first, but for 95% of standard CNNs, LLMs, and Transformers, basic tracing is exactly what you want.

How to Convert PyTorch Model to ONNX: The Baseline Export

Let’s start by exporting a standard Computer Vision model. We’ll use a pre-trained ResNet18 from torchvision. This is the exact boilerplate I use as a starting point for every vision model deployment.

import torch
import torchvision.models as models

# 1. Initialize the model and set it to evaluation mode
# THIS IS CRITICAL. If you leave it in training mode, Dropout and BatchNorm 
# will behave incorrectly in the exported ONNX graph.
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.eval()

# 2. Create a dummy input tensor
# The dimensions must exactly match what your model expects.
# For ResNet: Batch Size (1), Channels (3), Height (224), Width (224)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)

# 3. Define the export path
onnx_file_path = "resnet18_static.onnx"

# 4. Export the model
torch.onnx.export(
    model,                      # The PyTorch model
    dummy_input,                # The dummy input tensor
    onnx_file_path,             # Where to save the file
    export_params=True,         # Store the trained parameter weights inside the model file
    opset_version=17,           # The ONNX version to export the model to
    do_constant_folding=True,   # Optimize the graph by pre-computing constant nodes
    input_names=['input_image'],   # Give the input a semantic name
    output_names=['class_logits']  # Give the output a semantic name
)

print(f"Successfully exported model to {onnx_file_path}")

Let’s break down the critical parameters here, because copying and pasting without understanding will lead to immense pain later:

  • model.eval(): I cannot stress this enough. If you forget this, your ONNX model will contain training-specific dropout nodes, and your inference results will be garbage.
  • opset_version=17: The opset defines the mathematical operators available in the ONNX standard. Older opsets (like 11 or 12) lack support for newer PyTorch functions. Unless you are deploying to legacy hardware, always use opset 16 or 17.
  • do_constant_folding=True: This is a free performance boost. If your graph has operations that only involve constants (e.g., reshaping a static weight matrix), PyTorch will compute it once during export rather than forcing the ONNX runtime to compute it on every inference pass.

The #1 Production Trap: Handling Dynamic Axes

machine learning code on screen - Data center green screen computers showing neural network ...

If you take the resnet18_static.onnx file we just generated and deploy it to a FastAPI endpoint, you will hit a wall the moment you try to process a batch of two images. Because we used a dummy input with a batch size of 1, the ONNX graph has hardcoded the batch dimension to exactly 1.

In the real world, batch sizes fluctuate. If you are building robust APIs—especially if you keep up with FastAPI News and asynchronous batching techniques—you need dynamic axes. Here is how you modify the export to allow flexible batch sizes and image dimensions.

# Define which dimensions are allowed to change at runtime
dynamic_axes_config = {
    'input_image': {
        0: 'batch_size',  # The 0th dimension is the batch size
        2: 'height',      # The 2nd dimension is height
        3: 'width'        # The 3rd dimension is width
    },
    'class_logits': {
        0: 'batch_size'   # The output's 0th dimension will match the dynamic batch size
    }
}

torch.onnx.export(
    model,
    dummy_input,
    "resnet18_dynamic.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=['input_image'],
    output_names=['class_logits'],
    dynamic_axes=dynamic_axes_config  # Inject the dynamic configuration
)

By mapping the integer indices of the dimensions to string names (like 'batch_size'), we tell the ONNX exporter to leave those dimensions as variables. Now, your inference server can accept batches of 1, 8, or 64 without throwing a shape mismatch error.

Validating the ONNX Model (Trust, but Verify)

Just because PyTorch generated an .onnx file without throwing a Python exception does not mean the model is valid. The graph might be malformed, or floating-point drift might have destroyed your accuracy.

As a senior developer, I never deploy an exported model without running an automated parity check. We need to run the exact same tensor through both PyTorch and ONNX Runtime and assert that the outputs are mathematically identical.

import onnx
import onnxruntime as ort
import numpy as np

# 1. Structural Validation
onnx_model = onnx.load("resnet18_dynamic.onnx")
try:
    onnx.checker.check_model(onnx_model)
    print("ONNX graph is structurally valid.")
except onnx.checker.ValidationError as e:
    print(f"The model is invalid: {e}")

# 2. Mathematical Parity Check
# Generate a fresh random tensor
test_input = torch.randn(4, 3, 256, 256) # Testing dynamic batch and spatial size

# Get PyTorch output
with torch.no_grad():
    pytorch_output = model(test_input).numpy()

# Get ONNX Runtime output
ort_session = ort.InferenceSession("resnet18_dynamic.onnx", providers=['CPUExecutionProvider'])
ort_inputs = {ort_session.get_inputs()[0].name: test_input.numpy()}
ort_output = ort_session.run(None, ort_inputs)[0]

# Compare the results using numpy's allclose
# We allow a very small tolerance (1e-5) for floating point arithmetic differences
np.testing.assert_allclose(pytorch_output, ort_output, rtol=1e-03, atol=1e-05)
print("SUCCESS: PyTorch and ONNX outputs match!")

Notice the rtol and atol parameters in assert_allclose. CPU and GPU backends handle floating-point math slightly differently (fused multiply-add vs discrete operations). A tiny variance is normal; massive variance means your model contains an operator that exported incorrectly.

Simplifying the Graph with ONNX-Simplifier

PyTorch’s exporter is notorious for leaving “glue” nodes in the graph—unnecessary Gather, Unsqueeze, and Cast operations that slow down inference. Before moving to a production server, I always run the model through onnxsim.

onnxsim resnet18_dynamic.onnx resnet18_dynamic_opt.onnx

This CLI tool mathematically reduces the graph, folding redundant operations. If you open the before and after files in Netron (an absolute must-have visualizer for ONNX models), you will often see a 10-20% reduction in total node count.

Advanced Export: Hugging Face Transformers

If you follow Hugging Face Transformers News, you know that NLP models are significantly more complex to export than CNNs. They require multiple inputs (input_ids, attention_mask, and sometimes token_type_ids), and they heavily rely on dynamic sequence lengths.

While you can use Hugging Face’s optimum library for a shortcut, doing it manually with PyTorch teaches you exactly how multi-input exports work.

from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
nlp_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
nlp_model.eval()

# Create dummy text inputs
text = ["This is a sample sentence for ONNX export."]
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Define dynamic axes for both batch size and sequence length
nlp_dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'logits': {0: 'batch_size'}
}

torch.onnx.export(
    nlp_model,
    (inputs['input_ids'], inputs['attention_mask']), # Tuple of inputs
    "distilbert.onnx",
    input_names=['input_ids', 'attention_mask'],
    output_names=['logits'],
    dynamic_axes=nlp_dynamic_axes,
    opset_version=16
)

The crucial detail here is passing the inputs as a tuple: (inputs['input_ids'], inputs['attention_mask']). PyTorch tracing requires raw tensors, not Python dictionaries. When you write your inference code later, you must pass the inputs to the ONNX session using the exact string names defined in input_names.

Benchmarking: Proving the ROI

Why did we do all of this? Let’s write a quick benchmarking script to prove the latency reduction. I usually run this to justify the engineering time spent on optimization.

machine learning code on screen - AI development workstations in data center with green screen ...

import time

def benchmark(model_func, inputs, iterations=100):
    # Warmup
    for _ in range(10):
        model_func(inputs)
        
    start_time = time.perf_counter()
    for _ in range(iterations):
        model_func(inputs)
    end_time = time.perf_counter()
    
    return (end_time - start_time) / iterations * 1000 # returns ms per inference

# PyTorch Benchmark
pytorch_latency = benchmark(lambda x: model(x), test_input)

# ONNX Benchmark
ort_latency = benchmark(lambda x: ort_session.run(None, x), ort_inputs)

print(f"PyTorch Latency: {pytorch_latency:.2f} ms")
print(f"ONNX Runtime Latency: {ort_latency:.2f} ms")
print(f"Speedup: {pytorch_latency / ort_latency:.2f}x")

Even on a standard CPU execution provider, you will typically see a 2x to 3x speedup. But this is just the beginning.

Next Steps: TensorRT and Triton Inference Server

Converting a PyTorch model to ONNX is rarely the final step; it is the enabler for true hardware acceleration. If you are deploying on NVIDIA hardware and tracking TensorRT News, you know that ONNX is the primary ingestion format for TensorRT. By passing your .onnx file into trtexec, NVIDIA’s compiler will fuse layers, convert weights to FP16 or INT8, and aggressively optimize memory bandwidth specifically for your target GPU architecture (like Ampere or Hopper).

Similarly, if you are building enterprise AI platforms—perhaps leveraging AWS SageMaker News or Azure Machine Learning News trends—you will likely deploy this model using NVIDIA’s Triton Inference Server. Triton natively supports ONNX runtime backends, allowing you to serve the model via gRPC or HTTP with zero-copy shared memory, dynamic batching, and concurrent model execution out of the box.

For Intel environments, you would take this exact same ONNX file and compile it using OpenVINO. The beauty of the ONNX standard is that you write the export code once, and the hardware deployment landscape becomes entirely agnostic.

Frequently Asked Questions

Can I convert any PyTorch model to ONNX?

Most standard architectures (CNNs, Transformers, MLPs) export flawlessly. However, models with complex, data-dependent control flow (like dynamic while-loops) or highly customized C++ operators may fail during the tracing process. In those cases, you must rewrite the custom logic using standard PyTorch operations or register a custom ONNX symbolic function.

Why is my ONNX model output slightly different from PyTorch?

Minor discrepancies (differences smaller than 1e-4) are normal and result from floating-point arithmetic differences between the PyTorch eager engine and the ONNX Runtime backend. If the difference is large, it usually means you forgot to call model.eval() before exporting, causing dropout or batch normalization to remain active.

How do I visually inspect the architecture of an ONNX file?

The industry standard tool for this is Netron. It is an open-source visualizer that allows you to drag and drop your .onnx file into a browser window to inspect every node, tensor shape, and learned weight in the exported computational graph.

Does ONNX support PyTorch 2.0’s torch.compile?

torch.compile() is PyTorch’s native JIT compiler (using TorchDynamo) designed to speed up native PyTorch execution, whereas ONNX is an entirely separate export format. You do not use torch.compile() before exporting to ONNX; you simply export the standard nn.Module and let ONNX Runtime or TensorRT handle the compilation and optimization.

Conclusion

Learning how to convert PyTorch model to ONNX format fundamentally shifts how you approach machine learning engineering. You stop viewing models as Python scripts and start viewing them as portable, highly optimizable computational graphs. By strictly managing your opset versions, handling dynamic axes properly, and rigorously validating parity between PyTorch and ONNX Runtime, you eliminate the massive latency overhead of native Python execution.

Take the time to integrate this export process into your CI/CD pipelines. Combine it with graph simplifiers like onnxsim, and use it as the foundation for hardware-specific compilers like TensorRT or OpenVINO. The effort you spend mastering this conversion today will pay off exponentially the next time you need to scale an AI API to millions of requests.