Lindenii Project Forge
Warning: Due to various recent migrations, viewing non-HEAD refs may be broken.
/forged/internal/bare/unmarshal.go (raw)
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com>
package bare
import (
"bytes"
"errors"
"fmt"
"io"
"reflect"
"sync"
)
// A type which implements this interface will be responsible for unmarshaling
// itself when encountered.
type Unmarshalable interface {
Unmarshal(r *Reader) error
}
// Unmarshals a BARE message into val, which must be a pointer to a value of
// the message type.
func Unmarshal(data []byte, val interface{}) error {
b := bytes.NewReader(data)
r := NewReader(b)
return UnmarshalBareReader(r, val)
}
// Unmarshals a BARE message into value (val, which must be a pointer), from a
// reader. See Unmarshal for details.
func UnmarshalReader(r io.Reader, val interface{}) error {
r = newLimitedReader(r)
return UnmarshalBareReader(NewReader(r), val)
}
type decodeFunc func(r *Reader, v reflect.Value) error
var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc
func UnmarshalBareReader(r *Reader, val interface{}) error {
t := reflect.TypeOf(val)
v := reflect.ValueOf(val)
if t.Kind() != reflect.Ptr {
return errors.New("Expected val to be pointer type")
}
return getDecoder(t.Elem())(r, v.Elem())
}
// get decoder from cache
func getDecoder(t reflect.Type) decodeFunc {
if f, ok := decodeFuncCache.Load(t); ok {
return f.(decodeFunc)
}
f := decoderFunc(t)
decodeFuncCache.Store(t, f)
return f
}
var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()
func decoderFunc(t reflect.Type) decodeFunc {
if reflect.PointerTo(t).Implements(unmarshalableInterface) {
return func(r *Reader, v reflect.Value) error {
uv := v.Addr().Interface().(Unmarshalable)
return uv.Unmarshal(r)
}
}
if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
return decodeUnion(t)
}
switch t.Kind() {
case reflect.Ptr:
return decodeOptional(t.Elem())
case reflect.Struct:
return decodeStruct(t)
case reflect.Array:
return decodeArray(t)
case reflect.Slice:
return decodeSlice(t)
case reflect.Map:
return decodeMap(t)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return decodeUint
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return decodeInt
case reflect.Float32, reflect.Float64:
return decodeFloat
case reflect.Bool:
return decodeBool
case reflect.String:
return decodeString
}
return func(r *Reader, v reflect.Value) error {
return &UnsupportedTypeError{v.Type()}
}
}
func decodeOptional(t reflect.Type) decodeFunc {
return func(r *Reader, v reflect.Value) error {
s, err := r.ReadU8()
if err != nil {
return err
}
if s > 1 {
return fmt.Errorf("Invalid optional value: %#x", s)
}
if s == 0 {
return nil
}
v.Set(reflect.New(t))
return getDecoder(t)(r, v.Elem())
}
}
func decodeStruct(t reflect.Type) decodeFunc {
n := t.NumField()
decoders := make([]decodeFunc, n)
for i := 0; i < n; i++ {
field := t.Field(i)
if field.Tag.Get("bare") == "-" {
continue
}
decoders[i] = getDecoder(field.Type)
}
return func(r *Reader, v reflect.Value) error {
for i := 0; i < n; i++ {
if decoders[i] == nil {
continue
}
err := decoders[i](r, v.Field(i))
if err != nil {
return err
}
}
return nil
}
}
func decodeArray(t reflect.Type) decodeFunc {
f := getDecoder(t.Elem())
len := t.Len()
return func(r *Reader, v reflect.Value) error {
for i := 0; i < len; i++ {
err := f(r, v.Index(i))
if err != nil {
return err
}
}
return nil
}
}
func decodeSlice(t reflect.Type) decodeFunc {
elem := t.Elem()
f := getDecoder(elem)
return func(r *Reader, v reflect.Value) error {
len, err := r.ReadUint()
if err != nil {
return err
}
if len > maxArrayLength {
return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength)
}
v.Set(reflect.MakeSlice(t, int(len), int(len)))
for i := 0; i < int(len); i++ {
if err := f(r, v.Index(i)); err != nil {
return err
}
}
return nil
}
}
func decodeMap(t reflect.Type) decodeFunc {
keyType := t.Key()
keyf := getDecoder(keyType)
valueType := t.Elem()
valf := getDecoder(valueType)
return func(r *Reader, v reflect.Value) error {
size, err := r.ReadUint()
if err != nil {
return err
}
if size > maxMapSize {
return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize)
}
v.Set(reflect.MakeMapWithSize(t, int(size)))
key := reflect.New(keyType).Elem()
value := reflect.New(valueType).Elem()
for i := uint64(0); i < size; i++ {
if err := keyf(r, key); err != nil {
return err
}
if v.MapIndex(key).Kind() > reflect.Invalid {
return fmt.Errorf("Encountered duplicate map key: %v", key.Interface())
}
if err := valf(r, value); err != nil {
return err
}
v.SetMapIndex(key, value)
}
return nil
}
}
func decodeUnion(t reflect.Type) decodeFunc {
ut, ok := unionRegistry[t]
if !ok {
return func(r *Reader, v reflect.Value) error {
return fmt.Errorf("Union type %s is not registered", t.Name())
}
}
decoders := make(map[uint64]decodeFunc)
for tag, t := range ut.types {
t := t
f := getDecoder(t)
decoders[tag] = func(r *Reader, v reflect.Value) error {
nv := reflect.New(t)
if err := f(r, nv.Elem()); err != nil {
return err
}
v.Set(nv)
return nil
}
}
return func(r *Reader, v reflect.Value) error {
tag, err := r.ReadUint()
if err != nil {
return err
}
if f, ok := decoders[tag]; ok {
return f(r, v)
}
return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name())
}
}
func decodeUint(r *Reader, v reflect.Value) error {
var err error
switch getIntKind(v.Type()) {
case reflect.Uint:
var u uint64
u, err = r.ReadUint()
v.SetUint(u)
case reflect.Uint8:
var u uint8
u, err = r.ReadU8()
v.SetUint(uint64(u))
case reflect.Uint16:
var u uint16
u, err = r.ReadU16()
v.SetUint(uint64(u))
case reflect.Uint32:
var u uint32
u, err = r.ReadU32()
v.SetUint(uint64(u))
case reflect.Uint64:
var u uint64
u, err = r.ReadU64()
v.SetUint(uint64(u))
default:
panic("not an uint")
}
return err
}
func decodeInt(r *Reader, v reflect.Value) error {
var err error
switch getIntKind(v.Type()) {
case reflect.Int:
var i int64
i, err = r.ReadInt()
v.SetInt(i)
case reflect.Int8:
var i int8
i, err = r.ReadI8()
v.SetInt(int64(i))
case reflect.Int16:
var i int16
i, err = r.ReadI16()
v.SetInt(int64(i))
case reflect.Int32:
var i int32
i, err = r.ReadI32()
v.SetInt(int64(i))
case reflect.Int64:
var i int64
i, err = r.ReadI64()
v.SetInt(int64(i))
default:
panic("not an int")
}
return err
}
func decodeFloat(r *Reader, v reflect.Value) error {
var err error
switch v.Type().Kind() {
case reflect.Float32:
var f float32
f, err = r.ReadF32()
v.SetFloat(float64(f))
case reflect.Float64:
var f float64
f, err = r.ReadF64()
v.SetFloat(f)
default:
panic("not a float")
}
return err
}
func decodeBool(r *Reader, v reflect.Value) error {
b, err := r.ReadBool()
v.SetBool(b)
return err
}
func decodeString(r *Reader, v reflect.Value) error {
s, err := r.ReadString()
v.SetString(s)
return err
}