Skip to content

Commit 52dacb8

Browse files
authored
Update custom marshaler and unmarshaler to accept context (#745)
* feat: Update custom marshaler and unmarshaler to accept context * Add tests for custom marshaler and unmarshaler with context
1 parent 680eea7 commit 52dacb8

File tree

7 files changed

+149
-14
lines changed

7 files changed

+149
-14
lines changed

decode.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type Decoder struct {
3030
referenceReaders []io.Reader
3131
anchorNodeMap map[string]ast.Node
3232
anchorValueMap map[string]reflect.Value
33-
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
33+
customUnmarshalerMap map[reflect.Type]func(context.Context, interface{}, []byte) error
3434
commentMaps []CommentMap
3535
toCommentMap CommentMap
3636
opts []DecodeOption
@@ -54,7 +54,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
5454
reader: r,
5555
anchorNodeMap: map[string]ast.Node{},
5656
anchorValueMap: map[string]reflect.Value{},
57-
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
57+
customUnmarshalerMap: map[reflect.Type]func(context.Context, interface{}, []byte) error{},
5858
opts: opts,
5959
referenceReaders: []io.Reader{},
6060
referenceFiles: []string{},
@@ -722,7 +722,7 @@ func (d *Decoder) existsTypeInCustomUnmarshalerMap(t reflect.Type) bool {
722722
return false
723723
}
724724

725-
func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(interface{}, []byte) error, bool) {
725+
func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(context.Context, interface{}, []byte) error, bool) {
726726
if unmarshaler, exists := d.customUnmarshalerMap[t]; exists {
727727
return unmarshaler, exists
728728
}
@@ -765,7 +765,7 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
765765
if err != nil {
766766
return err
767767
}
768-
if err := unmarshaler(ptrValue.Interface(), b); err != nil {
768+
if err := unmarshaler(ctx, ptrValue.Interface(), b); err != nil {
769769
return err
770770
}
771771
return nil

decode_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,6 +2173,29 @@ func TestDecoder_CustomUnmarshaler(t *testing.T) {
21732173
t.Fatalf("failed to switch to custom unmarshaler. got: %q", v.Foo)
21742174
}
21752175
})
2176+
t.Run("override bytes type with context", func(t *testing.T) {
2177+
type T struct {
2178+
Foo []byte `yaml:"foo"`
2179+
}
2180+
src := []byte(`foo: "bar"`)
2181+
var v T
2182+
ctx := context.WithValue(context.Background(), "plop", uint(42))
2183+
if err := yaml.UnmarshalContext(ctx, src, &v, yaml.CustomUnmarshalerContext[[]byte](func(ctx context.Context, dst *[]byte, b []byte) error {
2184+
if !bytes.Equal(b, []byte(`"bar"`)) {
2185+
t.Fatalf("failed to get target buffer: %q", b)
2186+
}
2187+
if ctx.Value("plop") != uint(42) {
2188+
t.Fatalf("context value is not correct")
2189+
}
2190+
*dst = []byte("bazbaz")
2191+
return nil
2192+
})); err != nil {
2193+
t.Fatal(err)
2194+
}
2195+
if !bytes.Equal(v.Foo, []byte("bazbaz")) {
2196+
t.Fatalf("failed to switch to custom unmarshaler. got: %q", v.Foo)
2197+
}
2198+
})
21762199
}
21772200

