Skip to content

Commit 385dd8d

Browse files
committed
Move to go-sdk IOTransport
1 parent 5a92f9c commit 385dd8d

2 files changed

Lines changed: 79 additions & 57 deletions

File tree

internal/ghmcp/server.go

Lines changed: 77 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"fmt"
66
"io"
7-
"log"
87
"log/slog"
98
"net/http"
109
"net/url"
@@ -20,8 +19,7 @@ import (
2019
"github.com/github/github-mcp-server/pkg/raw"
2120
"github.com/github/github-mcp-server/pkg/translations"
2221
gogithub "github.com/google/go-github/v77/github"
23-
"github.com/mark3labs/mcp-go/mcp"
24-
"github.com/mark3labs/mcp-go/server"
22+
"github.com/modelcontextprotocol/go-sdk/mcp"
2523
"github.com/shurcooL/githubv4"
2624
)
2725

@@ -54,11 +52,14 @@ type MCPServerConfig struct {
5452

5553
// LockdownMode indicates if we should enable lockdown mode
5654
LockdownMode bool
55+
56+
// Logger is used for logging within the server
57+
Logger *slog.Logger
5758
}
5859

5960
const stdioServerLogPrefix = "stdioserver"
6061

61-
func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
62+
func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) {
6263
apiHost, err := parseAPIHost(cfg.Host)
6364
if err != nil {
6465
return nil, fmt.Errorf("failed to parse API host: %w", err)
@@ -81,34 +82,6 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
8182
} // We're going to wrap the Transport later in beforeInit
8283
gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
8384

84-
// When a client send an initialize request, update the user agent to include the client info.
85-
beforeInit := func(_ context.Context, _ any, message *mcp.InitializeRequest) {
86-
userAgent := fmt.Sprintf(
87-
"github-mcp-server/%s (%s/%s)",
88-
cfg.Version,
89-
message.Params.ClientInfo.Name,
90-
message.Params.ClientInfo.Version,
91-
)
92-
93-
restClient.UserAgent = userAgent
94-
95-
gqlHTTPClient.Transport = &userAgentTransport{
96-
transport: gqlHTTPClient.Transport,
97-
agent: userAgent,
98-
}
99-
}
100-
101-
hooks := &server.Hooks{
102-
OnBeforeInitialize: []server.OnBeforeInitializeFunc{beforeInit},
103-
OnBeforeAny: []server.BeforeAnyHookFunc{
104-
func(ctx context.Context, _ any, _ mcp.MCPMethod, _ any) {
105-
// Ensure the context is cleared of any previous errors
106-
// as context isn't propagated through middleware
107-
errors.ContextWithGitHubErrors(ctx)
108-
},
109-
},
110-
}
111-
11285
enabledToolsets := cfg.EnabledToolsets
11386

11487
// If dynamic toolsets are enabled, remove "all" from the enabled toolsets
@@ -135,10 +108,14 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
135108
// Generate instructions based on enabled toolsets
136109
instructions := github.GenerateInstructions(enabledToolsets)
137110

138-
ghServer := github.NewServer(cfg.Version,
139-
server.WithInstructions(instructions),
140-
server.WithHooks(hooks),
141-
)
111+
ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{
112+
Instructions: instructions,
113+
Logger: cfg.Logger,
114+
})
115+
116+
// Add middlewares
117+
ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext)
118+
ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient))
142119

