Skip to content

Commit 4989da6

Browse files
committed
Resolve the URL at request time.
If we do this at start time, we won't know which GHEC tenant we're on.
1 parent b81964b commit 4989da6

File tree

2 files changed

+74
-47
lines changed

2 files changed

+74
-47
lines changed

pkg/http/oauth/oauth.go

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ type Config struct {
5454

5555
// AuthHandler handles OAuth-related HTTP endpoints.
5656
type AuthHandler struct {
57-
cfg *Config
57+
cfg *Config
58+
apiHost utils.APIHostResolver
5859
}
5960

6061
// NewAuthHandler creates a new OAuth auth handler.
@@ -63,18 +64,9 @@ func NewAuthHandler(ctx context.Context, cfg *Config, apiHost utils.APIHostResol
6364
cfg = &Config{}
6465
}
6566

66-
// Default authorization server to GitHub
67-
if cfg.AuthorizationServer == "" {
68-
url, err := apiHost.AuthorizationServerURL(ctx)
69-
if err != nil {
70-
return nil, fmt.Errorf("failed to get authorization server URL from API host: %w", err)
71-
}
72-
73-
cfg.AuthorizationServer = url.String()
74-
}
75-
7667
return &AuthHandler{
77-
cfg: cfg,
68+
cfg: cfg,
69+
apiHost: apiHost,
7870
}, nil
7971
}
8072

@@ -99,15 +91,28 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) {
9991

10092
func (h *AuthHandler) metadataHandler() http.Handler {
10193
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94+
ctx := r.Context()
10295
resourcePath := resolveResourcePath(
10396
strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix),
10497
h.cfg.ResourcePath,
10598
)
10699
resourceURL := h.buildResourceURL(r, resourcePath)
107100

101+
var authorizationServerURL string
102+
if h.cfg.AuthorizationServer != "" {
103+
authorizationServerURL = h.cfg.AuthorizationServer
104+
} else {
105+
authURL, err := h.apiHost.AuthorizationServerURL(ctx)
106+
if err != nil {
107+
http.Error(w, fmt.Sprintf("failed to resolve authorization server URL: %v", err), http.StatusInternalServerError)
108+
return
109+
}
110+
authorizationServerURL = authURL.String()
111+
}
112+
108113
metadata := &oauthex.ProtectedResourceMetadata{
109114
Resource: resourceURL,
110-
AuthorizationServers: []string{h.cfg.AuthorizationServer},
115+
AuthorizationServers: []string{authorizationServerURL},
111116
ResourceName: "GitHub MCP Server",
112117
ScopesSupported: SupportedScopes,
113118
BearerMethodsSupported: []string{"header"},

pkg/http/oauth/oauth_test.go

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,6 @@ func TestNewAuthHandler(t *testing.T) {
3030
expectedAuthServer string
3131
expectedResourcePath string
3232
}{
33-
{
34-
name: "nil config uses defaults",
35-
cfg: nil,
36-
expectedAuthServer: defaultAuthorizationServer,
37-
expectedResourcePath: "",
38-
},
39-
{
40-
name: "empty config uses defaults",
41-
cfg: &Config{},
42-
expectedAuthServer: defaultAuthorizationServer,
43-
expectedResourcePath: "",
44-
},
4533
{
4634
name: "custom authorization server",
4735
cfg: &Config{
@@ -56,7 +44,7 @@ func TestNewAuthHandler(t *testing.T) {
5644
BaseURL: "https://example.com",
5745
ResourcePath: "/mcp",
5846
},
59-
expectedAuthServer: defaultAuthorizationServer,
47+
expectedAuthServer: "",
6048
expectedResourcePath: "/mcp",
6149
},
6250
}
@@ -636,42 +624,44 @@ func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
636624
t.Parallel()
637625

638626
tests := []struct {
639-
name string
640-
host string
641-
expectedURL string
642-
expectError bool
643-
errorContains string
627+
name string
628+
host string
629+
expectedURL string
630+
expectedError bool
631+
expectedStatusCode int
632+
errorContains string
644633
}{
645634
{
646-
name: "valid host returns authorization server URL",
647-
host: "http://github.com",
648-
expectedURL: "https://github.com/login/oauth",
649-
expectError: false,
635+
name: "valid host returns authorization server URL",
636+
host: "http://github.com",
637+
expectedURL: "https://github.com/login/oauth",
638+
expectedStatusCode: http.StatusOK,
650639
},
651640
{
652641
name: "invalid host returns error",
653642
host: "://invalid-url",
654643
expectedURL: "",
655-
expectError: true,
644+
expectedError: true,
656645
errorContains: "could not parse host as URL",
657646
},
658647
{
659648
name: "host without scheme returns error",
660649
host: "github.com",
661650
expectedURL: "",
662-
expectError: true,
651+
expectedError: true,
663652
errorContains: "host must have a scheme",
664653
},
665654
{
666-
name: "GHEC host returns correct authorization server URL",
667-
host: "https://test.ghe.com",
668-
expectedURL: "https://test.ghe.com/login/oauth",
655+
name: "GHEC host returns correct authorization server URL",
656+
host: "https://test.ghe.com",
657+
expectedURL: "https://test.ghe.com/login/oauth",
658+
expectedStatusCode: http.StatusOK,
669659
},
670660
{
671-
name: "GHES host returns correct authorization server URL",
672-
host: "https://ghe.example.com",
673-
expectedURL: "https://ghe.example.com/login/oauth",
674-
expectError: false,
661+
name: "GHES host returns correct authorization server URL",
662+
host: "https://ghe.example.com",
663+
expectedURL: "https://ghe.example.com/login/oauth",
664+
expectedStatusCode: http.StatusOK,
675665
},
676666
}
677667

@@ -680,18 +670,50 @@ func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
680670
t.Parallel()
681671

682672
apiHost, err := utils.NewAPIHost(tc.host)
683-
if tc.expectError {
673+
if tc.expectedError {
684674
require.Error(t, err)
685675
if tc.errorContains != "" {
686676
assert.Contains(t, err.Error(), tc.errorContains)
687677
}
688678
return
679+
} else {
680+
require.NoError(t, err)
689681
}
682+
683+
handler, err := NewAuthHandler(t.Context(), &Config{
684+
BaseURL: "https://api.example.com",
685+
}, apiHost)
690686
require.NoError(t, err)
691687

692-
url, err := apiHost.AuthorizationServerURL(t.Context())
688+
router := chi.NewRouter()
689+
handler.RegisterRoutes(router)
690+
691+
req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil)
692+
req.Host = "api.example.com"
693+
694+
rec := httptest.NewRecorder()
695+
router.ServeHTTP(rec, req)
696+
697+
require.Equal(t, http.StatusOK, rec.Code)
698+
699+
var response map[string]any
700+
err = json.Unmarshal(rec.Body.Bytes(), &response)
693701
require.NoError(t, err)
694-
assert.Equal(t, tc.expectedURL, url.String())
702+
703+
assert.Contains(t, response, "authorization_servers")
704+
if tc.expectedStatusCode != http.StatusOK {
705+
require.Equal(t, tc.expectedStatusCode, rec.Code)
706+
if tc.errorContains != "" {
707+
assert.Contains(t, rec.Body.String(), tc.errorContains)
708+
}
709+
return
710+
}
711+
require.NoError(t, err)
712+
713+
responseAuthServers, ok := response["authorization_servers"].([]any)
714+
require.True(t, ok)
715+
require.Len(t, responseAuthServers, 1)
716+
assert.Equal(t, tc.expectedURL, responseAuthServers[0])
695717
})
696718
}
697719
}

0 commit comments

Comments
 (0)