11use super :: * ;
22use crate :: util:: Socket ;
33use std:: pin:: Pin ;
4- use std:: sync:: atomic:: { AtomicU32 , Ordering } ;
4+ use std:: sync:: atomic:: { AtomicBool , AtomicU32 , Ordering } ;
55use std:: sync:: Arc ;
66use std:: task:: { Context , Poll } ;
77use tempfile:: tempdir;
88use tokio:: io:: { duplex, AsyncRead , AsyncWrite , DuplexStream } ;
99use tower:: Service ;
10+ use std:: time:: Duration ;
1011
1112#[ tokio:: test]
1213async fn test_sync_context_push_frame ( ) {
@@ -131,6 +132,50 @@ async fn test_sync_context_corrupted_metadata() {
131132 assert_eq ! ( sync_ctx. generation( ) , 1 ) ;
132133}
133134
135+ #[ tokio:: test]
136+ async fn test_sync_context_retry_on_error ( ) {
137+ // Pause time to control it manually
138+ tokio:: time:: pause ( ) ;
139+
140+ let server = MockServer :: start ( ) ;
141+ let temp_dir = tempdir ( ) . unwrap ( ) ;
142+ let db_path = temp_dir. path ( ) . join ( "test.db" ) ;
143+
144+ let sync_ctx = SyncContext :: new (
145+ server. connector ( ) ,
146+ db_path. to_str ( ) . unwrap ( ) . to_string ( ) ,
147+ server. url ( ) ,
148+ None ,
149+ )
150+ . await
151+ . unwrap ( ) ;
152+
153+ let mut sync_ctx = sync_ctx;
154+ let frame = Bytes :: from ( "test frame data" ) ;
155+
156+ // Set server to return errors
157+ server. return_error . store ( true , Ordering :: SeqCst ) ;
158+
159+ // First attempt should fail but retry
160+ let result = sync_ctx. push_one_frame ( frame. clone ( ) , 1 , 0 ) . await ;
161+ assert ! ( result. is_err( ) ) ;
162+
163+ // Advance time to trigger retries faster
164+ tokio:: time:: advance ( Duration :: from_secs ( 2 ) ) . await ;
165+
166+ // Verify multiple requests were made (retries occurred)
167+ assert ! ( server. request_count( ) > 1 ) ;
168+
169+ // Allow the server to succeed
170+ server. return_error . store ( false , Ordering :: SeqCst ) ;
171+
172+ // Next attempt should succeed
173+ let durable_frame = sync_ctx. push_one_frame ( frame, 1 , 0 ) . await . unwrap ( ) ;
174+ sync_ctx. write_metadata ( ) . await . unwrap ( ) ;
175+ assert_eq ! ( durable_frame, 1 ) ;
176+ assert_eq ! ( server. frame_count( ) , 1 ) ;
177+ }
178+
134179#[ test]
135180fn test_hash_verification ( ) {
136181 let mut metadata = MetadataJson {
@@ -212,11 +257,15 @@ struct MockServer {
212257 url : String ,
213258 frame_count : Arc < AtomicU32 > ,
214259 connector : ConnectorService ,
260+ return_error : Arc < AtomicBool > ,
261+ request_count : Arc < AtomicU32 > ,
215262}
216263
217264impl MockServer {
218265 fn start ( ) -> Self {
219266 let frame_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
267+ let return_error = Arc :: new ( AtomicBool :: new ( false ) ) ;
268+ let request_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
220269
221270 // Create the mock connector with Some(client_stream)
222271 let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( 1 ) ;
@@ -227,23 +276,43 @@ impl MockServer {
227276 url : "http://mock.server" . to_string ( ) ,
228277 frame_count : frame_count. clone ( ) ,
229278 connector,
279+ return_error : return_error. clone ( ) ,
280+ request_count : request_count. clone ( ) ,
230281 } ;
231282
232283 // Spawn the server handler
233284 let frame_count_clone = frame_count. clone ( ) ;
285+ let return_error_clone = return_error. clone ( ) ;
286+ let request_count_clone = request_count. clone ( ) ;
234287
235288 tokio:: spawn ( async move {
236289 while let Some ( server_stream) = rx. recv ( ) . await {
237290 let frame_count_clone = frame_count_clone. clone ( ) ;
291+ let return_error_clone = return_error_clone. clone ( ) ;
292+ let request_count_clone = request_count_clone. clone ( ) ;
238293
239294 tokio:: spawn ( async move {
240295 use hyper:: server:: conn:: Http ;
241296 use hyper:: service:: service_fn;
242297
243298 let frame_count_clone = frame_count_clone. clone ( ) ;
299+ let return_error_clone = return_error_clone. clone ( ) ;
300+ let request_count_clone = request_count_clone. clone ( ) ;
244301 let service = service_fn ( move |req : http:: Request < Body > | {
245302 let frame_count = frame_count_clone. clone ( ) ;
303+ let return_error = return_error_clone. clone ( ) ;
304+ let request_count = request_count_clone. clone ( ) ;
246305 async move {
306+ request_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
307+ if return_error. load ( Ordering :: SeqCst ) {
308+ return Ok :: < _ , hyper:: Error > (
309+ http:: Response :: builder ( )
310+ . status ( 500 )
311+ . body ( Body :: from ( "Internal Server Error" ) )
312+ . unwrap ( ) ,
313+ ) ;
314+ }
315+
247316 let current_count = frame_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
248317
249318 if req. uri ( ) . path ( ) . contains ( "/sync/" ) {
@@ -287,6 +356,10 @@ impl MockServer {
287356 fn frame_count ( & self ) -> u32 {
288357 self . frame_count . load ( Ordering :: SeqCst )
289358 }
359+
360+ fn request_count ( & self ) -> u32 {
361+ self . request_count . load ( Ordering :: SeqCst )
362+ }
290363}
291364
292365// Mock connection that implements the Socket trait
0 commit comments