fleetforge_runtime/gateway/
openai.rs

1use std::sync::Arc;
2
3use anyhow::{anyhow, Context, Result};
4use async_trait::async_trait;
5use reqwest::{Client, Url};
6use serde::{Deserialize, Serialize};
7use serde_json::{json, Map, Value};
8use tracing::debug;
9
10use crate::gateway::LanguageModel;
11use fleetforge_prompt::{ChatMessage, ChatRole, ModelResponse, ModelUsage, ToolSpec};
12use fleetforge_telemetry::context::TraceContext;
13use fleetforge_trust::Trust;
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com";
16const DEFAULT_SCHEMA_NAME: &str = "structured_output";
17
18#[derive(Debug, Clone)]
19pub struct OpenAiConfig {
20    pub api_key: String,
21    pub default_model: String,
22    pub base_url: String,
23    pub organization: Option<String>,
24    pub schema_name: Option<String>,
25    pub client: Option<Client>,
26}
27
28impl OpenAiConfig {
29    pub fn new(api_key: impl Into<String>, default_model: impl Into<String>) -> Self {
30        Self {
31            api_key: api_key.into(),
32            default_model: default_model.into(),
33            base_url: DEFAULT_BASE_URL.to_string(),
34            organization: None,
35            schema_name: None,
36            client: None,
37        }
38    }
39
40    pub fn base_url(mut self, url: impl Into<String>) -> Self {
41        self.base_url = url.into();
42        self
43    }
44
45    pub fn organization(mut self, org: impl Into<String>) -> Self {
46        self.organization = Some(org.into());
47        self
48    }
49
50    pub fn schema_name(mut self, name: impl Into<String>) -> Self {
51        self.schema_name = Some(name.into());
52        self
53    }
54
55    pub fn client(mut self, client: Client) -> Self {
56        self.client = Some(client);
57        self
58    }
59}
60
61#[derive(Clone)]
62pub struct OpenAiLanguageModel {
63    client: Client,
64    config: Arc<OpenAiConfig>,
65    responses_url: Url,
66}
67
68impl OpenAiLanguageModel {
69    pub fn new(config: OpenAiConfig) -> Result<Self> {
70        if config.api_key.trim().is_empty() {
71            return Err(anyhow!("OpenAI api_key must not be empty"));
72        }
73        let client = config
74            .client
75            .clone()
76            .unwrap_or_else(|| Client::builder().build().expect("reqwest client"));
77        let base = Url::parse(&config.base_url).context("invalid OpenAI base_url")?;
78        let responses_url = base
79            .join("/v1/responses")
80            .context("invalid OpenAI responses endpoint")?;
81        Ok(Self {
82            client,
83            config: Arc::new(config),
84            responses_url,
85        })
86    }
87}
88
89#[async_trait]
90impl LanguageModel for OpenAiLanguageModel {
91    async fn chat(
92        &self,
93        messages: &[ChatMessage],
94        tools: Option<&[ToolSpec]>,
95        response_schema: Option<&Value>,
96        strict: bool,
97        params: &Value,
98        trace: &TraceContext,
99    ) -> Result<ModelResponse> {
100        let model = params
101            .get("model")
102            .and_then(Value::as_str)
103            .map(|s| s.to_string())
104            .unwrap_or_else(|| self.config.default_model.clone());
105
106        let provider_messages = map_messages(messages)?;
107        let tool_payload = tools.map(map_tools).transpose()?;
108        let tool_choice = parse_tool_choice(params.get("tool_choice"));
109        let temperature = params.get("temperature").and_then(Value::as_f64);
110        let mut payload = json!({
111            "model": model,
112            "input": provider_messages,
113        });
114
115        if let Some(temp) = temperature {
116            payload["temperature"] = json!(temp);
117        }
118        if let Some(tools_value) = tool_payload {
119            payload["tools"] = Value::Array(tools_value);
120        }
121        if let Some(choice) = tool_choice {
122            payload["tool_choice"] = choice;
123        }
124
125        if let Some(schema) = response_schema {
126            payload["response_format"] = json!({
127                "type": "json_schema",
128                "json_schema": {
129                    "name": self
130                        .config
131                        .schema_name
132                        .clone()
133                        .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()),
134                    "schema": schema,
135                    "strict": strict,
136                }
137            });
138        } else if strict {
139            payload["response_format"] = json!({ "type": "json_object" });
140        }
141
142        let mut request = self
143            .client
144            .post(self.responses_url.clone())
145            .bearer_auth(&self.config.api_key)
146            .header(reqwest::header::CONTENT_TYPE, "application/json");
147        if let Some(org) = &self.config.organization {
148            request = request.header("OpenAI-Organization", org);
149        }
150        for (key, value) in trace.w3c_headers() {
151            request = request.header(&key, value);
152        }
153
154        let response = request
155            .json(&payload)
156            .send()
157            .await
158            .context("openai request failed")?;
159
160        let status = response.status();
161        let body: Value = response
162            .json()
163            .await
164            .context("failed to decode openai response body")?;
165
166        if !status.is_success() {
167            return Err(anyhow!(
168                "openai error {}: {}",
169                status,
170                body.get("error")
171                    .and_then(|err| err.get("message"))
172                    .and_then(Value::as_str)
173                    .unwrap_or_else(|| body.to_string().as_str())
174            ));
175        }
176
177        let (messages_out, response_json) = parse_responses_output(&body)?;
178
179        let usage = body.get("usage").cloned();
180
181        let usage_metrics = usage.as_ref().map(|u| {
182            let prompt_tokens = u
183                .get("prompt_tokens")
184                .or_else(|| u.get("input_tokens"))
185                .and_then(Value::as_i64);
186            let completion_tokens = u
187                .get("completion_tokens")
188                .or_else(|| u.get("output_tokens"))
189                .and_then(Value::as_i64);
190            let total_tokens = u.get("total_tokens").and_then(Value::as_i64).or_else(|| {
191                match (prompt_tokens, completion_tokens) {
192                    (Some(prompt), Some(completion)) => Some(prompt + completion),
193                    _ => None,
194                }
195            });
196            let cost = u.get("total_cost").and_then(Value::as_f64);
197            ModelUsage {
198                prompt_tokens,
199                completion_tokens,
200                total_tokens,
201                cost,
202            }
203        });
204
205        debug!("openai response usage" = ?usage_metrics);
206
207        Ok(ModelResponse {
208            messages: messages_out,
209            response_json,
210            usage: usage_metrics,
211            provider: Some("openai".to_string()),
212            provider_version: body
213                .get("system_fingerprint")
214                .or_else(|| {
215                    body.get("response")
216                        .and_then(|resp| resp.get("system_fingerprint"))
217                })
218                .and_then(Value::as_str)
219                .map(|s| s.to_string()),
220            raw: Some(body),
221        })
222    }
223}
224
225fn map_messages(messages: &[ChatMessage]) -> Result<Vec<Value>> {
226    messages
227        .iter()
228        .map(|message| {
229            let role = match message.role {
230                ChatRole::System => "system",
231                ChatRole::User => "user",
232                ChatRole::Assistant => "assistant",
233                ChatRole::Tool => "tool",
234            };
235
236            let content_segments = normalize_content_segments(&message.content)?;
237
238            let mut obj = json!({
239                "role": role,
240                "content": content_segments,
241            });
242            if let Some(name) = &message.name {
243                obj["name"] = json!(name);
244            }
245            if let Some(id) = &message.tool_call_id {
246                obj["tool_call_id"] = json!(id);
247            }
248            Ok(obj)
249        })
250        .collect()
251}
252
253fn map_tools(tools: &[ToolSpec]) -> Result<Vec<Value>> {
254    tools
255        .iter()
256        .map(|tool| {
257            if tool.name.is_empty() {
258                return Err(anyhow!("tool name must not be empty"));
259            }
260            Ok(json!({
261                "type": "function",
262                "function": {
263                    "name": tool.name,
264                    "description": tool.description,
265                    "parameters": tool.schema.clone(),
266                }
267            }))
268        })
269        .collect()
270}
271
272fn parse_tool_choice(value: Option<&Value>) -> Option<Value> {
273    let raw = value?;
274    match raw {
275        Value::String(s) => match s.to_ascii_lowercase().as_str() {
276            "auto" => Some(json!("auto")),
277            "none" => Some(json!("none")),
278            other => {
279                if let Some(name) = other.strip_prefix("required:") {
280                    Some(json!({"type":"function","function":{"name": name}}))
281                } else {
282                    None
283                }
284            }
285        },
286        Value::Object(map) => Some(Value::Object(map.clone())),
287        _ => None,
288    }
289}
290
291fn normalize_content_segments(content: &Value) -> Result<Value> {
292    match content {
293        Value::Array(items) => Ok(Value::Array(items.clone())),
294        Value::Null => Ok(Value::Array(vec![json!({
295            "type": "input_text",
296            "text": ""
297        })])),
298        Value::String(text) => Ok(Value::Array(vec![json!({
299            "type": "input_text",
300            "text": text
301        })])),
302        other => {
303            let serialized = serde_json::to_string(other).unwrap_or_else(|_| other.to_string());
304            Ok(Value::Array(vec![json!({
305                "type": "input_text",
306                "text": serialized
307            })]))
308        }
309    }
310}
311
312fn parse_responses_output(body: &Value) -> Result<(Vec<ChatMessage>, Option<Value>)> {
313    let Some(outputs) = body.get("output").and_then(Value::as_array) else {
314        return Err(anyhow!("missing output field in responses payload"));
315    };
316
317    let mut messages = Vec::new();
318    let mut response_json: Option<Value> = None;
319
320    for item in outputs {
321        let role = item
322            .get("role")
323            .and_then(Value::as_str)
324            .unwrap_or("assistant");
325        let role = match role {
326            "system" => ChatRole::System,
327            "user" => ChatRole::User,
328            "tool" => ChatRole::Tool,
329            _ => ChatRole::Assistant,
330        };
331
332        let mut text_parts: Vec<String> = Vec::new();
333        let mut tool_messages: Vec<ChatMessage> = Vec::new();
334        let mut metadata = Map::new();
335
336        if let Some(content) = item.get("content") {
337            metadata.insert("segments".to_string(), content.clone());
338            if let Some(contents) = content.as_array() {
339                for segment in contents {
340                    let segment_type = segment
341                        .get("type")
342                        .and_then(Value::as_str)
343                        .unwrap_or_default();
344                    match segment_type {
345                        "output_text" | "text" | "input_text" => {
346                            if let Some(text) = segment.get("text").and_then(Value::as_str) {
347                                text_parts.push(text.to_string());
348                                if response_json.is_none() {
349                                    if let Ok(candidate) = serde_json::from_str::<Value>(text) {
350                                        if candidate.is_object() || candidate.is_array() {
351                                            response_json = Some(candidate);
352                                        }
353                                    }
354                                }
355                            }
356                        }
357                        "output_json" | "json_object" | "json_schema" => {
358                            if response_json.is_none() {
359                                if let Some(json_value) = segment.get("json") {
360                                    response_json = Some(json_value.clone());
361                                } else if let Some(output_value) = segment.get("output") {
362                                    response_json = Some(output_value.clone());
363                                } else if let Some(text) =
364                                    segment.get("text").and_then(Value::as_str)
365                                {
366                                    if let Ok(candidate) = serde_json::from_str::<Value>(text) {
367                                        response_json = Some(candidate);
368                                    }
369                                }
370                            }
371                        }
372                        "tool_calls" => {
373                            if let Some(tool_calls) =
374                                segment.get("tool_calls").and_then(Value::as_array)
375                            {
376                                for call in tool_calls {
377                                    let tool_call_id = call
378                                        .get("id")
379                                        .and_then(Value::as_str)
380                                        .map(|s| s.to_string());
381                                    let name = call
382                                        .get("function")
383                                        .and_then(|f| f.get("name"))
384                                        .and_then(Value::as_str)
385                                        .map(|s| s.to_string());
386                                    tool_messages.push(ChatMessage {
387                                        role: ChatRole::Assistant,
388                                        content: call.clone(),
389                                        name,
390                                        tool_call_id,
391                                        metadata: Some(json!({ "kind": "tool_call" })),
392                                        trust: Some(Trust::Untrusted),
393                                        trust_origin: None,
394                                    });
395                                }
396                            }
397                        }
398                        "tool_use" | "tool_call" => {
399                            let tool_call_id = segment
400                                .get("id")
401                                .and_then(Value::as_str)
402                                .map(|s| s.to_string());
403                            let name = segment
404                                .get("name")
405                                .and_then(Value::as_str)
406                                .map(|s| s.to_string());
407                            let arguments =
408                                segment.get("arguments").cloned().unwrap_or(Value::Null);
409                            let payload = json!({
410                                "id": tool_call_id,
411                                "name": name,
412                                "arguments": arguments,
413                                "type": segment_type,
414                            });
415                            tool_messages.push(ChatMessage {
416                                role: ChatRole::Assistant,
417                                content: payload,
418                                name,
419                                tool_call_id,
420                                metadata: Some(json!({ "kind": "tool_call" })),
421                                trust: Some(Trust::Untrusted),
422                                trust_origin: None,
423                            });
424                        }
425                        _ => {}
426                    }
427                }
428            }
429        }
430
431        if let Some(status) = item.get("status") {
432            metadata.insert("status".to_string(), status.clone());
433        }
434        if let Some(id) = item.get("id") {
435            metadata.insert("id".to_string(), id.clone());
436        }
437
438        let content_value = if !text_parts.is_empty() {
439            Value::String(text_parts.join(""))
440        } else {
441            item.get("content").cloned().unwrap_or(Value::Null)
442        };
443
444        let mut message = ChatMessage {
445            role,
446            content: content_value,
447            name: item
448                .get("name")
449                .and_then(Value::as_str)
450                .map(|s| s.to_string()),
451            tool_call_id: item
452                .get("id")
453                .and_then(Value::as_str)
454                .map(|s| s.to_string()),
455            metadata: if metadata.is_empty() {
456                None
457            } else {
458                Some(Value::Object(metadata))
459            },
460            trust: Some(Trust::Untrusted),
461            trust_origin: None,
462        };
463
464        messages.push(message);
465        messages.extend(tool_messages);
466    }
467
468    if response_json.is_none() {
469        if let Some(candidate) = body
470            .get("response")
471            .and_then(|resp| resp.get("output"))
472            .and_then(Value::as_array)
473            .and_then(|arr| arr.first())
474            .and_then(|item| item.get("content"))
475            .and_then(Value::as_array)
476            .and_then(|segments| {
477                segments.iter().find_map(|segment| {
478                    segment
479                        .get("json")
480                        .cloned()
481                        .or_else(|| segment.get("output").cloned())
482                })
483            })
484        {
485            response_json = Some(candidate);
486        }
487    }
488
489    Ok((messages, response_json))
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn parse_tool_choice_variants() {
498        assert_eq!(
499            parse_tool_choice(Some(&Value::String("auto".into()))),
500            Some(json!("auto"))
501        );
502        assert_eq!(
503            parse_tool_choice(Some(&Value::String("none".into()))),
504            Some(json!("none"))
505        );
506        assert_eq!(
507            parse_tool_choice(Some(&Value::String("required:weather".into()))),
508            Some(json!({"type":"function","function":{"name":"weather"}}))
509        );
510    }
511}