Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/codex-api/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub enum ResponseEvent {
summary_index: i64,
},
RateLimits(RateLimitSnapshot),
ModelsEtag(String),
}

#[derive(Debug, Serialize, Clone)]
Expand Down
3 changes: 3 additions & 0 deletions codex-rs/codex-api/src/endpoint/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ impl Stream for AggregatedStream {
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
}
Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag))));
}
Poll::Ready(Some(Ok(ResponseEvent::Completed {
response_id,
token_usage,
Expand Down
54 changes: 24 additions & 30 deletions codex-rs/codex-api/src/endpoint/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
Expand Down Expand Up @@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<ModelsResponse, ApiError> {
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());
Expand All @@ -66,17 +67,15 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);

let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
.map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})?;

let etag = header_etag.unwrap_or(etag);

Ok(ModelsResponse { models, etag })
Ok((models, header_etag))
}
}

Expand All @@ -102,16 +101,15 @@ mod tests {
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
etag: Option<String>,
}

impl Default for CapturingTransport {
fn default() -> Self {
Self {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(ModelsResponse {
models: Vec::new(),
etag: String::new(),
}),
body: Arc::new(ModelsResponse { models: Vec::new() }),
etag: None,
}
}
}
Expand All @@ -122,8 +120,8 @@ mod tests {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
let mut headers = HeaderMap::new();
if !self.body.etag.is_empty() {
headers.insert(ETAG, self.body.etag.parse().unwrap());
if let Some(etag) = &self.etag {
headers.insert(ETAG, etag.parse().unwrap());
}
Ok(Response {
status: StatusCode::OK,
Expand Down Expand Up @@ -166,14 +164,12 @@ mod tests {

#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse {
models: Vec::new(),
etag: String::new(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
etag: None,
};

let client = ModelsClient::new(
Expand All @@ -182,12 +178,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(models.len(), 0);

let url = transport
.last_request
Expand Down Expand Up @@ -232,12 +228,12 @@ mod tests {
}))
.unwrap(),
],
etag: String::new(),
};

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
etag: None,
};

let client = ModelsClient::new(
Expand All @@ -246,27 +242,25 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");
assert_eq!(models[0].supported_in_api, true);
assert_eq!(models[0].priority, 1);
}

#[tokio::test]
async fn list_models_includes_etag() {
let response = ModelsResponse {
models: Vec::new(),
etag: "\"abc\"".to_string(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
etag: Some("\"abc\"".to_string()),
};

let client = ModelsClient::new(
Expand All @@ -275,12 +269,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(result.etag, "\"abc\"");
assert_eq!(models.len(), 0);
assert_eq!(etag, Some("\"abc\"".to_string()));
}
}
8 changes: 8 additions & 0 deletions codex-rs/codex-api/src/sse/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ pub fn spawn_response_stream(
telemetry: Option<Arc<dyn SseTelemetry>>,
) -> ResponseStream {
let rate_limits = parse_rate_limit(&stream_response.headers);
let models_etag = stream_response
.headers
.get("X-Models-Etag")
.and_then(|v| v.to_str().ok())
.map(ToString::to_string);
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(async move {
if let Some(snapshot) = rate_limits {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
}
if let Some(etag) = models_etag {
let _ = tx_event.send(Ok(ResponseEvent::ModelsEtag(etag))).await;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit annoying to have to push everything through events but that's how it is.

}
process_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
});

Expand Down
7 changes: 3 additions & 4 deletions codex-rs/codex-api/tests/models_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ async fn models_client_hits_models_endpoint() {
reasoning_summary_format: ReasoningSummaryFormat::None,
experimental_supported_tools: Vec::new(),
}],
etag: String::new(),
};

Mock::given(method("GET"))
Expand All @@ -104,13 +103,13 @@ async fn models_client_hits_models_endpoint() {
let transport = ReqwestTransport::new(reqwest::Client::new());
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);

let result = client
let (models, _) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("models request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");

let received = server
.received_requests()
Expand Down
12 changes: 9 additions & 3 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ impl Codex {

let config = Arc::new(config);
if config.features.enabled(Feature::RemoteModels)
&& let Err(err) = models_manager.refresh_available_models(&config).await
&& let Err(err) = models_manager
.refresh_available_models_with_cache(&config)
.await
{
error!("failed to refresh available models: {err:?}");
}
Expand Down Expand Up @@ -2611,6 +2613,10 @@ async fn try_run_turn(
// token usage is available to avoid duplicate TokenCount events.
sess.update_rate_limits(&turn_context, snapshot).await;
}
ResponseEvent::ModelsEtag(etag) => {
// Update internal state with latest models etag
sess.services.models_manager.refresh_if_new_etag(etag).await;
}
ResponseEvent::Completed {
response_id: _,
token_usage,
Expand Down Expand Up @@ -3138,7 +3144,7 @@ mod tests {
exec_policy,
auth_manager: auth_manager.clone(),
otel_manager: otel_manager.clone(),
models_manager,
models_manager: Arc::clone(&models_manager),
tool_approvals: Mutex::new(ApprovalStore::default()),
skills_manager,
};
Expand Down Expand Up @@ -3225,7 +3231,7 @@ mod tests {
exec_policy,
auth_manager: Arc::clone(&auth_manager),
otel_manager: otel_manager.clone(),
models_manager,
models_manager: Arc::clone(&models_manager),
tool_approvals: Mutex::new(ApprovalStore::default()),
skills_manager,
};
Expand Down
Loading
Loading