Skip to content

Commit 925798d

Browse files
committed
Addressing review comments && checking etag
1 parent 27f3e6a commit 925798d

File tree

2 files changed

+208
-85
lines changed

2 files changed

+208
-85
lines changed

pkg/github/repositories.go

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -411,60 +411,66 @@ If the SHA is not provided, the tool will attempt to acquire it by fetching the
411411
}
412412

413413
path = strings.TrimPrefix(path, "/")
414-
fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts)
415-
if err != nil {
416-
if strings.Contains(err.Error(), `"sha" wasn't supplied`) && sha == "" {
417-
// Close the response from the initial failed CreateFile call
418-
if resp != nil {
419-
_ = resp.Body.Close()
420-
}
421414

422-
// attempt to get the current file SHA by fetching the file contents
423-
getOpts := &github.RepositoryContentGetOptions{
424-
Ref: branch,
425-
}
426-
currentFileContent, _, respContents, err := client.Repositories.GetContents(ctx, owner, repo, path, getOpts)
427-
428-
if err == nil {
429-
// Close the GetContents response before making the retry call
430-
if respContents != nil {
431-
_ = respContents.Body.Close()
432-
}
415+
// SHA validation using conditional HEAD request (efficient - no body transfer)
416+
var previousSHA string
417+
contentURL := fmt.Sprintf("repos/%s/%s/contents/%s", owner, repo, url.PathEscape(path))
418+
if branch != "" {
419+
contentURL += "?ref=" + url.QueryEscape(branch)
420+
}
433421

434-
if currentFileContent != nil && currentFileContent.SHA != nil {
435-
opts.SHA = currentFileContent.SHA
436-
fileContent, resp, err = client.Repositories.CreateFile(ctx, owner, repo, path, opts)
437-
defer func() { _ = resp.Body.Close() }()
438-
439-
if err != nil {
440-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
441-
"failed to create/update file after retrieving current SHA",
442-
resp,
443-
err,
444-
), nil, nil
445-
}
446-
} else {
447-
return utils.NewToolResultError("file content SHA is nil, cannot update the file"), nil, nil
422+
if sha != "" {
423+
// User provided SHA - validate it's still current
424+
req, err := client.NewRequest("HEAD", contentURL, nil)
425+
if err == nil {
426+
req.Header.Set("If-None-Match", fmt.Sprintf(`"%s"`, sha))
427+
resp, _ := client.Do(ctx, req, nil)
428+
if resp != nil {
429+
defer resp.Body.Close()
430+
431+
switch resp.StatusCode {
432+
case http.StatusNotModified:
433+
// SHA matches current - proceed
434+
opts.SHA = github.Ptr(sha)
435+
case http.StatusOK:
436+
// SHA is stale - reject with current SHA so user can check diff
437+
currentSHA := strings.Trim(resp.Header.Get("ETag"), `"`)
438+
return utils.NewToolResultError(fmt.Sprintf(
439+
"SHA mismatch: provided SHA %s is stale. Current file SHA is %s. "+
440+
"Use get_file_contents or compare commits to review changes before updating.",
441+
sha, currentSHA)), nil, nil
442+
case http.StatusNotFound:
443+
// File doesn't exist - this is a create, ignore provided SHA
448444
}
449-
} else {
450-
// Close the GetContents response before returning error
451-
if respContents != nil {
452-
_ = respContents.Body.Close()
445+
}
446+
}
447+
} else {
448+
// No SHA provided - check if file exists to warn about blind update
449+
req, err := client.NewRequest("HEAD", contentURL, nil)
450+
if err == nil {
451+
resp, _ := client.Do(ctx, req, nil)
452+
if resp != nil {
453+
defer resp.Body.Close()
454+
if resp.StatusCode == http.StatusOK {
455+
previousSHA = strings.Trim(resp.Header.Get("ETag"), `"`)
453456
}
454-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
455-
"failed to get file SHA for update",
456-
respContents,
457-
err,
458-
), nil, nil
457+
// 404 = new file, no previous SHA needed
459458
}
460-
} else {
461-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
462-
"failed to create/update file",
463-
resp,
464-
err,
465-
), nil, nil
466459
}
467460
}
461+
462+
if previousSHA != "" {
463+
opts.SHA = github.Ptr(previousSHA)
464+
}
465+
466+
fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts)
467+
if err != nil {
468+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
469+
"failed to create/update file",
470+
resp,
471+
err,
472+
), nil, nil
473+
}
468474
defer func() { _ = resp.Body.Close() }()
469475

