Skip to content

Commit 5c13966

Browse files
committed
add aarch64
1 parent fbad8d4 commit 5c13966

3 files changed

Lines changed: 172 additions & 6 deletions

File tree

src/lib.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
//! still on small strings.
3333
3434
#![deny(missing_docs)]
35-
3635
#![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)]
3736

3837
#[cfg(not(feature = "runtime-dispatch-simd"))]
@@ -45,7 +44,11 @@ pub use naive::*;
4544
mod integer_simd;
4645

4746
#[cfg(any(
48-
all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")),
47+
all(
48+
feature = "runtime-dispatch-simd",
49+
any(target_arch = "x86", target_arch = "x86_64")
50+
),
51+
target_arch = "aarch64",
4952
feature = "generic-simd"
5053
))]
5154
mod simd;
@@ -64,7 +67,9 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
6467
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
6568
{
6669
if is_x86_feature_detected!("avx2") {
67-
unsafe { return simd::x86_avx2::chunk_count(haystack, needle); }
70+
unsafe {
71+
return simd::x86_avx2::chunk_count(haystack, needle);
72+
}
6873
}
6974
}
7075

@@ -80,7 +85,15 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
8085
))]
8186
{
8287
if is_x86_feature_detected!("sse2") {
83-
unsafe { return simd::x86_sse2::chunk_count(haystack, needle); }
88+
unsafe {
89+
return simd::x86_sse2::chunk_count(haystack, needle);
90+
}
91+
}
92+
}
93+
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
94+
{
95+
unsafe {
96+
return simd::aarch64::chunk_count(haystack, needle);
8497
}
8598
}
8699
}
@@ -109,7 +122,9 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
109122
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
110123
{
111124
if is_x86_feature_detected!("avx2") {
112-
unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); }
125+
unsafe {
126+
return simd::x86_avx2::chunk_num_chars(utf8_chars);
127+
}
113128
}
114129
}
115130

@@ -125,7 +140,15 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
125140
))]
126141
{
127142
if is_x86_feature_detected!("sse2") {
128-
unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); }
143+
unsafe {
144+
return simd::x86_sse2::chunk_num_chars(utf8_chars);
145+
}
146+
}
147+
}
148+
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
149+
{
150+
unsafe {
151+
return simd::aarch64::chunk_num_chars(utf8_chars);
129152
}
130153
}
131154
}

src/simd/aarch64.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use core::arch::aarch64::{
2+
uint8x16_t, vaddlvq_u8, vandq_u8, vceqq_u8, vcgtq_u8, vdupq_n_u8, vld1q_u8, vmvnq_u8, vsubq_u8,
3+
};
4+
5+
const MASK: [u8; 32] = [
6+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
7+
255, 255, 255, 255, 255, 255, 255,
8+
];
9+
10+
#[target_feature(enable = "neon")]
11+
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t {
12+
vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned?
13+
}
14+
15+
#[target_feature(enable = "neon")]
16+
unsafe fn sum(u8s: &uint8x16_t) -> usize {
17+
vaddlvq_u8(*u8s) as usize
18+
}
19+
20+
#[target_feature(enable = "neon")]
21+
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
22+
assert!(haystack.len() >= 16);
23+
24+
let mut offset = 0;
25+
let mut count = 0;
26+
27+
let needles = vdupq_n_u8(needle);
28+
29+
// 4080
30+
while haystack.len() >= offset + 16 * 255 {
31+
let mut counts = vdupq_n_u8(0);
32+
for _ in 0..255 {
33+
counts = vsubq_u8(
34+
counts,
35+
vceqq_u8(u8x16_from_offset(haystack, offset), needles),
36+
);
37+
offset += 16;
38+
}
39+
count += sum(&counts);
40+
}
41+
42+
// 2048
43+
if haystack.len() >= offset + 16 * 128 {
44+
let mut counts = vdupq_n_u8(0);
45+
for _ in 0..128 {
46+
counts = vsubq_u8(
47+
counts,
48+
vceqq_u8(u8x16_from_offset(haystack, offset), needles),
49+
);
50+
offset += 16;
51+
}
52+
count += sum(&counts);
53+
}
54+
55+
// 16
56+
let mut counts = vdupq_n_u8(0);
57+
for i in 0..(haystack.len() - offset) / 16 {
58+
counts = vsubq_u8(
59+
counts,
60+
vcgtq_u8(u8x16_from_offset(haystack, offset + i * 32), needles),
61+
);
62+
}
63+
if haystack.len() % 16 != 0 {
64+
counts = vsubq_u8(
65+
counts,
66+
vandq_u8(
67+
vceqq_u8(u8x16_from_offset(haystack, haystack.len() - 16), needles),
68+
u8x16_from_offset(&MASK, haystack.len() % 16),
69+
),
70+
);
71+
}
72+
count += sum(&counts);
73+
74+
count
75+
}
76+
77+
#[target_feature(enable = "neon")]
78+
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
79+
vmvnq_u8(vceqq_u8(
80+
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
81+
vdupq_n_u8(0b1000_0000),
82+
))
83+
}
84+
85+
#[target_feature(enable = "neon")]
86+
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
87+
assert!(utf8_chars.len() >= 16);
88+
89+
let mut offset = 0;
90+
let mut count = 0;
91+
92+
// 4080
93+
while utf8_chars.len() >= offset + 16 * 255 {
94+
let mut counts = vdupq_n_u8(0);
95+
96+
for _ in 0..255 {
97+
counts = vsubq_u8(
98+
counts,
99+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
100+
);
101+
offset += 16;
102+
}
103+
count += sum(&counts);
104+
}
105+
106+
// 2048
107+
if utf8_chars.len() >= offset + 16 * 128 {
108+
let mut counts = vdupq_n_u8(0);
109+
for _ in 0..128 {
110+
counts = vsubq_u8(
111+
counts,
112+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
113+
);
114+
offset += 16;
115+
}
116+
count += sum(&counts);
117+
}
118+
119+
// 16
120+
let mut counts = vdupq_n_u8(0);
121+
for i in 0..(utf8_chars.len() - offset) / 16 {
122+
counts = vsubq_u8(
123+
counts,
124+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 32)),
125+
);
126+
}
127+
if utf8_chars.len() % 16 != 0 {
128+
counts = vsubq_u8(
129+
counts,
130+
vandq_u8(
131+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
132+
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
133+
),
134+
);
135+
}
136+
count += sum(&counts);
137+
138+
count
139+
}

src/simd/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ pub mod x86_sse2;
1515
// Runtime feature detection is not available with no_std.
1616
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
1717
pub mod x86_avx2;
18+
19+
/// Modern ARM machines are also quite capable thanks to NEON
20+
#[cfg(target_arch = "aarch64")]
21+
pub mod aarch64;

0 commit comments

Comments
 (0)