Skip to content

Commit 3cae985

Browse files
authored
Host-wasmtime-rust: import functions are able to Trap execution (#388)
* drop `branch = "main"` from wasmtime git dep this is implied since main is the default branch, and means that in other places that depend on the default branch, the wasmtime brought in by these crates is not the same dependency additionally, drop the dep on wasmtime-wasi altogether, it is not used. * wit-bindgen-cli: parse out name from before first `.` in filename to handle `*.wit.md` files correctly. * add a runtime test to see what codegen for result types looks like * track when types appear in the error position, and impl std::error::Error on them * std Error impls plus the From into wit_bindgen_host_wasmtime_rust::Error * limit the special case of errors just to functions with a single result which is a result<a, e> and e is a defined type. This ensures we can generate the std::error::Error impl for those types, otherwise we cant wrap them up into an anyhow::Error. * every import func gets the opportunity to trap * test trapping behavior works as intended * exercise some ways to go about trapping a bit more * fix ci * comment
1 parent 0968716 commit 3cae985

21 files changed

Lines changed: 718 additions & 197 deletions

File tree

crates/bindgen-core/src/lib.rs

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ pub struct Types {
169169
type_info: HashMap<TypeId, TypeInfo>,
170170
}
171171

172-
#[derive(Default, Clone, Copy)]
172+
#[derive(Default, Clone, Copy, Debug)]
173173
pub struct TypeInfo {
174174
/// Whether or not this type is ever used (transitively) within the
175175
/// parameter of a function.
@@ -179,6 +179,10 @@ pub struct TypeInfo {
179179
/// result of a function.
180180
pub result: bool,
181181

182+
/// Whether or not this type is ever used (transitively) within the
183+
/// error case in the result of a function.
184+
pub error: bool,
185+
182186
/// Whether or not this type (transitively) has a list.
183187
pub has_list: bool,
184188
}
@@ -187,6 +191,7 @@ impl std::ops::BitOrAssign for TypeInfo {
187191
fn bitor_assign(&mut self, rhs: Self) {
188192
self.param |= rhs.param;
189193
self.result |= rhs.result;
194+
self.error |= rhs.error;
190195
self.has_list |= rhs.has_list;
191196
}
192197
}
@@ -198,10 +203,10 @@ impl Types {
198203
}
199204
for f in iface.functions.iter() {
200205
for (_, ty) in f.params.iter() {
201-
self.set_param_result_ty(iface, ty, true, false);
206+
self.set_param_result_ty(iface, ty, true, false, false);
202207
}
203208
for ty in f.results.iter_types() {
204-
self.set_param_result_ty(iface, ty, false, true);
209+
self.set_param_result_ty(iface, ty, false, true, false);
205210
}
206211
}
207212
}
@@ -281,56 +286,77 @@ impl Types {
281286
}
282287
}
283288

