Skip to content

Commit 8e97eec

Browse files
committed
Generic support for loading vectors, too
1 parent 0a74db7 commit 8e97eec

2 files changed

Lines changed: 143 additions & 30 deletions

File tree

cli/src/lib.rs

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020

2121
use anyhow::{Context, Result, anyhow, bail, ensure};
2222
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
23-
use common_traits::{ToBytes, UnsignedInt};
23+
use common_traits::{AsBytes, FromBytes, ToBytes, UnsignedInt};
2424
use dsi_bitstream::dispatch::Codes;
25+
use epserde::deser::Deserialize;
2526
use epserde::ser::Serialize;
26-
use std::io::{BufWriter, Write};
27+
use std::io::{BufRead, BufReader, BufWriter, Read, Write};
2728
use std::path::{Path, PathBuf};
2829
use std::time::Duration;
2930
use std::time::SystemTime;
@@ -280,6 +281,135 @@ impl FloatVectorFormat {
280281

281282
Ok(())
282283
}
284+
285+
/// Loads float values from the specified `path` using the format defined
286+
/// by `self`.
287+
pub fn load<F>(&self, path: impl AsRef<Path>) -> Result<Vec<F>>
288+
where
289+
F: FromBytes + std::str::FromStr + Copy,
290+
<F as AsBytes>::Bytes: for<'a> TryFrom<&'a [u8]>,
291+
<F as std::str::FromStr>::Err: std::error::Error + Send + Sync + 'static,
292+
Vec<F>: epserde::deser::Deserialize,
293+
{
294+
let path = path.as_ref();
295+
let path_display = path.display();
296+
297+
match self {
298+
FloatVectorFormat::Epserde => {
299+
log::info!("Loading ε-serde format from {}", path_display);
300+
Ok(unsafe {
301+
<Vec<F>>::load_full(path)
302+
.with_context(|| format!("Could not load vector from {}", path_display))?
303+
})
304+
}
305+
FloatVectorFormat::Java => {
306+
log::info!("Loading Java format from {}", path_display);
307+
let file = std::fs::File::open(path)
308+
.with_context(|| format!("Could not open {}", path_display))?;
309+
let file_len = file.metadata()?.len() as usize;
310+
let byte_size = size_of::<F>();
311+
ensure!(
312+
file_len % byte_size == 0,
313+
"File size ({}) is not a multiple of {} bytes",
314+
file_len,
315+
byte_size
316+
);
317+
let n = file_len / byte_size;
318+
let mut reader = BufReader::new(file);
319+
let mut result = Vec::with_capacity(n);
320+
let mut buf = vec![0u8; byte_size];
321+
for i in 0..n {
322+
reader.read_exact(&mut buf).with_context(|| {
323+
format!("Could not read value at index {i} from {}", path_display)
324+
})?;
325+
let bytes = buf.as_slice().into();
326+
result.push(F::from_be_bytes(bytes));
327+
}
328+
Ok(result)
329+
}
330+
FloatVectorFormat::Ascii => {
331+
log::info!("Loading ASCII format from {}", path_display);
332+
let file = std::fs::File::open(path)
333+
.with_context(|| format!("Could not open {}", path_display))?;
334+
let reader = BufReader::new(file);
335+
reader
336+
.lines()
337+
.enumerate()
338+
.filter(|(_, line)| line.as_ref().map_or(true, |l| !l.trim().is_empty()))
339+
.map(|(i, line)| {
340+
let line = line.with_context(|| {
341+
format!("Error reading line {} of {}", i + 1, path_display)
342+
})?;
343+
line.trim().parse::<F>().map_err(|e| {
344+
anyhow!("Error parsing line {} of {}: {}", i + 1, path_display, e)
345+
})
346+
})
347+
.collect()
348+
}
349+
FloatVectorFormat::Json => {
350+
log::info!("Loading JSON format from {}", path_display);
351+
let file = std::fs::File::open(path)
352+
.with_context(|| format!("Could not open {}", path_display))?;
353+
let mut reader = BufReader::new(file);
354+
let mut result = Vec::new();
355+
let mut byte = [0u8; 1];
356+
357+
// Skip whitespace and opening bracket
358+
loop {
359+
reader
360+
.read_exact(&mut byte)
361+
.with_context(|| format!("Unexpected end of file in {}", path_display))?;
362+
match byte[0] {
363+
b'[' => break,
364+
b if b.is_ascii_whitespace() => continue,
365+
_ => bail!("Expected '[' at start of JSON array in {}", path_display),
366+
}
367+
}
368+
369+
// Parse comma-separated values until ']'
370+
let mut token = String::new();
371+
let mut index = 0usize;
372+
loop {
373+
reader
374+
.read_exact(&mut byte)
375+
.with_context(|| format!("Unexpected end of file in {}", path_display))?;
376+
match byte[0] {
377+
b']' => {
378+
let trimmed = token.trim();
379+
if !trimmed.is_empty() {
380+
result.push(trimmed.parse::<F>().map_err(|e| {
381+
anyhow!(
382+
"Error parsing element {} of {}: {}",
383+
index + 1,
384+
path_display,
385+
e
386+
)
387+
})?);
388+
}
389+
break;
390+
}
391+
b',' => {
392+
let trimmed = token.trim();
393+
result.push(trimmed.parse::<F>().map_err(|e| {
394+
anyhow!(
395+
"Error parsing element {} of {}: {}",
396+
index + 1,
397+
path_display,
398+
e
399+
)
400+
})?);
401+
token.clear();
402+
index += 1;
403+
}
404+
c => {
405+
token.push(c as char);
406+
}
407+
}
408+
}
409+
Ok(result)
410+
}
411+
}
412+
}
283413
}
284414

