Lindenii Project Forge
Login

go-lindenii-common

Common library for the Lindenii Project
Commit info
ID
b6c8304364d1457fa7454da257f7ad3d71131b91
Author
Runxi Yu <me@runxiyu.org>
Author date
Sun, 29 Dec 2024 18:40:22 +0000
Committer
Runxi Yu <me@runxiyu.org>
Committer date
Sun, 29 Dec 2024 18:40:22 +0000
Actions
Add ~emersion/go-scfg as a subdirectory
// Inspired by github.com/SaveTheRbtz/generic-sync-map-go but technically
// written from scratch with Go 1.23's sync.Map.
// Copyright 2024 Runxi Yu (porting it to generics)
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package lindenii_common
package cmap

import (
	"sync"
	"sync/atomic"
	"unsafe"
)

// Map[K comparable, V comparable] is like a Go map[K]V but is safe for concurrent use
// by multiple goroutines without additional locking or coordination.  Loads,
// stores, and deletes run in amortized constant time.
//
// The Map type is optimized for two common use cases: (1) when the entry for a given
// key is only ever written once but read many times, as in caches that only grow,
// or (2) when multiple goroutines read, write, and overwrite entries for disjoint
// sets of keys. In these two cases, use of a Map may significantly reduce lock
// contention compared to a Go map paired with a separate [Mutex] or [RWMutex].
//
// The zero Map is empty and ready for use. A Map must not be copied after first use.
//
// In the terminology of [the Go memory model], Map arranges that a write operation
// “synchronizes before” any read operation that observes the effect of the write, where
// read and write operations are defined as follows.
// [Map.Load], [Map.LoadAndDelete], [Map.LoadOrStore], [Map.Swap], [Map.CompareAndSwap],
// and [Map.CompareAndDelete] are read operations;
// [Map.Delete], [Map.LoadAndDelete], [Map.Store], and [Map.Swap] are write operations;
// [Map.LoadOrStore] is a write operation when it returns loaded set to false;
// [Map.CompareAndSwap] is a write operation when it returns swapped set to true;
// and [Map.CompareAndDelete] is a write operation when it returns deleted set to true.
//
// [the Go memory model]: https://go.dev/ref/mem
type Map[K comparable, V comparable] struct {
	mu sync.Mutex

	// read contains the portion of the map's contents that are safe for
	// concurrent access (with or without mu held).
	//
	// The read field itself is always safe to load, but must only be stored with
	// mu held.
	//
	// Entries stored in read may be updated concurrently without mu, but updating
	// a previously-expunged entry requires that the entry be copied to the dirty
	// map and unexpunged with mu held.
	read atomic.Pointer[readOnly[K, V]]

	// dirty contains the portion of the map's contents that require mu to be
	// held. To ensure that the dirty map can be promoted to the read map quickly,
	// it also includes all of the non-expunged entries in the read map.
	//
	// Expunged entries are not stored in the dirty map. An expunged entry in the
	// clean map must be unexpunged and added to the dirty map before a new value
	// can be stored to it.
	//
	// If the dirty map is nil, the next write to the map will initialize it by
	// making a shallow copy of the clean map, omitting stale entries.
	dirty map[K]*entry[V]

	// misses counts the number of loads since the read map was last updated that
	// needed to lock mu to determine whether the key was present.
	//
	// Once enough misses have occurred to cover the cost of copying the dirty
	// map, the dirty map will be promoted to the read map (in the unamended
	// state) and the next store to the map will make a new dirty copy.
	misses int
}

// readOnly is an immutable struct stored atomically in the Map.read field.
type readOnly[K comparable, V comparable] struct {
	m       map[K]*entry[V]
	amended bool // true if the dirty map contains some key not in m.
}

// expunged is an arbitrary pointer that marks entries which have been deleted
// from the dirty map.
var expunged = unsafe.Pointer(new(any))

// An entry is a slot in the map corresponding to a particular key.
type entry[V comparable] struct {
	// p points to the interface{} value stored for the entry.
	//
	// If p == nil, the entry has been deleted, and either m.dirty == nil or
	// m.dirty[key] is e.
	//
	// If p == expunged, the entry has been deleted, m.dirty != nil, and the entry
	// is missing from m.dirty.
	//
	// Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty
	// != nil, in m.dirty[key].
	//
	// An entry can be deleted by atomic replacement with nil: when m.dirty is
	// next created, it will atomically replace nil with expunged and leave
	// m.dirty[key] unset.
	//
	// An entry's associated value can be updated by atomic replacement, provided
	// p != expunged. If p == expunged, an entry's associated value can be updated
	// only after first setting m.dirty[key] = e so that lookups using the dirty
	// map find the entry.
	p unsafe.Pointer
}

func newEntry[V comparable](i V) *entry[V] {
	return &entry[V]{p: unsafe.Pointer(&i)}
}

func (m *Map[K, V]) loadReadOnly() readOnly[K, V] {
	if p := m.read.Load(); p != nil {
		return *p
	}
	return readOnly[K, V]{}
}

// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
	read := m.loadReadOnly()
	e, ok := read.m[key]
	if !ok && read.amended {
		m.mu.Lock()
		// Avoid reporting a spurious miss if m.dirty got promoted while we were
		// blocked on m.mu. (If further loads of the same key will not miss, it's
		// not worth copying the dirty map for this key.)
		read = m.loadReadOnly()
		e, ok = read.m[key]
		if !ok && read.amended {
			e, ok = m.dirty[key]
			// Regardless of whether the entry was present, record a miss: this key
			// will take the slow path until the dirty map is promoted to the read
			// map.
			m.missLocked()
		}
		m.mu.Unlock()
	}
	if !ok {
		return *new(V), false
	}
	return e.load()
}

func (e *entry[V]) load() (value V, ok bool) {
	p := atomic.LoadPointer(&e.p)
	if p == nil || p == expunged {
		return value, false
	}
	return *(*V)(p), true
}

// Store sets the value for a key.
func (m *Map[K, V]) Store(key K, value V) {
	_, _ = m.Swap(key, value)
}

// Clear deletes all the entries, resulting in an empty Map.
func (m *Map[K, V]) Clear() {
	read := m.loadReadOnly()
	if len(read.m) == 0 && !read.amended {
		// Avoid allocating a new readOnly when the map is already clear.
		return
	}

	m.mu.Lock()
	defer m.mu.Unlock()

	read = m.loadReadOnly()
	if len(read.m) > 0 || read.amended {
		m.read.Store(&readOnly[K, V]{})
	}

	clear(m.dirty)
	// Don't immediately promote the newly-cleared dirty map on the next operation.
	m.misses = 0
}

