Image Classification in Rust with Tch-rs (Torch bindings)

by Didin J. on Nov 03, 2025 Image Classification in Rust with Tch-rs (Torch bindings)

Learn how to build an image classification API in Rust using tch (PyTorch bindings) and Axum, with ResNet18, inference, and top-5 predictions.

Rust has rapidly become a favorite among developers who want both performance and safety. While languages like Python dominate the machine learning landscape, Rust is increasingly being used for AI and data-intensive applications — especially when speed, memory efficiency, and low-level control are critical.

In this tutorial, you’ll learn how to perform image classification in Rust using the powerful tch-rs library — the official Rust bindings for LibTorch, the core engine behind PyTorch. You’ll see how to:

  • Set up a Rust project with tch-rs

  • Load a pretrained neural network model such as ResNet18

  • Preprocess and load images into tensors

  • Run predictions and interpret the results

  • Optionally, load your own trained model for inference

By the end of this guide, you’ll have a fully working Rust-based image classifier capable of predicting objects in images with just a few lines of code — all powered by the performance of Torch and the safety of Rust.


Prerequisites

Before diving into code, make sure your system is ready for Rust and Torch development. In this section, we’ll set up the environment, install the required tools, and ensure that everything works smoothly.

What You’ll Need

  • Basic Rust knowledge – You should be comfortable with creating projects using Cargo, working with modules, and handling external crates.

  • Rust installed – Ensure you have the latest stable version of Rust and Cargo. You can check this by running:

     
    rustc --version
    cargo --version

     

    If not installed, download it using Rustup:

     
    curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

     

  • LibTorch installedtch-rs requires LibTorch, the C++ backend used by PyTorch. You can download the prebuilt binaries from the official PyTorch website.
    Choose the LibTorch C++/Java distribution that matches your OS and CUDA configuration (use CPU-only if you don’t have a GPU).

    For example, on Linux:

     
    wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-latest.zip
    unzip libtorch-shared-with-deps-latest.zip
    export LIBTORCH=$(pwd)/libtorch
    export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH

     

    On macOS, the commands are similar (just adjust the paths accordingly).

    💡 Tip: You can add these environment variables to your .bashrc or .zshrc to make them persistent.

  • C++ Build Tools – Since tch-rs links to LibTorch (a C++ library), you’ll need a working C++ compiler toolchain:

    • macOS: Xcode Command Line Tools (xcode-select --install)

    • Linux: sudo apt install build-essential

    • Windows: Visual Studio Build Tools

  • An Image to Classify – Prepare a test image (e.g., cat.jpg, dog.jpg, or car.jpg) in your project directory for testing.

Optional Tools

  • VS Code or JetBrains RustRover for Rust development

  • Git for version control

  • Python (optional) if you plan to export your own .pt models from PyTorch

Once these prerequisites are in place, you’re ready to create your Rust project and start coding.


Setting Up the Rust Project

With your environment ready, the next step is to create a new Rust project and configure it to use the tch crate, which provides the bindings to the PyTorch C++ library (LibTorch).

Step 1: Create a New Cargo Project

Open your terminal and create a new Rust binary project called rust-image-classifier:

cargo new rust-image-classifier
cd rust-image-classifier

This will create a directory structure like:

rust-image-classifier/
├── Cargo.toml
└── src/
    └── main.rs

Step 2: Add the tch Dependency

Open the Cargo.toml file and add the following under [dependencies]:

[package]
name = "rust-image-classifier"
version = "0.1.0"
edition = "2024"

[dependencies]
tch = "0.22.0"
anyhow = "1.0"

💡 Check the latest version of tch on crates.io and update accordingly if there’s a newer release.

Step 3: Set Up the LibTorch Environment Variables

Tch-rs needs to know where your LibTorch installation is located. Make sure the following environment variables are set in your shell before running the program:

On macOS/Linux:

export LIBTORCH=/opt/homebrew/Cellar/pytorch/2.9.0_1
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH

On Windows (PowerShell):

setx LIBTORCH "C:\path\to\libtorch"
setx PATH "%LIBTORCH%\lib;%PATH%"

You can verify that tch detects LibTorch correctly later by running your Rust program (we’ll check this soon).

Step 4: Verify the Setup

Before moving on, let’s ensure that tch can load and print the Torch version successfully.
Open src/main.rs and replace the contents with:

use tch::{ Tensor, Device };

