Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import (

"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/distribution/oci/authn"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/distribution/oci/remote"
"github.com/docker/model-runner/pkg/inference"
dmrm "github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/scheduling"
Expand All @@ -29,6 +32,78 @@ var (
ErrServiceUnavailable = errors.New("service unavailable")
)

// resolveCredentials resolves Docker registry credentials for a model reference
// and exchanges them for a short-lived bearer token.
// Returns nil if no credentials are found (anonymous access).
func resolveCredentials(ctx context.Context, model string) *distribution.Credentials {
// Skip credential resolution for Hugging Face models (use HF_TOKEN env var instead).
if strings.HasPrefix(strings.ToLower(model), "hf.co/") {
if hfToken := os.Getenv("HF_TOKEN"); hfToken != "" {
return &distribution.Credentials{BearerToken: hfToken}
}
return nil
}

ref, err := reference.ParseReference(model)
if err != nil {
return nil
}

resource := authn.NewResource(ref)
auth, err := authn.DefaultKeychain.Resolve(resource)
if err != nil {
return nil
}

authConfig, err := auth.Authorization()
if err != nil || authConfig == nil {
return nil
}

if authConfig.RegistryToken != "" {
return &distribution.Credentials{BearerToken: authConfig.RegistryToken}
}
if authConfig.IdentityToken != "" {
return &distribution.Credentials{BearerToken: authConfig.IdentityToken}
}

if authConfig.Username != "" && authConfig.Password != "" {
token, err := exchangeForToken(ctx, ref, auth)
if err != nil {
return &distribution.Credentials{
Username: authConfig.Username,
Password: authConfig.Password,
}
}
return &distribution.Credentials{BearerToken: token}
}

return nil
}

// exchangeForToken exchanges credentials for a short-lived bearer token.
func exchangeForToken(ctx context.Context, ref reference.Reference, auth authn.Authenticator) (string, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

pr, err := remote.Ping(ctx, ref.Context().Registry, nil)
if err != nil {
return "", fmt.Errorf("pinging registry: %w", err)
}

if pr.WWWAuthenticate.Realm == "" {
return "", fmt.Errorf("no auth required")
}

scope := ref.Scope(remote.PushScope)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 issue (security): Using PushScope for all token exchanges may be too restrictive for pull-only operations

resolveCredentials is used for both pull and push, but exchangeForToken always requests remote.PushScope. For pull-only callers this can both fail unnecessarily and request overly broad (write) tokens. Consider making the scope a parameter (pull vs push) or deriving it from the caller so we only request the minimum required scope.

token, err := remote.Exchange(ctx, ref.Context().Registry, auth, nil, []string{scope}, pr)
if err != nil {
return "", fmt.Errorf("exchanging credentials: %w", err)
}

return token.Token, nil
}

type otelErrorSilencer struct{}

func (oes *otelErrorSilencer) Handle(error) {}
Expand Down Expand Up @@ -98,17 +173,16 @@ func (c *Client) Status() Status {
}

func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, bool, error) {
// Check if this is a Hugging Face model and if HF_TOKEN is set
var hfToken string
if strings.HasPrefix(strings.ToLower(model), "hf.co/") {
hfToken = os.Getenv("HF_TOKEN")
}
creds := resolveCredentials(context.Background(), model)

return c.withRetries("download", 3, printer, func(attempt int) (string, bool, error, bool) {
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{
From: model,
BearerToken: hfToken,
})
req := dmrm.ModelCreateRequest{From: model}
if creds != nil {
req.Username = creds.Username
req.Password = creds.Password
req.BearerToken = creds.BearerToken
}
jsonData, err := json.Marshal(req)
if err != nil {
// Marshaling errors are not retryable
return "", false, fmt.Errorf("error marshaling request: %w", err), false
Expand Down Expand Up @@ -223,12 +297,27 @@ func (c *Client) withRetries(
}

func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
creds := resolveCredentials(context.Background(), model)

return c.withRetries("push", 3, printer, func(attempt int) (string, bool, error, bool) {
var body io.Reader
if creds != nil {
jsonData, err := json.Marshal(dmrm.ModelPushRequest{
Username: creds.Username,
Password: creds.Password,
BearerToken: creds.BearerToken,
})
if err != nil {
return "", false, fmt.Errorf("error marshaling request: %w", err), false
}
body = bytes.NewReader(jsonData)
}

pushPath := inference.ModelsPrefix + "/" + model + "/push"
resp, err := c.doRequest(
http.MethodPost,
pushPath,
nil, // Assuming no body is needed for the push request
body,
)
if err != nil {
// Only retry on network errors, not on client errors
Expand Down
70 changes: 0 additions & 70 deletions cmd/cli/pkg/standalone/containers.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package standalone

import (
"archive/tar"
"bytes"
"context"
"errors"
"fmt"
Expand All @@ -29,66 +27,6 @@ import (
// controllerContainerName is the name to use for the controller container.
const controllerContainerName = "docker-model-runner"

// copyDockerConfigToContainer copies the Docker config file from the host to the container
// and sets up proper ownership and permissions for the modelrunner user.
// It does nothing for Desktop and Cloud engine kinds.
func copyDockerConfigToContainer(ctx context.Context, dockerClient *client.Client, containerID string, engineKind types.ModelRunnerEngineKind) error {
// Do nothing for Desktop and Cloud engine kinds
if engineKind == types.ModelRunnerEngineKindDesktop || engineKind == types.ModelRunnerEngineKindCloud ||
os.Getenv("_MODEL_RUNNER_TREAT_DESKTOP_AS_MOBY") == "1" {
return nil
}

dockerConfigPath := os.ExpandEnv("$HOME/.docker/config.json")
if s, err := os.Stat(dockerConfigPath); err != nil || s.Mode()&os.ModeType != 0 {
return nil
}

configData, err := os.ReadFile(dockerConfigPath)
if err != nil {
return fmt.Errorf("failed to read Docker config file: %w", err)
}

var buf bytes.Buffer
tw := tar.NewWriter(&buf)
header := &tar.Header{
Name: ".docker/config.json",
Mode: 0600,
Size: int64(len(configData)),
}
if err := tw.WriteHeader(header); err != nil {
return fmt.Errorf("failed to write tar header: %w", err)
}
if _, err := tw.Write(configData); err != nil {
return fmt.Errorf("failed to write config data to tar: %w", err)
}
if err := tw.Close(); err != nil {
return fmt.Errorf("failed to close tar writer: %w", err)
}

// Ensure the .docker directory exists
mkdirCmd := "mkdir -p /home/modelrunner/.docker && chown modelrunner:modelrunner /home/modelrunner/.docker"
if err := execInContainer(ctx, dockerClient, containerID, mkdirCmd, false); err != nil {
return err
}

// Copy directly into the .docker directory
err = dockerClient.CopyToContainer(ctx, containerID, "/home/modelrunner", &buf, container.CopyToContainerOptions{
CopyUIDGID: true,
})
if err != nil {
return fmt.Errorf("failed to copy config file to container: %w", err)
}

// Set correct ownership and permissions
chmodCmd := "chown modelrunner:modelrunner /home/modelrunner/.docker/config.json && chmod 600 /home/modelrunner/.docker/config.json"
if err := execInContainer(ctx, dockerClient, containerID, chmodCmd, false); err != nil {
return err
}

return nil
}

func execInContainer(ctx context.Context, dockerClient *client.Client, containerID, cmd string, asRoot bool) error {
execConfig := container.ExecOptions{
Cmd: []string{"sh", "-c", cmd},
Expand Down Expand Up @@ -447,14 +385,6 @@ func CreateControllerContainer(ctx context.Context, dockerClient *client.Client,
return fmt.Errorf("failed to start container %s: %w", controllerContainerName, err)
}

// Copy Docker config file if it exists and we're the container creator.
if created && !vllmOnWSL {
if err := copyDockerConfigToContainer(ctx, dockerClient, resp.ID, engineKind); err != nil {
// Log warning but continue - don't fail container creation
printer.Printf("Warning: failed to copy Docker config: %v\n", err)
}
}

// Add proxy certificate to the system CA bundle (requires root for update-ca-certificates)
if created && proxyCert != "" {
printer.Printf("Updating CA certificates...\n")
Expand Down
45 changes: 32 additions & 13 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ import (
"github.com/sirupsen/logrus"
)

// Credentials holds authentication credentials for registry operations.
type Credentials struct {
// Username for basic authentication.
Username string
// Password for basic authentication.
Password string
// BearerToken for token-based authentication (e.g., Hugging Face).
BearerToken string
}

// Client provides model distribution functionality
type Client struct {
store *store.LocalStore
Expand Down Expand Up @@ -227,32 +237,32 @@ func (c *Client) resolveID(id string) string {
}

// PullModel pulls a model from a registry and returns the local file path
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error {
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, creds *Credentials) error {
// Store original reference before normalization (needed for case-sensitive HuggingFace API)
originalReference := reference
// Normalize the model reference
reference = c.normalizeModelName(reference)
c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference))

// Handle bearer token for registry authentication
var token string
if len(bearerToken) > 0 && bearerToken[0] != "" {
token = bearerToken[0]
}

// HuggingFace references always use native pull (download raw files from HF Hub)
if isHuggingFaceReference(originalReference) {
c.log.Infoln("Using native HuggingFace pull for:", utils.SanitizeForLog(reference))
Comment on lines +240 to 249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Authentication precedence differs between PullModel and PushModel and may be surprising

In PullModel, basic auth currently takes precedence over BearerToken, while PushModel does the opposite. This inconsistent precedence means callers who set both fields get different auth behavior between pull and push. Consider standardizing the precedence across both methods or rejecting calls where multiple auth fields are set to avoid confusion and unintended auth schemes.

Suggested implementation:

 // PullModel pulls a model from a registry and returns the local file path
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, creds *Credentials) error {
	// Ensure callers don't mix auth mechanisms; this matches PushModel behavior and avoids
	// surprising precedence differences when multiple credential fields are set.
	if creds != nil {
		hasBearer := creds.BearerToken != ""
		// NOTE: adjust these field names if your Credentials type uses different ones
		hasBasic := creds.Username != "" || creds.Password != ""

		if hasBearer && hasBasic {
			return fmt.Errorf("multiple authentication methods configured: use either bearer token or basic auth, not both")
		}
	}

	// Store original reference before normalization (needed for case-sensitive HuggingFace API)
	originalReference := reference
	// Normalize the model reference
	reference = c.normalizeModelName(reference)
	c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference))
  1. Ensure the Credentials struct has Username and Password fields; if they differ (e.g. User / Pass), update the hasBasic line accordingly.
  2. Add the fmt package to the imports in pkg/distribution/distribution/client.go:
    • import "fmt"
  3. For full consistency with PushModel, confirm that PushModel either:
    • Uses the same mutual-exclusion rule (rejecting mixed auth), or
    • If it assumes a specific precedence (e.g. bearer over basic), document that and potentially mirror that behavior here instead of rejecting.

// Pass original reference to preserve case-sensitivity for HuggingFace API
var token string
if creds != nil && creds.BearerToken != "" {
token = creds.BearerToken
}
return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token)
}

// For non-HF references, use OCI registry
registryClient := c.registry
if token != "" {
// Create a temporary registry client with bearer token authentication
auth := authn.NewBearer(token)
registryClient = registry.FromClient(c.registry, registry.WithAuth(auth))
if creds != nil {
if creds.Username != "" && creds.Password != "" {
registryClient = registry.FromClient(c.registry, registry.WithAuthConfig(creds.Username, creds.Password))
} else if creds.BearerToken != "" {
registryClient = registry.FromClient(c.registry, registry.WithAuth(authn.NewBearer(creds.BearerToken)))
}
}

// Fetch the remote model to get the manifest
Expand Down Expand Up @@ -538,9 +548,18 @@ func (c *Client) Tag(source string, target string) error {
}

// PushModel pushes a tagged model from the content store to the registry.
func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer) (err error) {
func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer, creds *Credentials) (err error) {
registryClient := c.registry
if creds != nil {
if creds.BearerToken != "" {
registryClient = registry.FromClient(c.registry, registry.WithAuth(authn.NewBearer(creds.BearerToken)))
} else if creds.Username != "" && creds.Password != "" {
registryClient = registry.FromClient(c.registry, registry.WithAuthConfig(creds.Username, creds.Password))
}
}

// Parse the tag
target, err := c.registry.NewTarget(tag)
target, err := registryClient.NewTarget(tag)
if err != nil {
return fmt.Errorf("new tag: %w", err)
}
Expand Down
Loading
Loading