Lindenii Project Forge
Login

server

Lindenii Forge’s main backend daemon

Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.

/forged/internal/bare/unmarshal.go (raw)

// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com>

package bare

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"reflect"
	"sync"
)

// A type which implements this interface will be responsible for unmarshaling
// itself when encountered.
type Unmarshalable interface {
	Unmarshal(r *Reader) error
}

// Unmarshals a BARE message into val, which must be a pointer to a value of
// the message type.
func Unmarshal(data []byte, val interface{}) error {
	b := bytes.NewReader(data)
	r := NewReader(b)
	return UnmarshalBareReader(r, val)
}

// Unmarshals a BARE message into value (val, which must be a pointer), from a
// reader. See Unmarshal for details.
func UnmarshalReader(r io.Reader, val interface{}) error {
	r = newLimitedReader(r)
	return UnmarshalBareReader(NewReader(r), val)
}

type decodeFunc func(r *Reader, v reflect.Value) error

var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc

func UnmarshalBareReader(r *Reader, val interface{}) error {
	t := reflect.TypeOf(val)
	v := reflect.ValueOf(val)
	if t.Kind() != reflect.Ptr {
		return errors.New("Expected val to be pointer type")
	}

	return getDecoder(t.Elem())(r, v.Elem())
}

// get decoder from cache
func getDecoder(t reflect.Type) decodeFunc {
	if f, ok := decodeFuncCache.Load(t); ok {
		return f.(decodeFunc)
	}

	f := decoderFunc(t)
	decodeFuncCache.Store(t, f)
	return f
}

var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()

func decoderFunc(t reflect.Type) decodeFunc {
	if reflect.PointerTo(t).Implements(unmarshalableInterface) {
		return func(r *Reader, v reflect.Value) error {
			uv := v.Addr().Interface().(Unmarshalable)
			return uv.Unmarshal(r)
		}
	}

	if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
		return decodeUnion(t)
	}

	switch t.Kind() {
	case reflect.Ptr:
		return decodeOptional(t.Elem())
	case reflect.Struct:
		return decodeStruct(t)
	case reflect.Array:
		return decodeArray(t)
	case reflect.Slice:
		return decodeSlice(t)
	case reflect.Map:
		return decodeMap(t)
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		return decodeUint
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		return decodeInt
	case reflect.Float32, reflect.Float64:
		return decodeFloat
	case reflect.Bool:
		return decodeBool
	case reflect.String:
		return decodeString
	}

	return func(r *Reader, v reflect.Value) error {
		return &UnsupportedTypeError{v.Type()}
	}
}

func decodeOptional(t reflect.Type) decodeFunc {
	return func(r *Reader, v reflect.Value) error {
		s, err := r.ReadU8()
		if err != nil {
			return err
		}

		if s > 1 {
			return fmt.Errorf("Invalid optional value: %#x", s)
		}

		if s == 0 {
			return nil
		}

		v.Set(reflect.New(t))
		return getDecoder(t)(r, v.Elem())
	}
}

func decodeStruct(t reflect.Type) decodeFunc {
	n := t.NumField()
	decoders := make([]decodeFunc, n)
	for i := 0; i < n; i++ {
		field := t.Field(i)
		if field.Tag.Get("bare") == "-" {
			continue
		}
		decoders[i] = getDecoder(field.Type)
	}

	return func(r *Reader, v reflect.Value) error {
		for i := 0; i < n; i++ {
			if decoders[i] == nil {
				continue
			}
			err := decoders[i](r, v.Field(i))
			if err != nil {
				return err
			}
		}
		return nil
	}
}

func decodeArray(t reflect.Type) decodeFunc {
	f := getDecoder(t.Elem())
	len := t.Len()

	return func(r *Reader, v reflect.Value) error {
		for i := 0; i < len; i++ {
			err := f(r, v.Index(i))
			if err != nil {
				return err
			}
		}
		return nil
	}
}

func decodeSlice(t reflect.Type) decodeFunc {
	elem := t.Elem()
	f := getDecoder(elem)

	return func(r *Reader, v reflect.Value) error {
		len, err := r.ReadUint()
		if err != nil {
			return err
		}

		if len > maxArrayLength {
			return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength)
		}

		v.Set(reflect.MakeSlice(t, int(len), int(len)))

		for i := 0; i < int(len); i++ {
			if err := f(r, v.Index(i)); err != nil {
				return err
			}
		}
		return nil
	}
}

