Skip to content

Commit 249755f

Browse files
Add GraphQLFeaturesTransport for reusable header handling
Co-authored-by: SamMorrowDrums <4811358+SamMorrowDrums@users.noreply.github.com>
1 parent 0ea475c commit 249755f

File tree

3 files changed

+192
-9
lines changed

3 files changed

+192
-9
lines changed

internal/ghmcp/server.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,13 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients,
9494

9595
// Construct GraphQL client
9696
// We use NewEnterpriseClient unconditionally since we already parsed the API host
97+
// Layer transports: DefaultTransport -> bearerAuthTransport -> GraphQLFeaturesTransport
9798
gqlHTTPClient := &http.Client{
98-
Transport: &bearerAuthTransport{
99-
transport: http.DefaultTransport,
100-
token: cfg.Token,
99+
Transport: &github.GraphQLFeaturesTransport{
100+
Transport: &bearerAuthTransport{
101+
transport: http.DefaultTransport,
102+
token: cfg.Token,
103+
},
101104
},
102105
}
103106
gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
@@ -622,12 +625,6 @@ type bearerAuthTransport struct {
622625
func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
623626
req = req.Clone(req.Context())
624627
req.Header.Set("Authorization", "Bearer "+t.token)
625-
626-
// Check for GraphQL-Features in context and add header if present
627-
if features := github.GetGraphQLFeatures(req.Context()); len(features) > 0 {
628-
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
629-
}
630-
631628
return t.transport.RoundTrip(req)
632629
}
633630

pkg/github/transport.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package github
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
)
7+
8+
// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features
9+
// header based on context values set by withGraphQLFeatures.
10+
//
11+
// This transport should be used in the HTTP client chain for githubv4.Client
12+
// to ensure GraphQL feature flags are properly sent to the GitHub API.
13+
//
14+
// Example usage:
15+
//
16+
// httpClient := &http.Client{
17+
// Transport: &github.GraphQLFeaturesTransport{
18+
// Transport: http.DefaultTransport,
19+
// },
20+
// }
21+
// gqlClient := githubv4.NewClient(httpClient)
22+
type GraphQLFeaturesTransport struct {
23+
// Transport is the underlying http.RoundTripper. If nil, http.DefaultTransport is used.
24+
Transport http.RoundTripper
25+
}
26+
27+
// RoundTrip implements http.RoundTripper.
28+
// It adds the GraphQL-Features header if features are present in the request context.
29+
func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) {
30+
transport := t.Transport
31+
if transport == nil {
32+
transport = http.DefaultTransport
33+
}
34+
35+
// Clone request to avoid modifying the original
36+
req = req.Clone(req.Context())
37+
38+
// Check for GraphQL-Features in context and add header if present
39+
if features := GetGraphQLFeatures(req.Context()); len(features) > 0 {
40+
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
41+
}
42+
43+
return transport.RoundTrip(req)
44+
}

pkg/github/transport_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestGraphQLFeaturesTransport(t *testing.T) {
14+
tests := []struct {
15+
name string
16+
features []string
17+
expectHeader bool
18+
expectedHeaderVal string
19+
}{
20+
{
21+
name: "adds single feature to header",
22+
features: []string{"issues_copilot_assignment_api_support"},
23+
expectHeader: true,
24+
expectedHeaderVal: "issues_copilot_assignment_api_support",
25+
},
26+
{
27+
name: "adds multiple features to header",
28+
features: []string{"feature1", "feature2", "feature3"},
29+
expectHeader: true,
30+
expectedHeaderVal: "feature1, feature2, feature3",
31+
},
32+
{
33+
name: "no header when no features in context",
34+
features: nil,
35+
expectHeader: false,
36+
},
37+
{
38+
name: "no header when empty features slice",
39+
features: []string{},
40+
expectHeader: false,
41+
},
42+
}
43+
44+
for _, tt := range tests {
45+
t.Run(tt.name, func(t *testing.T) {
46+
// Create a test server that captures the request
47+
var capturedReq *http.Request
48+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49+
capturedReq = r
50+
w.WriteHeader(http.StatusOK)
51+
}))
52+
defer server.Close()
53+
54+
// Create HTTP client with GraphQLFeaturesTransport
55+
client := &http.Client{
56+
Transport: &GraphQLFeaturesTransport{
57+
Transport: http.DefaultTransport,
58+
},
59+
}
60+
61+
// Create request with or without features in context
62+
ctx := context.Background()
63+
if tt.features != nil {
64+
ctx = withGraphQLFeatures(ctx, tt.features...)
65+
}
66+
67+
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
68+
require.NoError(t, err)
69+
70+
// Make request
71+
resp, err := client.Do(req)
72+
require.NoError(t, err)
73+
defer resp.Body.Close()
74+
75+
// Verify header
76+
if tt.expectHeader {
77+
assert.Equal(t, tt.expectedHeaderVal, capturedReq.Header.Get("GraphQL-Features"))
78+
} else {
79+
assert.Empty(t, capturedReq.Header.Get("GraphQL-Features"))
80+
}
81+
})
82+
}
83+
}
84+
85+
func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) {
86+
// Test that nil Transport falls back to http.DefaultTransport
87+
var capturedReq *http.Request
88+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
89+
capturedReq = r
90+
w.WriteHeader(http.StatusOK)
91+
}))
92+
defer server.Close()
93+
94+
client := &http.Client{
95+
Transport: &GraphQLFeaturesTransport{
96+
Transport: nil, // Explicitly nil
97+
},
98+
}
99+
100+
ctx := withGraphQLFeatures(context.Background(), "test_feature")
101+
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
102+
require.NoError(t, err)
103+
104+
resp, err := client.Do(req)
105+
require.NoError(t, err)
106+
defer resp.Body.Close()
107+
108+
assert.Equal(t, "test_feature", capturedReq.Header.Get("GraphQL-Features"))
109+
}
110+
111+
func TestGraphQLFeaturesTransport_PreservesOtherHeaders(t *testing.T) {
112+
// Test that the transport doesn't interfere with other headers
113+
var capturedReq *http.Request
114+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
115+
capturedReq = r
116+
w.WriteHeader(http.StatusOK)
117+
}))
118+
defer server.Close()
119+
120+
client := &http.Client{
121+
Transport: &GraphQLFeaturesTransport{
122+
Transport: http.DefaultTransport,
123+
},
124+
}
125+
126+
ctx := withGraphQLFeatures(context.Background(), "feature1")
127+
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
128+
require.NoError(t, err)
129+
130+
// Add custom headers
131+
req.Header.Set("Authorization", "Bearer test-token")
132+
req.Header.Set("User-Agent", "test-agent")
133+
134+
resp, err := client.Do(req)
135+
require.NoError(t, err)
136+
defer resp.Body.Close()
137+
138+
// Verify all headers are preserved
139+
assert.Equal(t, "feature1", capturedReq.Header.Get("GraphQL-Features"))
140+
assert.Equal(t, "Bearer test-token", capturedReq.Header.Get("Authorization"))
141+
assert.Equal(t, "test-agent", capturedReq.Header.Get("User-Agent"))
142+
}

0 commit comments

Comments
 (0)