diff --git a/pkg/images/really_remix.go b/pkg/images/really_remix.go index 8a9b529..4a2967e 100644 --- a/pkg/images/really_remix.go +++ b/pkg/images/really_remix.go @@ -3,13 +3,58 @@ package images import ( "fmt" "os" + "strings" "time" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/mutate" ) +func getCOGWeights(imageRef v1.Image) (string, error) { + cfg, err := imageRef.ConfigFile() + if err != nil { + return "", fmt.Errorf("getting config %w", err) + } + + for _, envVar := range cfg.Config.Env { + if strings.HasPrefix(envVar, "COG_WEIGHTS=") { + return strings.TrimPrefix(envVar, "COG_WEIGHTS="), nil + } + } + + return "", nil +} + +func addCogWeights(baseImage v1.Image, cogWeights string) (v1.Image, error) { + cfg, err := baseImage.ConfigFile() + if err != nil { + return nil, fmt.Errorf("getting config %w", err) + } + + // Find and update COG_WEIGHTS if it exists, otherwise append it + found := false + for i, envVar := range cfg.Config.Env { + if strings.HasPrefix(envVar, "COG_WEIGHTS=") { + cfg.Config.Env[i] = "COG_WEIGHTS=" + cogWeights + found = true + break + } + } + if !found { + cfg.Config.Env = append(cfg.Config.Env, "COG_WEIGHTS="+cogWeights) + } + + // Create a new image with the updated config + mutant, err := mutate.Config(baseImage, cfg.Config) + if err != nil { + return nil, fmt.Errorf("mutating config %w", err) + } + + return mutant, nil +} + func ReallyRemix(baseRef string, weightsRef string, dest string, auth authn.Authenticator) (string, error) { fmt.Fprintln(os.Stderr, "fetching metadata for", weightsRef) start := time.Now() @@ -29,19 +74,36 @@ func ReallyRemix(baseRef string, weightsRef string, dest string, auth authn.Auth fmt.Fprintln(os.Stderr, "finding weights layer") - start = time.Now() - weightsLayer, err := findWeightsLayer(weightsImage) + cogWeights, err := getCOGWeights(weightsImage) if err != nil { - return "", fmt.Errorf("getting layers %w", err) + return "", fmt.Errorf("getting cog weights %w", err) } - fmt.Fprintln(os.Stderr, "finding weights layer took", time.Since(start)) - start = time.Now() - mutant, err := appendLayers(baseImage, weightsLayer) - if err != nil { - return "", fmt.Errorf("appending layers %w", err) + var mutant v1.Image + + if cogWeights != "" { + fmt.Println("found cog weights", cogWeights) + start = time.Now() + mutant, err = addCogWeights(baseImage, cogWeights) + if err != nil { + return "", fmt.Errorf("adding cog weights %w", err) + } + fmt.Fprintln(os.Stderr, "adding cog weights took", time.Since(start)) + } else { + start = time.Now() + weightsLayer, err := findWeightsLayer(weightsImage) + if err != nil { + return "", fmt.Errorf("getting layers %w", err) + } + fmt.Fprintln(os.Stderr, "finding weights layer took", time.Since(start)) + + start = time.Now() + mutant, err = appendLayers(baseImage, weightsLayer) + if err != nil { + return "", fmt.Errorf("appending layers %w", err) + } + fmt.Fprintln(os.Stderr, "appending layers took", time.Since(start)) } - fmt.Fprintln(os.Stderr, "appending layers took", time.Since(start)) fmt.Fprintln(os.Stderr, "mutant image:", mutant)