// tryCompareAndSwap compare the entry with the given old value and swaps
// it with a new value if the entry is equal to the old value, and the entry
// has not been expunged.
//
// If the entry is expunged, tryCompareAndSwap returns false and leaves
// the entry unchanged.
func (e *entry[V]) tryCompareAndSwap(old V, new V) bool {
	p := atomic.LoadPointer(&e.p)
	if p == nil || p == expunged || *(*V)(p) != old { // XXX
		return false
	}

	// Copy the interface after the first load to make this method more amenable
	// to escape analysis: if the comparison fails from the start, we shouldn't
	// bother heap-allocating an interface value to store.
	nc := new
	for {
		if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(&nc)) {
			return true
		}
		p = atomic.LoadPointer(&e.p)
		if p == nil || p == expunged || *(*V)(p) != old {
			return false
		}
	}
}

// unexpungeLocked ensures that the entry is not marked as expunged.
//
// If the entry was previously expunged, it must be added to the dirty map
// before m.mu is unlocked.
func (e *entry[V]) unexpungeLocked() (wasExpunged bool) {
	return atomic.CompareAndSwapPointer(&e.p, expunged, nil)
}

// swapLocked unconditionally swaps a value into the entry.
//
// The entry must be known not to be expunged.
func (e *entry[V]) swapLocked(i *V) *V {
	return (*V)(atomic.SwapPointer(&e.p, unsafe.Pointer(i)))
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
	// Avoid locking if it's a clean hit.
	read := m.loadReadOnly()
	if e, ok := read.m[key]; ok {
		actual, loaded, ok := e.tryLoadOrStore(value)
		if ok {
			return actual, loaded
		}
	}

	m.mu.Lock()
	read = m.loadReadOnly()
	if e, ok := read.m[key]; ok {
		if e.unexpungeLocked() {
			m.dirty[key] = e
		}
		actual, loaded, _ = e.tryLoadOrStore(value)
	} else if e, ok := m.dirty[key]; ok {
		actual, loaded, _ = e.tryLoadOrStore(value)
		m.missLocked()
	} else {
		if !read.amended {
			// We're adding the first new key to the dirty map.
			// Make sure it is allocated and mark the read-only map as incomplete.
			m.dirtyLocked()
			m.read.Store(&readOnly[K, V]{m: read.m, amended: true})
		}
		m.dirty[key] = newEntry(value)
		actual, loaded = value, false
	}
	m.mu.Unlock()

	return actual, loaded
}

// tryLoadOrStore atomically loads or stores a value if the entry is not
// expunged.
//
// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and
// returns with ok==false.
func (e *entry[V]) tryLoadOrStore(i V) (actual V, loaded, ok bool) {
	p := atomic.LoadPointer(&e.p)
	if p == expunged {
		return actual, false, false
	}
	if p != nil {
		return *(*V)(p), true, true
	}

	// Copy the interface after the first load to make this method more amenable
	// to escape analysis: if we hit the "load" path or the entry is expunged, we
	// shouldn't bother heap-allocating.
	ic := i
	for {
		if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) {
			return i, false, true
		}
		p = atomic.LoadPointer(&e.p)
		if p == expunged {
			return actual, false, false
		}
		if p != nil {
			return *(*V)(p), true, true
		}
	}
}

// LoadAndDelete deletes the value for a key, returning the previous value if any.
// The loaded result reports whether the key was present.
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
	read := m.loadReadOnly()
	e, ok := read.m[key]
	if !ok && read.amended {
		m.mu.Lock()
		read = m.loadReadOnly()
		e, ok = read.m[key]
		if !ok && read.amended {
			e, ok = m.dirty[key]
			delete(m.dirty, key)
			// Regardless of whether the entry was present, record a miss: this key
			// will take the slow path until the dirty map is promoted to the read
			// map.
			m.missLocked()
		}
		m.mu.Unlock()
	}
	if ok {
		return e.delete()
	}
	return value, false
}

// Delete deletes the value for a key.
func (m *Map[K, V]) Delete(key K) {
	m.LoadAndDelete(key)
}

func (e *entry[V]) delete() (value V, ok bool) {
	for {
		p := atomic.LoadPointer(&e.p)
		if p == nil || p == expunged {
			return value, false
		}
		if atomic.CompareAndSwapPointer(&e.p, p, nil) {
			return *(*V)(p), true
		}
	}
}

// trySwap swaps a value if the entry has not been expunged.
//
// If the entry is expunged, trySwap returns false and leaves the entry
// unchanged.
func (e *entry[V]) trySwap(i *V) (*V, bool) {
	for {
		p := atomic.LoadPointer(&e.p)
		if p == expunged {
			return nil, false
		}
		if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) {
			return (*V)(p), true
		}
	}
}

// Swap swaps the value for a key and returns the previous value if any.
// The loaded result reports whether the key was present.
func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) {
	read := m.loadReadOnly()
	if e, ok := read.m[key]; ok {
		if v, ok := e.trySwap(&value); ok {
			if v == nil {
				return previous, false
			}
			return *v, true
		}
	}

	m.mu.Lock()
	read = m.loadReadOnly()
	if e, ok := read.m[key]; ok {
		if e.unexpungeLocked() {
			// The entry was previously expunged, which implies that there is a
			// non-nil dirty map and this entry is not in it.
			m.dirty[key] = e
		}
		if v := e.swapLocked(&value); v != nil {
			loaded = true
			previous = *v
		}
	} else if e, ok := m.dirty[key]; ok {
		if v := e.swapLocked(&value); v != nil {
			loaded = true
			previous = *v
		}
	} else {
		if !read.amended {
			// We're adding the first new key to the dirty map.
			// Make sure it is allocated and mark the read-only map as incomplete.
			m.dirtyLocked()
			m.read.Store(&readOnly[K, V]{m: read.m, amended: true})
		}
		m.dirty[key] = newEntry(value)
	}
	m.mu.Unlock()
	return previous, loaded
}

