Skip to content

Commit a317aa8

Browse files
committed
Introduce semaphores to prevent incomplete downloads
We have situations, when file upload is not complete yet, but download trying to fetch it and getting incomplete file. Let's make sure upload is completed, by using semaphores and "locking" url if it is being uploaded. Signed-off-by: Denys Fedoryshchenko <denys.f@collabora.com>
1 parent 1dfcb6e commit a317aa8

1 file changed

Lines changed: 56 additions & 5 deletions

File tree

src/main.rs

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ mod storjwt;
1616

1717
use axum::{
1818
body::Body,
19-
extract::{ConnectInfo, DefaultBodyLimit, Multipart, Path},
19+
extract::{ConnectInfo, DefaultBodyLimit, Multipart, Path, State},
2020
http::{header, Method, StatusCode},
2121
response::IntoResponse,
2222
routing::{get, post},
@@ -31,6 +31,8 @@ use tokio::io::AsyncSeekExt;
3131
use tokio_util::io::ReaderStream;
3232
use toml::Table;
3333
use tower::ServiceBuilder;
34+
use std::{collections::HashMap, sync::Arc};
35+
use tokio::sync::{RwLock, Semaphore};
3436

3537
#[derive(Parser, Debug)]
3638
#[command(version, about, long_about = None)]
@@ -61,6 +63,25 @@ struct Args {
6163
generate_jwt_token: String,
6264
}
6365

66+
67+
type FileSemaphores = Arc<RwLock<HashMap<String, Arc<Semaphore>>>>;
68+
69+
#[derive(Clone)]
70+
struct AppState {
71+
file_locks: FileSemaphores,
72+
}
73+
74+
75+
async fn get_or_create_semaphore(
76+
locks: &FileSemaphores,
77+
filename: &str,
78+
) -> Arc<Semaphore> {
79+
let mut map = locks.write().await;
80+
map.entry(filename.to_string())
81+
.or_insert_with(|| Arc::new(Semaphore::new(1)))
82+
.clone()
83+
}
84+
6485
struct ReceivedFile {
6586
original_filename: String,
6687
cached_filename: String,
@@ -177,7 +198,9 @@ async fn main() {
177198
tracing_subscriber::fmt::init();
178199
let tlscfg = initial_setup().await;
179200
let port = 3000;
180-
201+
let state = AppState {
202+
file_locks: Arc::new(RwLock::new(HashMap::new())),
203+
};
181204
println!("Starting server, tls: {:?}", tlscfg);
182205

183206
// Supported endpoints:
@@ -193,7 +216,8 @@ async fn main() {
193216
.route("/upload", post(ax_post_file))
194217
.route("/*filepath", get(ax_get_file))
195218
.route("/v1/list", get(ax_list_files))
196-
.layer(ServiceBuilder::new().layer(DefaultBodyLimit::max(1024 * 1024 * 1024 * 4)));
219+
.layer(ServiceBuilder::new().layer(DefaultBodyLimit::max(1024 * 1024 * 1024 * 4)))
220+
.with_state(state);
197221

198222
/*
199223
.layer(SecureClientIpSource::ConnectInfo.into_extension())
@@ -313,7 +337,7 @@ fn verify_upload_permissions(owner: &str, path: &str) -> Result<(), String> {
313337
This function will check if the Authorization header is present and if the token is correct
314338
If the token is correct, it will write the content of the file to the server
315339
*/
316-
async fn ax_post_file(headers: HeaderMap, mut multipart: Multipart) -> (StatusCode, Vec<u8>) {
340+
async fn ax_post_file(headers: HeaderMap, State(state): State<AppState>, mut multipart: Multipart) -> (StatusCode, Vec<u8>) {
317341
// call check_auth
318342
let message = verify_auth_hdr(&headers);
319343
let owner = match message {
@@ -337,6 +361,7 @@ async fn ax_post_file(headers: HeaderMap, mut multipart: Multipart) -> (StatusCo
337361
let mut path: String = "".to_string();
338362
let mut file0: Vec<u8> = Vec::new();
339363
let mut file0_filename: String = "".to_string();
364+
let full_path = format!("{}/{}", path, file0_filename);
340365

341366
// verify upload permissions, some users have upload permissions only for certain prefix(path)
342367
// check config.toml for upload_prefixes
@@ -345,6 +370,16 @@ async fn ax_post_file(headers: HeaderMap, mut multipart: Multipart) -> (StatusCo
345370
Err(e) => return (StatusCode::FORBIDDEN, e.to_string().into_bytes()),
346371
}
347372

373+
let semaphore = get_or_create_semaphore(&state.file_locks, &full_path).await;
374+
375+
// Try to acquire permit - fails immediately if upload in progress
376+
let permit = match semaphore.try_acquire() {
377+
Ok(permit) => permit,
378+
Err(_) => {
379+
return (StatusCode::CONFLICT, "Upload already in progress".to_string().into_bytes());
380+
}
381+
};
382+
348383
while let Some(field) = multipart.next_field().await.unwrap() {
349384
let name = field.name().unwrap().to_string();
350385
//let filename = field.file_name();
@@ -388,7 +423,7 @@ async fn ax_post_file(headers: HeaderMap, mut multipart: Multipart) -> (StatusCo
388423
println!("Removing trailing /, workaround");
389424
path.pop();
390425
}
391-
let full_path = format!("{}/{}", path, file0_filename);
426+
392427
let hdr_content_type = headers.get("Content-Type-Upstream");
393428
let content_type: String = match hdr_content_type {
394429
Some(content_type) => {
@@ -442,6 +477,7 @@ async fn ax_get_file(
442477
rxheaders: HeaderMap,
443478
method: Method,
444479
ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
480+
State(state): State<AppState>,
445481
) -> impl IntoResponse {
446482
let timestamp = std::time::SystemTime::now();
447483
let human_time = chrono::DateTime::<chrono::Utc>::from(timestamp);
@@ -451,6 +487,21 @@ async fn ax_get_file(
451487
None => "",
452488
};
453489

490+
let semaphore = get_or_create_semaphore(&state.file_locks, &filepath).await;
491+
// Wait for permit with timeout
492+
let _permit = match tokio::time::timeout(
493+
tokio::time::Duration::from_secs(30),
494+
semaphore.acquire(),
495+
).await {
496+
Ok(Ok(permit)) => permit,
497+
Ok(Err(_)) => {
498+
return (StatusCode::INTERNAL_SERVER_ERROR, "Semaphore closed").into_response();
499+
}
500+
Err(_) => {
501+
return (StatusCode::REQUEST_TIMEOUT, "Timeout waiting for upload").into_response();
502+
}
503+
};
504+
454505
let received_file = driver_get_file(filepath.clone());
455506
if !received_file.valid {
456507
println!(

0 commit comments

Comments
 (0)