fn main() {
    // Create a simple tensor using Tensor::from_slice
    let data = [1.0f32, 2.0, 3.0, 4.0, 5.0];
    let tensor = Tensor::from_slice(&data);
    println!("Tensor: {:?}", tensor);

    // Check available device
    let device = Device::cuda_if_available();
    println!("Using device: {:?}", device);

    // Move tensor to the device and perform a simple operation
    let tensor_on_device = tensor.to_device(device);
    let result = &tensor_on_device * 2;
    println!("Tensor * 2 = {:?}", result);
}

Now, run the program:

cargo run

If everything is set up correctly, you’ll see output similar to:

Tensor: [1.0, 2.0, 3.0, 4.0, 5.0]
Using device: Cpu
Tensor * 2 = [2.0, 4.0, 6.0, 8.0, 10.0]

If you have CUDA properly configured, it might show Cuda(0) instead of Cpu.

At this point, your Rust project is ready to use the Torch bindings!


Installing and Using tch-rs

Now that your project is set up and verified, let’s explore how to use the tch crate (Rust bindings for PyTorch) effectively. This section covers what’s inside tch, how it interfaces with LibTorch, and how to use its core modules.

What Is tch-rs?

tch-rs is a Rust wrapper around LibTorch, the C++ backend of PyTorch.
It provides a safe and idiomatic Rust API for performing operations such as:

  • Tensor creation and manipulation

  • Neural network model loading and inference

  • GPU/CPU device handling

  • Access to TorchVision models like ResNet, VGG, and MobileNet

Essentially, tch-rs lets you do inference and training in Rust using the same underlying engine as PyTorch — but with Rust’s safety, performance, and type guarantees.

Core Modules Overview

Some of the most commonly used modules include:

Module Description
tch::Tensor The core data structure, similar to PyTorch tensors
tch::nn Defines neural network layers and models
tch::vision Provides pretrained models and image utilities
tch::Device Handles computation devices (CPU, CUDA)
tch::kind Defines tensor data types (Float, Int, etc.)

Example: Creating and Manipulating Tensors

Let’s try a few tensor operations to get familiar with the API.
Replace your src/main.rs with the following:

use tch::{Tensor, Device, Kind};

fn main() {
    // Create a tensor of random values
    let random_tensor = Tensor::randn([3, 3], (Kind::Float, Device::Cpu));
    println!("Random tensor:\n{:?}", random_tensor);

    // Create a tensor filled with zeros
    let zeros = Tensor::zeros([2, 4], (Kind::Float, Device::Cpu));
    println!("Zeros tensor:\n{:?}", zeros);

    // Perform arithmetic operations
    let ones = Tensor::ones([2, 4], (Kind::Float, Device::Cpu));
    let sum = &zeros + &ones;
    println!("Sum of zeros + ones:\n{:?}", sum);

    // Matrix multiplication
    let a = Tensor::randn([2, 3], (Kind::Float, Device::Cpu));
    let b = Tensor::randn([3, 2], (Kind::Float, Device::Cpu));
    let result = a.matmul(&b);
    println!("Matrix multiplication result:\n{:?}", result);
}

Run it:

cargo run

Expected output (values will vary):

Random tensor:
Tensor[[3, 3], Float]
Zeros tensor:
Tensor[[2, 4], Float]
Sum of zeros + ones:
Tensor[[2, 4], Float]
Matrix multiplication result:
Tensor[[2, 2], Float]

This demonstrates how simple tensor math works in Rust with Torch.

💡 Pro Tip

If you’re familiar with PyTorch in Python, you’ll notice that tch-rs is conceptually similar — most function names and tensor operations are almost identical. This makes transitioning between the two languages straightforward for developers.


Loading a Pretrained Model (ResNet18)

One of the most powerful features of tch-rs is the ability to use pretrained models directly from TorchVision. These models come ready-trained on large datasets (like ImageNet), meaning you can perform high-quality image classification without building or training anything from scratch.

In this section, we’ll load the ResNet18 model — a lightweight and popular convolutional neural network for image recognition.

Step 1: Import Required Modules

Open src/main.rs and replace the code with:

use anyhow::Result;
use tch::{ Device, nn, vision::resnet };

fn main() -> Result<()> {
    // Choose device (CPU or CUDA)
    let device = Device::cuda_if_available();
    println!("Using device: {:?}", device);

    // Create a variable store
    let vs = nn::VarStore::new(device);

    // Load ResNet18 model with pretrained weights
    let _model = resnet::resnet18(&vs.root(), 1);

    println!("✅ ResNet18 model loaded successfully!");
    Ok(())
}

