Lindenii Project Forge
Login

server

Lindenii Forge’s main backend daemon
Commit info
ID
774b00069f8f55b750a6e7f7b99d88ec76801d91
Author
Runxi Yu <me@runxiyu.org>
Author date
Sun, 06 Apr 2025 11:36:51 +0800
Committer
Runxi Yu <me@runxiyu.org>
Committer date
Sun, 06 Apr 2025 11:36:51 +0800
Actions
Replace PtrTo with PointerTo
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com>

package bare

import (
	"bytes"
	"errors"
	"fmt"
	"reflect"
	"sync"
)

// A type which implements this interface will be responsible for marshaling
// itself when encountered.
type Marshalable interface {
	Marshal(w *Writer) error
}

var encoderBufferPool = sync.Pool{
	New: func() interface{} {
		buf := &bytes.Buffer{}
		buf.Grow(32)
		return buf
	},
}

// Marshals a value (val, which must be a pointer) into a BARE message.
//
// The encoding of each struct field can be customized by the format string
// stored under the "bare" key in the struct field's tag.
//
// As a special case, if the field tag is "-", the field is always omitted.
func Marshal(val interface{}) ([]byte, error) {
	// reuse buffers from previous serializations
	b := encoderBufferPool.Get().(*bytes.Buffer)
	defer func() {
		b.Reset()
		encoderBufferPool.Put(b)
	}()

	w := NewWriter(b)
	err := MarshalWriter(w, val)

	msg := make([]byte, b.Len())
	copy(msg, b.Bytes())

	return msg, err
}

// Marshals a value (val, which must be a pointer) into a BARE message and
// writes it to a Writer. See Marshal for details.
func MarshalWriter(w *Writer, val interface{}) error {
	t := reflect.TypeOf(val)
	v := reflect.ValueOf(val)
	if t.Kind() != reflect.Ptr {
		return errors.New("Expected val to be pointer type")
	}

	return getEncoder(t.Elem())(w, v.Elem())
}

type encodeFunc func(w *Writer, v reflect.Value) error

var encodeFuncCache sync.Map // map[reflect.Type]encodeFunc

// get decoder from cache
func getEncoder(t reflect.Type) encodeFunc {
	if f, ok := encodeFuncCache.Load(t); ok {
		return f.(encodeFunc)
	}

	f := encoderFunc(t)
	encodeFuncCache.Store(t, f)
	return f
}

var marshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()

func encoderFunc(t reflect.Type) encodeFunc {
	if reflect.PtrTo(t).Implements(marshalableInterface) {
	if reflect.PointerTo(t).Implements(marshalableInterface) {
		return func(w *Writer, v reflect.Value) error {
			uv := v.Addr().Interface().(Marshalable)
			return uv.Marshal(w)
		}
	}

	if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
		return encodeUnion(t)
	}

	switch t.Kind() {
	case reflect.Ptr:
		return encodeOptional(t.Elem())
	case reflect.Struct:
		return encodeStruct(t)
	case reflect.Array:
		return encodeArray(t)
	case reflect.Slice:
		return encodeSlice(t)
	case reflect.Map:
		return encodeMap(t)
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		return encodeUint
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		return encodeInt
	case reflect.Float32, reflect.Float64:
		return encodeFloat
	case reflect.Bool:
		return encodeBool
	case reflect.String:
		return encodeString
	}

	return func(w *Writer, v reflect.Value) error {
		return &UnsupportedTypeError{v.Type()}
	}
}

func encodeOptional(t reflect.Type) encodeFunc {
	return func(w *Writer, v reflect.Value) error {
		if v.IsNil() {
			return w.WriteBool(false)
		}

		if err := w.WriteBool(true); err != nil {
			return err
		}

		return getEncoder(t)(w, v.Elem())
	}
}

