@@ -218,6 +218,17 @@ type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, any) (any, error)
218218// values.
219219type DecodeHookFuncValue func (from reflect.Value , to reflect.Value ) (any , error )
220220
221+ // Unmarshaler is the interface implemented by types that can unmarshal
222+ // themselves. UnmarshalMapstructure receives the input data (potentially
223+ // transformed by DecodeHook) and should populate the receiver with the
224+ // decoded values.
225+ //
226+ // The Unmarshaler interface takes precedence over the default decoding
227+ // logic for any type (structs, slices, maps, primitives, etc.).
228+ type Unmarshaler interface {
229+ UnmarshalMapstructure (any ) error
230+ }
231+
221232// DecoderConfig is the configuration that is used to create a new decoder
222233// and allows customization of various aspects of decoding.
223234type DecoderConfig struct {
@@ -340,6 +351,11 @@ type DecoderConfig struct {
340351 // the initial lookup. If not found, MatchName is used as a fallback comparison.
341352 // Explicit struct tags always take precedence over MapFieldName.
342353 MapFieldName func (string ) string
354+
355+ // DisableUnmarshaler, if set to true, disables the use of the Unmarshaler
356+ // interface. Types implementing Unmarshaler will be decoded using the
357+ // standard struct decoding logic instead.
358+ DisableUnmarshaler bool
343359}
344360
345361// A Decoder takes a raw interface value and turns it into structured
@@ -577,36 +593,50 @@ func (d *Decoder) decode(name string, input any, outVal reflect.Value) error {
577593
578594 var err error
579595 addMetaKey := true
580- switch outputKind {
581- case reflect .Bool :
582- err = d .decodeBool (name , input , outVal )
583- case reflect .Interface :
584- err = d .decodeBasic (name , input , outVal )
585- case reflect .String :
586- err = d .decodeString (name , input , outVal )
587- case reflect .Int :
588- err = d .decodeInt (name , input , outVal )
589- case reflect .Uint :
590- err = d .decodeUint (name , input , outVal )
591- case reflect .Float32 :
592- err = d .decodeFloat (name , input , outVal )
593- case reflect .Complex64 :
594- err = d .decodeComplex (name , input , outVal )
595- case reflect .Struct :
596- err = d .decodeStruct (name , input , outVal )
597- case reflect .Map :
598- err = d .decodeMap (name , input , outVal )
599- case reflect .Ptr :
600- addMetaKey , err = d .decodePtr (name , input , outVal )
601- case reflect .Slice :
602- err = d .decodeSlice (name , input , outVal )
603- case reflect .Array :
604- err = d .decodeArray (name , input , outVal )
605- case reflect .Func :
606- err = d .decodeFunc (name , input , outVal )
607- default :
608- // If we reached this point then we weren't able to decode it
609- return newDecodeError (name , fmt .Errorf ("unsupported type: %s" , outputKind ))
596+
597+ // Check if the target implements Unmarshaler and use it if not disabled
598+ unmarshaled := false
599+ if ! d .config .DisableUnmarshaler {
600+ if unmarshaler , ok := getUnmarshaler (outVal ); ok {
601+ if err = unmarshaler .UnmarshalMapstructure (input ); err != nil {
602+ err = newDecodeError (name , err )
603+ }
604+ unmarshaled = true
605+ }
606+ }
607+
608+ if ! unmarshaled {
609+ switch outputKind {
610+ case reflect .Bool :
611+ err = d .decodeBool (name , input , outVal )
612+ case reflect .Interface :
613+ err = d .decodeBasic (name , input , outVal )
614+ case reflect .String :
615+ err = d .decodeString (name , input , outVal )
616+ case reflect .Int :
617+ err = d .decodeInt (name , input , outVal )
618+ case reflect .Uint :
619+ err = d .decodeUint (name , input , outVal )
620+ case reflect .Float32 :
621+ err = d .decodeFloat (name , input , outVal )
622+ case reflect .Complex64 :
623+ err = d .decodeComplex (name , input , outVal )
624+ case reflect .Struct :
625+ err = d .decodeStruct (name , input , outVal )
626+ case reflect .Map :
627+ err = d .decodeMap (name , input , outVal )
628+ case reflect .Ptr :
629+ addMetaKey , err = d .decodePtr (name , input , outVal )
630+ case reflect .Slice :
631+ err = d .decodeSlice (name , input , outVal )
632+ case reflect .Array :
633+ err = d .decodeArray (name , input , outVal )
634+ case reflect .Func :
635+ err = d .decodeFunc (name , input , outVal )
636+ default :
637+ // If we reached this point then we weren't able to decode it
638+ return newDecodeError (name , fmt .Errorf ("unsupported type: %s" , outputKind ))
639+ }
610640 }
611641
612642 // If we reached here, then we successfully decoded SOMETHING, so
@@ -1860,3 +1890,40 @@ func splitTagNames(tagName string) []string {
18601890
18611891 return result
18621892}
1893+
1894+ // unmarshalerType is cached for performance
1895+ var unmarshalerType = reflect .TypeOf ((* Unmarshaler )(nil )).Elem ()
1896+
1897+ // getUnmarshaler checks if the value implements Unmarshaler and returns
1898+ // the Unmarshaler and a boolean indicating if it was found. It handles both
1899+ // pointer and value receivers.
1900+ func getUnmarshaler (val reflect.Value ) (Unmarshaler , bool ) {
1901+ // Skip invalid or nil values
1902+ if ! val .IsValid () {
1903+ return nil , false
1904+ }
1905+
1906+ switch val .Kind () {
1907+ case reflect .Pointer , reflect .Interface :
1908+ if val .IsNil () {
1909+ return nil , false
1910+ }
1911+ }
1912+
1913+ // Check pointer receiver first (most common case)
1914+ if val .CanAddr () {
1915+ ptrVal := val .Addr ()
1916+ // Quick check: if no methods, can't implement any interface
1917+ if ptrVal .Type ().NumMethod () > 0 && ptrVal .Type ().Implements (unmarshalerType ) {
1918+ return ptrVal .Interface ().(Unmarshaler ), true
1919+ }
1920+ }
1921+
1922+ // Check value receiver
1923+ // Quick check: if no methods, can't implement any interface
1924+ if val .Type ().NumMethod () > 0 && val .CanInterface () && val .Type ().Implements (unmarshalerType ) {
1925+ return val .Interface ().(Unmarshaler ), true
1926+ }
1927+
1928+ return nil , false
1929+ }
0 commit comments