Skip to content

Commit 2d41959

Browse files
authored
Merge pull request #88 from llogiq/wasm
implement wasm32 simd, optimize aarch64 num_chars, bump version
2 parents 150b4aa + d5d3acc commit 2d41959

6 files changed

Lines changed: 238 additions & 31 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ authors = ["Andre Bogus <bogusandre@gmail.de>", "Joshua Landau <joshua@landau.ws
33
description = "count occurrences of a given byte, or the number of UTF-8 code points, in a byte slice, fast"
44
edition = "2018"
55
name = "bytecount"
6-
version = "0.6.4"
6+
version = "0.6.5"
77
license = "Apache-2.0/MIT"
88
repository = "https://github.com/llogiq/bytecount"
99
categories = ["algorithms", "no-std"]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The [newlinebench](https://github.com/llogiq/newlinebench) repository has furthe
1212

1313
To use bytecount in your crate, if you have [cargo-edit](https://github.com/killercup/cargo-edit), just type
1414
`cargo add bytecount` in a terminal with the crate root as the current path. Otherwise you can manually edit your
15-
`Cargo.toml` to add `bytecount = 0.6.4` to your `[dependencies]` section.
15+
`Cargo.toml` to add `bytecount = 0.6.5` to your `[dependencies]` section.
1616

1717
In your crate root (`lib.rs` or `main.rs`, depending on if you are writing a
1818
library or application), add `extern crate bytecount;`. Now you can simply use

src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod integer_simd;
4949
any(target_arch = "x86", target_arch = "x86_64")
5050
),
5151
target_arch = "aarch64",
52+
target_arch = "wasm32",
5253
feature = "generic-simd"
5354
))]
5455
mod simd;
@@ -96,6 +97,13 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
9697
return simd::aarch64::chunk_count(haystack, needle);
9798
}
9899
}
100+
101+
#[cfg(target_arch = "wasm32")]
102+
{
103+
unsafe {
104+
return simd::wasm::chunk_count(haystack, needle);
105+
}
106+
}
99107
}
100108

101109
if haystack.len() >= mem::size_of::<usize>() {
@@ -151,6 +159,13 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
151159
return simd::aarch64::chunk_num_chars(utf8_chars);
152160
}
153161
}
162+
163+
#[cfg(target_arch = "wasm32")]
164+
{
165+
unsafe {
166+
return simd::wasm::chunk_num_chars(utf8_chars);
167+
}
168+
}
154169
}
155170

156171
if utf8_chars.len() >= mem::size_of::<usize>() {

src/simd/aarch64.rs

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use core::arch::aarch64::{
22
uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4,
3-
vmvnq_u8, vsubq_u8,
3+
vsubq_u8,
44
};
55

66
const MASK: [u8; 32] = [
@@ -35,6 +35,10 @@ unsafe fn sum(u8s: uint8x16_t) -> usize {
3535
vaddlvq_u8(u8s) as usize
3636
}
3737

38+
unsafe fn sum4(u1: uint8x16_t, u2: uint8x16_t, u3: uint8x16_t, u4: uint8x16_t) -> usize {
39+
((vaddlvq_u8(u1) + vaddlvq_u8(u2)) + (vaddlvq_u8(u3) + vaddlvq_u8(u4))) as usize
40+
}
41+
3842
#[target_feature(enable = "neon")]
3943
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
4044
assert!(haystack.len() >= 16);
@@ -56,7 +60,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
5660
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
5761
offset += 64;
5862
}
59-
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
63+
count += sum4(count1, count2, count3, count4);
6064
}
6165

6266
// 64
@@ -70,7 +74,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
7074
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
7175
offset += 64;
7276
}
73-
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
77+
count += sum4(count1, count2, count3, count4);
7478

7579
let mut counts = vdupq_n_u8(0);
7680
// 16
@@ -93,11 +97,11 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
9397
}
9498

9599
#[target_feature(enable = "neon")]
96-
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
97-
vmvnq_u8(vceqq_u8(
100+
unsafe fn is_following_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
101+
vceqq_u8(
98102
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
99103
vdupq_n_u8(0b1000_0000),
100-
))
104+
)
101105
}
102106

103107
#[target_feature(enable = "neon")]
@@ -108,50 +112,51 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
108112
let mut count = 0;
109113

