@@ -345,7 +345,7 @@ async def test_use_function_create_method(client_mode):
345345 run = hotdog_detector .create (prompt = "hello world" )
346346
347347 # Assert that run is a Run object with a prediction
348- from replicate .use import Run , AsyncRun
348+ from replicate .use import AsyncRun , Run
349349
350350 if client_mode == ClientMode .ASYNC :
351351 assert isinstance (run , AsyncRun )
@@ -621,6 +621,226 @@ async def async_iterator():
621621 assert str (result ) == "['Hello', ' ', 'World']" # str() gives list representation
622622
623623
624+ @pytest .mark .asyncio
625+ @pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
626+ @respx .mock
627+ async def test_iterator_output_returns_immediately (client_mode ):
628+ """Test that OutputIterator is returned immediately without waiting for completion."""
629+ mock_model_endpoints (
630+ versions = [
631+ create_mock_version (
632+ {
633+ "openapi_schema" : {
634+ "components" : {
635+ "schemas" : {
636+ "Output" : {
637+ "type" : "array" ,
638+ "items" : {"type" : "string" },
639+ "x-cog-array-type" : "iterator" ,
640+ "x-cog-array-display" : "concatenate" ,
641+ }
642+ }
643+ }
644+ }
645+ }
646+ )
647+ ]
648+ )
649+
650+ # Mock prediction that starts as processing (not completed)
651+ mock_prediction_endpoints (
652+ predictions = [
653+ create_mock_prediction ({"status" : "processing" , "output" : []}),
654+ create_mock_prediction ({"status" : "processing" , "output" : ["Hello" ]}),
655+ create_mock_prediction (
656+ {"status" : "succeeded" , "output" : ["Hello" , " " , "World" ]}
657+ ),
658+ ]
659+ )
660+
661+ # Call use with "acme/hotdog-detector"
662+ hotdog_detector = replicate .use (
663+ "acme/hotdog-detector" , use_async = client_mode == ClientMode .ASYNC
664+ )
665+
666+ # Get the output iterator - this should return immediately even though prediction is processing
667+ if client_mode == ClientMode .ASYNC :
668+ run = await hotdog_detector .create (prompt = "hello world" )
669+ output_iterator = await run .output ()
670+ else :
671+ run = hotdog_detector .create (prompt = "hello world" )
672+ output_iterator = run .output ()
673+
674+ # Assert that we get an OutputIterator immediately (without waiting for completion)
675+ from replicate .use import OutputIterator
676+
677+ assert isinstance (output_iterator , OutputIterator )
678+
679+ # Verify the prediction is still processing when we get the iterator
680+ assert run .prediction .status == "processing"
681+
682+
683+ @pytest .mark .asyncio
684+ @pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
685+ @respx .mock
686+ async def test_streaming_output_yields_incrementally (client_mode ):
687+ """Test that OutputIterator yields results incrementally during polling."""
688+ mock_model_endpoints (
689+ versions = [
690+ create_mock_version (
691+ {
692+ "openapi_schema" : {
693+ "components" : {
694+ "schemas" : {
695+ "Output" : {
696+ "type" : "array" ,
697+ "items" : {"type" : "string" },
698+ "x-cog-array-type" : "iterator" ,
699+ "x-cog-array-display" : "concatenate" ,
700+ }
701+ }
702+ }
703+ }
704+ }
705+ )
706+ ]
707+ )
708+
709+ # Create a prediction that will be polled multiple times
710+ prediction_id = "pred123"
711+
712+ # Mock the initial prediction creation
713+ initial_prediction = create_mock_prediction (
714+ {"id" : prediction_id , "status" : "processing" , "output" : []},
715+ prediction_id = prediction_id ,
716+ )
717+
718+ if client_mode == ClientMode .ASYNC :
719+ respx .post ("https://api.replicate.com/v1/predictions" ).mock (
720+ return_value = httpx .Response (201 , json = initial_prediction )
721+ )
722+ else :
723+ respx .post ("https://api.replicate.com/v1/predictions" ).mock (
724+ return_value = httpx .Response (201 , json = initial_prediction )
725+ )
726+
727+ # Mock incremental polling responses - each poll returns more data
728+ poll_responses = [
729+ create_mock_prediction (
730+ {"status" : "processing" , "output" : ["Hello" ]}, prediction_id = prediction_id
731+ ),
732+ create_mock_prediction (
733+ {"status" : "processing" , "output" : ["Hello" , " " ]},
734+ prediction_id = prediction_id ,
735+ ),
736+ create_mock_prediction (
737+ {"status" : "processing" , "output" : ["Hello" , " " , "streaming" ]},
738+ prediction_id = prediction_id ,
739+ ),
740+ create_mock_prediction (
741+ {"status" : "processing" , "output" : ["Hello" , " " , "streaming" , " " ]},
742+ prediction_id = prediction_id ,
743+ ),
744+ create_mock_prediction (
745+ {
746+ "status" : "succeeded" ,
747+ "output" : ["Hello" , " " , "streaming" , " " , "world!" ],
748+ },
749+ prediction_id = prediction_id ,
750+ ),
751+ ]
752+
753+ # Mock the polling endpoint to return different responses in sequence
754+ respx .get (f"https://api.replicate.com/v1/predictions/{ prediction_id } " ).mock (
755+ side_effect = [httpx .Response (200 , json = resp ) for resp in poll_responses ]
756+ )
757+
758+ # Call use with "acme/hotdog-detector"
759+ hotdog_detector = replicate .use (
760+ "acme/hotdog-detector" , use_async = client_mode == ClientMode .ASYNC
761+ )
762+
763+ # Get the output iterator immediately
764+ if client_mode == ClientMode .ASYNC :
765+ run = await hotdog_detector .create (prompt = "hello world" , use_async = True )
766+ output_iterator = await run .output ()
767+ else :
768+ run = hotdog_detector .create (prompt = "hello world" )
769+ output_iterator = run .output ()
770+
771+ # Assert that we get an OutputIterator immediately
772+ from replicate .use import OutputIterator
773+
774+ assert isinstance (output_iterator , OutputIterator )
775+
776+ # Track when we receive each item to verify incremental delivery
777+ collected_items = []
778+
779+ if client_mode == ClientMode .ASYNC :
780+ async for item in output_iterator :
781+ collected_items .append (item )
782+ # Break after we get some incremental results to verify polling works
783+ if len (collected_items ) >= 3 :
784+ break
785+ else :
786+ for item in output_iterator :
787+ collected_items .append (item )
788+ # Break after we get some incremental results to verify polling works
789+ if len (collected_items ) >= 3 :
790+ break
791+
792+ # Verify we got incremental streaming results
793+ assert len (collected_items ) >= 3
794+ # The items should be the concatenated string parts from the incremental output
795+ result = "" .join (collected_items )
796+ assert "Hello" in result # Should contain the first part we streamed
797+
798+
799+ @pytest .mark .asyncio
800+ @pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
801+ @respx .mock
802+ async def test_non_streaming_output_waits_for_completion (client_mode ):
803+ """Test that non-iterator outputs still wait for completion."""
804+ mock_model_endpoints (
805+ versions = [
806+ create_mock_version (
807+ {
808+ "openapi_schema" : {
809+ "components" : {
810+ "schemas" : {
811+ "Output" : {"type" : "string" } # Non-iterator output
812+ }
813+ }
814+ }
815+ }
816+ )
817+ ]
818+ )
819+
820+ mock_prediction_endpoints (
821+ predictions = [
822+ create_mock_prediction ({"status" : "processing" , "output" : None }),
823+ create_mock_prediction ({"status" : "succeeded" , "output" : "Final result" }),
824+ ]
825+ )
826+
827+ # Call use with "acme/hotdog-detector"
828+ hotdog_detector = replicate .use (
829+ "acme/hotdog-detector" , use_async = client_mode == ClientMode .ASYNC
830+ )
831+
832+ # For non-iterator output, this should wait for completion
833+ if client_mode == ClientMode .ASYNC :
834+ run = await hotdog_detector .create (prompt = "hello world" )
835+ output = await run .output ()
836+ else :
837+ run = hotdog_detector .create (prompt = "hello world" )
838+ output = run .output ()
839+
840+ # Should get the final result directly
841+ assert output == "Final result"
842+
843+
624844@pytest .mark .asyncio
625845@pytest .mark .parametrize ("client_mode" , [ClientMode .DEFAULT , ClientMode .ASYNC ])
626846@respx .mock
0 commit comments