Step 2: Run the Program

cargo run

When you run this, the model will download automatically the first time (to ~/.cache/torch/hub/checkpoints) and then load from cache afterward.
You should see:

Using device: Cpu
ResNet18 model loaded successfully!

If you have a CUDA-compatible GPU and LibTorch CUDA version, it may show:

Using device: Cuda(0)
ResNet18 model loaded successfully!

If you see no extra output, is simply because the model is being loaded successfully, but you’re not doing any inference or printing any tensors yet. The ResNet model itself doesn’t log anything internally — it’s just constructed and ready to use.

Step 3: Understanding What Happens

  • nn::VarStore manages the model parameters (weights and biases).

  • resnet::resnet18(&vs.root(), true) loads the model with pretrained ImageNet weights.

  • If you set the second argument to false, it initializes the model with random weights — useful for training from scratch.

  • The model is ready to perform inference (classification) on any input tensor that matches its expected shape and normalization.

Step 4: Optional — Print Model Summary

You can inspect the architecture of the loaded model by adding:

println!("{:?}", model);

This will print the structure of ResNet18 layers — useful for debugging or educational purposes.

💡 Note on Model Download

If you are behind a proxy or working offline, you can manually download the pretrained weights:

  1. Go to PyTorch Model Zoo

  2. Download resnet18-f37072fd.pth

  3. Place it under ~/.cache/torch/hub/checkpoints/

Then tch-rs will load it locally without attempting to re-download.


Loading and Preprocessing Images

Now that your ResNet18 model loads successfully, let’s feed it an image and get predictions.

Step 1: Add an Example Image

Put an image file (for example grass.jpg) inside your project folder:

rust-image-classifier/
 ├─ src/
 ├─ Cargo.toml
 └─ grass.jpg

Step 2: Update main.rs

Here’s a complete working example that:

  • Loads the image

  • Converts it into a tensor

  • Runs it through ResNet18

  • Prints the top prediction index

use anyhow::Result;
use tch::{ Device, nn::{ self, ModuleT }, vision::{ imagenet, resnet } };

fn main() -> Result<()> {
    // Choose device (CPU or CUDA)
    let device = Device::cuda_if_available();
    println!("Using device: {:?}", device);

    // Create a variable store
    let vs = nn::VarStore::new(device);

    // Load ResNet18 model
    let model = resnet::resnet18(&vs.root(), 1);

    // Load and preprocess an image
    let image = imagenet::load_image_and_resize224("grass.jpg")?; // Result<Tensor>
    let image = imagenet::normalize(&image)?; // <- normalize() also returns Result<Tensor>

    let input_tensor = image
        .unsqueeze(0) // add batch dimension
        .to_device(device);

    // Run inference
    let output = model.forward_t(&input_tensor, false);
    let predicted = output.argmax(1, false);
    let class_idx = predicted.int64_value(&[0]);

    println!("🧠 Predicted class index: {}", class_idx);

    // Load ImageNet labels manually
    let labels: Vec<&str> = include_str!("imagenet_classes.txt").lines().collect();
    println!("🐾 Predicted class label: {}", labels[class_idx as usize]);

    Ok(())
}

Download the official ImageNet class list file from PyTorch:

https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

Then save it in your src/ folder beside your main file:

src/
 ├── main.rs
 ├── imagenet_classes.txt

Step 3: Run It

cargo run

Output example:

Using device: Cpu
🧠 Predicted class index: 207
🐾 Predicted class label: golden retriever

🧠 What Happens Here

  • imagenet::load_image_and_resize224() loads and resizes to 224×224.

  • imagenet::normalise() applies the standard ImageNet mean/std normalization.

  • unsqueeze(0) adds the batch dimension ([1,3,224,224]).

  • .forward() runs inference.

  • argmax(1) picks the top predicted label index.

  • imagenet::CLASS_LABELS provides 1000 ImageNet category names.


Displaying Top-5 Predictions and Confidence Scores

Now that your model successfully produces output tensors, let’s extract and display the top 5 most likely ImageNet classes with their confidence percentages.

🦀 Code Example

use anyhow::Result;
use tch::{ Device, Kind, Tensor, nn::{ self, ModuleT }, no_grad, vision::{ imagenet, resnet } };

