Skip to content

Commit 2c62e38

Browse files
committed
optimize num_chars for aarch64 more
1 parent 10ab8fb commit 2c62e38

1 file changed

Lines changed: 34 additions & 29 deletions

File tree

src/simd/aarch64.rs

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use core::arch::aarch64::{
22
uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4,
3-
vmvnq_u8, vsubq_u8,
3+
vsubq_u8,
44
};
55

66
const MASK: [u8; 32] = [
@@ -35,6 +35,10 @@ unsafe fn sum(u8s: uint8x16_t) -> usize {
3535
vaddlvq_u8(u8s) as usize
3636
}
3737

38+
unsafe fn sum4(u1: uint8x16_t, u2: uint8x16_t, u3: uint8x16_t, u4: uint8x16_t) -> usize {
39+
((vaddlvq_u8(u1) + vaddlvq_u8(u2)) + (vaddlvq_u8(u3) + vaddlvq_u8(u4))) as usize
40+
}
41+
3842
#[target_feature(enable = "neon")]
3943
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
4044
assert!(haystack.len() >= 16);
@@ -56,7 +60,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
5660
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
5761
offset += 64;
5862
}
59-
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
63+
count += sum4(count1, count2, count3, count4);
6064
}
6165

6266
// 64
@@ -70,7 +74,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
7074
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
7175
offset += 64;
7276
}
73-
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
77+
count += sum4(count1, count2, count3, count4);
7478

7579
let mut counts = vdupq_n_u8(0);
7680
// 16
@@ -93,11 +97,11 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
9397
}
9498

9599
#[target_feature(enable = "neon")]
96-
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
97-
vmvnq_u8(vceqq_u8(
100+
unsafe fn is_following_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
101+
vceqq_u8(
98102
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
99103
vdupq_n_u8(0b1000_0000),
100-
))
104+
)
101105
}
102106

103107
#[target_feature(enable = "neon")]
@@ -108,50 +112,51 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
108112
let mut count = 0;
109113

110114
// 4080
111-
while utf8_chars.len() >= offset + 16 * 255 {
112-
let mut counts = vdupq_n_u8(0);
115+
while utf8_chars.len() >= offset + 64 * 255 {
116+
let (mut count1, mut count2, mut count3, mut count4) =
117+
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
113118

114119
for _ in 0..255 {
115-
counts = vsubq_u8(
116-
counts,
117-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
118-
);
119-
offset += 16;
120+
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(utf8_chars, offset);
121+
count1 = vsubq_u8(count1,is_following_utf8_byte(h1));
122+
count2 = vsubq_u8(count2,is_following_utf8_byte(h2));
123+
count3 = vsubq_u8(count3,is_following_utf8_byte(h3));
124+
count4 = vsubq_u8(count4,is_following_utf8_byte(h4));
125+
offset += 64;
120126
}
121-
count += sum(counts);
127+
count += sum4(count1, count2, count3, count4);
122128
}
123129

124-
// 2048
125-
if utf8_chars.len() >= offset + 16 * 128 {
126-
let mut counts = vdupq_n_u8(0);
127-
for _ in 0..128 {
128-
counts = vsubq_u8(
129-
counts,
130-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
131-
);
132-
offset += 16;
130+
// 4080
131+
let (mut count1, mut count2, mut count3, mut count4) =
132+
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
133+
for _ in 0..(utf8_chars.len() - offset) / 64 {
134+
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(utf8_chars, offset);
135+
count1 = vsubq_u8(count1, is_following_utf8_byte(h1));
136+
count2 = vsubq_u8(count2, is_following_utf8_byte(h2));
137+
count3 = vsubq_u8(count3, is_following_utf8_byte(h3));
138+
count4 = vsubq_u8(count4, is_following_utf8_byte(h4));
139+
offset += 64;
133140
}
134-
count += sum(counts);
135-
}
136-
141+
count += sum4(count1, count2, count3, count4);
137142
// 16
138143
let mut counts = vdupq_n_u8(0);
139144
for i in 0..(utf8_chars.len() - offset) / 16 {
140145
counts = vsubq_u8(
141146
counts,
142-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
147+
is_following_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
143148
);
144149
}
145150
if utf8_chars.len() % 16 != 0 {
146151
counts = vsubq_u8(
147152
counts,
148153
vandq_u8(
149-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
154+
is_following_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
150155
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
151156
),
152157
);
153158
}
154159
count += sum(counts);
155160

156-
count
161+
utf8_chars.len() - count
157162
}

0 commit comments

Comments
 (0)