Skip to content

Commit b375732

Browse files
authored
Merge pull request #82 from llogiq/aarch64
Add aarch64 SIMD specialization
2 parents fbad8d4 + ffd810a commit b375732

9 files changed

Lines changed: 248 additions & 57 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
arch:
2323
- i686
2424
- x86_64
25+
- aarch64
2526
features:
2627
- default
2728
- runtime-dispatch-simd

src/integer_simd.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ unsafe fn usize_load_unchecked(bytes: &[u8], offset: usize) -> usize {
1313
ptr::copy_nonoverlapping(
1414
bytes.as_ptr().add(offset),
1515
&mut output as *mut usize as *mut u8,
16-
mem::size_of::<usize>()
16+
mem::size_of::<usize>(),
1717
);
1818
output
1919
}
@@ -65,11 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
6565
// 8
6666
let mut counts = 0;
6767
for i in 0..(haystack.len() - offset) / chunksize {
68-
counts += bytewise_equal(usize_load_unchecked(haystack, offset + i * chunksize), needles);
68+
counts += bytewise_equal(
69+
usize_load_unchecked(haystack, offset + i * chunksize),
70+
needles,
71+
);
6972
}
7073
if haystack.len() % 8 != 0 {
7174
let mask = usize::from_le(!(!0 >> ((haystack.len() % chunksize) * 8)));
72-
counts += bytewise_equal(usize_load_unchecked(haystack, haystack.len() - chunksize), needles) & mask;
75+
counts += bytewise_equal(
76+
usize_load_unchecked(haystack, haystack.len() - chunksize),
77+
needles,
78+
) & mask;
7379
}
7480
count += sum_usize(counts);
7581

@@ -98,11 +104,15 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
98104
// 8
99105
let mut counts = 0;
100106
for i in 0..(utf8_chars.len() - offset) / chunksize {
101-
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
107+
counts +=
108+
is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
102109
}
103110
if utf8_chars.len() % 8 != 0 {
104111
let mask = usize::from_le(!(!0 >> ((utf8_chars.len() % chunksize) * 8)));
105-
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, utf8_chars.len() - chunksize)) & mask;
112+
counts += is_leading_utf8_byte(usize_load_unchecked(
113+
utf8_chars,
114+
utf8_chars.len() - chunksize,
115+
)) & mask;
106116
}
107117
count += sum_usize(counts);
108118

src/lib.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
//! still on small strings.
3333
3434
#![deny(missing_docs)]
35-
3635
#![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)]
3736

3837
#[cfg(not(feature = "runtime-dispatch-simd"))]
@@ -45,7 +44,11 @@ pub use naive::*;
4544
mod integer_simd;
4645

4746
#[cfg(any(
48-
all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")),
47+
all(
48+
feature = "runtime-dispatch-simd",
49+
any(target_arch = "x86", target_arch = "x86_64")
50+
),
51+
target_arch = "aarch64",
4952
feature = "generic-simd"
5053
))]
5154
mod simd;
@@ -64,7 +67,9 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
6467
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
6568
{
6669
if is_x86_feature_detected!("avx2") {
67-
unsafe { return simd::x86_avx2::chunk_count(haystack, needle); }
70+
unsafe {
71+
return simd::x86_avx2::chunk_count(haystack, needle);
72+
}
6873
}
6974
}
7075

@@ -80,7 +85,15 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
8085
))]
8186
{
8287
if is_x86_feature_detected!("sse2") {
83-
unsafe { return simd::x86_sse2::chunk_count(haystack, needle); }
88+
unsafe {
89+
return simd::x86_sse2::chunk_count(haystack, needle);
90+
}
91+
}
92+
}
93+
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
94+
{
95+
unsafe {
96+
return simd::aarch64::chunk_count(haystack, needle);
8497
}
8598
}
8699
}
@@ -109,7 +122,9 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
109122
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
110123
{
111124
if is_x86_feature_detected!("avx2") {
112-
unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); }
125+
unsafe {
126+
return simd::x86_avx2::chunk_num_chars(utf8_chars);
127+
}
113128
}
114129
}
115130

