Aller au contenu principal
EngineeringMar 28, 2026

Deep EVM #22: Dependency Injection in Rust — ServiceLocator, Arc, and Trait Objects

OS
Open Soft Team

Engineering Team

The DI Problem in Rust

Dependency injection is a fundamental design principle: a component should receive its dependencies from the outside rather than creating them internally. In Java or C#, DI frameworks like Spring or Autofac handle this automatically. Rust does not have a standard DI framework, and its ownership model makes naive DI patterns awkward.

Consider a service that needs a database connection, a cache, and a logger:

// Bad: hardcoded dependencies
struct UserService {
    db: PgPool,          // Concrete type, not mockable
    cache: RedisClient,  // Concrete type
}

impl UserService {
    fn new() -> Self {
        // Worst pattern: constructing own dependencies
        Self {
            db: PgPool::connect("postgres://...").await.unwrap(),
            cache: RedisClient::new("redis://...").unwrap(),
        }
    }
}

This is untestable. You cannot run unit tests without a real PostgreSQL and Redis instance. Let us fix this with Rust-idiomatic DI.

Pattern 1: Trait Objects with Arc

Define behavior as traits, then accept Arc<dyn Trait> for runtime polymorphism:

use std::sync::Arc;
use async_trait::async_trait;

#[async_trait]
trait UserRepository: Send + Sync {
    async fn find_by_id(&self, id: i64) -> anyhow::Result<Option<User>>;
    async fn save(&self, user: &User) -> anyhow::Result<()>;
    async fn delete(&self, id: i64) -> anyhow::Result<()>;
}

#[async_trait]
trait CacheStore: Send + Sync {
    async fn get(&self, key: &str) -> anyhow::Result<Option<Vec<u8>>>;
    async fn set(&self, key: &str, value: &[u8], ttl_secs: u64) -> anyhow::Result<()>;
    async fn invalidate(&self, key: &str) -> anyhow::Result<()>;
}

struct UserService {
    repo: Arc<dyn UserRepository>,
    cache: Arc<dyn CacheStore>,
}

impl UserService {
    fn new(
        repo: Arc<dyn UserRepository>,
        cache: Arc<dyn CacheStore>,
    ) -> Self {
        Self { repo, cache }
    }

    async fn get_user(&self, id: i64) -> anyhow::Result<Option<User>> {
        // Check cache first
        let cache_key = format!("user:{}", id);
        if let Some(bytes) = self.cache.get(&cache_key).await? {
            let user: User = serde_json::from_slice(&bytes)?;
            return Ok(Some(user));
        }

        // Cache miss — query database
        let user = self.repo.find_by_id(id).await?;
        if let Some(ref u) = user {
            let bytes = serde_json::to_vec(u)?;
            self.cache.set(&cache_key, &bytes, 300).await?;
        }

        Ok(user)
    }
}

Production Implementations

struct PostgresUserRepository {
    pool: PgPool,
}

#[async_trait]
impl UserRepository for PostgresUserRepository {
    async fn find_by_id(&self, id: i64) -> anyhow::Result<Option<User>> {
        let user = sqlx::query_as::<_, User>(
            "SELECT id, name, email FROM users WHERE id = $1"
        )
        .bind(id)
        .fetch_optional(&self.pool)
        .await?;
        Ok(user)
    }

    async fn save(&self, user: &User) -> anyhow::Result<()> {
        sqlx::query(
            "INSERT INTO users (id, name, email) VALUES ($1, $2, $3)
             ON CONFLICT (id) DO UPDATE SET name = $2, email = $3"
        )
        .bind(user.id)
        .bind(&user.name)
        .bind(&user.email)
        .execute(&self.pool)
        .await?;
        Ok(())
    }

    async fn delete(&self, id: i64) -> anyhow::Result<()> {
        sqlx::query("DELETE FROM users WHERE id = $1")
            .bind(id)
            .execute(&self.pool)
            .await?;
        Ok(())
    }
}

struct RedisCacheStore {
    client: redis::Client,
}

#[async_trait]
impl CacheStore for RedisCacheStore {
    async fn get(&self, key: &str) -> anyhow::Result<Option<Vec<u8>>> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        let result: Option<Vec<u8>> = redis::cmd("GET")
            .arg(key)
            .query_async(&mut conn)
            .await?;
        Ok(result)
    }

    async fn set(&self, key: &str, value: &[u8], ttl_secs: u64) -> anyhow::Result<()> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        redis::cmd("SETEX")
            .arg(key)
            .arg(ttl_secs)
            .arg(value)
            .query_async(&mut conn)
            .await?;
        Ok(())
    }

    async fn invalidate(&self, key: &str) -> anyhow::Result<()> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        redis::cmd("DEL").arg(key).query_async(&mut conn).await?;
        Ok(())
    }
}

