@@ -30,18 +30,6 @@ func TestNewAuthHandler(t *testing.T) {
3030 expectedAuthServer string
3131 expectedResourcePath string
3232 }{
33- {
34- name : "nil config uses defaults" ,
35- cfg : nil ,
36- expectedAuthServer : defaultAuthorizationServer ,
37- expectedResourcePath : "" ,
38- },
39- {
40- name : "empty config uses defaults" ,
41- cfg : & Config {},
42- expectedAuthServer : defaultAuthorizationServer ,
43- expectedResourcePath : "" ,
44- },
4533 {
4634 name : "custom authorization server" ,
4735 cfg : & Config {
@@ -56,7 +44,7 @@ func TestNewAuthHandler(t *testing.T) {
5644 BaseURL : "https://example.com" ,
5745 ResourcePath : "/mcp" ,
5846 },
59- expectedAuthServer : defaultAuthorizationServer ,
47+ expectedAuthServer : "" ,
6048 expectedResourcePath : "/mcp" ,
6149 },
6250 }
@@ -636,42 +624,44 @@ func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
636624 t .Parallel ()
637625
638626 tests := []struct {
639- name string
640- host string
641- expectedURL string
642- expectError bool
643- errorContains string
627+ name string
628+ host string
629+ expectedURL string
630+ expectedError bool
631+ expectedStatusCode int
632+ errorContains string
644633 }{
645634 {
646- name : "valid host returns authorization server URL" ,
647- host : "http://github.com" ,
648- expectedURL : "https://github.com/login/oauth" ,
649- expectError : false ,
635+ name : "valid host returns authorization server URL" ,
636+ host : "http://github.com" ,
637+ expectedURL : "https://github.com/login/oauth" ,
638+ expectedStatusCode : http . StatusOK ,
650639 },
651640 {
652641 name : "invalid host returns error" ,
653642 host : "://invalid-url" ,
654643 expectedURL : "" ,
655- expectError : true ,
644+ expectedError : true ,
656645 errorContains : "could not parse host as URL" ,
657646 },
658647 {
659648 name : "host without scheme returns error" ,
660649 host : "github.com" ,
661650 expectedURL : "" ,
662- expectError : true ,
651+ expectedError : true ,
663652 errorContains : "host must have a scheme" ,
664653 },
665654 {
666- name : "GHEC host returns correct authorization server URL" ,
667- host : "https://test.ghe.com" ,
668- expectedURL : "https://test.ghe.com/login/oauth" ,
655+ name : "GHEC host returns correct authorization server URL" ,
656+ host : "https://test.ghe.com" ,
657+ expectedURL : "https://test.ghe.com/login/oauth" ,
658+ expectedStatusCode : http .StatusOK ,
669659 },
670660 {
671- name : "GHES host returns correct authorization server URL" ,
672- host : "https://ghe.example.com" ,
673- expectedURL : "https://ghe.example.com/login/oauth" ,
674- expectError : false ,
661+ name : "GHES host returns correct authorization server URL" ,
662+ host : "https://ghe.example.com" ,
663+ expectedURL : "https://ghe.example.com/login/oauth" ,
664+ expectedStatusCode : http . StatusOK ,
675665 },
676666 }
677667
@@ -680,18 +670,50 @@ func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
680670 t .Parallel ()
681671
682672 apiHost , err := utils .NewAPIHost (tc .host )
683- if tc .expectError {
673+ if tc .expectedError {
684674 require .Error (t , err )
685675 if tc .errorContains != "" {
686676 assert .Contains (t , err .Error (), tc .errorContains )
687677 }
688678 return
679+ } else {
680+ require .NoError (t , err )
689681 }
682+
683+ handler , err := NewAuthHandler (t .Context (), & Config {
684+ BaseURL : "https://api.example.com" ,
685+ }, apiHost )
690686 require .NoError (t , err )
691687
692- url , err := apiHost .AuthorizationServerURL (t .Context ())
688+ router := chi .NewRouter ()
689+ handler .RegisterRoutes (router )
690+
691+ req := httptest .NewRequest (http .MethodGet , OAuthProtectedResourcePrefix , nil )
692+ req .Host = "api.example.com"
693+
694+ rec := httptest .NewRecorder ()
695+ router .ServeHTTP (rec , req )
696+
697+ require .Equal (t , http .StatusOK , rec .Code )
698+
699+ var response map [string ]any
700+ err = json .Unmarshal (rec .Body .Bytes (), & response )
693701 require .NoError (t , err )
694- assert .Equal (t , tc .expectedURL , url .String ())
702+
703+ assert .Contains (t , response , "authorization_servers" )
704+ if tc .expectedStatusCode != http .StatusOK {
705+ require .Equal (t , tc .expectedStatusCode , rec .Code )
706+ if tc .errorContains != "" {
707+ assert .Contains (t , rec .Body .String (), tc .errorContains )
708+ }
709+ return
710+ }
711+ require .NoError (t , err )
712+
713+ responseAuthServers , ok := response ["authorization_servers" ].([]any )
714+ require .True (t , ok )
715+ require .Len (t , responseAuthServers , 1 )
716+ assert .Equal (t , tc .expectedURL , responseAuthServers [0 ])
695717 })
696718 }
697719}
0 commit comments