From 6db6f28bb0f952889a3035a6baca385c14d2091a Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 07:10:36 -0800 Subject: [PATCH 01/23] added new `Middleware` --- genkit-tools/common/src/types/index.ts | 1 + genkit-tools/common/src/types/middleware.ts | 37 +++ genkit-tools/common/src/types/model.ts | 3 + go/ai/gen.go | 2 + go/ai/generate.go | 125 ++++++++++- go/ai/middleware.go | 155 +++++++++++++ go/ai/middleware_test.go | 237 ++++++++++++++++++++ go/ai/option.go | 18 +- go/ai/prompt.go | 22 ++ go/genkit/genkit.go | 10 + go/genkit/reflection.go | 22 ++ 11 files changed, 628 insertions(+), 4 deletions(-) create mode 100644 genkit-tools/common/src/types/middleware.ts create mode 100644 go/ai/middleware.go create mode 100644 go/ai/middleware_test.go diff --git a/genkit-tools/common/src/types/index.ts b/genkit-tools/common/src/types/index.ts index ea12971f0e..360546af0e 100644 --- a/genkit-tools/common/src/types/index.ts +++ b/genkit-tools/common/src/types/index.ts @@ -23,6 +23,7 @@ export * from './document'; export * from './env'; export * from './eval'; export * from './evaluator'; +export * from './middleware'; export * from './model'; export * from './prompt'; export * from './retriever'; diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts new file mode 100644 index 0000000000..7f41af991e --- /dev/null +++ b/genkit-tools/common/src/types/middleware.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { z } from 'zod'; +import { JSONSchema7Schema } from './action'; + +/** Descriptor for a registered middleware, returned by reflection API. */ +export const MiddlewareDescSchema = z.object({ + /** Unique name of the middleware. */ + name: z.string(), + /** Human-readable description of what the middleware does. */ + description: z.string().optional(), + /** JSON Schema for the middleware's configuration. */ + configSchema: JSONSchema7Schema.optional(), +}); +export type MiddlewareDesc = z.infer; + +/** Reference to a registered middleware with optional configuration. */ +export const MiddlewareRefSchema = z.object({ + /** Name of the registered middleware. */ + name: z.string(), + /** Configuration for the middleware (schema defined by the middleware). */ + config: z.any().optional(), +}); +export type MiddlewareRef = z.infer; diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index a36d9f288f..62fa83dedb 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -15,6 +15,7 @@ */ import { z } from 'zod'; import { DocumentDataSchema } from './document'; +import { MiddlewareRefSchema } from './middleware'; import { CustomPartSchema, DataPartSchema, @@ -399,5 +400,7 @@ export const GenerateActionOptionsSchema = z.object({ maxTurns: z.number().optional(), /** Custom step name for this generate call to display in trace views. Defaults to "generate". */ stepName: z.string().optional(), + /** Middleware to apply to this generation. */ + use: z.array(MiddlewareRefSchema).optional(), }); export type GenerateActionOptions = z.infer; diff --git a/go/ai/gen.go b/go/ai/gen.go index e391ef2215..6f6ee06a19 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -96,6 +96,8 @@ type GenerateActionOptions struct { ToolChoice ToolChoice `json:"toolChoice,omitempty"` // Tools is a list of registered tool names for this generation if supported. Tools []string `json:"tools,omitempty"` + // Use is middleware to apply to this generation, referenced by name with optional config. + Use []*MiddlewareRef `json:"use,omitempty"` } // GenerateActionResume holds options for resuming an interrupted generation. diff --git a/go/ai/generate.go b/go/ai/generate.go index 003eb0b653..64aad71d33 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -67,6 +67,8 @@ type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelRespons type ModelStreamCallback = func(context.Context, *ModelResponseChunk) error // ModelMiddleware is middleware for model generate requests that takes in a ModelFunc, does something, then returns another ModelFunc. +// +// Deprecated: Use [Middleware] interface with [WithUse] instead, which supports Generate, Model, and Tool hooks. type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk] // model is an action with functions specific to model generation such as Generate(). @@ -313,6 +315,27 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } + // Resolve middleware from Use refs. + var middlewareHandlers []Middleware + if len(opts.Use) > 0 { + middlewareHandlers = make([]Middleware, 0, len(opts.Use)) + for _, ref := range opts.Use { + desc := LookupMiddleware(r, ref.Name) + if desc == nil { + return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) + } + configJSON, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) + } + handler, err := desc.configFromJSON(configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) + } + middlewareHandlers = append(middlewareHandlers, handler) + } + } + fn := m.Generate if bm != nil { if cb != nil { @@ -320,6 +343,24 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } fn = backgroundModelToModelFn(bm.Start) } + + // Apply Model hooks from new middleware as a ModelMiddleware, then chain with legacy mw. + if len(middlewareHandlers) > 0 { + modelHook := func(next ModelFunc) ModelFunc { + wrapped := next + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + inner := wrapped + wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return h.Model(ctx, &ModelState{Request: req, Callback: cb}, func(ctx context.Context, state *ModelState) (*ModelResponse, error) { + return inner(ctx, state.Request, state.Callback) + }) + } + } + return wrapped + } + mw = append([]ModelMiddleware{modelHook}, mw...) + } fn = core.ChainMiddleware(mw...)(fn) // Inline recursive helper function that captures variables from parent scope. @@ -388,7 +429,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, middlewareHandlers) if err != nil { return nil, err } @@ -406,6 +447,28 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi }) } + // Wrap generate with the Generate hook chain from middleware. + if len(middlewareHandlers) > 0 { + innerGenerate := generate + generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + innerFn := func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return innerGenerate(ctx, state.Request, currentTurn, messageIndex) + } + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + next := innerFn + innerFn = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return h.Generate(ctx, state, next) + } + } + return innerFn(ctx, &GenerateState{ + Options: opts, + Request: req, + Iteration: currentTurn, + }) + } + } + return generate(ctx, req, 0, 0) } @@ -535,6 +598,28 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } + // Register dynamic middleware (like dynamic tools) and build MiddlewareRefs. + if len(genOpts.Use) > 0 { + for _, mw := range genOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + NewMiddleware("", mw).Register(r) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to marshal middleware %q config: %v", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to unmarshal middleware %q config: %v", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + // Process resources in messages processedMessages, err := processResources(ctx, r, messages) if err != nil { @@ -773,7 +858,7 @@ func clone[T any](obj *T) *T { // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling. -func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) { +func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, middlewareHandlers []Middleware) (*ModelRequest, *Message, error) { toolCount := len(resp.ToolRequests()) if toolCount == 0 { return nil, nil, nil @@ -796,7 +881,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input) + multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, middlewareHandlers) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -879,6 +964,39 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return newReq, nil, nil } +// runToolWithMiddleware runs a tool, wrapping the execution with Tool hooks from middleware. +func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) { + if len(handlers) == 0 { + return tool.RunRawMultipart(ctx, toolReq.Input) + } + + inner := func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + resp, err := state.Tool.RunRawMultipart(ctx, state.Request.Input) + if err != nil { + return nil, err + } + return &ToolResponse{ + Name: state.Request.Name, + Output: resp.Output, + }, nil + } + + for i := len(handlers) - 1; i >= 0; i-- { + h := handlers[i] + next := inner + inner = func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + return h.Tool(ctx, state, next) + } + } + + toolResp, err := inner(ctx, &ToolState{Request: toolReq, Tool: tool}) + if err != nil { + return nil, err + } + + return &MultipartToolResponse{Output: toolResp.Output}, nil +} + // Text returns the contents of the first candidate in a // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. @@ -1357,6 +1475,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc Docs: genOpts.Docs, ReturnToolRequests: genOpts.ReturnToolRequests, Output: genOpts.Output, + Use: genOpts.Use, }, toolMessage: toolMessage, }, nil diff --git a/go/ai/middleware.go b/go/ai/middleware.go new file mode 100644 index 0000000000..71d5b93d1a --- /dev/null +++ b/go/ai/middleware.go @@ -0,0 +1,155 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" +) + +// Middleware provides hooks for different stages of generation. +type Middleware interface { + // Name returns the middleware's unique identifier. + Name() string + // New returns a fresh instance for each ai.Generate() call, enabling per-invocation state. + New() Middleware + // Generate wraps each iteration of the tool loop. + Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) + // Model wraps each model API call. + Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) + // Tool wraps each tool execution. + Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) +} + +// GenerateState holds state for the Generate hook. +type GenerateState struct { + // Options is the original options passed to [Generate]. + Options *GenerateActionOptions + // Request is the current model request for this iteration, with accumulated messages. + Request *ModelRequest + // Iteration is the current tool-loop iteration (0-indexed). + Iteration int +} + +// ModelState holds state for the Model hook. +type ModelState struct { + // Request is the model request about to be sent. + Request *ModelRequest + // Callback is the streaming callback, or nil if not streaming. + Callback ModelStreamCallback +} + +// ToolState holds state for the Tool hook. +type ToolState struct { + // Request is the tool request about to be executed. + Request *ToolRequest + // Tool is the resolved tool being called. + Tool Tool +} + +// GenerateNext is the next function in the Generate hook chain. +type GenerateNext = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) + +// ModelNext is the next function in the Model hook chain. +type ModelNext = func(ctx context.Context, state *ModelState) (*ModelResponse, error) + +// ToolNext is the next function in the Tool hook chain. +type ToolNext = func(ctx context.Context, state *ToolState) (*ToolResponse, error) + +// BaseMiddleware provides default pass-through for the three hooks. +// Embed this so you only need to implement Name() and New(). +type BaseMiddleware struct{} + +func (b *BaseMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + return next(ctx, state) +} + +// MiddlewareDesc is the registered descriptor for a middleware. +type MiddlewareDesc struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + ConfigSchema map[string]any `json:"configSchema,omitempty"` + configFromJSON func([]byte) (Middleware, error) +} + +// Register registers the descriptor with the registry. +func (d *MiddlewareDesc) Register(r api.Registry) { + r.RegisterValue("/middleware/"+d.Name, d) +} + +// NewMiddleware creates a middleware descriptor without registering it. +// The prototype carries stable state; configFromJSON calls prototype.New() +// then unmarshals user config on top. +func NewMiddleware[T Middleware](description string, prototype T) *MiddlewareDesc { + return &MiddlewareDesc{ + Name: prototype.Name(), + Description: description, + ConfigSchema: core.InferSchemaMap(*new(T)), + configFromJSON: func(configJSON []byte) (Middleware, error) { + inst := prototype.New() + if len(configJSON) > 0 { + if err := json.Unmarshal(configJSON, inst); err != nil { + return nil, fmt.Errorf("middleware %q: %w", prototype.Name(), err) + } + } + return inst, nil + }, + } +} + +// DefineMiddleware creates and registers a middleware descriptor. +func DefineMiddleware[T Middleware](r api.Registry, description string, prototype T) *MiddlewareDesc { + d := NewMiddleware(description, prototype) + d.Register(r) + return d +} + +// LookupMiddleware looks up a registered middleware descriptor by name. +func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { + v := r.LookupValue("/middleware/" + name) + if v == nil { + return nil + } + d, ok := v.(*MiddlewareDesc) + if !ok { + return nil + } + return d +} + +// MiddlewareRef is a serializable reference to a registered middleware with config. +type MiddlewareRef struct { + Name string `json:"name"` + Config any `json:"config,omitempty"` +} + +// MiddlewarePlugin is implemented by plugins that provide middleware. +type MiddlewarePlugin interface { + ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) +} diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go new file mode 100644 index 0000000000..0613e3e63e --- /dev/null +++ b/go/ai/middleware_test.go @@ -0,0 +1,237 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "sync/atomic" + "testing" +) + +// testMiddleware is a simple middleware for testing that tracks hook invocations. +type testMiddleware struct { + BaseMiddleware + Label string `json:"label"` + generateCalls int + modelCalls int + toolCalls int32 // atomic since tool hooks run in parallel +} + +func (m *testMiddleware) Name() string { return "test" } + +func (m *testMiddleware) New() Middleware { + return &testMiddleware{Label: m.Label} +} + +func (m *testMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + m.generateCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + m.modelCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + atomic.AddInt32(&m.toolCalls, 1) + return next(ctx, state) +} + +func TestNewMiddleware(t *testing.T) { + proto := &testMiddleware{Label: "original"} + desc := NewMiddleware("test middleware", proto) + + if desc.Name != "test" { + t.Errorf("got name %q, want %q", desc.Name, "test") + } + if desc.Description != "test middleware" { + t.Errorf("got description %q, want %q", desc.Description, "test middleware") + } +} + +func TestDefineAndLookupMiddleware(t *testing.T) { + r := newTestRegistry(t) + proto := &testMiddleware{Label: "original"} + DefineMiddleware(r, "test middleware", proto) + + found := LookupMiddleware(r, "test") + if found == nil { + t.Fatal("expected to find middleware, got nil") + } + if found.Name != "test" { + t.Errorf("got name %q, want %q", found.Name, "test") + } +} + +func TestLookupMiddlewareNotFound(t *testing.T) { + r := newTestRegistry(t) + found := LookupMiddleware(r, "nonexistent") + if found != nil { + t.Errorf("expected nil, got %v", found) + } +} + +func TestConfigFromJSON(t *testing.T) { + proto := &testMiddleware{Label: "stable"} + desc := NewMiddleware("test middleware", proto) + + handler, err := desc.configFromJSON([]byte(`{"label": "custom"}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + tm, ok := handler.(*testMiddleware) + if !ok { + t.Fatalf("expected *testMiddleware, got %T", handler) + } + if tm.Label != "custom" { + t.Errorf("got label %q, want %q", tm.Label, "custom") + } + // Per-request state should be zeroed by New() + if tm.generateCalls != 0 { + t.Errorf("got generateCalls %d, want 0", tm.generateCalls) + } +} + +func TestConfigFromJSONPreservesStableState(t *testing.T) { + // Simulate a plugin middleware with unexported stable state + proto := &stableStateMiddleware{apiKey: "secret123"} + desc := NewMiddleware("middleware with stable state", proto) + + handler, err := desc.configFromJSON([]byte(`{"sampleRate": 0.5}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + sm, ok := handler.(*stableStateMiddleware) + if !ok { + t.Fatalf("expected *stableStateMiddleware, got %T", handler) + } + if sm.apiKey != "secret123" { + t.Errorf("got apiKey %q, want %q", sm.apiKey, "secret123") + } + if sm.SampleRate != 0.5 { + t.Errorf("got SampleRate %f, want 0.5", sm.SampleRate) + } +} + +func TestMiddlewareModelHook(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + DefineMiddleware(r, "tracks calls", &testMiddleware{}) + + resp, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) + if resp == nil { + t.Fatal("expected response, got nil") + } +} + +func TestMiddlewareToolHook(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + mw := &testMiddleware{} + DefineMiddleware(r, "tracks calls", mw) + + _, err := Generate(ctx, r, + WithModelName("test/toolModel"), + WithPrompt("use the tool"), + WithTools(ToolName("myTool")), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) +} + +func TestMiddlewareOrdering(t *testing.T) { + // First middleware is outermost + var order []string + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + mwA := &orderMiddleware{label: "A", order: &order} + mwB := &orderMiddleware{label: "B", order: &order} + DefineMiddleware(r, "middleware A", mwA) + DefineMiddleware(r, "middleware B", mwB) + + _, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse( + &orderMiddleware{label: "A", order: &order}, + &orderMiddleware{label: "B", order: &order}, + ), + ) + assertNoError(t, err) + + // Expect: A-before, B-before, B-after, A-after (first is outermost) + want := []string{"A-model-before", "B-model-before", "B-model-after", "A-model-after"} + if len(order) != len(want) { + t.Fatalf("got order %v, want %v", order, want) + } + for i := range want { + if order[i] != want[i] { + t.Errorf("order[%d] = %q, want %q", i, order[i], want[i]) + } + } +} + +// --- helper middleware types for tests --- + +// stableStateMiddleware has unexported stable state preserved by New(). +type stableStateMiddleware struct { + BaseMiddleware + SampleRate float64 `json:"sampleRate"` + apiKey string +} + +func (m *stableStateMiddleware) Name() string { return "stableState" } + +func (m *stableStateMiddleware) New() Middleware { + return &stableStateMiddleware{apiKey: m.apiKey} +} + +// orderMiddleware tracks the order of Model hook invocations. +type orderMiddleware struct { + BaseMiddleware + label string + order *[]string +} + +func (m *orderMiddleware) Name() string { return "order-" + m.label } + +func (m *orderMiddleware) New() Middleware { + return &orderMiddleware{label: m.label, order: m.order} +} + +func (m *orderMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + *m.order = append(*m.order, m.label+"-model-before") + resp, err := next(ctx, state) + *m.order = append(*m.order, m.label+"-model-after") + return resp, err +} + +var ctx = context.Background() diff --git a/go/ai/option.go b/go/ai/option.go index d28c68e3e9..84019b11d7 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -109,7 +109,8 @@ type commonGenOptions struct { ToolChoice ToolChoice // Whether tool calls are required, disabled, or optional. MaxTurns int // Maximum number of tool call iterations. ReturnToolRequests *bool // Whether to return tool requests instead of making the tool calls and continuing the generation. - Middleware []ModelMiddleware // Middleware to apply to the model request and model response. + Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. + Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). } type CommonGenOption interface { @@ -181,6 +182,13 @@ func (o *commonGenOptions) applyCommonGen(opts *commonGenOptions) error { opts.Middleware = o.Middleware } + if o.Use != nil { + if opts.Use != nil { + return errors.New("cannot set middleware more than once (WithUse)") + } + opts.Use = o.Use + } + return nil } @@ -233,10 +241,18 @@ func WithModelName(name string) CommonGenOption { } // WithMiddleware sets middleware to apply to the model request. +// +// Deprecated: Use [WithUse] instead, which supports Generate, Model, and Tool hooks. func WithMiddleware(middleware ...ModelMiddleware) CommonGenOption { return &commonGenOptions{Middleware: middleware} } +// WithUse sets middleware to apply to generation. Middleware hooks wrap +// the generate loop, model calls, and tool executions. +func WithUse(middleware ...Middleware) CommonGenOption { + return &commonGenOptions{Use: middleware} +} + // WithMaxTurns sets the maximum number of tool call iterations before erroring. // A tool call happens when tools are provided in the request and a model decides to call one or more as a response. // Each round trip, including multiple tools in parallel, counts as one turn. diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 4d0151c4c8..88c36e0cd7 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,6 +249,28 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } + // Register dynamic middleware and build MiddlewareRefs. + if len(execOpts.Use) > 0 { + for _, mw := range execOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + NewMiddleware("", mw).Register(r) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to marshal middleware %q config: %w", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to unmarshal middleware %q config: %w", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 377fb5e836..8fd32913c2 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -228,6 +228,16 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { action.Register(r) } r.RegisterPlugin(plugin.Name(), plugin) + + if mp, ok := plugin.(ai.MiddlewarePlugin); ok { + descs, err := mp.ListMiddleware(ctx) + if err != nil { + panic(fmt.Errorf("genkit.Init: plugin %q ListMiddleware failed: %w", plugin.Name(), err)) + } + for _, d := range descs { + d.Register(r) + } + } } ai.ConfigureFormats(r) diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 1bd675f75a..9936936e61 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -303,6 +303,7 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions))) mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify())) mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions))) + mux.HandleFunc("GET /api/values", wrapReflectionHandler(handleListValues(g))) return mux } @@ -598,6 +599,27 @@ func handleListActions(g *Genkit) func(w http.ResponseWriter, r *http.Request) e } } +// handleListValues returns registered values filtered by type query parameter. +// Matches JS: GET /api/values?type=middleware +func handleListValues(g *Genkit) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + valueType := r.URL.Query().Get("type") + if valueType == "" { + http.Error(w, `query parameter "type" is required`, http.StatusBadRequest) + return nil + } + prefix := "/" + valueType + "/" + result := map[string]any{} + for key, val := range g.reg.ListValues() { + if strings.HasPrefix(key, prefix) { + name := strings.TrimPrefix(key, prefix) + result[name] = val + } + } + return writeJSON(r.Context(), w, result) + } +} + // listActions lists all the registered actions. func listActions(g *Genkit) []api.ActionDesc { ads := []api.ActionDesc{} From a21b8e6c8bf36ab54e3c79d7ce234b0ff9de5e17 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 07:49:54 -0800 Subject: [PATCH 02/23] updated Genkit schema --- genkit-tools/genkit-schema.json | 48 +++++++++++++++++++++++++ genkit-tools/scripts/schema-exporter.ts | 1 + 2 files changed, 49 insertions(+) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 26cc4fbf4f..2cbd939736 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -270,6 +270,48 @@ ], "additionalProperties": false }, + "MiddlewareDesc": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "configSchema": { + "anyOf": [ + { + "type": "object", + "properties": {}, + "additionalProperties": false, + "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + }, + { + "type": "null" + } + ], + "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + } + }, + "required": [ + "name" + ], + "additionalProperties": false + }, + "MiddlewareRef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "config": {} + }, + "required": [ + "name" + ], + "additionalProperties": false + }, "CandidateError": { "type": "object", "properties": { @@ -466,6 +508,12 @@ }, "stepName": { "type": "string" + }, + "use": { + "type": "array", + "items": { + "$ref": "#/$defs/MiddlewareRef" + } } }, "required": [ diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 48df79b56a..7462a12a8d 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -26,6 +26,7 @@ const EXPORTED_TYPE_MODULES = [ '../common/src/types/embedder.ts', '../common/src/types/evaluator.ts', '../common/src/types/error.ts', + '../common/src/types/middleware.ts', '../common/src/types/model.ts', '../common/src/types/parts.ts', '../common/src/types/reranker.ts', From 79382d5c80f5e934bb855904f8f63105ad287ef5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:02:14 -0800 Subject: [PATCH 03/23] updated common schema --- genkit-tools/common/src/types/middleware.ts | 3 +- genkit-tools/genkit-schema.json | 7 +--- go/ai/gen.go | 19 +++++++++ go/ai/middleware.go | 17 ++------ go/core/schemas.config | 43 +++++++++++++++++++++ 5 files changed, 68 insertions(+), 21 deletions(-) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 7f41af991e..4bb1297ede 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -14,7 +14,6 @@ * limitations under the License. */ import { z } from 'zod'; -import { JSONSchema7Schema } from './action'; /** Descriptor for a registered middleware, returned by reflection API. */ export const MiddlewareDescSchema = z.object({ @@ -23,7 +22,7 @@ export const MiddlewareDescSchema = z.object({ /** Human-readable description of what the middleware does. */ description: z.string().optional(), /** JSON Schema for the middleware's configuration. */ - configSchema: JSONSchema7Schema.optional(), + configSchema: z.record(z.any()).nullish(), }); export type MiddlewareDesc = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 2cbd939736..d808e46df9 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -283,15 +283,12 @@ "anyOf": [ { "type": "object", - "properties": {}, - "additionalProperties": false, - "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + "additionalProperties": {} }, { "type": "null" } - ], - "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + ] } }, "required": [ diff --git a/go/ai/gen.go b/go/ai/gen.go index 6f6ee06a19..963b8cd737 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -208,6 +208,25 @@ type Message struct { Role Role `json:"role,omitempty"` } +// MiddlewareDesc is the registered descriptor for a middleware. +type MiddlewareDesc struct { + // ConfigSchema is a JSON Schema describing the middleware's configuration. + ConfigSchema map[string]any `json:"configSchema,omitempty"` + // Description explains what the middleware does. + Description string `json:"description,omitempty"` + // Name is the middleware's unique identifier. + Name string `json:"name,omitempty"` + configFromJSON middlewareConfigFunc +} + +// MiddlewareRef is a serializable reference to a registered middleware with config. +type MiddlewareRef struct { + // Config contains the middleware configuration. + Config any `json:"config,omitempty"` + // Name is the name of the registered middleware. + Name string `json:"name,omitempty"` +} + // ModelInfo contains metadata about a model's capabilities and characteristics. type ModelInfo struct { // ConfigSchema defines the model-specific configuration schema. diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 71d5b93d1a..35b2faf37f 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -25,6 +25,9 @@ import ( "github.com/firebase/genkit/go/core/api" ) +// middlewareConfigFunc creates a Middleware instance from JSON config. +type middlewareConfigFunc = func([]byte) (Middleware, error) + // Middleware provides hooks for different stages of generation. type Middleware interface { // Name returns the middleware's unique identifier. @@ -90,14 +93,6 @@ func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNe return next(ctx, state) } -// MiddlewareDesc is the registered descriptor for a middleware. -type MiddlewareDesc struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - ConfigSchema map[string]any `json:"configSchema,omitempty"` - configFromJSON func([]byte) (Middleware, error) -} - // Register registers the descriptor with the registry. func (d *MiddlewareDesc) Register(r api.Registry) { r.RegisterValue("/middleware/"+d.Name, d) @@ -143,12 +138,6 @@ func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { return d } -// MiddlewareRef is a serializable reference to a registered middleware with config. -type MiddlewareRef struct { - Name string `json:"name"` - Config any `json:"config,omitempty"` -} - // MiddlewarePlugin is implemented by plugins that provide middleware. type MiddlewarePlugin interface { ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) diff --git a/go/core/schemas.config b/go/core/schemas.config index 70798f2eb3..2fe8cc6d54 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -732,6 +732,10 @@ StepName is a custom step name for this generate call to display in trace views. Defaults to "generate". . +GenerateActionOptions.use doc +Use is middleware to apply to this generation, referenced by name with optional config. +. + GenerateActionOptionsResume doc GenerateActionResume holds options for resuming an interrupted generation. . @@ -840,6 +844,38 @@ PathMetadata.error doc Error contains error information if the path failed. . +# ---------------------------------------------------------------------------- +# Middleware Types +# ---------------------------------------------------------------------------- + +MiddlewareDesc doc +MiddlewareDesc is the registered descriptor for a middleware. +. + +MiddlewareDesc.name doc +Name is the middleware's unique identifier. +. + +MiddlewareDesc.description doc +Description explains what the middleware does. +. + +MiddlewareDesc.configSchema doc +ConfigSchema is a JSON Schema describing the middleware's configuration. +. + +MiddlewareRef doc +MiddlewareRef is a serializable reference to a registered middleware with config. +. + +MiddlewareRef.name doc +Name is the name of the registered middleware. +. + +MiddlewareRef.config doc +Config contains the middleware configuration. +. + # ---------------------------------------------------------------------------- # Multipart Tool Response # ---------------------------------------------------------------------------- @@ -1060,6 +1096,7 @@ GenerateActionOptions.config type any GenerateActionOptions.output type *GenerateActionOutputConfig GenerateActionOptions.returnToolRequests type bool GenerateActionOptions.maxTurns type int +GenerateActionOptions.use type []*MiddlewareRef GenerateActionOptionsResume name GenerateActionResume # GenerateActionOutputConfig @@ -1101,6 +1138,12 @@ ModelResponseChunk.index type int ModelResponseChunk.role type Role ModelResponseChunk field formatHandler StreamingFormatHandler +# Middleware +MiddlewareDesc pkg ai +MiddlewareDesc.configSchema type map[string]any +MiddlewareDesc field configFromJSON middlewareConfigFunc +MiddlewareRef pkg ai + Score omit Embedding.embedding type []float32 From 02aec1cf6e3e5ba4dd4811a47351049b68fa576a Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:03:47 -0800 Subject: [PATCH 04/23] Update generate.go --- go/ai/generate.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 64aad71d33..70d53e8da5 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -976,8 +976,9 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return nil, err } return &ToolResponse{ - Name: state.Request.Name, - Output: resp.Output, + Name: state.Request.Name, + Output: resp.Output, + Content: resp.Content, }, nil } @@ -994,7 +995,7 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return nil, err } - return &MultipartToolResponse{Output: toolResp.Output}, nil + return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content}, nil } // Text returns the contents of the first candidate in a From c24403529cfc36e04822fd4313fd40c29a727789 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:04:21 -0800 Subject: [PATCH 05/23] Update middleware_test.go --- go/ai/middleware_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index 0613e3e63e..a0f9f935ea 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -25,7 +25,7 @@ import ( // testMiddleware is a simple middleware for testing that tracks hook invocations. type testMiddleware struct { BaseMiddleware - Label string `json:"label"` + Label string `json:"label"` generateCalls int modelCalls int toolCalls int32 // atomic since tool hooks run in parallel @@ -149,7 +149,7 @@ func TestMiddlewareModelHook(t *testing.T) { func TestMiddlewareToolHook(t *testing.T) { r := newTestRegistry(t) defineFakeModel(t, r, fakeModelConfig{ - name: "test/toolModel", + name: "test/toolModel", handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), }) defineFakeTool(t, r, "myTool", "A test tool") From ebb1d5c59723ba87e3537f32b3a184d307543835 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:06:30 -0800 Subject: [PATCH 06/23] Update typing.py --- py/packages/genkit/src/genkit/core/typing.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 02f5927450..f67b0e91e2 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -120,6 +120,23 @@ class GenkitError(BaseModel): data: Data | None = None +class MiddlewareDesc(BaseModel): + """Model for middlewaredesc data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str + description: str | None = None + config_schema: dict[str, Any] | None = Field(default=None) + + +class MiddlewareRef(BaseModel): + """Model for middlewareref data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str + config: Any | None = None + + class Code(StrEnum): """Code data type class.""" @@ -1002,6 +1019,7 @@ class GenerateActionOptions(BaseModel): return_tool_requests: bool | None = Field(default=None) max_turns: float | None = Field(default=None) step_name: str | None = Field(default=None) + use: list[MiddlewareRef] | None = None class GenerateRequest(BaseModel): From e6f453553fab338c0d875283e37bd30c4df59ac1 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:26:52 -0800 Subject: [PATCH 07/23] fixes --- go/ai/generate.go | 6 +----- go/ai/prompt.go | 3 +-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 70d53e8da5..c6b2a8066a 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -315,7 +315,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - // Resolve middleware from Use refs. var middlewareHandlers []Middleware if len(opts.Use) > 0 { middlewareHandlers = make([]Middleware, 0, len(opts.Use)) @@ -344,7 +343,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi fn = backgroundModelToModelFn(bm.Start) } - // Apply Model hooks from new middleware as a ModelMiddleware, then chain with legacy mw. if len(middlewareHandlers) > 0 { modelHook := func(next ModelFunc) ModelFunc { wrapped := next @@ -598,7 +596,6 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Register dynamic middleware (like dynamic tools) and build MiddlewareRefs. if len(genOpts.Use) > 0 { for _, mw := range genOpts.Use { name := mw.Name() @@ -606,7 +603,7 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod if !r.IsChild() { r = r.NewChild() } - NewMiddleware("", mw).Register(r) + DefineMiddleware(r, "", mw) } configJSON, err := json.Marshal(mw) if err != nil { @@ -620,7 +617,6 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Process resources in messages processedMessages, err := processResources(ctx, r, messages) if err != nil { return nil, core.NewError(core.INTERNAL, "ai.Generate: error processing resources: %v", err) diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 88c36e0cd7..9e4dff9f14 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,7 +249,6 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } - // Register dynamic middleware and build MiddlewareRefs. if len(execOpts.Use) > 0 { for _, mw := range execOpts.Use { name := mw.Name() @@ -257,7 +256,7 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod if !r.IsChild() { r = r.NewChild() } - NewMiddleware("", mw).Register(r) + DefineMiddleware(r, "", mw) } configJSON, err := json.Marshal(mw) if err != nil { From 9e4293151baa799819d481565d86175d3250f2e5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 09:53:04 -0800 Subject: [PATCH 08/23] Update genkit-tools/common/src/types/middleware.ts Co-authored-by: Pavel Jbanov --- genkit-tools/common/src/types/middleware.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 4bb1297ede..6fd2dd9810 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -23,6 +23,8 @@ export const MiddlewareDescSchema = z.object({ description: z.string().optional(), /** JSON Schema for the middleware's configuration. */ configSchema: z.record(z.any()).nullish(), + /** User defined metadata for the middleware. */ + metadata: z.record(z.any()).optional(), }); export type MiddlewareDesc = z.infer; From a962631a70d8a9223e59fbd73c01cbe9183c927e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 10:00:40 -0800 Subject: [PATCH 09/23] added new fields --- genkit-tools/genkit-schema.json | 4 ++++ go/ai/gen.go | 2 ++ go/core/schemas.config | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index d808e46df9..c921650c0c 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -289,6 +289,10 @@ "type": "null" } ] + }, + "metadata": { + "type": "object", + "additionalProperties": {} } }, "required": [ diff --git a/go/ai/gen.go b/go/ai/gen.go index 963b8cd737..41a7eb1c56 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -214,6 +214,8 @@ type MiddlewareDesc struct { ConfigSchema map[string]any `json:"configSchema,omitempty"` // Description explains what the middleware does. Description string `json:"description,omitempty"` + // Metadata contains additional context for the middleware. + Metadata map[string]any `json:"metadata,omitempty"` // Name is the middleware's unique identifier. Name string `json:"name,omitempty"` configFromJSON middlewareConfigFunc diff --git a/go/core/schemas.config b/go/core/schemas.config index 2fe8cc6d54..68ca942f64 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -864,6 +864,10 @@ MiddlewareDesc.configSchema doc ConfigSchema is a JSON Schema describing the middleware's configuration. . +MiddlewareDesc.metadata doc +Metadata contains additional context for the middleware. +. + MiddlewareRef doc MiddlewareRef is a serializable reference to a registered middleware with config. . From aa0d2e294d6298e084245d8969697f5aacf1111f Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 10:41:15 -0800 Subject: [PATCH 10/23] Update typing.py --- py/packages/genkit/src/genkit/core/typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index f67b0e91e2..85d311dd2e 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -127,6 +127,7 @@ class MiddlewareDesc(BaseModel): name: str description: str | None = None config_schema: dict[str, Any] | None = Field(default=None) + metadata: dict[str, Any] | None = None class MiddlewareRef(BaseModel): From 690fe39bde649cd9314714f4c97f10402a1e22ad Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 07:10:36 -0800 Subject: [PATCH 11/23] added new `Middleware` --- genkit-tools/common/src/types/index.ts | 1 + genkit-tools/common/src/types/middleware.ts | 37 +++ genkit-tools/common/src/types/model.ts | 3 + go/ai/gen.go | 2 + go/ai/generate.go | 125 ++++++++++- go/ai/middleware.go | 155 +++++++++++++ go/ai/middleware_test.go | 237 ++++++++++++++++++++ go/ai/option.go | 18 +- go/ai/prompt.go | 22 ++ go/genkit/genkit.go | 10 + go/genkit/reflection.go | 22 ++ 11 files changed, 628 insertions(+), 4 deletions(-) create mode 100644 genkit-tools/common/src/types/middleware.ts create mode 100644 go/ai/middleware.go create mode 100644 go/ai/middleware_test.go diff --git a/genkit-tools/common/src/types/index.ts b/genkit-tools/common/src/types/index.ts index ea12971f0e..360546af0e 100644 --- a/genkit-tools/common/src/types/index.ts +++ b/genkit-tools/common/src/types/index.ts @@ -23,6 +23,7 @@ export * from './document'; export * from './env'; export * from './eval'; export * from './evaluator'; +export * from './middleware'; export * from './model'; export * from './prompt'; export * from './retriever'; diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts new file mode 100644 index 0000000000..7f41af991e --- /dev/null +++ b/genkit-tools/common/src/types/middleware.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { z } from 'zod'; +import { JSONSchema7Schema } from './action'; + +/** Descriptor for a registered middleware, returned by reflection API. */ +export const MiddlewareDescSchema = z.object({ + /** Unique name of the middleware. */ + name: z.string(), + /** Human-readable description of what the middleware does. */ + description: z.string().optional(), + /** JSON Schema for the middleware's configuration. */ + configSchema: JSONSchema7Schema.optional(), +}); +export type MiddlewareDesc = z.infer; + +/** Reference to a registered middleware with optional configuration. */ +export const MiddlewareRefSchema = z.object({ + /** Name of the registered middleware. */ + name: z.string(), + /** Configuration for the middleware (schema defined by the middleware). */ + config: z.any().optional(), +}); +export type MiddlewareRef = z.infer; diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index a36d9f288f..62fa83dedb 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -15,6 +15,7 @@ */ import { z } from 'zod'; import { DocumentDataSchema } from './document'; +import { MiddlewareRefSchema } from './middleware'; import { CustomPartSchema, DataPartSchema, @@ -399,5 +400,7 @@ export const GenerateActionOptionsSchema = z.object({ maxTurns: z.number().optional(), /** Custom step name for this generate call to display in trace views. Defaults to "generate". */ stepName: z.string().optional(), + /** Middleware to apply to this generation. */ + use: z.array(MiddlewareRefSchema).optional(), }); export type GenerateActionOptions = z.infer; diff --git a/go/ai/gen.go b/go/ai/gen.go index e391ef2215..6f6ee06a19 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -96,6 +96,8 @@ type GenerateActionOptions struct { ToolChoice ToolChoice `json:"toolChoice,omitempty"` // Tools is a list of registered tool names for this generation if supported. Tools []string `json:"tools,omitempty"` + // Use is middleware to apply to this generation, referenced by name with optional config. + Use []*MiddlewareRef `json:"use,omitempty"` } // GenerateActionResume holds options for resuming an interrupted generation. diff --git a/go/ai/generate.go b/go/ai/generate.go index 003eb0b653..64aad71d33 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -67,6 +67,8 @@ type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelRespons type ModelStreamCallback = func(context.Context, *ModelResponseChunk) error // ModelMiddleware is middleware for model generate requests that takes in a ModelFunc, does something, then returns another ModelFunc. +// +// Deprecated: Use [Middleware] interface with [WithUse] instead, which supports Generate, Model, and Tool hooks. type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk] // model is an action with functions specific to model generation such as Generate(). @@ -313,6 +315,27 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } + // Resolve middleware from Use refs. + var middlewareHandlers []Middleware + if len(opts.Use) > 0 { + middlewareHandlers = make([]Middleware, 0, len(opts.Use)) + for _, ref := range opts.Use { + desc := LookupMiddleware(r, ref.Name) + if desc == nil { + return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) + } + configJSON, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) + } + handler, err := desc.configFromJSON(configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) + } + middlewareHandlers = append(middlewareHandlers, handler) + } + } + fn := m.Generate if bm != nil { if cb != nil { @@ -320,6 +343,24 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } fn = backgroundModelToModelFn(bm.Start) } + + // Apply Model hooks from new middleware as a ModelMiddleware, then chain with legacy mw. + if len(middlewareHandlers) > 0 { + modelHook := func(next ModelFunc) ModelFunc { + wrapped := next + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + inner := wrapped + wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return h.Model(ctx, &ModelState{Request: req, Callback: cb}, func(ctx context.Context, state *ModelState) (*ModelResponse, error) { + return inner(ctx, state.Request, state.Callback) + }) + } + } + return wrapped + } + mw = append([]ModelMiddleware{modelHook}, mw...) + } fn = core.ChainMiddleware(mw...)(fn) // Inline recursive helper function that captures variables from parent scope. @@ -388,7 +429,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, middlewareHandlers) if err != nil { return nil, err } @@ -406,6 +447,28 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi }) } + // Wrap generate with the Generate hook chain from middleware. + if len(middlewareHandlers) > 0 { + innerGenerate := generate + generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + innerFn := func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return innerGenerate(ctx, state.Request, currentTurn, messageIndex) + } + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + next := innerFn + innerFn = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return h.Generate(ctx, state, next) + } + } + return innerFn(ctx, &GenerateState{ + Options: opts, + Request: req, + Iteration: currentTurn, + }) + } + } + return generate(ctx, req, 0, 0) } @@ -535,6 +598,28 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } + // Register dynamic middleware (like dynamic tools) and build MiddlewareRefs. + if len(genOpts.Use) > 0 { + for _, mw := range genOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + NewMiddleware("", mw).Register(r) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to marshal middleware %q config: %v", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to unmarshal middleware %q config: %v", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + // Process resources in messages processedMessages, err := processResources(ctx, r, messages) if err != nil { @@ -773,7 +858,7 @@ func clone[T any](obj *T) *T { // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling. -func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) { +func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, middlewareHandlers []Middleware) (*ModelRequest, *Message, error) { toolCount := len(resp.ToolRequests()) if toolCount == 0 { return nil, nil, nil @@ -796,7 +881,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input) + multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, middlewareHandlers) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -879,6 +964,39 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return newReq, nil, nil } +// runToolWithMiddleware runs a tool, wrapping the execution with Tool hooks from middleware. +func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) { + if len(handlers) == 0 { + return tool.RunRawMultipart(ctx, toolReq.Input) + } + + inner := func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + resp, err := state.Tool.RunRawMultipart(ctx, state.Request.Input) + if err != nil { + return nil, err + } + return &ToolResponse{ + Name: state.Request.Name, + Output: resp.Output, + }, nil + } + + for i := len(handlers) - 1; i >= 0; i-- { + h := handlers[i] + next := inner + inner = func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + return h.Tool(ctx, state, next) + } + } + + toolResp, err := inner(ctx, &ToolState{Request: toolReq, Tool: tool}) + if err != nil { + return nil, err + } + + return &MultipartToolResponse{Output: toolResp.Output}, nil +} + // Text returns the contents of the first candidate in a // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. @@ -1357,6 +1475,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc Docs: genOpts.Docs, ReturnToolRequests: genOpts.ReturnToolRequests, Output: genOpts.Output, + Use: genOpts.Use, }, toolMessage: toolMessage, }, nil diff --git a/go/ai/middleware.go b/go/ai/middleware.go new file mode 100644 index 0000000000..71d5b93d1a --- /dev/null +++ b/go/ai/middleware.go @@ -0,0 +1,155 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" +) + +// Middleware provides hooks for different stages of generation. +type Middleware interface { + // Name returns the middleware's unique identifier. + Name() string + // New returns a fresh instance for each ai.Generate() call, enabling per-invocation state. + New() Middleware + // Generate wraps each iteration of the tool loop. + Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) + // Model wraps each model API call. + Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) + // Tool wraps each tool execution. + Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) +} + +// GenerateState holds state for the Generate hook. +type GenerateState struct { + // Options is the original options passed to [Generate]. + Options *GenerateActionOptions + // Request is the current model request for this iteration, with accumulated messages. + Request *ModelRequest + // Iteration is the current tool-loop iteration (0-indexed). + Iteration int +} + +// ModelState holds state for the Model hook. +type ModelState struct { + // Request is the model request about to be sent. + Request *ModelRequest + // Callback is the streaming callback, or nil if not streaming. + Callback ModelStreamCallback +} + +// ToolState holds state for the Tool hook. +type ToolState struct { + // Request is the tool request about to be executed. + Request *ToolRequest + // Tool is the resolved tool being called. + Tool Tool +} + +// GenerateNext is the next function in the Generate hook chain. +type GenerateNext = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) + +// ModelNext is the next function in the Model hook chain. +type ModelNext = func(ctx context.Context, state *ModelState) (*ModelResponse, error) + +// ToolNext is the next function in the Tool hook chain. +type ToolNext = func(ctx context.Context, state *ToolState) (*ToolResponse, error) + +// BaseMiddleware provides default pass-through for the three hooks. +// Embed this so you only need to implement Name() and New(). +type BaseMiddleware struct{} + +func (b *BaseMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + return next(ctx, state) +} + +// MiddlewareDesc is the registered descriptor for a middleware. +type MiddlewareDesc struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + ConfigSchema map[string]any `json:"configSchema,omitempty"` + configFromJSON func([]byte) (Middleware, error) +} + +// Register registers the descriptor with the registry. +func (d *MiddlewareDesc) Register(r api.Registry) { + r.RegisterValue("/middleware/"+d.Name, d) +} + +// NewMiddleware creates a middleware descriptor without registering it. +// The prototype carries stable state; configFromJSON calls prototype.New() +// then unmarshals user config on top. +func NewMiddleware[T Middleware](description string, prototype T) *MiddlewareDesc { + return &MiddlewareDesc{ + Name: prototype.Name(), + Description: description, + ConfigSchema: core.InferSchemaMap(*new(T)), + configFromJSON: func(configJSON []byte) (Middleware, error) { + inst := prototype.New() + if len(configJSON) > 0 { + if err := json.Unmarshal(configJSON, inst); err != nil { + return nil, fmt.Errorf("middleware %q: %w", prototype.Name(), err) + } + } + return inst, nil + }, + } +} + +// DefineMiddleware creates and registers a middleware descriptor. +func DefineMiddleware[T Middleware](r api.Registry, description string, prototype T) *MiddlewareDesc { + d := NewMiddleware(description, prototype) + d.Register(r) + return d +} + +// LookupMiddleware looks up a registered middleware descriptor by name. +func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { + v := r.LookupValue("/middleware/" + name) + if v == nil { + return nil + } + d, ok := v.(*MiddlewareDesc) + if !ok { + return nil + } + return d +} + +// MiddlewareRef is a serializable reference to a registered middleware with config. +type MiddlewareRef struct { + Name string `json:"name"` + Config any `json:"config,omitempty"` +} + +// MiddlewarePlugin is implemented by plugins that provide middleware. +type MiddlewarePlugin interface { + ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) +} diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go new file mode 100644 index 0000000000..0613e3e63e --- /dev/null +++ b/go/ai/middleware_test.go @@ -0,0 +1,237 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "sync/atomic" + "testing" +) + +// testMiddleware is a simple middleware for testing that tracks hook invocations. +type testMiddleware struct { + BaseMiddleware + Label string `json:"label"` + generateCalls int + modelCalls int + toolCalls int32 // atomic since tool hooks run in parallel +} + +func (m *testMiddleware) Name() string { return "test" } + +func (m *testMiddleware) New() Middleware { + return &testMiddleware{Label: m.Label} +} + +func (m *testMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + m.generateCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + m.modelCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + atomic.AddInt32(&m.toolCalls, 1) + return next(ctx, state) +} + +func TestNewMiddleware(t *testing.T) { + proto := &testMiddleware{Label: "original"} + desc := NewMiddleware("test middleware", proto) + + if desc.Name != "test" { + t.Errorf("got name %q, want %q", desc.Name, "test") + } + if desc.Description != "test middleware" { + t.Errorf("got description %q, want %q", desc.Description, "test middleware") + } +} + +func TestDefineAndLookupMiddleware(t *testing.T) { + r := newTestRegistry(t) + proto := &testMiddleware{Label: "original"} + DefineMiddleware(r, "test middleware", proto) + + found := LookupMiddleware(r, "test") + if found == nil { + t.Fatal("expected to find middleware, got nil") + } + if found.Name != "test" { + t.Errorf("got name %q, want %q", found.Name, "test") + } +} + +func TestLookupMiddlewareNotFound(t *testing.T) { + r := newTestRegistry(t) + found := LookupMiddleware(r, "nonexistent") + if found != nil { + t.Errorf("expected nil, got %v", found) + } +} + +func TestConfigFromJSON(t *testing.T) { + proto := &testMiddleware{Label: "stable"} + desc := NewMiddleware("test middleware", proto) + + handler, err := desc.configFromJSON([]byte(`{"label": "custom"}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + tm, ok := handler.(*testMiddleware) + if !ok { + t.Fatalf("expected *testMiddleware, got %T", handler) + } + if tm.Label != "custom" { + t.Errorf("got label %q, want %q", tm.Label, "custom") + } + // Per-request state should be zeroed by New() + if tm.generateCalls != 0 { + t.Errorf("got generateCalls %d, want 0", tm.generateCalls) + } +} + +func TestConfigFromJSONPreservesStableState(t *testing.T) { + // Simulate a plugin middleware with unexported stable state + proto := &stableStateMiddleware{apiKey: "secret123"} + desc := NewMiddleware("middleware with stable state", proto) + + handler, err := desc.configFromJSON([]byte(`{"sampleRate": 0.5}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + sm, ok := handler.(*stableStateMiddleware) + if !ok { + t.Fatalf("expected *stableStateMiddleware, got %T", handler) + } + if sm.apiKey != "secret123" { + t.Errorf("got apiKey %q, want %q", sm.apiKey, "secret123") + } + if sm.SampleRate != 0.5 { + t.Errorf("got SampleRate %f, want 0.5", sm.SampleRate) + } +} + +func TestMiddlewareModelHook(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + DefineMiddleware(r, "tracks calls", &testMiddleware{}) + + resp, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) + if resp == nil { + t.Fatal("expected response, got nil") + } +} + +func TestMiddlewareToolHook(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + mw := &testMiddleware{} + DefineMiddleware(r, "tracks calls", mw) + + _, err := Generate(ctx, r, + WithModelName("test/toolModel"), + WithPrompt("use the tool"), + WithTools(ToolName("myTool")), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) +} + +func TestMiddlewareOrdering(t *testing.T) { + // First middleware is outermost + var order []string + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + mwA := &orderMiddleware{label: "A", order: &order} + mwB := &orderMiddleware{label: "B", order: &order} + DefineMiddleware(r, "middleware A", mwA) + DefineMiddleware(r, "middleware B", mwB) + + _, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse( + &orderMiddleware{label: "A", order: &order}, + &orderMiddleware{label: "B", order: &order}, + ), + ) + assertNoError(t, err) + + // Expect: A-before, B-before, B-after, A-after (first is outermost) + want := []string{"A-model-before", "B-model-before", "B-model-after", "A-model-after"} + if len(order) != len(want) { + t.Fatalf("got order %v, want %v", order, want) + } + for i := range want { + if order[i] != want[i] { + t.Errorf("order[%d] = %q, want %q", i, order[i], want[i]) + } + } +} + +// --- helper middleware types for tests --- + +// stableStateMiddleware has unexported stable state preserved by New(). +type stableStateMiddleware struct { + BaseMiddleware + SampleRate float64 `json:"sampleRate"` + apiKey string +} + +func (m *stableStateMiddleware) Name() string { return "stableState" } + +func (m *stableStateMiddleware) New() Middleware { + return &stableStateMiddleware{apiKey: m.apiKey} +} + +// orderMiddleware tracks the order of Model hook invocations. +type orderMiddleware struct { + BaseMiddleware + label string + order *[]string +} + +func (m *orderMiddleware) Name() string { return "order-" + m.label } + +func (m *orderMiddleware) New() Middleware { + return &orderMiddleware{label: m.label, order: m.order} +} + +func (m *orderMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + *m.order = append(*m.order, m.label+"-model-before") + resp, err := next(ctx, state) + *m.order = append(*m.order, m.label+"-model-after") + return resp, err +} + +var ctx = context.Background() diff --git a/go/ai/option.go b/go/ai/option.go index d28c68e3e9..84019b11d7 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -109,7 +109,8 @@ type commonGenOptions struct { ToolChoice ToolChoice // Whether tool calls are required, disabled, or optional. MaxTurns int // Maximum number of tool call iterations. ReturnToolRequests *bool // Whether to return tool requests instead of making the tool calls and continuing the generation. - Middleware []ModelMiddleware // Middleware to apply to the model request and model response. + Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. + Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). } type CommonGenOption interface { @@ -181,6 +182,13 @@ func (o *commonGenOptions) applyCommonGen(opts *commonGenOptions) error { opts.Middleware = o.Middleware } + if o.Use != nil { + if opts.Use != nil { + return errors.New("cannot set middleware more than once (WithUse)") + } + opts.Use = o.Use + } + return nil } @@ -233,10 +241,18 @@ func WithModelName(name string) CommonGenOption { } // WithMiddleware sets middleware to apply to the model request. +// +// Deprecated: Use [WithUse] instead, which supports Generate, Model, and Tool hooks. func WithMiddleware(middleware ...ModelMiddleware) CommonGenOption { return &commonGenOptions{Middleware: middleware} } +// WithUse sets middleware to apply to generation. Middleware hooks wrap +// the generate loop, model calls, and tool executions. +func WithUse(middleware ...Middleware) CommonGenOption { + return &commonGenOptions{Use: middleware} +} + // WithMaxTurns sets the maximum number of tool call iterations before erroring. // A tool call happens when tools are provided in the request and a model decides to call one or more as a response. // Each round trip, including multiple tools in parallel, counts as one turn. diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 4d0151c4c8..88c36e0cd7 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,6 +249,28 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } + // Register dynamic middleware and build MiddlewareRefs. + if len(execOpts.Use) > 0 { + for _, mw := range execOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + NewMiddleware("", mw).Register(r) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to marshal middleware %q config: %w", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to unmarshal middleware %q config: %w", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 377fb5e836..8fd32913c2 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -228,6 +228,16 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { action.Register(r) } r.RegisterPlugin(plugin.Name(), plugin) + + if mp, ok := plugin.(ai.MiddlewarePlugin); ok { + descs, err := mp.ListMiddleware(ctx) + if err != nil { + panic(fmt.Errorf("genkit.Init: plugin %q ListMiddleware failed: %w", plugin.Name(), err)) + } + for _, d := range descs { + d.Register(r) + } + } } ai.ConfigureFormats(r) diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 1bd675f75a..9936936e61 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -303,6 +303,7 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions))) mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify())) mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions))) + mux.HandleFunc("GET /api/values", wrapReflectionHandler(handleListValues(g))) return mux } @@ -598,6 +599,27 @@ func handleListActions(g *Genkit) func(w http.ResponseWriter, r *http.Request) e } } +// handleListValues returns registered values filtered by type query parameter. +// Matches JS: GET /api/values?type=middleware +func handleListValues(g *Genkit) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + valueType := r.URL.Query().Get("type") + if valueType == "" { + http.Error(w, `query parameter "type" is required`, http.StatusBadRequest) + return nil + } + prefix := "/" + valueType + "/" + result := map[string]any{} + for key, val := range g.reg.ListValues() { + if strings.HasPrefix(key, prefix) { + name := strings.TrimPrefix(key, prefix) + result[name] = val + } + } + return writeJSON(r.Context(), w, result) + } +} + // listActions lists all the registered actions. func listActions(g *Genkit) []api.ActionDesc { ads := []api.ActionDesc{} From a6e3b549e15f29fb4c1cc9da613bf95f50ddaa94 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 07:49:54 -0800 Subject: [PATCH 12/23] updated Genkit schema --- genkit-tools/genkit-schema.json | 48 +++++++++++++++++++++++++ genkit-tools/scripts/schema-exporter.ts | 1 + 2 files changed, 49 insertions(+) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 26cc4fbf4f..2cbd939736 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -270,6 +270,48 @@ ], "additionalProperties": false }, + "MiddlewareDesc": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "configSchema": { + "anyOf": [ + { + "type": "object", + "properties": {}, + "additionalProperties": false, + "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + }, + { + "type": "null" + } + ], + "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + } + }, + "required": [ + "name" + ], + "additionalProperties": false + }, + "MiddlewareRef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "config": {} + }, + "required": [ + "name" + ], + "additionalProperties": false + }, "CandidateError": { "type": "object", "properties": { @@ -466,6 +508,12 @@ }, "stepName": { "type": "string" + }, + "use": { + "type": "array", + "items": { + "$ref": "#/$defs/MiddlewareRef" + } } }, "required": [ diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 48df79b56a..7462a12a8d 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -26,6 +26,7 @@ const EXPORTED_TYPE_MODULES = [ '../common/src/types/embedder.ts', '../common/src/types/evaluator.ts', '../common/src/types/error.ts', + '../common/src/types/middleware.ts', '../common/src/types/model.ts', '../common/src/types/parts.ts', '../common/src/types/reranker.ts', From 2467e333a8e30d033e458129753639c32fea4b66 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:02:14 -0800 Subject: [PATCH 13/23] updated common schema --- genkit-tools/common/src/types/middleware.ts | 3 +- genkit-tools/genkit-schema.json | 7 +--- go/ai/gen.go | 19 +++++++++ go/ai/middleware.go | 17 ++------ go/core/schemas.config | 43 +++++++++++++++++++++ 5 files changed, 68 insertions(+), 21 deletions(-) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 7f41af991e..4bb1297ede 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -14,7 +14,6 @@ * limitations under the License. */ import { z } from 'zod'; -import { JSONSchema7Schema } from './action'; /** Descriptor for a registered middleware, returned by reflection API. */ export const MiddlewareDescSchema = z.object({ @@ -23,7 +22,7 @@ export const MiddlewareDescSchema = z.object({ /** Human-readable description of what the middleware does. */ description: z.string().optional(), /** JSON Schema for the middleware's configuration. */ - configSchema: JSONSchema7Schema.optional(), + configSchema: z.record(z.any()).nullish(), }); export type MiddlewareDesc = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 2cbd939736..d808e46df9 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -283,15 +283,12 @@ "anyOf": [ { "type": "object", - "properties": {}, - "additionalProperties": false, - "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + "additionalProperties": {} }, { "type": "null" } - ], - "description": "A JSON Schema Draft 7 (http://json-schema.org/draft-07/schema) object." + ] } }, "required": [ diff --git a/go/ai/gen.go b/go/ai/gen.go index 6f6ee06a19..963b8cd737 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -208,6 +208,25 @@ type Message struct { Role Role `json:"role,omitempty"` } +// MiddlewareDesc is the registered descriptor for a middleware. +type MiddlewareDesc struct { + // ConfigSchema is a JSON Schema describing the middleware's configuration. + ConfigSchema map[string]any `json:"configSchema,omitempty"` + // Description explains what the middleware does. + Description string `json:"description,omitempty"` + // Name is the middleware's unique identifier. + Name string `json:"name,omitempty"` + configFromJSON middlewareConfigFunc +} + +// MiddlewareRef is a serializable reference to a registered middleware with config. +type MiddlewareRef struct { + // Config contains the middleware configuration. + Config any `json:"config,omitempty"` + // Name is the name of the registered middleware. + Name string `json:"name,omitempty"` +} + // ModelInfo contains metadata about a model's capabilities and characteristics. type ModelInfo struct { // ConfigSchema defines the model-specific configuration schema. diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 71d5b93d1a..35b2faf37f 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -25,6 +25,9 @@ import ( "github.com/firebase/genkit/go/core/api" ) +// middlewareConfigFunc creates a Middleware instance from JSON config. +type middlewareConfigFunc = func([]byte) (Middleware, error) + // Middleware provides hooks for different stages of generation. type Middleware interface { // Name returns the middleware's unique identifier. @@ -90,14 +93,6 @@ func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNe return next(ctx, state) } -// MiddlewareDesc is the registered descriptor for a middleware. -type MiddlewareDesc struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - ConfigSchema map[string]any `json:"configSchema,omitempty"` - configFromJSON func([]byte) (Middleware, error) -} - // Register registers the descriptor with the registry. func (d *MiddlewareDesc) Register(r api.Registry) { r.RegisterValue("/middleware/"+d.Name, d) @@ -143,12 +138,6 @@ func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { return d } -// MiddlewareRef is a serializable reference to a registered middleware with config. -type MiddlewareRef struct { - Name string `json:"name"` - Config any `json:"config,omitempty"` -} - // MiddlewarePlugin is implemented by plugins that provide middleware. type MiddlewarePlugin interface { ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) diff --git a/go/core/schemas.config b/go/core/schemas.config index 70798f2eb3..2fe8cc6d54 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -732,6 +732,10 @@ StepName is a custom step name for this generate call to display in trace views. Defaults to "generate". . +GenerateActionOptions.use doc +Use is middleware to apply to this generation, referenced by name with optional config. +. + GenerateActionOptionsResume doc GenerateActionResume holds options for resuming an interrupted generation. . @@ -840,6 +844,38 @@ PathMetadata.error doc Error contains error information if the path failed. . +# ---------------------------------------------------------------------------- +# Middleware Types +# ---------------------------------------------------------------------------- + +MiddlewareDesc doc +MiddlewareDesc is the registered descriptor for a middleware. +. + +MiddlewareDesc.name doc +Name is the middleware's unique identifier. +. + +MiddlewareDesc.description doc +Description explains what the middleware does. +. + +MiddlewareDesc.configSchema doc +ConfigSchema is a JSON Schema describing the middleware's configuration. +. + +MiddlewareRef doc +MiddlewareRef is a serializable reference to a registered middleware with config. +. + +MiddlewareRef.name doc +Name is the name of the registered middleware. +. + +MiddlewareRef.config doc +Config contains the middleware configuration. +. + # ---------------------------------------------------------------------------- # Multipart Tool Response # ---------------------------------------------------------------------------- @@ -1060,6 +1096,7 @@ GenerateActionOptions.config type any GenerateActionOptions.output type *GenerateActionOutputConfig GenerateActionOptions.returnToolRequests type bool GenerateActionOptions.maxTurns type int +GenerateActionOptions.use type []*MiddlewareRef GenerateActionOptionsResume name GenerateActionResume # GenerateActionOutputConfig @@ -1101,6 +1138,12 @@ ModelResponseChunk.index type int ModelResponseChunk.role type Role ModelResponseChunk field formatHandler StreamingFormatHandler +# Middleware +MiddlewareDesc pkg ai +MiddlewareDesc.configSchema type map[string]any +MiddlewareDesc field configFromJSON middlewareConfigFunc +MiddlewareRef pkg ai + Score omit Embedding.embedding type []float32 From 1dfc1c899a6c004559a29d8b9210718e92e9baf3 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:03:47 -0800 Subject: [PATCH 14/23] Update generate.go --- go/ai/generate.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 64aad71d33..70d53e8da5 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -976,8 +976,9 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return nil, err } return &ToolResponse{ - Name: state.Request.Name, - Output: resp.Output, + Name: state.Request.Name, + Output: resp.Output, + Content: resp.Content, }, nil } @@ -994,7 +995,7 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, return nil, err } - return &MultipartToolResponse{Output: toolResp.Output}, nil + return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content}, nil } // Text returns the contents of the first candidate in a From e1e17215f7d71709009adc24916b65f3ab27f0ed Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:04:21 -0800 Subject: [PATCH 15/23] Update middleware_test.go --- go/ai/middleware_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index 0613e3e63e..a0f9f935ea 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -25,7 +25,7 @@ import ( // testMiddleware is a simple middleware for testing that tracks hook invocations. type testMiddleware struct { BaseMiddleware - Label string `json:"label"` + Label string `json:"label"` generateCalls int modelCalls int toolCalls int32 // atomic since tool hooks run in parallel @@ -149,7 +149,7 @@ func TestMiddlewareModelHook(t *testing.T) { func TestMiddlewareToolHook(t *testing.T) { r := newTestRegistry(t) defineFakeModel(t, r, fakeModelConfig{ - name: "test/toolModel", + name: "test/toolModel", handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), }) defineFakeTool(t, r, "myTool", "A test tool") From 53eb2f754b306fa8686f9eba93b1a195995761db Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:06:30 -0800 Subject: [PATCH 16/23] Update typing.py --- py/packages/genkit/src/genkit/core/typing.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 02f5927450..f67b0e91e2 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -120,6 +120,23 @@ class GenkitError(BaseModel): data: Data | None = None +class MiddlewareDesc(BaseModel): + """Model for middlewaredesc data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str + description: str | None = None + config_schema: dict[str, Any] | None = Field(default=None) + + +class MiddlewareRef(BaseModel): + """Model for middlewareref data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str + config: Any | None = None + + class Code(StrEnum): """Code data type class.""" @@ -1002,6 +1019,7 @@ class GenerateActionOptions(BaseModel): return_tool_requests: bool | None = Field(default=None) max_turns: float | None = Field(default=None) step_name: str | None = Field(default=None) + use: list[MiddlewareRef] | None = None class GenerateRequest(BaseModel): From 6e467e47c91b80f05b1ebb862866fe02bbdc57db Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 6 Feb 2026 08:26:52 -0800 Subject: [PATCH 17/23] fixes --- go/ai/generate.go | 6 +----- go/ai/prompt.go | 3 +-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 70d53e8da5..c6b2a8066a 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -315,7 +315,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } - // Resolve middleware from Use refs. var middlewareHandlers []Middleware if len(opts.Use) > 0 { middlewareHandlers = make([]Middleware, 0, len(opts.Use)) @@ -344,7 +343,6 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi fn = backgroundModelToModelFn(bm.Start) } - // Apply Model hooks from new middleware as a ModelMiddleware, then chain with legacy mw. if len(middlewareHandlers) > 0 { modelHook := func(next ModelFunc) ModelFunc { wrapped := next @@ -598,7 +596,6 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Register dynamic middleware (like dynamic tools) and build MiddlewareRefs. if len(genOpts.Use) > 0 { for _, mw := range genOpts.Use { name := mw.Name() @@ -606,7 +603,7 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod if !r.IsChild() { r = r.NewChild() } - NewMiddleware("", mw).Register(r) + DefineMiddleware(r, "", mw) } configJSON, err := json.Marshal(mw) if err != nil { @@ -620,7 +617,6 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Process resources in messages processedMessages, err := processResources(ctx, r, messages) if err != nil { return nil, core.NewError(core.INTERNAL, "ai.Generate: error processing resources: %v", err) diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 88c36e0cd7..9e4dff9f14 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,7 +249,6 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } - // Register dynamic middleware and build MiddlewareRefs. if len(execOpts.Use) > 0 { for _, mw := range execOpts.Use { name := mw.Name() @@ -257,7 +256,7 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod if !r.IsChild() { r = r.NewChild() } - NewMiddleware("", mw).Register(r) + DefineMiddleware(r, "", mw) } configJSON, err := json.Marshal(mw) if err != nil { From f98d60e4c73bbb804c5e3fa7f8e04029bfc4bab6 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 09:53:04 -0800 Subject: [PATCH 18/23] Update genkit-tools/common/src/types/middleware.ts Co-authored-by: Pavel Jbanov --- genkit-tools/common/src/types/middleware.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts index 4bb1297ede..6fd2dd9810 100644 --- a/genkit-tools/common/src/types/middleware.ts +++ b/genkit-tools/common/src/types/middleware.ts @@ -23,6 +23,8 @@ export const MiddlewareDescSchema = z.object({ description: z.string().optional(), /** JSON Schema for the middleware's configuration. */ configSchema: z.record(z.any()).nullish(), + /** User defined metadata for the middleware. */ + metadata: z.record(z.any()).optional(), }); export type MiddlewareDesc = z.infer; From 3d19238818e597e83a1d7169121bc6aa80ebc4a2 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 10:00:40 -0800 Subject: [PATCH 19/23] added new fields --- genkit-tools/genkit-schema.json | 4 ++++ go/ai/gen.go | 2 ++ go/core/schemas.config | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index d808e46df9..c921650c0c 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -289,6 +289,10 @@ "type": "null" } ] + }, + "metadata": { + "type": "object", + "additionalProperties": {} } }, "required": [ diff --git a/go/ai/gen.go b/go/ai/gen.go index 963b8cd737..41a7eb1c56 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -214,6 +214,8 @@ type MiddlewareDesc struct { ConfigSchema map[string]any `json:"configSchema,omitempty"` // Description explains what the middleware does. Description string `json:"description,omitempty"` + // Metadata contains additional context for the middleware. + Metadata map[string]any `json:"metadata,omitempty"` // Name is the middleware's unique identifier. Name string `json:"name,omitempty"` configFromJSON middlewareConfigFunc diff --git a/go/core/schemas.config b/go/core/schemas.config index 2fe8cc6d54..68ca942f64 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -864,6 +864,10 @@ MiddlewareDesc.configSchema doc ConfigSchema is a JSON Schema describing the middleware's configuration. . +MiddlewareDesc.metadata doc +Metadata contains additional context for the middleware. +. + MiddlewareRef doc MiddlewareRef is a serializable reference to a registered middleware with config. . From c90892ff1cf3d5af16f19bbb46e7644f420d6abd Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 9 Feb 2026 10:41:15 -0800 Subject: [PATCH 20/23] Update typing.py --- py/packages/genkit/src/genkit/core/typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index f67b0e91e2..85d311dd2e 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -127,6 +127,7 @@ class MiddlewareDesc(BaseModel): name: str description: str | None = None config_schema: dict[str, Any] | None = Field(default=None) + metadata: dict[str, Any] | None = None class MiddlewareRef(BaseModel): From 78958d823ac4e8b6c053ca397605e12b118254b9 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 10 Feb 2026 09:17:28 -0800 Subject: [PATCH 21/23] added tools to middleware interface --- go/ai/generate.go | 8 ++++++++ go/ai/middleware.go | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/go/ai/generate.go b/go/ai/generate.go index c6b2a8066a..7262e928a8 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -500,6 +500,14 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod return nil, err } + // Collect tools provided by middleware. + for _, mw := range genOpts.Use { + for _, t := range mw.Tools() { + dynamicTools = append(dynamicTools, t) + toolNames = append(toolNames, t.Name()) + } + } + if len(dynamicTools) > 0 { if !r.IsChild() { r = r.NewChild() diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 35b2faf37f..d5bd63f792 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -40,6 +40,9 @@ type Middleware interface { Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) // Tool wraps each tool execution. Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) + // Tools returns additional tools to make available during generation. + // These tools are dynamically registered when the middleware is used via [WithUse]. + Tools() []Tool } // GenerateState holds state for the Generate hook. @@ -93,6 +96,8 @@ func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNe return next(ctx, state) } +func (b *BaseMiddleware) Tools() []Tool { return nil } + // Register registers the descriptor with the registry. func (d *MiddlewareDesc) Register(r api.Registry) { r.RegisterValue("/middleware/"+d.Name, d) From 3d7afe9225059e53ddcfe0b16f9bd6eea1077513 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 13 Feb 2026 14:31:21 -0800 Subject: [PATCH 22/23] renames --- go/ai/generate.go | 28 +++++++++++------------ go/ai/middleware.go | 48 ++++++++++++++++++++-------------------- go/ai/middleware_test.go | 18 +++++++-------- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/go/ai/generate.go b/go/ai/generate.go index 7262e928a8..7f8778664f 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -350,8 +350,8 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi h := middlewareHandlers[i] inner := wrapped wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { - return h.Model(ctx, &ModelState{Request: req, Callback: cb}, func(ctx context.Context, state *ModelState) (*ModelResponse, error) { - return inner(ctx, state.Request, state.Callback) + return h.WrapModel(ctx, &ModelParams{Request: req, Callback: cb}, func(ctx context.Context, params *ModelParams) (*ModelResponse, error) { + return inner(ctx, params.Request, params.Callback) }) } } @@ -449,17 +449,17 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi if len(middlewareHandlers) > 0 { innerGenerate := generate generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { - innerFn := func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { - return innerGenerate(ctx, state.Request, currentTurn, messageIndex) + innerFn := func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + return innerGenerate(ctx, params.Request, currentTurn, messageIndex) } for i := len(middlewareHandlers) - 1; i >= 0; i-- { h := middlewareHandlers[i] next := innerFn - innerFn = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { - return h.Generate(ctx, state, next) + innerFn = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) { + return h.WrapGenerate(ctx, params, next) } } - return innerFn(ctx, &GenerateState{ + return innerFn(ctx, &GenerateParams{ Options: opts, Request: req, Iteration: currentTurn, @@ -968,19 +968,19 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return newReq, nil, nil } -// runToolWithMiddleware runs a tool, wrapping the execution with Tool hooks from middleware. +// runToolWithMiddleware runs a tool, wrapping the execution with WrapTool hooks from middleware. func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) { if len(handlers) == 0 { return tool.RunRawMultipart(ctx, toolReq.Input) } - inner := func(ctx context.Context, state *ToolState) (*ToolResponse, error) { - resp, err := state.Tool.RunRawMultipart(ctx, state.Request.Input) + inner := func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { + resp, err := params.Tool.RunRawMultipart(ctx, params.Request.Input) if err != nil { return nil, err } return &ToolResponse{ - Name: state.Request.Name, + Name: params.Request.Name, Output: resp.Output, Content: resp.Content, }, nil @@ -989,12 +989,12 @@ func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, for i := len(handlers) - 1; i >= 0; i-- { h := handlers[i] next := inner - inner = func(ctx context.Context, state *ToolState) (*ToolResponse, error) { - return h.Tool(ctx, state, next) + inner = func(ctx context.Context, params *ToolParams) (*ToolResponse, error) { + return h.WrapTool(ctx, params, next) } } - toolResp, err := inner(ctx, &ToolState{Request: toolReq, Tool: tool}) + toolResp, err := inner(ctx, &ToolParams{Request: toolReq, Tool: tool}) if err != nil { return nil, err } diff --git a/go/ai/middleware.go b/go/ai/middleware.go index d5bd63f792..aff3b063fc 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -34,19 +34,19 @@ type Middleware interface { Name() string // New returns a fresh instance for each ai.Generate() call, enabling per-invocation state. New() Middleware - // Generate wraps each iteration of the tool loop. - Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) - // Model wraps each model API call. - Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) - // Tool wraps each tool execution. - Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) + // WrapGenerate wraps each iteration of the tool loop. + WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) + // WrapModel wraps each model API call. + WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) + // WrapTool wraps each tool execution. + WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) // Tools returns additional tools to make available during generation. // These tools are dynamically registered when the middleware is used via [WithUse]. Tools() []Tool } -// GenerateState holds state for the Generate hook. -type GenerateState struct { +// GenerateParams holds params for the WrapGenerate hook. +type GenerateParams struct { // Options is the original options passed to [Generate]. Options *GenerateActionOptions // Request is the current model request for this iteration, with accumulated messages. @@ -55,45 +55,45 @@ type GenerateState struct { Iteration int } -// ModelState holds state for the Model hook. -type ModelState struct { +// ModelParams holds params for the WrapModel hook. +type ModelParams struct { // Request is the model request about to be sent. Request *ModelRequest // Callback is the streaming callback, or nil if not streaming. Callback ModelStreamCallback } -// ToolState holds state for the Tool hook. -type ToolState struct { +// ToolParams holds params for the WrapTool hook. +type ToolParams struct { // Request is the tool request about to be executed. Request *ToolRequest // Tool is the resolved tool being called. Tool Tool } -// GenerateNext is the next function in the Generate hook chain. -type GenerateNext = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) +// GenerateNext is the next function in the WrapGenerate hook chain. +type GenerateNext = func(ctx context.Context, params *GenerateParams) (*ModelResponse, error) -// ModelNext is the next function in the Model hook chain. -type ModelNext = func(ctx context.Context, state *ModelState) (*ModelResponse, error) +// ModelNext is the next function in the WrapModel hook chain. +type ModelNext = func(ctx context.Context, params *ModelParams) (*ModelResponse, error) -// ToolNext is the next function in the Tool hook chain. -type ToolNext = func(ctx context.Context, state *ToolState) (*ToolResponse, error) +// ToolNext is the next function in the WrapTool hook chain. +type ToolNext = func(ctx context.Context, params *ToolParams) (*ToolResponse, error) // BaseMiddleware provides default pass-through for the three hooks. // Embed this so you only need to implement Name() and New(). type BaseMiddleware struct{} -func (b *BaseMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { - return next(ctx, state) +func (b *BaseMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { + return next(ctx, params) } -func (b *BaseMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { - return next(ctx, state) +func (b *BaseMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { + return next(ctx, params) } -func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { - return next(ctx, state) +func (b *BaseMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { + return next(ctx, params) } func (b *BaseMiddleware) Tools() []Tool { return nil } diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index a0f9f935ea..4361d00b58 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -37,19 +37,19 @@ func (m *testMiddleware) New() Middleware { return &testMiddleware{Label: m.Label} } -func (m *testMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { +func (m *testMiddleware) WrapGenerate(ctx context.Context, params *GenerateParams, next GenerateNext) (*ModelResponse, error) { m.generateCalls++ - return next(ctx, state) + return next(ctx, params) } -func (m *testMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { +func (m *testMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { m.modelCalls++ - return next(ctx, state) + return next(ctx, params) } -func (m *testMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { +func (m *testMiddleware) WrapTool(ctx context.Context, params *ToolParams, next ToolNext) (*ToolResponse, error) { atomic.AddInt32(&m.toolCalls, 1) - return next(ctx, state) + return next(ctx, params) } func TestNewMiddleware(t *testing.T) { @@ -214,7 +214,7 @@ func (m *stableStateMiddleware) New() Middleware { return &stableStateMiddleware{apiKey: m.apiKey} } -// orderMiddleware tracks the order of Model hook invocations. +// orderMiddleware tracks the order of WrapModel hook invocations. type orderMiddleware struct { BaseMiddleware label string @@ -227,9 +227,9 @@ func (m *orderMiddleware) New() Middleware { return &orderMiddleware{label: m.label, order: m.order} } -func (m *orderMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { +func (m *orderMiddleware) WrapModel(ctx context.Context, params *ModelParams, next ModelNext) (*ModelResponse, error) { *m.order = append(*m.order, m.label+"-model-before") - resp, err := next(ctx, state) + resp, err := next(ctx, params) *m.order = append(*m.order, m.label+"-model-after") return resp, err } From 0aa7385ff756ff6f42d6d49fb62a01358a11602c Mon Sep 17 00:00:00 2001 From: Google Admin Date: Wed, 18 Feb 2026 09:18:18 -0500 Subject: [PATCH 23/23] Refactor Github Action per b/485167538 (#4761) --- .github/workflows/build-cli-binaries.yml | 14 +++++++--- .github/workflows/bump-cli-version.yml | 4 ++- .github/workflows/bump-js-version.yml | 4 ++- .github/workflows/bump-package-version.yml | 6 ++++- .github/workflows/publish_python.yml | 11 +++++--- .github/workflows/release_js_main.yml | 3 ++- .github/workflows/release_js_package.yml | 6 +++-- .github/workflows/releasekit-uv.yml | 31 +++++++++++++--------- 8 files changed, 54 insertions(+), 25 deletions(-) diff --git a/.github/workflows/build-cli-binaries.yml b/.github/workflows/build-cli-binaries.yml index 8bbfb73f9b..d374b8944f 100644 --- a/.github/workflows/build-cli-binaries.yml +++ b/.github/workflows/build-cli-binaries.yml @@ -146,20 +146,26 @@ jobs: shell: bash run: | echo "Testing genkit --help" - ./genkit-${{ matrix.target }}${{ steps.binary.outputs.ext }} --help + ./genkit-${{ matrix.target }}${STEPS_BINARY_OUTPUTS_EXT} --help + env: + STEPS_BINARY_OUTPUTS_EXT: ${{ steps.binary.outputs.ext }} - name: Test --version command shell: bash run: | echo "Testing genkit --version" - ./genkit-${{ matrix.target }}${{ steps.binary.outputs.ext }} --version + ./genkit-${{ matrix.target }}${STEPS_BINARY_OUTPUTS_EXT} --version + env: + STEPS_BINARY_OUTPUTS_EXT: ${{ steps.binary.outputs.ext }} - name: Verify UI commands exist shell: bash run: | echo "Verifying UI commands are available" - ./genkit-${{ matrix.target }}${{ steps.binary.outputs.ext }} ui:start --help - ./genkit-${{ matrix.target }}${{ steps.binary.outputs.ext }} ui:stop --help + ./genkit-${{ matrix.target }}${STEPS_BINARY_OUTPUTS_EXT} ui:start --help + ./genkit-${{ matrix.target }}${STEPS_BINARY_OUTPUTS_EXT} ui:stop --help + env: + STEPS_BINARY_OUTPUTS_EXT: ${{ steps.binary.outputs.ext }} - name: Test UI start functionality (Unix only) if: runner.os != 'Windows' diff --git a/.github/workflows/bump-cli-version.yml b/.github/workflows/bump-cli-version.yml index 6ce05c4da4..49e442d2f9 100644 --- a/.github/workflows/bump-cli-version.yml +++ b/.github/workflows/bump-cli-version.yml @@ -63,7 +63,9 @@ jobs: - name: Bump and Tag run: | - js/scripts/bump_and_tag_cli.sh ${{ inputs.releaseType }} ${{ inputs.preid }} + js/scripts/bump_and_tag_cli.sh ${{ inputs.releaseType }} ${INPUTS_PREID} + env: + INPUTS_PREID: ${{ inputs.preid }} - name: Push shell: bash diff --git a/.github/workflows/bump-js-version.yml b/.github/workflows/bump-js-version.yml index d139baea42..65b9fafd87 100644 --- a/.github/workflows/bump-js-version.yml +++ b/.github/workflows/bump-js-version.yml @@ -63,7 +63,9 @@ jobs: - name: Bump and Tag run: | - js/scripts/bump_and_tag_js.sh ${{ inputs.releaseType }} ${{ inputs.preid }} + js/scripts/bump_and_tag_js.sh ${{ inputs.releaseType }} ${INPUTS_PREID} + env: + INPUTS_PREID: ${{ inputs.preid }} - name: Push shell: bash diff --git a/.github/workflows/bump-package-version.yml b/.github/workflows/bump-package-version.yml index d1d8ad81a2..79a392b1b5 100644 --- a/.github/workflows/bump-package-version.yml +++ b/.github/workflows/bump-package-version.yml @@ -71,7 +71,11 @@ jobs: - name: Bump and Tag run: | - js/scripts/bump_and_tag.sh ${{ inputs.packageDir }} ${{ inputs.packageName }} ${{ inputs.releaseType }} ${{ inputs.preid }} + js/scripts/bump_and_tag.sh ${INPUTS_PACKAGEDIR} ${INPUTS_PACKAGENAME} ${{ inputs.releaseType }} ${INPUTS_PREID} + env: + INPUTS_PACKAGEDIR: ${{ inputs.packageDir }} + INPUTS_PACKAGENAME: ${{ inputs.packageName }} + INPUTS_PREID: ${{ inputs.preid }} - name: Push shell: bash diff --git a/.github/workflows/publish_python.yml b/.github/workflows/publish_python.yml index ccc16dc189..3176658809 100644 --- a/.github/workflows/publish_python.yml +++ b/.github/workflows/publish_python.yml @@ -224,16 +224,16 @@ jobs: echo "**Dry Run:** ${{ inputs.dry_run }}" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY - if [ "${{ needs.publish.result }}" == "success" ]; then + if [ "${NEEDS_PUBLISH_RESULT}" == "success" ]; then echo "### ✅ Publish Status: Success" >> $GITHUB_STEP_SUMMARY else echo "### ❌ Publish Status: Failed" >> $GITHUB_STEP_SUMMARY fi echo "" >> $GITHUB_STEP_SUMMARY - if [ "${{ needs.verify.result }}" == "success" ]; then + if [ "${NEEDS_VERIFY_RESULT}" == "success" ]; then echo "### ✅ Verification: Passed" >> $GITHUB_STEP_SUMMARY - elif [ "${{ needs.verify.result }}" == "failure" ]; then + elif [ "${NEEDS_VERIFY_RESULT}" == "failure" ]; then echo "### ⚠️ Verification: Some packages failed" >> $GITHUB_STEP_SUMMARY else echo "### ⏭️ Verification: Skipped" >> $GITHUB_STEP_SUMMARY @@ -244,3 +244,8 @@ jobs: echo "1. Verify on PyPI: https://pypi.org/project/genkit/$VERSION/" >> $GITHUB_STEP_SUMMARY echo "2. Test installation: \`pip install genkit==$VERSION\`" >> $GITHUB_STEP_SUMMARY echo "3. Update documentation if needed" >> $GITHUB_STEP_SUMMARY + + env: + NEEDS_PUBLISH_RESULT: ${{ needs.publish.result }} + + NEEDS_VERIFY_RESULT: ${{ needs.verify.result }} diff --git a/.github/workflows/release_js_main.yml b/.github/workflows/release_js_main.yml index 5ae4453b48..9dec1f58d8 100644 --- a/.github/workflows/release_js_main.yml +++ b/.github/workflows/release_js_main.yml @@ -62,6 +62,7 @@ jobs: registry-url: 'https://wombat-dressing-room.appspot.com/' - name: release script shell: bash - run: RELEASE_BRANCH=${{ steps.extract_branch.outputs.branch }} RELEASE_TAG=${{ inputs.releaseTag }} scripts/release_main.sh + run: RELEASE_BRANCH=${STEPS_EXTRACT_BRANCH_OUTPUTS_BRANCH} RELEASE_TAG=${{ inputs.releaseTag }} scripts/release_main.sh env: NODE_AUTH_TOKEN: ${{ secrets.NODE_AUTH_TOKEN }} + STEPS_EXTRACT_BRANCH_OUTPUTS_BRANCH: ${{ steps.extract_branch.outputs.branch }} diff --git a/.github/workflows/release_js_package.yml b/.github/workflows/release_js_package.yml index c3119251a0..dd215e57fe 100644 --- a/.github/workflows/release_js_package.yml +++ b/.github/workflows/release_js_package.yml @@ -65,7 +65,9 @@ jobs: - name: release script shell: bash run: | - cd ${{ inputs.packageDir }} - pnpm publish --tag ${{ inputs.releaseTag }} --publish-branch ${{ steps.extract_branch.outputs.branch }} --access=public --registry https://wombat-dressing-room.appspot.com + cd ${INPUTS_PACKAGEDIR} + pnpm publish --tag ${{ inputs.releaseTag }} --publish-branch ${STEPS_EXTRACT_BRANCH_OUTPUTS_BRANCH} --access=public --registry https://wombat-dressing-room.appspot.com env: NODE_AUTH_TOKEN: ${{ secrets.NODE_AUTH_TOKEN }} + INPUTS_PACKAGEDIR: ${{ inputs.packageDir }} + STEPS_EXTRACT_BRANCH_OUTPUTS_BRANCH: ${{ steps.extract_branch.outputs.branch }} diff --git a/.github/workflows/releasekit-uv.yml b/.github/workflows/releasekit-uv.yml index e79066933b..090526bbd7 100644 --- a/.github/workflows/releasekit-uv.yml +++ b/.github/workflows/releasekit-uv.yml @@ -234,14 +234,14 @@ jobs: if [ "${{ inputs.force_prepare }}" = "true" ]; then cmd+=(--force) fi - if [ -n "${{ inputs.group }}" ]; then - cmd+=(--group "${{ inputs.group }}") + if [ -n "${INPUTS_GROUP}" ]; then + cmd+=(--group "${INPUTS_GROUP}") fi if [ "${{ inputs.bump_type }}" != "auto" ] && [ -n "${{ inputs.bump_type }}" ]; then cmd+=(--bump "${{ inputs.bump_type }}") fi - if [ -n "${{ inputs.prerelease }}" ]; then - cmd+=(--prerelease "${{ inputs.prerelease }}") + if [ -n "${INPUTS_PRERELEASE}" ]; then + cmd+=(--prerelease "${INPUTS_PRERELEASE}") fi # Run prepare — capture output even on failure so CI logs @@ -264,6 +264,10 @@ jobs: env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + INPUTS_GROUP: ${{ inputs.group }} + + INPUTS_PRERELEASE: ${{ inputs.prerelease }} + # ═══════════════════════════════════════════════════════════════════════ # RELEASE: Tag merge commit and create GitHub Release # @@ -308,7 +312,7 @@ jobs: echo "::group::Execution Plan (ASCII)" uv run --directory ${{ env.RELEASEKIT_DIR }} releasekit --workspace py plan --format full 2>&1 || true echo "::endgroup::" - if [ "${{ env.DRY_RUN }}" = "true" ]; then + if [ "${DRY_RUN}" = "true" ]; then echo "::notice::DRY RUN — no tags or releases will be created" else echo "::notice::LIVE RUN — tags and GitHub Release will be created" @@ -322,7 +326,7 @@ jobs: set -euo pipefail DRY_RUN_FLAG="" - if [ "${{ env.DRY_RUN }}" = "true" ]; then + if [ "${DRY_RUN}" = "true" ]; then DRY_RUN_FLAG="--dry-run" fi @@ -389,7 +393,7 @@ jobs: echo "::group::Execution Plan (ASCII)" uv run --directory ${{ env.RELEASEKIT_DIR }} releasekit --workspace py plan --format full 2>&1 || true echo "::endgroup::" - if [ "${{ env.DRY_RUN }}" = "true" ]; then + if [ "${DRY_RUN}" = "true" ]; then echo "::notice::DRY RUN — no packages will be published" else echo "::notice::LIVE RUN — packages will be published to ${{ inputs.target || 'pypi' }}" @@ -401,7 +405,7 @@ jobs: cmd=(uv run --directory ${{ env.RELEASEKIT_DIR }} releasekit --workspace py publish --force) - if [ "${{ env.DRY_RUN }}" = "true" ]; then + if [ "${DRY_RUN}" = "true" ]; then cmd+=(--dry-run) fi @@ -412,16 +416,16 @@ jobs: cmd+=(--index-url "$PUBLISH_INDEX_URL") fi - CONCURRENCY="${{ inputs.concurrency }}" + CONCURRENCY="${INPUTS_CONCURRENCY}" if [ -n "$CONCURRENCY" ] && [ "$CONCURRENCY" != "0" ]; then cmd+=(--concurrency "$CONCURRENCY") fi - if [ -n "${{ inputs.group }}" ]; then - cmd+=(--group "${{ inputs.group }}") + if [ -n "${INPUTS_GROUP}" ]; then + cmd+=(--group "${INPUTS_GROUP}") fi - MAX_RETRIES="${{ inputs.max_retries }}" + MAX_RETRIES="${INPUTS_MAX_RETRIES}" if [ -n "$MAX_RETRIES" ] && [ "$MAX_RETRIES" != "0" ]; then cmd+=(--max-retries "$MAX_RETRIES") fi @@ -439,6 +443,9 @@ jobs: # For trusted publishing (OIDC), no token needed. # For API token auth, set PYPI_TOKEN / TESTPYPI_TOKEN in repo secrets. UV_PUBLISH_TOKEN: ${{ inputs.target == 'testpypi' && secrets.TESTPYPI_TOKEN || secrets.PYPI_TOKEN }} + INPUTS_CONCURRENCY: ${{ inputs.concurrency }} + INPUTS_GROUP: ${{ inputs.group }} + INPUTS_MAX_RETRIES: ${{ inputs.max_retries }} - name: Upload manifest artifact uses: actions/upload-artifact@v4