diff --git a/gateway/rpc/proto/gateway_rpc.proto b/gateway/rpc/proto/gateway_rpc.proto index 046b731e..d553ff55 100644 --- a/gateway/rpc/proto/gateway_rpc.proto +++ b/gateway/rpc/proto/gateway_rpc.proto @@ -399,6 +399,8 @@ service Admin { rpc GetGlobalConnections(google.protobuf.Empty) returns (GlobalConnectionsStats) {} // Get all node statuses rpc GetNodeStatuses(google.protobuf.Empty) returns (GetNodeStatusesResponse) {} + // Force remove an instance from the gateway and KvStore. + rpc ForceRemoveInstance(ForceRemoveInstanceRequest) returns (google.protobuf.Empty) {} // ==================== DNS Credential Management ==================== // List all DNS credentials @@ -608,6 +610,11 @@ message ForceReleaseCertLockRequest { string domain = 1; } +// Force remove instance request +message ForceRemoveInstanceRequest { + string instance_id = 1; +} + // Certificate attestation info message CertAttestationInfo { // Certificate public key (DER encoded) diff --git a/gateway/src/admin_service.rs b/gateway/src/admin_service.rs index 8a02bdb1..9cf0bb9a 100644 --- a/gateway/src/admin_service.rs +++ b/gateway/src/admin_service.rs @@ -10,11 +10,11 @@ use dstack_gateway_rpc::{ admin_server::{AdminRpc, AdminServer}, CertAttestationInfo, CertbotConfigResponse, ClearInstancePortPolicyRequest, CreateDnsCredentialRequest, DeleteDnsCredentialRequest, DeleteZtDomainRequest, - DnsCredentialInfo, ForceReleaseCertLockRequest, GetDefaultDnsCredentialResponse, - GetDnsCredentialRequest, GetInfoRequest, GetInfoResponse, GetInstanceHandshakesRequest, - GetInstanceHandshakesResponse, GetInstancePortPolicyRequest, GetInstancePortPolicyResponse, - GetMetaResponse, GetNodeStatusesResponse, GetZtDomainRequest, GlobalConnectionsStats, - HandshakeEntry, HostInfo, LastSeenEntry, ListCertAttestationsRequest, + DnsCredentialInfo, ForceReleaseCertLockRequest, ForceRemoveInstanceRequest, + GetDefaultDnsCredentialResponse, GetDnsCredentialRequest, GetInfoRequest, GetInfoResponse, + GetInstanceHandshakesRequest, GetInstanceHandshakesResponse, GetInstancePortPolicyRequest, + GetInstancePortPolicyResponse, GetMetaResponse, GetNodeStatusesResponse, GetZtDomainRequest, + GlobalConnectionsStats, HandshakeEntry, HostInfo, LastSeenEntry, ListCertAttestationsRequest, ListCertAttestationsResponse, ListDnsCredentialsResponse, ListZtDomainsResponse, NodeStatusEntry, PeerSyncStatus as ProtoPeerSyncStatus, PortAttrs as RpcPortAttrs, PortPolicy as RpcPortPolicy, RenewCertResponse, RenewZtDomainCertRequest, @@ -292,6 +292,17 @@ impl AdminRpc for AdminRpcHandler { Ok(GetNodeStatusesResponse { statuses: entries }) } + async fn force_remove_instance(self, request: ForceRemoveInstanceRequest) -> Result<()> { + if request.instance_id.is_empty() { + bail!("instance_id is empty"); + } + self.state + .lock() + .force_remove_instance(&request.instance_id)?; + info!("Force removed instance: {}", request.instance_id); + Ok(()) + } + // ==================== DNS Credential Management ==================== async fn list_dns_credentials(self) -> Result { diff --git a/gateway/src/main_service.rs b/gateway/src/main_service.rs index 120048a6..40f1395b 100644 --- a/gateway/src/main_service.rs +++ b/gateway/src/main_service.rs @@ -866,6 +866,14 @@ impl ProxyState { None } + fn ip_claimed_by_other(&self, id: &str, ip: Ipv4Addr) -> Option { + self.state + .instances + .iter() + .find(|(other_id, info)| other_id.as_str() != id && info.ip == ip) + .map(|(other_id, _)| other_id.clone()) + } + fn new_client_by_id( &mut self, id: &str, @@ -909,7 +917,9 @@ impl ProxyState { existing.port_policy = port_policy.clone(); } let existing = existing.clone(); - if self.valid_ip(existing.ip) { + let valid_ip = self.valid_ip(existing.ip); + let claimed_by = self.ip_claimed_by_other(id, existing.ip); + if valid_ip && claimed_by.is_none() { // Sync existing instance to KvStore (might be from legacy state) let data = InstanceData { app_id: existing.app_id.clone(), @@ -925,8 +935,22 @@ impl ProxyState { } return Ok(existing); } - info!("ip {} is invalid, removing", existing.ip); - self.state.allocated_addresses.remove(&existing.ip); + if let Some(claimed_by) = claimed_by { + warn!( + "ip {} for instance {id} is already claimed by instance {claimed_by}, reallocating", + existing.ip + ); + } else { + info!("ip {} is invalid, removing", existing.ip); + self.state.allocated_addresses.remove(&existing.ip); + } + self.state.instances.remove(id); + if let Some(app_instances) = self.state.apps.get_mut(&existing.app_id) { + app_instances.remove(id); + if app_instances.is_empty() { + self.state.apps.remove(&existing.app_id); + } + } } let ip = self .alloc_ip() @@ -1253,6 +1277,11 @@ impl ProxyState { Ok(()) } + pub(crate) fn force_remove_instance(&mut self, id: &str) -> Result<()> { + self.remove_instance(id)?; + self.reconfigure() + } + fn recycle(&mut self) -> Result<()> { // Refresh state: sync local handshakes to KvStore, update local last_seen from global if let Err(err) = self.refresh_state() { diff --git a/gateway/src/main_service/tests.rs b/gateway/src/main_service/tests.rs index ec7f8783..760c48be 100644 --- a/gateway/src/main_service/tests.rs +++ b/gateway/src/main_service/tests.rs @@ -25,6 +25,11 @@ async fn create_test_state() -> TestState { let mut config = figment.focus("core").extract::().unwrap(); let temp_dir = TempDir::new().expect("failed to create temp dir"); config.sync.data_dir = temp_dir.path().to_string_lossy().to_string(); + config.wg.config_path = temp_dir + .path() + .join("wg.conf") + .to_string_lossy() + .to_string(); let options = ProxyOptions { config, my_app_id: None, @@ -62,6 +67,37 @@ fn policy(restrict: bool, ports: &[u16]) -> PortPolicy { } } +fn insert_instance_with_ip( + proxy: &mut ProxyState, + id: &str, + app_id: &str, + ip: Ipv4Addr, + public_key: &str, +) { + let info = InstanceInfo { + id: id.to_string(), + app_id: app_id.to_string(), + ip, + public_key: public_key.to_string(), + reg_time: SystemTime::now(), + port_policy: None, + port_policy_hash: String::new(), + admin_port_policy: None, + connections: Default::default(), + }; + proxy + .state + .apps + .entry(info.app_id.clone()) + .or_default() + .insert(info.id.clone()); + proxy.state.instances.insert(info.id.clone(), info); +} + +fn allowed_ip_count(wg_config: &str, ip: Ipv4Addr) -> usize { + wg_config.matches(&format!("AllowedIPs = {ip}/32")).count() +} + #[tokio::test] async fn test_port_policy_restrict_mode_allows_listed_only() { let state = create_test_state().await; @@ -264,3 +300,56 @@ async fn test_config() { let wg_config = state.lock().generate_wg_config().unwrap(); insta::assert_snapshot!(wg_config); } + +#[tokio::test] +async fn test_reregister_reallocates_ip_claimed_by_other_instance() { + let state = create_test_state().await; + let first = state + .lock() + .new_client_by_id("inst-a", "app-a", "pubkey-a", "hash-a", None) + .unwrap(); + + { + let mut proxy = state.lock(); + insert_instance_with_ip(&mut proxy, "inst-b", "app-b", first.ip, "pubkey-b-old"); + } + + let second = state + .lock() + .new_client_by_id("inst-b", "app-b", "pubkey-b", "hash-b", None) + .unwrap(); + + assert_ne!(second.ip, first.ip); + let proxy = state.lock(); + assert_eq!(proxy.state.instances["inst-a"].ip, first.ip); + assert_eq!(proxy.state.instances["inst-b"].ip, second.ip); + assert!(proxy.state.allocated_addresses.contains(&first.ip)); + assert!(proxy.state.allocated_addresses.contains(&second.ip)); +} + +#[tokio::test] +async fn test_force_remove_instance_resolves_duplicate_allowed_ip() { + let state = create_test_state().await; + let first = state + .lock() + .new_client_by_id("inst-a", "app-a", "pubkey-a", "hash-a", None) + .unwrap(); + + let mut proxy = state.lock(); + insert_instance_with_ip( + &mut proxy, + "inst-stale", + "app-stale", + first.ip, + "pubkey-stale", + ); + + let wg_config = proxy.generate_wg_config().unwrap(); + assert_eq!(allowed_ip_count(&wg_config, first.ip), 2); + + proxy.force_remove_instance("inst-stale").unwrap(); + let wg_config = proxy.generate_wg_config().unwrap(); + assert_eq!(allowed_ip_count(&wg_config, first.ip), 1); + assert!(proxy.state.instances.contains_key("inst-a")); + assert!(!proxy.state.instances.contains_key("inst-stale")); +}