Skip to content

Commit be1402f

Browse files
Add --enable-method-chaining for the Rust generator. (#1586)
* Implement optional `--enable-method-chaining` for the Rust generator. * Add `enable_method_chaining` to the `generate` macro * Factor out `should_return_self`
1 parent 9f20dc3 commit be1402f

File tree

9 files changed

+164
-22
lines changed

9 files changed

+164
-22
lines changed

crates/guest-rust/macro/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ impl Parse for Config {
171171
}
172172
opts.async_ = val;
173173
}
174+
Opt::EnableMethodChaining(enable) => {
175+
opts.enable_method_chaining = enable.value();
176+
}
174177
}
175178
}
176179
} else {
@@ -324,6 +327,7 @@ mod kw {
324327
syn::custom_keyword!(disable_custom_section_link_helpers);
325328
syn::custom_keyword!(imports);
326329
syn::custom_keyword!(debug);
330+
syn::custom_keyword!(enable_method_chaining);
327331
}
328332

329333
#[derive(Clone)]
@@ -404,6 +408,7 @@ enum Opt {
404408
DisableCustomSectionLinkHelpers(syn::LitBool),
405409
Async(AsyncFilterSet, Span),
406410
Debug(syn::LitBool),
411+
EnableMethodChaining(syn::LitBool),
407412
}
408413

409414
impl Parse for Opt {
@@ -568,6 +573,10 @@ impl Parse for Opt {
568573
input.parse::<kw::debug>()?;
569574
input.parse::<Token![:]>()?;
570575
Ok(Opt::Debug(input.parse()?))
576+
} else if l.peek(kw::enable_method_chaining) {
577+
input.parse::<kw::enable_method_chaining>()?;
578+
input.parse::<Token![:]>()?;
579+
Ok(Opt::EnableMethodChaining(input.parse()?))
571580
} else if l.peek(Token![async]) {
572581
let span = input.parse::<Token![async]>()?.span;
573582
input.parse::<Token![:]>()?;

crates/guest-rust/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,11 @@ extern crate std;
860860
/// "-export:wasi:http/handler@0.3.0-draft#handle",
861861
/// "all",
862862
/// ],
863+
///
864+
/// // All resource methods with empty returns are instead generated as
865+
/// // returning `-> &Self`, to permit method chaining (e.g. for builder).
866+
/// // This expectation is also imposed on exports.
867+
/// enable_method_chaining: true,
863868
/// });
864869
/// ```
865870
///

crates/rust/src/bindgen.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub(super) struct FunctionBindgen<'a, 'b> {
2121
pub import_return_pointer_area_align: Alignment,
2222
pub handle_decls: Vec<String>,
2323
always_owned: bool,
24+
return_self: bool,
2425
}
2526

2627
pub const POINTER_SIZE_EXPRESSION: &str = "::core::mem::size_of::<*const u8>()";
@@ -31,6 +32,7 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> {
3132
params: Vec<String>,
3233
wasm_import_module: &'b str,
3334
always_owned: bool,
35+
return_self: bool,
3436
) -> FunctionBindgen<'a, 'b> {
3537
FunctionBindgen {
3638
r#gen,
@@ -45,6 +47,7 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> {
4547
import_return_pointer_area_align: Default::default(),
4648
handle_decls: Vec::new(),
4749
always_owned,
50+
return_self,
4851
}
4952
}
5053

@@ -1048,18 +1051,26 @@ impl Bindgen for FunctionBindgen<'_, '_> {
10481051
}
10491052
}
10501053

1051-
Instruction::Return { amt, .. } => match amt {
1052-
0 => {}
1053-
1 => {
1054-
self.push_str(&operands[0]);
1055-
self.push_str("\n");
1056-
}
1057-
_ => {
1058-
self.push_str("(");
1059-
self.push_str(&operands.join(", "));
1060-
self.push_str(")\n");
1054+
Instruction::Return { amt, .. } => {
1055+
assert!(!self.return_self || *amt == 0);
1056+
1057+
match amt {
1058+
0 => {
1059+
if self.return_self {
1060+
self.push_str("self\n");
1061+
}
1062+
}
1063+
1 => {
1064+
self.push_str(&operands[0]);
1065+
self.push_str("\n");
1066+
}
1067+
_ => {
1068+
self.push_str("(");
1069+
self.push_str(&operands.join(", "));
1070+
self.push_str(")\n");
1071+
}
10611072
}
1062-
},
1073+
}
10631074

10641075
Instruction::I32Load { offset } => {
10651076
let tmp = self.tmp();

crates/rust/src/interface.rs

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ pub mod vtable{ordinal} {{
768768
}
769769

770770
fn lower_to_memory(&mut self, address: &str, value: &str, ty: &Type, module: &str) -> String {
771-
let mut f = FunctionBindgen::new(self, Vec::new(), module, true);
771+
let mut f = FunctionBindgen::new(self, Vec::new(), module, true, false);
772772
abi::lower_to_memory(f.r#gen.resolve, &mut f, address.into(), value.into(), ty);
773773
format!("unsafe {{ {} }}", String::from(f.src))
774774
}
@@ -780,7 +780,7 @@ pub mod vtable{ordinal} {{
780780
indirect: bool,
781781
module: &str,
782782
) -> String {
783-
let mut f = FunctionBindgen::new(self, Vec::new(), module, true);
783+
let mut f = FunctionBindgen::new(self, Vec::new(), module, true, false);
784784
abi::deallocate_lists_in_types(f.r#gen.resolve, types, operands, indirect, &mut f);
785785
format!("unsafe {{ {} }}", String::from(f.src))
786786
}
@@ -792,13 +792,13 @@ pub mod vtable{ordinal} {{
792792
indirect: bool,
793793
module: &str,
794794
) -> String {
795-
let mut f = FunctionBindgen::new(self, Vec::new(), module, true);
795+
let mut f = FunctionBindgen::new(self, Vec::new(), module, true, false);
796796
abi::deallocate_lists_and_own_in_types(f.r#gen.resolve, types, operands, indirect, &mut f);
797797
format!("unsafe {{ {} }}", String::from(f.src))
798798
}
799799

800800
fn lift_from_memory(&mut self, address: &str, ty: &Type, module: &str) -> String {
801-
let mut f = FunctionBindgen::new(self, Vec::new(), module, true);
801+
let mut f = FunctionBindgen::new(self, Vec::new(), module, true, false);
802802
let result = abi::lift_from_memory(f.r#gen.resolve, &mut f, address.into(), ty);
803803
format!("unsafe {{ {}\n{result} }}", String::from(f.src))
804804
}
@@ -809,7 +809,13 @@ pub mod vtable{ordinal} {{
809809
func: &Function,
810810
params: Vec<String>,
811811
) {
812-
let mut f = FunctionBindgen::new(self, params, module, false);
812+
let mut f = FunctionBindgen::new(
813+
self,
814+
params,
815+
module,
816+
false,
817+
self.r#gen.should_return_self(func),
818+
);
813819
abi::call(
814820
f.r#gen.resolve,
815821
AbiVariant::GuestImport,
@@ -1032,7 +1038,7 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8)
10321038
}
10331039
lowers.push("ParamsLower(_ptr,)".to_string());
10341040
} else {
1035-
let mut f = FunctionBindgen::new(self, Vec::new(), module, true);
1041+
let mut f = FunctionBindgen::new(self, Vec::new(), module, true, false);
10361042
let mut results = Vec::new();
10371043
for (i, Param { ty, .. }) in func.params.iter().enumerate() {
10381044
let name = format!("_lower{i}");
@@ -1078,8 +1084,13 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8)
10781084
}
10791085
uwriteln!(
10801086
self.src,
1081-
"_MySubtask {{ _unused: core::marker::PhantomData }}.call(({})).await",
1082-
params.join(" ")
1087+
"_MySubtask {{ _unused: core::marker::PhantomData }}.call(({})).await{}",
1088+
params.join(" "),
1089+
if self.r#gen.should_return_self(func) {
1090+
";\nself"
1091+
} else {
1092+
""
1093+
}
10831094
);
10841095
}
10851096

@@ -1121,7 +1132,7 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8)
11211132
);
11221133
}
11231134

1124-
let mut f = FunctionBindgen::new(self, params, self.wasm_import_module, false);
1135+
let mut f = FunctionBindgen::new(self, params, self.wasm_import_module, false, false);
11251136
let variant = if async_ {
11261137
AbiVariant::GuestExportAsync
11271138
} else {
@@ -1191,7 +1202,7 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8)
11911202
let params = self.print_post_return_sig(func);
11921203
self.src.push_str("{ unsafe {\n");
11931204

1194-
let mut f = FunctionBindgen::new(self, params, self.wasm_import_module, false);
1205+
let mut f = FunctionBindgen::new(self, params, self.wasm_import_module, false, false);
11951206
abi::post_return(f.r#gen.resolve, func, &mut f);
11961207
let FunctionBindgen {
11971208
needs_cleanup_list,
@@ -1451,7 +1462,11 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8)
14511462
}
14521463
}
14531464
} else {
1454-
self.print_result_type(&func.result);
1465+
if self.r#gen.should_return_self(func) {
1466+
self.push_str("&Self");
1467+
} else {
1468+
self.print_result_type(&func.result);
1469+
}
14551470
}
14561471
params
14571472
}

crates/rust/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ pub struct Opts {
300300
arg(long, require_equals = true, value_name = "true|false")
301301
)]
302302
pub merge_structurally_equal_types: Option<Option<bool>>,
303+
304+
/// If true, methods normally returning `()` instead return `&Self`. This applies to both imported and exported methods.
305+
#[cfg_attr(feature = "clap", arg(long))]
306+
pub enable_method_chaining: bool,
303307
}
304308

305309
impl Opts {
@@ -1056,6 +1060,12 @@ macro_rules! __export_{world_name}_impl {{
10561060
.async_
10571061
.is_async(resolve, interface, func, is_import)
10581062
}
1063+
1064+
fn should_return_self(&self, func: &Function) -> bool {
1065+
self.opts.enable_method_chaining
1066+
&& func.result.is_none()
1067+
&& matches!(&func.kind, FunctionKind::Method(_))
1068+
}
10591069
}
10601070

10611071
impl WorldGenerator for RustWasm {

crates/rust/tests/codegen.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,22 @@ mod retyped_list {
216216
}
217217
});
218218
}
219+
220+
#[allow(unused, reason = "testing codegen, not functionality")]
221+
mod method_chaining {
222+
wit_bindgen::generate!({
223+
inline: r#"
224+
package test:method-chaining;
225+
world test {
226+
resource a {
227+
constructor();
228+
set-a: func(arg: u32);
229+
set-b: func(arg: bool);
230+
do: func();
231+
}
232+
}
233+
"#,
234+
generate_all,
235+
enable_method_chaining: true
236+
});
237+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//@ args = '--enable-method-chaining'
2+
3+
include!(env!("BINDINGS"));
4+
5+
use crate::foo::bar::i::A;
6+
7+
struct Component;
8+
export!(Component);
9+
10+
impl Guest for Component {
11+
fn run() {
12+
let my_a = A::new();
13+
my_a.set_a(42).set_b(true).do_();
14+
}
15+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//@ args = '--enable-method-chaining'
2+
3+
include!(env!("BINDINGS"));
4+
5+
use crate::exports::foo::bar::i::{Guest, GuestA};
6+
use std::cell::Cell;
7+
8+
struct Component;
9+
export!(Component);
10+
impl Guest for Component {
11+
type A = MyA;
12+
}
13+
14+
struct MyA {
15+
prop_a: Cell<u32>,
16+
prop_b: Cell<bool>,
17+
}
18+
19+
impl GuestA for MyA {
20+
fn new() -> MyA {
21+
MyA {
22+
prop_a: Cell::new(0),
23+
prop_b: Cell::new(false),
24+
}
25+
}
26+
27+
fn set_a(&self, a: u32) -> &Self {
28+
self.prop_a.set(a);
29+
self
30+
}
31+
32+
fn set_b(&self, b: bool) -> &Self {
33+
self.prop_b.set(b);
34+
self
35+
}
36+
37+
fn do_(&self) -> &Self {
38+
self
39+
}
40+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package foo:bar;
2+
3+
interface i {
4+
resource a {
5+
constructor();
6+
set-a: func(arg: u32);
7+
set-b: func(arg: bool);
8+
do: func();
9+
}
10+
}
11+
world runner {
12+
import i;
13+
14+
export run: func();
15+
}
16+
world test {
17+
export i;
18+
}

0 commit comments

Comments
 (0)