From 6cb7966004370e1e812c066acdc38cb5a6800b21 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 11 Jan 2026 05:39:00 +0000 Subject: [PATCH 1/3] Add gpu-search command to search and filter GPU instance types Introduces a new `brev gpu-search` command (also aliased as `brev gpu`, `brev gpus`, and `brev gpu-list`) that allows users to search and filter GPU instance types from the Brev API. Features include: - Filter by GPU name (case-insensitive partial match) - Filter by minimum VRAM per GPU (in GB) - Filter by minimum total VRAM (GPU count * VRAM) - Filter by minimum GPU compute capability (e.g., 8.0 for Ampere) - Sort by price, gpu-count, vram, total-vram, vcpu, type, or capability - Support for ascending/descending sort order The command displays results in a formatted table showing instance type, GPU name, count, VRAM per GPU, total VRAM, compute capability, vCPUs, and hourly price. Includes comprehensive unit tests for filtering, sorting, and data processing functionality. --- pkg/cmd/cmd.go | 2 + pkg/cmd/gpusearch/gpusearch.go | 429 ++++++++++++++++++++++++++++ pkg/cmd/gpusearch/gpusearch_test.go | 388 +++++++++++++++++++++++++ pkg/store/instancetypes.go | 48 ++++ 4 files changed, 867 insertions(+) create mode 100644 pkg/cmd/gpusearch/gpusearch.go create mode 100644 pkg/cmd/gpusearch/gpusearch_test.go create mode 100644 pkg/store/instancetypes.go diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 2980b0ce..04564f9c 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -14,6 +14,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/delete" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/fu" + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" "github.com/brevdev/brev-cli/pkg/cmd/healthcheck" "github.com/brevdev/brev-cli/pkg/cmd/hello" "github.com/brevdev/brev-cli/pkg/cmd/importideconfig" @@ -270,6 +271,7 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor } cmd.AddCommand(workspacegroups.NewCmdWorkspaceGroups(t, loginCmdStore)) cmd.AddCommand(scale.NewCmdScale(t, noLoginCmdStore)) + cmd.AddCommand(gpusearch.NewCmdGPUSearch(t, noLoginCmdStore)) cmd.AddCommand(configureenvvars.NewCmdConfigureEnvVars(t, loginCmdStore)) cmd.AddCommand(importideconfig.NewCmdImportIDEConfig(t, noLoginCmdStore)) cmd.AddCommand(shell.NewCmdShell(t, loginCmdStore, noLoginCmdStore)) diff --git a/pkg/cmd/gpusearch/gpusearch.go b/pkg/cmd/gpusearch/gpusearch.go new file mode 100644 index 00000000..d66a5545 --- /dev/null +++ b/pkg/cmd/gpusearch/gpusearch.go @@ -0,0 +1,429 @@ +// Package gpusearch provides a command to search and filter GPU instance types +package gpusearch + +import ( + "fmt" + "os" + "regexp" + "sort" + "strconv" + "strings" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" +) + +// MemoryBytes represents the memory size with value and unit +type MemoryBytes struct { + Value int64 `json:"value"` + Unit string `json:"unit"` +} + +// GPU represents a GPU configuration within an instance type +type GPU struct { + Count int `json:"count"` + Name string `json:"name"` + Manufacturer string `json:"manufacturer"` + Memory string `json:"memory"` + MemoryBytes MemoryBytes `json:"memory_bytes"` +} + +// BasePrice represents the pricing information +type BasePrice struct { + Currency string `json:"currency"` + Amount string `json:"amount"` +} + +// InstanceType represents an instance type from the API +type InstanceType struct { + Type string `json:"type"` + SupportedGPUs []GPU `json:"supported_gpus"` + SupportedStorage []interface{} `json:"supported_storage"` // Complex objects, not used in filtering + Memory string `json:"memory"` + VCPU int `json:"vcpu"` + BasePrice BasePrice `json:"base_price"` + Location string `json:"location"` + SubLocation string `json:"sub_location"` + AvailableLocations []string `json:"available_locations"` +} + +// InstanceTypesResponse represents the API response +type InstanceTypesResponse struct { + Items []InstanceType `json:"items"` +} + +// GPUSearchStore defines the interface for fetching instance types +type GPUSearchStore interface { + GetInstanceTypes() (*InstanceTypesResponse, error) +} + +var ( + long = `Search and filter GPU instance types available on Brev. + +Filter instances by GPU name, VRAM, total VRAM, and GPU compute capability. +Sort results by various columns to find the best instance for your needs.` + + example = ` + # List all GPU instances + brev gpu-search + + # Filter by GPU name (case-insensitive, partial match) + brev gpu-search --gpu-name A100 + brev gpu-search --gpu-name "L40S" + + # Filter by minimum VRAM per GPU (in GB) + brev gpu-search --min-vram 24 + + # Filter by minimum total VRAM (in GB) + brev gpu-search --min-total-vram 80 + + # Filter by minimum GPU compute capability + brev gpu-search --min-capability 8.0 + + # Sort by different columns (price, gpu-count, vram, total-vram, vcpu) + brev gpu-search --sort price + brev gpu-search --sort total-vram --desc + + # Combine filters + brev gpu-search --gpu-name A100 --min-vram 40 --sort price +` +) + +// NewCmdGPUSearch creates the gpu-search command +func NewCmdGPUSearch(t *terminal.Terminal, store GPUSearchStore) *cobra.Command { + var gpuName string + var minVRAM float64 + var minTotalVRAM float64 + var minCapability float64 + var sortBy string + var descending bool + + cmd := &cobra.Command{ + Annotations: map[string]string{"workspace": ""}, + Use: "gpu-search", + Aliases: []string{"gpu", "gpus", "gpu-list"}, + DisableFlagsInUseLine: true, + Short: "Search and filter GPU instance types", + Long: long, + Example: example, + RunE: func(cmd *cobra.Command, args []string) error { + err := RunGPUSearch(t, store, gpuName, minVRAM, minTotalVRAM, minCapability, sortBy, descending) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil + }, + } + + cmd.Flags().StringVarP(&gpuName, "gpu-name", "g", "", "Filter by GPU name (case-insensitive, partial match)") + cmd.Flags().Float64VarP(&minVRAM, "min-vram", "v", 0, "Minimum VRAM per GPU in GB") + cmd.Flags().Float64VarP(&minTotalVRAM, "min-total-vram", "t", 0, "Minimum total VRAM (GPU count * VRAM) in GB") + cmd.Flags().Float64VarP(&minCapability, "min-capability", "c", 0, "Minimum GPU compute capability (e.g., 8.0 for Ampere)") + cmd.Flags().StringVarP(&sortBy, "sort", "s", "price", "Sort by: price, gpu-count, vram, total-vram, vcpu, type") + cmd.Flags().BoolVarP(&descending, "desc", "d", false, "Sort in descending order") + + return cmd +} + +// GPUInstanceInfo holds processed GPU instance information for display +type GPUInstanceInfo struct { + Type string + GPUName string + GPUCount int + VRAMPerGPU float64 // in GB + TotalVRAM float64 // in GB + Capability float64 + VCPUs int + Memory string + PricePerHour float64 + Manufacturer string +} + +// RunGPUSearch executes the GPU search with filters and sorting +func RunGPUSearch(t *terminal.Terminal, store GPUSearchStore, gpuName string, minVRAM, minTotalVRAM, minCapability float64, sortBy string, descending bool) error { + response, err := store.GetInstanceTypes() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if response == nil || len(response.Items) == 0 { + t.Vprint(t.Yellow("No instance types found")) + return nil + } + + // Process and filter instances + instances := processInstances(response.Items) + + // Apply filters + filtered := filterInstances(instances, gpuName, minVRAM, minTotalVRAM, minCapability) + + if len(filtered) == 0 { + t.Vprint(t.Yellow("No GPU instances match the specified filters")) + return nil + } + + // Sort instances + sortInstances(filtered, sortBy, descending) + + // Display results + displayGPUTable(t, filtered) + + t.Vprintf("\n%s\n", t.Green(fmt.Sprintf("Found %d GPU instance types", len(filtered)))) + + return nil +} + +// parseMemoryToGB converts memory string like "22GiB360MiB" or "40GiB" to GB +func parseMemoryToGB(memory string) float64 { + // Handle memory_bytes if provided (in MiB) + // Otherwise parse the string format + + var totalGB float64 + + // Match GiB values + gibRe := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*GiB`) + gibMatches := gibRe.FindAllStringSubmatch(memory, -1) + for _, match := range gibMatches { + if val, err := strconv.ParseFloat(match[1], 64); err == nil { + totalGB += val + } + } + + // Match MiB values and convert to GB + mibRe := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*MiB`) + mibMatches := mibRe.FindAllStringSubmatch(memory, -1) + for _, match := range mibMatches { + if val, err := strconv.ParseFloat(match[1], 64); err == nil { + totalGB += val / 1024 + } + } + + // Match GB values (in case API uses GB instead of GiB) + gbRe := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*GB`) + gbMatches := gbRe.FindAllStringSubmatch(memory, -1) + for _, match := range gbMatches { + if val, err := strconv.ParseFloat(match[1], 64); err == nil { + totalGB += val + } + } + + return totalGB +} + +// gpuCapabilityEntry represents a GPU pattern and its compute capability +type gpuCapabilityEntry struct { + pattern string + capability float64 +} + +// getGPUCapability returns the compute capability for known GPU types +func getGPUCapability(gpuName string) float64 { + gpuName = strings.ToUpper(gpuName) + + // Order matters: more specific patterns must come before less specific ones + // (e.g., "A100" before "A10", "L40S" before "L40") + capabilities := []gpuCapabilityEntry{ + // NVIDIA Hopper + {"H100", 9.0}, + {"H200", 9.0}, + + // NVIDIA Ada Lovelace (L40S before L40, L4) + {"L40S", 8.9}, + {"L40", 8.9}, + {"L4", 8.9}, + {"RTX4090", 8.9}, + {"RTX4080", 8.9}, + + // NVIDIA Ampere (A100 before A10G, A10) + {"A100", 8.0}, + {"A10G", 8.6}, + {"A10", 8.6}, + {"A40", 8.6}, + {"A30", 8.0}, + {"A16", 8.6}, + {"RTX3090", 8.6}, + {"RTX3080", 8.6}, + + // NVIDIA Turing + {"T4", 7.5}, + {"RTX2080", 7.5}, + + // NVIDIA Volta + {"V100", 7.0}, + + // NVIDIA Pascal (P100 before P40, P4) + {"P100", 6.0}, + {"P40", 6.1}, + {"P4", 6.1}, + + // NVIDIA Kepler + {"K80", 3.7}, + + // Gaudi (Habana) - not CUDA compatible + {"HL-205", 0}, + {"GAUDI3", 0}, + {"GAUDI2", 0}, + {"GAUDI", 0}, + } + + for _, entry := range capabilities { + if strings.Contains(gpuName, entry.pattern) { + return entry.capability + } + } + return 0 +} + +// processInstances converts raw instance types to GPUInstanceInfo +func processInstances(items []InstanceType) []GPUInstanceInfo { + var instances []GPUInstanceInfo + + for _, item := range items { + if len(item.SupportedGPUs) == 0 { + continue // Skip non-GPU instances + } + + for _, gpu := range item.SupportedGPUs { + vramPerGPU := parseMemoryToGB(gpu.Memory) + // Also check memory_bytes as fallback + if vramPerGPU == 0 && gpu.MemoryBytes.Value > 0 { + // Convert based on unit + if gpu.MemoryBytes.Unit == "MiB" { + vramPerGPU = float64(gpu.MemoryBytes.Value) / 1024 // MiB to GiB + } else if gpu.MemoryBytes.Unit == "GiB" { + vramPerGPU = float64(gpu.MemoryBytes.Value) + } + } + + totalVRAM := vramPerGPU * float64(gpu.Count) + capability := getGPUCapability(gpu.Name) + + price := 0.0 + if item.BasePrice.Amount != "" { + price, _ = strconv.ParseFloat(item.BasePrice.Amount, 64) + } + + instances = append(instances, GPUInstanceInfo{ + Type: item.Type, + GPUName: gpu.Name, + GPUCount: gpu.Count, + VRAMPerGPU: vramPerGPU, + TotalVRAM: totalVRAM, + Capability: capability, + VCPUs: item.VCPU, + Memory: item.Memory, + PricePerHour: price, + Manufacturer: gpu.Manufacturer, + }) + } + } + + return instances +} + +// filterInstances applies all filters to the instance list +func filterInstances(instances []GPUInstanceInfo, gpuName string, minVRAM, minTotalVRAM, minCapability float64) []GPUInstanceInfo { + var filtered []GPUInstanceInfo + + for _, inst := range instances { + // Filter by GPU name (case-insensitive partial match) + if gpuName != "" && !strings.Contains(strings.ToLower(inst.GPUName), strings.ToLower(gpuName)) { + continue + } + + // Filter by minimum VRAM per GPU + if minVRAM > 0 && inst.VRAMPerGPU < minVRAM { + continue + } + + // Filter by minimum total VRAM + if minTotalVRAM > 0 && inst.TotalVRAM < minTotalVRAM { + continue + } + + // Filter by minimum GPU capability + if minCapability > 0 && inst.Capability < minCapability { + continue + } + + filtered = append(filtered, inst) + } + + return filtered +} + +// sortInstances sorts the instance list by the specified column +func sortInstances(instances []GPUInstanceInfo, sortBy string, descending bool) { + sort.Slice(instances, func(i, j int) bool { + var less bool + switch strings.ToLower(sortBy) { + case "price": + less = instances[i].PricePerHour < instances[j].PricePerHour + case "gpu-count": + less = instances[i].GPUCount < instances[j].GPUCount + case "vram": + less = instances[i].VRAMPerGPU < instances[j].VRAMPerGPU + case "total-vram": + less = instances[i].TotalVRAM < instances[j].TotalVRAM + case "vcpu": + less = instances[i].VCPUs < instances[j].VCPUs + case "type": + less = instances[i].Type < instances[j].Type + case "capability": + less = instances[i].Capability < instances[j].Capability + default: + less = instances[i].PricePerHour < instances[j].PricePerHour + } + + if descending { + return !less + } + return less + }) +} + +// getBrevTableOptions returns table styling options +func getBrevTableOptions() table.Options { + options := table.OptionsDefault + options.DrawBorder = false + options.SeparateColumns = false + options.SeparateRows = false + options.SeparateHeader = false + return options +} + +// displayGPUTable renders the GPU instances as a table +func displayGPUTable(t *terminal.Terminal, instances []GPUInstanceInfo) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + + header := table.Row{"TYPE", "GPU", "COUNT", "VRAM/GPU", "TOTAL VRAM", "CAPABILITY", "VCPUs", "$/HR"} + ta.AppendHeader(header) + + for _, inst := range instances { + vramStr := fmt.Sprintf("%.0f GB", inst.VRAMPerGPU) + totalVramStr := fmt.Sprintf("%.0f GB", inst.TotalVRAM) + capStr := "-" + if inst.Capability > 0 { + capStr = fmt.Sprintf("%.1f", inst.Capability) + } + priceStr := fmt.Sprintf("$%.2f", inst.PricePerHour) + + row := table.Row{ + inst.Type, + t.Green(inst.GPUName), + inst.GPUCount, + vramStr, + totalVramStr, + capStr, + inst.VCPUs, + priceStr, + } + ta.AppendRow(row) + } + + ta.Render() +} diff --git a/pkg/cmd/gpusearch/gpusearch_test.go b/pkg/cmd/gpusearch/gpusearch_test.go new file mode 100644 index 00000000..0714874f --- /dev/null +++ b/pkg/cmd/gpusearch/gpusearch_test.go @@ -0,0 +1,388 @@ +package gpusearch + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// MockGPUSearchStore is a mock implementation of GPUSearchStore for testing +type MockGPUSearchStore struct { + Response *InstanceTypesResponse + Err error +} + +func (m *MockGPUSearchStore) GetInstanceTypes() (*InstanceTypesResponse, error) { + if m.Err != nil { + return nil, m.Err + } + return m.Response, nil +} + +func createTestInstanceTypes() *InstanceTypesResponse { + return &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "g5.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.006"}, + }, + { + Type: "g5.2xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "32GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "1.212"}, + }, + { + Type: "p3.2xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "V100", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "61GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "3.06"}, + }, + { + Type: "p3.8xlarge", + SupportedGPUs: []GPU{ + {Count: 4, Name: "V100", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "244GiB", + VCPU: 32, + BasePrice: BasePrice{Currency: "USD", Amount: "12.24"}, + }, + { + Type: "p4d.24xlarge", + SupportedGPUs: []GPU{ + {Count: 8, Name: "A100", Manufacturer: "NVIDIA", Memory: "40GiB"}, + }, + Memory: "1152GiB", + VCPU: 96, + BasePrice: BasePrice{Currency: "USD", Amount: "32.77"}, + }, + { + Type: "g4dn.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "T4", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.526"}, + }, + { + Type: "g6.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "L4", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.805"}, + }, + }, + } +} + +func TestParseMemoryToGB(t *testing.T) { + tests := []struct { + name string + input string + expected float64 + }{ + {"Simple GiB", "24GiB", 24}, + {"GiB with MiB", "22GiB360MiB", 22.3515625}, + {"Simple GB", "16GB", 16}, + {"Large GiB", "1152GiB", 1152}, + {"Empty string", "", 0}, + {"MiB only", "512MiB", 0.5}, + {"With spaces", "24 GiB", 24}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseMemoryToGB(tt.input) + assert.InDelta(t, tt.expected, result, 0.01, "Memory parsing failed for %s", tt.input) + }) + } +} + +func TestGetGPUCapability(t *testing.T) { + tests := []struct { + name string + gpuName string + expected float64 + }{ + {"A100", "A100", 8.0}, + {"A10G", "A10G", 8.6}, + {"V100", "V100", 7.0}, + {"T4", "T4", 7.5}, + {"L4", "L4", 8.9}, + {"L40S", "L40S", 8.9}, + {"H100", "H100", 9.0}, + {"Unknown GPU", "Unknown", 0}, + {"Case insensitive", "a100", 8.0}, + {"Gaudi", "HL-205", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getGPUCapability(tt.gpuName) + assert.Equal(t, tt.expected, result, "GPU capability mismatch for %s", tt.gpuName) + }) + } +} + +func TestProcessInstances(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + assert.Len(t, instances, 7, "Expected 7 GPU instances") + + // Check specific instance + var a10gInstance *GPUInstanceInfo + for i := range instances { + if instances[i].Type == "g5.xlarge" { + a10gInstance = &instances[i] + break + } + } + + assert.NotNil(t, a10gInstance, "g5.xlarge instance should exist") + assert.Equal(t, "A10G", a10gInstance.GPUName) + assert.Equal(t, 1, a10gInstance.GPUCount) + assert.Equal(t, 24.0, a10gInstance.VRAMPerGPU) + assert.Equal(t, 24.0, a10gInstance.TotalVRAM) + assert.Equal(t, 8.6, a10gInstance.Capability) + assert.Equal(t, 4, a10gInstance.VCPUs) + assert.InDelta(t, 1.006, a10gInstance.PricePerHour, 0.001) +} + +func TestFilterInstancesByGPUName(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by A10G + filtered := filterInstances(instances, "A10G", 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 A10G instances") + + // Filter by V100 + filtered = filterInstances(instances, "V100", 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 V100 instances") + + // Filter by lowercase (case-insensitive) + filtered = filterInstances(instances, "v100", 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 V100 instances (case-insensitive)") + + // Filter by partial match + filtered = filterInstances(instances, "A1", 0, 0, 0) + assert.Len(t, filtered, 3, "Should have 3 instances matching 'A1' (A10G and A100)") +} + +func TestFilterInstancesByMinVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by min VRAM 24GB + filtered := filterInstances(instances, "", 24, 0, 0) + assert.Len(t, filtered, 4, "Should have 4 instances with >= 24GB VRAM") + + // Filter by min VRAM 40GB + filtered = filterInstances(instances, "", 40, 0, 0) + assert.Len(t, filtered, 1, "Should have 1 instance with >= 40GB VRAM") + assert.Equal(t, "A100", filtered[0].GPUName) +} + +func TestFilterInstancesByMinTotalVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by min total VRAM 60GB + filtered := filterInstances(instances, "", 0, 60, 0) + assert.Len(t, filtered, 2, "Should have 2 instances with >= 60GB total VRAM") + + // Filter by min total VRAM 300GB + filtered = filterInstances(instances, "", 0, 300, 0) + assert.Len(t, filtered, 1, "Should have 1 instance with >= 300GB total VRAM") + assert.Equal(t, "p4d.24xlarge", filtered[0].Type) +} + +func TestFilterInstancesByMinCapability(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by capability >= 8.0 + filtered := filterInstances(instances, "", 0, 0, 8.0) + assert.Len(t, filtered, 4, "Should have 4 instances with capability >= 8.0") + + // Filter by capability >= 8.5 + filtered = filterInstances(instances, "", 0, 0, 8.5) + assert.Len(t, filtered, 3, "Should have 3 instances with capability >= 8.5") +} + +func TestFilterInstancesCombined(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Filter by GPU name and min VRAM + filtered := filterInstances(instances, "A10G", 24, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 A10G instances with >= 24GB VRAM") + + // Filter by GPU name, min VRAM, and capability + filtered = filterInstances(instances, "", 24, 0, 8.5) + assert.Len(t, filtered, 3, "Should have 3 instances with >= 24GB VRAM and capability >= 8.5") +} + +func TestSortInstancesByPrice(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by price ascending + sortInstances(instances, "price", false) + assert.Equal(t, "g4dn.xlarge", instances[0].Type, "Cheapest should be g4dn.xlarge") + assert.Equal(t, "p4d.24xlarge", instances[len(instances)-1].Type, "Most expensive should be p4d.24xlarge") + + // Sort by price descending + sortInstances(instances, "price", true) + assert.Equal(t, "p4d.24xlarge", instances[0].Type, "Most expensive should be first when descending") +} + +func TestSortInstancesByGPUCount(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by GPU count ascending + sortInstances(instances, "gpu-count", false) + assert.Equal(t, 1, instances[0].GPUCount, "Instances with 1 GPU should be first") + + // Sort by GPU count descending + sortInstances(instances, "gpu-count", true) + assert.Equal(t, 8, instances[0].GPUCount, "Instance with 8 GPUs should be first when descending") +} + +func TestSortInstancesByVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by VRAM ascending + sortInstances(instances, "vram", false) + assert.Equal(t, 16.0, instances[0].VRAMPerGPU, "Instances with 16GB VRAM should be first") + + // Sort by VRAM descending + sortInstances(instances, "vram", true) + assert.Equal(t, 40.0, instances[0].VRAMPerGPU, "Instance with 40GB VRAM should be first when descending") +} + +func TestSortInstancesByTotalVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by total VRAM ascending + sortInstances(instances, "total-vram", false) + assert.Equal(t, 16.0, instances[0].TotalVRAM, "Instances with 16GB total VRAM should be first") + + // Sort by total VRAM descending + sortInstances(instances, "total-vram", true) + assert.Equal(t, 320.0, instances[0].TotalVRAM, "Instance with 320GB total VRAM should be first when descending") +} + +func TestSortInstancesByVCPU(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by vCPU ascending + sortInstances(instances, "vcpu", false) + assert.Equal(t, 4, instances[0].VCPUs, "Instances with 4 vCPUs should be first") + + // Sort by vCPU descending + sortInstances(instances, "vcpu", true) + assert.Equal(t, 96, instances[0].VCPUs, "Instance with 96 vCPUs should be first when descending") +} + +func TestSortInstancesByCapability(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by capability ascending + sortInstances(instances, "capability", false) + assert.Equal(t, 7.0, instances[0].Capability, "Instances with capability 7.0 should be first") + + // Sort by capability descending + sortInstances(instances, "capability", true) + assert.Equal(t, 8.9, instances[0].Capability, "Instance with capability 8.9 should be first when descending") +} + +func TestSortInstancesByType(t *testing.T) { + response := createTestInstanceTypes() + instances := processInstances(response.Items) + + // Sort by type ascending + sortInstances(instances, "type", false) + assert.Equal(t, "g4dn.xlarge", instances[0].Type, "g4dn.xlarge should be first alphabetically") + + // Sort by type descending + sortInstances(instances, "type", true) + assert.Equal(t, "p4d.24xlarge", instances[0].Type, "p4d.24xlarge should be first when descending") +} + +func TestEmptyInstanceTypes(t *testing.T) { + response := &InstanceTypesResponse{Items: []InstanceType{}} + instances := processInstances(response.Items) + + assert.Len(t, instances, 0, "Should have 0 instances") + + filtered := filterInstances(instances, "A100", 0, 0, 0) + assert.Len(t, filtered, 0, "Filtered should also be empty") +} + +func TestNonGPUInstancesAreFiltered(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "m5.xlarge", + SupportedGPUs: []GPU{}, // No GPUs + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.192"}, + }, + { + Type: "g5.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.006"}, + }, + }, + } + + instances := processInstances(response.Items) + assert.Len(t, instances, 1, "Should only have 1 GPU instance, non-GPU instances should be filtered") + assert.Equal(t, "g5.xlarge", instances[0].Type) +} + +func TestMemoryBytesAsFallback(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "test.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "TestGPU", Manufacturer: "NVIDIA", Memory: "", MemoryBytes: MemoryBytes{Value: 24576, Unit: "MiB"}}, // 24GB in MiB + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.00"}, + }, + }, + } + + instances := processInstances(response.Items) + assert.Len(t, instances, 1) + assert.Equal(t, 24.0, instances[0].VRAMPerGPU, "Should fall back to MemoryBytes when Memory string is empty") +} diff --git a/pkg/store/instancetypes.go b/pkg/store/instancetypes.go new file mode 100644 index 00000000..4f12710a --- /dev/null +++ b/pkg/store/instancetypes.go @@ -0,0 +1,48 @@ +package store + +import ( + "encoding/json" + + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + resty "github.com/go-resty/resty/v2" +) + +const ( + instanceTypesAPIURL = "https://api.brev.dev" + instanceTypesAPIPath = "v1/instance/types" +) + +// GetInstanceTypes fetches all available instance types from the public API +func (s NoAuthHTTPStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + return fetchInstanceTypes() +} + +// GetInstanceTypes fetches all available instance types from the public API +func (s AuthHTTPStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + return fetchInstanceTypes() +} + +// fetchInstanceTypes fetches instance types from the public Brev API +func fetchInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + client := resty.New() + client.SetBaseURL(instanceTypesAPIURL) + + res, err := client.R(). + SetHeader("Accept", "application/json"). + Get(instanceTypesAPIPath) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if res.IsError() { + return nil, NewHTTPResponseError(res) + } + + var result gpusearch.InstanceTypesResponse + err = json.Unmarshal(res.Body(), &result) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + return &result, nil +} From 701e2321db0dc5664cde8916341b133c9dec36ea Mon Sep 17 00:00:00 2001 From: Alec Fong Date: Sat, 10 Jan 2026 23:14:43 -0800 Subject: [PATCH 2/3] Add compute capabilities for additional NVIDIA GPUs Add capability mappings for: - RTXPro6000 (12.0), B200 and RTX5090 (10.0 Blackwell) - RTX6000Ada, RTX4000Ada (8.9 Ada Lovelace) - A6000, A5000, A4000 (8.6 Ampere) - RTX6000 (7.5 Turing) - M60 (5.2 Maxwell) Co-Authored-By: Claude Opus 4.5 --- pkg/cmd/gpusearch/gpusearch.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/pkg/cmd/gpusearch/gpusearch.go b/pkg/cmd/gpusearch/gpusearch.go index d66a5545..c9a3e88f 100644 --- a/pkg/cmd/gpusearch/gpusearch.go +++ b/pkg/cmd/gpusearch/gpusearch.go @@ -225,14 +225,23 @@ func getGPUCapability(gpuName string) float64 { // Order matters: more specific patterns must come before less specific ones // (e.g., "A100" before "A10", "L40S" before "L40") capabilities := []gpuCapabilityEntry{ + // NVIDIA Professional (before other RTX patterns) + {"RTXPRO6000", 12.0}, + + // NVIDIA Blackwell + {"B200", 10.0}, + {"RTX5090", 10.0}, + // NVIDIA Hopper {"H100", 9.0}, {"H200", 9.0}, - // NVIDIA Ada Lovelace (L40S before L40, L4) + // NVIDIA Ada Lovelace (L40S before L40, L4; RTX*Ada before RTX*) {"L40S", 8.9}, {"L40", 8.9}, {"L4", 8.9}, + {"RTX6000ADA", 8.9}, + {"RTX4000ADA", 8.9}, {"RTX4090", 8.9}, {"RTX4080", 8.9}, @@ -241,6 +250,9 @@ func getGPUCapability(gpuName string) float64 { {"A10G", 8.6}, {"A10", 8.6}, {"A40", 8.6}, + {"A6000", 8.6}, + {"A5000", 8.6}, + {"A4000", 8.6}, {"A30", 8.0}, {"A16", 8.6}, {"RTX3090", 8.6}, @@ -248,6 +260,7 @@ func getGPUCapability(gpuName string) float64 { // NVIDIA Turing {"T4", 7.5}, + {"RTX6000", 7.5}, {"RTX2080", 7.5}, // NVIDIA Volta @@ -258,6 +271,9 @@ func getGPUCapability(gpuName string) float64 { {"P40", 6.1}, {"P4", 6.1}, + // NVIDIA Maxwell + {"M60", 5.2}, + // NVIDIA Kepler {"K80", 3.7}, From 89be4b66374d942793a748a59e68c981675e5a6c Mon Sep 17 00:00:00 2001 From: Alec Fong Date: Sat, 10 Jan 2026 23:17:34 -0800 Subject: [PATCH 3/3] Filter out non-NVIDIA GPUs from gpu-search results Only show NVIDIA GPUs (exclude AMD Radeon, Intel Gaudi, etc.) since compute capability is NVIDIA-specific. Co-Authored-By: Claude Opus 4.5 --- pkg/cmd/gpusearch/gpusearch.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/cmd/gpusearch/gpusearch.go b/pkg/cmd/gpusearch/gpusearch.go index c9a3e88f..54d1bd37 100644 --- a/pkg/cmd/gpusearch/gpusearch.go +++ b/pkg/cmd/gpusearch/gpusearch.go @@ -344,6 +344,11 @@ func filterInstances(instances []GPUInstanceInfo, gpuName string, minVRAM, minTo var filtered []GPUInstanceInfo for _, inst := range instances { + // Filter out non-NVIDIA GPUs (AMD, Intel/Habana, etc.) + if !strings.Contains(strings.ToUpper(inst.Manufacturer), "NVIDIA") { + continue + } + // Filter by GPU name (case-insensitive partial match) if gpuName != "" && !strings.Contains(strings.ToLower(inst.GPUName), strings.ToLower(gpuName)) { continue