Skip to content

Commit a405727

Browse files
committed
Override user agent
1 parent 8e8fa59 commit a405727

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

pkg/aws/aws_config.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@ package aws
22

33
import (
44
"context"
5+
56
"github.com/aws/aws-sdk-go-v2/aws"
6-
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
77
"github.com/aws/aws-sdk-go-v2/aws/ratelimit"
88
"github.com/aws/aws-sdk-go-v2/aws/retry"
99
"github.com/aws/aws-sdk-go-v2/config"
1010
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
1111
smithymiddleware "github.com/aws/smithy-go/middleware"
12+
smithyhttp "github.com/aws/smithy-go/transport/http"
1213
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
1314
awsmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/aws"
1415
"sigs.k8s.io/aws-load-balancer-controller/pkg/version"
@@ -50,7 +51,7 @@ func (gen *awsConfigGeneratorImpl) GenerateAWSConfig(optFns ...func(*config.Load
5051
}),
5152
config.WithEC2IMDSEndpointMode(gen.ec2IMDSEndpointMode),
5253
config.WithAPIOptions([]func(stack *smithymiddleware.Stack) error{
53-
awsmiddleware.AddUserAgentKeyValue(userAgent, version.GitVersion),
54+
overrideUserAgentMiddleware(userAgent + "/" + version.GitVersion),
5455
}),
5556
}
5657

@@ -79,3 +80,17 @@ func (gen *awsConfigGeneratorImpl) GenerateAWSConfig(optFns ...func(*config.Load
7980
}
8081

8182
var _ AWSConfigGenerator = &awsConfigGeneratorImpl{}
83+
84+
// overrideUserAgentMiddleware returns a middleware that replaces the User-Agent
85+
// header with the given value, stripping all SDK-generated metadata.
86+
func overrideUserAgentMiddleware(ua string) func(stack *smithymiddleware.Stack) error {
87+
return func(stack *smithymiddleware.Stack) error {
88+
return stack.Build.Add(smithymiddleware.BuildMiddlewareFunc("OverrideUserAgent",
89+
func(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) (smithymiddleware.BuildOutput, smithymiddleware.Metadata, error) {
90+
if req, ok := in.Request.(*smithyhttp.Request); ok {
91+
req.Header.Set("User-Agent", ua)
92+
}
93+
return next.HandleBuild(ctx, in)
94+
}), smithymiddleware.After)
95+
}
96+
}

pkg/aws/aws_config_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package aws
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/url"
8+
"testing"
9+
10+
"github.com/aws/aws-sdk-go-v2/aws"
11+
"github.com/aws/aws-sdk-go-v2/config"
12+
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
13+
"github.com/aws/aws-sdk-go-v2/service/ec2"
14+
smithyendpoints "github.com/aws/smithy-go/endpoints"
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
"sigs.k8s.io/aws-load-balancer-controller/pkg/version"
18+
)
19+
20+
// staticEC2EndpointResolver redirects all ec2 calls to a fixed URL.
21+
type staticEC2EndpointResolver struct{ url string }
22+
23+
func (r staticEC2EndpointResolver) ResolveEndpoint(_ context.Context, _ ec2.EndpointParameters) (smithyendpoints.Endpoint, error) {
24+
u, err := url.Parse(r.url)
25+
if err != nil {
26+
return smithyendpoints.Endpoint{}, err
27+
}
28+
return smithyendpoints.Endpoint{URI: *u}, nil
29+
}
30+
31+
func TestGenerateAWSConfig_UserAgentHeader(t *testing.T) {
32+
tests := []struct {
33+
name string
34+
gitVersion string
35+
wantUA string
36+
}{
37+
{
38+
name: "user agent is exactly elbv2.k8s.aws/<version>",
39+
gitVersion: "v2.14.1",
40+
wantUA: "elbv2.k8s.aws/v2.14.1",
41+
},
42+
{
43+
name: "user agent reflects updated git version",
44+
gitVersion: "v3.0.0",
45+
wantUA: "elbv2.k8s.aws/v3.0.0",
46+
},
47+
}
48+
49+
for _, tt := range tests {
50+
t.Run(tt.name, func(t *testing.T) {
51+
version.GitVersion = tt.gitVersion
52+
53+
var capturedUA string
54+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
55+
capturedUA = r.Header.Get("User-Agent")
56+
w.WriteHeader(http.StatusBadRequest)
57+
}))
58+
defer srv.Close()
59+
60+
gen := NewAWSConfigGenerator(CloudConfig{Region: "us-east-1", MaxRetries: 1}, imds.EndpointModeStateIPv4, nil)
61+
awsCfg, err := gen.GenerateAWSConfig(
62+
config.WithCredentialsProvider(aws.AnonymousCredentials{}),
63+
)
64+
require.NoError(t, err)
65+
66+
ec2Client := ec2.NewFromConfig(awsCfg, func(o *ec2.Options) {
67+
o.EndpointResolverV2 = staticEC2EndpointResolver{url: srv.URL}
68+
})
69+
_, _ = ec2Client.DescribeInstances(t.Context(), &ec2.DescribeInstancesInput{})
70+
71+
require.NotEmpty(t, capturedUA, "no request was captured")
72+
assert.Equal(t, tt.wantUA, capturedUA)
73+
})
74+
}
75+
}

0 commit comments

Comments
 (0)