Skip to content

Commit 14c993b

Browse files
committed
Task 8
1 parent 63224bf commit 14c993b

File tree

3 files changed

+330
-2
lines changed

3 files changed

+330
-2
lines changed

Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul
349349
_logger.LogInformation("Starting InvokeOnceAsync");
350350

351351
var invocation = await Client.GetNextInvocationAsync(cancellationToken);
352+
var isMultiConcurrency = Utils.IsUsingMultiConcurrency(_environmentVariables);
352353

353354
Func<Task> processingFunc = async () =>
354355
{
@@ -358,6 +359,18 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul
358359
SetInvocationTraceId(impl.RuntimeApiHeaders.TraceId);
359360
}
360361

362+
// Initialize ResponseStreamFactory — includes RuntimeApiClient reference
363+
var runtimeApiClient = Client as RuntimeApiClient;
364+
if (runtimeApiClient != null)
365+
{
366+
ResponseStreamFactory.InitializeInvocation(
367+
invocation.LambdaContext.AwsRequestId,
368+
StreamingConstants.MaxResponseSize,
369+
isMultiConcurrency,
370+
runtimeApiClient,
371+
cancellationToken);
372+
}
373+
361374
try
362375
{
363376
InvocationResponse response = null;
@@ -372,15 +385,39 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul
372385
catch (Exception exception)
373386
{
374387
WriteUnhandledExceptionToLog(exception);
375-
await Client.ReportInvocationErrorAsync(invocation.LambdaContext.AwsRequestId, exception, cancellationToken);
388+
389+
var streamIfCreated = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency);
390+
if (streamIfCreated != null && streamIfCreated.BytesWritten > 0)
391+
{
392+
// Midstream error — report via trailers on the already-open HTTP connection
393+
await streamIfCreated.ReportErrorAsync(exception);
394+
}
395+
else
396+
{
397+
// Error before streaming started — use standard error reporting
398+
await Client.ReportInvocationErrorAsync(invocation.LambdaContext.AwsRequestId, exception, cancellationToken);
399+
}
376400
}
377401
finally
378402
{
379403
_logger.LogInformation("Finished invoking handler");
380404
}
381405

382-
if (invokeSucceeded)
406+
// If streaming was started, await the HTTP send task to ensure it completes
407+
var sendTask = ResponseStreamFactory.GetSendTask(isMultiConcurrency);
408+
if (sendTask != null)
383409
{
410+
var stream = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency);
411+
if (stream != null && !stream.IsCompleted && !stream.HasError)
412+
{
413+
// Handler returned successfully — signal stream completion
414+
stream.MarkCompleted();
415+
}
416+
await sendTask; // Wait for HTTP request to finish
417+
}
418+
else if (invokeSucceeded)
419+
{
420+
// No streaming — send buffered response
384421
_logger.LogInformation("Starting sending response");
385422
try
386423
{
@@ -415,6 +452,10 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul
415452
}
416453
finally
417454
{
455+
if (runtimeApiClient != null)
456+
{
457+
ResponseStreamFactory.CleanupInvocation(isMultiConcurrency);
458+
}
418459
invocation.Dispose();
419460
}
420461
};

Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
*/
1515
using System;
1616
using System.Collections.Generic;
17+
using System.IO;
1718
using System.Linq;
1819
using System.Net.Http;
1920
using System.Text;
21+
using System.Threading;
2022
using System.Threading.Tasks;
2123
using Xunit;
2224

