Skip to main content

Tutorial: Build a Provider

What you'll build

A custom LLM provider that integrates with an imaginary "AcmeLLM" API. You'll implement the Provider trait, map errors for the fallback FSM, register in the provider registry, and test with mock responses.

Prerequisites

The Provider Trait

Every LLM provider in ClawDesk implements a single trait:

/// In clawdesk-providers/src/trait.rs
#[async_trait]
pub trait Provider: Send + Sync + 'static {
/// Send a request to the LLM and return the response.
async fn send(
&self,
req: ProviderRequest,
) -> Result<ProviderResponse, ProviderError>;

/// Returns the provider's unique identifier.
fn id(&self) -> ProviderId;

/// Returns the list of models this provider supports.
fn supported_models(&self) -> Vec<ModelId>;

/// Returns the provider's current health status.
async fn health(&self) -> ProviderHealth;
}

The trait is intentionally minimal. The complexity lives in error mapping — the fallback FSM needs to know whether an error is retryable, fatal, or rate-limited.


The Data Types


Step 1: Define the Provider Struct

// clawdesk-providers/src/acme.rs

use std::time::Duration;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use async_trait::async_trait;

use crate::{
Provider, ProviderId, ProviderRequest, ProviderResponse,
ProviderError, ProviderHealth, ModelId, Usage,
};

/// AcmeLLM provider — a custom LLM integration.
pub struct AcmeProvider {
client: Client,
api_key: String,
base_url: String,
models: Vec<ModelId>,
timeout: Duration,
}

impl AcmeProvider {
pub fn new(config: AcmeConfig) -> Result<Self, ProviderError> {
let client = Client::builder()
.timeout(config.timeout.unwrap_or(Duration::from_secs(30)))
.build()
.map_err(|e| ProviderError::Network(e.to_string()))?;

Ok(Self {
client,
api_key: config.api_key,
base_url: config.base_url
.unwrap_or_else(|| "https://api.acme-llm.example/v1".to_string()),
models: config.models.unwrap_or_else(|| {
vec![
ModelId::new("acme-fast"),
ModelId::new("acme-smart"),
]
}),
timeout: config.timeout.unwrap_or(Duration::from_secs(30)),
})
}
}

/// Configuration for the AcmeLLM provider.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AcmeConfig {
pub api_key: String,
pub base_url: Option<String>,
pub models: Option<Vec<ModelId>>,
pub timeout: Option<Duration>,
}

Step 2: Implement the Provider Trait

