Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions pkg/images/really_remix.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,58 @@ package images
import (
"fmt"
"os"
"strings"
"time"

"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/crane"
v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/google/go-containerregistry/pkg/v1/mutate"
)

func getCOGWeights(imageRef v1.Image) (string, error) {
cfg, err := imageRef.ConfigFile()
if err != nil {
return "", fmt.Errorf("getting config %w", err)
}

for _, envVar := range cfg.Config.Env {
if strings.HasPrefix(envVar, "COG_WEIGHTS=") {
return strings.TrimPrefix(envVar, "COG_WEIGHTS="), nil
}
}

return "", nil
}

func addCogWeights(baseImage v1.Image, cogWeights string) (v1.Image, error) {
cfg, err := baseImage.ConfigFile()
if err != nil {
return nil, fmt.Errorf("getting config %w", err)
}

// Find and update COG_WEIGHTS if it exists, otherwise append it
found := false
for i, envVar := range cfg.Config.Env {
if strings.HasPrefix(envVar, "COG_WEIGHTS=") {
cfg.Config.Env[i] = "COG_WEIGHTS=" + cogWeights
found = true
break
}
}
if !found {
cfg.Config.Env = append(cfg.Config.Env, "COG_WEIGHTS="+cogWeights)
}

// Create a new image with the updated config
mutant, err := mutate.Config(baseImage, cfg.Config)
if err != nil {
return nil, fmt.Errorf("mutating config %w", err)
}

return mutant, nil
}

func ReallyRemix(baseRef string, weightsRef string, dest string, auth authn.Authenticator) (string, error) {
fmt.Fprintln(os.Stderr, "fetching metadata for", weightsRef)
start := time.Now()
Expand All @@ -29,19 +74,36 @@ func ReallyRemix(baseRef string, weightsRef string, dest string, auth authn.Auth

fmt.Fprintln(os.Stderr, "finding weights layer")

start = time.Now()
weightsLayer, err := findWeightsLayer(weightsImage)
cogWeights, err := getCOGWeights(weightsImage)
if err != nil {
return "", fmt.Errorf("getting layers %w", err)
return "", fmt.Errorf("getting cog weights %w", err)
}
fmt.Fprintln(os.Stderr, "finding weights layer took", time.Since(start))

start = time.Now()
mutant, err := appendLayers(baseImage, weightsLayer)
if err != nil {
return "", fmt.Errorf("appending layers %w", err)
var mutant v1.Image

if cogWeights != "" {
fmt.Println("found cog weights", cogWeights)
start = time.Now()
mutant, err = addCogWeights(baseImage, cogWeights)
if err != nil {
return "", fmt.Errorf("adding cog weights %w", err)
}
fmt.Fprintln(os.Stderr, "adding cog weights took", time.Since(start))
} else {
start = time.Now()
weightsLayer, err := findWeightsLayer(weightsImage)
if err != nil {
return "", fmt.Errorf("getting layers %w", err)
}
fmt.Fprintln(os.Stderr, "finding weights layer took", time.Since(start))

start = time.Now()
mutant, err = appendLayers(baseImage, weightsLayer)
if err != nil {
return "", fmt.Errorf("appending layers %w", err)
}
fmt.Fprintln(os.Stderr, "appending layers took", time.Since(start))
}
fmt.Fprintln(os.Stderr, "appending layers took", time.Since(start))

fmt.Fprintln(os.Stderr, "mutant image:", mutant)

Expand Down