func encodeStruct(t reflect.Type) encodeFunc {
	n := t.NumField()
	encoders := make([]encodeFunc, n)
	for i := 0; i < n; i++ {
		field := t.Field(i)
		if field.Tag.Get("bare") == "-" {
			continue
		}
		encoders[i] = getEncoder(field.Type)
	}

	return func(w *Writer, v reflect.Value) error {
		for i := 0; i < n; i++ {
			if encoders[i] == nil {
				continue
			}
			err := encoders[i](w, v.Field(i))
			if err != nil {
				return err
			}
		}
		return nil
	}
}

func encodeArray(t reflect.Type) encodeFunc {
	f := getEncoder(t.Elem())
	len := t.Len()

	return func(w *Writer, v reflect.Value) error {
		for i := 0; i < len; i++ {
			if err := f(w, v.Index(i)); err != nil {
				return err
			}
		}
		return nil
	}
}

func encodeSlice(t reflect.Type) encodeFunc {
	elem := t.Elem()
	f := getEncoder(elem)

	return func(w *Writer, v reflect.Value) error {
		if err := w.WriteUint(uint64(v.Len())); err != nil {
			return err
		}

		for i := 0; i < v.Len(); i++ {
			if err := f(w, v.Index(i)); err != nil {
				return err
			}
		}
		return nil
	}
}

func encodeMap(t reflect.Type) encodeFunc {
	keyType := t.Key()
	keyf := getEncoder(keyType)

	valueType := t.Elem()
	valf := getEncoder(valueType)

	return func(w *Writer, v reflect.Value) error {
		if err := w.WriteUint(uint64(v.Len())); err != nil {
			return err
		}

		iter := v.MapRange()
		for iter.Next() {
			if err := keyf(w, iter.Key()); err != nil {
				return err
			}
			if err := valf(w, iter.Value()); err != nil {
				return err
			}
		}
		return nil
	}
}

func encodeUnion(t reflect.Type) encodeFunc {
	ut, ok := unionRegistry[t]
	if !ok {
		return func(w *Writer, v reflect.Value) error {
			return fmt.Errorf("Union type %s is not registered", t.Name())
		}
	}

	encoders := make(map[uint64]encodeFunc)
	for tag, t := range ut.types {
		encoders[tag] = getEncoder(t)
	}

	return func(w *Writer, v reflect.Value) error {
		t := v.Elem().Type()
		if t.Kind() == reflect.Ptr {
			// If T is a valid union value type, *T is valid too.
			t = t.Elem()
			v = v.Elem()
		}
		tag, ok := ut.tags[t]
		if !ok {
			return fmt.Errorf("Invalid union value: %s", v.Elem().String())
		}

		if err := w.WriteUint(tag); err != nil {
			return err
		}

		return encoders[tag](w, v.Elem())
	}
}

func encodeUint(w *Writer, v reflect.Value) error {
	switch getIntKind(v.Type()) {
	case reflect.Uint:
		return w.WriteUint(v.Uint())

	case reflect.Uint8:
		return w.WriteU8(uint8(v.Uint()))

	case reflect.Uint16:
		return w.WriteU16(uint16(v.Uint()))

	case reflect.Uint32:
		return w.WriteU32(uint32(v.Uint()))

	case reflect.Uint64:
		return w.WriteU64(uint64(v.Uint()))
	}

	panic("not uint")
}

func encodeInt(w *Writer, v reflect.Value) error {
	switch getIntKind(v.Type()) {
	case reflect.Int:
		return w.WriteInt(v.Int())

	case reflect.Int8:
		return w.WriteI8(int8(v.Int()))

	case reflect.Int16:
		return w.WriteI16(int16(v.Int()))

	case reflect.Int32:
		return w.WriteI32(int32(v.Int()))

	case reflect.Int64:
		return w.WriteI64(int64(v.Int()))
	}

	panic("not int")
}

func encodeFloat(w *Writer, v reflect.Value) error {
	switch v.Type().Kind() {
	case reflect.Float32:
		return w.WriteF32(float32(v.Float()))
	case reflect.Float64:
		return w.WriteF64(v.Float())
	}

	panic("not float")
}

