Skip to content

Commit 7f6e0e8

Browse files
committed
Add scope challenge & pat filtering based on token scopes
1 parent 9a338d7 commit 7f6e0e8

File tree

11 files changed

+175
-30
lines changed

11 files changed

+175
-30
lines changed

cmd/github-mcp-server/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ var (
109109
ContentWindowSize: viper.GetInt("content-window-size"),
110110
LockdownMode: viper.GetBool("lockdown-mode"),
111111
RepoAccessCacheTTL: &ttl,
112+
ScopeChallenge: viper.GetBool("scope-challenge"),
112113
}
113114

114115
return ghhttp.RunHTTPServer(httpConfig)
@@ -141,6 +142,7 @@ func init() {
141142
httpCmd.PersistentFlags().Int("port", 8082, "HTTP server port")
142143
httpCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)")
143144
httpCmd.PersistentFlags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)")
145+
httpCmd.PersistentFlags().Bool("scope-challenge", false, "Enable OAuth scope challenge responses and tool filtering based on token scopes")
144146

145147
// Bind flag to viper
146148
_ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets"))
@@ -161,6 +163,7 @@ func init() {
161163
_ = viper.BindPFlag("port", httpCmd.PersistentFlags().Lookup("port"))
162164
_ = viper.BindPFlag("base-url", httpCmd.PersistentFlags().Lookup("base-url"))
163165
_ = viper.BindPFlag("base-path", httpCmd.PersistentFlags().Lookup("base-path"))
166+
_ = viper.BindPFlag("scope-challenge", httpCmd.PersistentFlags().Lookup("scope-challenge"))
164167

165168
// Add subcommands
166169
rootCmd.AddCommand(stdioCmd)

internal/ghmcp/server.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,8 @@ func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string,
366366
return nil, fmt.Errorf("failed to parse API host: %w", err)
367367
}
368368

369-
baseRestURL, err := apiHost.BaseRESTURL(ctx)
370-
if err != nil {
371-
return nil, fmt.Errorf("failed to get base REST URL: %w", err)
372-
}
373-
374369
fetcher := scopes.NewFetcher(scopes.FetcherOptions{
375-
APIHost: baseRestURL.String(),
370+
APIHost: apiHost,
376371
})
377372

378373
return fetcher.FetchTokenScopes(ctx, token)

pkg/context/token.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,26 @@ var tokenCtxKey tokenCtx = "tokenctx"
1414
type TokenInfo struct {
1515
Token string
1616
TokenType utils.TokenType
17+
Scopes []string
1718
}
1819