285415
#[derive(Debug, Clone, Copy, ValueEnum)]

cli/src/rank/pagerank.rs

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
*/
66

77
use crate::{FloatVectorFormat, GlobalArgs, GranularityArgs, NumThreadsArg, get_thread_pool};
8-
use anyhow::{Context, Result, ensure};
8+
use anyhow::{Result, ensure};
99
use clap::Parser;
1010
use dsi_bitstream::prelude::*;
1111
use dsi_progress_logger::{ProgressLog, concurrent_progress_logger, progress_logger};
1212
use predicates::prelude::*;
13-
use std::io::BufRead;
1413
use std::path::PathBuf;
1514
use webgraph::graphs::bvgraph::get_endianness;
1615
use webgraph::prelude::BvGraph;
@@ -67,9 +66,13 @@ pub struct CliArgs {
6766
pub threshold: f64,
6867

6968
#[arg(short, long)]
70-
/// Path to a preference (personalization) vector (one f64 per line).
69+
/// Path to a preference (personalization) vector.
7170
pub preference: Option<PathBuf>,
7271

72+
#[arg(long, value_enum, default_value_t = FloatVectorFormat::Ascii)]
73+
/// The input format for the preference vector.
74+
pub preference_fmt: FloatVectorFormat,
75+
7376
#[arg(short, long, value_enum, default_value_t = CliMode::StronglyPreferential)]
7477
/// The PageRank mode.
7578
pub mode: CliMode,
@@ -127,7 +130,11 @@ pub fn pagerank<E: Endianness>(global_args: GlobalArgs, args: CliArgs) -> Result
127130
);
128131
let transpose = BvGraph::with_basename(&args.transpose).load()?;
129132

130-
let preference = args.preference.as_ref().map(load_f64_vector).transpose()?;
133+
let preference: Option<Vec<f64>> = args
134+
.preference
135+
.as_ref()
136+
.map(|path| args.preference_fmt.load(path))
137+
.transpose()?;
131138

132139
// Build stopping predicate
133140
let mut predicate = L1Norm::try_from(args.threshold)?.boxed();
@@ -156,27 +163,3 @@ pub fn pagerank<E: Endianness>(global_args: GlobalArgs, args: CliArgs) -> Result
156163

157164
Ok(())
158165
}
159-
160-
/// Reads a text file containing one `f64` per line.
161-
fn load_f64_vector(path: impl AsRef<std::path::Path>) -> Result<Vec<f64>> {
162-
let path = path.as_ref();
163-
let file = std::fs::File::open(path)
164-
.with_context(|| format!("Could not open vector file {}", path.display()))?;
165-
let reader = std::io::BufReader::new(file);
166-
reader
167-
.lines()
168-
.enumerate()
169-
.map(|(i, line)| {
170-
let line = line
171-
.with_context(|| format!("Error reading line {} of {}", i + 1, path.display()))?;
172-
line.trim().parse::<f64>().with_context(|| {
173-
format!(
174-
"Error parsing line {} of {}: {:?}",
175-
i + 1,
176-
path.display(),
177-
line
178-
)
179-
})
180-
})
181-
.collect()
182-
}

0 commit comments

Comments
 (0)