// CompareAndSwap swaps the old and new values for key
// if the value stored in the map is equal to old.
// The old value must be of a comparable type.
func (m *Map[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
	read := m.loadReadOnly()
	if e, ok := read.m[key]; ok {
		return e.tryCompareAndSwap(old, new)
	} else if !read.amended {
		return false // No existing value for key.
	}

	m.mu.Lock()
	defer m.mu.Unlock()
	read = m.loadReadOnly()
	swapped = false
	if e, ok := read.m[key]; ok {
		swapped = e.tryCompareAndSwap(old, new)
	} else if e, ok := m.dirty[key]; ok {
		swapped = e.tryCompareAndSwap(old, new)
		// We needed to lock mu in order to load the entry for key,
		// and the operation didn't change the set of keys in the map
		// (so it would be made more efficient by promoting the dirty
		// map to read-only).
		// Count it as a miss so that we will eventually switch to the
		// more efficient steady state.
		m.missLocked()
	}
	return swapped
}

// CompareAndDelete deletes the entry for key if its value is equal to old.
// The old value must be of a comparable type.
//
// If there is no current value for key in the map, CompareAndDelete
// returns false (even if the old value is the nil interface value).
func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
	read := m.loadReadOnly()
	e, ok := read.m[key]
	if !ok && read.amended {
		m.mu.Lock()
		read = m.loadReadOnly()
		e, ok = read.m[key]
		if !ok && read.amended {
			e, ok = m.dirty[key]
			// Don't delete key from m.dirty: we still need to do the “compare” part
			// of the operation. The entry will eventually be expunged when the
			// dirty map is promoted to the read map.
			//
			// Regardless of whether the entry was present, record a miss: this key
			// will take the slow path until the dirty map is promoted to the read
			// map.
			m.missLocked()
		}
		m.mu.Unlock()
	}
	for ok {
		p := atomic.LoadPointer(&e.p)
		if p == nil || p == expunged || *(*V)(p) != old {
			return false
		}
		if atomic.CompareAndSwapPointer(&e.p, p, nil) {
			return true
		}
	}
	return false
}

// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot of the Map's
// contents: no key will be visited more than once, but if the value for any key
// is stored or deleted concurrently (including by f), Range may reflect any
// mapping for that key from any point during the Range call. Range does not
// block other methods on the receiver; even f itself may call any method on m.
//
// Range may be O(N) with the number of elements in the map even if f returns
// false after a constant number of calls.
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
	// We need to be able to iterate over all of the keys that were already
	// present at the start of the call to Range.
	// If read.amended is false, then read.m satisfies that property without
	// requiring us to hold m.mu for a long time.
	read := m.loadReadOnly()
	if read.amended {
		// m.dirty contains keys not in read.m. Fortunately, Range is already O(N)
		// (assuming the caller does not break out early), so a call to Range
		// amortizes an entire copy of the map: we can promote the dirty copy
		// immediately!
		m.mu.Lock()
		read = m.loadReadOnly()
		if read.amended {
			read = readOnly[K, V]{m: m.dirty}
			copyRead := read
			m.read.Store(&copyRead)
			m.dirty = nil
			m.misses = 0
		}
		m.mu.Unlock()
	}

	for k, e := range read.m {
		v, ok := e.load()
		if !ok {
			continue
		}
		if !f(k, v) {
			break
		}
	}
}

func (m *Map[K, V]) missLocked() {
	m.misses++
	if m.misses < len(m.dirty) {
		return
	}
	m.read.Store(&readOnly[K, V]{m: m.dirty})
	m.dirty = nil
	m.misses = 0
}

func (m *Map[K, V]) dirtyLocked() {
	if m.dirty != nil {
		return
	}

	read := m.loadReadOnly()
	m.dirty = make(map[K]*entry[V], len(read.m))
	for k, e := range read.m {
		if !e.tryExpungeLocked() {
			m.dirty[k] = e
		}
	}
}

func (e *entry[V]) tryExpungeLocked() (isExpunged bool) {
	p := atomic.LoadPointer(&e.p)
	for p == nil {
		if atomic.CompareAndSwapPointer(&e.p, nil, expunged) {
			return true
		}
		p = atomic.LoadPointer(&e.p)
	}
	return p == expunged
}
// This is directly ported from github.com/SaveTheRbtz/generic-sync-map-go
// and does not test the new CompareAndSwap and related functions yet.
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package lindenii_common_test
package cmap_test

import (
	"math/rand"
	"runtime"
	"sync"
	"testing"

	lindenii_common "go.lindenii.runxiyu.org/lindenii-common"
	"go.lindenii.runxiyu.org/lindenii-common/cmap"
)

func TestConcurrentRange(t *testing.T) {
	const mapSize = 1 << 10

	m := new(lindenii_common.Map[int64, int64])
	m := new(cmap.Map[int64, int64])
	for n := int64(1); n <= mapSize; n++ {
		m.Store(n, int64(n))
	}

	done := make(chan struct{})
	var wg sync.WaitGroup
	defer func() {
		close(done)
		wg.Wait()
	}()
	for g := int64(runtime.GOMAXPROCS(0)); g > 0; g-- {
		r := rand.New(rand.NewSource(g))
		wg.Add(1)
		go func(g int64) {
			defer wg.Done()
			for i := int64(0); ; i++ {
				select {
				case <-done:
					return
				default:
				}
				for n := int64(1); n < mapSize; n++ {
					if r.Int63n(mapSize) == 0 {
						m.Store(n, n*i*g)
					} else {
						m.Load(n)
					}
				}
			}
		}(g)
	}

	iters := 1 << 10
	if testing.Short() {
		iters = 16
	}
	for n := iters; n > 0; n-- {
		seen := make(map[int64]bool, mapSize)

		m.Range(func(k, v int64) bool {
			if v%k != 0 {
				t.Fatalf("while Storing multiples of %v, Range saw value %v", k, v)
			}
			if seen[k] {
				t.Fatalf("Range visited key %v twice", k)
			}
			seen[k] = true
			return true
		})

		if len(seen) != mapSize {
			t.Fatalf("Range visited %v elements of %v-element Map", len(seen), mapSize)
		}
	}
}
package scfg

import (
	"bufio"
	"fmt"
	"io"
	"os"
	"strings"
)

// This limits the max block nesting depth to prevent stack overflows.
const maxNestingDepth = 1000

// Load loads a configuration file.
func Load(path string) (Block, error) {
	f, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	defer f.Close()

	return Read(f)
}