1920
// WithTokenInfo adds TokenInfo to the context
20-
func WithTokenInfo(ctx context.Context, token string, tokenType utils.TokenType) context.Context {
21-
return context.WithValue(ctx, tokenCtxKey, TokenInfo{Token: token, TokenType: tokenType})
21+
func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context {
22+
return context.WithValue(ctx, tokenCtxKey, tokenInfo)
23+
}
24+
25+
func SetTokenScopes(ctx context.Context, scopes []string) context.Context {
26+
if tokenInfo, ok := GetTokenInfo(ctx); ok {
27+
tokenInfo.Scopes = scopes
28+
return WithTokenInfo(ctx, tokenInfo)
29+
}
30+
return ctx
2231
}
2332

2433
// GetTokenInfo retrieves the authentication token from the context
25-
func GetTokenInfo(ctx context.Context) (TokenInfo, bool) {
26-
if tokenInfo, ok := ctx.Value(tokenCtxKey).(TokenInfo); ok {
34+
func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) {
35+
if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok {
2736
return tokenInfo, true
2837
}
29-
return TokenInfo{}, false
38+
return nil, false
3039
}

pkg/http/handler.go

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"github.com/github/github-mcp-server/pkg/http/middleware"
1212
"github.com/github/github-mcp-server/pkg/http/oauth"
1313
"github.com/github/github-mcp-server/pkg/inventory"
14+
"github.com/github/github-mcp-server/pkg/scopes"
1415
"github.com/github/github-mcp-server/pkg/translations"
16+
"github.com/github/github-mcp-server/pkg/utils"
1517
"github.com/go-chi/chi/v5"
1618
"github.com/modelcontextprotocol/go-sdk/mcp"
1719
)
@@ -24,20 +26,29 @@ type Handler struct {
2426
config *ServerConfig
2527
deps github.ToolDependencies
2628
logger *slog.Logger
29+
apiHosts utils.APIHostResolver
2730
t translations.TranslationHelperFunc
2831
githubMcpServerFactory GitHubMCPServerFactoryFunc
2932
inventoryFactoryFunc InventoryFactoryFunc
3033
oauthCfg *oauth.Config
34+
scopeFetcher scopes.FetcherInterface
3135
}
3236

3337
type HandlerOptions struct {
3438
GitHubMcpServerFactory GitHubMCPServerFactoryFunc
3539
InventoryFactory InventoryFactoryFunc
3640
OAuthConfig *oauth.Config
41+
ScopeFetcher scopes.FetcherInterface
3742
}
3843

3944
type HandlerOption func(*HandlerOptions)
4045

46+
func WithScopeFetcher(f scopes.FetcherInterface) HandlerOption {
47+
return func(o *HandlerOptions) {
48+
o.ScopeFetcher = f
49+
}
50+
}
51+
4152
func WithGitHubMCPServerFactory(f GitHubMCPServerFactoryFunc) HandlerOption {
4253
return func(o *HandlerOptions) {
4354
o.GitHubMcpServerFactory = f
@@ -62,6 +73,7 @@ func NewHTTPMcpHandler(
6273
deps github.ToolDependencies,
6374
t translations.TranslationHelperFunc,
6475
logger *slog.Logger,
76+
apiHost utils.APIHostResolver,
6577
options ...HandlerOption) *Handler {
6678
opts := &HandlerOptions{}
6779
for _, o := range options {
@@ -75,29 +87,39 @@ func NewHTTPMcpHandler(
7587

7688
inventoryFactory := opts.InventoryFactory
7789
if inventoryFactory == nil {
78-
inventoryFactory = DefaultInventoryFactory(cfg, t, nil)
90+
inventoryFactory = DefaultInventoryFactory(cfg, t, nil, opts.ScopeFetcher)
91+
}
92+
93+
scopeFetcher := opts.ScopeFetcher
94+
if scopeFetcher == nil {
95+
scopeFetcher = scopes.NewFetcher(scopes.FetcherOptions{
96+
APIHost: apiHost,
97+
})
7998
}
8099

81100
return &Handler{
82101
ctx: ctx,
83102
config: cfg,
84103
deps: deps,
85104
logger: logger,
105+
apiHosts: apiHost,
86106
t: t,
87107
githubMcpServerFactory: githubMcpServerFactory,
88108
inventoryFactoryFunc: inventoryFactory,
89109
oauthCfg: opts.OAuthConfig,
110+
scopeFetcher: scopeFetcher,
90111
}
91112
}
92113

93114
func (h *Handler) RegisterMiddleware(r chi.Router) {
94115
r.Use(
95116
middleware.ExtractUserToken(h.oauthCfg),
96117
middleware.WithRequestConfig,
97-
middleware.WithScopeChallenge(h.oauthCfg),
98118
)
99119

100-
r.Use(middleware.WithScopeChallenge(h.oauthCfg))
120+
if h.config.ScopeChallenge {
121+
r.Use(middleware.WithScopeChallenge(h.oauthCfg, h.scopeFetcher))
122+
}
101123
}
102124

103125
// RegisterRoutes registers the routes for the MCP server
@@ -159,7 +181,7 @@ func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies
159181
return github.NewMCPServer(r.Context(), cfg, deps, inventory)
160182
}
161183

162-
func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFunc, staticChecker inventory.FeatureFlagChecker) InventoryFactoryFunc {
184+
func DefaultInventoryFactory(cfg *ServerConfig, t translations.TranslationHelperFunc, staticChecker inventory.FeatureFlagChecker, scopeFetcher scopes.FetcherInterface) InventoryFactoryFunc {
163185
return func(r *http.Request) (*inventory.Inventory, error) {
164186
b := github.NewInventory(t).WithDeprecatedAliases(github.DeprecatedToolAliases)
165187

@@ -170,6 +192,11 @@ func DefaultInventoryFactory(_ *ServerConfig, t translations.TranslationHelperFu
170192
}
171193

172194
b = InventoryFiltersForRequest(r, b)
195+
196+
if cfg.ScopeChallenge {
197+
b = b.WithFilter(ScopeChallengeFilter(r, scopeFetcher))
198+
}
199+
173200
b.WithServerInstructions()
174201

175202
return b.Build()
@@ -198,3 +225,26 @@ func InventoryFiltersForRequest(r *http.Request, builder *inventory.Builder) *in
198225

199226
return builder
200227
}
228+
229+
func ScopeChallengeFilter(r *http.Request, fetcher scopes.FetcherInterface) inventory.ToolFilter {
230+
ctx := r.Context()
231+
232+
tokenInfo, ok := ghcontext.GetTokenInfo(ctx)
233+
if !ok || tokenInfo == nil {
234+
return nil
235+
}
236+
237+
// Fetch token scopes for scope-based tool filtering (PAT tokens only)
238+
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
239+
// Fine-grained PATs and other token types don't support this, so we skip filtering.
240+
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
241+
scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token)
242+
if err != nil {
243+
return nil
244+
}
245+
246+
return github.CreateToolScopeFilter(scopesList)
247+
}
248+
249+
return nil
250+
}

pkg/http/handler_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import (
1212
"github.com/github/github-mcp-server/pkg/github"
1313
"github.com/github/github-mcp-server/pkg/http/headers"
1414
"github.com/github/github-mcp-server/pkg/inventory"
15+
"github.com/github/github-mcp-server/pkg/scopes"
1516
"github.com/github/github-mcp-server/pkg/translations"
17+
"github.com/github/github-mcp-server/pkg/utils"
1618
"github.com/go-chi/chi/v5"
1719
"github.com/modelcontextprotocol/go-sdk/mcp"
1820
"github.com/stretchr/testify/assert"
@@ -32,6 +34,20 @@ func mockTool(name, toolsetID string, readOnly bool) inventory.ServerTool {
3234
}
3335
}
3436

37+
type allScopesFetcher struct{}
38+
39+
func (f allScopesFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) {
40+
return []string{
41+
string(scopes.Repo),
42+
string(scopes.WriteOrg),
43+
string(scopes.User),
44+
string(scopes.Gist),
45+
string(scopes.Notifications),
46+
}, nil
47+
}
48+
49+
var _ scopes.FetcherInterface = allScopesFetcher{}
50+
3551
func TestInventoryFiltersForRequest(t *testing.T) {
3652
tools := []inventory.ServerTool{
3753
mockTool("get_file_contents", "repos", true),
@@ -230,6 +246,8 @@ func TestHTTPHandlerRoutes(t *testing.T) {
230246
t.Run(tt.name, func(t *testing.T) {
231247
var capturedInventory *inventory.Inventory
232248

249+
apiHost := utils.NewDefaultAPIHostResolver()
250+
233251
// Create inventory factory that captures the built inventory
234252
inventoryFactory := func(r *http.Request) (*inventory.Inventory, error) {
235253
builder := inventory.NewBuilder().
@@ -249,23 +267,32 @@ func TestHTTPHandlerRoutes(t *testing.T) {
249267
return mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil), nil
250268
}
251269

270+
allScopesFetcher := allScopesFetcher{}
271+
252272
// Create handler with our factories
253273
handler := NewHTTPMcpHandler(
254274
context.Background(),
255275
&ServerConfig{Version: "test"},
256276
nil, // deps not needed for this test
257277
translations.NullTranslationHelper,
258278
slog.Default(),
279+
apiHost,
259280
WithInventoryFactory(inventoryFactory),
260281
WithGitHubMCPServerFactory(mcpServerFactory),
282+
WithScopeFetcher(allScopesFetcher),
261283
)
262284

263285
// Create router and register routes
264286
r := chi.NewRouter()
287+
handler.RegisterMiddleware(r)
265288
handler.RegisterRoutes(r)
266289

267290
// Create request
268291
req := httptest.NewRequest(http.MethodPost, tt.path, nil)
292+
293+
// Ensure we're setting Authorization header for token context
294+
req.Header.Set(headers.AuthorizationHeader, "Bearer ghp_testtoken")
295+
269296
for k, v := range tt.headers {
270297
req.Header.Set(k, v)
271298
}

pkg/http/middleware/scope_challenge.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func FetchScopesFromGitHubAPI(ctx context.Context, token string, apiHost utils.A
5757

5858
// WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to
5959
// complete the request and returns a scope challenge if not.
60-
func WithScopeChallenge(oauthCfg *oauth.Config) func(http.Handler) http.Handler {
60+
func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInterface) func(http.Handler) http.Handler {
6161
return func(next http.Handler) http.Handler {
6262
fn := func(w http.ResponseWriter, r *http.Request) {
6363
ctx := r.Context()
@@ -136,7 +136,11 @@ func WithScopeChallenge(oauthCfg *oauth.Config) func(http.Handler) http.Handler
136136
}
137137

138138
// Get OAuth scopes from GitHub API
139-
activeScopes, err := FetchScopesFromGitHubAPI(ctx, tokenInfo.Token, oauthCfg.ApiHosts)
139+
activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
140+
if err != nil {
141+
next.ServeHTTP(w, r)
142+
return
143+
}
140144

141145
// Check if user has the required scopes
142146
if toolScopeInfo.HasAcceptedScope(activeScopes...) {

pkg/http/middleware/token.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl
2626
}
2727

2828
ctx := r.Context()
29-
ctx = ghcontext.WithTokenInfo(ctx, token, tokenType)
29+
ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{
30+
Token: token,
31+
TokenType: tokenType,
32+
})
3033
r = r.WithContext(ctx)
3134

3235
next.ServeHTTP(w, r)

pkg/http/server.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ type ServerConfig struct {
5757

5858
// RepoAccessCacheTTL overrides the default TTL for repository access cache entries.
5959
RepoAccessCacheTTL *time.Duration
60+
61+
// ScopeChallenge indicates if we should return OAuth scope challenges, and if we should perform
62+
// tool filtering based on token scopes.
63+
ScopeChallenge bool
6064
}
6165

6266
func RunHTTPServer(cfg ServerConfig) error {
@@ -117,8 +121,16 @@ func RunHTTPServer(cfg ServerConfig) error {
117121
ResourcePath: cfg.ResourcePath,
118122
}
119123

124+
severOptions := []HandlerOption{}
125+
if cfg.ScopeChallenge {
126+
scopeFetcher := scopes.NewFetcher(scopes.FetcherOptions{
127+
APIHost: apiHost,
128+
})
129+
severOptions = append(severOptions, WithScopeFetcher(scopeFetcher))
130+
}
131+
120132
r := chi.NewRouter()
121-
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, WithOAuthConfig(oauthCfg))
133+
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(severOptions, WithOAuthConfig(oauthCfg))...)
122134
oauthHandler, err := oauth.NewAuthHandler(oauthCfg)
123135
if err != nil {
124136
return fmt.Errorf("failed to create OAuth handler: %w", err)

0 commit comments

Comments
 (0)