Skip to content

Commit 6d3051d

Browse files
committed
Fallback to default branch in get_file_contents when main doesn't exist
1 parent 2f31c15 commit 6d3051d

File tree

2 files changed

+122
-12
lines changed

2 files changed

+122
-12
lines changed

pkg/github/repositories.go

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,8 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
671671
if err != nil {
672672
return utils.NewToolResultError(err.Error()), nil, nil
673673
}
674+
originalRef := ref
675+
674676
sha, err := OptionalParam[string](args, "sha")
675677
if err != nil {
676678
return utils.NewToolResultError(err.Error()), nil, nil
@@ -747,6 +749,12 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
747749
}
748750
}
749751

752+
// main branch ref passed in ref parameter but it doesn't exist - default branch was used
753+
var successNote string
754+
if !strings.HasSuffix(rawOpts.Ref, originalRef) {
755+
successNote = fmt.Sprintf(" Note: the provided ref '%s' does not exist, default branch '%s' was used instead.", originalRef, rawOpts.Ref)
756+
}
757+
750758
// Determine if content is text or binary
751759
isTextContent := strings.HasPrefix(contentType, "text/") ||
752760
contentType == "application/json" ||
@@ -762,9 +770,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
762770
}
763771
// Include SHA in the result metadata
764772
if fileSHA != "" {
765-
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil
773+
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA)+successNote, result), nil, nil
766774
}
767-
return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil
775+
return utils.NewToolResultResource("successfully downloaded text file"+successNote, result), nil, nil
768776
}
769777

770778
result := &mcp.ResourceContents{
@@ -774,9 +782,9 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
774782
}
775783
// Include SHA in the result metadata
776784
if fileSHA != "" {
777-
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil
785+
return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA)+successNote, result), nil, nil
778786
}
779-
return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil
787+
return utils.NewToolResultResource("successfully downloaded binary file"+successNote, result), nil, nil
780788
}
781789

782790
// Raw API call failed
@@ -1897,12 +1905,11 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
18971905
switch {
18981906
case originalRef == "":
18991907
// 2a) If ref is empty, determine the default branch.
1900-
repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo)
1908+
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
19011909
if err != nil {
1902-
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err)
1903-
return nil, fmt.Errorf("failed to get repository info: %w", err)
1910+
return nil, err // Error is already wrapped in resolveDefaultBranch.
19041911
}
1905-
ref = fmt.Sprintf("refs/heads/%s", repoInfo.GetDefaultBranch())
1912+
ref = reference.GetRef()
19061913
case strings.HasPrefix(originalRef, "refs/"):
19071914
// 2b) Already fully qualified. The reference will be fetched at the end.
19081915
case strings.HasPrefix(originalRef, "heads/") || strings.HasPrefix(originalRef, "tags/"):
@@ -1928,7 +1935,13 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
19281935
ghErr2, isGhErr2 := err.(*github.ErrorResponse)
19291936
if isGhErr2 && ghErr2.Response.StatusCode == http.StatusNotFound {
19301937
if originalRef == "main" {
1931-
return nil, fmt.Errorf("could not find branch or tag 'main'. Some repositories use 'master' as the default branch name")
1938+
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
1939+
if err != nil {
1940+
return nil, err // Error is already wrapped in resolveDefaultBranch.
1941+
}
1942+
// Update ref to the actual default branch ref so the note can be generated
1943+
ref = reference.GetRef()
1944+
break
19321945
}
19331946
return nil, fmt.Errorf("could not resolve ref %q as a branch or a tag", originalRef)
19341947
}
@@ -1949,17 +1962,41 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner
19491962
reference, resp, err = githubClient.Git.GetRef(ctx, owner, repo, ref)
19501963
if err != nil {
19511964
if ref == "refs/heads/main" {
1952-
return nil, fmt.Errorf("could not find branch 'main'. Some repositories use 'master' as the default branch name")
1965+
reference, err = resolveDefaultBranch(ctx, githubClient, owner, repo)
1966+
if err != nil {
1967+
return nil, err // Error is already wrapped in resolveDefaultBranch.
1968+
}
1969+
// Update ref to the actual default branch ref so the note can be generated
1970+
ref = reference.GetRef()
1971+
} else {
1972+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err)
1973+
return nil, fmt.Errorf("failed to get final reference for %q: %w", ref, err)
19531974
}
1954-
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err)
1955-
return nil, fmt.Errorf("failed to get final reference for %q: %w", ref, err)
19561975
}
19571976
}
19581977

19591978
sha = reference.GetObject().GetSHA()
19601979
return &raw.ContentOpts{Ref: ref, SHA: sha}, nil
19611980
}
19621981