// Read parses a configuration file from an io.Reader.
func Read(r io.Reader) (Block, error) {
	scanner := bufio.NewScanner(r)

	dec := decoder{scanner: scanner}
	block, closingBrace, err := dec.readBlock()
	if err != nil {
		return nil, err
	} else if closingBrace {
		return nil, fmt.Errorf("line %v: unexpected '}'", dec.lineno)
	}

	return block, scanner.Err()
}

type decoder struct {
	scanner    *bufio.Scanner
	lineno     int
	blockDepth int
}

// readBlock reads a block. closingBrace is true if parsing stopped on '}'
// (otherwise, it stopped on Scanner.Scan).
func (dec *decoder) readBlock() (block Block, closingBrace bool, err error) {
	dec.blockDepth++
	defer func() {
		dec.blockDepth--
	}()

	if dec.blockDepth >= maxNestingDepth {
		return nil, false, fmt.Errorf("exceeded max block depth")
	}

	for dec.scanner.Scan() {
		dec.lineno++

		l := dec.scanner.Text()
		words, err := splitWords(l)
		if err != nil {
			return nil, false, fmt.Errorf("line %v: %v", dec.lineno, err)
		} else if len(words) == 0 {
			continue
		}

		if len(words) == 1 && l[len(l)-1] == '}' {
			closingBrace = true
			break
		}

		var d *Directive
		if words[len(words)-1] == "{" && l[len(l)-1] == '{' {
			words = words[:len(words)-1]

			var name string
			params := words
			if len(words) > 0 {
				name, params = words[0], words[1:]
			}

			startLineno := dec.lineno
			childBlock, childClosingBrace, err := dec.readBlock()
			if err != nil {
				return nil, false, err
			} else if !childClosingBrace {
				return nil, false, fmt.Errorf("line %v: unterminated block", startLineno)
			}

			// Allows callers to tell apart "no block" and "empty block"
			if childBlock == nil {
				childBlock = Block{}
			}

			d = &Directive{Name: name, Params: params, Children: childBlock, lineno: dec.lineno}
		} else {
			d = &Directive{Name: words[0], Params: words[1:], lineno: dec.lineno}
		}
		block = append(block, d)
	}

	return block, closingBrace, nil
}

func splitWords(l string) ([]string, error) {
	var (
		words   []string
		sb      strings.Builder
		escape  bool
		quote   rune
		wantWSP bool
	)
	for _, ch := range l {
		switch {
		case escape:
			sb.WriteRune(ch)
			escape = false
		case wantWSP && (ch != ' ' && ch != '\t'):
			return words, fmt.Errorf("atom not allowed after quoted string")
		case ch == '\\':
			escape = true
		case quote != 0 && ch == quote:
			quote = 0
			wantWSP = true
			if sb.Len() == 0 {
				words = append(words, "")
			}
		case quote == 0 && len(words) == 0 && sb.Len() == 0 && ch == '#':
			return nil, nil
		case quote == 0 && (ch == '\'' || ch == '"'):
			if sb.Len() > 0 {
				return words, fmt.Errorf("quoted string not allowed after atom")
			}
			quote = ch
		case quote == 0 && (ch == ' ' || ch == '\t'):
			if sb.Len() > 0 {
				words = append(words, sb.String())
			}
			sb.Reset()
			wantWSP = false
		default:
			sb.WriteRune(ch)
		}
	}
	if quote != 0 {
		return words, fmt.Errorf("unterminated quoted string")
	}
	if sb.Len() > 0 {
		words = append(words, sb.String())
	}
	return words, nil
}
package scfg

import (
	"fmt"
	"reflect"
	"strings"
	"testing"
)

var readTests = []struct {
	name string
	src  string
	want Block
}{
	{
		name: "flat",
		src: `dir1 param1 param2 "" param3
dir2
dir3 param1

# comment
dir4 "param 1" 'param 2'`,
		want: Block{
			{Name: "dir1", Params: []string{"param1", "param2", "", "param3"}},
			{Name: "dir2", Params: []string{}},
			{Name: "dir3", Params: []string{"param1"}},
			{Name: "dir4", Params: []string{"param 1", "param 2"}},
		},
	},
	{
		name: "simpleBlocks",
		src: `block1 {
	dir1 param1 param2
	dir2 param1
}

block2 {
}

block3 {
	# comment
}

block4 param1 "param2" {
	dir1
}`,
		want: Block{
			{
				Name:   "block1",
				Params: []string{},
				Children: Block{
					{Name: "dir1", Params: []string{"param1", "param2"}},
					{Name: "dir2", Params: []string{"param1"}},
				},
			},
			{Name: "block2", Params: []string{}, Children: Block{}},
			{Name: "block3", Params: []string{}, Children: Block{}},
			{
				Name:   "block4",
				Params: []string{"param1", "param2"},
				Children: Block{
					{Name: "dir1", Params: []string{}},
				},
			},
		},
	},
	{
		name: "nested",
		src: `block1 {
	block2 {
		dir1 param1
	}

	block3 {
	}
}

block4 {
	block5 {
		block6 param1 {
			dir1
		}
	}

	dir1
}`,
		want: Block{
			{
				Name:   "block1",
				Params: []string{},
				Children: Block{
					{
						Name:   "block2",
						Params: []string{},
						Children: Block{
							{Name: "dir1", Params: []string{"param1"}},
						},
					},
					{
						Name:     "block3",
						Params:   []string{},
						Children: Block{},
					},
				},
			},
			{
				Name:   "block4",
				Params: []string{},
				Children: Block{
					{
						Name:   "block5",
						Params: []string{},
						Children: Block{{
							Name:   "block6",
							Params: []string{"param1"},
							Children: Block{{
								Name:   "dir1",
								Params: []string{},
							}},
						}},
					},
					{
						Name:   "dir1",
						Params: []string{},
					},
				},
			},
		},
	},
	{
		name: "quotes",
		src:  `"a \b ' \" c" 'd \e \' " f' a\"b`,
		want: Block{
			{Name: "a b ' \" c", Params: []string{"d e ' \" f", "a\"b"}},
		},
	},
	{
		name: "quotes-2",
		src:  `dir arg1 "arg2" ` + `\"\"`,
		want: Block{
			{Name: "dir", Params: []string{"arg1", "arg2", "\"\""}},
		},
	},
	{
		name: "quotes-3",
		src:  `dir arg1 "\"\"\"\"" arg3`,
		want: Block{
			{Name: "dir", Params: []string{"arg1", `"` + "\"\"" + `"`, "arg3"}},
		},
	},
}

