Skip to content

Commit 44cdde5

Browse files
committed
feat(libsql-server): send correct sse response
1 parent f2271b0 commit 44cdde5

1 file changed

Lines changed: 61 additions & 23 deletions

File tree

libsql-server/src/http/user/listen.rs

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@ use crate::{
66
namespace::{NamespaceName, NamespaceStore},
77
};
88
use axum::extract::State as AxumState;
9-
use axum::http::header::{CACHE_CONTROL, CONTENT_TYPE};
10-
use axum::http::{HeaderValue, Uri};
11-
use axum::response::{IntoResponse, Redirect, Response};
12-
use axum_extra::{extract::Query, json_lines::JsonLines};
9+
use axum::http::Uri;
10+
use axum::response::{
11+
sse::{Event, Sse},
12+
IntoResponse, Redirect,
13+
};
14+
use axum_extra::extract::Query;
1315
use futures::{Stream, StreamExt};
1416
use hyper::HeaderMap;
1517
use serde::{Deserialize, Serialize};
18+
use std::boxed::Box;
19+
use std::convert::Infallible;
20+
use std::pin::Pin;
21+
use std::time::Duration;
1622
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
1723

1824
use super::db_factory::namespace_from_headers;
@@ -33,16 +39,13 @@ pub struct ListenQuery {
3339
action: Option<Vec<Action>>,
3440
}
3541

36-
const EVENT_STREAM: HeaderValue = HeaderValue::from_static("text/event-stream");
37-
const NO_CACHE: HeaderValue = HeaderValue::from_static("no-cache");
38-
3942
pub(super) async fn handle_listen(
4043
auth: Authenticated,
4144
AxumState(state): AxumState<AppState>,
4245
headers: HeaderMap,
4346
uri: Uri,
4447
query: Query<ListenQuery>,
45-
) -> crate::Result<Response> {
48+
) -> crate::Result<impl IntoResponse> {
4649
let namespace = namespace_from_headers(
4750
&headers,
4851
state.disable_default_namespace,
@@ -55,23 +58,24 @@ pub(super) async fn handle_listen(
5558

5659
if let Some(primary_url) = state.primary_url {
5760
let url = primary_url + uri.path_and_query().map_or("", |x| x.as_str());
58-
return Ok(Redirect::temporary(&url).into_response());
61+
return Ok(ListenResponse::Redirect(Redirect::temporary(&url)));
5962
}
6063

61-
let stream = listen_stream(
64+
let stream = sse_stream(
6265
state.namespaces.clone(),
6366
namespace,
6467
query.table.clone(),
6568
query.action.clone(),
6669
)
6770
.await;
6871

69-
let mut response = JsonLines::new(stream).into_response();
70-
let headers = response.headers_mut();
71-
headers.insert(CONTENT_TYPE, EVENT_STREAM);
72-
headers.insert(CACHE_CONTROL, NO_CACHE);
73-
74-
Ok(response)
72+
Ok(ListenResponse::SSE(
73+
Sse::new(stream).keep_alive(
74+
axum::response::sse::KeepAlive::new()
75+
.interval(Duration::from_secs(15))
76+
.text("keep-alive"),
77+
),
78+
))
7579
}
7680

7781
static LAGGED_MSG: &str = "some changes were lost";
@@ -96,6 +100,27 @@ impl Drop for Subscription {
96100
}
97101
}
98102

103+
async fn sse_stream(
104+
store: NamespaceStore,
105+
namespace: NamespaceName,
106+
table: String,
107+
actions: Option<Vec<Action>>,
108+
) -> SseStream {
109+
Box::pin(
110+
listen_stream(store, namespace, table, actions)
111+
.await
112+
.map(|result| {
113+
Ok(match result {
114+
Ok(AggregatorEvent::Error(msg)) => Event::default().event("error").data(msg),
115+
Ok(AggregatorEvent::Changes(msg)) => {
116+
Event::default().event("changes").json_data(msg).unwrap()
117+
}
118+
Err(e) => Event::default().event("error").data(e.to_string()),
119+
})
120+
}),
121+
)
122+
}
123+
99124
async fn listen_stream(
100125
store: NamespaceStore,
101126
namespace: NamespaceName,
@@ -119,7 +144,7 @@ async fn listen_stream(
119144
},
120145
Err(BroadcastStreamRecvError::Lagged(n)) => {
121146
LISTEN_EVENTS_DROPPED.increment(n as u64);
122-
yield AggregatorEvent::Error(&LAGGED_MSG);
147+
yield AggregatorEvent::Error(LAGGED_MSG);
123148
},
124149
}
125150
}
@@ -128,17 +153,30 @@ async fn listen_stream(
128153

129154
fn filter_actions(msg: &BroadcastMsg, actions: &Option<Vec<Action>>) -> bool {
130155
actions.as_ref().map_or(true, |actions| {
131-
for action in actions {
156+
actions.iter().any(|action| {
132157
let count = match action {
133158
Action::DELETE => msg.delete,
134159
Action::INSERT => msg.insert,
135160
Action::UPDATE => msg.update,
136161
Action::UNKNOWN => msg.unknown,
137162
};
138-
if count > 0 {
139-
return true;
140-
}
141-
}
142-
false
163+
count > 0
164+
})
143165
})
144166
}
167+
168+
type SseStream = Pin<Box<dyn Stream<Item = Result<Event, Infallible>> + Send>>;
169+
170+
enum ListenResponse {
171+
SSE(Sse<SseStream>),
172+
Redirect(Redirect),
173+
}
174+
175+
impl IntoResponse for ListenResponse {
176+
fn into_response(self) -> axum::response::Response {
177+
match self {
178+
ListenResponse::SSE(sse) => sse.into_response(),
179+
ListenResponse::Redirect(redirect) => redirect.into_response(),
180+
}
181+
}
182+
}

0 commit comments

Comments
 (0)