Faster argmin on floats
Published on
Consider the following problem: you are given a dynamically large array of floating point numbers, and you are asked to find the index of the smallest one (commonly called ) as fast as possible. But there’s a catch: you know that all values in the list are positive or +0, non-infinity and non-NaN.
First solution
The first solution that comes to mind for this, in Rust, is:
let argmin = data.iter().enumerate().min_by(|a, b| a.1.total_cmp(b.1));
For a million numbers, this runs in around 511 us. Not bad. However, we can do better, by using what we know about the data itself.
Second solution
When optimizing this, my first instinct was to implement our own comparator function, using the natural partial order of floats. This will of course return wrong values for other cases.
let argmin = data.iter().enumerate()
.reduce(|a, b| if a.1 < b.1 { a } else { b });
The runtime here is 489 us for a million numbers.
Third solution
A different, easy solution would be to use the partial comparison function in the rust standard library, and return Equal if the two values cannot be partially compared (which should be never in our case).
let argmin = data.iter().enumerate()
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal));
This turns out to be slightly faster compared to our second solution; for our -number benchmark, this runs in 470 us.
I suspect this is because the compiler can better optimize the code.
Fourth solution
We are also told that the list contains only positive numbers.
Based on this, we can use a very elegant property of floating point representation: you can sort the f32
values as u32
, if you have only positive numbers.
This is also why you can do radix sort on floats.
let argmin = data.iter().enumerate().min_by_key(|a| (a.1.to_bits()));
For positive numbers, this is blazingly fast: it takes only 370 us, providing a 30% speedup over baseline.
For more information, I recommend reading the answer in https://stackoverflow.com/a/59349481.
Benchmark program
use criterion::{Criterion, criterion_group, criterion_main};
use std::{cmp::Ordering, hint::black_box};
fn generate_data(n: usize) -> Vec<f32> {
let half = n / 2;
let mut vec = Vec::with_capacity(n * 2);
vec.extend(
(0..=half)
.rev()
.chain(0..=half)
.map(|x| (2 * x + 1) as f32 / 2.0),
);
vec
}
const N: usize = 1_000_000;
fn bench_normal_min(c: &mut Criterion) {
let data = generate_data(N);
c.bench_function("normal_min", |b| {
b.iter(|| {
let argmin = data.iter().enumerate().min_by(|a, b| a.1.total_cmp(b.1));
black_box(argmin);
})
});
}
fn bench_reduce_min(c: &mut Criterion) {
let data = generate_data(N);
c.bench_function("reduce_min", |b| {
b.iter(|| {
let min = data
.iter()
.enumerate()
.reduce(|a, b| if a.1 < b.1 { a } else { b });
black_box(min);
})
});
}
fn bench_partial_min(c: &mut Criterion) {
let data = generate_data(N);
c.bench_function("partial_unwrap_min", |b| {
b.iter(|| {
let argmin = data
.iter()
.enumerate()
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(Ordering::Equal));
black_box(argmin);
})
});
}
fn bench_u32_min_positive(c: &mut Criterion) {
let data = generate_data(N);
c.bench_function("u32_min_positive", |b| {
b.iter(|| {
let min = data.iter().enumerate().min_by_key(|a| (a.1.to_bits()));
black_box(min);
})
});
}
criterion_group!(
benches,
bench_normal_min,
bench_reduce_min,
bench_partial_min,
bench_u32_min_positive,
);
criterion_main!(benches);