func TestRead(t *testing.T) {
	for _, test := range readTests {
		t.Run(test.name, func(t *testing.T) {
			r := strings.NewReader(test.src)
			got, err := Read(r)
			if err != nil {
				t.Fatalf("Read() = %v", err)
			}
			stripLineno(got)
			if !reflect.DeepEqual(got, test.want) {
				t.Error(fmt.Sprintf("Read() = %#v but want %#v", got, test.want))
			}
		})
	}
}

func stripLineno(block Block) {
	for _, dir := range block {
		dir.lineno = 0
		stripLineno(dir.Children)
	}
}
// Package scfg parses and formats configuration files.
package scfg

import (
	"fmt"
)

// Block is a list of directives.
type Block []*Directive

// GetAll returns a list of directives with the provided name.
func (blk Block) GetAll(name string) []*Directive {
	l := make([]*Directive, 0, len(blk))
	for _, child := range blk {
		if child.Name == name {
			l = append(l, child)
		}
	}
	return l
}

// Get returns the first directive with the provided name.
func (blk Block) Get(name string) *Directive {
	for _, child := range blk {
		if child.Name == name {
			return child
		}
	}
	return nil
}

// Directive is a configuration directive.
type Directive struct {
	Name   string
	Params []string

	Children Block

	lineno int
}

// ParseParams extracts parameters from the directive. It errors out if the
// user hasn't provided enough parameters.
func (d *Directive) ParseParams(params ...*string) error {
	if len(d.Params) < len(params) {
		return fmt.Errorf("directive %q: want %v params, got %v", d.Name, len(params), len(d.Params))
	}
	for i, ptr := range params {
		if ptr == nil {
			continue
		}
		*ptr = d.Params[i]
	}
	return nil
}
package scfg

import (
	"fmt"
	"reflect"
	"strings"
	"sync"
)

// structInfo contains scfg metadata for structs.
type structInfo struct {
	param    int            // index of field storing parameters
	children map[string]int // indices of fields storing child directives
}

var (
	structCacheMutex sync.Mutex
	structCache      = make(map[reflect.Type]*structInfo)
)

func getStructInfo(t reflect.Type) (*structInfo, error) {
	structCacheMutex.Lock()
	defer structCacheMutex.Unlock()

	if info := structCache[t]; info != nil {
		return info, nil
	}

	info := &structInfo{
		param:    -1,
		children: make(map[string]int),
	}

	for i := 0; i < t.NumField(); i++ {
		f := t.Field(i)
		if f.Anonymous {
			return nil, fmt.Errorf("scfg: anonymous struct fields are not supported")
		} else if !f.IsExported() {
			continue
		}

		tag := f.Tag.Get("scfg")
		parts := strings.Split(tag, ",")
		k, options := parts[0], parts[1:]
		if k == "-" {
			continue
		} else if k == "" {
			k = f.Name
		}

		isParam := false
		for _, opt := range options {
			switch opt {
			case "param":
				isParam = true
			default:
				return nil, fmt.Errorf("scfg: invalid option %q in struct tag", opt)
			}
		}

		if isParam {
			if info.param >= 0 {
				return nil, fmt.Errorf("scfg: param option specified multiple times in struct tag in %v", t)
			}
			if parts[0] != "" {
				return nil, fmt.Errorf("scfg: name must be empty when param option is specified in struct tag in %v", t)
			}
			info.param = i
		} else {
			if _, ok := info.children[k]; ok {
				return nil, fmt.Errorf("scfg: key %q specified multiple times in struct tag in %v", k, t)
			}
			info.children[k] = i
		}
	}

	structCache[t] = info
	return info, nil
}
package scfg

import (
	"encoding"
	"fmt"
	"io"
	"reflect"
	"strconv"
)

// Decoder reads and decodes an scfg document from an input stream.
type Decoder struct {
	r io.Reader
}

// NewDecoder returns a new decoder which reads from r.
func NewDecoder(r io.Reader) *Decoder {
	return &Decoder{r}
}

// Decode reads scfg document from the input and stores it in the value pointed
// to by v.
//
// If v is nil or not a pointer, Decode returns an error.
//
// Blocks can be unmarshaled to:
//
//   - Maps. Each directive is unmarshaled into a map entry. The map key must
//     be a string.
//   - Structs. Each directive is unmarshaled into a struct field.
//
// Duplicate directives are not allowed, unless the struct field or map value
// is a slice of values representing a directive: structs or maps.
//
// Directives can be unmarshaled to:
//
//   - Maps. The children block is unmarshaled into the map. Parameters are not
//     allowed.
//   - Structs. The children block is unmarshaled into the struct. Parameters
//     are allowed if one of the struct fields contains the "param" option in
//     its tag.
//   - Slices. Parameters are unmarshaled into the slice. Children blocks are
//     not allowed.
//   - Arrays. Parameters are unmarshaled into the array. The number of
//     parameters must match exactly the length of the array. Children blocks
//     are not allowed.
//   - Strings, booleans, integers, floating-point values, values implementing
//     encoding.TextUnmarshaler. Only a single parameter is allowed and is
//     unmarshaled into the value. Children blocks are not allowed.
//
// The decoding of each struct field can be customized by the format string
// stored under the "scfg" key in the struct field's tag. The tag contains the
// name of the field possibly followed by a comma-separated list of options.
// The name may be empty in order to specify options without overriding the
// default field name. As a special case, if the field name is "-", the field
// is ignored. The "param" option specifies that directive parameters are
// stored in this field (the name must be empty).
func (dec *Decoder) Decode(v interface{}) error {
	block, err := Read(dec.r)
	if err != nil {
		return err
	}

	rv := reflect.ValueOf(v)
	if rv.Kind() != reflect.Ptr || rv.IsNil() {
		return fmt.Errorf("scfg: invalid value for unmarshaling")
	}

	return unmarshalBlock(block, rv)
}