func decodeMap(t reflect.Type) decodeFunc {
	keyType := t.Key()
	keyf := getDecoder(keyType)

	valueType := t.Elem()
	valf := getDecoder(valueType)

	return func(r *Reader, v reflect.Value) error {
		size, err := r.ReadUint()
		if err != nil {
			return err
		}

		if size > maxMapSize {
			return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize)
		}

		v.Set(reflect.MakeMapWithSize(t, int(size)))

		key := reflect.New(keyType).Elem()
		value := reflect.New(valueType).Elem()

		for i := uint64(0); i < size; i++ {
			if err := keyf(r, key); err != nil {
				return err
			}

			if v.MapIndex(key).Kind() > reflect.Invalid {
				return fmt.Errorf("Encountered duplicate map key: %v", key.Interface())
			}

			if err := valf(r, value); err != nil {
				return err
			}

			v.SetMapIndex(key, value)
		}
		return nil
	}
}

func decodeUnion(t reflect.Type) decodeFunc {
	ut, ok := unionRegistry[t]
	if !ok {
		return func(r *Reader, v reflect.Value) error {
			return fmt.Errorf("Union type %s is not registered", t.Name())
		}
	}

	decoders := make(map[uint64]decodeFunc)
	for tag, t := range ut.types {
		t := t
		f := getDecoder(t)

		decoders[tag] = func(r *Reader, v reflect.Value) error {
			nv := reflect.New(t)
			if err := f(r, nv.Elem()); err != nil {
				return err
			}

			v.Set(nv)
			return nil
		}
	}

	return func(r *Reader, v reflect.Value) error {
		tag, err := r.ReadUint()
		if err != nil {
			return err
		}

		if f, ok := decoders[tag]; ok {
			return f(r, v)
		}

		return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name())
	}
}

func decodeUint(r *Reader, v reflect.Value) error {
	var err error
	switch getIntKind(v.Type()) {
	case reflect.Uint:
		var u uint64
		u, err = r.ReadUint()
		v.SetUint(u)

	case reflect.Uint8:
		var u uint8
		u, err = r.ReadU8()
		v.SetUint(uint64(u))

	case reflect.Uint16:
		var u uint16
		u, err = r.ReadU16()
		v.SetUint(uint64(u))
	case reflect.Uint32:
		var u uint32
		u, err = r.ReadU32()
		v.SetUint(uint64(u))

	case reflect.Uint64:
		var u uint64
		u, err = r.ReadU64()
		v.SetUint(uint64(u))

	default:
		panic("not an uint")
	}

	return err
}

func decodeInt(r *Reader, v reflect.Value) error {
	var err error
	switch getIntKind(v.Type()) {
	case reflect.Int:
		var i int64
		i, err = r.ReadInt()
		v.SetInt(i)

	case reflect.Int8:
		var i int8
		i, err = r.ReadI8()
		v.SetInt(int64(i))

	case reflect.Int16:
		var i int16
		i, err = r.ReadI16()
		v.SetInt(int64(i))
	case reflect.Int32:
		var i int32
		i, err = r.ReadI32()
		v.SetInt(int64(i))

	case reflect.Int64:
		var i int64
		i, err = r.ReadI64()
		v.SetInt(int64(i))

	default:
		panic("not an int")
	}

	return err
}

func decodeFloat(r *Reader, v reflect.Value) error {
	var err error
	switch v.Type().Kind() {
	case reflect.Float32:
		var f float32
		f, err = r.ReadF32()
		v.SetFloat(float64(f))
	case reflect.Float64:
		var f float64
		f, err = r.ReadF64()
		v.SetFloat(f)
	default:
		panic("not a float")
	}
	return err
}

func decodeBool(r *Reader, v reflect.Value) error {
	b, err := r.ReadBool()
	v.SetBool(b)
	return err
}

func decodeString(r *Reader, v reflect.Value) error {
	s, err := r.ReadString()
	v.SetString(s)
	return err
}