@@ -16,7 +16,7 @@ mod storjwt;
1616
1717use axum:: {
1818 body:: Body ,
19- extract:: { ConnectInfo , DefaultBodyLimit , Multipart , Path } ,
19+ extract:: { ConnectInfo , DefaultBodyLimit , Multipart , Path , State } ,
2020 http:: { header, Method , StatusCode } ,
2121 response:: IntoResponse ,
2222 routing:: { get, post} ,
@@ -31,6 +31,8 @@ use tokio::io::AsyncSeekExt;
3131use tokio_util:: io:: ReaderStream ;
3232use toml:: Table ;
3333use tower:: ServiceBuilder ;
34+ use std:: { collections:: HashMap , sync:: Arc } ;
35+ use tokio:: sync:: { RwLock , Semaphore } ;
3436
3537#[ derive( Parser , Debug ) ]
3638#[ command( version, about, long_about = None ) ]
@@ -61,6 +63,25 @@ struct Args {
6163 generate_jwt_token : String ,
6264}
6365
66+
67+ type FileSemaphores = Arc < RwLock < HashMap < String , Arc < Semaphore > > > > ;
68+
69+ #[ derive( Clone ) ]
70+ struct AppState {
71+ file_locks : FileSemaphores ,
72+ }
73+
74+
75+ async fn get_or_create_semaphore (
76+ locks : & FileSemaphores ,
77+ filename : & str ,
78+ ) -> Arc < Semaphore > {
79+ let mut map = locks. write ( ) . await ;
80+ map. entry ( filename. to_string ( ) )
81+ . or_insert_with ( || Arc :: new ( Semaphore :: new ( 1 ) ) )
82+ . clone ( )
83+ }
84+
6485struct ReceivedFile {
6586 original_filename : String ,
6687 cached_filename : String ,
@@ -177,7 +198,9 @@ async fn main() {
177198 tracing_subscriber:: fmt:: init ( ) ;
178199 let tlscfg = initial_setup ( ) . await ;
179200 let port = 3000 ;
180-
201+ let state = AppState {
202+ file_locks : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
203+ } ;
181204 println ! ( "Starting server, tls: {:?}" , tlscfg) ;
182205
183206 // Supported endpoints:
@@ -193,7 +216,8 @@ async fn main() {
193216 . route ( "/upload" , post ( ax_post_file) )
194217 . route ( "/*filepath" , get ( ax_get_file) )
195218 . route ( "/v1/list" , get ( ax_list_files) )
196- . layer ( ServiceBuilder :: new ( ) . layer ( DefaultBodyLimit :: max ( 1024 * 1024 * 1024 * 4 ) ) ) ;
219+ . layer ( ServiceBuilder :: new ( ) . layer ( DefaultBodyLimit :: max ( 1024 * 1024 * 1024 * 4 ) ) )
220+ . with_state ( state) ;
197221
198222 /*
199223 .layer(SecureClientIpSource::ConnectInfo.into_extension())
@@ -313,7 +337,7 @@ fn verify_upload_permissions(owner: &str, path: &str) -> Result<(), String> {
313337 This function will check if the Authorization header is present and if the token is correct
314338 If the token is correct, it will write the content of the file to the server
315339*/
316- async fn ax_post_file ( headers : HeaderMap , mut multipart : Multipart ) -> ( StatusCode , Vec < u8 > ) {
340+ async fn ax_post_file ( headers : HeaderMap , State ( state ) : State < AppState > , mut multipart : Multipart ) -> ( StatusCode , Vec < u8 > ) {
317341 // call check_auth
318342 let message = verify_auth_hdr ( & headers) ;
319343 let owner = match message {
@@ -337,6 +361,7 @@ async fn ax_post_file(headers: HeaderMap, mut multipart: Multipart) -> (StatusCo
337361 let mut path: String = "" . to_string ( ) ;
338362 let mut file0: Vec < u8 > = Vec :: new ( ) ;
339363 let mut file0_filename: String = "" . to_string ( ) ;
364+ let full_path = format ! ( "{}/{}" , path, file0_filename) ;
340365
341366 // verify upload permissions, some users have upload permissions only for certain prefix(path)
342367 // check config.toml for upload_prefixes
@@ -345,6 +370,16 @@ async fn ax_post_file(headers: HeaderMap, mut multipart: Multipart) -> (StatusCo
345370 Err ( e) => return ( StatusCode :: FORBIDDEN , e. to_string ( ) . into_bytes ( ) ) ,
346371 }
347372
373+ let semaphore = get_or_create_semaphore ( & state. file_locks , & full_path) . await ;
374+
375+ // Try to acquire permit - fails immediately if upload in progress
376+ let permit = match semaphore. try_acquire ( ) {
377+ Ok ( permit) => permit,
378+ Err ( _) => {
379+ return ( StatusCode :: CONFLICT , "Upload already in progress" . to_string ( ) . into_bytes ( ) ) ;
380+ }
381+ } ;
382+
348383 while let Some ( field) = multipart. next_field ( ) . await . unwrap ( ) {
349384 let name = field. name ( ) . unwrap ( ) . to_string ( ) ;
350385 //let filename = field.file_name();
@@ -388,7 +423,7 @@ async fn ax_post_file(headers: HeaderMap, mut multipart: Multipart) -> (StatusCo
388423 println ! ( "Removing trailing /, workaround" ) ;
389424 path. pop ( ) ;
390425 }
391- let full_path = format ! ( "{}/{}" , path , file0_filename ) ;
426+
392427 let hdr_content_type = headers. get ( "Content-Type-Upstream" ) ;
393428 let content_type: String = match hdr_content_type {
394429 Some ( content_type) => {
@@ -442,6 +477,7 @@ async fn ax_get_file(
442477 rxheaders : HeaderMap ,
443478 method : Method ,
444479 ConnectInfo ( remote_addr) : ConnectInfo < SocketAddr > ,
480+ State ( state) : State < AppState > ,
445481) -> impl IntoResponse {
446482 let timestamp = std:: time:: SystemTime :: now ( ) ;
447483 let human_time = chrono:: DateTime :: < chrono:: Utc > :: from ( timestamp) ;
@@ -451,6 +487,21 @@ async fn ax_get_file(
451487 None => "" ,
452488 } ;
453489
490+ let semaphore = get_or_create_semaphore ( & state. file_locks , & filepath) . await ;
491+ // Wait for permit with timeout
492+ let _permit = match tokio:: time:: timeout (
493+ tokio:: time:: Duration :: from_secs ( 30 ) ,
494+ semaphore. acquire ( ) ,
495+ ) . await {
496+ Ok ( Ok ( permit) ) => permit,
497+ Ok ( Err ( _) ) => {
498+ return ( StatusCode :: INTERNAL_SERVER_ERROR , "Semaphore closed" ) . into_response ( ) ;
499+ }
500+ Err ( _) => {
501+ return ( StatusCode :: REQUEST_TIMEOUT , "Timeout waiting for upload" ) . into_response ( ) ;
502+ }
503+ } ;
504+
454505 let received_file = driver_get_file ( filepath. clone ( ) ) ;
455506 if !received_file. valid {
456507 println ! (
0 commit comments