Skip to content

Commit 6ec5c07

Browse files
refactor(oauth): remove dead code and consolidate helpers
- Remove unused StartDeviceFlow, StartInteractiveFlow, StartOAuthFlow (170+ lines) - Consolidate generateState and generateElicitationID into generateRandomToken - Simplify tests to verify behavior not internal state - Net reduction: ~200 lines of dead code
1 parent 0a3fc25 commit 6ec5c07

File tree

3 files changed

+15
-210
lines changed

3 files changed

+15
-210
lines changed

internal/oauth/manager.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func (m *Manager) startDeviceFlowWithElicitation(ctx context.Context, session *m
116116

117117
// Use session elicitation if available to show the user the verification URL and code
118118
if session != nil {
119-
elicitID, err := generateElicitationID()
119+
elicitID, err := generateRandomToken()
120120
if err != nil {
121121
// Log warning but continue - elicitation ID is for tracking only
122122
elicitID = "fallback-id"
@@ -158,7 +158,7 @@ func (m *Manager) startPKCEFlowWithElicitation(ctx context.Context, session *mcp
158158
}
159159

160160
// Generate state for CSRF protection
161-
state, err := generateState()
161+
state, err := generateRandomToken()
162162
if err != nil {
163163
return m.startDeviceFlowWithElicitation(ctx, session)
164164
}
@@ -206,7 +206,7 @@ func (m *Manager) startPKCEFlowWithElicitation(ctx context.Context, session *mcp
206206
// Only elicit if browser failed to open (e.g., headless environment)
207207
// and we need to show the user the URL manually
208208
if browserErr != nil && session != nil {
209-
elicitID, _ := generateElicitationID()
209+
elicitID, _ := generateRandomToken()
210210
_, _ = session.Elicit(ctx, &mcp.ElicitParams{
211211
Mode: "url",
212212
URL: authURL,
@@ -258,15 +258,9 @@ func (m *Manager) setToken(token *Result) {
258258

259259
// Helper functions
260260

261-
func generateElicitationID() (string, error) {
262-
b := make([]byte, 16)
263-
if _, err := rand.Read(b); err != nil {
264-
return "", err
265-
}
266-
return base64.RawURLEncoding.EncodeToString(b), nil
267-
}
268-
269-
func generateState() (string, error) {
261+
// generateRandomToken generates a cryptographically random URL-safe token.
262+
// Used for CSRF state and elicitation IDs.
263+
func generateRandomToken() (string, error) {
270264
b := make([]byte, 16)
271265
if _, err := rand.Read(b); err != nil {
272266
return "", err

internal/oauth/oauth.go

Lines changed: 0 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
package oauth
22

33
import (
4-
"context"
54
"crypto/rand"
65
"encoding/base64"
76
"fmt"
87
"io"
9-
"log"
108
"net"
119
"net/http"
1210
"os"
1311
"os/exec"
1412
"runtime"
1513
"strings"
1614
"time"
17-
18-
"golang.org/x/oauth2"
1915
)
2016

2117
const (
@@ -72,178 +68,6 @@ func isRunningInDocker() bool {
7268
return false
7369
}
7470

75-
// StartDeviceFlow initiates an OAuth device authorization flow
76-
// This is suitable for environments without callback capabilities (like Docker containers)
77-
func StartDeviceFlow(ctx context.Context, cfg Config) (*Result, error) {
78-
oauth2Cfg := &oauth2.Config{
79-
ClientID: cfg.ClientID,
80-
ClientSecret: cfg.ClientSecret,
81-
Scopes: cfg.Scopes,
82-
Endpoint: oauth2.Endpoint{
83-
AuthURL: cfg.AuthURL,
84-
TokenURL: cfg.TokenURL,
85-
DeviceAuthURL: cfg.DeviceAuthURL,
86-
},
87-
}
88-
89-
// Request device authorization
90-
deviceAuth, err := oauth2Cfg.DeviceAuth(ctx)
91-
if err != nil {
92-
return nil, fmt.Errorf("failed to get device authorization: %w", err)
93-
}
94-
95-
// Display verification instructions to user
96-
fmt.Fprint(os.Stderr, "\n"+strings.Repeat("=", 80)+"\n")
97-
fmt.Fprint(os.Stderr, "GitHub OAuth Device Authorization\n")
98-
fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n")
99-
fmt.Fprintf(os.Stderr, "Please visit: %s\n\n", deviceAuth.VerificationURI)
100-
fmt.Fprintf(os.Stderr, "And enter code: %s\n\n", deviceAuth.UserCode)
101-
fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n")
102-
103-
// Poll for token
104-
token, err := oauth2Cfg.DeviceAccessToken(ctx, deviceAuth)
105-
if err != nil {
106-
return nil, fmt.Errorf("failed to get device access token: %w", err)
107-
}
108-
109-
fmt.Fprint(os.Stderr, "\n✓ Authorization successful!\n\n")
110-
111-
return &Result{
112-
AccessToken: token.AccessToken,
113-
RefreshToken: token.RefreshToken,
114-
TokenType: token.TokenType,
115-
Expiry: token.Expiry,
116-
}, nil
117-
}
118-
119-
// StartOAuthFlow automatically selects the appropriate OAuth flow based on environment
120-
// - Device flow for Docker containers (no callback server possible)
121-
// - Interactive PKCE flow for native binaries (best UX with browser)
122-
func StartOAuthFlow(ctx context.Context, cfg Config) (*Result, error) {
123-
// Check if we're in Docker
124-
if isRunningInDocker() && cfg.CallbackPort == 0 {
125-
// Docker without explicit callback port - use device flow
126-
log.Printf("Detected Docker environment, using device flow")
127-
return StartDeviceFlow(ctx, cfg)
128-
}
129-
130-
// Use interactive PKCE flow (browser-based)
131-
return StartInteractiveFlow(ctx, cfg)
132-
}
133-
134-
// StartInteractiveFlow initiates an interactive OAuth flow with PKCE
135-
// This is intended for stdio mode only and opens a browser for user consent
136-
func StartInteractiveFlow(ctx context.Context, cfg Config) (*Result, error) {
137-
// Generate PKCE verifier
138-
verifier, err := generatePKCEVerifier()
139-
if err != nil {
140-
return nil, fmt.Errorf("failed to generate PKCE verifier: %w", err)
141-
}
142-
143-
// Create OAuth2 config
144-
oauth2Cfg := &oauth2.Config{
145-
ClientID: cfg.ClientID,
146-
ClientSecret: cfg.ClientSecret,
147-
RedirectURL: cfg.RedirectURL,
148-
Scopes: cfg.Scopes,
149-
Endpoint: oauth2.Endpoint{
150-
AuthURL: cfg.AuthURL,
151-
TokenURL: cfg.TokenURL,
152-
},
153-
}
154-
155-
// Generate state for CSRF protection
156-
stateBytes := make([]byte, 16)
157-
if _, err := rand.Read(stateBytes); err != nil {
158-
return nil, fmt.Errorf("failed to generate state: %w", err)
159-
}
160-
state := base64.RawURLEncoding.EncodeToString(stateBytes)
161-
162-
// Start local HTTP server for callback
163-
listener, port, err := startLocalServer(cfg.CallbackPort)
164-
if err != nil {
165-
return nil, fmt.Errorf("failed to start local server: %w", err)
166-
}
167-
defer listener.Close()
168-
169-
// Update redirect URL with actual port
170-
oauth2Cfg.RedirectURL = fmt.Sprintf("http://localhost:%d/callback", port)
171-
172-
// Channel to receive the authorization code
173-
codeChan := make(chan string, 1)
174-
errChan := make(chan error, 1)
175-
176-
// Setup HTTP handler for callback
177-
server := &http.Server{
178-
Handler: createCallbackHandler(state, codeChan, errChan),
179-
ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks
180-
}
181-
182-
// Start server in background
183-
go func() {
184-
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
185-
errChan <- fmt.Errorf("server error: %w", err)
186-
}
187-
}()
188-
189-
// Build authorization URL with PKCE
190-
authURL := oauth2Cfg.AuthCodeURL(
191-
state,
192-
oauth2.S256ChallengeOption(verifier),
193-
)
194-
195-
// Display URL to user and try to open browser
196-
fmt.Fprint(os.Stderr, "\n"+strings.Repeat("=", 80)+"\n")
197-
fmt.Fprint(os.Stderr, "GitHub OAuth Authorization Required\n")
198-
fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n")
199-
fmt.Fprint(os.Stderr, "Opening your browser to complete authorization...\n\n")
200-
fmt.Fprint(os.Stderr, "If your browser doesn't open automatically, please visit this URL:\n\n")
201-
fmt.Fprintf(os.Stderr, " %s\n\n", authURL)
202-
fmt.Fprint(os.Stderr, strings.Repeat("=", 80)+"\n\n")
203-
204-
// Try to open browser
205-
if err := openBrowser(authURL); err != nil {
206-
log.Printf("Warning: Could not open browser automatically: %v", err)
207-
}
208-
209-
// Wait for callback with timeout
210-
var code string
211-
select {
212-
case code = <-codeChan:
213-
// Success
214-
case err := <-errChan:
215-
return nil, fmt.Errorf("callback error: %w", err)
216-
case <-ctx.Done():
217-
return nil, fmt.Errorf("context cancelled: %w", ctx.Err())
218-
case <-time.After(DefaultAuthTimeout):
219-
return nil, fmt.Errorf("authorization timeout after %v", DefaultAuthTimeout)
220-
}
221-
222-
// Shutdown server gracefully
223-
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
224-
defer cancel()
225-
_ = server.Shutdown(shutdownCtx)
226-
227-
// Exchange authorization code for token with PKCE verifier
228-
token, err := oauth2Cfg.Exchange(
229-
ctx,
230-
code,
231-
oauth2.VerifierOption(verifier),
232-
)
233-
if err != nil {
234-
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
235-
}
236-
237-
fmt.Fprint(os.Stderr, "\n✓ Authorization successful!\n\n")
238-
239-
return &Result{
240-
AccessToken: token.AccessToken,
241-
RefreshToken: token.RefreshToken,
242-
TokenType: token.TokenType,
243-
Expiry: token.Expiry,
244-
}, nil
245-
}
246-
24771
// startLocalServer starts a local HTTP server on the specified port
24872
// If port is 0, uses a random available port
24973
func startLocalServer(port int) (net.Listener, int, error) {

internal/oauth/oauth_test.go

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ func TestNewManager(t *testing.T) {
110110
mgr := NewManager(cfg)
111111

112112
assert.NotNil(t, mgr)
113-
assert.Equal(t, cfg.ClientID, mgr.config.ClientID)
114-
assert.Equal(t, cfg.ClientSecret, mgr.config.ClientSecret)
115-
assert.Equal(t, cfg.Scopes, mgr.config.Scopes)
113+
// Test observable behavior, not internal state
116114
assert.False(t, mgr.HasToken())
117115
assert.Empty(t, mgr.GetAccessToken())
118116
}
@@ -180,28 +178,17 @@ func TestManagerSetToken(t *testing.T) {
180178
assert.True(t, mgr.HasToken())
181179
}
182180

183-
func TestGenerateState(t *testing.T) {
184-
state1, err := generateState()
181+
func TestGenerateRandomToken(t *testing.T) {
182+
token1, err := generateRandomToken()
185183
require.NoError(t, err)
186-
require.NotEmpty(t, state1)
184+
require.NotEmpty(t, token1)
187185

188-
// State should be URL-safe base64 encoded
186+
// Token should be URL-safe base64 encoded
189187
// 16 bytes of random data = ~22 chars in base64url
190-
assert.GreaterOrEqual(t, len(state1), 20)
188+
assert.GreaterOrEqual(t, len(token1), 20)
191189

192-
// Each call should produce unique state
193-
state2, err := generateState()
190+
// Each call should produce unique token
191+
token2, err := generateRandomToken()
194192
require.NoError(t, err)
195-
assert.NotEqual(t, state1, state2)
196-
}
197-
198-
func TestGenerateElicitationID(t *testing.T) {
199-
id1, err := generateElicitationID()
200-
require.NoError(t, err)
201-
require.NotEmpty(t, id1)
202-
203-
// Each call should produce unique ID
204-
id2, err := generateElicitationID()
205-
require.NoError(t, err)
206-
assert.NotEqual(t, id1, id2)
193+
assert.NotEqual(t, token1, token2)
207194
}

0 commit comments

Comments
 (0)