Skip to content

Commit 9aa3f77

Browse files
Merge pull request #166 from go-viper/unmarshal2
Add unmarshaler interface
2 parents fd74c75 + ae32a61 commit 9aa3f77

File tree

3 files changed

+709
-30
lines changed

3 files changed

+709
-30
lines changed

mapstructure.go

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,25 @@
173173
// Public: "I made it through!"
174174
// }
175175
//
176+
// # Custom Decoding with Unmarshaler
177+
//
178+
// Types can implement the Unmarshaler interface to control their own decoding. The interface
179+
// behaves similarly to how UnmarshalJSON does in the standard library. It can be used as an
180+
// alternative or companion to a DecodeHook.
181+
//
182+
// type TrimmedString string
183+
//
184+
// func (t *TrimmedString) UnmarshalMapstructure(input any) error {
185+
// str, ok := input.(string)
186+
// if !ok {
187+
// return fmt.Errorf("expected string, got %T", input)
188+
// }
189+
// *t = TrimmedString(strings.TrimSpace(str))
190+
// return nil
191+
// }
192+
//
193+
// See the Unmarshaler interface documentation for more details.
194+
//
176195
// # Other Configuration
177196
//
178197
// mapstructure is highly configurable. See the DecoderConfig struct
@@ -218,6 +237,17 @@ type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, any) (any, error)
218237
// values.
219238
type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (any, error)
220239