func encodeBool(w *Writer, v reflect.Value) error {
	return w.WriteBool(v.Bool())
}

func encodeString(w *Writer, v reflect.Value) error {
	if v.Kind() != reflect.String {
		panic("not string")
	}
	return w.WriteString(v.String())
}
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com>

package bare

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"reflect"
	"sync"
)

// A type which implements this interface will be responsible for unmarshaling
// itself when encountered.
type Unmarshalable interface {
	Unmarshal(r *Reader) error
}

// Unmarshals a BARE message into val, which must be a pointer to a value of
// the message type.
func Unmarshal(data []byte, val interface{}) error {
	b := bytes.NewReader(data)
	r := NewReader(b)
	return UnmarshalBareReader(r, val)
}

// Unmarshals a BARE message into value (val, which must be a pointer), from a
// reader. See Unmarshal for details.
func UnmarshalReader(r io.Reader, val interface{}) error {
	r = newLimitedReader(r)
	return UnmarshalBareReader(NewReader(r), val)
}

type decodeFunc func(r *Reader, v reflect.Value) error

var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc

func UnmarshalBareReader(r *Reader, val interface{}) error {
	t := reflect.TypeOf(val)
	v := reflect.ValueOf(val)
	if t.Kind() != reflect.Ptr {
		return errors.New("Expected val to be pointer type")
	}

	return getDecoder(t.Elem())(r, v.Elem())
}

// get decoder from cache
func getDecoder(t reflect.Type) decodeFunc {
	if f, ok := decodeFuncCache.Load(t); ok {
		return f.(decodeFunc)
	}

	f := decoderFunc(t)
	decodeFuncCache.Store(t, f)
	return f
}

var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()

func decoderFunc(t reflect.Type) decodeFunc {
	if reflect.PtrTo(t).Implements(unmarshalableInterface) {
	if reflect.PointerTo(t).Implements(unmarshalableInterface) {
		return func(r *Reader, v reflect.Value) error {
			uv := v.Addr().Interface().(Unmarshalable)
			return uv.Unmarshal(r)
		}
	}

	if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
		return decodeUnion(t)
	}

	switch t.Kind() {
	case reflect.Ptr:
		return decodeOptional(t.Elem())
	case reflect.Struct:
		return decodeStruct(t)
	case reflect.Array:
		return decodeArray(t)
	case reflect.Slice:
		return decodeSlice(t)
	case reflect.Map:
		return decodeMap(t)
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		return decodeUint
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		return decodeInt
	case reflect.Float32, reflect.Float64:
		return decodeFloat
	case reflect.Bool:
		return decodeBool
	case reflect.String:
		return decodeString
	}

	return func(r *Reader, v reflect.Value) error {
		return &UnsupportedTypeError{v.Type()}
	}
}

func decodeOptional(t reflect.Type) decodeFunc {
	return func(r *Reader, v reflect.Value) error {
		s, err := r.ReadU8()
		if err != nil {
			return err
		}

		if s > 1 {
			return fmt.Errorf("Invalid optional value: %#x", s)
		}

		if s == 0 {
			return nil
		}

		v.Set(reflect.New(t))
		return getDecoder(t)(r, v.Elem())
	}
}

func decodeStruct(t reflect.Type) decodeFunc {
	n := t.NumField()
	decoders := make([]decodeFunc, n)
	for i := 0; i < n; i++ {
		field := t.Field(i)
		if field.Tag.Get("bare") == "-" {
			continue
		}
		decoders[i] = getDecoder(field.Type)
	}

	return func(r *Reader, v reflect.Value) error {
		for i := 0; i < n; i++ {
			if decoders[i] == nil {
				continue
			}
			err := decoders[i](r, v.Field(i))
			if err != nil {
				return err
			}
		}
		return nil
	}
}