143120
getClient := func(_ context.Context) (*gogithub.Client, error) {
144121
return restClient, nil // closing over client
@@ -229,23 +206,6 @@ func RunStdioServer(cfg StdioServerConfig) error {
229206

230207
t, dumpTranslations := translations.TranslationHelper()
231208

232-
ghServer, err := NewMCPServer(MCPServerConfig{
233-
Version: cfg.Version,
234-
Host: cfg.Host,
235-
Token: cfg.Token,
236-
EnabledToolsets: cfg.EnabledToolsets,
237-
DynamicToolsets: cfg.DynamicToolsets,
238-
ReadOnly: cfg.ReadOnly,
239-
Translator: t,
240-
ContentWindowSize: cfg.ContentWindowSize,
241-
LockdownMode: cfg.LockdownMode,
242-
})
243-
if err != nil {
244-
return fmt.Errorf("failed to create MCP server: %w", err)
245-
}
246-
247-
stdioServer := server.NewStdioServer(ghServer)
248-
249209
var slogHandler slog.Handler
250210
var logOutput io.Writer
251211
if cfg.LogFilePath != "" {
@@ -261,8 +221,22 @@ func RunStdioServer(cfg StdioServerConfig) error {
261221
}
262222
logger := slog.New(slogHandler)
263223
logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode)
264-
stdLogger := log.New(logOutput, stdioServerLogPrefix, 0)
265-
stdioServer.SetErrorLogger(stdLogger)
224+
225+
ghServer, err := NewMCPServer(MCPServerConfig{
226+
Version: cfg.Version,
227+
Host: cfg.Host,
228+
Token: cfg.Token,
229+
EnabledToolsets: cfg.EnabledToolsets,
230+
DynamicToolsets: cfg.DynamicToolsets,
231+
ReadOnly: cfg.ReadOnly,
232+
Translator: t,
233+
ContentWindowSize: cfg.ContentWindowSize,
234+
LockdownMode: cfg.LockdownMode,
235+
Logger: logger,
236+
})
237+
if err != nil {
238+
return fmt.Errorf("failed to create MCP server: %w", err)
239+
}
266240

267241
if cfg.ExportTranslations {
268242
// Once server is initialized, all translations are loaded
@@ -272,15 +246,20 @@ func RunStdioServer(cfg StdioServerConfig) error {
272246
// Start listening for messages
273247
errC := make(chan error, 1)
274248
go func() {
275-
in, out := io.Reader(os.Stdin), io.Writer(os.Stdout)
249+
var in io.ReadCloser
250+
var out io.WriteCloser
251+
252+
in = os.Stdin
253+
out = os.Stdout
276254

277255
if cfg.EnableCommandLogging {
278256
loggedIO := mcplog.NewIOLogger(in, out, logger)
279257
in, out = loggedIO, loggedIO
280258
}
259+
281260
// enable GitHub errors in the context
282261
ctx := errors.ContextWithGitHubErrors(ctx)
283-
errC <- stdioServer.Listen(ctx, in, out)
262+
errC <- ghServer.Run(ctx, &mcp.IOTransport{Reader: in, Writer: out})
284263
}()
285264

286265
// Output github-mcp-server string
@@ -497,3 +476,44 @@ func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, erro
497476
req.Header.Set("Authorization", "Bearer "+t.token)
498477
return t.transport.RoundTrip(req)
499478
}
479+
480+
func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler {
481+
return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) {
482+
// Ensure the context is cleared of any previous errors
483+
// as context isn't propagated through middleware
484+
ctx = errors.ContextWithGitHubErrors(ctx)
485+
return next(ctx, method, req)
486+
}
487+
}
488+
489+
func addUserAgentsMiddleware(cfg MCPServerConfig, restClient *gogithub.Client, gqlHTTPClient *http.Client) func(next mcp.MethodHandler) mcp.MethodHandler {
490+
return func(next mcp.MethodHandler) mcp.MethodHandler {
491+
return func(ctx context.Context, method string, request mcp.Request) (result mcp.Result, err error) {
492+
if method != "initialize" {
493+
return next(ctx, method, request)
494+
}
495+
496+
initializeRequest, ok := request.(*mcp.InitializeRequest)
497+
if !ok {
498+
return next(ctx, method, request)
499+
}
500+
501+
message := initializeRequest
502+
userAgent := fmt.Sprintf(
503+
"github-mcp-server/%s (%s/%s)",
504+
cfg.Version,
505+
message.Params.ClientInfo.Name,
506+
message.Params.ClientInfo.Version,
507+
)
508+
509+
restClient.UserAgent = userAgent
510+
511+
gqlHTTPClient.Transport = &userAgentTransport{
512+
transport: gqlHTTPClient.Transport,
513+
agent: userAgent,
514+
}
515+
516+
return next(ctx, method, request)
517+
}
518+
}
519+
}

pkg/log/io.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
// IOLogger is a wrapper around io.Reader and io.Writer that can be used
1010
// to log the data being read and written from the underlying streams
1111
type IOLogger struct {
12+
io.ReadWriteCloser
13+
1214
reader io.Reader
1315
writer io.Writer
1416
logger *slog.Logger

0 commit comments

Comments
 (0)