240+
// Unmarshaler is the interface implemented by types that can unmarshal
241+
// themselves. UnmarshalMapstructure receives the input data (potentially
242+
// transformed by DecodeHook) and should populate the receiver with the
243+
// decoded values.
244+
//
245+
// The Unmarshaler interface takes precedence over the default decoding
246+
// logic for any type (structs, slices, maps, primitives, etc.).
247+
type Unmarshaler interface {
248+
UnmarshalMapstructure(any) error
249+
}
250+
221251
// DecoderConfig is the configuration that is used to create a new decoder
222252
// and allows customization of various aspects of decoding.
223253
type DecoderConfig struct {
@@ -340,6 +370,11 @@ type DecoderConfig struct {
340370
// the initial lookup. If not found, MatchName is used as a fallback comparison.
341371
// Explicit struct tags always take precedence over MapFieldName.
342372
MapFieldName func(string) string
373+
374+
// DisableUnmarshaler, if set to true, disables the use of the Unmarshaler
375+
// interface. Types implementing Unmarshaler will be decoded using the
376+
// standard struct decoding logic instead.
377+
DisableUnmarshaler bool
343378
}
344379

345380
// A Decoder takes a raw interface value and turns it into structured
@@ -577,36 +612,50 @@ func (d *Decoder) decode(name string, input any, outVal reflect.Value) error {
577612

578613
var err error
579614
addMetaKey := true
580-
switch outputKind {
581-
case reflect.Bool:
582-
err = d.decodeBool(name, input, outVal)
583-
case reflect.Interface:
584-
err = d.decodeBasic(name, input, outVal)
585-
case reflect.String:
586-
err = d.decodeString(name, input, outVal)
587-
case reflect.Int:
588-
err = d.decodeInt(name, input, outVal)
589-
case reflect.Uint:
590-
err = d.decodeUint(name, input, outVal)
591-
case reflect.Float32:
592-
err = d.decodeFloat(name, input, outVal)
593-
case reflect.Complex64:
594-
err = d.decodeComplex(name, input, outVal)
595-
case reflect.Struct:
596-
err = d.decodeStruct(name, input, outVal)
597-
case reflect.Map:
598-
err = d.decodeMap(name, input, outVal)
599-
case reflect.Ptr:
600-
addMetaKey, err = d.decodePtr(name, input, outVal)
601-
case reflect.Slice:
602-
err = d.decodeSlice(name, input, outVal)
603-
case reflect.Array:
604-
err = d.decodeArray(name, input, outVal)
605-
case reflect.Func:
606-
err = d.decodeFunc(name, input, outVal)
607-
default:
608-
// If we reached this point then we weren't able to decode it
609-
return newDecodeError(name, fmt.Errorf("unsupported type: %s", outputKind))
615+
616+
// Check if the target implements Unmarshaler and use it if not disabled
617+
unmarshaled := false
618+
if !d.config.DisableUnmarshaler {
619+
if unmarshaler, ok := getUnmarshaler(outVal); ok {
620+
if err = unmarshaler.UnmarshalMapstructure(input); err != nil {
621+
err = newDecodeError(name, err)
622+
}
623+
unmarshaled = true
624+
}
625+
}
626+
627+
if !unmarshaled {
628+
switch outputKind {
629+
case reflect.Bool:
630+
err = d.decodeBool(name, input, outVal)
631+
case reflect.Interface:
632+
err = d.decodeBasic(name, input, outVal)
633+
case reflect.String:
634+
err = d.decodeString(name, input, outVal)
635+
case reflect.Int:
636+
err = d.decodeInt(name, input, outVal)
637+
case reflect.Uint:
638+
err = d.decodeUint(name, input, outVal)
639+
case reflect.Float32:
640+
err = d.decodeFloat(name, input, outVal)
641+
case reflect.Complex64:
642+
err = d.decodeComplex(name, input, outVal)
643+
case reflect.Struct:
644+
err = d.decodeStruct(name, input, outVal)
645+
case reflect.Map:
646+
err = d.decodeMap(name, input, outVal)
647+
case reflect.Ptr:
648+
addMetaKey, err = d.decodePtr(name, input, outVal)
649+
case reflect.Slice:
650+
err = d.decodeSlice(name, input, outVal)
651+
case reflect.Array:
652+
err = d.decodeArray(name, input, outVal)
653+
case reflect.Func:
654+
err = d.decodeFunc(name, input, outVal)
655+
default:
656+
// If we reached this point then we weren't able to decode it
657+
return newDecodeError(name, fmt.Errorf("unsupported type: %s", outputKind))
658+
}
610659
}
611660

612661
// If we reached here, then we successfully decoded SOMETHING, so
@@ -1860,3 +1909,40 @@ func splitTagNames(tagName string) []string {
18601909

18611910
return result
18621911
}
1912+
1913+
// unmarshalerType is cached for performance
1914+
var unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
1915+
1916+
// getUnmarshaler checks if the value implements Unmarshaler and returns
1917+
// the Unmarshaler and a boolean indicating if it was found. It handles both
1918+
// pointer and value receivers.
1919+
func getUnmarshaler(val reflect.Value) (Unmarshaler, bool) {
1920+
// Skip invalid or nil values
1921+
if !val.IsValid() {
1922+
return nil, false
1923+
}
1924+
1925+
switch val.Kind() {
1926+
case reflect.Pointer, reflect.Interface:
1927+
if val.IsNil() {
1928+
return nil, false
1929+
}
1930+
}
1931+
1932+
// Check pointer receiver first (most common case)
1933+
if val.CanAddr() {
1934+
ptrVal := val.Addr()
1935+
// Quick check: if no methods, can't implement any interface
1936+
if ptrVal.Type().NumMethod() > 0 && ptrVal.Type().Implements(unmarshalerType) {
1937+
return ptrVal.Interface().(Unmarshaler), true
1938+
}
1939+
}
1940+
1941+
// Check value receiver
1942+
// Quick check: if no methods, can't implement any interface
1943+
if val.Type().NumMethod() > 0 && val.CanInterface() && val.Type().Implements(unmarshalerType) {
1944+
return val.Interface().(Unmarshaler), true
1945+
}
1946+
1947+
return nil, false
1948+
}

mapstructure_examples_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,56 @@ func ExampleDecode_decodeHookFunc() {
357357
// Output:
358358
// mapstructure.Person{Name:"Mitchell", Location:mapstructure.PersonLocation{Latitude:-35.2809, Longtitude:149.13}}
359359
}
360+
361+
// ExampleServerConfig is used by ExampleDecode_unmarshaler.
362+
// It implements the Unmarshaler interface to apply custom decoding logic.
363+
type ExampleServerConfig struct {
364+
Host string
365+
Port int
366+
}
367+
368+
// UnmarshalMapstructure implements the Unmarshaler interface.
369+
// It applies default values when fields are missing from the input.
370+
func (s *ExampleServerConfig) UnmarshalMapstructure(data any) error {
371+
m, ok := data.(map[string]any)
372+
if !ok {
373+
return fmt.Errorf("expected map[string]any, got %T", data)
374+
}
375+
376+
// Apply defaults first
377+
s.Host = "localhost"
378+
s.Port = 8080
379+
380+
// Override with provided values
381+
if host, ok := m["host"].(string); ok {
382+
s.Host = host
383+
}
384+
if port, ok := m["port"].(int); ok {
385+
s.Port = port
386+
}
387+
388+
return nil
389+
}
390+
391+
func ExampleDecode_unmarshaler() {
392+
// Types that implement the Unmarshaler interface can control how they
393+
// are decoded from map data. This is useful for applying defaults,
394+
// custom validation, or complex transformation logic.
395+
396+
input := map[string]any{
397+
"host": "example.com",
398+
// Note: port is intentionally omitted to demonstrate default handling
399+
}
400+
401+
var result ExampleServerConfig
402+
err := Decode(input, &result)
403+
if err != nil {
404+
panic(err)
405+
}
406+
407+
// The Unmarshaler applied the default port value since it wasn't in the input
408+
fmt.Printf("Host: %s, Port: %d", result.Host, result.Port)
409+
410+
// Output:
411+
// Host: example.com, Port: 8080
412+
}

0 commit comments

Comments
 (0)