func unmarshalBlock(block Block, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	dirsByName := make(map[string][]*Directive, len(block))
	for _, dir := range block {
		dirsByName[dir.Name] = append(dirsByName[dir.Name], dir)
	}

	switch v.Kind() {
	case reflect.Map:
		if t.Key().Kind() != reflect.String {
			return fmt.Errorf("scfg: map key type must be string")
		}
		if v.IsNil() {
			v.Set(reflect.MakeMap(t))
		} else if v.Len() > 0 {
			clearMap(v)
		}

		for name, dirs := range dirsByName {
			mv := reflect.New(t.Elem()).Elem()
			if err := unmarshalDirectiveList(dirs, mv); err != nil {
				return err
			}
			v.SetMapIndex(reflect.ValueOf(name), mv)
		}
	case reflect.Struct:
		si, err := getStructInfo(t)
		if err != nil {
			return err
		}

		for name, dirs := range dirsByName {
			fieldIndex, ok := si.children[name]
			if !ok {
				return newUnmarshalDirectiveError(dirs[0], "unknown directive")
			}
			fv := v.Field(fieldIndex)
			if err := unmarshalDirectiveList(dirs, fv); err != nil {
				return err
			}
		}
	default:
		return fmt.Errorf("scfg: unsupported type for unmarshaling blocks: %v", t)
	}

	return nil
}

func unmarshalDirectiveList(dirs []*Directive, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	if v.Kind() != reflect.Slice || !isDirectiveType(t.Elem()) {
		if len(dirs) > 1 {
			return newUnmarshalDirectiveError(dirs[1], "directive must not be specified more than once")
		}
		return unmarshalDirective(dirs[0], v)
	}

	sv := reflect.MakeSlice(t, len(dirs), len(dirs))
	for i, dir := range dirs {
		if err := unmarshalDirective(dir, sv.Index(i)); err != nil {
			return err
		}
	}
	v.Set(sv)
	return nil
}

// isDirectiveType checks whether a type can only be unmarshaled as a
// directive, not as a parameter. Accepting too many types here would result in
// ambiguities, see:
// https://lists.sr.ht/~emersion/public-inbox/%3C20230629132458.152205-1-contact%40emersion.fr%3E#%3Ch4Y2peS_YBqY3ar4XlmPDPiNBFpYGns3EBYUx3_6zWEhV2o8_-fBQveRujGADWYhVVCucHBEryFGoPtpC3d3mQ-x10pWnFogfprbQTSvtxc=@emersion.fr%3E
func isDirectiveType(t reflect.Type) bool {
	for t.Kind() == reflect.Ptr {
		t = t.Elem()
	}

	textUnmarshalerType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
	if reflect.PtrTo(t).Implements(textUnmarshalerType) {
		return false
	}

	switch t.Kind() {
	case reflect.Struct, reflect.Map:
		return true
	default:
		return false
	}
}

func unmarshalDirective(dir *Directive, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	if v.CanAddr() {
		if _, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
			if len(dir.Children) != 0 {
				return newUnmarshalDirectiveError(dir, "directive requires zero children")
			}
			return unmarshalParamList(dir, v)
		}
	}

	switch v.Kind() {
	case reflect.Map:
		if len(dir.Params) > 0 {
			return newUnmarshalDirectiveError(dir, "directive requires zero parameters")
		}
		if err := unmarshalBlock(dir.Children, v); err != nil {
			return err
		}
	case reflect.Struct:
		si, err := getStructInfo(t)
		if err != nil {
			return err
		}

		if si.param >= 0 {
			if err := unmarshalParamList(dir, v.Field(si.param)); err != nil {
				return err
			}
		} else {
			if len(dir.Params) > 0 {
				return newUnmarshalDirectiveError(dir, "directive requires zero parameters")
			}
		}

		if err := unmarshalBlock(dir.Children, v); err != nil {
			return err
		}
	default:
		if len(dir.Children) != 0 {
			return newUnmarshalDirectiveError(dir, "directive requires zero children")
		}
		if err := unmarshalParamList(dir, v); err != nil {
			return err
		}
	}
	return nil
}

func unmarshalParamList(dir *Directive, v reflect.Value) error {
	switch v.Kind() {
	case reflect.Slice:
		t := v.Type()
		sv := reflect.MakeSlice(t, len(dir.Params), len(dir.Params))
		for i, param := range dir.Params {
			if err := unmarshalParam(param, sv.Index(i)); err != nil {
				return newUnmarshalParamError(dir, i, err)
			}
		}
		v.Set(sv)
	case reflect.Array:
		if len(dir.Params) != v.Len() {
			return newUnmarshalDirectiveError(dir, fmt.Sprintf("directive requires exactly %v parameters", v.Len()))
		}
		for i, param := range dir.Params {
			if err := unmarshalParam(param, v.Index(i)); err != nil {
				return newUnmarshalParamError(dir, i, err)
			}
		}
	default:
		if len(dir.Params) != 1 {
			return newUnmarshalDirectiveError(dir, "directive requires exactly one parameter")
		}
		if err := unmarshalParam(dir.Params[0], v); err != nil {
			return newUnmarshalParamError(dir, 0, err)
		}
	}

	return nil
}

func unmarshalParam(param string, v reflect.Value) error {
	v = unwrapPointers(v)
	t := v.Type()

	// TODO: improve our logic following:
	// https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/encoding/json/decode.go;drc=b9b8cecbfc72168ca03ad586cc2ed52b0e8db409;l=421
	if v.CanAddr() {
		if v, ok := v.Addr().Interface().(encoding.TextUnmarshaler); ok {
			return v.UnmarshalText([]byte(param))
		}
	}

	switch v.Kind() {
	case reflect.String:
		v.Set(reflect.ValueOf(param))
	case reflect.Bool:
		switch param {
		case "true":
			v.Set(reflect.ValueOf(true))
		case "false":
			v.Set(reflect.ValueOf(false))
		default:
			return fmt.Errorf("invalid bool parameter %q", param)
		}
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		i, err := strconv.ParseInt(param, 10, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(i).Convert(t))
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		u, err := strconv.ParseUint(param, 10, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(u).Convert(t))
	case reflect.Float32, reflect.Float64:
		f, err := strconv.ParseFloat(param, t.Bits())
		if err != nil {
			return fmt.Errorf("invalid %v parameter: %v", t, err)
		}
		v.Set(reflect.ValueOf(f).Convert(t))
	default:
		return fmt.Errorf("unsupported type for unmarshaling parameter: %v", t)
	}

	return nil
}

func unwrapPointers(v reflect.Value) reflect.Value {
	for v.Kind() == reflect.Ptr {
		if v.IsNil() {
			v.Set(reflect.New(v.Type().Elem()))
		}
		v = v.Elem()
	}
	return v
}

func clearMap(v reflect.Value) {
	for _, k := range v.MapKeys() {
		v.SetMapIndex(k, reflect.Value{})
	}
}

type unmarshalDirectiveError struct {
	lineno int
	name   string
	msg    string
}

