Tutorial: Build a Provider
A custom LLM provider that integrates with an imaginary "AcmeLLM" API. You'll implement the Provider trait, map errors for the fallback FSM, register in the provider registry, and test with mock responses.
Prerequisites
- Completed Message Flow Tutorial
- Understanding of the Fallback FSM (recommended)
The Provider Trait
Every LLM provider in ClawDesk implements a single trait:
/// In clawdesk-providers/src/trait.rs
#[async_trait]
pub trait Provider: Send + Sync + 'static {
/// Send a request to the LLM and return the response.
async fn send(
&self,
req: ProviderRequest,
) -> Result<ProviderResponse, ProviderError>;
/// Returns the provider's unique identifier.
fn id(&self) -> ProviderId;
/// Returns the list of models this provider supports.
fn supported_models(&self) -> Vec<ModelId>;
/// Returns the provider's current health status.
async fn health(&self) -> ProviderHealth;
}
The trait is intentionally minimal. The complexity lives in error mapping — the fallback FSM needs to know whether an error is retryable, fatal, or rate-limited.
The Data Types
Step 1: Define the Provider Struct
// clawdesk-providers/src/acme.rs
use std::time::Duration;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use async_trait::async_trait;
use crate::{
Provider, ProviderId, ProviderRequest, ProviderResponse,
ProviderError, ProviderHealth, ModelId, Usage,
};
/// AcmeLLM provider — a custom LLM integration.
pub struct AcmeProvider {
client: Client,
api_key: String,
base_url: String,
models: Vec<ModelId>,
timeout: Duration,
}
impl AcmeProvider {
pub fn new(config: AcmeConfig) -> Result<Self, ProviderError> {
let client = Client::builder()
.timeout(config.timeout.unwrap_or(Duration::from_secs(30)))
.build()
.map_err(|e| ProviderError::Network(e.to_string()))?;
Ok(Self {
client,
api_key: config.api_key,
base_url: config.base_url
.unwrap_or_else(|| "https://api.acme-llm.example/v1".to_string()),
models: config.models.unwrap_or_else(|| {
vec![
ModelId::new("acme-fast"),
ModelId::new("acme-smart"),
]
}),
timeout: config.timeout.unwrap_or(Duration::from_secs(30)),
})
}
}
/// Configuration for the AcmeLLM provider.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AcmeConfig {
pub api_key: String,
pub base_url: Option<String>,
pub models: Option<Vec<ModelId>>,
pub timeout: Option<Duration>,
}
Step 2: Implement the Provider Trait
#[async_trait]
impl Provider for AcmeProvider {
fn id(&self) -> ProviderId {
ProviderId::new("acme")
}
fn supported_models(&self) -> Vec<ModelId> {
self.models.clone()
}
async fn health(&self) -> ProviderHealth {
match self.client
.get(format!("{}/health", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
{
Ok(resp) if resp.status().is_success() => ProviderHealth::Healthy,
Ok(resp) => ProviderHealth::Degraded(
format!("Status: {}", resp.status())
),
Err(e) => ProviderHealth::Unhealthy(e.to_string()),
}
}
async fn send(
&self,
req: ProviderRequest,
) -> Result<ProviderResponse, ProviderError> {
// 1. Build the Acme-specific API request
let api_request = AcmeApiRequest {
model: req.model
.clone()
.unwrap_or_else(|| ModelId::new("acme-fast")),
messages: req.messages.iter().map(|m| AcmeMessage {
role: m.role.to_string(),
content: m.content.clone(),
}).collect(),
tools: req.tools.iter().map(|t| AcmeTool {
name: t.name.clone(),
description: t.description.clone(),
parameters: t.parameters.clone(),
}).collect(),
temperature: req.temperature,
max_tokens: req.max_tokens,
};
// 2. Send the HTTP request
let response = self.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&api_request)
.send()
.await
.map_err(|e| self.map_reqwest_error(e))?;
// 3. Handle HTTP status codes
let status = response.status();
if !status.is_success() {
return Err(self.map_http_error(status, response).await);
}
// 4. Parse the response
let api_response: AcmeApiResponse = response
.json()
.await
.map_err(|e| ProviderError::InvalidResponse(
format!("Failed to parse Acme response: {}", e)
))?;
// 5. Convert to ClawDesk's unified response type
Ok(ProviderResponse {
content: api_response.choices[0].message.content.clone(),
tool_calls: self.extract_tool_calls(&api_response),
model: api_response.model,
usage: Usage {
prompt_tokens: api_response.usage.prompt_tokens,
completion_tokens: api_response.usage.completion_tokens,
total_tokens: api_response.usage.total_tokens,
},
metadata: Default::default(),
})
}
}
Step 3: Error Mapping — The Critical Part
The fallback FSM categorizes errors to decide whether to retry, try another provider, or abort. Your error mapping must be accurate:
impl AcmeProvider {
/// Map reqwest transport errors to ProviderError.
fn map_reqwest_error(&self, err: reqwest::Error) -> ProviderError {
if err.is_timeout() {
ProviderError::Timeout(self.timeout)
} else if err.is_connect() {
ProviderError::Network(format!("Connection failed: {}", err))
} else {
ProviderError::Network(err.to_string())
}
}
/// Map HTTP error responses to ProviderError.
async fn map_http_error(
&self,
status: reqwest::StatusCode,
response: reqwest::Response,
) -> ProviderError {
let body = response.text().await.unwrap_or_default();
match status.as_u16() {
401 | 403 => ProviderError::AuthError(
format!("Authentication failed: {}", body)
),
404 => ProviderError::ModelNotFound(
ModelId::new("unknown") // Parse from body if possible
),
429 => {
// Extract Retry-After header
let retry_after = Duration::from_secs(60); // default
ProviderError::RateLimit(retry_after)
}
413 => {
// Parse context length from error body
ProviderError::ContextTooLong(0, 0) // Parse actual values
}
500..=599 => ProviderError::ServerError(
status.as_u16(),
body,
),
_ => ProviderError::ServerError(status.as_u16(), body),
}
}
}
Error → FSM Event Mapping
The fallback FSM uses ProviderError::classify() to determine the event:
ProviderError variant | FallbackEvent | FSM action |
|---|---|---|
RateLimit(_) | Throttled | Wait, then retry same provider |
AuthError(_) | Fatal | Skip provider permanently |
ModelNotFound(_) | Fatal | Skip provider permanently |
ContextTooLong(_, _) | Failure(Recoverable) | Try next provider |
Network(_) | Failure(Transient) | Retry same provider once |
Timeout(_) | Failure(Transient) | Retry same provider once |
ServerError(5xx, _) | Failure(Transient) | Retry, then next provider |
InvalidResponse(_) | Failure(Recoverable) | Try next provider |
If you classify a fatal error as transient, the FSM will waste time retrying a permanently broken provider. If you classify a transient error as fatal, users lose access to a provider that might recover.
Step 4: Acme-Specific API Types
// Internal API types — not exposed outside this module
#[derive(Serialize)]
struct AcmeApiRequest {
model: ModelId,
messages: Vec<AcmeMessage>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<AcmeTool>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
}
#[derive(Serialize)]
struct AcmeMessage {
role: String,
content: String,
}
#[derive(Serialize)]
struct AcmeTool {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Deserialize)]
struct AcmeApiResponse {
model: String,
choices: Vec<AcmeChoice>,
usage: AcmeUsage,
}
#[derive(Deserialize)]
struct AcmeChoice {
message: AcmeResponseMessage,
}
#[derive(Deserialize)]
struct AcmeResponseMessage {
content: String,
#[serde(default)]
tool_calls: Vec<AcmeToolCall>,
}
#[derive(Deserialize)]
struct AcmeToolCall {
name: String,
arguments: serde_json::Value,
}
#[derive(Deserialize)]
struct AcmeUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
Step 5: Register in the Provider Registry
// In clawdesk-providers/src/registry.rs
use crate::acme::{AcmeProvider, AcmeConfig};
pub fn register_providers(
registry: &mut ProviderRegistry,
config: &AppConfig,
) -> Result<(), ProviderError> {
// ... existing providers (Anthropic, OpenAI, etc.) ...
if let Some(acme_config) = &config.providers.acme {
let provider = AcmeProvider::new(acme_config.clone())?;
registry.register(Box::new(provider));
tracing::info!(
models = ?provider.supported_models(),
"Registered Acme provider"
);
}
Ok(())
}
Configuration in clawdesk.toml:
[providers.acme]
api_key = "${ACME_API_KEY}"
base_url = "https://api.acme-llm.example/v1"
models = ["acme-fast", "acme-smart"]
timeout = "30s"
Step 6: Testing with Mock LLM Responses
#[cfg(test)]
mod tests {
use super::*;
use wiremock::{MockServer, Mock, ResponseTemplate};
use wiremock::matchers::{method, path, header};
/// Create a provider pointing at the mock server.
async fn mock_provider(server: &MockServer) -> AcmeProvider {
AcmeProvider::new(AcmeConfig {
api_key: "test-key".to_string(),
base_url: Some(server.uri()),
models: None,
timeout: Some(Duration::from_secs(5)),
}).unwrap()
}
#[tokio::test]
async fn successful_completion() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.and(header("Authorization", "Bearer test-key"))
.respond_with(ResponseTemplate::new(200).set_body_json(
serde_json::json!({
"model": "acme-fast",
"choices": [{
"message": {
"content": "Hello! How can I help you?",
"tool_calls": []
}
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18
}
})
))
.mount(&server)
.await;
let provider = mock_provider(&server).await;
let request = ProviderRequest::simple("Hello");
let response = provider.send(request).await.unwrap();
assert_eq!(response.content, "Hello! How can I help you?");
assert_eq!(response.usage.total_tokens, 18);
}
#[tokio::test]
async fn rate_limit_returns_correct_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(429)
.set_body_string("Rate limit exceeded")
)
.mount(&server)
.await;
let provider = mock_provider(&server).await;
let result = provider.send(ProviderRequest::simple("test")).await;
assert!(matches!(result, Err(ProviderError::RateLimit(_))));
}
#[tokio::test]
async fn auth_error_is_fatal() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(
ResponseTemplate::new(401)
.set_body_string("Invalid API key")
)
.mount(&server)
.await;
let provider = mock_provider(&server).await;
let result = provider.send(ProviderRequest::simple("test")).await;
match result {
Err(ProviderError::AuthError(_)) => {
// Verify this classifies as Fatal for the FSM
assert_eq!(
result.unwrap_err().classify(),
ErrorClassification::Fatal
);
}
other => panic!("Expected AuthError, got {:?}", other),
}
}
#[tokio::test]
async fn server_error_is_transient() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let provider = mock_provider(&server).await;
let result = provider.send(ProviderRequest::simple("test")).await;
assert!(matches!(
result,
Err(ProviderError::ServerError(500, _))
));
}
#[tokio::test]
async fn health_check_healthy() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/health"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let provider = mock_provider(&server).await;
assert!(matches!(provider.health().await, ProviderHealth::Healthy));
}
#[tokio::test]
async fn tool_calls_are_extracted() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(
serde_json::json!({
"model": "acme-smart",
"choices": [{
"message": {
"content": "",
"tool_calls": [{
"name": "get_weather",
"arguments": {"city": "London"}
}]
}
}],
"usage": {
"prompt_tokens": 15,
"completion_tokens": 20,
"total_tokens": 35
}
})
))
.mount(&server)
.await;
let provider = mock_provider(&server).await;
let response = provider.send(ProviderRequest::simple("weather")).await.unwrap();
let tool_calls = response.tool_calls.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].name, "get_weather");
}
}
Provider Implementation Checklist
| Requirement | Status | Notes |
|---|---|---|
Implement Provider::send() | ✅ | Maps request/response |
Implement Provider::id() | ✅ | Unique identifier |
Implement Provider::supported_models() | ✅ | Model list |
Implement Provider::health() | ✅ | Health endpoint check |
Map all HTTP errors to ProviderError | ✅ | 429, 401, 5xx, etc. |
| Error classification is accurate | ✅ | Fatal vs transient |
Register in ProviderRegistry | ✅ | Config-driven |
| Unit tests with mock server | ✅ | 6 test cases |
What's Next?
- Build a Skill — create tool bindings that the provider can invoke
- Fallback FSM Deep Dive — full analysis of the state machine your error mapping feeds into
- Adding a Provider (Contributing Guide) — review checklist for merging