Skip to content

Commit 5b22829

Browse files
committed
feat: add unmarshaler interface
Signed-off-by: Mark Sagi-Kazar <mark.sagikazar@gmail.com>
1 parent fd74c75 commit 5b22829

File tree

3 files changed

+665
-30
lines changed

3 files changed

+665
-30
lines changed

mapstructure.go

Lines changed: 97 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,17 @@ type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, any) (any, error)
218218
// values.
219219
type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (any, error)
220220

221+
// Unmarshaler is the interface implemented by types that can unmarshal
222+
// themselves. UnmarshalMapstructure receives the input data (potentially
223+
// transformed by DecodeHook) and should populate the receiver with the
224+
// decoded values.
225+
//
226+
// The Unmarshaler interface takes precedence over the default decoding
227+
// logic for any type (structs, slices, maps, primitives, etc.).
228+
type Unmarshaler interface {
229+
UnmarshalMapstructure(any) error
230+
}
231+
221232
// DecoderConfig is the configuration that is used to create a new decoder
222233
// and allows customization of various aspects of decoding.
223234
type DecoderConfig struct {
@@ -340,6 +351,11 @@ type DecoderConfig struct {
340351
// the initial lookup. If not found, MatchName is used as a fallback comparison.
341352
// Explicit struct tags always take precedence over MapFieldName.
342353
MapFieldName func(string) string
354+
355+
// DisableUnmarshaler, if set to true, disables the use of the Unmarshaler
356+
// interface. Types implementing Unmarshaler will be decoded using the
357+
// standard struct decoding logic instead.
358+
DisableUnmarshaler bool
343359
}
344360

345361
// A Decoder takes a raw interface value and turns it into structured
@@ -577,36 +593,50 @@ func (d *Decoder) decode(name string, input any, outVal reflect.Value) error {
577593

578594
var err error
579595
addMetaKey := true
580-
switch outputKind {
581-
case reflect.Bool:
582-
err = d.decodeBool(name, input, outVal)
583-
case reflect.Interface:
584-
err = d.decodeBasic(name, input, outVal)
585-
case reflect.String:
586-
err = d.decodeString(name, input, outVal)
587-
case reflect.Int:
588-
err = d.decodeInt(name, input, outVal)
589-
case reflect.Uint:
590-
err = d.decodeUint(name, input, outVal)
591-
case reflect.Float32:
592-
err = d.decodeFloat(name, input, outVal)
593-
case reflect.Complex64:
594-
err = d.decodeComplex(name, input, outVal)
595-
case reflect.Struct:
596-
err = d.decodeStruct(name, input, outVal)
597-
case reflect.Map:
598-
err = d.decodeMap(name, input, outVal)
599-
case reflect.Ptr:
600-
addMetaKey, err = d.decodePtr(name, input, outVal)
601-
case reflect.Slice:
602-
err = d.decodeSlice(name, input, outVal)
603-
case reflect.Array:
604-
err = d.decodeArray(name, input, outVal)
605-
case reflect.Func:
606-
err = d.decodeFunc(name, input, outVal)
607-
default:
608-
// If we reached this point then we weren't able to decode it
609-
return newDecodeError(name, fmt.Errorf("unsupported type: %s", outputKind))
596+
597+
// Check if the target implements Unmarshaler and use it if not disabled
598+
unmarshaled := false
599+
if !d.config.DisableUnmarshaler {
600+
if unmarshaler, ok := getUnmarshaler(outVal); ok {
601+
if err = unmarshaler.UnmarshalMapstructure(input); err != nil {
602+
err = newDecodeError(name, err)
603+
}
604+
unmarshaled = true
605+
}
606+
}
607+
608+
if !unmarshaled {
609+
switch outputKind {
610+
case reflect.Bool:
611+
err = d.decodeBool(name, input, outVal)
612+
case reflect.Interface:
613+
err = d.decodeBasic(name, input, outVal)
614+
case reflect.String:
615+
err = d.decodeString(name, input, outVal)
616+
case reflect.Int:
617+
err = d.decodeInt(name, input, outVal)
618+
case reflect.Uint:
619+
err = d.decodeUint(name, input, outVal)
620+
case reflect.Float32:
621+
err = d.decodeFloat(name, input, outVal)
622+
case reflect.Complex64:
623+
err = d.decodeComplex(name, input, outVal)
624+
case reflect.Struct:
625+
err = d.decodeStruct(name, input, outVal)
626+
case reflect.Map:
627+
err = d.decodeMap(name, input, outVal)
628+
case reflect.Ptr:
629+
addMetaKey, err = d.decodePtr(name, input, outVal)
630+
case reflect.Slice:
631+
err = d.decodeSlice(name, input, outVal)
632+
case reflect.Array:
633+
err = d.decodeArray(name, input, outVal)
634+
case reflect.Func:
635+
err = d.decodeFunc(name, input, outVal)
636+
default:
637+
// If we reached this point then we weren't able to decode it
638+
return newDecodeError(name, fmt.Errorf("unsupported type: %s", outputKind))
639+
}
610640
}
611641

612642
// If we reached here, then we successfully decoded SOMETHING, so
@@ -1860,3 +1890,40 @@ func splitTagNames(tagName string) []string {
18601890

18611891
return result
18621892
}
1893+
1894+
// unmarshalerType is cached for performance
1895+
var unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
1896+
1897+
// getUnmarshaler checks if the value implements Unmarshaler and returns
1898+
// the Unmarshaler and a boolean indicating if it was found. It handles both
1899+
// pointer and value receivers.
1900+
func getUnmarshaler(val reflect.Value) (Unmarshaler, bool) {
1901+
// Skip invalid or nil values
1902+
if !val.IsValid() {
1903+
return nil, false
1904+
}
1905+
1906+
switch val.Kind() {
1907+
case reflect.Pointer, reflect.Interface:
1908+
if val.IsNil() {
1909+
return nil, false
1910+
}
1911+
}
1912+
1913+
// Check pointer receiver first (most common case)
1914+
if val.CanAddr() {
1915+
ptrVal := val.Addr()
1916+
// Quick check: if no methods, can't implement any interface
1917+
if ptrVal.Type().NumMethod() > 0 && ptrVal.Type().Implements(unmarshalerType) {
1918+
return ptrVal.Interface().(Unmarshaler), true
1919+
}
1920+
}
1921+
1922+
// Check value receiver
1923+
// Quick check: if no methods, can't implement any interface
1924+
if val.Type().NumMethod() > 0 && val.CanInterface() && val.Type().Implements(unmarshalerType) {
1925+
return val.Interface().(Unmarshaler), true
1926+
}
1927+
1928+
return nil, false
1929+
}

mapstructure_examples_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,56 @@ func ExampleDecode_decodeHookFunc() {
357357
// Output:
358358
// mapstructure.Person{Name:"Mitchell", Location:mapstructure.PersonLocation{Latitude:-35.2809, Longtitude:149.13}}
359359
}
360+
361+
// ExampleServerConfig is used by ExampleDecode_unmarshaler.
362+
// It implements the Unmarshaler interface to apply custom decoding logic.
363+
type ExampleServerConfig struct {
364+
Host string
365+
Port int
366+
}
367+
368+
// UnmarshalMapstructure implements the Unmarshaler interface.
369+
// It applies default values when fields are missing from the input.
370+
func (s *ExampleServerConfig) UnmarshalMapstructure(data any) error {
371+
m, ok := data.(map[string]any)
372+
if !ok {
373+
return fmt.Errorf("expected map[string]any, got %T", data)
374+
}
375+
376+
// Apply defaults first
377+
s.Host = "localhost"
378+
s.Port = 8080
379+
380+
// Override with provided values
381+
if host, ok := m["host"].(string); ok {
382+
s.Host = host
383+
}
384+
if port, ok := m["port"].(int); ok {
385+
s.Port = port
386+
}
387+
388+
return nil
389+
}
390+
391+
func ExampleDecode_unmarshaler() {
392+
// Types that implement the Unmarshaler interface can control how they
393+
// are decoded from map data. This is useful for applying defaults,
394+
// custom validation, or complex transformation logic.
395+
396+
input := map[string]any{
397+
"host": "example.com",
398+
// Note: port is intentionally omitted to demonstrate default handling
399+
}
400+
401+
var result ExampleServerConfig
402+
err := Decode(input, &result)
403+
if err != nil {
404+
panic(err)
405+
}
406+
407+
// The Unmarshaler applied the default port value since it wasn't in the input
408+
fmt.Printf("Host: %s, Port: %d", result.Host, result.Port)
409+
410+
// Output:
411+
// Host: example.com, Port: 8080
412+
}

0 commit comments

Comments
 (0)