fleetforge_runtime/gateway/
openai.rs1use std::sync::Arc;
2
3use anyhow::{anyhow, Context, Result};
4use async_trait::async_trait;
5use reqwest::{Client, Url};
6use serde::{Deserialize, Serialize};
7use serde_json::{json, Map, Value};
8use tracing::debug;
9
10use crate::gateway::LanguageModel;
11use fleetforge_prompt::{ChatMessage, ChatRole, ModelResponse, ModelUsage, ToolSpec};
12use fleetforge_telemetry::context::TraceContext;
13use fleetforge_trust::Trust;
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com";
16const DEFAULT_SCHEMA_NAME: &str = "structured_output";
17
18#[derive(Debug, Clone)]
19pub struct OpenAiConfig {
20 pub api_key: String,
21 pub default_model: String,
22 pub base_url: String,
23 pub organization: Option<String>,
24 pub schema_name: Option<String>,
25 pub client: Option<Client>,
26}
27
28impl OpenAiConfig {
29 pub fn new(api_key: impl Into<String>, default_model: impl Into<String>) -> Self {
30 Self {
31 api_key: api_key.into(),
32 default_model: default_model.into(),
33 base_url: DEFAULT_BASE_URL.to_string(),
34 organization: None,
35 schema_name: None,
36 client: None,
37 }
38 }
39
40 pub fn base_url(mut self, url: impl Into<String>) -> Self {
41 self.base_url = url.into();
42 self
43 }
44
45 pub fn organization(mut self, org: impl Into<String>) -> Self {
46 self.organization = Some(org.into());
47 self
48 }
49
50 pub fn schema_name(mut self, name: impl Into<String>) -> Self {
51 self.schema_name = Some(name.into());
52 self
53 }
54
55 pub fn client(mut self, client: Client) -> Self {
56 self.client = Some(client);
57 self
58 }
59}
60
61#[derive(Clone)]
62pub struct OpenAiLanguageModel {
63 client: Client,
64 config: Arc<OpenAiConfig>,
65 responses_url: Url,
66}
67
68impl OpenAiLanguageModel {
69 pub fn new(config: OpenAiConfig) -> Result<Self> {
70 if config.api_key.trim().is_empty() {
71 return Err(anyhow!("OpenAI api_key must not be empty"));
72 }
73 let client = config
74 .client
75 .clone()
76 .unwrap_or_else(|| Client::builder().build().expect("reqwest client"));
77 let base = Url::parse(&config.base_url).context("invalid OpenAI base_url")?;
78 let responses_url = base
79 .join("/v1/responses")
80 .context("invalid OpenAI responses endpoint")?;
81 Ok(Self {
82 client,
83 config: Arc::new(config),
84 responses_url,
85 })
86 }
87}
88
89#[async_trait]
90impl LanguageModel for OpenAiLanguageModel {
91 async fn chat(
92 &self,
93 messages: &[ChatMessage],
94 tools: Option<&[ToolSpec]>,
95 response_schema: Option<&Value>,
96 strict: bool,
97 params: &Value,
98 trace: &TraceContext,
99 ) -> Result<ModelResponse> {
100 let model = params
101 .get("model")
102 .and_then(Value::as_str)
103 .map(|s| s.to_string())
104 .unwrap_or_else(|| self.config.default_model.clone());
105
106 let provider_messages = map_messages(messages)?;
107 let tool_payload = tools.map(map_tools).transpose()?;
108 let tool_choice = parse_tool_choice(params.get("tool_choice"));
109 let temperature = params.get("temperature").and_then(Value::as_f64);
110 let mut payload = json!({
111 "model": model,
112 "input": provider_messages,
113 });
114
115 if let Some(temp) = temperature {
116 payload["temperature"] = json!(temp);
117 }
118 if let Some(tools_value) = tool_payload {
119 payload["tools"] = Value::Array(tools_value);
120 }
121 if let Some(choice) = tool_choice {
122 payload["tool_choice"] = choice;
123 }
124
125 if let Some(schema) = response_schema {
126 payload["response_format"] = json!({
127 "type": "json_schema",
128 "json_schema": {
129 "name": self
130 .config
131 .schema_name
132 .clone()
133 .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()),
134 "schema": schema,
135 "strict": strict,
136 }
137 });
138 } else if strict {
139 payload["response_format"] = json!({ "type": "json_object" });
140 }
141
142 let mut request = self
143 .client
144 .post(self.responses_url.clone())
145 .bearer_auth(&self.config.api_key)
146 .header(reqwest::header::CONTENT_TYPE, "application/json");
147 if let Some(org) = &self.config.organization {
148 request = request.header("OpenAI-Organization", org);
149 }
150 for (key, value) in trace.w3c_headers() {
151 request = request.header(&key, value);
152 }
153
154 let response = request
155 .json(&payload)
156 .send()
157 .await
158 .context("openai request failed")?;
159
160 let status = response.status();
161 let body: Value = response
162 .json()
163 .await
164 .context("failed to decode openai response body")?;
165
166 if !status.is_success() {
167 return Err(anyhow!(
168 "openai error {}: {}",
169 status,
170 body.get("error")
171 .and_then(|err| err.get("message"))
172 .and_then(Value::as_str)
173 .unwrap_or_else(|| body.to_string().as_str())
174 ));
175 }
176
177 let (messages_out, response_json) = parse_responses_output(&body)?;
178
179 let usage = body.get("usage").cloned();
180
181 let usage_metrics = usage.as_ref().map(|u| {
182 let prompt_tokens = u
183 .get("prompt_tokens")
184 .or_else(|| u.get("input_tokens"))
185 .and_then(Value::as_i64);
186 let completion_tokens = u
187 .get("completion_tokens")
188 .or_else(|| u.get("output_tokens"))
189 .and_then(Value::as_i64);
190 let total_tokens = u.get("total_tokens").and_then(Value::as_i64).or_else(|| {
191 match (prompt_tokens, completion_tokens) {
192 (Some(prompt), Some(completion)) => Some(prompt + completion),
193 _ => None,
194 }
195 });
196 let cost = u.get("total_cost").and_then(Value::as_f64);
197 ModelUsage {
198 prompt_tokens,
199 completion_tokens,
200 total_tokens,
201 cost,
202 }
203 });
204
205 debug!("openai response usage" = ?usage_metrics);
206
207 Ok(ModelResponse {
208 messages: messages_out,
209 response_json,
210 usage: usage_metrics,
211 provider: Some("openai".to_string()),
212 provider_version: body
213 .get("system_fingerprint")
214 .or_else(|| {
215 body.get("response")
216 .and_then(|resp| resp.get("system_fingerprint"))
217 })
218 .and_then(Value::as_str)
219 .map(|s| s.to_string()),
220 raw: Some(body),
221 })
222 }
223}
224
225fn map_messages(messages: &[ChatMessage]) -> Result<Vec<Value>> {
226 messages
227 .iter()
228 .map(|message| {
229 let role = match message.role {
230 ChatRole::System => "system",
231 ChatRole::User => "user",
232 ChatRole::Assistant => "assistant",
233 ChatRole::Tool => "tool",
234 };
235
236 let content_segments = normalize_content_segments(&message.content)?;
237
238 let mut obj = json!({
239 "role": role,
240 "content": content_segments,
241 });
242 if let Some(name) = &message.name {
243 obj["name"] = json!(name);
244 }
245 if let Some(id) = &message.tool_call_id {
246 obj["tool_call_id"] = json!(id);
247 }
248 Ok(obj)
249 })
250 .collect()
251}
252
253fn map_tools(tools: &[ToolSpec]) -> Result<Vec<Value>> {
254 tools
255 .iter()
256 .map(|tool| {
257 if tool.name.is_empty() {
258 return Err(anyhow!("tool name must not be empty"));
259 }
260 Ok(json!({
261 "type": "function",
262 "function": {
263 "name": tool.name,
264 "description": tool.description,
265 "parameters": tool.schema.clone(),
266 }
267 }))
268 })
269 .collect()
270}
271
272fn parse_tool_choice(value: Option<&Value>) -> Option<Value> {
273 let raw = value?;
274 match raw {
275 Value::String(s) => match s.to_ascii_lowercase().as_str() {
276 "auto" => Some(json!("auto")),
277 "none" => Some(json!("none")),
278 other => {
279 if let Some(name) = other.strip_prefix("required:") {
280 Some(json!({"type":"function","function":{"name": name}}))
281 } else {
282 None
283 }
284 }
285 },
286 Value::Object(map) => Some(Value::Object(map.clone())),
287 _ => None,
288 }
289}
290
291fn normalize_content_segments(content: &Value) -> Result<Value> {
292 match content {
293 Value::Array(items) => Ok(Value::Array(items.clone())),
294 Value::Null => Ok(Value::Array(vec![json!({
295 "type": "input_text",
296 "text": ""
297 })])),
298 Value::String(text) => Ok(Value::Array(vec![json!({
299 "type": "input_text",
300 "text": text
301 })])),
302 other => {
303 let serialized = serde_json::to_string(other).unwrap_or_else(|_| other.to_string());
304 Ok(Value::Array(vec![json!({
305 "type": "input_text",
306 "text": serialized
307 })]))
308 }
309 }
310}
311
312fn parse_responses_output(body: &Value) -> Result<(Vec<ChatMessage>, Option<Value>)> {
313 let Some(outputs) = body.get("output").and_then(Value::as_array) else {
314 return Err(anyhow!("missing output field in responses payload"));
315 };
316
317 let mut messages = Vec::new();
318 let mut response_json: Option<Value> = None;
319
320 for item in outputs {
321 let role = item
322 .get("role")
323 .and_then(Value::as_str)
324 .unwrap_or("assistant");
325 let role = match role {
326 "system" => ChatRole::System,
327 "user" => ChatRole::User,
328 "tool" => ChatRole::Tool,
329 _ => ChatRole::Assistant,
330 };
331
332 let mut text_parts: Vec<String> = Vec::new();
333 let mut tool_messages: Vec<ChatMessage> = Vec::new();
334 let mut metadata = Map::new();
335
336 if let Some(content) = item.get("content") {
337 metadata.insert("segments".to_string(), content.clone());
338 if let Some(contents) = content.as_array() {
339 for segment in contents {
340 let segment_type = segment
341 .get("type")
342 .and_then(Value::as_str)
343 .unwrap_or_default();
344 match segment_type {
345 "output_text" | "text" | "input_text" => {
346 if let Some(text) = segment.get("text").and_then(Value::as_str) {
347 text_parts.push(text.to_string());
348 if response_json.is_none() {
349 if let Ok(candidate) = serde_json::from_str::<Value>(text) {
350 if candidate.is_object() || candidate.is_array() {
351 response_json = Some(candidate);
352 }
353 }
354 }
355 }
356 }
357 "output_json" | "json_object" | "json_schema" => {
358 if response_json.is_none() {
359 if let Some(json_value) = segment.get("json") {
360 response_json = Some(json_value.clone());
361 } else if let Some(output_value) = segment.get("output") {
362 response_json = Some(output_value.clone());
363 } else if let Some(text) =
364 segment.get("text").and_then(Value::as_str)
365 {
366 if let Ok(candidate) = serde_json::from_str::<Value>(text) {
367 response_json = Some(candidate);
368 }
369 }
370 }
371 }
372 "tool_calls" => {
373 if let Some(tool_calls) =
374 segment.get("tool_calls").and_then(Value::as_array)
375 {
376 for call in tool_calls {
377 let tool_call_id = call
378 .get("id")
379 .and_then(Value::as_str)
380 .map(|s| s.to_string());
381 let name = call
382 .get("function")
383 .and_then(|f| f.get("name"))
384 .and_then(Value::as_str)
385 .map(|s| s.to_string());
386 tool_messages.push(ChatMessage {
387 role: ChatRole::Assistant,
388 content: call.clone(),
389 name,
390 tool_call_id,
391 metadata: Some(json!({ "kind": "tool_call" })),
392 trust: Some(Trust::Untrusted),
393 trust_origin: None,
394 });
395 }
396 }
397 }
398 "tool_use" | "tool_call" => {
399 let tool_call_id = segment
400 .get("id")
401 .and_then(Value::as_str)
402 .map(|s| s.to_string());
403 let name = segment
404 .get("name")
405 .and_then(Value::as_str)
406 .map(|s| s.to_string());
407 let arguments =
408 segment.get("arguments").cloned().unwrap_or(Value::Null);
409 let payload = json!({
410 "id": tool_call_id,
411 "name": name,
412 "arguments": arguments,
413 "type": segment_type,
414 });
415 tool_messages.push(ChatMessage {
416 role: ChatRole::Assistant,
417 content: payload,
418 name,
419 tool_call_id,
420 metadata: Some(json!({ "kind": "tool_call" })),
421 trust: Some(Trust::Untrusted),
422 trust_origin: None,
423 });
424 }
425 _ => {}
426 }
427 }
428 }
429 }
430
431 if let Some(status) = item.get("status") {
432 metadata.insert("status".to_string(), status.clone());
433 }
434 if let Some(id) = item.get("id") {
435 metadata.insert("id".to_string(), id.clone());
436 }
437
438 let content_value = if !text_parts.is_empty() {
439 Value::String(text_parts.join(""))
440 } else {
441 item.get("content").cloned().unwrap_or(Value::Null)
442 };
443
444 let mut message = ChatMessage {
445 role,
446 content: content_value,
447 name: item
448 .get("name")
449 .and_then(Value::as_str)
450 .map(|s| s.to_string()),
451 tool_call_id: item
452 .get("id")
453 .and_then(Value::as_str)
454 .map(|s| s.to_string()),
455 metadata: if metadata.is_empty() {
456 None
457 } else {
458 Some(Value::Object(metadata))
459 },
460 trust: Some(Trust::Untrusted),
461 trust_origin: None,
462 };
463
464 messages.push(message);
465 messages.extend(tool_messages);
466 }
467
468 if response_json.is_none() {
469 if let Some(candidate) = body
470 .get("response")
471 .and_then(|resp| resp.get("output"))
472 .and_then(Value::as_array)
473 .and_then(|arr| arr.first())
474 .and_then(|item| item.get("content"))
475 .and_then(Value::as_array)
476 .and_then(|segments| {
477 segments.iter().find_map(|segment| {
478 segment
479 .get("json")
480 .cloned()
481 .or_else(|| segment.get("output").cloned())
482 })
483 })
484 {
485 response_json = Some(candidate);
486 }
487 }
488
489 Ok((messages, response_json))
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn parse_tool_choice_variants() {
498 assert_eq!(
499 parse_tool_choice(Some(&Value::String("auto".into()))),
500 Some(json!("auto"))
501 );
502 assert_eq!(
503 parse_tool_choice(Some(&Value::String("none".into()))),
504 Some(json!("none"))
505 );
506 assert_eq!(
507 parse_tool_choice(Some(&Value::String("required:weather".into()))),
508 Some(json!({"type":"function","function":{"name":"weather"}}))
509 );
510 }
511}