func newUnmarshalDirectiveError(dir *Directive, msg string) *unmarshalDirectiveError {
	return &unmarshalDirectiveError{
		name:   dir.Name,
		lineno: dir.lineno,
		msg:    msg,
	}
}

func (err *unmarshalDirectiveError) Error() string {
	return fmt.Sprintf("line %v, directive %q: %v", err.lineno, err.name, err.msg)
}

type unmarshalParamError struct {
	lineno     int
	directive  string
	paramIndex int
	err        error
}

func newUnmarshalParamError(dir *Directive, paramIndex int, err error) *unmarshalParamError {
	return &unmarshalParamError{
		directive:  dir.Name,
		lineno:     dir.lineno,
		paramIndex: paramIndex,
		err:        err,
	}
}

func (err *unmarshalParamError) Error() string {
	return fmt.Sprintf("line %v, directive %q, parameter %v: %v", err.lineno, err.directive, err.paramIndex+1, err.err)
}
package scfg_test

import (
	"fmt"
	"log"
	"reflect"
	"strings"
	"testing"

	"go.lindenii.runxiyu.org/lindenii-common/scfg"
)

func ExampleDecoder() {
	var data struct {
		Foo int `scfg:"foo"`
		Bar struct {
			Param string `scfg:",param"`
			Baz   string `scfg:"baz"`
		} `scfg:"bar"`
	}

	raw := `foo 42
bar asdf {
	baz hello
}
`

	r := strings.NewReader(raw)
	if err := scfg.NewDecoder(r).Decode(&data); err != nil {
		log.Fatal(err)
	}

	fmt.Printf("Foo = %v\n", data.Foo)
	fmt.Printf("Bar.Param = %v\n", data.Bar.Param)
	fmt.Printf("Bar.Baz = %v\n", data.Bar.Baz)

	// Output:
	// Foo = 42
	// Bar.Param = asdf
	// Bar.Baz = hello
}

type nestedStructInner struct {
	Bar string `scfg:"bar"`
}

type structParams struct {
	Params []string `scfg:",param"`
	Bar    string
}

type textUnmarshaler struct {
	text string
}

func (tu *textUnmarshaler) UnmarshalText(text []byte) error {
	tu.text = string(text)
	return nil
}

type textUnmarshalerParams struct {
	Params []textUnmarshaler `scfg:",param"`
}

var barStr = "bar"

var unmarshalTests = []struct {
	name string
	raw  string
	want interface{}
}{
	{
		name: "stringMap",
		raw: `hello world
foo bar`,
		want: map[string]string{
			"hello": "world",
			"foo":   "bar",
		},
	},
	{
		name: "simpleStruct",
		raw: `MyString asdf
MyBool true
MyInt -42
MyUint 42
MyFloat 3.14`,
		want: struct {
			MyString string
			MyBool   bool
			MyInt    int
			MyUint   uint
			MyFloat  float32
		}{
			MyString: "asdf",
			MyBool:   true,
			MyInt:    -42,
			MyUint:   42,
			MyFloat:  3.14,
		},
	},
	{
		name: "simpleStructTag",
		raw:  `foo bar`,
		want: struct {
			Foo string `scfg:"foo"`
		}{
			Foo: "bar",
		},
	},
	{
		name: "sliceParams",
		raw:  `Foo a s d f`,
		want: struct {
			Foo []string
		}{
			Foo: []string{"a", "s", "d", "f"},
		},
	},
	{
		name: "arrayParams",
		raw:  `Foo a s d f`,
		want: struct {
			Foo [4]string
		}{
			Foo: [4]string{"a", "s", "d", "f"},
		},
	},
	{
		name: "pointers",
		raw:  `Foo bar`,
		want: struct {
			Foo *string
		}{
			Foo: &barStr,
		},
	},
	{
		name: "nestedMap",
		raw: `foo {
	bar baz
}`,
		want: struct {
			Foo map[string]string `scfg:"foo"`
		}{
			Foo: map[string]string{"bar": "baz"},
		},
	},
	{
		name: "nestedStruct",
		raw: `foo {
	bar baz
}`,
		want: struct {
			Foo nestedStructInner `scfg:"foo"`
		}{
			Foo: nestedStructInner{
				Bar: "baz",
			},
		},
	},
	{
		name: "structParams",
		raw: `Foo param1 param2 {
	Bar baz
}`,
		want: struct {
			Foo structParams
		}{
			Foo: structParams{
				Params: []string{"param1", "param2"},
				Bar:    "baz",
			},
		},
	},
	{
		name: "textUnmarshaler",
		raw: `Foo param1
Bar param2
Baz param3`,
		want: struct {
			Foo []textUnmarshaler
			Bar *textUnmarshaler
			Baz textUnmarshalerParams
		}{
			Foo: []textUnmarshaler{{"param1"}},
			Bar: &textUnmarshaler{"param2"},
			Baz: textUnmarshalerParams{
				Params: []textUnmarshaler{{"param3"}},
			},
		},
	},
	{
		name: "directiveStructSlice",
		raw: `Foo param1 param2 {
	Bar baz
}
Foo param3 param4`,
		want: struct {
			Foo []structParams
		}{
			Foo: []structParams{
				{
					Params: []string{"param1", "param2"},
					Bar:    "baz",
				},
				{
					Params: []string{"param3", "param4"},
				},
			},
		},
	},
	{
		name: "directiveMapSlice",
		raw: `Foo {
	key1 param1
}
Foo {
	key2 param2
}`,
		want: struct {
			Foo []map[string]string
		}{
			Foo: []map[string]string{
				{"key1": "param1"},
				{"key2": "param2"},
			},
		},
	},
}

func TestUnmarshal(t *testing.T) {
	for _, tc := range unmarshalTests {
		tc := tc // capture variable
		t.Run(tc.name, func(t *testing.T) {
			testUnmarshal(t, tc.raw, tc.want)
		})
	}
}

func testUnmarshal(t *testing.T, raw string, want interface{}) {
	out := reflect.New(reflect.TypeOf(want))
	r := strings.NewReader(raw)
	if err := scfg.NewDecoder(r).Decode(out.Interface()); err != nil {
		t.Fatalf("Decode() = %v", err)
	}
	got := out.Elem().Interface()
	if !reflect.DeepEqual(got, want) {
		t.Errorf("Decode() = \n%#v\n but want \n%#v", got, want)
	}
}
package scfg