func decodeArray(t reflect.Type) decodeFunc {
	f := getDecoder(t.Elem())
	len := t.Len()

	return func(r *Reader, v reflect.Value) error {
		for i := 0; i < len; i++ {
			err := f(r, v.Index(i))
			if err != nil {
				return err
			}
		}
		return nil
	}
}

func decodeSlice(t reflect.Type) decodeFunc {
	elem := t.Elem()
	f := getDecoder(elem)

	return func(r *Reader, v reflect.Value) error {
		len, err := r.ReadUint()
		if err != nil {
			return err
		}

		if len > maxArrayLength {
			return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength)
		}

		v.Set(reflect.MakeSlice(t, int(len), int(len)))

		for i := 0; i < int(len); i++ {
			if err := f(r, v.Index(i)); err != nil {
				return err
			}
		}
		return nil
	}
}

func decodeMap(t reflect.Type) decodeFunc {
	keyType := t.Key()
	keyf := getDecoder(keyType)

	valueType := t.Elem()
	valf := getDecoder(valueType)

	return func(r *Reader, v reflect.Value) error {
		size, err := r.ReadUint()
		if err != nil {
			return err
		}

		if size > maxMapSize {
			return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize)
		}

		v.Set(reflect.MakeMapWithSize(t, int(size)))

		key := reflect.New(keyType).Elem()
		value := reflect.New(valueType).Elem()

		for i := uint64(0); i < size; i++ {
			if err := keyf(r, key); err != nil {
				return err
			}

			if v.MapIndex(key).Kind() > reflect.Invalid {
				return fmt.Errorf("Encountered duplicate map key: %v", key.Interface())
			}

			if err := valf(r, value); err != nil {
				return err
			}

			v.SetMapIndex(key, value)
		}
		return nil
	}
}

func decodeUnion(t reflect.Type) decodeFunc {
	ut, ok := unionRegistry[t]
	if !ok {
		return func(r *Reader, v reflect.Value) error {
			return fmt.Errorf("Union type %s is not registered", t.Name())
		}
	}

	decoders := make(map[uint64]decodeFunc)
	for tag, t := range ut.types {
		t := t
		f := getDecoder(t)

		decoders[tag] = func(r *Reader, v reflect.Value) error {
			nv := reflect.New(t)
			if err := f(r, nv.Elem()); err != nil {
				return err
			}

			v.Set(nv)
			return nil
		}
	}

	return func(r *Reader, v reflect.Value) error {
		tag, err := r.ReadUint()
		if err != nil {
			return err
		}

		if f, ok := decoders[tag]; ok {
			return f(r, v)
		}

		return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name())
	}
}

func decodeUint(r *Reader, v reflect.Value) error {
	var err error
	switch getIntKind(v.Type()) {
	case reflect.Uint:
		var u uint64
		u, err = r.ReadUint()
		v.SetUint(u)

	case reflect.Uint8:
		var u uint8
		u, err = r.ReadU8()
		v.SetUint(uint64(u))

	case reflect.Uint16:
		var u uint16
		u, err = r.ReadU16()
		v.SetUint(uint64(u))
	case reflect.Uint32:
		var u uint32
		u, err = r.ReadU32()
		v.SetUint(uint64(u))

	case reflect.Uint64:
		var u uint64
		u, err = r.ReadU64()
		v.SetUint(uint64(u))

	default:
		panic("not an uint")
	}

	return err
}

func decodeInt(r *Reader, v reflect.Value) error {
	var err error
	switch getIntKind(v.Type()) {
	case reflect.Int:
		var i int64
		i, err = r.ReadInt()
		v.SetInt(i)

	case reflect.Int8:
		var i int8
		i, err = r.ReadI8()
		v.SetInt(int64(i))

	case reflect.Int16:
		var i int16
		i, err = r.ReadI16()
		v.SetInt(int64(i))
	case reflect.Int32:
		var i int32
		i, err = r.ReadI32()
		v.SetInt(int64(i))

	case reflect.Int64:
		var i int64
		i, err = r.ReadI64()
		v.SetInt(int64(i))

	default:
		panic("not an int")
	}

	return err
}

