package packet // This file implements the pushdown automata (PDA) from PGPainless (Paul Schaub) // to verify pgp packet sequences. See Paul's blogpost for more details: // https://blog.jabberhead.tk/2022/10/26/implementing-packet-sequence-validation-using-pushdown-automata/ import ( "fmt" "github.com/ProtonMail/go-crypto/openpgp/errors" ) func NewErrMalformedMessage(from State, input InputSymbol, stackSymbol StackSymbol) errors.ErrMalformedMessage { return errors.ErrMalformedMessage(fmt.Sprintf("state %d, input symbol %d, stack symbol %d ", from, input, stackSymbol)) } // InputSymbol defines the input alphabet of the PDA type InputSymbol uint8 const ( LDSymbol InputSymbol = iota SigSymbol OPSSymbol CompSymbol ESKSymbol EncSymbol EOSSymbol UnknownSymbol ) // StackSymbol defines the stack alphabet of the PDA type StackSymbol int8 const ( MsgStackSymbol StackSymbol = iota OpsStackSymbol KeyStackSymbol EndStackSymbol EmptyStackSymbol ) // State defines the states of the PDA type State int8 const ( OpenPGPMessage State = iota ESKMessage LiteralMessage CompressedMessage EncryptedMessage ValidMessage ) // transition represents a state transition in the PDA type transition func(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) // SequenceVerifier is a pushdown automata to verify // PGP messages packet sequences according to rfc4880. type SequenceVerifier struct { stack []StackSymbol state State } // Next performs a state transition with the given input symbol. // If the transition fails a ErrMalformedMessage is returned. func (sv *SequenceVerifier) Next(input InputSymbol) error { for { stackSymbol := sv.popStack() transitionFunc := getTransition(sv.state) nextState, newStackSymbols, redo, err := transitionFunc(input, stackSymbol) if err != nil { return err } if redo { sv.pushStack(stackSymbol) } for _, newStackSymbol := range newStackSymbols { sv.pushStack(newStackSymbol) } sv.state = nextState if !redo { break } } return nil } // Valid returns true if RDA is in a valid state. func (sv *SequenceVerifier) Valid() bool { return sv.state == ValidMessage && len(sv.stack) == 0 } func (sv *SequenceVerifier) AssertValid() error { if !sv.Valid() { return errors.ErrMalformedMessage("invalid message") } return nil } func NewSequenceVerifier() *SequenceVerifier { return &SequenceVerifier{ stack: []StackSymbol{EndStackSymbol, MsgStackSymbol}, state: OpenPGPMessage, } } func (sv *SequenceVerifier) popStack() StackSymbol { if len(sv.stack) == 0 { return EmptyStackSymbol } elemIndex := len(sv.stack) - 1 stackSymbol := sv.stack[elemIndex] sv.stack = sv.stack[:elemIndex] return stackSymbol } func (sv *SequenceVerifier) pushStack(stackSymbol StackSymbol) { sv.stack = append(sv.stack, stackSymbol) } func getTransition(from State) transition { switch from { case OpenPGPMessage: return fromOpenPGPMessage case LiteralMessage: return fromLiteralMessage case CompressedMessage: return fromCompressedMessage case EncryptedMessage: return fromEncryptedMessage case ESKMessage: return fromESKMessage case ValidMessage: return fromValidMessage } return nil } // fromOpenPGPMessage is the transition for the state OpenPGPMessage. func fromOpenPGPMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { if stackSymbol != MsgStackSymbol { return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol) } switch input { case LDSymbol: return LiteralMessage, nil, false, nil case SigSymbol: return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, false, nil case OPSSymbol: return OpenPGPMessage, []StackSymbol{OpsStackSymbol, MsgStackSymbol}, false, nil case CompSymbol: return CompressedMessage, nil, false, nil case ESKSymbol: return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil case EncSymbol: return EncryptedMessage, nil, false, nil } return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol) } // fromESKMessage is the transition for the state ESKMessage. func fromESKMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { if stackSymbol != KeyStackSymbol { return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol) } switch input { case ESKSymbol: return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil case EncSymbol: return EncryptedMessage, nil, false, nil } return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol) } // fromLiteralMessage is the transition for the state LiteralMessage. func fromLiteralMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { switch input { case SigSymbol: if stackSymbol == OpsStackSymbol { return LiteralMessage, nil, false, nil } case EOSSymbol: if stackSymbol == EndStackSymbol { return ValidMessage, nil, false, nil } } return 0, nil, false, NewErrMalformedMessage(LiteralMessage, input, stackSymbol) } // fromLiteralMessage is the transition for the state CompressedMessage. func fromCompressedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { switch input { case SigSymbol: if stackSymbol == OpsStackSymbol { return CompressedMessage, nil, false, nil } case EOSSymbol: if stackSymbol == EndStackSymbol { return ValidMessage, nil, false, nil } } return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil } // fromEncryptedMessage is the transition for the state EncryptedMessage. func fromEncryptedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { switch input { case SigSymbol: if stackSymbol == OpsStackSymbol { return EncryptedMessage, nil, false, nil } case EOSSymbol: if stackSymbol == EndStackSymbol { return ValidMessage, nil, false, nil } } return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil } // fromValidMessage is the transition for the state ValidMessage. func fromValidMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { return 0, nil, false, NewErrMalformedMessage(ValidMessage, input, stackSymbol) }