Skip to content

Commit 7b5c1c6

Browse files
authored
feat(go): add map type support (#1583)
* feat(go): add map type support to Go backend * refactor(go): combine duplicate GuestDeallocate match arms
1 parent 4992ae6 commit 7b5c1c6

File tree

4 files changed

+337
-5
lines changed

4 files changed

+337
-5
lines changed

crates/go/src/lib.rs

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ impl Go {
395395
TypeDefKind::Type(ty) => {
396396
self.type_name(resolve, *ty, local, in_import, imports)
397397
}
398+
TypeDefKind::Map(key, value) => {
399+
let key = self.type_name(resolve, *key, local, in_import, imports);
400+
let value = self.type_name(resolve, *value, local, in_import, imports);
401+
format!("map[{key}]{value}")
402+
}
398403
_ => todo!("{:?}", ty.kind),
399404
}
400405
}
@@ -696,6 +701,13 @@ func Lift{upper_kind}{camel}(handle int32) *witTypes.{upper_kind}Reader[{payload
696701
.unwrap_or_else(|| "unit".into())
697702
)
698703
}
704+
TypeDefKind::Map(key, value) => {
705+
format!(
706+
"map_{}_{}",
707+
self.mangle_name(resolve, *key, local),
708+
self.mangle_name(resolve, *value, local)
709+
)
710+
}
699711
kind => todo!("{kind:?}"),
700712
}
701713
}
@@ -2546,7 +2558,62 @@ lifters = append(lifters, func() {{
25462558
let handle = &operands[0];
25472559
results.push(format!("{package}LiftStream{camel}({handle})"));
25482560
}
2549-
Instruction::GuestDeallocate { .. } => {
2561+
Instruction::MapLower { key, value, .. } => {
2562+
self.need_unsafe = true;
2563+
self.need_pinner = true;
2564+
self.imports.insert(remote_pkg("runtime"));
2565+
let (body, _) = self.blocks.pop().unwrap();
2566+
let value_op = &operands[0];
2567+
let src_map = self.locals.tmp("srcMap");
2568+
let result = self.locals.tmp("result");
2569+
let length = self.locals.tmp("length");
2570+
let idx = self.locals.tmp("idx");
2571+
let entry = self.generator.sizes.record([*key, *value]);
2572+
let size = entry.size.format(POINTER_SIZE_EXPRESSION);
2573+
let align = entry.align.format(POINTER_SIZE_EXPRESSION);
2574+
uwriteln!(
2575+
self.src,
2576+
"{src_map} := {value_op}
2577+
{length} := uint32(len({src_map}))
2578+
{result} := witRuntime.Allocate({PINNER}, uintptr({length} * {size}), {align})
2579+
var {idx} int
2580+
for map_key, map_value := range {src_map} {{
2581+
{ITER_BASE_POINTER} := unsafe.Add({result}, {idx} * {size})
2582+
{body}
2583+
{idx}++
2584+
}}
2585+
"
2586+
);
2587+
results.push(format!("uintptr({result})"));
2588+
results.push(length);
2589+
}
2590+
Instruction::MapLift { key, value, .. } => {
2591+
self.need_unsafe = true;
2592+
let (body, body_results) = self.blocks.pop().unwrap();
2593+
let pointer = &operands[0];
2594+
let length = &operands[1];
2595+
let result = self.locals.tmp("result");
2596+
let entry = self.generator.sizes.record([*key, *value]);
2597+
let size = entry.size.format(POINTER_SIZE_EXPRESSION);
2598+
let key_type = self.type_name(resolve, **key);
2599+
let value_type = self.type_name(resolve, **value);
2600+
let body_key = &body_results[0];
2601+
let body_value = &body_results[1];
2602+
uwriteln!(
2603+
self.src,
2604+
"{result} := make(map[{key_type}]{value_type}, {length})
2605+
for index := 0; index < int({length}); index++ {{
2606+
{ITER_BASE_POINTER} := unsafe.Add(unsafe.Pointer({pointer}), index * {size})
2607+
{body}
2608+
{result}[{body_key}] = {body_value}
2609+
}}
2610+
"
2611+
);
2612+
results.push(result);
2613+
}
2614+
Instruction::IterMapKey { .. } => results.push("map_key".into()),
2615+
Instruction::IterMapValue { .. } => results.push("map_value".into()),
2616+
Instruction::GuestDeallocateMap { .. } | Instruction::GuestDeallocate { .. } => {
25502617
// Nothing to do here; should be handled when calling `pinner.Unpin()`
25512618
}
25522619
_ => unimplemented!("{instruction:?}"),
@@ -2980,8 +3047,12 @@ const (
29803047
uwriteln!(self.src, "{docs}type {name} = [{size}]{ty}");
29813048
}
29823049

2983-
fn type_map(&mut self, _id: TypeId, _name: &str, _key: &Type, _value: &Type, _docs: &Docs) {
2984-
todo!("map types are not yet supported in the Go backend")
3050+
fn type_map(&mut self, _id: TypeId, name: &str, key: &Type, value: &Type, docs: &Docs) {
3051+
let name = name.to_upper_camel_case();
3052+
let key = self.type_name(self.resolve, *key);
3053+
let value = self.type_name(self.resolve, *value);
3054+
let docs = format_docs(docs);
3055+
uwriteln!(self.src, "{docs}type {name} = map[{key}]{value}");
29853056
}
29863057

29873058
fn type_builtin(&mut self, id: TypeId, name: &str, ty: &Type, docs: &Docs) {
@@ -3146,6 +3217,9 @@ fn any(resolve: &Resolve, ty: Type, fun: &dyn Fn(Type) -> bool) -> bool {
31463217
TypeDefKind::Future(ty) | TypeDefKind::Stream(ty) => {
31473218
ty.map(|ty| any(resolve, ty, fun)).unwrap_or(false)
31483219
}
3220+
TypeDefKind::Map(key, value) => {
3221+
any(resolve, *key, fun) || any(resolve, *value, fun)
3222+
}
31493223
_ => todo!("{:?}", ty.kind),
31503224
}
31513225
}

crates/test/src/go.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ impl LanguageMethods for Go {
3030
config.error_context
3131
|| name == "async-trait-function.wit"
3232
|| name == "named-fixed-length-list.wit"
33-
|| name == "map.wit"
3433
}
3534

3635
fn default_bindgen_args_for_codegen(&self) -> &[&str] {
@@ -162,7 +161,7 @@ fn replace_bindings_go_mod(runner: &Runner, bindings_dir: &Path) -> Result<()> {
162161
super::write_if_different(
163162
&bindings_dir.join("go.mod"),
164163
format!(
165-
"module wit_component\n\ngo 1.25\n\nreplace go.bytecodealliance.org => {}",
164+
"module wit_component\n\ngo 1.25\n\nreplace go.bytecodealliance.org/pkg => {}",
166165
go_package_path.display()
167166
),
168167
)?;

tests/runtime/map/runner.go

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
//@ wasmtime-flags = '-Wcomponent-model-map'
2+
3+
package export_wit_world
4+
5+
import (
6+
"fmt"
7+
test "wit_component/test_maps_to_test"
8+
9+
. "go.bytecodealliance.org/pkg/wit/types"
10+
)
11+
12+
func Run() {
13+
testNamedRoundtrip()
14+
testBytesRoundtrip()
15+
testEmptyRoundtrip()
16+
testOptionRoundtrip()
17+
testRecordRoundtrip()
18+
testInlineRoundtrip()
19+
testLargeRoundtrip()
20+
testMultiParamRoundtrip()
21+
testNestedRoundtrip()
22+
testVariantRoundtrip()
23+
testResultRoundtrip()
24+
testTupleRoundtrip()
25+
testSingleEntryRoundtrip()
26+
}
27+
28+
func testNamedRoundtrip() {
29+
input := test.NamesById{
30+
1: "uno",
31+
2: "two",
32+
}
33+
result := test.NamedRoundtrip(input)
34+
assertEqual(result["uno"], uint32(1))
35+
assertEqual(result["two"], uint32(2))
36+
}
37+
38+
func testBytesRoundtrip() {
39+
input := test.BytesByName{
40+
"hello": []uint8("world"),
41+
"bin": {0, 1, 2},
42+
}
43+
result := test.BytesRoundtrip(input)
44+
assertSliceEqual(result["hello"], []uint8("world"))
45+
assertSliceEqual(result["bin"], []uint8{0, 1, 2})
46+
}
47+
48+
func testEmptyRoundtrip() {
49+
input := test.NamesById{}
50+
result := test.EmptyRoundtrip(input)
51+
assertEqual(len(result), 0)
52+
}
53+
54+
func testOptionRoundtrip() {
55+
input := map[string]Option[uint32]{
56+
"some": Some[uint32](42),
57+
"none": None[uint32](),
58+
}
59+
result := test.OptionRoundtrip(input)
60+
assertEqual(len(result), 2)
61+
assertEqual(result["some"].Some(), uint32(42))
62+
assertEqual(result["none"].Tag(), OptionNone)
63+
}
64+
65+
func testRecordRoundtrip() {
66+
entry := test.LabeledEntry{
67+
Label: "test-label",
68+
Values: test.NamesById{
69+
10: "ten",
70+
20: "twenty",
71+
},
72+
}
73+
result := test.RecordRoundtrip(entry)
74+
assertEqual(result.Label, "test-label")
75+
assertEqual(len(result.Values), 2)
76+
assertEqual(result.Values[10], "ten")
77+
assertEqual(result.Values[20], "twenty")
78+
}
79+
80+
func testInlineRoundtrip() {
81+
input := map[uint32]string{
82+
1: "one",
83+
2: "two",
84+
}
85+
result := test.InlineRoundtrip(input)
86+
assertEqual(len(result), 2)
87+
assertEqual(result["one"], uint32(1))
88+
assertEqual(result["two"], uint32(2))
89+
}
90+
91+
func testLargeRoundtrip() {
92+
input := make(test.NamesById)
93+
for i := uint32(0); i < 100; i++ {
94+
input[i] = fmt.Sprintf("value-%d", i)
95+
}
96+
result := test.LargeRoundtrip(input)
97+
assertEqual(len(result), 100)
98+
for i := uint32(0); i < 100; i++ {
99+
assertEqual(result[i], fmt.Sprintf("value-%d", i))
100+
}
101+
}
102+
103+
func testMultiParamRoundtrip() {
104+
names := test.NamesById{
105+
1: "one",
106+
2: "two",
107+
}
108+
bytes := test.BytesByName{
109+
"key": {42},
110+
}
111+
ids, bytesOut := test.MultiParamRoundtrip(names, bytes)
112+
assertEqual(len(ids), 2)
113+
assertEqual(ids["one"], uint32(1))
114+
assertEqual(ids["two"], uint32(2))
115+
assertEqual(len(bytesOut), 1)
116+
assertSliceEqual(bytesOut["key"], []uint8{42})
117+
}
118+
119+
func testNestedRoundtrip() {
120+
input := map[string]map[uint32]string{
121+
"group-a": {
122+
1: "one",
123+
2: "two",
124+
},
125+
"group-b": {
126+
10: "ten",
127+
},
128+
}
129+
result := test.NestedRoundtrip(input)
130+
assertEqual(len(result), 2)
131+
assertEqual(result["group-a"][1], "one")
132+
assertEqual(result["group-a"][2], "two")
133+
assertEqual(result["group-b"][10], "ten")
134+
}
135+
136+
func testVariantRoundtrip() {
137+
m := test.NamesById{1: "one"}
138+
asMap := test.VariantRoundtrip(test.MakeMapOrStringAsMap(m))
139+
assertEqual(asMap.Tag(), test.MapOrStringAsMap)
140+
assertEqual(asMap.AsMap()[1], "one")
141+
142+
asStr := test.VariantRoundtrip(test.MakeMapOrStringAsString("hello"))
143+
assertEqual(asStr.Tag(), test.MapOrStringAsString)
144+
assertEqual(asStr.AsString(), "hello")
145+
}
146+
147+
func testResultRoundtrip() {
148+
m := test.NamesById{5: "five"}
149+
okResult := test.ResultRoundtrip(Ok[test.NamesById, string](m))
150+
assertEqual(okResult.Tag(), ResultOk)
151+
assertEqual(okResult.Ok()[5], "five")
152+
153+
errResult := test.ResultRoundtrip(Err[test.NamesById, string]("bad input"))
154+
assertEqual(errResult.Tag(), ResultErr)
155+
assertEqual(errResult.Err(), "bad input")
156+
}
157+
158+
func testTupleRoundtrip() {
159+
m := test.NamesById{7: "seven"}
160+
resultMap, resultNum := test.TupleRoundtrip(Tuple2[test.NamesById, uint64]{m, 42})
161+
assertEqual(len(resultMap), 1)
162+
assertEqual(resultMap[7], "seven")
163+
assertEqual(resultNum, uint64(42))
164+
}
165+
166+
func testSingleEntryRoundtrip() {
167+
input := test.NamesById{99: "ninety-nine"}
168+
result := test.SingleEntryRoundtrip(input)
169+
assertEqual(len(result), 1)
170+
assertEqual(result[99], "ninety-nine")
171+
}
172+
173+
func assertEqual[T comparable](a T, b T) {
174+
if a != b {
175+
panic(fmt.Sprintf("%v not equal to %v", a, b))
176+
}
177+
}
178+
179+
func assertSliceEqual[T comparable](a []T, b []T) {
180+
if len(a) != len(b) {
181+
panic(fmt.Sprintf("slices have different lengths: %d vs %d", len(a), len(b)))
182+
}
183+
for i := range a {
184+
if a[i] != b[i] {
185+
panic(fmt.Sprintf("slices differ at index %d: %v vs %v", i, a[i], b[i]))
186+
}
187+
}
188+
}

tests/runtime/map/test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package export_test_maps_to_test
2+
3+
import (
4+
. "wit_component/test_maps_to_test"
5+
6+
. "go.bytecodealliance.org/pkg/wit/types"
7+
)
8+
9+
func NamedRoundtrip(a NamesById) IdsByName {
10+
result := make(IdsByName)
11+
for id, name := range a {
12+
result[name] = id
13+
}
14+
return result
15+
}
16+
17+
func BytesRoundtrip(a BytesByName) BytesByName {
18+
return a
19+
}
20+
21+
func EmptyRoundtrip(a NamesById) NamesById {
22+
return a
23+
}
24+
25+
func OptionRoundtrip(a map[string]Option[uint32]) map[string]Option[uint32] {
26+
return a
27+
}
28+
29+
func RecordRoundtrip(a LabeledEntry) LabeledEntry {
30+
return a
31+
}
32+
33+
func InlineRoundtrip(a map[uint32]string) map[string]uint32 {
34+
result := make(map[string]uint32)
35+
for k, v := range a {
36+
result[v] = k
37+
}
38+
return result
39+
}
40+
41+
func LargeRoundtrip(a NamesById) NamesById {
42+
return a
43+
}
44+
45+
func MultiParamRoundtrip(a NamesById, b BytesByName) (IdsByName, BytesByName) {
46+
ids := make(IdsByName)
47+
for id, name := range a {
48+
ids[name] = id
49+
}
50+
return ids, b
51+
}
52+
53+
func NestedRoundtrip(a map[string]map[uint32]string) map[string]map[uint32]string {
54+
return a
55+
}
56+
57+
func VariantRoundtrip(a MapOrString) MapOrString {
58+
return a
59+
}
60+
61+
func ResultRoundtrip(a Result[NamesById, string]) Result[NamesById, string] {
62+
return a
63+
}
64+
65+
func TupleRoundtrip(a Tuple2[NamesById, uint64]) (NamesById, uint64) {
66+
return a.F0, a.F1
67+
}
68+
69+
func SingleEntryRoundtrip(a NamesById) NamesById {
70+
return a
71+
}

0 commit comments

Comments
 (0)