@@ -125,7 +140,15 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
125140
))]
126141
{
127142
if is_x86_feature_detected!("sse2") {
128-
unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); }
143+
unsafe {
144+
return simd::x86_sse2::chunk_num_chars(utf8_chars);
145+
}
146+
}
147+
}
148+
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
149+
{
150+
unsafe {
151+
return simd::aarch64::chunk_num_chars(utf8_chars);
129152
}
130153
}
131154
}

src/naive.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ pub fn naive_count_32(haystack: &[u8], needle: u8) -> usize {
2222
/// assert_eq!(number_of_spaces, 6);
2323
/// ```
2424
pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
25-
utf8_chars.iter().fold(0, |n, c| n + (*c == needle) as usize)
25+
utf8_chars
26+
.iter()
27+
.fold(0, |n, c| n + (*c == needle) as usize)
2628
}
2729

2830
/// Count the number of UTF-8 encoded Unicode codepoints in a slice of bytes, simple
@@ -38,5 +40,8 @@ pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
3840
/// assert_eq!(char_count, 4);
3941
/// ```
4042
pub fn naive_num_chars(utf8_chars: &[u8]) -> usize {
41-
utf8_chars.iter().filter(|&&byte| (byte >> 6) != 0b10).count()
43+
utf8_chars
44+
.iter()
45+
.filter(|&&byte| (byte >> 6) != 0b10)
46+
.count()
4247
}

src/simd/aarch64.rs

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
use core::arch::aarch64::{
2+
uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4,
3+
vmvnq_u8, vsubq_u8,
4+
};
5+
6+
const MASK: [u8; 32] = [
7+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
8+
255, 255, 255, 255, 255, 255, 255,
9+
];
10+
11+
#[target_feature(enable = "neon")]
12+
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t {
13+
debug_assert!(
14+
offset + 16 <= slice.len(),
15+
"{} + 16 ≥ {}",
16+
offset,
17+
slice.len()
18+
);
19+
vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned?
20+
}
21+
22+
#[target_feature(enable = "neon")]
23+
unsafe fn u8x16_x4_from_offset(slice: &[u8], offset: usize) -> uint8x16x4_t {
24+
debug_assert!(
25+
offset + 64 <= slice.len(),
26+
"{} + 64 ≥ {}",
27+
offset,
28+
slice.len()
29+
);
30+
vld1q_u8_x4(slice.as_ptr().add(offset) as *const _)
31+
}
32+
33+
#[target_feature(enable = "neon")]
34+
unsafe fn sum(u8s: uint8x16_t) -> usize {
35+
vaddlvq_u8(u8s) as usize
36+
}
37+
38+
#[target_feature(enable = "neon")]
39+
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
40+
assert!(haystack.len() >= 16);
41+
42+
let mut offset = 0;
43+
let mut count = 0;
44+
45+
let needles = vdupq_n_u8(needle);
46+
47+
// 16320
48+
while haystack.len() >= offset + 64 * 255 {
49+
let (mut count1, mut count2, mut count3, mut count4) =
50+
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
51+
for _ in 0..255 {
52+
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
53+
count1 = vsubq_u8(count1, vceqq_u8(h1, needles));
54+
count2 = vsubq_u8(count2, vceqq_u8(h2, needles));
55+
count3 = vsubq_u8(count3, vceqq_u8(h3, needles));
56+
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
57+
offset += 64;
58+
}
59+
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
60+
}
61+
62+
// 64
63+
let (mut count1, mut count2, mut count3, mut count4) =
64+
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
65+
for _ in 0..(haystack.len() - offset) / 64 {
66+
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
67+
count1 = vsubq_u8(count1, vceqq_u8(h1, needles));
68+
count2 = vsubq_u8(count2, vceqq_u8(h2, needles));
69+
count3 = vsubq_u8(count3, vceqq_u8(h3, needles));
70+
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
71+
offset += 64;
72+
}
73+
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
74+
75+
let mut counts = vdupq_n_u8(0);
76+
// 16
77+
for i in 0..(haystack.len() - offset) / 16 {
78+
counts = vsubq_u8(
79+
counts,
80+
vceqq_u8(u8x16_from_offset(haystack, offset + i * 16), needles),
81+
);
82+
}
83+
if haystack.len() % 16 != 0 {
84+
counts = vsubq_u8(
85+
counts,
86+
vandq_u8(
87+
vceqq_u8(u8x16_from_offset(haystack, haystack.len() - 16), needles),
88+
u8x16_from_offset(&MASK, haystack.len() % 16),
89+
),
90+
);
91+
}
92+
count + sum(counts)
93+
}
94+
95+
#[target_feature(enable = "neon")]
96+
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
97+
vmvnq_u8(vceqq_u8(
98+
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
99+
vdupq_n_u8(0b1000_0000),
100+
))
101+
}
102+
103+
#[target_feature(enable = "neon")]
104+
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
105+
assert!(utf8_chars.len() >= 16);
106+
107+
let mut offset = 0;
108+
let mut count = 0;
109+
110+
// 4080
111+
while utf8_chars.len() >= offset + 16 * 255 {
112+
let mut counts = vdupq_n_u8(0);
113+
114+
for _ in 0..255 {
115+
counts = vsubq_u8(
116+
counts,
117+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
118+
);
119+
offset += 16;
120+
}
121+
count += sum(counts);
122+
}
123+
124+
// 2048
125+
if utf8_chars.len() >= offset + 16 * 128 {
126+
let mut counts = vdupq_n_u8(0);
127+
for _ in 0..128 {
128+
counts = vsubq_u8(
129+
counts,
130+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
131+
);
132+
offset += 16;
133+
}
134+
count += sum(counts);
135+
}
136+
137+
// 16
138+
let mut counts = vdupq_n_u8(0);
139+
for i in 0..(utf8_chars.len() - offset) / 16 {
140+
counts = vsubq_u8(
141+
counts,
142+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
143+
);
144+
}
145+
if utf8_chars.len() % 16 != 0 {
146+
counts = vsubq_u8(
147+
counts,
148+
vandq_u8(
149+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
150+
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
151+
),
152+
);
153+
}
154+
count += sum(counts);
155+
156+
count
157+
}

