Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions gateway/rpc/proto/gateway_rpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ service Admin {
rpc GetGlobalConnections(google.protobuf.Empty) returns (GlobalConnectionsStats) {}
// Get all node statuses
rpc GetNodeStatuses(google.protobuf.Empty) returns (GetNodeStatusesResponse) {}
// Force remove an instance from the gateway and KvStore.
rpc ForceRemoveInstance(ForceRemoveInstanceRequest) returns (google.protobuf.Empty) {}

// ==================== DNS Credential Management ====================
// List all DNS credentials
Expand Down Expand Up @@ -608,6 +610,11 @@ message ForceReleaseCertLockRequest {
string domain = 1;
}

// Force remove instance request
message ForceRemoveInstanceRequest {
string instance_id = 1;
}

// Certificate attestation info
message CertAttestationInfo {
// Certificate public key (DER encoded)
Expand Down
21 changes: 16 additions & 5 deletions gateway/src/admin_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use dstack_gateway_rpc::{
admin_server::{AdminRpc, AdminServer},
CertAttestationInfo, CertbotConfigResponse, ClearInstancePortPolicyRequest,
CreateDnsCredentialRequest, DeleteDnsCredentialRequest, DeleteZtDomainRequest,
DnsCredentialInfo, ForceReleaseCertLockRequest, GetDefaultDnsCredentialResponse,
GetDnsCredentialRequest, GetInfoRequest, GetInfoResponse, GetInstanceHandshakesRequest,
GetInstanceHandshakesResponse, GetInstancePortPolicyRequest, GetInstancePortPolicyResponse,
GetMetaResponse, GetNodeStatusesResponse, GetZtDomainRequest, GlobalConnectionsStats,
HandshakeEntry, HostInfo, LastSeenEntry, ListCertAttestationsRequest,
DnsCredentialInfo, ForceReleaseCertLockRequest, ForceRemoveInstanceRequest,
GetDefaultDnsCredentialResponse, GetDnsCredentialRequest, GetInfoRequest, GetInfoResponse,
GetInstanceHandshakesRequest, GetInstanceHandshakesResponse, GetInstancePortPolicyRequest,
GetInstancePortPolicyResponse, GetMetaResponse, GetNodeStatusesResponse, GetZtDomainRequest,
GlobalConnectionsStats, HandshakeEntry, HostInfo, LastSeenEntry, ListCertAttestationsRequest,
ListCertAttestationsResponse, ListDnsCredentialsResponse, ListZtDomainsResponse,
NodeStatusEntry, PeerSyncStatus as ProtoPeerSyncStatus, PortAttrs as RpcPortAttrs,
PortPolicy as RpcPortPolicy, RenewCertResponse, RenewZtDomainCertRequest,
Expand Down Expand Up @@ -292,6 +292,17 @@ impl AdminRpc for AdminRpcHandler {
Ok(GetNodeStatusesResponse { statuses: entries })
}

async fn force_remove_instance(self, request: ForceRemoveInstanceRequest) -> Result<()> {
if request.instance_id.is_empty() {
bail!("instance_id is empty");
}
self.state
.lock()
.force_remove_instance(&request.instance_id)?;
info!("Force removed instance: {}", request.instance_id);
Ok(())
}

// ==================== DNS Credential Management ====================

async fn list_dns_credentials(self) -> Result<ListDnsCredentialsResponse> {
Expand Down
35 changes: 32 additions & 3 deletions gateway/src/main_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,14 @@ impl ProxyState {
None
}

fn ip_claimed_by_other(&self, id: &str, ip: Ipv4Addr) -> Option<String> {
self.state
.instances
.iter()
.find(|(other_id, info)| other_id.as_str() != id && info.ip == ip)
.map(|(other_id, _)| other_id.clone())
}

fn new_client_by_id(
&mut self,
id: &str,
Expand Down Expand Up @@ -909,7 +917,9 @@ impl ProxyState {
existing.port_policy = port_policy.clone();
}
let existing = existing.clone();
if self.valid_ip(existing.ip) {
let valid_ip = self.valid_ip(existing.ip);
let claimed_by = self.ip_claimed_by_other(id, existing.ip);
if valid_ip && claimed_by.is_none() {
// Sync existing instance to KvStore (might be from legacy state)
let data = InstanceData {
app_id: existing.app_id.clone(),
Expand All @@ -925,8 +935,22 @@ impl ProxyState {
}
return Ok(existing);
}
info!("ip {} is invalid, removing", existing.ip);
self.state.allocated_addresses.remove(&existing.ip);
if let Some(claimed_by) = claimed_by {
warn!(
"ip {} for instance {id} is already claimed by instance {claimed_by}, reallocating",
existing.ip
);
} else {
info!("ip {} is invalid, removing", existing.ip);
self.state.allocated_addresses.remove(&existing.ip);
}
self.state.instances.remove(id);
if let Some(app_instances) = self.state.apps.get_mut(&existing.app_id) {
app_instances.remove(id);
if app_instances.is_empty() {
self.state.apps.remove(&existing.app_id);
}
}
}
let ip = self
.alloc_ip()
Expand Down Expand Up @@ -1253,6 +1277,11 @@ impl ProxyState {
Ok(())
}

pub(crate) fn force_remove_instance(&mut self, id: &str) -> Result<()> {
self.remove_instance(id)?;
self.reconfigure()
}

fn recycle(&mut self) -> Result<()> {
// Refresh state: sync local handshakes to KvStore, update local last_seen from global
if let Err(err) = self.refresh_state() {
Expand Down
89 changes: 89 additions & 0 deletions gateway/src/main_service/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ async fn create_test_state() -> TestState {
let mut config = figment.focus("core").extract::<Config>().unwrap();
let temp_dir = TempDir::new().expect("failed to create temp dir");
config.sync.data_dir = temp_dir.path().to_string_lossy().to_string();
config.wg.config_path = temp_dir
.path()
.join("wg.conf")
.to_string_lossy()
.to_string();
let options = ProxyOptions {
config,
my_app_id: None,
Expand Down Expand Up @@ -62,6 +67,37 @@ fn policy(restrict: bool, ports: &[u16]) -> PortPolicy {
}
}

fn insert_instance_with_ip(
proxy: &mut ProxyState,
id: &str,
app_id: &str,
ip: Ipv4Addr,
public_key: &str,
) {
let info = InstanceInfo {
id: id.to_string(),
app_id: app_id.to_string(),
ip,
public_key: public_key.to_string(),
reg_time: SystemTime::now(),
port_policy: None,
port_policy_hash: String::new(),
admin_port_policy: None,
connections: Default::default(),
};
proxy
.state
.apps
.entry(info.app_id.clone())
.or_default()
.insert(info.id.clone());
proxy.state.instances.insert(info.id.clone(), info);
}

fn allowed_ip_count(wg_config: &str, ip: Ipv4Addr) -> usize {
wg_config.matches(&format!("AllowedIPs = {ip}/32")).count()
}

#[tokio::test]
async fn test_port_policy_restrict_mode_allows_listed_only() {
let state = create_test_state().await;
Expand Down Expand Up @@ -264,3 +300,56 @@ async fn test_config() {
let wg_config = state.lock().generate_wg_config().unwrap();
insta::assert_snapshot!(wg_config);
}

#[tokio::test]
async fn test_reregister_reallocates_ip_claimed_by_other_instance() {
let state = create_test_state().await;
let first = state
.lock()
.new_client_by_id("inst-a", "app-a", "pubkey-a", "hash-a", None)
.unwrap();

{
let mut proxy = state.lock();
insert_instance_with_ip(&mut proxy, "inst-b", "app-b", first.ip, "pubkey-b-old");
}

let second = state
.lock()
.new_client_by_id("inst-b", "app-b", "pubkey-b", "hash-b", None)
.unwrap();

assert_ne!(second.ip, first.ip);
let proxy = state.lock();
assert_eq!(proxy.state.instances["inst-a"].ip, first.ip);
assert_eq!(proxy.state.instances["inst-b"].ip, second.ip);
assert!(proxy.state.allocated_addresses.contains(&first.ip));
assert!(proxy.state.allocated_addresses.contains(&second.ip));
}

#[tokio::test]
async fn test_force_remove_instance_resolves_duplicate_allowed_ip() {
let state = create_test_state().await;
let first = state
.lock()
.new_client_by_id("inst-a", "app-a", "pubkey-a", "hash-a", None)
.unwrap();

let mut proxy = state.lock();
insert_instance_with_ip(
&mut proxy,
"inst-stale",
"app-stale",
first.ip,
"pubkey-stale",
);

let wg_config = proxy.generate_wg_config().unwrap();
assert_eq!(allowed_ip_count(&wg_config, first.ip), 2);

proxy.force_remove_instance("inst-stale").unwrap();
let wg_config = proxy.generate_wg_config().unwrap();
assert_eq!(allowed_ip_count(&wg_config, first.ip), 1);
assert!(proxy.state.instances.contains_key("inst-a"));
assert!(!proxy.state.instances.contains_key("inst-stale"));
}
Loading