func decodeFloat(r *Reader, v reflect.Value) error {
	var err error
	switch v.Type().Kind() {
	case reflect.Float32:
		var f float32
		f, err = r.ReadF32()
		v.SetFloat(float64(f))
	case reflect.Float64:
		var f float64
		f, err = r.ReadF64()
		v.SetFloat(f)
	default:
		panic("not a float")
	}
	return err
}

func decodeBool(r *Reader, v reflect.Value) error {
	b, err := r.ReadBool()
	v.SetBool(b)
	return err
}

func decodeString(r *Reader, v reflect.Value) error {
	s, err := r.ReadString()
	v.SetString(s)
	return err
}
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: Copyright (c) 2020 Simon Ser <https://emersion.fr>
// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org>

package scfg

import (
	"encoding"
	"fmt"
	"io"
	"reflect"
	"strconv"
)

// Decoder reads and decodes an scfg document from an input stream.
type Decoder struct {
	r                 io.Reader
	unknownDirectives []*Directive
}

// NewDecoder returns a new decoder which reads from r.
func NewDecoder(r io.Reader) *Decoder {
	return &Decoder{r: r}
}

// UnknownDirectives returns a slice of all unknown directives encountered
// during Decode.
func (dec *Decoder) UnknownDirectives() []*Directive {
	return dec.unknownDirectives
}

// Decode reads scfg document from the input and stores it in the value pointed
// to by v.
//
// If v is nil or not a pointer, Decode returns an error.
//
// Blocks can be unmarshaled to:
//
//   - Maps. Each directive is unmarshaled into a map entry. The map key must
//     be a string.
//   - Structs. Each directive is unmarshaled into a struct field.
//
// Duplicate directives are not allowed, unless the struct field or map value
// is a slice of values representing a directive: structs or maps.
//
// Directives can be unmarshaled to:
//
//   - Maps. The children block is unmarshaled into the map. Parameters are not
//     allowed.
//   - Structs. The children block is unmarshaled into the struct. Parameters
//     are allowed if one of the struct fields contains the "param" option in
//     its tag.
//   - Slices. Parameters are unmarshaled into the slice. Children blocks are
//     not allowed.
//   - Arrays. Parameters are unmarshaled into the array. The number of
//     parameters must match exactly the length of the array. Children blocks
//     are not allowed.
//   - Strings, booleans, integers, floating-point values, values implementing
//     encoding.TextUnmarshaler. Only a single parameter is allowed and is
//     unmarshaled into the value. Children blocks are not allowed.
//
// The decoding of each struct field can be customized by the format string
// stored under the "scfg" key in the struct field's tag. The tag contains the
// name of the field possibly followed by a comma-separated list of options.
// The name may be empty in order to specify options without overriding the
// default field name. As a special case, if the field name is "-", the field
// is ignored. The "param" option specifies that directive parameters are
// stored in this field (the name must be empty).
func (dec *Decoder) Decode(v interface{}) error {
	block, err := Read(dec.r)
	if err != nil {
		return err
	}

	rv := reflect.ValueOf(v)
	if rv.Kind() != reflect.Ptr || rv.IsNil() {
		return fmt.Errorf("scfg: invalid value for unmarshaling")
	}

	return dec.unmarshalBlock(block, rv)
}