#[async_trait]
impl Provider for AcmeProvider {
fn id(&self) -> ProviderId {
ProviderId::new("acme")
}

fn supported_models(&self) -> Vec<ModelId> {
self.models.clone()
}

async fn health(&self) -> ProviderHealth {
match self.client
.get(format!("{}/health", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
{
Ok(resp) if resp.status().is_success() => ProviderHealth::Healthy,
Ok(resp) => ProviderHealth::Degraded(
format!("Status: {}", resp.status())
),
Err(e) => ProviderHealth::Unhealthy(e.to_string()),
}
}

async fn send(
&self,
req: ProviderRequest,
) -> Result<ProviderResponse, ProviderError> {
// 1. Build the Acme-specific API request
let api_request = AcmeApiRequest {
model: req.model
.clone()
.unwrap_or_else(|| ModelId::new("acme-fast")),
messages: req.messages.iter().map(|m| AcmeMessage {
role: m.role.to_string(),
content: m.content.clone(),
}).collect(),
tools: req.tools.iter().map(|t| AcmeTool {
name: t.name.clone(),
description: t.description.clone(),
parameters: t.parameters.clone(),
}).collect(),
temperature: req.temperature,
max_tokens: req.max_tokens,
};

// 2. Send the HTTP request
let response = self.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&api_request)
.send()
.await
.map_err(|e| self.map_reqwest_error(e))?;

// 3. Handle HTTP status codes
let status = response.status();
if !status.is_success() {
return Err(self.map_http_error(status, response).await);
}

// 4. Parse the response
let api_response: AcmeApiResponse = response
.json()
.await
.map_err(|e| ProviderError::InvalidResponse(
format!("Failed to parse Acme response: {}", e)
))?;

// 5. Convert to ClawDesk's unified response type
Ok(ProviderResponse {
content: api_response.choices[0].message.content.clone(),
tool_calls: self.extract_tool_calls(&api_response),
model: api_response.model,
usage: Usage {
prompt_tokens: api_response.usage.prompt_tokens,
completion_tokens: api_response.usage.completion_tokens,
total_tokens: api_response.usage.total_tokens,
},
metadata: Default::default(),
})
}
}

Step 3: Error Mapping — The Critical Part

The fallback FSM categorizes errors to decide whether to retry, try another provider, or abort. Your error mapping must be accurate:

impl AcmeProvider {
/// Map reqwest transport errors to ProviderError.
fn map_reqwest_error(&self, err: reqwest::Error) -> ProviderError {
if err.is_timeout() {
ProviderError::Timeout(self.timeout)
} else if err.is_connect() {
ProviderError::Network(format!("Connection failed: {}", err))
} else {
ProviderError::Network(err.to_string())
}
}

/// Map HTTP error responses to ProviderError.
async fn map_http_error(
&self,
status: reqwest::StatusCode,
response: reqwest::Response,
) -> ProviderError {
let body = response.text().await.unwrap_or_default();

match status.as_u16() {
401 | 403 => ProviderError::AuthError(
format!("Authentication failed: {}", body)
),
404 => ProviderError::ModelNotFound(
ModelId::new("unknown") // Parse from body if possible
),
429 => {
// Extract Retry-After header
let retry_after = Duration::from_secs(60); // default
ProviderError::RateLimit(retry_after)
}
413 => {
// Parse context length from error body
ProviderError::ContextTooLong(0, 0) // Parse actual values
}
500..=599 => ProviderError::ServerError(
status.as_u16(),
body,
),
_ => ProviderError::ServerError(status.as_u16(), body),
}
}
}

Error → FSM Event Mapping

The fallback FSM uses ProviderError::classify() to determine the event:

ProviderError variantFallbackEventFSM action
RateLimit(_)ThrottledWait, then retry same provider
AuthError(_)FatalSkip provider permanently
ModelNotFound(_)FatalSkip provider permanently
ContextTooLong(_, _)Failure(Recoverable)Try next provider
Network(_)Failure(Transient)Retry same provider once
Timeout(_)Failure(Transient)Retry same provider once
ServerError(5xx, _)Failure(Transient)Retry, then next provider
InvalidResponse(_)Failure(Recoverable)Try next provider
Accuracy matters

If you classify a fatal error as transient, the FSM will waste time retrying a permanently broken provider. If you classify a transient error as fatal, users lose access to a provider that might recover.


Step 4: Acme-Specific API Types

// Internal API types — not exposed outside this module

#[derive(Serialize)]
struct AcmeApiRequest {
model: ModelId,
messages: Vec<AcmeMessage>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<AcmeTool>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
}

#[derive(Serialize)]
struct AcmeMessage {
role: String,
content: String,
}

#[derive(Serialize)]
struct AcmeTool {
name: String,
description: String,
parameters: serde_json::Value,
}

#[derive(Deserialize)]
struct AcmeApiResponse {
model: String,
choices: Vec<AcmeChoice>,
usage: AcmeUsage,
}

#[derive(Deserialize)]
struct AcmeChoice {
message: AcmeResponseMessage,
}

#[derive(Deserialize)]
struct AcmeResponseMessage {
content: String,
#[serde(default)]
tool_calls: Vec<AcmeToolCall>,
}

#[derive(Deserialize)]
struct AcmeToolCall {
name: String,
arguments: serde_json::Value,
}

#[derive(Deserialize)]
struct AcmeUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}

Step 5: Register in the Provider Registry

// In clawdesk-providers/src/registry.rs

use crate::acme::{AcmeProvider, AcmeConfig};

pub fn register_providers(
registry: &mut ProviderRegistry,
config: &AppConfig,
) -> Result<(), ProviderError> {
// ... existing providers (Anthropic, OpenAI, etc.) ...

if let Some(acme_config) = &config.providers.acme {
let provider = AcmeProvider::new(acme_config.clone())?;
registry.register(Box::new(provider));
tracing::info!(
models = ?provider.supported_models(),
"Registered Acme provider"
);
}

Ok(())
}