470476
if resp.StatusCode != 200 && resp.StatusCode != 201 {
@@ -480,6 +486,19 @@ If the SHA is not provided, the tool will attempt to acquire it by fetching the
480486
return nil, nil, fmt.Errorf("failed to marshal response: %w", err)
481487
}
482488

489+
// Warn if file was updated without SHA validation (blind update)
490+
if sha == "" && previousSHA != "" {
491+
return utils.NewToolResultText(fmt.Sprintf(
492+
"Warning: File updated without SHA validation. Previous file SHA was %s. "+
493+
`Verify no unintended changes were overwritten:
494+
1. Extract the SHA of the local version using git ls-tree HEAD %s.
495+
2. Compare with the previous SHA above.
496+
3. Revert changes if shas do not match.
497+
498+
%s`,
499+
previousSHA, path, string(r))), nil, nil
500+
}
501+
483502
return utils.NewToolResultText(string(r)), nil, nil
484503
})
485504

pkg/github/repositories_test.go

Lines changed: 142 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,81 +1122,179 @@ func Test_CreateOrUpdateFile(t *testing.T) {
11221122
expectedErrMsg: "failed to create/update file",
11231123
},
11241124
{
1125-
name: "file creation fails with missing sha, then succeeds after fetching sha",
1125+
name: "sha validation - current sha matches (304 Not Modified)",
11261126
mockedClient: mock.NewMockedHTTPClient(
11271127
mock.WithRequestMatchHandler(
1128-
mock.PutReposContentsByOwnerByRepoByPath,
1129-
func() http.HandlerFunc {
1130-
callCount := 0
1131-
return func(w http.ResponseWriter, _ *http.Request) {
1132-
callCount++
1133-
if callCount == 1 {
1134-
// First call fails with "sha wasn't supplied" error
1135-
w.WriteHeader(http.StatusUnprocessableEntity)
1136-
_, _ = w.Write([]byte(`{"message": "\"sha\" wasn't supplied"}`))
1137-
} else {
1138-
// Second call succeeds after SHA is retrieved
1139-
w.WriteHeader(http.StatusOK)
1140-
respBytes, _ := json.Marshal(mockFileResponse)
1141-
_, _ = w.Write(respBytes)
1142-
}
1128+
mock.EndpointPattern{
1129+
Pattern: "/repos/owner/repo/contents/docs/example.md",
1130+
Method: "HEAD",
1131+
},
1132+
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1133+
// Verify If-None-Match header is set correctly
1134+
ifNoneMatch := req.Header.Get("If-None-Match")
1135+
if ifNoneMatch == `"abc123def456"` {
1136+
w.WriteHeader(http.StatusNotModified)
1137+
} else {
1138+
w.WriteHeader(http.StatusOK)
1139+
w.Header().Set("ETag", `"abc123def456"`)
11431140
}
1144-
}(),
1141+
}),
11451142
),
11461143
mock.WithRequestMatchHandler(
1147-
mock.GetReposContentsByOwnerByRepoByPath,
1144+
mock.PutReposContentsByOwnerByRepoByPath,
1145+
expectRequestBody(t, map[string]interface{}{
1146+
"message": "Update example file",
1147+
"content": "IyBVcGRhdGVkIEV4YW1wbGUKClRoaXMgZmlsZSBoYXMgYmVlbiB1cGRhdGVkLg==",
1148+
"branch": "main",
1149+
"sha": "abc123def456",
1150+
}).andThen(
1151+
mockResponse(t, http.StatusOK, mockFileResponse),
1152+
),
1153+
),
1154+
),
1155+
requestArgs: map[string]interface{}{
1156+
"owner": "owner",
1157+
"repo": "repo",
1158+
"path": "docs/example.md",
1159+
"content": "# Updated Example\n\nThis file has been updated.",
1160+
"message": "Update example file",
1161+
"branch": "main",
1162+
"sha": "abc123def456",
1163+
},
1164+
expectError: false,
1165+
expectedContent: mockFileResponse,
1166+
},
1167+
{
1168+
name: "sha validation - stale sha detected (200 OK with different ETag)",
1169+
mockedClient: mock.NewMockedHTTPClient(
1170+
mock.WithRequestMatchHandler(
1171+
mock.EndpointPattern{
1172+
Pattern: "/repos/owner/repo/contents/docs/example.md",
1173+
Method: "HEAD",
1174+
},
11481175
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1176+
// SHA doesn't match - return 200 with current ETag
1177+
w.Header().Set("ETag", `"newsha999888"`)
11491178
w.WriteHeader(http.StatusOK)
1150-
existingFile := &github.RepositoryContent{
1151-
Name: github.Ptr("example.md"),
1152-
Path: github.Ptr("docs/example.md"),
1153-
SHA: github.Ptr("abc123def456"),
1154-
Type: github.Ptr("file"),
1155-
}
1156-
contentBytes, _ := json.Marshal(existingFile)
1157-
_, _ = w.Write(contentBytes)
11581179
}),
11591180
),
11601181
),
11611182
requestArgs: map[string]interface{}{
11621183
"owner": "owner",
11631184
"repo": "repo",
11641185
"path": "docs/example.md",
1165-
"content": "# Example\n\nThis is an example file.",
1166-
"message": "Add example file",
1186+
"content": "# Updated Example\n\nThis file has been updated.",
1187+
"message": "Update example file",
1188+
"branch": "main",
1189+
"sha": "oldsha123456",
1190+
},
1191+
expectError: true,
1192+
expectedErrMsg: "SHA mismatch: provided SHA oldsha123456 is stale. Current file SHA is newsha999888",
1193+
},
1194+
{
1195+
name: "sha validation - file doesn't exist (404), proceed with create",
1196+
mockedClient: mock.NewMockedHTTPClient(
1197+
mock.WithRequestMatchHandler(
1198+
mock.EndpointPattern{
1199+
Pattern: "/repos/owner/repo/contents/docs/example.md",
1200+
Method: "HEAD",
1201+
},
1202+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1203+
w.WriteHeader(http.StatusNotFound)
1204+
}),
1205+
),
1206+
mock.WithRequestMatchHandler(
1207+
mock.PutReposContentsByOwnerByRepoByPath,
1208+
expectRequestBody(t, map[string]interface{}{
1209+
"message": "Create new file",
1210+
"content": "IyBOZXcgRmlsZQoKVGhpcyBpcyBhIG5ldyBmaWxlLg==",
1211+
"branch": "main",
1212+
"sha": "ignoredsha", // SHA is sent but GitHub API ignores it for new files
1213+
}).andThen(
1214+
mockResponse(t, http.StatusCreated, mockFileResponse),
1215+
),
1216+
),
1217+
),
1218+
requestArgs: map[string]interface{}{
1219+
"owner": "owner",
1220+
"repo": "repo",
1221+
"path": "docs/example.md",
1222+
"content": "# New File\n\nThis is a new file.",
1223+
"message": "Create new file",
11671224
"branch": "main",
1225+
"sha": "ignoredsha",
11681226
},
11691227
expectError: false,
11701228
expectedContent: mockFileResponse,
11711229
},
11721230
{
1173-
name: "file creation fails with missing sha and GetContents also fails",
1231+
name: "no sha provided - file exists, returns warning",
11741232
mockedClient: mock.NewMockedHTTPClient(
11751233
mock.WithRequestMatchHandler(
1176-
mock.PutReposContentsByOwnerByRepoByPath,
1234+
mock.EndpointPattern{
1235+
Pattern: "/repos/owner/repo/contents/docs/example.md",
1236+
Method: "HEAD",
1237+
},
11771238
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1178-
w.WriteHeader(http.StatusUnprocessableEntity)
1179-
_, _ = w.Write([]byte(`{"message": "\"sha\" wasn't supplied"}`))
1239+
w.Header().Set("ETag", `"existing123"`)
1240+
w.WriteHeader(http.StatusOK)
11801241
}),
11811242
),
11821243
mock.WithRequestMatchHandler(
1183-
mock.GetReposContentsByOwnerByRepoByPath,
1244+
mock.PutReposContentsByOwnerByRepoByPath,
1245+
expectRequestBody(t, map[string]interface{}{
1246+
"message": "Update without SHA",
1247+
"content": "IyBVcGRhdGVkCgpVcGRhdGVkIHdpdGhvdXQgU0hBLg==",
1248+
"branch": "main",
1249+
}).andThen(
1250+
mockResponse(t, http.StatusOK, mockFileResponse),
1251+
),
1252+
),
1253+
),
1254+
requestArgs: map[string]interface{}{
1255+
"owner": "owner",
1256+
"repo": "repo",
1257+
"path": "docs/example.md",
1258+
"content": "# Updated\n\nUpdated without SHA.",
1259+
"message": "Update without SHA",
1260+
"branch": "main",
1261+
},
1262+
expectError: false,
1263+
expectedErrMsg: "Warning: File updated without SHA validation. Previous file SHA was existing123",
1264+
},
1265+
{
1266+
name: "no sha provided - file doesn't exist, no warning",
1267+
mockedClient: mock.NewMockedHTTPClient(
1268+
mock.WithRequestMatchHandler(
1269+
mock.EndpointPattern{
1270+
Pattern: "/repos/owner/repo/contents/docs/example.md",
1271+
Method: "HEAD",
1272+
},
11841273
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
11851274
w.WriteHeader(http.StatusNotFound)
1186-
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
11871275
}),
11881276
),
1277+
mock.WithRequestMatchHandler(
1278+
mock.PutReposContentsByOwnerByRepoByPath,
1279+
expectRequestBody(t, map[string]interface{}{
1280+
"message": "Create new file",
1281+
"content": "IyBOZXcgRmlsZQoKQ3JlYXRlZCB3aXRob3V0IFNIQQ==",
1282+
"branch": "main",
1283+
}).andThen(
1284+
mockResponse(t, http.StatusCreated, mockFileResponse),
1285+
),
1286+
),
11891287
),
11901288
requestArgs: map[string]interface{}{
11911289
"owner": "owner",
11921290
"repo": "repo",
11931291
"path": "docs/example.md",
1194-
"content": "# Example\n\nThis is an example file.",
1195-
"message": "Add example file",
1292+
"content": "# New File\n\nCreated without SHA",
1293+
"message": "Create new file",
11961294
"branch": "main",
11971295
},
1198-
expectError: true,
1199-
expectedErrMsg: "failed to get file SHA for update",
1296+
expectError: false,
1297+
expectedContent: mockFileResponse,
12001298
},
12011299
}
12021300

@@ -1227,6 +1325,12 @@ func Test_CreateOrUpdateFile(t *testing.T) {
12271325
// Parse the result and get the text content if no error
12281326
textContent := getTextResult(t, result)
12291327

1328+
// If expectedErrMsg is set (but expectError is false), this is a warning case
1329+
if tc.expectedErrMsg != "" {
1330+
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
1331+
return
1332+
}
1333+
12301334
// Unmarshal and verify the result
12311335
var returnedContent github.RepositoryContentResponse
12321336
err = json.Unmarshal([]byte(textContent.Text), &returnedContent)

0 commit comments

Comments
 (0)