diff --git a/client_test.go b/client_test.go index 6d3ebe7..76b2d99 100644 --- a/client_test.go +++ b/client_test.go @@ -543,6 +543,84 @@ func TestCreatePrediction(t *testing.T) { assert.Equal(t, "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", prediction.URLs["stream"]) } +func TestCreatePredictionWithVersionlessModel(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + defer r.Body.Close() + + var requestBody map[string]interface{} + err = json.Unmarshal(body, &requestBody) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, map[string]interface{}{"text": "Alice"}, requestBody["input"]) + assert.Equal(t, "https://example.com/webhook", requestBody["webhook"]) + assert.Equal(t, []interface{}{"start", "completed"}, requestBody["webhook_events_filter"]) + assert.Equal(t, true, requestBody["stream"]) + + switch r.URL.Path { + case "/models/owner/model/predictions": + response := replicate.Prediction{ + ID: "ufawqhfynnddngldkgtslldrkq", + Model: "replicate/hello-world", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: "starting", + Input: map[string]interface{}{"text": "Alice"}, + Output: nil, + Error: nil, + Logs: nil, + Metrics: nil, + CreatedAt: "2022-04-26T22:13:06.224088Z", + URLs: map[string]string{ + "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + "stream": "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + }, + } + responseBytes, err := json.Marshal(response) + if err != nil { + t.Fatal(err) + } + + w.WriteHeader(http.StatusCreated) + w.Write(responseBytes) + default: + t.Fatalf("Unexpected request to %s", r.URL.Path) + } + })) + defer mockServer.Close() + + client, err := replicate.NewClient( + replicate.WithToken("test-token"), + replicate.WithBaseURL(mockServer.URL), + ) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + input := replicate.PredictionInput{"text": "Alice"} + webhook := replicate.Webhook{ + URL: "https://example.com/webhook", + Events: []replicate.WebhookEventType{"start", "completed"}, + } + prediction, err := client.CreatePrediction(ctx, "owner/model", input, &webhook, true) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, "ufawqhfynnddngldkgtslldrkq", prediction.ID) + assert.Equal(t, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", prediction.Version) + assert.Equal(t, replicate.Starting, prediction.Status) + assert.Equal(t, "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", prediction.URLs["get"]) + assert.Equal(t, "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", prediction.URLs["cancel"]) + assert.Equal(t, "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", prediction.URLs["stream"]) +} + func TestCreatePredictionWithDeployment(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodPost, r.Method) diff --git a/prediction.go b/prediction.go index 9f59fb8..4959d67 100644 --- a/prediction.go +++ b/prediction.go @@ -147,10 +147,17 @@ func (r *Client) createPredictionRequest(ctx context.Context, path string, data } // CreatePrediction creates a prediction for a specific version of a model. -func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) { +func (r *Client) CreatePrediction(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) { + // Parse the identifier to extract version + id, err := ParseIdentifier(identifier) + path := "/predictions" - data := map[string]interface{}{ - "version": version, + data := map[string]interface{}{} + // Set the model path or version in the data + if err == nil && id.Version == nil { + path = fmt.Sprintf("/models/%s/%s/predictions", id.Owner, id.Name) + } else { + data["version"] = identifier } req, err := r.createPredictionRequest(ctx, path, data, input, webhook, stream)