Configuration in clawdesk.toml:

[providers.acme]
api_key = "${ACME_API_KEY}"
base_url = "https://api.acme-llm.example/v1"
models = ["acme-fast", "acme-smart"]
timeout = "30s"

Step 6: Testing with Mock LLM Responses

#[cfg(test)]
mod tests {
use super::*;
use wiremock::{MockServer, Mock, ResponseTemplate};
use wiremock::matchers::{method, path, header};

/// Create a provider pointing at the mock server.
async fn mock_provider(server: &MockServer) -> AcmeProvider {
AcmeProvider::new(AcmeConfig {
api_key: "test-key".to_string(),
base_url: Some(server.uri()),
models: None,
timeout: Some(Duration::from_secs(5)),
}).unwrap()
}

#[tokio::test]
async fn successful_completion() {
let server = MockServer::start().await;

Mock::given(method("POST"))
.and(path("/chat/completions"))
.and(header("Authorization", "Bearer test-key"))
.respond_with(ResponseTemplate::new(200).set_body_json(
serde_json::json!({
"model": "acme-fast",
"choices": [{
"message": {
"content": "Hello! How can I help you?",
"tool_calls": []
}
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18
}
})
))
.mount(&server)
.await;

let provider = mock_provider(&server).await;
let request = ProviderRequest::simple("Hello");

let response = provider.send(request).await.unwrap();

assert_eq!(response.content, "Hello! How can I help you?");
assert_eq!(response.usage.total_tokens, 18);
}

#[tokio::test]
async fn rate_limit_returns_correct_error() {
let server = MockServer::start().await;

Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(429)
.set_body_string("Rate limit exceeded")
)
.mount(&server)
.await;

let provider = mock_provider(&server).await;
let result = provider.send(ProviderRequest::simple("test")).await;

assert!(matches!(result, Err(ProviderError::RateLimit(_))));
}

#[tokio::test]
async fn auth_error_is_fatal() {
let server = MockServer::start().await;

Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(401)
.set_body_string("Invalid API key")
)
.mount(&server)
.await;

let provider = mock_provider(&server).await;
let result = provider.send(ProviderRequest::simple("test")).await;

match result {
Err(ProviderError::AuthError(_)) => {
// Verify this classifies as Fatal for the FSM
assert_eq!(
result.unwrap_err().classify(),
ErrorClassification::Fatal
);
}
other => panic!("Expected AuthError, got {:?}", other),
}
}

#[tokio::test]
async fn server_error_is_transient() {
let server = MockServer::start().await;

Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;

let provider = mock_provider(&server).await;
let result = provider.send(ProviderRequest::simple("test")).await;

assert!(matches!(
result,
Err(ProviderError::ServerError(500, _))
));
}

#[tokio::test]
async fn health_check_healthy() {
let server = MockServer::start().await;

Mock::given(method("GET"))
.and(path("/health"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;

let provider = mock_provider(&server).await;
assert!(matches!(provider.health().await, ProviderHealth::Healthy));
}

#[tokio::test]
async fn tool_calls_are_extracted() {
let server = MockServer::start().await;

Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(
serde_json::json!({
"model": "acme-smart",
"choices": [{
"message": {
"content": "",
"tool_calls": [{
"name": "get_weather",
"arguments": {"city": "London"}
}]
}
}],
"usage": {
"prompt_tokens": 15,
"completion_tokens": 20,
"total_tokens": 35
}
})
))
.mount(&server)
.await;

let provider = mock_provider(&server).await;
let response = provider.send(ProviderRequest::simple("weather")).await.unwrap();

let tool_calls = response.tool_calls.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].name, "get_weather");
}
}

Provider Implementation Checklist

RequirementStatusNotes
Implement Provider::send()Maps request/response
Implement Provider::id()Unique identifier
Implement Provider::supported_models()Model list
Implement Provider::health()Health endpoint check
Map all HTTP errors to ProviderError429, 401, 5xx, etc.
Error classification is accurateFatal vs transient
Register in ProviderRegistryConfig-driven
Unit tests with mock server6 test cases

What's Next?