Skip to content

Commit 10ab8fb

Browse files
committed
implement wasm32 simd
1 parent 150b4aa commit 10ab8fb

3 files changed

Lines changed: 202 additions & 0 deletions

File tree

src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod integer_simd;
4949
any(target_arch = "x86", target_arch = "x86_64")
5050
),
5151
target_arch = "aarch64",
52+
target_arch = "wasm32",
5253
feature = "generic-simd"
5354
))]
5455
mod simd;
@@ -96,6 +97,13 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
9697
return simd::aarch64::chunk_count(haystack, needle);
9798
}
9899
}
100+
101+
#[cfg(target_arch = "wasm32")]
102+
{
103+
unsafe {
104+
return simd::wasm::chunk_count(haystack, needle);
105+
}
106+
}
99107
}
100108

101109
if haystack.len() >= mem::size_of::<usize>() {
@@ -151,6 +159,13 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
151159
return simd::aarch64::chunk_num_chars(utf8_chars);
152160
}
153161
}
162+
163+
#[cfg(target_arch = "wasm32")]
164+
{
165+
unsafe {
166+
return simd::wasm::chunk_num_chars(utf8_chars);
167+
}
168+
}
154169
}
155170

156171
if utf8_chars.len() >= mem::size_of::<usize>() {

src/simd/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ pub mod x86_avx2;
1919
/// Modern ARM machines are also quite capable thanks to NEON
2020
#[cfg(target_arch = "aarch64")]
2121
pub mod aarch64;
22+
23+
#[cfg(target_arch = "wasm32")]
24+
pub mod wasm;

src/simd/wasm.rs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
use core::arch::wasm32::*;
2+
3+
const MASK: [u8; 32] = [
4+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
5+
255, 255, 255, 255, 255, 255, 255,
6+
];
7+
8+
#[target_feature(enable = "simd128")]
9+
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> v128 {
10+
debug_assert!(
11+
offset + 16 <= slice.len(),
12+
"{} + 16 ≥ {}",
13+
offset,
14+
slice.len()
15+
);
16+
v128_load(slice.as_ptr().add(offset) as *const _)
17+
}
18+
19+
#[target_feature(enable = "simd128")]
20+
unsafe fn u8x16x4_from_offset(slice: &[u8], offset: usize) -> (v128, v128, v128, v128) {
21+
debug_assert!(
22+
offset + 64 <= slice.len(),
23+
"{} + 64 ≥ {}",
24+
offset,
25+
slice.len()
26+
);
27+
(
28+
v128_load(slice.as_ptr().add(offset + 0) as *const _),
29+
v128_load(slice.as_ptr().add(offset + 16) as *const _),
30+
v128_load(slice.as_ptr().add(offset + 32) as *const _),
31+
v128_load(slice.as_ptr().add(offset + 48) as *const _),
32+
)
33+
}
34+
35+
// TODO: We might want to amortize some additions by
36+
// keeping in multiple u16s and u32s respectively for a few ns
37+
#[target_feature(enable = "simd128")]
38+
unsafe fn sum(u8s: v128) -> usize {
39+
let u16s = u16x8_extadd_pairwise_u8x16(u8s);
40+
let u32s = u32x4_extadd_pairwise_u16x8(u16s);
41+
let (u1, u2, u3, u4) = (
42+
u32x4_extract_lane::<1>(u32s),
43+
u32x4_extract_lane::<2>(u32s),
44+
u32x4_extract_lane::<3>(u32s),
45+
u32x4_extract_lane::<4>(u32s),
46+
);
47+
((u1 + u2) + (u3 + u4)) as usize
48+
}
49+
50+
#[target_feature(enable = "simd128")]
51+
unsafe fn sum4(u1: v128, u2: v128, u3: v128, u4: v128) -> usize {
52+
// sum < (2^2 * 2^3 * 2^8 = 2^13) < 2^16, therefore no overflow here
53+
let u16s = u16x8_add(
54+
u16x8_add(u16x8_extadd_pairwise_u8x16(u1), u16x8_extadd_pairwise_u8x16(u2)),
55+
u16x8_add(u16x8_extadd_pairwise_u8x16(u3), u16x8_extadd_pairwise_u8x16(u4)),
56+
);
57+
let u32s = u32x4_extadd_pairwise_u16x8(u16s);
58+
let (u1, u2, u3, u4) = (
59+
u32x4_extract_lane::<1>(u32s),
60+
u32x4_extract_lane::<2>(u32s),
61+
u32x4_extract_lane::<3>(u32s),
62+
u32x4_extract_lane::<4>(u32s),
63+
);
64+
((u1 + u2) + (u3 + u4)) as usize
65+
}
66+
67+
#[target_feature(enable = "simd128")]
68+
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
69+
let needles = u8x16_splat(needle);
70+
let mut count = 0;
71+
let mut offset = 0;
72+
73+
while haystack.len() >= offset + 16 * 255 {
74+
let (mut count1, mut count2, mut count3, mut count4) =
75+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
76+
for _ in 0..255 {
77+
let (h1, h2, h3, h4) = u8x16x4_from_offset(haystack, offset);
78+
count1 = u8x16_sub(count1, u8x16_eq(h1, needles));
79+
count2 = u8x16_sub(count2, u8x16_eq(h2, needles));
80+
count3 = u8x16_sub(count3, u8x16_eq(h3, needles));
81+
count4 = u8x16_sub(count4, u8x16_eq(h4, needles));
82+
offset += 64;
83+
}
84+
count += sum4(count1, count2, count3, count4);
85+
}
86+
87+
// 64
88+
let (mut count1, mut count2, mut count3, mut count4) =
89+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
90+
for _ in 0..(haystack.len() - offset) / 64 {
91+
let (h1, h2, h3, h4) = u8x16x4_from_offset(haystack, offset);
92+
count1 = u8x16_sub(count1, u8x16_eq(h1, needles));
93+
count2 = u8x16_sub(count2, u8x16_eq(h2, needles));
94+
count3 = u8x16_sub(count3, u8x16_eq(h3, needles));
95+
count4 = u8x16_sub(count4, u8x16_eq(h4, needles));
96+
offset += 64;
97+
}
98+
count += sum4(count1, count2, count3, count4);
99+
100+
let mut counts = u8x16_splat(0);
101+
// 16
102+
for i in 0..(haystack.len() - offset) / 16 {
103+
counts = u8x16_sub(
104+
counts,
105+
u8x16_eq(u8x16_from_offset(haystack, offset + i * 16), needles),
106+
);
107+
}
108+
if haystack.len() % 16 != 0 {
109+
counts = u8x16_sub(
110+
counts,
111+
v128_and(
112+
u8x16_eq(u8x16_from_offset(haystack, haystack.len() - 16), needles),
113+
u8x16_from_offset(&MASK, haystack.len() % 16),
114+
),
115+
);
116+
}
117+
count + sum(counts)
118+
}
119+
120+
#[target_feature(enable = "simd128")]
121+
unsafe fn is_leading_utf8_byte(u8s: v128) -> v128 {
122+
u8x16_ne(
123+
v128_and(u8s, u8x16_splat(0b1100_0000)),
124+
u8x16_splat(0b1000_0000),
125+
)
126+
}
127+
128+
#[target_feature(enable = "simd128")]
129+
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
130+
assert!(utf8_chars.len() >= 16);
131+
132+
let mut offset = 0;
133+
let mut count = 0;
134+
135+
// 4080
136+
while utf8_chars.len() >= offset + 64 * 255 {
137+
let (mut count1, mut count2, mut count3, mut count4) =
138+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
139+
140+
for _ in 0..255 {
141+
let (h1, h2, h3, h4) = u8x16x4_from_offset(utf8_chars, offset);
142+
count1 = u8x16_sub(count1,is_leading_utf8_byte(h1));
143+
count2 = u8x16_sub(count2,is_leading_utf8_byte(h2));
144+
count3 = u8x16_sub(count3,is_leading_utf8_byte(h3));
145+
count4 = u8x16_sub(count4,is_leading_utf8_byte(h4));
146+
offset += 64;
147+
}
148+
count += sum4(count1, count2, count3, count4);
149+
}
150+
151+
// 4080
152+
let (mut count1, mut count2, mut count3, mut count4) =
153+
(u8x16_splat(0), u8x16_splat(0), u8x16_splat(0), u8x16_splat(0));
154+
for _ in 0..(utf8_chars.len() - offset) / 64 {
155+
let (h1, h2, h3, h4) = u8x16x4_from_offset(utf8_chars, offset);
156+
count1 = u8x16_sub(count1, is_leading_utf8_byte(h1));
157+
count2 = u8x16_sub(count2, is_leading_utf8_byte(h2));
158+
count3 = u8x16_sub(count3, is_leading_utf8_byte(h3));
159+
count4 = u8x16_sub(count4, is_leading_utf8_byte(h4));
160+
offset += 64;
161+
}
162+
count += sum4(count1, count2, count3, count4);
163+
164+
// 16
165+
let mut counts = u8x16_splat(0);
166+
for i in 0..(utf8_chars.len() - offset) / 16 {
167+
counts = u8x16_sub(
168+
counts,
169+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
170+
);
171+
}
172+
if utf8_chars.len() % 16 != 0 {
173+
counts = u8x16_sub(
174+
counts,
175+
v128_and(
176+
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
177+
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
178+
),
179+
);
180+
}
181+
count += sum(counts);
182+
183+
count
184+
}

0 commit comments

Comments
 (0)