diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index f02a36557..b41913a3a 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -465,6 +465,7 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End loader.RegisterFeatureGate(datalayer.FeatureGate) loader.RegisterFeatureGate(flowcontrol.FeatureGate) + loader.RegisterFeatureGate(datalayer.PrepareDataPluginsFeatureGate) r.registerInTreePlugins() @@ -504,8 +505,9 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf // Add requestControl plugins r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...) + // Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles. - if r.requestControlConfig.PrepareDataPluginGraph() != nil { + if r.requestControlConfig.PrepareDataPluginGraph(r.featureGates[datalayer.PrepareDataPluginsFeatureGate]) != nil { return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies") } diff --git a/pkg/epp/datalayer/factory.go b/pkg/epp/datalayer/factory.go index 3a81763d5..08c5f3fd2 100644 --- a/pkg/epp/datalayer/factory.go +++ b/pkg/epp/datalayer/factory.go @@ -26,7 +26,8 @@ import ( ) const ( - FeatureGate = "dataLayer" + FeatureGate = "dataLayer" + PrepareDataPluginsFeatureGate = "prepareDataPlugins" ) // PoolInfo represents the DataStore information needed for endpoints. diff --git a/pkg/epp/datalayer/plugins/data_types.go b/pkg/epp/datalayer/plugins/data_types.go new file mode 100644 index 000000000..63501ebb9 --- /dev/null +++ b/pkg/epp/datalayer/plugins/data_types.go @@ -0,0 +1,45 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" +) + +const ( + PrefixCacheMatchInfoKey = "PrefixCacheMatchInfoKey" +) + +type PrefixCacheMatchInfo struct { + matchPercentage float64 +} + +func NewPrefixCacheMatchInfo(matchPercentage float64) *PrefixCacheMatchInfo { + return &PrefixCacheMatchInfo{ + matchPercentage: matchPercentage, + } +} + +func (p *PrefixCacheMatchInfo) MatchPercentage() float64 { + return p.matchPercentage +} + +func (p *PrefixCacheMatchInfo) Clone() datalayer.Cloneable { + return &PrefixCacheMatchInfo{ + matchPercentage: p.matchPercentage, + } +} diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go index 8f08ac121..0bfa36e8a 100644 --- a/pkg/epp/requestcontrol/request_control_config.go +++ b/pkg/epp/requestcontrol/request_control_config.go @@ -107,7 +107,11 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) { // PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order. // If a cycle is detected, it returns an error. -func (c *Config) PrepareDataPluginGraph() error { +func (c *Config) PrepareDataPluginGraph(enablePrepareDataPlugins bool) error { + if !enablePrepareDataPlugins { + c.prepareDataPlugins = []PrepareDataPlugin{} + return nil + } dag := buildDAG(c.prepareDataPlugins) plugins, err := sortPlugins(dag, c.prepareDataPlugins) if err != nil { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 2a1a3a8b2..b0c494def 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -28,6 +28,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" @@ -206,6 +207,40 @@ func (p *Plugin) WithName(name string) *Plugin { return p } +func (p *Plugin) Produces() map[string]any { + return map[string]any{dplugins.PrefixCacheMatchInfoKey: dplugins.PrefixCacheMatchInfo{}} +} + +func (p *Plugin) Consumes() map[string]any { + return map[string]any{} +} + +func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error { + // pre score step, hashing prompt and find longest prefix match. + hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch) + state := &SchedulingContextState{ + PrefixHashes: hashes, + PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), + } + for server, matchLen := range state.PrefixCacheServers { + log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "server", server, "longest-prefix-match", matchLen) + + } + + total := len(state.PrefixHashes) + podScoreFunc := func(pod types.Pod) float64 { + if total == 0 { + return 0 + } + matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)] + return float64(matchLen) / float64(total) + } + for _, pod := range pods { + pod.Put(dplugins.PrefixCacheMatchInfoKey, dplugins.NewPrefixCacheMatchInfo(podScoreFunc(pod))) + } + return nil +} + // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { // pre score step, hashing prompt and find longest prefix match. diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index f0feeef68..682ff08f0 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -30,9 +30,13 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +// static check to ensure Plugin implements the PrepareDataPlugin interface. +var _ requestcontrol.PrepareDataPlugin = &Plugin{} + func TestPrefixPluginCompletion(t *testing.T) { config := Config{ BlockSize: 4, diff --git a/pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer.go b/pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer.go new file mode 100644 index 000000000..3066b7d5a --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer.go @@ -0,0 +1,85 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer + +import ( + "context" + "encoding/json" + + k8stypes "k8s.io/apimachinery/pkg/types" + dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + PrefixCacheMatchScorerType = "prefix-cache-match-scorer" +) + +type ServerID k8stypes.NamespacedName + +// compile-time type assertion +var _ framework.Scorer = &PrefixCacheScorer{} + +// PrefixCacheScorerFactory defines the factory function for PrefixCacheScorer. +func PrefixCacheScorerFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { + return NewPrefixCacheScorer().WithName(name), nil +} + +// NewPrefixCacheScorer initializes a new PrefixCacheScorer and returns its pointer. +func NewPrefixCacheScorer() *PrefixCacheScorer { + return &PrefixCacheScorer{ + tn: plugins.TypedName{Type: PrefixCacheMatchScorerType, Name: PrefixCacheMatchScorerType}, + } +} + +// PrefixCacheScorer scores list of candidate pods based on Lora affinity and availability. +type PrefixCacheScorer struct { + tn plugins.TypedName +} + +// TypedName returns the type and name tuple of this plugin instance. +func (s *PrefixCacheScorer) TypedName() plugins.TypedName { + return s.tn +} + +// Consumes returns the list of data that is consumed by the plugin. +func (s *PrefixCacheScorer) Consumes() map[string]any { + return map[string]any{} +} + +// WithName sets the name of the scorer. +func (s *PrefixCacheScorer) WithName(name string) *PrefixCacheScorer { + s.tn.Name = name + return s +} + +func (s *PrefixCacheScorer) Score(_ context.Context, cycleState *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { + // calculate the scores of pods + scores := make(map[types.Pod]float64, len(pods)) + + for _, pod := range pods { + matchPercent, ok := pod.Get(dplugins.PrefixCacheMatchInfoKey) + if !ok { + scores[pod] = 0.0 + continue + } + scores[pod] = matchPercent.(*dplugins.PrefixCacheMatchInfo).MatchPercentage() + } + return scores +} diff --git a/pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer_test.go b/pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer_test.go new file mode 100644 index 000000000..948044c5a --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/scorer/prefix_cache_match_scorer_test.go @@ -0,0 +1,150 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scorer + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" + dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// mockPod is a mock implementation of types.Pod for testing purposes. +type mockPod struct { + data map[string]datalayer.Cloneable +} + +func newMockPod() *mockPod { + return &mockPod{ + data: make(map[string]datalayer.Cloneable), + } +} + +func (p *mockPod) Get(key string) (datalayer.Cloneable, bool) { + val, ok := p.data[key] + return val, ok +} + +func (p *mockPod) Put(key string, value datalayer.Cloneable) { + p.data[key] = value +} + +func (p *mockPod) GetPod() *backend.Pod { + return nil +} + +func (p *mockPod) GetMetrics() *backendmetrics.MetricsState { + return nil +} + +func (p *mockPod) String() string { + return "" +} + +func (p *mockPod) Keys() []string { + keys := make([]string, 0, len(p.data)) + for k := range p.data { + keys = append(keys, k) + } + return keys +} + +func TestPrefixCacheScorer_Score(t *testing.T) { + pod1 := newMockPod() + pod1.Put(dplugins.PrefixCacheMatchInfoKey, dplugins.NewPrefixCacheMatchInfo(50.0)) + + pod2 := newMockPod() + pod2.Put(dplugins.PrefixCacheMatchInfoKey, dplugins.NewPrefixCacheMatchInfo(100.0)) + + pod3 := newMockPod() + + testCases := []struct { + name string + pods []types.Pod + expected map[types.Pod]float64 + }{ + { + name: "pods with prefix cache match percent", + pods: []types.Pod{pod1, pod2}, + expected: map[types.Pod]float64{ + pod1: 50.0, + pod2: 100.0, + }, + }, + { + name: "pod without prefix cache match percent", + pods: []types.Pod{pod3}, + expected: map[types.Pod]float64{ + pod3: 0.0, + }, + }, + { + name: "mixed pods", + pods: []types.Pod{pod1, pod3}, + expected: map[types.Pod]float64{ + pod1: 50.0, + pod3: 0.0, + }, + }, + { + name: "empty pods list", + pods: []types.Pod{}, + expected: map[types.Pod]float64{}, + }, + } + + scorer := NewPrefixCacheScorer() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + scores := scorer.Score(context.Background(), nil, nil, tc.pods) + assert.Equal(t, tc.expected, scores) + }) + } +} + +func TestNewPrefixCacheScorer(t *testing.T) { + scorer := NewPrefixCacheScorer() + assert.NotNil(t, scorer) + assert.Equal(t, PrefixCacheMatchScorerType, scorer.tn.Type) + assert.Equal(t, PrefixCacheMatchScorerType, scorer.tn.Name) +} + +func TestPrefixCacheScorer_WithName(t *testing.T) { + scorer := NewPrefixCacheScorer() + customName := "custom-scorer" + scorer.WithName(customName) + assert.Equal(t, customName, scorer.TypedName().Name) +} + +func TestPrefixCacheScorer_TypedName(t *testing.T) { + scorer := NewPrefixCacheScorer() + tn := scorer.TypedName() + assert.Equal(t, PrefixCacheMatchScorerType, tn.Type) + assert.Equal(t, PrefixCacheMatchScorerType, tn.Name) +} + +func TestPrefixCacheScorer_Consumes(t *testing.T) { + scorer := NewPrefixCacheScorer() + consumes := scorer.Consumes() + assert.Empty(t, consumes) +}