21782201
type unmarshalContext struct {

encode.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type Encoder struct {
3737
anchorRefToName map[uintptr]string
3838
anchorNameMap map[string]struct{}
3939
anchorCallback func(*ast.AnchorNode, interface{}) error
40-
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
40+
customMarshalerMap map[reflect.Type]func(context.Context, interface{}) ([]byte, error)
4141
omitZero bool
4242
omitEmpty bool
4343
autoInt bool
@@ -59,7 +59,7 @@ func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder {
5959
return &Encoder{
6060
writer: w,
6161
opts: opts,
62-
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
62+
customMarshalerMap: map[reflect.Type]func(context.Context, interface{}) ([]byte, error){},
6363
line: 1,
6464
column: 1,
6565
offset: 0,
@@ -301,7 +301,7 @@ func (e *Encoder) existsTypeInCustomMarshalerMap(t reflect.Type) bool {
301301
return false
302302
}
303303

304-
func (e *Encoder) marshalerFromCustomMarshalerMap(t reflect.Type) (func(interface{}) ([]byte, error), bool) {
304+
func (e *Encoder) marshalerFromCustomMarshalerMap(t reflect.Type) (func(context.Context, interface{}) ([]byte, error), bool) {
305305
if marshaler, exists := e.customMarshalerMap[t]; exists {
306306
return marshaler, exists
307307
}
@@ -347,7 +347,7 @@ func (e *Encoder) encodeByMarshaler(ctx context.Context, v reflect.Value, column
347347
iface := v.Interface()
348348

349349
if marshaler, exists := e.marshalerFromCustomMarshalerMap(v.Type()); exists {
350-
doc, err := marshaler(iface)
350+
doc, err := marshaler(ctx, iface)
351351
if err != nil {
352352
return nil, err
353353
}

encode_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,27 @@ func TestEncoder_CustomMarshaler(t *testing.T) {
14921492
t.Fatalf("failed to switch to custom marshaler. got: %q", b)
14931493
}
14941494
})
1495+
t.Run("override bytes type with context", func(t *testing.T) {
1496+
type T struct {
1497+
Foo []byte `yaml:"foo"`
1498+
}
1499+
ctx := context.WithValue(context.Background(), "plop", uint(42))
1500+
b, err := yaml.MarshalContext(ctx, &T{Foo: []byte("bar")}, yaml.CustomMarshalerContext[[]byte](func(ctx context.Context, v []byte) ([]byte, error) {
1501+
if !bytes.Equal(v, []byte("bar")) {
1502+
t.Fatalf("failed to get src buffer: %q", v)
1503+
}
1504+
if ctx.Value("plop") != uint(42) {
1505+
t.Fatalf("context value is not correct")
1506+
}
1507+
return []byte(`override`), nil
1508+
}))
1509+
if err != nil {
1510+
t.Fatal(err)
1511+
}
1512+
if !bytes.Equal(b, []byte("foo: override\n")) {
1513+
t.Fatalf("failed to switch to custom marshaler. got: %q", b)
1514+
}
1515+
})
14951516
}
14961517

14971518
func TestEncoder_AutoInt(t *testing.T) {

option.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package yaml
22

33
import (
4+
"context"
45
"io"
56
"reflect"
67

@@ -101,13 +102,25 @@ func UseJSONUnmarshaler() DecodeOption {
101102
func CustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) DecodeOption {
102103
return func(d *Decoder) error {
103104
var typ *T
104-
d.customUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error {
105+
d.customUnmarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}, b []byte) error {
105106
return unmarshaler(v.(*T), b)
106107
}
107108
return nil
108109
}
109110
}
110111

112+
// CustomUnmarshalerContext overrides any decoding process for the type specified in generics.
113+
// Similar to CustomUnmarshaler, but allows passing a context to the unmarshaler function.
114+
func CustomUnmarshalerContext[T any](unmarshaler func(context.Context, *T, []byte) error) DecodeOption {
115+
return func(d *Decoder) error {
116+
var typ *T
117+
d.customUnmarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}, b []byte) error {
118+
return unmarshaler(ctx, v.(*T), b)
119+
}
120+
return nil
121+
}
122+
}
123+
111124
// EncodeOption functional option type for Encoder
112125
type EncodeOption func(e *Encoder) error
113126

@@ -199,13 +212,25 @@ func UseJSONMarshaler() EncodeOption {
199212
func CustomMarshaler[T any](marshaler func(T) ([]byte, error)) EncodeOption {
200213
return func(e *Encoder) error {
201214
var typ T
202-
e.customMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) {
215+
e.customMarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}) ([]byte, error) {
203216
return marshaler(v.(T))
204217
}
205218
return nil
206219
}
207220
}
208221

222+
// CustomMarshalerContext overrides any encoding process for the type specified in generics.
223+
// Similar to CustomMarshaler, but allows passing a context to the marshaler function.
224+
func CustomMarshalerContext[T any](marshaler func(context.Context, T) ([]byte, error)) EncodeOption {
225+
return func(e *Encoder) error {
226+
var typ T
227+
e.customMarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}) ([]byte, error) {
228+
return marshaler(ctx, v.(T))
229+
}
230+
return nil
231+
}
232+
}
233+
209234
// AutoInt automatically converts floating-point numbers to integers when the fractional part is zero.
210235
// For example, a value of 1.0 will be encoded as 1.
211236
func AutoInt() EncodeOption {

testdata/yaml_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package yaml_test
22

33
import (
44
"bytes"
5+
"context"
56
"errors"
67
"strings"
78
"testing"
@@ -1084,6 +1085,26 @@ func TestRegisterCustomMarshaler(t *testing.T) {
10841085
}
10851086
}
10861087

1088+
func TestRegisterCustomMarshalerContext(t *testing.T) {
1089+
type T struct {
1090+
Foo []byte `yaml:"foo"`
1091+
}
1092+
yaml.RegisterCustomMarshalerContext[T](func(ctx context.Context, _ T) ([]byte, error) {
1093+
if ctx.Value("plop") != uint(42) {
1094+
t.Fatalf("context value is not correct")
1095+
}
1096+
return []byte(`"override"`), nil
1097+
})
1098+
ctx := context.WithValue(context.Background(), "plop", uint(42))
1099+
b, err := yaml.MarshalContext(ctx, &T{Foo: []byte("bar")})
1100+
if err != nil {
1101+
t.Fatal(err)
1102+
}
1103+
if !bytes.Equal(b, []byte("\"override\"\n")) {
1104+
t.Fatalf("failed to register custom marshaler. got: %q", b)
1105+
}
1106+
}
1107+
10871108
func TestRegisterCustomUnmarshaler(t *testing.T) {
10881109
type T struct {
10891110
Foo []byte `yaml:"foo"`
@@ -1100,3 +1121,24 @@ func TestRegisterCustomUnmarshaler(t *testing.T) {
11001121
t.Fatalf("failed to decode. got %q", v.Foo)
11011122
}
11021123
}
1124+
1125+
func TestRegisterCustomUnmarshalerContext(t *testing.T) {
1126+
type T struct {
1127+
Foo []byte `yaml:"foo"`
1128+
}
1129+
yaml.RegisterCustomUnmarshalerContext[T](func(ctx context.Context, v *T, _ []byte) error {
1130+
if ctx.Value("plop") != uint(42) {
1131+
t.Fatalf("context value is not correct")
1132+
}
1133+
v.Foo = []byte("override")
1134+
return nil
1135+
})
1136+
var v T
1137+
ctx := context.WithValue(context.Background(), "plop", uint(42))
1138+
if err := yaml.UnmarshalContext(ctx, []byte(`"foo": "bar"`), &v); err != nil {
1139+
t.Fatal(err)
1140+
}
1141+
if !bytes.Equal(v.Foo, []byte("override")) {
1142+
t.Fatalf("failed to decode. got %q", v.Foo)
1143+
}
1144+
}

yaml.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ func JSONToYAML(bytes []byte) ([]byte, error) {
266266
var (
267267
globalCustomMarshalerMu sync.Mutex
268268
globalCustomUnmarshalerMu sync.Mutex
269-
globalCustomMarshalerMap = map[reflect.Type]func(interface{}) ([]byte, error){}
270-
globalCustomUnmarshalerMap = map[reflect.Type]func(interface{}, []byte) error{}
269+
globalCustomMarshalerMap = map[reflect.Type]func(context.Context, interface{}) ([]byte, error){}
270+
globalCustomUnmarshalerMap = map[reflect.Type]func(context.Context, interface{}, []byte) error{}
271271
)
272272

273273
// RegisterCustomMarshaler overrides any encoding process for the type specified in generics.
@@ -281,11 +281,23 @@ func RegisterCustomMarshaler[T any](marshaler func(T) ([]byte, error)) {
281281
defer globalCustomMarshalerMu.Unlock()
282282

283283
var typ T
284-
globalCustomMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) {
284+
globalCustomMarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}) ([]byte, error) {
285285
return marshaler(v.(T))
286286
}
287287
}
288288

289+
// RegisterCustomMarshalerContext overrides any encoding process for the type specified in generics.
290+
// Similar to RegisterCustomMarshalerContext, but allows passing a context to the unmarshaler function.
291+
func RegisterCustomMarshalerContext[T any](marshaler func(context.Context, T) ([]byte, error)) {
292+
globalCustomMarshalerMu.Lock()
293+
defer globalCustomMarshalerMu.Unlock()
294+
295+
var typ T
296+
globalCustomMarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}) ([]byte, error) {
297+
return marshaler(ctx, v.(T))
298+
}
299+
}
300+
289301
// RegisterCustomUnmarshaler overrides any decoding process for the type specified in generics.
290302
// If you want to switch the behavior for each decoder, use `CustomUnmarshaler` defined as DecodeOption.
291303
//
@@ -296,7 +308,19 @@ func RegisterCustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) {
296308
defer globalCustomUnmarshalerMu.Unlock()
297309

298310
var typ *T
299-
globalCustomUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error {
311+
globalCustomUnmarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}, b []byte) error {
300312
return unmarshaler(v.(*T), b)
301313
}
302314
}
315+
316+
// RegisterCustomUnmarshalerContext overrides any decoding process for the type specified in generics.
317+
// Similar to RegisterCustomUnmarshalerContext, but allows passing a context to the unmarshaler function.
318+
func RegisterCustomUnmarshalerContext[T any](unmarshaler func(context.Context, *T, []byte) error) {
319+
globalCustomUnmarshalerMu.Lock()
320+
defer globalCustomUnmarshalerMu.Unlock()
321+
322+
var typ *T
323+
globalCustomUnmarshalerMap[reflect.TypeOf(typ)] = func(ctx context.Context, v interface{}, b []byte) error {
324+
return unmarshaler(ctx, v.(*T), b)
325+
}
326+
}

0 commit comments

Comments
 (0)