|
173 | 173 | // Public: "I made it through!" |
174 | 174 | // } |
175 | 175 | // |
| 176 | +// # Custom Decoding with Unmarshaler |
| 177 | +// |
| 178 | +// Types can implement the Unmarshaler interface to control their own decoding. The interface |
| 179 | +// behaves similarly to how UnmarshalJSON does in the standard library. It can be used as an |
| 180 | +// alternative or companion to a DecodeHook. |
| 181 | +// |
| 182 | +// type TrimmedString string |
| 183 | +// |
| 184 | +// func (t *TrimmedString) UnmarshalMapstructure(input any) error { |
| 185 | +// str, ok := input.(string) |
| 186 | +// if !ok { |
| 187 | +// return fmt.Errorf("expected string, got %T", input) |
| 188 | +// } |
| 189 | +// *t = TrimmedString(strings.TrimSpace(str)) |
| 190 | +// return nil |
| 191 | +// } |
| 192 | +// |
| 193 | +// See the Unmarshaler interface documentation for more details. |
| 194 | +// |
176 | 195 | // # Other Configuration |
177 | 196 | // |
178 | 197 | // mapstructure is highly configurable. See the DecoderConfig struct |
@@ -218,6 +237,17 @@ type DecodeHookFuncKind func(reflect.Kind, reflect.Kind, any) (any, error) |
218 | 237 | // values. |
219 | 238 | type DecodeHookFuncValue func(from reflect.Value, to reflect.Value) (any, error) |
220 | 239 |
|
| 240 | +// Unmarshaler is the interface implemented by types that can unmarshal |
| 241 | +// themselves. UnmarshalMapstructure receives the input data (potentially |
| 242 | +// transformed by DecodeHook) and should populate the receiver with the |
| 243 | +// decoded values. |
| 244 | +// |
| 245 | +// The Unmarshaler interface takes precedence over the default decoding |
| 246 | +// logic for any type (structs, slices, maps, primitives, etc.). |
| 247 | +type Unmarshaler interface { |
| 248 | + UnmarshalMapstructure(any) error |
| 249 | +} |
| 250 | + |
221 | 251 | // DecoderConfig is the configuration that is used to create a new decoder |
222 | 252 | // and allows customization of various aspects of decoding. |
223 | 253 | type DecoderConfig struct { |
@@ -340,6 +370,11 @@ type DecoderConfig struct { |
340 | 370 | // the initial lookup. If not found, MatchName is used as a fallback comparison. |
341 | 371 | // Explicit struct tags always take precedence over MapFieldName. |
342 | 372 | MapFieldName func(string) string |
| 373 | + |
| 374 | + // DisableUnmarshaler, if set to true, disables the use of the Unmarshaler |
| 375 | + // interface. Types implementing Unmarshaler will be decoded using the |
| 376 | + // standard struct decoding logic instead. |
| 377 | + DisableUnmarshaler bool |
343 | 378 | } |
344 | 379 |
|
345 | 380 | // A Decoder takes a raw interface value and turns it into structured |
@@ -577,36 +612,50 @@ func (d *Decoder) decode(name string, input any, outVal reflect.Value) error { |
577 | 612 |
|
578 | 613 | var err error |
579 | 614 | addMetaKey := true |
580 | | - switch outputKind { |
581 | | - case reflect.Bool: |
582 | | - err = d.decodeBool(name, input, outVal) |
583 | | - case reflect.Interface: |
584 | | - err = d.decodeBasic(name, input, outVal) |
585 | | - case reflect.String: |
586 | | - err = d.decodeString(name, input, outVal) |
587 | | - case reflect.Int: |
588 | | - err = d.decodeInt(name, input, outVal) |
589 | | - case reflect.Uint: |
590 | | - err = d.decodeUint(name, input, outVal) |
591 | | - case reflect.Float32: |
592 | | - err = d.decodeFloat(name, input, outVal) |
593 | | - case reflect.Complex64: |
594 | | - err = d.decodeComplex(name, input, outVal) |
595 | | - case reflect.Struct: |
596 | | - err = d.decodeStruct(name, input, outVal) |
597 | | - case reflect.Map: |
598 | | - err = d.decodeMap(name, input, outVal) |
599 | | - case reflect.Ptr: |
600 | | - addMetaKey, err = d.decodePtr(name, input, outVal) |
601 | | - case reflect.Slice: |
602 | | - err = d.decodeSlice(name, input, outVal) |
603 | | - case reflect.Array: |
604 | | - err = d.decodeArray(name, input, outVal) |
605 | | - case reflect.Func: |
606 | | - err = d.decodeFunc(name, input, outVal) |
607 | | - default: |
608 | | - // If we reached this point then we weren't able to decode it |
609 | | - return newDecodeError(name, fmt.Errorf("unsupported type: %s", outputKind)) |
| 615 | + |
| 616 | + // Check if the target implements Unmarshaler and use it if not disabled |
| 617 | + unmarshaled := false |
| 618 | + if !d.config.DisableUnmarshaler { |
| 619 | + if unmarshaler, ok := getUnmarshaler(outVal); ok { |
| 620 | + if err = unmarshaler.UnmarshalMapstructure(input); err != nil { |
| 621 | + err = newDecodeError(name, err) |
| 622 | + } |
| 623 | + unmarshaled = true |
| 624 | + } |
| 625 | + } |
| 626 | + |
| 627 | + if !unmarshaled { |
| 628 | + switch outputKind { |
| 629 | + case reflect.Bool: |
| 630 | + err = d.decodeBool(name, input, outVal) |
| 631 | + case reflect.Interface: |
| 632 | + err = d.decodeBasic(name, input, outVal) |
| 633 | + case reflect.String: |
| 634 | + err = d.decodeString(name, input, outVal) |
| 635 | + case reflect.Int: |
| 636 | + err = d.decodeInt(name, input, outVal) |
| 637 | + case reflect.Uint: |
| 638 | + err = d.decodeUint(name, input, outVal) |
| 639 | + case reflect.Float32: |
| 640 | + err = d.decodeFloat(name, input, outVal) |
| 641 | + case reflect.Complex64: |
| 642 | + err = d.decodeComplex(name, input, outVal) |
| 643 | + case reflect.Struct: |
| 644 | + err = d.decodeStruct(name, input, outVal) |
| 645 | + case reflect.Map: |
| 646 | + err = d.decodeMap(name, input, outVal) |
| 647 | + case reflect.Ptr: |
| 648 | + addMetaKey, err = d.decodePtr(name, input, outVal) |
| 649 | + case reflect.Slice: |
| 650 | + err = d.decodeSlice(name, input, outVal) |
| 651 | + case reflect.Array: |
| 652 | + err = d.decodeArray(name, input, outVal) |
| 653 | + case reflect.Func: |
| 654 | + err = d.decodeFunc(name, input, outVal) |
| 655 | + default: |
| 656 | + // If we reached this point then we weren't able to decode it |
| 657 | + return newDecodeError(name, fmt.Errorf("unsupported type: %s", outputKind)) |
| 658 | + } |
610 | 659 | } |
611 | 660 |
|
612 | 661 | // If we reached here, then we successfully decoded SOMETHING, so |
@@ -1860,3 +1909,40 @@ func splitTagNames(tagName string) []string { |
1860 | 1909 |
|
1861 | 1910 | return result |
1862 | 1911 | } |
| 1912 | + |
| 1913 | +// unmarshalerType is cached for performance |
| 1914 | +var unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() |
| 1915 | + |
| 1916 | +// getUnmarshaler checks if the value implements Unmarshaler and returns |
| 1917 | +// the Unmarshaler and a boolean indicating if it was found. It handles both |
| 1918 | +// pointer and value receivers. |
| 1919 | +func getUnmarshaler(val reflect.Value) (Unmarshaler, bool) { |
| 1920 | + // Skip invalid or nil values |
| 1921 | + if !val.IsValid() { |
| 1922 | + return nil, false |
| 1923 | + } |
| 1924 | + |
| 1925 | + switch val.Kind() { |
| 1926 | + case reflect.Pointer, reflect.Interface: |
| 1927 | + if val.IsNil() { |
| 1928 | + return nil, false |
| 1929 | + } |
| 1930 | + } |
| 1931 | + |
| 1932 | + // Check pointer receiver first (most common case) |
| 1933 | + if val.CanAddr() { |
| 1934 | + ptrVal := val.Addr() |
| 1935 | + // Quick check: if no methods, can't implement any interface |
| 1936 | + if ptrVal.Type().NumMethod() > 0 && ptrVal.Type().Implements(unmarshalerType) { |
| 1937 | + return ptrVal.Interface().(Unmarshaler), true |
| 1938 | + } |
| 1939 | + } |
| 1940 | + |
| 1941 | + // Check value receiver |
| 1942 | + // Quick check: if no methods, can't implement any interface |
| 1943 | + if val.Type().NumMethod() > 0 && val.CanInterface() && val.Type().Implements(unmarshalerType) { |
| 1944 | + return val.Interface().(Unmarshaler), true |
| 1945 | + } |
| 1946 | + |
| 1947 | + return nil, false |
| 1948 | +} |
0 commit comments