fn main() -> Result<()> {
    // Choose device (CPU or CUDA)
    let device = Device::cuda_if_available();
    println!("Using device: {:?}", device);

    // Create variable store
    let vs = nn::VarStore::new(device);

    // Load ResNet18 (1000 classes for ImageNet)
    let model = resnet::resnet18(&vs.root(), 1000);

    // Load and preprocess an image
    let image = imagenet::load_image_and_resize224("grass.jpg")?;
    let image = imagenet::normalize(&image)?;
    let input_tensor = image.unsqueeze(0).to_device(device);

    // Disable gradients during inference
    let output = no_grad(|| model.forward_t(&input_tensor, false));

    // Apply softmax to get probabilities
    let probabilities = output.softmax(-1, Kind::Float);

    // Get top-5 predictions
    let (top_probs, top_indices) = probabilities.topk(5, 1, true, true);

    // Convert tensors to Rust Vecs
    let top_probs: Vec<f32> = top_probs.squeeze().try_into()?;
    let top_indices: Vec<i64> = top_indices.squeeze().try_into()?;

    // Load labels from a local file (imagenet_classes.txt)
    let labels: Vec<&str> = include_str!("imagenet_classes.txt").lines().collect();

    println!("\n🏆 Top-5 Predictions:");
    for (i, (&idx, &prob)) in top_indices.iter().zip(top_probs.iter()).enumerate() {
        println!("{:>2}. {:<30} — {:.2}%", i + 1, labels[idx as usize], prob * 100.0);
    }

    Ok(())
}

🧠 What’s Happening Here

Step Description
1. Softmax Converts raw logits (unbounded numbers) into probabilities that sum to 1.
2. topk(5) Finds the indices and values of the top 5 probabilities.
3. to CPU Moves tensors to CPU for easy printing.
4. Mapping Uses the ImageNet label list to show readable class names.

📊 Example Output (random weights)

Using device: Cpu

🏆 Top-5 Predictions:
 1. bloodhound                     — 1.13%
 2. rugby ball                     — 0.74%
 3. pizza                          — 0.61%
 4. gong                           — 0.55%
 5. Greater Swiss Mountain dog     — 0.50%

If you load real pretrained weights, these probabilities will reflect meaningful classifications.

💡 Tip

If you later add pretrained weights (e.g., from a .ot file or the tch-models crate), this section will instantly start showing real classification probabilities — no code changes needed.


Loading Pretrained Weights for Real Predictions

Until now, our model has been running with random weights.
To get meaningful predictions, we’ll load pretrained ResNet18 weights that were trained on the ImageNet dataset.

1️⃣ Export the Pretrained Model from Python

If you have Python with PyTorch installed, run this short script to export the model:

# export_resnet18.py
import torch
import torchvision.models as models

# Load pretrained ResNet18
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.eval()

# Save to TorchScript format (compatible with tch-rs)
example = torch.rand(1, 3, 224, 224)
traced = torch.jit.trace(model, example)
traced.save("resnet18.pt")

print("✅ Exported model saved as resnet18.pt")

This saves a resnet18.pt file in your project directory — which Rust can load directly using tch.

2️⃣ Load the TorchScript Model in Rust

Now, update your Rust code to use the exported model:

use anyhow::Result;
use tch::{ CModule, Device, no_grad, vision::imagenet, Kind, Tensor };

fn main() -> Result<()> {
    // Choose device (CPU or CUDA)
    let device = Device::cuda_if_available();
    println!("Using device: {:?}", device);

    // Load pretrained model exported from PyTorch
    let model = CModule::load_on_device("resnet18.pt", device)?;

    // Load and preprocess an image
    let image = imagenet::load_image_and_resize224("grass.jpg")?;
    let image = imagenet::normalize(&image)?;
    let input_tensor = image.unsqueeze(0).to_device(device);

    // Run inference with no_grad
    let output = no_grad(|| model.forward_ts(&[input_tensor]))?;

    // Apply softmax to get probabilities
    let probabilities = output.softmax(-1, Kind::Float);

    // Get top-5 predictions
    let (top_probs, top_indices) = probabilities.topk(5, 1, true, true);
    let top_probs: Vec<f32> = top_probs.squeeze().try_into()?;
    let top_indices: Vec<i64> = top_indices.squeeze().try_into()?;

    // Load class labels
    let labels: Vec<&str> = include_str!("imagenet_classes.txt").lines().collect();

    println!("\n🏆 Top-5 Predictions (Pretrained Model):");
    for (i, (&idx, &prob)) in top_indices.iter().zip(top_probs.iter()).enumerate() {
        println!("{:>2}. {:<30} — {:.2}%", i + 1, labels[idx as usize], prob * 100.0);
    }

    Ok(())
}