func (dec *Decoder) unmarshalBlock(block Block, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	dirsByName := make(map[string][]*Directive, len(block))
	for _, dir := range block {
		dirsByName[dir.Name] = append(dirsByName[dir.Name], dir)
	}

	switch v.Kind() {
	case reflect.Map:
		if t.Key().Kind() != reflect.String {
			return fmt.Errorf("scfg: map key type must be string")
		}
		if v.IsNil() {
			v.Set(reflect.MakeMap(t))
		} else if v.Len() > 0 {
			clearMap(v)
		}

		for name, dirs := range dirsByName {
			mv := reflect.New(t.Elem()).Elem()
			if err := dec.unmarshalDirectiveList(dirs, mv); err != nil {
				return err
			}
			v.SetMapIndex(reflect.ValueOf(name), mv)
		}

	case reflect.Struct:
		si, err := getStructInfo(t)
		if err != nil {
			return err
		}

		seen := make(map[int]bool)

		for name, dirs := range dirsByName {
			fieldIndex, ok := si.children[name]
			if !ok {
				dec.unknownDirectives = append(dec.unknownDirectives, dirs...)
				continue
			}
			fv := v.Field(fieldIndex)
			if err := dec.unmarshalDirectiveList(dirs, fv); err != nil {
				return err
			}
			seen[fieldIndex] = true
		}

		for name, fieldIndex := range si.children {
			if fieldIndex == si.param {
				continue
			}
			if _, ok := seen[fieldIndex]; !ok {
				return fmt.Errorf("scfg: missing required directive %q", name)
			}
		}

	default:
		return fmt.Errorf("scfg: unsupported type for unmarshaling blocks: %v", t)
	}

	return nil
}

func (dec *Decoder) unmarshalDirectiveList(dirs []*Directive, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	if v.Kind() != reflect.Slice || !isDirectiveType(t.Elem()) {
		if len(dirs) > 1 {
			return newUnmarshalDirectiveError(dirs[1], "directive must not be specified more than once")
		}
		return dec.unmarshalDirective(dirs[0], v)
	}

	sv := reflect.MakeSlice(t, len(dirs), len(dirs))
	for i, dir := range dirs {
		if err := dec.unmarshalDirective(dir, sv.Index(i)); err != nil {
			return err
		}
	}
	v.Set(sv)
	return nil
}

// isDirectiveType checks whether a type can only be unmarshaled as a
// directive, not as a parameter. Accepting too many types here would result in
// ambiguities, see:
// https://lists.sr.ht/~emersion/public-inbox/%3C20230629132458.152205-1-contact%40emersion.fr%3E#%3Ch4Y2peS_YBqY3ar4XlmPDPiNBFpYGns3EBYUx3_6zWEhV2o8_-fBQveRujGADWYhVVCucHBEryFGoPtpC3d3mQ-x10pWnFogfprbQTSvtxc=@emersion.fr%3E
func isDirectiveType(t reflect.Type) bool {
	for t.Kind() == reflect.Ptr {
		t = t.Elem()
	}

	textUnmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
	if reflect.PtrTo(t).Implements(textUnmarshalerType) {
	if reflect.PointerTo(t).Implements(textUnmarshalerType) {
		return false
	}

	switch t.Kind() {
	case reflect.Struct, reflect.Map:
		return true
	default:
		return false
	}
}

func (dec *Decoder) unmarshalDirective(dir *Directive, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	if v.CanAddr() {
		if _, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
			if len(dir.Children) != 0 {
				return newUnmarshalDirectiveError(dir, "directive requires zero children")
			}
			return unmarshalParamList(dir, v)
		}
	}

	switch v.Kind() {
	case reflect.Map:
		if len(dir.Params) > 0 {
			return newUnmarshalDirectiveError(dir, "directive requires zero parameters")
		}
		if err := dec.unmarshalBlock(dir.Children, v); err != nil {
			return err
		}
	case reflect.Struct:
		si, err := getStructInfo(t)
		if err != nil {
			return err
		}

		if si.param >= 0 {
			if err := unmarshalParamList(dir, v.Field(si.param)); err != nil {
				return err
			}
		} else {
			if len(dir.Params) > 0 {
				return newUnmarshalDirectiveError(dir, "directive requires zero parameters")
			}
		}

		if err := dec.unmarshalBlock(dir.Children, v); err != nil {
			return err
		}
	default:
		if len(dir.Children) != 0 {
			return newUnmarshalDirectiveError(dir, "directive requires zero children")
		}
		if err := unmarshalParamList(dir, v); err != nil {
			return err
		}
	}
	return nil
}