@@ -283,5 +285,159 @@ public void IsCallPreJitTest()
283285
environmentVariables.SetEnvironmentVariable(ENVIRONMENT_VARIABLE_AWS_LAMBDA_INITIALIZATION_TYPE, AWS_LAMBDA_INITIALIZATION_TYPE_PC);
284286
Assert.True(UserCodeInit.IsCallPreJit(environmentVariables));
285287
}
288+
289+
// --- Streaming Integration Tests ---
290+
291+
private TestStreamingRuntimeApiClient CreateStreamingClient()
292+
{
293+
var envVars = new TestEnvironmentVariables();
294+
var headers = new Dictionary<string, IEnumerable<string>>
295+
{
296+
{ RuntimeApiHeaders.HeaderAwsRequestId, new List<string> { "streaming-request-id" } },
297+
{ RuntimeApiHeaders.HeaderInvokedFunctionArn, new List<string> { "invoked_function_arn" } },
298+
{ RuntimeApiHeaders.HeaderAwsTenantId, new List<string> { "tenant_id" } }
299+
};
300+
return new TestStreamingRuntimeApiClient(envVars, headers);
301+
}
302+
303+
/// <summary>
304+
/// Property 2: CreateStream Enables Streaming Mode
305+
/// When a handler calls ResponseStreamFactory.CreateStream(), the response is transmitted
306+
/// using streaming mode. LambdaBootstrap awaits the send task.
307+
/// **Validates: Requirements 1.4, 6.1, 6.2, 6.3, 6.4**
308+
/// </summary>
309+
[Fact]
310+
public async Task StreamingMode_HandlerCallsCreateStream_SendTaskAwaited()
311+
{
312+
var streamingClient = CreateStreamingClient();
313+
314+
LambdaBootstrapHandler handler = async (invocation) =>
315+
{
316+
var stream = ResponseStreamFactory.CreateStream();
317+
await stream.WriteAsync(Encoding.UTF8.GetBytes("hello"));
318+
return new InvocationResponse(Stream.Null, false);
319+
};
320+
321+
using (var bootstrap = new LambdaBootstrap(handler, null))
322+
{
323+
bootstrap.Client = streamingClient;
324+
await bootstrap.InvokeOnceAsync();
325+
}
326+
327+
Assert.True(streamingClient.StartStreamingResponseAsyncCalled);
328+
Assert.False(streamingClient.SendResponseAsyncCalled);
329+
}
330+
331+
/// <summary>
332+
/// Property 3: Default Mode Is Buffered
333+
/// When a handler does not call ResponseStreamFactory.CreateStream(), the response
334+
/// is transmitted using buffered mode via SendResponseAsync.
335+
/// **Validates: Requirements 1.5, 7.2**
336+
/// </summary>
337+
[Fact]
338+
public async Task BufferedMode_HandlerDoesNotCallCreateStream_UsesSendResponse()
339+
{
340+
var streamingClient = CreateStreamingClient();
341+
342+
LambdaBootstrapHandler handler = async (invocation) =>
343+
{
344+
var outputStream = new MemoryStream(Encoding.UTF8.GetBytes("buffered response"));
345+
return new InvocationResponse(outputStream);
346+
};
347+
348+
using (var bootstrap = new LambdaBootstrap(handler, null))
349+
{
350+
bootstrap.Client = streamingClient;
351+
await bootstrap.InvokeOnceAsync();
352+
}
353+
354+
Assert.False(streamingClient.StartStreamingResponseAsyncCalled);
355+
Assert.True(streamingClient.SendResponseAsyncCalled);
356+
}
357+
358+
/// <summary>
359+
/// Property 14: Exception After Writes Uses Trailers
360+
/// When a handler throws an exception after writing data to an IResponseStream,
361+
/// the error is reported via trailers (ReportErrorAsync) rather than standard error reporting.
362+
/// **Validates: Requirements 5.6, 5.7**
363+
/// </summary>
364+
[Fact]
365+
public async Task MidstreamError_ExceptionAfterWrites_ReportsViaTrailers()
366+
{
367+
var streamingClient = CreateStreamingClient();
368+
369+
LambdaBootstrapHandler handler = async (invocation) =>
370+
{
371+
var stream = ResponseStreamFactory.CreateStream();
372+
await stream.WriteAsync(Encoding.UTF8.GetBytes("partial data"));
373+
throw new InvalidOperationException("midstream failure");
374+
};
375+
376+
using (var bootstrap = new LambdaBootstrap(handler, null))
377+
{
378+
bootstrap.Client = streamingClient;
379+
await bootstrap.InvokeOnceAsync();
380+
}
381+
382+
// Error should be reported via trailers on the stream, not via standard error reporting
383+
Assert.True(streamingClient.StartStreamingResponseAsyncCalled);
384+
Assert.NotNull(streamingClient.LastStreamingResponseStream);
385+
Assert.True(streamingClient.LastStreamingResponseStream.HasError);
386+
Assert.False(streamingClient.ReportInvocationErrorAsyncExceptionCalled);
387+
}
388+
389+
/// <summary>
390+
/// Property 15: Exception Before CreateStream Uses Standard Error
391+
/// When a handler throws an exception before calling ResponseStreamFactory.CreateStream(),
392+
/// the error is reported using the standard Lambda error reporting mechanism.
393+
/// **Validates: Requirements 5.7, 7.1**
394+
/// </summary>
395+
[Fact]
396+
public async Task PreStreamError_ExceptionBeforeCreateStream_UsesStandardErrorReporting()
397+
{
398+
var streamingClient = CreateStreamingClient();
399+
400+
LambdaBootstrapHandler handler = async (invocation) =>
401+
{
402+
await Task.Yield();
403+
throw new InvalidOperationException("pre-stream failure");
404+
};
405+
406+
using (var bootstrap = new LambdaBootstrap(handler, null))
407+
{
408+
bootstrap.Client = streamingClient;
409+
await bootstrap.InvokeOnceAsync();
410+
}
411+
412+
Assert.False(streamingClient.StartStreamingResponseAsyncCalled);
413+
Assert.True(streamingClient.ReportInvocationErrorAsyncExceptionCalled);
414+
}
415+
416+
/// <summary>
417+
/// State Isolation: ResponseStreamFactory state is cleared after each invocation.
418+
/// **Validates: Requirements 6.5, 8.9**
419+
/// </summary>
420+
[Fact]
421+
public async Task Cleanup_ResponseStreamFactoryStateCleared_AfterInvocation()
422+
{
423+
var streamingClient = CreateStreamingClient();
424+
425+
LambdaBootstrapHandler handler = async (invocation) =>
426+
{
427+
var stream = ResponseStreamFactory.CreateStream();
428+
await stream.WriteAsync(Encoding.UTF8.GetBytes("data"));
429+
return new InvocationResponse(Stream.Null, false);
430+
};
431+
432+
using (var bootstrap = new LambdaBootstrap(handler, null))
433+
{
434+
bootstrap.Client = streamingClient;
435+
await bootstrap.InvokeOnceAsync();
436+
}
437+
438+
// After invocation, factory state should be cleaned up
439+
Assert.Null(ResponseStreamFactory.GetStreamIfCreated(false));
440+
Assert.Null(ResponseStreamFactory.GetSendTask(false));
441+
}
286442
}
287443
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
using Amazon.Lambda.RuntimeSupport.Helpers;
17+
using Amazon.Lambda.RuntimeSupport.UnitTests.TestHelpers;
18+
using System;
19+
using System.Collections.Generic;
20+
using System.IO;
21+
using System.Text;
22+
using System.Threading;
23+
using System.Threading.Tasks;
24+
25+
namespace Amazon.Lambda.RuntimeSupport.UnitTests
26+
{
27+
/// <summary>
28+
/// A RuntimeApiClient subclass for testing LambdaBootstrap streaming integration.
29+
/// Extends RuntimeApiClient so the (RuntimeApiClient)Client cast in LambdaBootstrap works.
30+
/// Overrides StartStreamingResponseAsync to avoid real HTTP calls.
31+
/// </summary>
32+
internal class TestStreamingRuntimeApiClient : RuntimeApiClient, IRuntimeApiClient
33+
{
34+
private readonly IEnvironmentVariables _environmentVariables;
35+
private readonly Dictionary<string, IEnumerable<string>> _headers;
36+
37+
public new IConsoleLoggerWriter ConsoleLogger { get; } = new LogLevelLoggerWriter(new SystemEnvironmentVariables());
38+
39+
public TestStreamingRuntimeApiClient(IEnvironmentVariables environmentVariables, Dictionary<string, IEnumerable<string>> headers)
40+
: base(environmentVariables, new NoOpInternalRuntimeApiClient())
41+
{
42+
_environmentVariables = environmentVariables;
43+
_headers = headers;
44+
}
45+
46+
// Tracking flags
47+
public bool GetNextInvocationAsyncCalled { get; private set; }
48+
public bool ReportInitializationErrorAsyncExceptionCalled { get; private set; }
49+
public bool ReportInvocationErrorAsyncExceptionCalled { get; private set; }
50+
public bool SendResponseAsyncCalled { get; private set; }
51+
public bool StartStreamingResponseAsyncCalled { get; private set; }
52+
53+
public string LastTraceId { get; private set; }
54+
public byte[] FunctionInput { get; set; }
55+
public Stream LastOutputStream { get; private set; }
56+
public Exception LastRecordedException { get; private set; }
57+
public ResponseStream LastStreamingResponseStream { get; private set; }
58+
59+
public new async Task<InvocationRequest> GetNextInvocationAsync(CancellationToken cancellationToken = default)
60+
{
61+
GetNextInvocationAsyncCalled = true;
62+
63+
LastTraceId = Guid.NewGuid().ToString();
64+
_headers[RuntimeApiHeaders.HeaderTraceId] = new List<string>() { LastTraceId };
65+
66+
var inputStream = new MemoryStream(FunctionInput == null ? new byte[0] : FunctionInput);
67+
inputStream.Position = 0;
68+
69+
return new InvocationRequest()
70+
{
71+
InputStream = inputStream,
72+
LambdaContext = new LambdaContext(
73+
new RuntimeApiHeaders(_headers),
74+
new LambdaEnvironment(_environmentVariables),
75+
new TestDateTimeHelper(), new SimpleLoggerWriter(_environmentVariables))
76+
};
77+
}
78+
79+
public new Task ReportInitializationErrorAsync(Exception exception, String errorType = null, CancellationToken cancellationToken = default)
80+
{
81+
LastRecordedException = exception;
82+
ReportInitializationErrorAsyncExceptionCalled = true;
83+
return Task.CompletedTask;
84+
}
85+
86+
public new Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default)
87+
{
88+
return Task.CompletedTask;
89+
}
90+
91+
public new Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default)
92+
{
93+
LastRecordedException = exception;
94+
ReportInvocationErrorAsyncExceptionCalled = true;
95+
return Task.CompletedTask;
96+
}
97+
98+
public new async Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default)
99+
{
100+
if (outputStream != null)
101+
{
102+
LastOutputStream = new MemoryStream((int)outputStream.Length);
103+
outputStream.CopyTo(LastOutputStream);
104+
LastOutputStream.Position = 0;
105+
}
106+
107+
SendResponseAsyncCalled = true;
108+
}
109+
110+
internal override async Task StartStreamingResponseAsync(
111+
string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default)
112+
{
113+
StartStreamingResponseAsyncCalled = true;
114+
LastStreamingResponseStream = responseStream;
115+
116+
// Simulate the HTTP stream being available
117+
responseStream.SetHttpOutputStream(new MemoryStream());
118+
119+
// Wait for the handler to finish writing (mirrors real SerializeToStreamAsync behavior)
120+
await responseStream.WaitForCompletionAsync();
121+
}
122+
123+
#if NET8_0_OR_GREATER
124+
public new Task RestoreNextInvocationAsync(CancellationToken cancellationToken = default)
125+
=> Task.CompletedTask;
126+
127+
public new Task ReportRestoreErrorAsync(Exception exception, String errorType = null, CancellationToken cancellationToken = default)
128+
=> Task.CompletedTask;
129+
#endif
130+
}
131+
}

0 commit comments

Comments
 (0)