1982+
func resolveDefaultBranch(ctx context.Context, githubClient *github.Client, owner, repo string) (*github.Reference, error) {
1983+
repoInfo, resp, err := githubClient.Repositories.Get(ctx, owner, repo)
1984+
if err != nil {
1985+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository info", resp, err)
1986+
return nil, fmt.Errorf("failed to get repository info: %w", err)
1987+
}
1988+
defaultBranch := repoInfo.GetDefaultBranch()
1989+
1990+
defaultRef, resp, err := githubClient.Git.GetRef(ctx, owner, repo, "heads/"+defaultBranch)
1991+
defer func() { _ = resp.Body.Close() }()
1992+
if err != nil {
1993+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get default branch reference", resp, err)
1994+
return nil, fmt.Errorf("failed to get default branch reference: %w", err)
1995+
}
1996+
1997+
return defaultRef, nil
1998+
}
1999+
19632000
// ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user.
19642001
func ListStarredRepositories(t translations.TranslationHelperFunc) inventory.ServerTool {
19652002
return NewTool(

pkg/github/repositories_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ func Test_GetFileContents(t *testing.T) {
6969
expectedResult interface{}
7070
expectedErrMsg string
7171
expectStatus int
72+
expectedMsg string // optional: expected message text to verify in result
7273
}{
7374
{
7475
name: "successful text content fetch",
@@ -290,6 +291,70 @@ func Test_GetFileContents(t *testing.T) {
290291
MIMEType: "text/markdown",
291292
},
292293
},
294+
{
295+
name: "successful text content fetch with note when ref falls back to default branch",
296+
mockedClient: mock.NewMockedHTTPClient(
297+
mock.WithRequestMatchHandler(
298+
mock.GetReposByOwnerByRepo,
299+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
300+
w.WriteHeader(http.StatusOK)
301+
_, _ = w.Write([]byte(`{"name": "repo", "default_branch": "develop"}`))
302+
}),
303+
),
304+
mock.WithRequestMatchHandler(
305+
mock.GetReposGitRefByOwnerByRepoByRef,
306+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
307+
// Request for "refs/heads/main" -> 404 (doesn't exist)
308+
// Request for "refs/heads/develop" (default branch) -> 200
309+
switch {
310+
case strings.Contains(r.URL.Path, "heads/main"):
311+
w.WriteHeader(http.StatusNotFound)
312+
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
313+
case strings.Contains(r.URL.Path, "heads/develop"):
314+
w.WriteHeader(http.StatusOK)
315+
_, _ = w.Write([]byte(`{"ref": "refs/heads/develop", "object": {"sha": "abc123def456"}}`))
316+
default:
317+
w.WriteHeader(http.StatusNotFound)
318+
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
319+
}
320+
}),
321+
),
322+
mock.WithRequestMatchHandler(
323+
mock.GetReposContentsByOwnerByRepoByPath,
324+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
325+
w.WriteHeader(http.StatusOK)
326+
fileContent := &github.RepositoryContent{
327+
Name: github.Ptr("README.md"),
328+
Path: github.Ptr("README.md"),
329+
SHA: github.Ptr("abc123"),
330+
Type: github.Ptr("file"),
331+
}
332+
contentBytes, _ := json.Marshal(fileContent)
333+
_, _ = w.Write(contentBytes)
334+
}),
335+
),
336+
mock.WithRequestMatchHandler(
337+
raw.GetRawReposContentsByOwnerByRepoBySHAByPath,
338+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
339+
w.Header().Set("Content-Type", "text/markdown")
340+
_, _ = w.Write(mockRawContent)
341+
}),
342+
),
343+
),
344+
requestArgs: map[string]interface{}{
345+
"owner": "owner",
346+
"repo": "repo",
347+
"path": "README.md",
348+
"ref": "main",
349+
},
350+
expectError: false,
351+
expectedResult: mcp.ResourceContents{
352+
URI: "repo://owner/repo/abc123def456/contents/README.md",
353+
Text: "# Test Repository\n\nThis is a test repository.",
354+
MIMEType: "text/markdown",
355+
},
356+
expectedMsg: " Note: the provided ref 'main' does not exist, default branch 'refs/heads/develop' was used instead.",
357+
},
293358
{
294359
name: "content fetch fails",
295360
mockedClient: mock.NewMockedHTTPClient(
@@ -358,6 +423,14 @@ func Test_GetFileContents(t *testing.T) {
358423
// Handle both text and blob resources
359424
resource := getResourceResult(t, result)
360425
assert.Equal(t, expected, *resource)
426+
427+
// If expectedMsg is set, verify the message text
428+
if tc.expectedMsg != "" {
429+
require.Len(t, result.Content, 2)
430+
textContent, ok := result.Content[0].(*mcp.TextContent)
431+
require.True(t, ok, "expected Content[0] to be TextContent")
432+
assert.Contains(t, textContent.Text, tc.expectedMsg)
433+
}
361434
case []*github.RepositoryContent:
362435
// Directory content fetch returns a text result (JSON array)
363436
textContent := getTextResult(t, result)

0 commit comments

Comments
 (0)