|
20 | 20 |
|
21 | 21 | use anyhow::{Context, Result, anyhow, bail, ensure}; |
22 | 22 | use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; |
23 | | -use common_traits::{ToBytes, UnsignedInt}; |
| 23 | +use common_traits::{AsBytes, FromBytes, ToBytes, UnsignedInt}; |
24 | 24 | use dsi_bitstream::dispatch::Codes; |
| 25 | +use epserde::deser::Deserialize; |
25 | 26 | use epserde::ser::Serialize; |
26 | | -use std::io::{BufWriter, Write}; |
| 27 | +use std::io::{BufRead, BufReader, BufWriter, Read, Write}; |
27 | 28 | use std::path::{Path, PathBuf}; |
28 | 29 | use std::time::Duration; |
29 | 30 | use std::time::SystemTime; |
@@ -280,6 +281,135 @@ impl FloatVectorFormat { |
280 | 281 |
|
281 | 282 | Ok(()) |
282 | 283 | } |
| 284 | + |
| 285 | + /// Loads float values from the specified `path` using the format defined |
| 286 | + /// by `self`. |
| 287 | + pub fn load<F>(&self, path: impl AsRef<Path>) -> Result<Vec<F>> |
| 288 | + where |
| 289 | + F: FromBytes + std::str::FromStr + Copy, |
| 290 | + <F as AsBytes>::Bytes: for<'a> TryFrom<&'a [u8]>, |
| 291 | + <F as std::str::FromStr>::Err: std::error::Error + Send + Sync + 'static, |
| 292 | + Vec<F>: epserde::deser::Deserialize, |
| 293 | + { |
| 294 | + let path = path.as_ref(); |
| 295 | + let path_display = path.display(); |
| 296 | + |
| 297 | + match self { |
| 298 | + FloatVectorFormat::Epserde => { |
| 299 | + log::info!("Loading ε-serde format from {}", path_display); |
| 300 | + Ok(unsafe { |
| 301 | + <Vec<F>>::load_full(path) |
| 302 | + .with_context(|| format!("Could not load vector from {}", path_display))? |
| 303 | + }) |
| 304 | + } |
| 305 | + FloatVectorFormat::Java => { |
| 306 | + log::info!("Loading Java format from {}", path_display); |
| 307 | + let file = std::fs::File::open(path) |
| 308 | + .with_context(|| format!("Could not open {}", path_display))?; |
| 309 | + let file_len = file.metadata()?.len() as usize; |
| 310 | + let byte_size = size_of::<F>(); |
| 311 | + ensure!( |
| 312 | + file_len % byte_size == 0, |
| 313 | + "File size ({}) is not a multiple of {} bytes", |
| 314 | + file_len, |
| 315 | + byte_size |
| 316 | + ); |
| 317 | + let n = file_len / byte_size; |
| 318 | + let mut reader = BufReader::new(file); |
| 319 | + let mut result = Vec::with_capacity(n); |
| 320 | + let mut buf = vec![0u8; byte_size]; |
| 321 | + for i in 0..n { |
| 322 | + reader.read_exact(&mut buf).with_context(|| { |
| 323 | + format!("Could not read value at index {i} from {}", path_display) |
| 324 | + })?; |
| 325 | + let bytes = buf.as_slice().into(); |
| 326 | + result.push(F::from_be_bytes(bytes)); |
| 327 | + } |
| 328 | + Ok(result) |
| 329 | + } |
| 330 | + FloatVectorFormat::Ascii => { |
| 331 | + log::info!("Loading ASCII format from {}", path_display); |
| 332 | + let file = std::fs::File::open(path) |
| 333 | + .with_context(|| format!("Could not open {}", path_display))?; |
| 334 | + let reader = BufReader::new(file); |
| 335 | + reader |
| 336 | + .lines() |
| 337 | + .enumerate() |
| 338 | + .filter(|(_, line)| line.as_ref().map_or(true, |l| !l.trim().is_empty())) |
| 339 | + .map(|(i, line)| { |
| 340 | + let line = line.with_context(|| { |
| 341 | + format!("Error reading line {} of {}", i + 1, path_display) |
| 342 | + })?; |
| 343 | + line.trim().parse::<F>().map_err(|e| { |
| 344 | + anyhow!("Error parsing line {} of {}: {}", i + 1, path_display, e) |
| 345 | + }) |
| 346 | + }) |
| 347 | + .collect() |
| 348 | + } |
| 349 | + FloatVectorFormat::Json => { |
| 350 | + log::info!("Loading JSON format from {}", path_display); |
| 351 | + let file = std::fs::File::open(path) |
| 352 | + .with_context(|| format!("Could not open {}", path_display))?; |
| 353 | + let mut reader = BufReader::new(file); |
| 354 | + let mut result = Vec::new(); |
| 355 | + let mut byte = [0u8; 1]; |
| 356 | + |
| 357 | + // Skip whitespace and opening bracket |
| 358 | + loop { |
| 359 | + reader |
| 360 | + .read_exact(&mut byte) |
| 361 | + .with_context(|| format!("Unexpected end of file in {}", path_display))?; |
| 362 | + match byte[0] { |
| 363 | + b'[' => break, |
| 364 | + b if b.is_ascii_whitespace() => continue, |
| 365 | + _ => bail!("Expected '[' at start of JSON array in {}", path_display), |
| 366 | + } |
| 367 | + } |
| 368 | + |
| 369 | + // Parse comma-separated values until ']' |
| 370 | + let mut token = String::new(); |
| 371 | + let mut index = 0usize; |
| 372 | + loop { |
| 373 | + reader |
| 374 | + .read_exact(&mut byte) |
| 375 | + .with_context(|| format!("Unexpected end of file in {}", path_display))?; |
| 376 | + match byte[0] { |
| 377 | + b']' => { |
| 378 | + let trimmed = token.trim(); |
| 379 | + if !trimmed.is_empty() { |
| 380 | + result.push(trimmed.parse::<F>().map_err(|e| { |
| 381 | + anyhow!( |
| 382 | + "Error parsing element {} of {}: {}", |
| 383 | + index + 1, |
| 384 | + path_display, |
| 385 | + e |
| 386 | + ) |
| 387 | + })?); |
| 388 | + } |
| 389 | + break; |
| 390 | + } |
| 391 | + b',' => { |
| 392 | + let trimmed = token.trim(); |
| 393 | + result.push(trimmed.parse::<F>().map_err(|e| { |
| 394 | + anyhow!( |
| 395 | + "Error parsing element {} of {}: {}", |
| 396 | + index + 1, |
| 397 | + path_display, |
| 398 | + e |
| 399 | + ) |
| 400 | + })?); |
| 401 | + token.clear(); |
| 402 | + index += 1; |
| 403 | + } |
| 404 | + c => { |
| 405 | + token.push(c as char); |
| 406 | + } |
| 407 | + } |
| 408 | + } |
| 409 | + Ok(result) |
| 410 | + } |
| 411 | + } |
| 412 | + } |
283 | 413 | } |
284 | 414 |
|
285 | 415 | #[derive(Debug, Clone, Copy, ValueEnum)] |
|
0 commit comments