Test Implementations

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;
    use tokio::sync::RwLock;

    struct MockUserRepository {
        users: RwLock<HashMap<i64, User>>,
    }

    impl MockUserRepository {
        fn new() -> Self {
            Self {
                users: RwLock::new(HashMap::new()),
            }
        }

        fn with_users(users: Vec<User>) -> Self {
            let map: HashMap<i64, User> = users
                .into_iter()
                .map(|u| (u.id, u))
                .collect();
            Self {
                users: RwLock::new(map),
            }
        }
    }

    #[async_trait]
    impl UserRepository for MockUserRepository {
        async fn find_by_id(&self, id: i64) -> anyhow::Result<Option<User>> {
            Ok(self.users.read().await.get(&id).cloned())
        }

        async fn save(&self, user: &User) -> anyhow::Result<()> {
            self.users.write().await.insert(user.id, user.clone());
            Ok(())
        }

        async fn delete(&self, id: i64) -> anyhow::Result<()> {
            self.users.write().await.remove(&id);
            Ok(())
        }
    }

    struct MockCacheStore {
        data: RwLock<HashMap<String, Vec<u8>>>,
    }

    impl MockCacheStore {
        fn new() -> Self {
            Self {
                data: RwLock::new(HashMap::new()),
            }
        }
    }

    #[async_trait]
    impl CacheStore for MockCacheStore {
        async fn get(&self, key: &str) -> anyhow::Result<Option<Vec<u8>>> {
            Ok(self.data.read().await.get(key).cloned())
        }

        async fn set(&self, key: &str, value: &[u8], _ttl: u64) -> anyhow::Result<()> {
            self.data.write().await.insert(key.to_string(), value.to_vec());
            Ok(())
        }

        async fn invalidate(&self, key: &str) -> anyhow::Result<()> {
            self.data.write().await.remove(key);
            Ok(())
        }
    }

    #[tokio::test]
    async fn test_get_user_cache_miss() {
        let user = User { id: 1, name: "Alice".into(), email: "alice@test.com".into() };
        let repo = Arc::new(MockUserRepository::with_users(vec![user.clone()]));
        let cache = Arc::new(MockCacheStore::new());
        let service = UserService::new(repo, cache.clone());

        let result = service.get_user(1).await.unwrap();
        assert_eq!(result, Some(user));

        // Verify it was cached
        assert!(cache.data.read().await.contains_key("user:1"));
    }
}

Pattern 2: Generics (Zero-Cost Abstraction)

For performance-critical paths where dynamic dispatch overhead matters, use generics:

struct UserService<R: UserRepository, C: CacheStore> {
    repo: R,
    cache: C,
}

impl<R: UserRepository, C: CacheStore> UserService<R, C> {
    fn new(repo: R, cache: C) -> Self {
        Self { repo, cache }
    }

    async fn get_user(&self, id: i64) -> anyhow::Result<Option<User>> {
        // Same logic, but statically dispatched
        // Compiler monomorphizes for each (R, C) combination
    }
}

Trade-offs:

AspectArcGenerics
Runtime costvtable lookup (~1ns)Zero
Binary sizeSmallerLarger (monomorphization)
FlexibilityCan swap at runtimeFixed at compile time
Compile timeFasterSlower
ErgonomicsSimpler signaturesGeneric bounds propagate

For most applications, Arc<dyn Trait> is the right choice. The vtable lookup cost is negligible compared to the database or network calls your service makes. Use generics only in hot loops where nanoseconds matter.

The Composition Root Pattern

Wire everything together at the application entry point:

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // Initialize infrastructure
    let db_pool = PgPool::connect(&std::env::var("DATABASE_URL")?).await?;
    let redis = redis::Client::open(std::env::var("REDIS_URL")?)?;

    // Build repositories
    let user_repo: Arc<dyn UserRepository> = Arc::new(
        PostgresUserRepository { pool: db_pool.clone() }
    );
    let cache: Arc<dyn CacheStore> = Arc::new(
        RedisCacheStore { client: redis }
    );

    // Build services
    let user_service = Arc::new(UserService::new(
        user_repo.clone(),
        cache.clone(),
    ));

    // Build app state
    let state = AppState {
        user_service,
        // ... other services
    };

    // Start server
    let app = Router::new()
        .route("/users/:id", get(get_user_handler))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?;
    axum::serve(listener, app).await?;
    Ok(())
}

All dependencies are created in main() and passed downward. No global state, no service locators called at arbitrary points. This is the purest form of DI.

The ServiceLocator Pattern

For complex systems with dozens of services, passing every dependency through constructors becomes unwieldy. The ServiceLocator pattern provides a registry that components can query:

use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;

#[derive(Clone, Default)]
struct ServiceLocator {
    services: Arc<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
}

impl ServiceLocator {
    fn builder() -> ServiceLocatorBuilder {
        ServiceLocatorBuilder {
            services: HashMap::new(),
        }
    }

    fn get<T: 'static + Send + Sync>(&self) -> Option<&T> {
        self.services
            .get(&TypeId::of::<T>())
            .and_then(|boxed| boxed.downcast_ref::<T>())
    }

    fn resolve<T: 'static + Send + Sync>(&self) -> &T {
        self.get::<T>()
            .unwrap_or_else(|| panic!(
                "Service not registered: {}",
                std::any::type_name::<T>()
            ))
    }
}

struct ServiceLocatorBuilder {
    services: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
}

impl ServiceLocatorBuilder {
    fn register<T: 'static + Send + Sync>(mut self, service: T) -> Self {
        self.services.insert(TypeId::of::<T>(), Box::new(service));
        self
    }

    fn build(self) -> ServiceLocator {
        ServiceLocator {
            services: Arc::new(self.services),
        }
    }
}

Usage:

let locator = ServiceLocator::builder()
    .register::<Arc<dyn UserRepository>>(Arc::new(pg_repo))
    .register::<Arc<dyn CacheStore>>(Arc::new(redis_cache))
    .register::<Arc<UserService>>(Arc::new(user_service))
    .build();

// Later, in a handler:
let user_service = locator.resolve::<Arc<UserService>>();

The ServiceLocator trades compile-time safety for runtime flexibility. It is a controlled anti-pattern: use it at the composition root to simplify wiring, but prefer constructor injection within individual services.

Conclusion

Dependency injection in Rust does not require a framework. Traits define contracts, Arc<dyn Trait> provides runtime polymorphism, and the composition root wires everything together. Use mock implementations for testing, generics for hot paths, and the ServiceLocator pattern only when constructor chains become unmanageable. The key principle is the same across all languages: depend on abstractions, not concretions.