1- use crate :: config:: { ProviderConfig , RetryPolicy } ;
1+ use crate :: config:: RetryPolicy ;
22use crate :: error:: LlmError ;
33use crate :: types:: {
44 ChatCompletion , ChatRequest , ChatResponse , ChatStreamEvent , StreamConfig , StreamResult ,
55} ;
66use futures:: Stream ;
77use metrics:: { counter, histogram} ;
88
9- use std:: collections:: HashMap ;
109use std:: pin:: Pin ;
1110use std:: time:: { Duration , Instant } ;
1211
@@ -170,325 +169,9 @@ where
170169 }
171170}
172171
173- /// Enhanced HTTP client that handles timeout and retries
174- #[ derive( Clone ) ]
175- pub struct EnhancedHttpClient {
176- client : reqwest:: Client ,
177- retry_policy : RetryPolicy ,
178- max_retries : u32 ,
179- }
180-
181- impl EnhancedHttpClient {
182- pub fn new < C : ProviderConfig > ( config : & C ) -> Result < Self , LlmError > {
183- let client = reqwest:: Client :: builder ( )
184- . timeout ( config. timeout ( ) )
185- . build ( )
186- . map_err ( |e| LlmError :: configuration ( format ! ( "Failed to create HTTP client: {e}" ) ) ) ?;
187-
188- Ok ( Self {
189- client,
190- retry_policy : config. retry_policy ( ) ,
191- max_retries : config. max_retries ( ) ,
192- } )
193- }
194-
195- /// Execute a GET request with retry logic
196- pub async fn get_with_retry (
197- & self ,
198- url : & str ,
199- headers : & HashMap < String , String > ,
200- ) -> Result < reqwest:: Response , LlmError > {
201- let mut attempt = 0 ;
202-
203- loop {
204- let mut req = self . client . get ( url) ;
205- for ( key, value) in headers {
206- req = req. header ( key, value) ;
207- }
208-
209- match req. send ( ) . await {
210- Ok ( response) => {
211- if response. status ( ) . is_success ( ) {
212- if attempt > 0 {
213- log:: info!( "GET request to {url} succeeded after {attempt} retries" ) ;
214- }
215- return Ok ( response) ;
216- }
217-
218- // Check if we should retry based on status code
219- let status_code = response. status ( ) . as_u16 ( ) ;
220- let should_retry = status_code >= 500 || status_code == 429 ;
221-
222- if !should_retry || attempt >= self . max_retries {
223- if !should_retry {
224- log:: debug!(
225- "GET request to {url} returned non-retryable status: {status_code}"
226- ) ;
227- } else {
228- log:: warn!(
229- "GET request to {} failed after {} retries, status: {}" ,
230- url,
231- self . max_retries,
232- status_code
233- ) ;
234- }
235- return Ok ( response) ; // Return the error response
236- }
237-
238- // Calculate delay and retry
239- let delay = self . calculate_delay ( attempt, Some ( & response) ) . await ;
240- let retry_info = RetryInfo {
241- attempt : attempt + 1 ,
242- max_retries : self . max_retries ,
243- delay,
244- reason : format ! ( "HTTP {status_code}" ) ,
245- response_status : Some ( status_code) ,
246- } ;
247-
248- log:: info!(
249- "Retrying GET request to {} (attempt {}/{}): {} - waiting {:?}" ,
250- url,
251- retry_info. attempt,
252- retry_info. max_retries,
253- retry_info. reason,
254- retry_info. delay
255- ) ;
256-
257- tokio:: time:: sleep ( delay) . await ;
258- attempt += 1 ;
259- }
260- Err ( e) => {
261- // Retry on network errors
262- if attempt >= self . max_retries {
263- log:: error!(
264- "GET request to {} failed after {} retries due to network error: {}" ,
265- url,
266- self . max_retries,
267- e
268- ) ;
269- return Err ( LlmError :: network ( format ! (
270- "Request failed after {} retries: {}" ,
271- self . max_retries, e
272- ) ) ) ;
273- }
274-
275- let delay = self . calculate_delay ( attempt, None ) . await ;
276- let retry_info = RetryInfo {
277- attempt : attempt + 1 ,
278- max_retries : self . max_retries ,
279- delay,
280- reason : format ! ( "Network error: {e}" ) ,
281- response_status : None ,
282- } ;
283-
284- log:: warn!(
285- "Retrying GET request to {} due to network error (attempt {}/{}): {} - waiting {:?}" ,
286- url,
287- retry_info. attempt,
288- retry_info. max_retries,
289- retry_info. reason,
290- retry_info. delay
291- ) ;
292-
293- tokio:: time:: sleep ( delay) . await ;
294- attempt += 1 ;
295- }
296- }
297- }
298- }
299-
300- /// Execute a POST request with retry logic
301- pub async fn post_with_retry (
302- & self ,
303- url : & str ,
304- headers : & HashMap < String , String > ,
305- body : serde_json:: Value ,
306- ) -> Result < reqwest:: Response , LlmError > {
307- let mut attempt = 0 ;
308-
309- loop {
310- let mut req = self . client . post ( url) ;
311- for ( key, value) in headers {
312- req = req. header ( key, value) ;
313- }
314- req = req. json ( & body) ;
315-
316- match req. send ( ) . await {
317- Ok ( response) => {
318- if response. status ( ) . is_success ( ) {
319- if attempt > 0 {
320- log:: info!( "POST request to {url} succeeded after {attempt} retries" ) ;
321- }
322- return Ok ( response) ;
323- }
324-
325- // Check if we should retry based on status code
326- let status_code = response. status ( ) . as_u16 ( ) ;
327- let should_retry = status_code >= 500 || status_code == 429 ;
328-
329- if !should_retry || attempt >= self . max_retries {
330- if !should_retry {
331- log:: debug!(
332- "POST request to {url} returned non-retryable status: {status_code}"
333- ) ;
334- } else {
335- log:: warn!(
336- "POST request to {} failed after {} retries, status: {}" ,
337- url,
338- self . max_retries,
339- status_code
340- ) ;
341- }
342- return Ok ( response) ; // Return the error response
343- }
344-
345- // Calculate delay and retry
346- let delay = self . calculate_delay ( attempt, Some ( & response) ) . await ;
347- let retry_info = RetryInfo {
348- attempt : attempt + 1 ,
349- max_retries : self . max_retries ,
350- delay,
351- reason : format ! ( "HTTP {status_code}" ) ,
352- response_status : Some ( status_code) ,
353- } ;
354-
355- log:: info!(
356- "Retrying POST request to {} (attempt {}/{}): {} - waiting {:?}" ,
357- url,
358- retry_info. attempt,
359- retry_info. max_retries,
360- retry_info. reason,
361- retry_info. delay
362- ) ;
363-
364- tokio:: time:: sleep ( delay) . await ;
365- attempt += 1 ;
366- }
367- Err ( e) => {
368- // Retry on network errors
369- if attempt >= self . max_retries {
370- log:: error!(
371- "POST request to {} failed after {} retries due to network error: {}" ,
372- url,
373- self . max_retries,
374- e
375- ) ;
376- return Err ( LlmError :: network ( format ! (
377- "Request failed after {} retries: {}" ,
378- self . max_retries, e
379- ) ) ) ;
380- }
381-
382- let delay = self . calculate_delay ( attempt, None ) . await ;
383- let retry_info = RetryInfo {
384- attempt : attempt + 1 ,
385- max_retries : self . max_retries ,
386- delay,
387- reason : format ! ( "Network error: {e}" ) ,
388- response_status : None ,
389- } ;
390-
391- log:: warn!(
392- "Retrying POST request to {} due to network error (attempt {}/{}): {} - waiting {:?}" ,
393- url,
394- retry_info. attempt,
395- retry_info. max_retries,
396- retry_info. reason,
397- retry_info. delay
398- ) ;
399-
400- tokio:: time:: sleep ( delay) . await ;
401- attempt += 1 ;
402- }
403- }
404- }
405- }
406-
407- async fn calculate_delay (
408- & self ,
409- attempt : u32 ,
410- response : Option < & reqwest:: Response > ,
411- ) -> Duration {
412- match & self . retry_policy {
413- RetryPolicy :: Fixed { delay_ms } => Duration :: from_millis ( * delay_ms) ,
414-
415- RetryPolicy :: ExponentialBackoff {
416- initial_delay_ms,
417- max_delay_ms,
418- multiplier,
419- jitter,
420- } => {
421- let base_delay = * initial_delay_ms as f64 * multiplier. powi ( attempt as i32 ) ;
422- let delay_ms = base_delay. min ( * max_delay_ms as f64 ) as u64 ;
423-
424- let final_delay = if * jitter {
425- // Add ±25% jitter
426- let jitter_factor = 0.75 + ( rand:: random :: < f64 > ( ) * 0.5 ) ;
427- ( delay_ms as f64 * jitter_factor) as u64
428- } else {
429- delay_ms
430- } ;
431-
432- Duration :: from_millis ( final_delay)
433- }
434-
435- RetryPolicy :: ApiGuided {
436- fallback,
437- max_api_delay_ms,
438- retry_headers,
439- } => {
440- // Try to parse delay from response headers
441- if let Some ( resp) = response {
442- let headers: HashMap < String , String > = resp
443- . headers ( )
444- . iter ( )
445- . map ( |( k, v) | ( k. to_string ( ) , v. to_str ( ) . unwrap_or ( "" ) . to_string ( ) ) )
446- . collect ( ) ;
447-
448- if let Some ( api_delay) = crate :: config:: retry_parsing:: parse_retry_delay (
449- & headers,
450- retry_headers,
451- * max_api_delay_ms,
452- ) {
453- return api_delay;
454- }
455- }
456-
457- // Fall back to the configured fallback policy
458- match & * * fallback {
459- RetryPolicy :: Fixed { delay_ms } => Duration :: from_millis ( * delay_ms) ,
460- RetryPolicy :: ExponentialBackoff {
461- initial_delay_ms,
462- max_delay_ms,
463- multiplier,
464- jitter,
465- } => {
466- let base_delay = * initial_delay_ms as f64 * multiplier. powi ( attempt as i32 ) ;
467- let delay_ms = base_delay. min ( * max_delay_ms as f64 ) as u64 ;
468-
469- let final_delay = if * jitter {
470- let jitter_factor = 0.75 + ( rand:: random :: < f64 > ( ) * 0.5 ) ;
471- ( delay_ms as f64 * jitter_factor) as u64
472- } else {
473- delay_ms
474- } ;
475-
476- Duration :: from_millis ( final_delay)
477- }
478- RetryPolicy :: ApiGuided { .. } => {
479- // For nested ApiGuided policies, use simple exponential backoff
480- Duration :: from_millis ( 1000 * 2_u64 . pow ( attempt) )
481- }
482- }
483- }
484- }
485- }
486- }
487-
488172#[ cfg( test) ]
489173mod tests {
490174 use super :: * ;
491- use crate :: config:: OpenAIConfig ;
492175
493176 #[ test]
494177 fn test_retry_info_creation ( ) {
@@ -506,49 +189,6 @@ mod tests {
506189 assert_eq ! ( retry_info. reason, "HTTP 500" ) ;
507190 assert_eq ! ( retry_info. response_status, Some ( 500 ) ) ;
508191 }
509-
510- #[ test]
511- fn test_enhanced_http_client_creation ( ) {
512- let config = OpenAIConfig :: new ( "test-key" ) ;
513- let client = EnhancedHttpClient :: new ( & config) ;
514-
515- assert ! ( client. is_ok( ) ) ;
516- let client = client. unwrap ( ) ;
517- assert_eq ! ( client. max_retries, 3 ) ; // Default from config
518- }
519-
520- #[ test]
521- fn test_calculate_delay_fixed_policy ( ) {
522- let config = OpenAIConfig :: new ( "test-key" ) . with_fixed_retry ( 3 , 100 ) ;
523- let client = EnhancedHttpClient :: new ( & config) . unwrap ( ) ;
524-
525- // Use a dummy async runtime for testing delay calculation
526- let rt = tokio:: runtime:: Runtime :: new ( ) . unwrap ( ) ;
527- let delay = rt. block_on ( async { client. calculate_delay ( 0 , None ) . await } ) ;
528-
529- assert_eq ! ( delay, Duration :: from_millis( 100 ) ) ;
530- }
531-
532- #[ test]
533- fn test_calculate_delay_exponential_backoff ( ) {
534- let config =
535- OpenAIConfig :: new ( "test-key" ) . with_exponential_backoff ( 3 , 100 , 1000 , 2.0 , false ) ;
536- let client = EnhancedHttpClient :: new ( & config) . unwrap ( ) ;
537-
538- let rt = tokio:: runtime:: Runtime :: new ( ) . unwrap ( ) ;
539-
540- // Test first attempt (should be initial delay)
541- let delay1 = rt. block_on ( async { client. calculate_delay ( 0 , None ) . await } ) ;
542- assert_eq ! ( delay1, Duration :: from_millis( 100 ) ) ;
543-
544- // Test second attempt (should be 200ms = 100 * 2^1)
545- let delay2 = rt. block_on ( async { client. calculate_delay ( 1 , None ) . await } ) ;
546- assert_eq ! ( delay2, Duration :: from_millis( 200 ) ) ;
547-
548- // Test third attempt (should be 400ms = 100 * 2^2)
549- let delay3 = rt. block_on ( async { client. calculate_delay ( 2 , None ) . await } ) ;
550- assert_eq ! ( delay3, Duration :: from_millis( 400 ) ) ;
551- }
552192}
553193
554194/// Middleware configuration for LLM providers
0 commit comments