284-
fn set_param_result_id(&mut self, iface: &Interface, ty: TypeId, param: bool, result: bool) {
289+
fn set_param_result_id(
290+
&mut self,
291+
iface: &Interface,
292+
ty: TypeId,
293+
param: bool,
294+
result: bool,
295+
error: bool,
296+
) {
285297
match &iface.types[ty].kind {
286298
TypeDefKind::Record(r) => {
287299
for field in r.fields.iter() {
288-
self.set_param_result_ty(iface, &field.ty, param, result)
300+
self.set_param_result_ty(iface, &field.ty, param, result, error)
289301
}
290302
}
291303
TypeDefKind::Tuple(t) => {
292304
for ty in t.types.iter() {
293-
self.set_param_result_ty(iface, ty, param, result)
305+
self.set_param_result_ty(iface, ty, param, result, error)
294306
}
295307
}
296308
TypeDefKind::Flags(_) => {}
297309
TypeDefKind::Enum(_) => {}
298310
TypeDefKind::Variant(v) => {
299311
for case in v.cases.iter() {
300-
self.set_param_result_optional_ty(iface, case.ty.as_ref(), param, result)
312+
self.set_param_result_optional_ty(iface, case.ty.as_ref(), param, result, error)
301313
}
302314
}
303315
TypeDefKind::List(ty) | TypeDefKind::Type(ty) | TypeDefKind::Option(ty) => {
304-
self.set_param_result_ty(iface, ty, param, result)
316+
self.set_param_result_ty(iface, ty, param, result, error)
305317
}
306318
TypeDefKind::Result(r) => {
307-
self.set_param_result_optional_ty(iface, r.ok.as_ref(), param, result);
308-
self.set_param_result_optional_ty(iface, r.err.as_ref(), param, result);
319+
self.set_param_result_optional_ty(iface, r.ok.as_ref(), param, result, error);
320+
self.set_param_result_optional_ty(iface, r.err.as_ref(), param, result, result);
309321
}
310322
TypeDefKind::Union(u) => {
311323
for case in u.cases.iter() {
312-
self.set_param_result_ty(iface, &case.ty, param, result)
324+
self.set_param_result_ty(iface, &case.ty, param, result, error)
313325
}
314326
}
315327
TypeDefKind::Future(ty) => {
316-
self.set_param_result_optional_ty(iface, ty.as_ref(), param, result)
328+
self.set_param_result_optional_ty(iface, ty.as_ref(), param, result, error)
317329
}
318330
TypeDefKind::Stream(stream) => {
319-
self.set_param_result_optional_ty(iface, stream.element.as_ref(), param, result);
320-
self.set_param_result_optional_ty(iface, stream.end.as_ref(), param, result);
331+
self.set_param_result_optional_ty(
332+
iface,
333+
stream.element.as_ref(),
334+
param,
335+
result,
336+
error,
337+
);
338+
self.set_param_result_optional_ty(iface, stream.end.as_ref(), param, result, error);
321339
}
322340
}
323341
}
324342

325-
fn set_param_result_ty(&mut self, iface: &Interface, ty: &Type, param: bool, result: bool) {
343+
fn set_param_result_ty(
344+
&mut self,
345+
iface: &Interface,
346+
ty: &Type,
347+
param: bool,
348+
result: bool,
349+
error: bool,
350+
) {
326351
match ty {
327352
Type::Id(id) => {
328353
self.type_id_info(iface, *id);
329354
let info = self.type_info.get_mut(id).unwrap();
330-
if (param && !info.param) || (result && !info.result) {
355+
if (param && !info.param) || (result && !info.result) || (error && !info.error) {
331356
info.param = info.param || param;
332357
info.result = info.result || result;
333-
self.set_param_result_id(iface, *id, param, result);
358+
info.error = info.error || error;
359+
self.set_param_result_id(iface, *id, param, result, error);
334360
}
335361
}
336362
_ => {}
@@ -343,9 +369,10 @@ impl Types {
343369
ty: Option<&Type>,
344370
param: bool,
345371
result: bool,
372+
error: bool,
346373
) {
347374
match ty {
348-
Some(ty) => self.set_param_result_ty(iface, ty, param, result),
375+
Some(ty) => self.set_param_result_ty(iface, ty, param, result, error),
349376
None => (),
350377
}
351378
}

crates/gen-host-wasmtime-rust/src/lib.rs

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ impl WorldGenerator for Wasmtime {
5353
fn import(&mut self, name: &str, iface: &Interface, _files: &mut Files) {
5454
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::Owned);
5555
gen.types();
56+
gen.generate_from_error_impls();
5657
gen.generate_add_to_linker(name);
5758

5859
let snake = name.to_snake_case();
@@ -77,6 +78,7 @@ impl WorldGenerator for Wasmtime {
7778
fn export(&mut self, name: &str, iface: &Interface, _files: &mut Files) {
7879
let mut gen = InterfaceGenerator::new(self, iface, TypeMode::AllBorrowed("'a"));
7980
gen.types();
81+
gen.generate_from_error_impls();
8082

8183
let camel = name.to_upper_camel_case();
8284
uwriteln!(gen.src, "pub struct {camel} {{");
@@ -312,6 +314,27 @@ impl<'a> InterfaceGenerator<'a> {
312314
}
313315
}
314316

317+
fn special_case_host_error(&self, results: &Results) -> Option<&Result_> {
318+
// We only support the wit_bindgen_host_wasmtime_rust::Error case when
319+
// a function has just one result, which is itself a `result<a, e>`, and the
320+
// `e` is *not* a primitive (i.e. defined in std) type.
321+
let mut i = results.iter_types();
322+
if i.len() == 1 {
323+
match i.next().unwrap() {
324+
Type::Id(id) => match &self.iface.types[*id].kind {
325+
TypeDefKind::Result(r) => match r.err {
326+
Some(Type::Id(_)) => Some(&r),
327+
_ => None,
328+
},
329+
_ => None,
330+
},
331+
_ => None,
332+
}
333+
} else {
334+
None
335+
}
336+
}
337+
315338
fn generate_add_to_linker(&mut self, name: &str) {
316339
let camel = name.to_upper_camel_case();
317340

@@ -327,12 +350,34 @@ impl<'a> InterfaceGenerator<'a> {
327350
fnsig.private = true;
328351
fnsig.self_arg = Some("&mut self".to_string());
329352

330-
// These trait method args used to be TypeMode::LeafBorrowed, but wasmtime
331-
// Lift is not impled for borrowed types, so I don't think we can
332-
// support that anymore?
333353
self.print_docs_and_params(func, TypeMode::Owned, &fnsig);
334354
self.push_str(" -> ");
335-
self.print_result_ty(&func.results, TypeMode::Owned);
355+
356+
if let Some(r) = self.special_case_host_error(&func.results).cloned() {
357+
// Functions which have a single result `result<ok,err>` get special
358+
// cased to use the host_wasmtime_rust::Error<err>, making it possible
359+
// for them to trap or use `?` to propogate their errors
360+
self.push_str("wit_bindgen_host_wasmtime_rust::Result<");
361+
if let Some(ok) = r.ok {
362+
self.print_ty(&ok, TypeMode::Owned);
363+
} else {
364+
self.push_str("()");
365+
}
366+
self.push_str(",");
367+
if let Some(err) = r.err {
368+
self.print_ty(&err, TypeMode::Owned);
369+
} else {
370+
self.push_str("()");
371+
}
372+
self.push_str(">");
373+
} else {
374+
// All other functions get their return values wrapped in an anyhow::Result.
375+
// Returning the anyhow::Error case can be used to trap.
376+
self.push_str("anyhow::Result<");
377+
self.print_result_ty(&func.results, TypeMode::Owned);
378+
self.push_str(">");
379+
}
380+
336381
self.push_str(";\n");
337382
}
338383
uwriteln!(self.src, "}}");
@@ -420,10 +465,22 @@ impl<'a> InterfaceGenerator<'a> {
420465
} else {
421466
uwrite!(self.src, ");\n");
422467
}
423-
if func.results.iter_types().len() == 1 {
424-
uwrite!(self.src, "Ok((r,))\n");
468+
469+
if self.special_case_host_error(&func.results).is_some() {
470+
uwrite!(
471+
self.src,
472+
"match r {{
473+
Ok(a) => Ok((Ok(a),)),
474+
Err(e) => match e.downcast() {{
475+
Ok(api_error) => Ok((Err(api_error),)),
476+
Err(anyhow_error) => Err(anyhow_error),
477+
}}
478+
}}"
479+
);
480+
} else if func.results.iter_types().len() == 1 {
481+
uwrite!(self.src, "Ok((r?,))\n");
425482
} else {
426-
uwrite!(self.src, "Ok(r)\n");
483+
uwrite!(self.src, "r\n");
427484
}
428485

429486
if self.gen.opts.async_ {
@@ -553,6 +610,36 @@ impl<'a> InterfaceGenerator<'a> {
553610
// End function body
554611
self.src.push_str("}\n");
555612
}
613+
614+
fn generate_from_error_impls(&mut self) {
615+
for (id, ty) in self.iface.types.iter() {
616+
if ty.name.is_none() {
617+
continue;
618+
}
619+
let info = self.info(id);
620+
if info.error {
621+
for (name, mode) in self.modes_of(id) {
622+
let name = name.to_upper_camel_case();
623+
if self.lifetime_for(&info, mode).is_some() {
624+
continue;
625+
}
626+
self.push_str("impl From<");
627+
self.push_str(&name);
628+
self.push_str("> for wit_bindgen_host_wasmtime_rust::Error<");
629+
self.push_str(&name);
630+
self.push_str("> {\n");
631+
self.push_str("fn from(e: ");
632+
self.push_str(&name);
633+
self.push_str(") -> wit_bindgen_host_wasmtime_rust::Error::< ");
634+
self.push_str(&name);
635+
self.push_str("> {\n");
636+
self.push_str("wit_bindgen_host_wasmtime_rust::Error::new(e)\n");
637+
self.push_str("}\n");
638+
self.push_str("}\n");
639+
}
640+
}
641+
}
642+
}
556643
}
557644

558645
impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> {

crates/gen-host-wasmtime-rust/tests/runtime.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,11 @@ wit_bindgen_host_wasmtime_rust::generate!({
9292
pub struct TestWasi;
9393

9494
impl testwasi::Testwasi for TestWasi {
95-
fn log(&mut self, bytes: Vec<u8>) {
95+
fn log(&mut self, bytes: Vec<u8>) -> Result<()> {
9696
match std::str::from_utf8(&bytes) {
9797
Ok(s) => print!("{}", s),
9898
Err(_) => println!("\nbinary: {:?}", bytes),
9999
}
100+
Ok(())
100101
}
101102
}

crates/gen-rust-lib/src/lib.rs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,26 @@ pub trait RustGenerator<'a> {
537537
self.push_str(".finish()\n");
538538
self.push_str("}\n");
539539
self.push_str("}\n");
540+
541+
if info.error {
542+
self.push_str("impl");
543+
self.print_generics(lt);
544+
self.push_str(" core::fmt::Display for ");
545+
self.push_str(&name);
546+
self.print_generics(lt);
547+
self.push_str(" {\n");
548+
self.push_str(
549+
"fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n",
550+
);
551+
self.push_str("write!(f, \"{:?}\", self)\n");
552+
self.push_str("}\n");
553+
self.push_str("}\n");
554+
if self.use_std() {
555+
self.push_str("impl std::error::Error for ");
556+
self.push_str(&name);
557+
self.push_str("{}\n");
558+
}
559+
}
540560
}
541561
}
542562

@@ -619,6 +639,7 @@ pub trait RustGenerator<'a> {
619639

620640
for (name, mode) in self.modes_of(id) {
621641
let name = name.to_upper_camel_case();
642+
622643
self.rustdoc(docs);
623644
let lt = self.lifetime_for(&info, mode);
624645
if let Some(derive_component) = derive_component {
@@ -663,6 +684,31 @@ pub trait RustGenerator<'a> {
663684
.into_iter()
664685
.map(|(name, _attr, _docs, ty)| (name, ty)),
665686
);
687+
688+
if info.error {
689+
self.push_str("impl");
690+
self.print_generics(lt);
691+
self.push_str(" core::fmt::Display for ");
692+
self.push_str(&name);
693+
self.print_generics(lt);
694+
self.push_str(" {\n");
695+
self.push_str(
696+
"fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n",
697+
);
698+
self.push_str("write!(f, \"{:?}\", self)");
699+
self.push_str("}\n");
700+
self.push_str("}\n");
701+
self.push_str("\n");
702+
703+
if self.use_std() {
704+
self.push_str("impl");
705+
self.print_generics(lt);
706+
self.push_str(" std::error::Error for ");
707+
self.push_str(&name);
708+
self.print_generics(lt);
709+
self.push_str(" {}\n");
710+
}
711+
}
666712
}
667713
}
668714

@@ -746,8 +792,7 @@ pub trait RustGenerator<'a> {
746792
) where
747793
Self: Sized,
748794
{
749-
// TODO: should this perhaps be an attribute in the wit file?
750-
let is_error = name.contains("errno");
795+
let info = self.info(id);
751796

752797
let name = name.to_upper_camel_case();
753798
self.rustdoc(docs);
@@ -768,7 +813,7 @@ pub trait RustGenerator<'a> {
768813

769814
// Auto-synthesize an implementation of the standard `Error` trait for
770815
// error-looking types based on their name.
771-
if is_error {
816+
if info.error {
772817
self.push_str("impl ");
773818
self.push_str(&name);
774819
self.push_str("{\n");

0 commit comments

Comments
 (0)