func unmarshalParamList(dir *Directive, v reflect.Value) error {
	switch v.Kind() {
	case reflect.Slice:
		t := v.Type()
		sv := reflect.MakeSlice(t, len(dir.Params), len(dir.Params))
		for i, param := range dir.Params {
			if err := unmarshalParam(param, sv.Index(i)); err != nil {
				return newUnmarshalParamError(dir, i, err)
			}
		}
		v.Set(sv)
	case reflect.Array:
		if len(dir.Params) != v.Len() {
			return newUnmarshalDirectiveError(dir, fmt.Sprintf("directive requires exactly %v parameters", v.Len()))
		}
		for i, param := range dir.Params {
			if err := unmarshalParam(param, v.Index(i)); err != nil {
				return newUnmarshalParamError(dir, i, err)
			}
		}
	default:
		if len(dir.Params) != 1 {
			return newUnmarshalDirectiveError(dir, "directive requires exactly one parameter")
		}
		if err := unmarshalParam(dir.Params[0], v); err != nil {
			return newUnmarshalParamError(dir, 0, err)
		}
	}

	return nil
}

func unmarshalParam(param string, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	// TODO: improve our logic following:
	// https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/encoding/json/decode.go;drc=b9b8cecbfc72168ca03ad586cc2ed52b0e8db409;l=421
	if v.CanAddr() {
		if v, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
			return v.UnmarshalText([]byte(param))
		}
	}

	switch v.Kind() {
	case reflect.String:
		v.Set(reflect.ValueOf(param))
	case reflect.Bool:
		switch param {
		case "true":
			v.Set(reflect.ValueOf(true))
		case "false":
			v.Set(reflect.ValueOf(false))
		default:
			return fmt.Errorf("invalid bool parameter %q", param)
		}
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		i, err := strconv.ParseInt(param, 10, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(i).Convert(t))
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		u, err := strconv.ParseUint(param, 10, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(u).Convert(t))
	case reflect.Float32, reflect.Float64:
		f, err := strconv.ParseFloat(param, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(f).Convert(t))
	default:
		return fmt.Errorf("unsupported type for unmarshaling parameter: %v", t)
	}

	return nil
}

func unwrapPointers(v reflect.Value) reflect.Value {
	for v.Kind() == reflect.Ptr {
		if v.IsNil() {
			v.Set(reflect.New(v.Type().Elem()))
		}
		v = v.Elem()
	}
	return v
}

func clearMap(v reflect.Value) {
	for _, k := range v.MapKeys() {
		v.SetMapIndex(k, reflect.Value{})
	}
}

type unmarshalDirectiveError struct {
	lineno int
	name   string
	msg    string
}

func newUnmarshalDirectiveError(dir *Directive, msg string) *unmarshalDirectiveError {
	return &unmarshalDirectiveError{
		name:   dir.Name,
		lineno: dir.lineno,
		msg:    msg,
	}
}

func (err *unmarshalDirectiveError) Error() string {
	return fmt.Sprintf("line %v, directive %q: %v", err.lineno, err.name, err.msg)
}

type unmarshalParamError struct {
	lineno     int
	directive  string
	paramIndex int
	err        error
}

func newUnmarshalParamError(dir *Directive, paramIndex int, err error) *unmarshalParamError {
	return &unmarshalParamError{
		directive:  dir.Name,
		lineno:     dir.lineno,
		paramIndex: paramIndex,
		err:        err,
	}
}

func (err *unmarshalParamError) Error() string {
	return fmt.Sprintf("line %v, directive %q, parameter %v: %v", err.lineno, err.directive, err.paramIndex+1, err.err)
}