3️⃣ Explanation

Concept Description
CModule::load_on_device() Loads a TorchScript model (compiled from PyTorch) directly into memory
forward_ts() Runs inference using the TorchScript model
imagenet::normalize() Normalizes pixel values with ImageNet’s mean and std
topk(5, …) Extracts the top-5 predicted classes and their probabilities

4️⃣ Example Output

After running this with dog.jpg, you’ll now get real predictions:

Using device: Cpu

🏆 Top-5 Predictions (Pretrained Model):
 1. matchstick                     — 9.29%
 2. spotlight                      — 5.38%
 3. nematode                       — 2.87%
 4. lighter                        — 2.76%
 5. digital clock                  — 2.74%

✅ Summary

In this section, you’ve learned how to:

  • Export a pretrained model from PyTorch as TorchScript

  • Load and run it in Rust using the tch crate

  • Perform top-5 image classification with real probabilities


Saving and Reusing the Model (Rust Inference API Example)

Once you have successfully loaded and tested your ResNet18 model with top-5 predictions, it’s time to make it reusable by saving it once and then creating a simple inference API.

This allows you to load your model only once and classify multiple images quickly — just like a real backend service.

🧠 1. Saving the Model in Rust

Although we exported the TorchScript model (resnet18.pt) from Python, you can save an updated version (e.g., fine-tuned or modified weights) directly from Rust using VarStore::save().

use anyhow::Result;
use tch::{nn, vision::resnet, Device};

fn main() -> Result<()> {
    let device = Device::cuda_if_available();
    let vs = nn::VarStore::new(device);

    let model = resnet::resnet18(&vs.root(), 1000); // same number of classes

    // Save the model weights to a file
    vs.save("resnet18_saved.ot")?;
    println!("✅ Model parameters saved to resnet18_saved.ot");

    Ok(())
}

This saves only the model parameters (.ot file). You can later reload them with vs.load("resnet18_saved.ot")?;

🧱 2. Loading and Using the Saved Model

You can load the saved weights and reuse them without retracing or re-exporting.

use anyhow::Result;
use tch::{Device, nn, vision::resnet};

fn main() -> Result<()> {
    let device = Device::cuda_if_available();
    let vs = nn::VarStore::new(device);

    let model = resnet::resnet18(&vs.root(), 1000);
    vs.load("resnet18_saved.ot")?;

    println!("✅ Loaded model from resnet18_saved.ot");
    Ok(())
}

🌐 3. Creating a Simple Rust Inference API

To serve your model via an HTTP endpoint, you can use Axum, a modern async web framework for Rust.

Add this to your Cargo.toml:

[dependencies]
axum = "0.8.6"
axum-extra = { version = "0.12.1", features = ["multipart"] }
tokio = { version = "1", features = ["full"] }
tch = "0.22.0"
anyhow = "1"
serde_json = "1"
image = "0.25"

Then create src/main.rs like this:

use axum::{ Router, extract::State, response::Json, routing::post };
use axum_extra::extract::Multipart;
use serde_json::json;
use std::{ net::SocketAddr, sync::Arc };
use tokio::sync::Mutex;
use tch::{ CModule, Device, Tensor, Kind };

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // Load the model and share it using Arc<Mutex<>>
    let model = Arc::new(
        Mutex::new(CModule::load_on_device("resnet18.pt", Device::cuda_if_available())?)
    );

    // Attach the model as global state
    let app = Router::new().route("/predict", post(predict_handler)).with_state(model.clone());

    let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
    println!("🚀 Server running at http://{addr}/predict");

    let listener = tokio::net::TcpListener::bind(addr).await?;
    axum::serve(listener, app).await?;
    Ok(())
}