110114
// 4080
111-
while utf8_chars.len() >= offset + 16 * 255 {
112-
let mut counts = vdupq_n_u8(0);
115+
while utf8_chars.len() >= offset + 64 * 255 {
116+
let (mut count1, mut count2, mut count3, mut count4) =
117+
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
113118

114119
for _ in 0..255 {
115-
counts = vsubq_u8(
116-
counts,
117-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
118-
);
119-
offset += 16;
120+
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(utf8_chars, offset);
121+
count1 = vsubq_u8(count1,is_following_utf8_byte(h1));
122+
count2 = vsubq_u8(count2,is_following_utf8_byte(h2));
123+
count3 = vsubq_u8(count3,is_following_utf8_byte(h3));
124+
count4 = vsubq_u8(count4,is_following_utf8_byte(h4));
125+
offset += 64;
120126
}
121-
count += sum(counts);
127+
count += sum4(count1, count2, count3, count4);
122128
}
123129

124-
// 2048
125-
if utf8_chars.len() >= offset + 16 * 128 {
126-
let mut counts = vdupq_n_u8(0);
127-
for _ in 0..128 {
128-
counts = vsubq_u8(
129-
counts,
130-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
131-
);
132-
offset += 16;
130+
// 4080
131+
let (mut count1, mut count2, mut count3, mut count4) =
132+
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
133+
for _ in 0..(utf8_chars.len() - offset) / 64 {
134+
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(utf8_chars, offset);
135+
count1 = vsubq_u8(count1, is_following_utf8_byte(h1));
136+
count2 = vsubq_u8(count2, is_following_utf8_byte(h2));
137+
count3 = vsubq_u8(count3, is_following_utf8_byte(h3));
138+
count4 = vsubq_u8(count4, is_following_utf8_byte(h4));
139+
offset += 64;
133140
}
134-
count += sum(counts);
135-
}
136-
141+
count += sum4(count1, count2, count3, count4);
137142
// 16
138143
let mut counts = vdupq_n_u8(0);
139144
for i in 0..(utf8_chars.len() - offset) / 16 {
140145
counts = vsubq_u8(
141146
counts,
142-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
147+
is_following_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
143148
);
144149
}
145150
if utf8_chars.len() % 16 != 0 {
146151
counts = vsubq_u8(
147152
counts,
148153
vandq_u8(
149-
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
154+
is_following_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
150155
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
151156
),
152157
);
153158
}
154159
count += sum(counts);
155160

156-
count
161+
utf8_chars.len() - count
157162
}

src/simd/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ pub mod x86_avx2;
1919
/// Modern ARM machines are also quite capable thanks to NEON
2020
#[cfg(target_arch = "aarch64")]
2121
pub mod aarch64;
22+
23+
#[cfg(target_arch = "wasm32")]
24+
pub mod wasm;

src/simd/wasm.rs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
use core::arch::wasm32::*;
2+
3+
const MASK: [u8; 32] = [
4+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
5+
255, 255, 255, 255, 255, 255, 255,
6+
];
7+
8+
#[target_feature(enable = "simd128")]
9+
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> v128 {
10+
debug_assert!(
11+
offset + 16 <= slice.len(),
12+
"{} + 16 ≥ {}",
13+
offset,
14+
slice.len()
15+
);
16+
v128_load(slice.as_ptr().add(offset) as *const _)
17+
}
18+
19+
#[target_feature(enable = "simd128")]
20+
unsafe fn u8x16x4_from_offset(slice: &[u8], offset: usize) -> (v128, v128, v128, v128) {
21+
debug_assert!(
22+
offset + 64 <= slice.len(),
23+
"{} + 64 ≥ {}",
24+
offset,
25+
slice.len()
26+
);
27+
(
28+
v128_load(slice.as_ptr().add(offset + 0) as *const _),
29+
v128_load(slice.as_ptr().add(offset + 16) as *const _),
30+
v128_load(slice.as_ptr().add(offset + 32) as *const _),
31+
v128_load(slice.as_ptr().add(offset + 48) as *const _),
32+
)
33+
}
34+
35+
// TODO: We might want to amortize some additions by
36+
// keeping in multiple u16s and u32s respectively for a few ns
37+
#[target_feature(enable = "simd128")]
38+
unsafe fn sum(u8s: v128) -> usize {
39+
let u16s = u16x8_extadd_pairwise_u8x16(u8s);
40+
let u32s = u32x4_extadd_pairwise_u16x8(u16s);
41+
let (u1, u2, u3, u4) = (
42+
u32x4_extract_lane::<1>(u32s),
43+
u32x4_extract_lane::<2>(u32s),
44+
u32x4_extract_lane::<3>(u32s),
45+
u32x4_extract_lane::<4>(u32s),
46+
);
47+
((u1 + u2) + (u3 + u4)) as usize
48+
}
49+
50+
#[target_feature(enable = "simd128")]
51+
unsafe fn sum4(u1: v128, u2: v128, u3: v128, u4: v128) -> usize {
52+
// sum < (2^2 * 2^3 * 2^8 = 2^13) < 2^16, therefore no overflow here
53+
let u16s = u16x8_add(
54+
u16x8_add(u16x8_extadd_pairwise_u8x16(u1), u16x8_extadd_pairwise_u8x16(u2)),
55+
u16x8_add(u16x8_extadd_pairwise_u8x16(u3), u16x8_extadd_pairwise_u8x16(u4)),
56+
);
57+
let u32s = u32x4_extadd_pairwise_u16x8(u16s);
58+
let (u1, u2, u3, u4) = (
59+
u32x4_extract_lane::<1>(u32s),
60+
u32x4_extract_lane::<2>(u32s),
61+
u32x4_extract_lane::<3>(u32s),
62+
u32x4_extract_lane::<4>(u32s),
63+
);
64+
((u1 + u2) + (u3 + u4)) as usize
65+
}
66+
67+
#[target_feature(enable = "simd128")]
68+
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
69+
let needles = u8x16_splat(needle);
70+
let mut count = 0;
71+
let mut offset = 0;
72+
73+
while haystack.len() >= offset + 16 * 255 {
74+
let (mut count1, mut count2, mut count3, mut count4) =
75+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
76+
for _ in 0..255 {
77+
let (h1, h2, h3, h4) = u8x16x4_from_offset(haystack, offset);
78+
count1 = u8x16_sub(count1, u8x16_eq(h1, needles));
79+
count2 = u8x16_sub(count2, u8x16_eq(h2, needles));
80+
count3 = u8x16_sub(count3, u8x16_eq(h3, needles));
81+
count4 = u8x16_sub(count4, u8x16_eq(h4, needles));
82+
offset += 64;
83+
}
84+
count += sum4(count1, count2, count3, count4);
85+
}
86+
87+
// 64
88+
let (mut count1, mut count2, mut count3, mut count4) =
89+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
90+
for _ in 0..(haystack.len() - offset) / 64 {
91+
let (h1, h2, h3, h4) = u8x16x4_from_offset(haystack, offset);
92+
count1 = u8x16_sub(count1, u8x16_eq(h1, needles));
93+
count2 = u8x16_sub(count2, u8x16_eq(h2, needles));
94+
count3 = u8x16_sub(count3, u8x16_eq(h3, needles));
95+
count4 = u8x16_sub(count4, u8x16_eq(h4, needles));
96+
offset += 64;
97+
}
98+
count += sum4(count1, count2, count3, count4);
99+
100+
let mut counts = u8x16_splat(0);
101+
// 16
102+
for i in 0..(haystack.len() - offset) / 16 {
103+
counts = u8x16_sub(
104+
counts,
105+
u8x16_eq(u8x16_from_offset(haystack, offset + i * 16), needles),
106+
);
107+
}
108+
if haystack.len() % 16 != 0 {
109+
counts = u8x16_sub(
110+
counts,
111+
v128_and(
112+
u8x16_eq(u8x16_from_offset(haystack, haystack.len() - 16), needles),
113+
u8x16_from_offset(&MASK, haystack.len() % 16),
114+
),
115+
);
116+
}
117+
count + sum(counts)
118+
}
119+
120+
#[target_feature(enable = "simd128")]
121+
unsafe fn is_leading_utf8_byte(u8s: v128) -> v128 {
122+
u8x16_ne(
123+
v128_and(u8s, u8x16_splat(0b1100_0000)),
124+
u8x16_splat(0b1000_0000),
125+
)
126+
}
127+
128+
#[target_feature(enable = "simd128")]
129+
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
130+
assert!(utf8_chars.len() >= 16);
131+
132+
let mut offset = 0;
133+
let mut count = 0;
134+
135+
// 4080
136+
while utf8_chars.len() >= offset + 64 * 255 {
137+
let (mut count1, mut count2, mut count3, mut count4) =
138+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
139+
140+
for _ in 0..255 {
141+
let (h1, h2, h3, h4) = u8x16x4_from_offset(utf8_chars, offset);
142+
count1 = u8x16_sub(count1,is_leading_utf8_byte(h1));
143+
count2 = u8x16_sub(count2,is_leading_utf8_byte(h2));
144+
count3 = u8x16_sub(count3,is_leading_utf8_byte(h3));
145+
count4 = u8x16_sub(count4,is_leading_utf8_byte(h4));
146+
offset += 64;
147+
}
148+
count += sum4(count1, count2, count3, count4);
149+
}
150+
151+
// 4080
152+
let (mut count1, mut count2, mut count3, mut count4) =
153+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
154+
for _ in 0..(utf8_chars.len() - offset) / 64 {
155+
let (h1, h2, h3, h4) = u8x16x4_from_offset(utf8_chars, offset);
156+
count1 = u8x16_sub(count1, is_leading_utf8_byte(h1));
157+
count2 = u8x16_sub(count2, is_leading_utf8_byte(h2));
158+
count3 = u8x16_sub(count3, is_leading_utf8_byte(h3));
159+
count4 = u8x16_sub(count4, is_leading_utf8_byte(h4));
160+
offset += 64;
161+
}
162+
count += sum4(count1, count2, count3, count4);
163+
164+
// 16
165+
let mut counts = u8x16_splat(0);
166+
for i in 0..(utf8_chars.len() - offset) / 16 {
167+
counts = u8x16_sub(
168+
counts,
169+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
170+
);
171+
}
172+
if utf8_chars.len() % 16 != 0 {
173+
counts = u8x16_sub(
174+
counts,
175+
v128_and(
176+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
177+
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
178+
),
179+
);
180+
}
181+
count += sum(counts);
182+
183+
count
184+
}

0 commit comments

Comments
 (0)