src/simd/generic.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ use std::mem;
88
use self::packed_simd::{u8x32, u8x64, FromCast};
99

1010
const MASK: [u8; 64] = [
11-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
12-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
13-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
14-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
11+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
12+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
13+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
1514
];
1615

1716
unsafe fn u8x64_from_offset(slice: &[u8], offset: usize) -> u8x64 {
@@ -66,15 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
6665
// 32
6766
let mut counts = u8x32::splat(0);
6867
for i in 0..(haystack.len() - offset) / 32 {
69-
counts -= u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
68+
counts -=
69+
u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
7070
}
7171
count += sum_x32(&counts);
7272

7373
// Straggler; need to reset counts because prior loop can run 255 times
7474
counts = u8x32::splat(0);
7575
if haystack.len() % 32 != 0 {
76-
counts -= u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32)) &
77-
u8x32_from_offset(&MASK, haystack.len() % 32);
76+
counts -=
77+
u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32))
78+
& u8x32_from_offset(&MASK, haystack.len() % 32);
7879
}
7980
count += sum_x32(&counts);
8081

@@ -127,8 +128,9 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
127128
// Straggler; need to reset counts because prior loop can run 255 times
128129
counts = u8x32::splat(0);
129130
if utf8_chars.len() % 32 != 0 {
130-
counts -= is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32)) &
131-
u8x32_from_offset(&MASK, utf8_chars.len() % 32);
131+
counts -=
132+
is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32))
133+
& u8x32_from_offset(&MASK, utf8_chars.len() % 32);
132134
}
133135
count += sum_x32(&counts);
134136

src/simd/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ pub mod x86_sse2;
1515
// Runtime feature detection is not available with no_std.
1616
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
1717
pub mod x86_avx2;
18+
19+
/// Modern ARM machines are also quite capable thanks to NEON
20+
#[cfg(target_arch = "aarch64")]
21+
pub mod aarch64;

0 commit comments

Comments
 (0)