async fn predict_handler(
    State(model): State<Arc<Mutex<CModule>>>,
    mut multipart: Multipart
) -> Result<Json<serde_json::Value>, (axum::http::StatusCode, Json<serde_json::Value>)> {
    while
        let Some(field) = multipart
            .next_field().await
            .map_err(|_| {
                (axum::http::StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid form data" })))
            })?
    {
        let data = field
            .bytes().await
            .map_err(|_| {
                (
                    axum::http::StatusCode::BAD_REQUEST,
                    Json(json!({ "error": "Failed to read file bytes" })),
                )
            })?;

        // 🧠 Load and preprocess image
        let image = image
            ::load_from_memory(&data)
            .map_err(|_| {
                (
                    axum::http::StatusCode::BAD_REQUEST,
                    Json(json!({ "error": "Invalid image format" })),
                )
            })?
            .to_rgb8();

        let resized = image::imageops::resize(
            &image,
            224,
            224,
            image::imageops::FilterType::Nearest
        );
        let img_data = resized.into_raw();

        let tensor =
            Tensor::from_slice(&img_data)
                .view([224, 224, 3])
                .permute(&[2, 0, 1])
                .unsqueeze(0)
                .to_kind(Kind::Float) / 255.0;

        let tensor = tensor.to_device(Device::cuda_if_available());

        let model = model.lock().await;
        let output = model
            .forward_ts(&[tensor])
            .map_err(|_| {
                (
                    axum::http::StatusCode::INTERNAL_SERVER_ERROR,
                    Json(json!({ "error": "Failed to run inference" })),
                )
            })?;

        let probabilities = output.softmax(-1, Kind::Float);
        let (confidence, class_index) = probabilities.max_dim(1, false);

        let confidence = confidence.double_value(&[]);
        let class_index = class_index.int64_value(&[]);

        return Ok(
            Json(
                json!({
            "class_index": class_index,
            "confidence": format!("{:.2}%", confidence * 100.0)
        })
            )
        );
    }

    Err((axum::http::StatusCode::BAD_REQUEST, Json(json!({ "error": "No file uploaded" }))))
}

🧪 4. Running the API

Run the server:

cargo run

Output:

🚀 Server running at http://127.0.0.1:8080/predict

Then test with an image file:

curl -F "[email protected]" http://127.0.0.1:8080/predict

Example response:

{"class_index":973,"confidence":"17.15%"}

✅ Summary

In this section, you learned how to:

  • Save and reload your model parameters in Rust.

  • Serve predictions via an Axum-based API.

  • Perform inference on uploaded images efficiently.


Conclusion and Next Steps

Congratulations! You’ve just built a complete deep learning inference API in Rust — from loading a pretrained ResNet18 model to serving real-time predictions through an HTTP endpoint. 🚀

Here’s what you achieved step by step:

✅ What You’ve Learned

  1. Setting up Rust and tch (LibTorch)

    • Installed and configured tch = "0.22" for deep learning in Rust.

    • Understood how to load and preprocess image tensors.

  2. Loading and Running a ResNet18 Model

    • Created and used a pretrained ResNet18 architecture.

    • Performed top-5 image classification inference.

  3. Saving and Reusing the Model

    • Exported a TorchScript .pt model.

    • Reloaded it for efficient inference, decoupled from training.

  4. Building an Inference API with Axum

    • Created a RESTful endpoint (POST /predict) for file uploads.

    • Integrated tch inference into a real-world Rust web service.

    • Handled JSON responses, file uploads, and error handling cleanly.

🧠 What Makes This Tutorial Unique

  • Zero Python at runtime — the model runs natively in Rust.

  • Fast and safe — leveraging Rust’s performance and memory safety.

  • Production-ready foundation — you can easily extend it into a full ML microservice.

🚀 Next Steps

Now that your API works locally, here are practical ways to extend it:

  1. 🧩 Add Support for Multiple Models
    Load different TorchScript models (e.g., ResNet, MobileNet, EfficientNet) and select via query parameters.

  2. ⚙️ Serve via Docker or Kubernetes
    Package your Rust app with Docker for easy deployment:

     
    FROM rust:1.81 AS builder
    WORKDIR /app
    COPY . .
    RUN cargo build --release
    
    FROM debian:bookworm-slim
    WORKDIR /app
    COPY --from=builder /app/target/release/rust-inference-api .
    COPY resnet18.pt .
    EXPOSE 3000
    CMD ["./rust-inference-api"]

     

  3. 📊 Add Metrics and Logging
    Use crates like tracing or prometheus to monitor API performance.

  4. 🧪 Integrate into a Larger ML System
    Combine this with a frontend or queue-based architecture (RabbitMQ/Kafka) for distributed inference.

  5. ⚡ Try Other tch Models
    Explore tch::vision::models — e.g., mobilenet_v2(), vgg16(), or custom models exported from PyTorch.

✨ Final Thoughts

Rust + tch gives you the best of both worlds
machine learning power with system-level speed and safety.

You’ve built a foundation that can easily evolve into a production-grade AI inference service.

“Performance, safety, and reliability — that’s what Rust brings to AI.”

You can find the full source code on our GitHub.

That's just the basics. If you need more deep learning about Rust, you can take the following cheap course:

Thanks!