Skip to content

Commit beb728a

Browse files
committed
Move scope storage into its own context key, separately from token info.
This allows us to provide scopes seperately in the remote server, where we have scopes before we do the auth.
1 parent aa30220 commit beb728a

File tree

5 files changed

+43
-21
lines changed

5 files changed

+43
-21
lines changed

pkg/context/token.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ type tokenCtx string
1212
var tokenCtxKey tokenCtx = "tokenctx"
1313

1414
type TokenInfo struct {
15-
Token string
16-
TokenType utils.TokenType
17-
ScopesFetched bool
18-
Scopes []string
15+
Token string
16+
TokenType utils.TokenType
1917
}
2018

2119
// WithTokenInfo adds TokenInfo to the context
@@ -30,3 +28,20 @@ func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) {
3028
}
3129
return nil, false
3230
}
31+
32+
type TokenScopesKey tokenCtx
33+
34+
var tokenScopesKey TokenScopesKey = "tokenscopesctx"
35+
36+
// WithTokenScopes adds token scopes to the context
37+
func WithTokenScopes(ctx context.Context, scopes []string) context.Context {
38+
return context.WithValue(ctx, tokenScopesKey, scopes)
39+
}
40+
41+
// GetTokenScopes retrieves token scopes from the context
42+
func GetTokenScopes(ctx context.Context) ([]string, bool) {
43+
if scopes, ok := ctx.Value(tokenScopesKey).([]string); ok {
44+
return scopes, true
45+
}
46+
return nil, false
47+
}

pkg/http/handler.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,10 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche
271271
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
272272
// Fine-grained PATs and other token types don't support this, so we skip filtering.
273273
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
274-
if tokenInfo.ScopesFetched {
275-
return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes))
274+
// Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them.
275+
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
276+
if ok {
277+
return b.WithFilter(github.CreateToolScopeFilter(existingScopes))
276278
}
277279

278280
scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token)

pkg/http/middleware/pat_scope.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,22 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu
2626
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
2727
// Fine-grained PATs and other token types don't support this, so we skip filtering.
2828
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
29+
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
30+
if ok {
31+
logger.Debug("using existing scopes from context", "scopes", existingScopes)
32+
next.ServeHTTP(w, r)
33+
return
34+
}
35+
2936
scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
3037
if err != nil {
3138
logger.Warn("failed to fetch PAT scopes", "error", err)
3239
next.ServeHTTP(w, r)
3340
return
3441
}
3542

36-
tokenInfo.Scopes = scopesList
37-
tokenInfo.ScopesFetched = true
38-
3943
// Store fetched scopes in context for downstream use
40-
ctx := ghcontext.WithTokenInfo(ctx, tokenInfo)
44+
ctx = ghcontext.WithTokenScopes(ctx, scopesList)
4145

4246
next.ServeHTTP(w, r.WithContext(ctx))
4347
return

pkg/http/middleware/pat_scope_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,13 @@ func TestWithPATScopes(t *testing.T) {
111111

112112
for _, tt := range tests {
113113
t.Run(tt.name, func(t *testing.T) {
114-
var capturedTokenInfo *ghcontext.TokenInfo
114+
var capturedScopes []string
115+
var scopesFound bool
115116
var nextHandlerCalled bool
116117

117118
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118119
nextHandlerCalled = true
119-
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
120+
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
120121
w.WriteHeader(http.StatusOK)
121122
})
122123

@@ -141,10 +142,9 @@ func TestWithPATScopes(t *testing.T) {
141142

142143
assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch")
143144

144-
if tt.expectNextHandlerCalled && tt.tokenInfo != nil {
145-
require.NotNil(t, capturedTokenInfo, "expected token info in context")
146-
assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched)
147-
assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes)
145+
if tt.expectNextHandlerCalled {
146+
assert.Equal(t, tt.expectScopesFetched, scopesFound, "scopes found mismatch")
147+
assert.Equal(t, tt.expectedScopes, capturedScopes)
148148
}
149149
})
150150
}
@@ -154,9 +154,12 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
154154
logger := slog.Default()
155155

156156
var capturedTokenInfo *ghcontext.TokenInfo
157+
var capturedScopes []string
158+
var scopesFound bool
157159

158160
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159161
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
162+
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
160163
w.WriteHeader(http.StatusOK)
161164
})
162165

@@ -182,6 +185,6 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
182185
require.NotNil(t, capturedTokenInfo)
183186
assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token)
184187
assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType)
185-
assert.True(t, capturedTokenInfo.ScopesFetched)
186-
assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes)
188+
assert.True(t, scopesFound)
189+
assert.Equal(t, []string{"repo", "user"}, capturedScopes)
187190
}

pkg/http/middleware/scope_challenge.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter
102102
}
103103

104104
// Store active scopes in context for downstream use
105-
tokenInfo.Scopes = activeScopes
106-
tokenInfo.ScopesFetched = true
107-
ctx = ghcontext.WithTokenInfo(ctx, tokenInfo)
105+
ctx = ghcontext.WithTokenScopes(ctx, activeScopes)
108106
r = r.WithContext(ctx)
109107

110108
// Check if user has the required scopes

0 commit comments

Comments
 (0)