Skip to content

Commit 28bc3f4

Browse files
committed
Update context tools and tests to use ToolHandlerFor with typed arguments and return values.
1 parent 385dd8d commit 28bc3f4

3 files changed

Lines changed: 45 additions & 37 deletions

File tree

pkg/github/context_tools.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type UserDetails struct {
3535
}
3636

3737
// GetMe creates a tool to get details of the authenticated user.
38-
func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandler) {
38+
func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
3939
tool := mcp.Tool{
4040
Name: "get_me",
4141
Description: t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls."),
@@ -45,10 +45,10 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too
4545
},
4646
}
4747

48-
handler := mcp.ToolHandler(func(ctx context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
48+
handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) {
4949
client, err := getClient(ctx)
5050
if err != nil {
51-
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), err
51+
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err
5252
}
5353

5454
user, res, err := client.Users.Get(ctx, "")
@@ -57,7 +57,7 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too
5757
"failed to get user",
5858
res,
5959
err,
60-
), err
60+
), nil, err
6161
}
6262

6363
// Create minimal user representation instead of returning full user object
@@ -87,7 +87,7 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too
8787
},
8888
}
8989

90-
return MarshalledTextResult(minimalUser), nil
90+
return MarshalledTextResult(minimalUser), nil, nil
9191
})
9292