import (
	"errors"
	"io"
	"strings"
)

var (
	errDirEmptyName = errors.New("scfg: directive with empty name")
)

// Write writes a parsed configuration to the provided io.Writer.
func Write(w io.Writer, blk Block) error {
	enc := newEncoder(w)
	err := enc.encodeBlock(blk)
	return err
}

// encoder write SCFG directives to an output stream.
type encoder struct {
	w   io.Writer
	lvl int
	err error
}

// newEncoder returns a new encoder that writes to w.
func newEncoder(w io.Writer) *encoder {
	return &encoder{w: w}
}

func (enc *encoder) push() {
	enc.lvl++
}

func (enc *encoder) pop() {
	enc.lvl--
}

func (enc *encoder) writeIndent() {
	for i := 0; i < enc.lvl; i++ {
		enc.write([]byte("\t"))
	}
}

func (enc *encoder) write(p []byte) {
	if enc.err != nil {
		return
	}
	_, enc.err = enc.w.Write(p)
}

func (enc *encoder) encodeBlock(blk Block) error {
	for _, dir := range blk {
		enc.encodeDir(*dir)
	}
	return enc.err
}

func (enc *encoder) encodeDir(dir Directive) error {
	if enc.err != nil {
		return enc.err
	}

	if dir.Name == "" {
		enc.err = errDirEmptyName
		return enc.err
	}

	enc.writeIndent()
	enc.write([]byte(maybeQuote(dir.Name)))
	for _, p := range dir.Params {
		enc.write([]byte(" "))
		enc.write([]byte(maybeQuote(p)))
	}

	if len(dir.Children) > 0 {
		enc.write([]byte(" {\n"))
		enc.push()
		enc.encodeBlock(dir.Children)
		enc.pop()

		enc.writeIndent()
		enc.write([]byte("}"))
	}
	enc.write([]byte("\n"))

	return enc.err
}

const specialChars = "\"\\\r\n'{} \t"

func maybeQuote(s string) string {
	if s == "" || strings.ContainsAny(s, specialChars) {
		var sb strings.Builder
		sb.WriteByte('"')
		for _, ch := range s {
			if strings.ContainsRune(`"\`, ch) {
				sb.WriteByte('\\')
			}
			sb.WriteRune(ch)
		}
		sb.WriteByte('"')
		return sb.String()
	}
	return s
}
package scfg

import (
	"bytes"
	"testing"
)

func TestWrite(t *testing.T) {
	for _, tc := range []struct {
		src  Block
		want string
		err  error
	}{
		{
			src:  Block{},
			want: "",
		},
		{
			src: Block{{
				Name: "dir",
				Children: Block{{
					Name:   "blk1",
					Params: []string{"p1", `"p2"`},
					Children: Block{
						{
							Name:   "sub1",
							Params: []string{"arg11", "arg12"},
						},
						{
							Name:   "sub2",
							Params: []string{"arg21", "arg22"},
						},
						{
							Name:   "sub3",
							Params: []string{"arg31", "arg32"},
							Children: Block{
								{
									Name: "sub-sub1",
								},
								{
									Name:   "sub-sub2",
									Params: []string{"arg321", "arg322"},
								},
							},
						},
					},
				}},
			}},
			want: `dir {
	blk1 p1 "\"p2\"" {
		sub1 arg11 arg12
		sub2 arg21 arg22
		sub3 arg31 arg32 {
			sub-sub1
			sub-sub2 arg321 arg322
		}
	}
}
`,
		},
		{
			src:  Block{{Name: "dir1"}},
			want: "dir1\n",
		},
		{
			src:  Block{{Name: "dir\"1"}},
			want: "\"dir\\\"1\"\n",
		},
		{
			src:  Block{{Name: "dir'1"}},
			want: "\"dir'1\"\n",
		},
		{
			src:  Block{{Name: "dir:}"}},
			want: "\"dir:}\"\n",
		},
		{
			src:  Block{{Name: "dir:{"}},
			want: "\"dir:{\"\n",
		},
		{
			src:  Block{{Name: "dir\t1"}},
			want: `"dir` + "\t" + `1"` + "\n",
		},
		{
			src:  Block{{Name: "dir 1"}},
			want: "\"dir 1\"\n",
		},
		{
			src:  Block{{Name: "dir1", Params: []string{"arg1", "arg2", `"arg3"`}}},
			want: "dir1 arg1 arg2 " + `"\"arg3\""` + "\n",
		},
		{
			src:  Block{{Name: "dir1", Params: []string{"arg1", "arg 2", "arg'3"}}},
			want: "dir1 arg1 \"arg 2\" \"arg'3\"\n",
		},
		{
			src:  Block{{Name: "dir1", Params: []string{"arg1", "", "arg3"}}},
			want: "dir1 arg1 \"\" arg3\n",
		},
		{
			src:  Block{{Name: "dir1", Params: []string{"arg1", `"` + "\"\"" + `"`, "arg3"}}},
			want: "dir1 arg1 " + `"\"\"\"\""` + " arg3\n",
		},
		{
			src: Block{{
				Name: "dir1",
				Children: Block{
					{Name: "sub1"},
					{Name: "sub2", Params: []string{"arg1", "arg2"}},
				},
			}},
			want: `dir1 {
	sub1
	sub2 arg1 arg2
}
`,
		},
		{
			src: Block{{Name: ""}},
			err: errDirEmptyName,
		},
		{
			src: Block{{
				Name: "dir",
				Children: Block{
					{Name: "sub1"},
					{Name: "", Children: Block{{Name: "sub21"}}},
				},
			}},
			err: errDirEmptyName,
		},
	} {
		t.Run("", func(t *testing.T) {
			var buf bytes.Buffer
			err := Write(&buf, tc.src)
			switch {
			case err != nil && tc.err != nil:
				if got, want := err.Error(), tc.err.Error(); got != want {
					t.Fatalf("invalid error:\ngot= %q\nwant=%q", got, want)
				}
				return
			case err == nil && tc.err != nil:
				t.Fatalf("got err=nil, want=%q", tc.err.Error())
			case err != nil && tc.err == nil:
				t.Fatalf("could not marshal: %+v", err)
			case err == nil && tc.err == nil:
				// ok.
			}
			if got, want := buf.String(), tc.want; got != want {
				t.Fatalf(
					"invalid marshal representation:\ngot:\n%s\nwant:\n%s\n---",
					got, want,
				)
			}
		})
	}
}