9393
return tool, handler
@@ -200,11 +200,11 @@ func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelpe
200200
},
201201
InputSchema: &jsonschema.Schema{
202202
Properties: map[string]*jsonschema.Schema{
203-
"org": &jsonschema.Schema{
203+
"org": {
204204
Type: "string",
205205
Description: t("TOOL_GET_TEAM_MEMBERS_ORG_DESCRIPTION", "Organization login (owner) that contains the team."),
206206
},
207-
"team_slug": &jsonschema.Schema{
207+
"team_slug": {
208208
Type: "string",
209209
Description: t("TOOL_GET_TEAM_MEMBERS_TEAM_SLUG_DESCRIPTION", "Team slug"),
210210
},

pkg/github/context_tools_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func Test_GetMe(t *testing.T) {
2525

2626
// Verify some basic very important properties
2727
assert.Equal(t, "get_me", tool.Name)
28-
assert.True(t, *tool.Annotations.ReadOnlyHint, "get_me tool should be read-only")
28+
assert.True(t, tool.Annotations.ReadOnlyHint, "get_me tool should be read-only")
2929

3030
// Setup mock user response
3131
mockUser := &github.User{
@@ -111,7 +111,7 @@ func Test_GetMe(t *testing.T) {
111111
_, handler := GetMe(tc.stubbedGetClientFn, translations.NullTranslationHelper)
112112

113113
request := createMCPRequest(tc.requestArgs)
114-
result, err := handler(context.Background(), request)
114+
result, _, err := handler(context.Background(), &request, tc.requestArgs)
115115
require.NoError(t, err)
116116
textContent := getTextResult(t, result)
117117

@@ -150,7 +150,7 @@ func Test_GetTeams(t *testing.T) {
150150
require.NoError(t, toolsnaps.Test(tool.Name, tool))
151151

152152
assert.Equal(t, "get_teams", tool.Name)
153-
assert.True(t, *tool.Annotations.ReadOnlyHint, "get_teams tool should be read-only")
153+
assert.True(t, tool.Annotations.ReadOnlyHint, "get_teams tool should be read-only")
154154

155155
mockUser := &github.User{
156156
Login: github.Ptr("testuser"),
@@ -335,7 +335,7 @@ func Test_GetTeams(t *testing.T) {
335335
_, handler := GetTeams(tc.stubbedGetClientFn, tc.stubbedGetGQLClientFn, translations.NullTranslationHelper)
336336

337337
request := createMCPRequest(tc.requestArgs)
338-
result, err := handler(context.Background(), request)
338+
result, _, err := handler(context.Background(), &request, tc.requestArgs)
339339
require.NoError(t, err)
340340
textContent := getTextResult(t, result)
341341

@@ -377,7 +377,7 @@ func Test_GetTeamMembers(t *testing.T) {
377377
require.NoError(t, toolsnaps.Test(tool.Name, tool))
378378

379379
assert.Equal(t, "get_team_members", tool.Name)
380-
assert.True(t, *tool.Annotations.ReadOnlyHint, "get_team_members tool should be read-only")
380+
assert.True(t, tool.Annotations.ReadOnlyHint, "get_team_members tool should be read-only")
381381

382382
mockTeamMembersResponse := githubv4mock.DataResponse(map[string]any{
383383
"organization": map[string]any{
@@ -471,7 +471,7 @@ func Test_GetTeamMembers(t *testing.T) {
471471
_, handler := GetTeamMembers(tc.stubbedGetGQLClientFn, translations.NullTranslationHelper)
472472

473473
request := createMCPRequest(tc.requestArgs)
474-
result, err := handler(context.Background(), request)
474+
result, _, err := handler(context.Background(), &request, tc.requestArgs)
475475
require.NoError(t, err)
476476
textContent := getTextResult(t, result)
477477

pkg/github/helper_test.go

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"net/http"
66
"testing"
77

8-
"github.com/mark3labs/mcp-go/mcp"
8+
"github.com/modelcontextprotocol/go-sdk/mcp"
99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"
1111
)
@@ -110,56 +110,66 @@ func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc {
110110

111111
// createMCPRequest is a helper function to create a MCP request with the given arguments.
112112
func createMCPRequest(args any) mcp.CallToolRequest {
113+
// convert args to map[string]interface{} and serialize to JSON
114+
argsMap, ok := args.(map[string]interface{})
115+
if !ok {
116+
argsMap = make(map[string]interface{})
117+
}
118+
119+
argsJSON, err := json.Marshal(argsMap)
120+
require.NoError(nil, err)
121+
122+
jsonRawMessage := json.RawMessage(argsJSON)
123+
113124
return mcp.CallToolRequest{
114-
Params: struct {
115-
Name string `json:"name"`
116-
Arguments any `json:"arguments,omitempty"`
117-
Meta *mcp.Meta `json:"_meta,omitempty"`
118-
}{
119-
Arguments: args,
125+
Params: &mcp.CallToolParamsRaw{
126+
Arguments: jsonRawMessage,
120127
},
121128
}
122129
}
123130

124131
// getTextResult is a helper function that returns a text result from a tool call.
125-
func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent {
132+
func getTextResult(t *testing.T, result *mcp.CallToolResult) *mcp.TextContent {
126133
t.Helper()
127134
assert.NotNil(t, result)
128135
require.Len(t, result.Content, 1)
129136
require.IsType(t, mcp.TextContent{}, result.Content[0])
130-
textContent := result.Content[0].(mcp.TextContent)
131-
assert.Equal(t, "text", textContent.Type)
137+
textContent := result.Content[0].(*mcp.TextContent)
132138
return textContent
133139
}
134140

135-
func getErrorResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent {
141+
func getErrorResult(t *testing.T, result *mcp.CallToolResult) *mcp.TextContent {
136142
res := getTextResult(t, result)
137143
require.True(t, result.IsError, "expected tool call result to be an error")
138144
return res
139145
}
140146

141147
// getTextResourceResult is a helper function that returns a text result from a tool call.
142-
func getTextResourceResult(t *testing.T, result *mcp.CallToolResult) mcp.TextResourceContents {
148+
func getTextResourceResult(t *testing.T, result *mcp.CallToolResult) *mcp.ResourceContents {
143149
t.Helper()
144150
assert.NotNil(t, result)
145151
require.Len(t, result.Content, 2)
146152
content := result.Content[1]
147153
require.IsType(t, mcp.EmbeddedResource{}, content)
148-
resource := content.(mcp.EmbeddedResource)
149-
require.IsType(t, mcp.TextResourceContents{}, resource.Resource)
150-
return resource.Resource.(mcp.TextResourceContents)
154+
resource := content.(*mcp.EmbeddedResource)
155+
156+
require.IsType(t, mcp.ResourceContents{}, resource.Resource)
157+
require.NotEmpty(t, resource.Resource.Text)
158+
return resource.Resource
151159
}
152160

153161
// getBlobResourceResult is a helper function that returns a blob result from a tool call.
154-
func getBlobResourceResult(t *testing.T, result *mcp.CallToolResult) mcp.BlobResourceContents {
162+
func getBlobResourceResult(t *testing.T, result *mcp.CallToolResult) *mcp.ResourceContents {
155163
t.Helper()
156164
assert.NotNil(t, result)
157165
require.Len(t, result.Content, 2)
158166
content := result.Content[1]
159167
require.IsType(t, mcp.EmbeddedResource{}, content)
160-
resource := content.(mcp.EmbeddedResource)
161-
require.IsType(t, mcp.BlobResourceContents{}, resource.Resource)
162-
return resource.Resource.(mcp.BlobResourceContents)
168+
169+
resource := content.(*mcp.EmbeddedResource)
170+
require.IsType(t, mcp.ResourceContents{}, resource.Resource)
171+
require.NotEmpty(t, resource.Resource.Blob)
172+
return resource.Resource
163173
}
164174

165175
func TestOptionalParamOK(t *testing.T) {
@@ -226,11 +236,9 @@ func TestOptionalParamOK(t *testing.T) {
226236

227237
for _, tc := range tests {
228238
t.Run(tc.name, func(t *testing.T) {
229-
request := createMCPRequest(tc.args)
230-
231239
// Test with string type assertion
232240
if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" {
233-
val, ok, err := OptionalParamOK[string](request, tc.paramName)
241+
val, ok, err := OptionalParamOK[string, map[string]any](tc.args, tc.paramName)
234242
if tc.expectError {
235243
require.Error(t, err)
236244
assert.Contains(t, err.Error(), tc.errorMsg)
@@ -245,7 +253,7 @@ func TestOptionalParamOK(t *testing.T) {
245253

246254
// Test with bool type assertion
247255
if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" {
248-
val, ok, err := OptionalParamOK[bool](request, tc.paramName)
256+
val, ok, err := OptionalParamOK[bool, map[string]any](tc.args, tc.paramName)
249257
if tc.expectError {
250258
require.Error(t, err)
251259
assert.Contains(t, err.Error(), tc.errorMsg)
@@ -260,7 +268,7 @@ func TestOptionalParamOK(t *testing.T) {
260268

261269
// Test with float64 type assertion (for number case)
262270
if _, isFloat := tc.expectedVal.(float64); isFloat {
263-
val, ok, err := OptionalParamOK[float64](request, tc.paramName)
271+
val, ok, err := OptionalParamOK[float64, map[string]any](tc.args, tc.paramName)
264272
if tc.expectError {
265273
// This case shouldn't happen for float64 in the defined tests
266274
require.Fail(t, "Unexpected error case for float64")

0 commit comments

Comments
 (0)