whspbrd-final/thrembio/server.go
2026-05-02 22:09:19 +02:00

604 lines
14 KiB
Go

package thrembio
// 10% of AI generated code
// human made; AI details/comments
// FINAL
import (
"WhspBrd/menc"
"WhspBrd/owner"
"WhspBrd/sfudp"
"WhspBrd/thrembio/taskpool"
"WhspBrd/typio/bit"
"bytes"
"crypto/mlkem"
"crypto/sha256"
"encoding/binary"
"errors"
"log"
"net"
"os"
"sync"
"time"
)
var (
ErrServerAlreadyOpen = errors.New("server already open")
ErrMalformedPacket = errors.New("user provided malformed packet")
ErrNoAssociatedRegisterTokenFound = errors.New("user provided register token ID that doesn't exist")
ErrInvalidRegisterCheckHash = errors.New("user provided invalid register check hash (binds token + identity)")
ErrUserAlreadyRegistered = errors.New("user is already registered")
ErrNotRegisteredTriedToLogin = errors.New("user that is not registered tried to login")
ErrReplayAttackSuspected = errors.New("replay attack suspected due to repeated packet")
ErrUserNotLoggedIn = errors.New("user not logged in tried to send data")
)
type ReadType bit.Bit8
const (
RT_None ReadType = iota
// New user registered
RT_Register
// User logged in
RT_Login
// Data from authenticated user
RT_Data
// Client-related error
RT_UserError
// Internal server error
RT_InsideError
)
type serverFlag bit.Bit8
const (
SF_None serverFlag = 0
// Don't log user errors (like invalid register token, invalid login, etc.)
SF_NoUserErrorLog serverFlag = 1 << 0
// Don't log register events (successful ones)
SF_NoRegisterLog serverFlag = 1 << 1
// Don't log login events (successful ones)
SF_NoLoginLog serverFlag = 1 << 2
)
func (f *serverFlag) has(flag serverFlag) bool {
return (*f)&flag != 0
}
type reqPacket struct {
from net.UDPAddr // The address of the client that sent the request.
time uint64
reqType PacketReqType // The type of the request (register, login, data, etc.)
fullPacket []byte // The full packet that was sent by the client, including everything.
header []byte
packet []byte // The packet without the common header (magic, version, timestamp, type), so just the payload-ish.
// register
_got_tokenID_B []byte
_got_checkHash []byte
_got_sign []byte
_got_tokenID uint32
_token []byte
// login
_got_encapKey []byte
//_got_sign []byte
// data
_got_user []byte
_got_seq_B []byte
_got_payload []byte
_got_seq uint32
_sesh [32]byte
// temps
__tempBool bool
__tempBytes []byte
// for auto parts
__offset int
}
const rP_ReadAll = -1
func (pkt *reqPacket) startParts() {
pkt.__offset = 0
}
func (pkt *reqPacket) nextPart(size int) []byte {
start := pkt.__offset
if start > len(pkt.packet) {
return nil
}
var end int
switch {
case size == rP_ReadAll:
end = len(pkt.packet)
case size >= 0 && start+size <= len(pkt.packet):
end = start + size
default:
return nil
}
pkt.__offset = end
return pkt.packet[start:end]
}
type readPacket struct {
rt ReadType
u owner.Identity
fa net.UDPAddr
d []byte
e error
}
type Server interface {
Open() error
Close()
SetFlags(flags serverFlag)
/*
Read waits for incoming packets until a valid one is processed.
Invalid packets are ignored.
Returns one of:
RT_Register:
(RT_Register, newUser, fromAddr, nil, nil)
Ignored if SF_NoRegisterLog is enabled.
RT_Login:
(RT_Login, user, fromAddr, nil, nil)
Ignored if SF_NoLoginLog is enabled.
RT_Data:
(RT_Data, user, fromAddr, requestData, nil)
RT_UserError:
(RT_UserError, ?user, fromAddr, nil, error)
Ignored if SF_NoUserError is enabled.
RT_InsideError:
(RT_InsideError, emptyUser, emptyAddr, nil, error)
*/
Read() (readType ReadType, user owner.Identity, fromAddr net.UDPAddr, data []byte, err error)
Write(data []byte, to owner.Identity) error
}
type server struct {
port uint16
secret owner.Secret
db ServerDB
flags serverFlag
listener *sfudp.SFUDPConn
debug bool
bufPool sync.Pool
taskPool *taskpool.Pool[reqPacket]
packets chan readPacket
done chan struct{}
doneMu sync.Once
}
func NewServer(port uint16, secret owner.Secret, db ServerDB) (Server, error) {
if secret == nil {
return nil, ErrSecretCantBeNil
}
srv := &server{
port: port,
secret: secret,
db: db,
listener: nil,
taskPool: taskpool.New[reqPacket](0, 0),
bufPool: sync.Pool{
New: func() any {
b := make([]byte, 16384)
return &b
},
},
debug: os.Getenv("WHSPBRD_DEBUG") != "",
}
return srv, nil
}
func (s *server) SetFlags(flags serverFlag) {
s.flags = flags
}
func (s *server) Open() error {
if s.listener != nil {
return ErrServerAlreadyOpen
}
l, err := sfudp.ListenSFUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: int(s.port)})
if err != nil {
return err
}
s.listener = l
s.packets = make(chan readPacket, 1024)
s.done = make(chan struct{})
s.taskPool.Open()
go s.serve()
return nil
}
func (s *server) serve() {
defer close(s.packets)
for {
select {
case <-s.done:
return
default:
}
bufPtr := s.bufPool.Get().(*[]byte)
buf := *bufPtr
length, addr, err := s.listener.ReadFromSFUDP(buf[:])
if err != nil {
s.bufPool.Put(bufPtr)
return
}
if length < commonHeaderSize ||
buf[0] != magicBytes[0] ||
buf[1] != magicBytes[1] ||
buf[2] != magicBytes[2] ||
buf[3] != version {
continue
}
ts := binary.LittleEndian.Uint64(buf[4:12])
now := time.Now()
nowMs := uint64(now.UnixMilli())
if ts > nowMs || nowMs-ts > 8_000 {
continue
}
reqType := PacketReqType(buf[12])
if reqType >= Rq_Unknown {
continue
}
pkt := reqPacket{
from: *addr,
time: ts,
reqType: reqType,
fullPacket: buf[:length],
}
pkt.header = pkt.fullPacket[:13]
pkt.packet = pkt.fullPacket[13:]
s.taskPool.Dispatch(taskpool.Task[reqPacket]{
Value: pkt,
Handle: s.handlePacket,
})
}
}
func (s *server) handlePacket(pkt reqPacket) {
defer s.bufPool.Put(&pkt.fullPacket)
var (
user owner.Identity
data []byte
userErr error
err error
)
switch pkt.reqType {
case Rq_Ping:
s.rawWriteAck(&pkt.from, pkt.fullPacket, true)
return
case Rq_Register:
user, userErr, err = s.handleRegister(pkt)
case Rq_Login:
user, userErr, err = s.handleLogin(pkt)
case Rq_Data:
user, data, userErr, err = s.handleData(pkt)
default:
return
}
s.emitResult(pkt, user, data, userErr, err)
}
func (s *server) emitResult(
pkt reqPacket,
user owner.Identity,
data []byte,
userErr error,
err error,
) {
var rp readPacket
if err != nil {
rp = readPacket{RT_InsideError, owner.Identity{}, net.UDPAddr{}, nil, err}
} else if userErr != nil {
if s.debug {
log.Printf("user error: %v", userErr)
}
if err == nil {
_ = s.rawWriteError(&pkt.from, pkt.fullPacket, userErrPayload(userErr), true)
}
if s.flags.has(SF_NoUserErrorLog) {
return
}
rp = readPacket{RT_UserError, owner.Identity{}, pkt.from, nil, userErr}
} else {
switch pkt.reqType {
case Rq_Register:
if s.flags.has(SF_NoRegisterLog) {
return
}
rp = readPacket{RT_Register, user, pkt.from, nil, nil}
case Rq_Login:
if s.flags.has(SF_NoLoginLog) {
return
}
rp = readPacket{RT_Login, user, pkt.from, nil, nil}
case Rq_Data:
rp = readPacket{RT_Data, user, pkt.from, data, nil}
}
s.db.SetLastIp(user, &pkt.from)
}
select {
case s.packets <- rp:
case <-s.done:
}
}
func (s *server) Close() {
s.doneMu.Do(func() {
if s.listener == nil {
return
}
close(s.done)
s.listener.Close()
s.taskPool.Close()
s.listener = nil
})
}
func (s *server) Read() (ReadType, owner.Identity, net.UDPAddr, []byte, error) {
p, ok := <-s.packets
if !ok {
return RT_InsideError, owner.Identity{}, net.UDPAddr{}, nil, net.ErrClosed
}
return p.rt, p.u, p.fa, p.d, p.e
}
func (s *server) Write(data []byte, user owner.Identity) error {
ip, ex := s.db.GetLastIp(user)
if !ex {
return errors.New("no known IP for user")
}
req, ex := s.db.GetLastReq(user)
if !ex {
return errors.New("no known last request for user")
}
key, ex, err := s.db.GetActiveSession(user)
if err != nil {
return err
}
if !ex {
return errors.New("no active session for user")
}
ciphertext, err := menc.AESGCM_Quick_Encrypt(key[:], data, req)
if err != nil {
return err
}
return s.rawWriteData(ip, req, ciphertext, false)
}
func (s *server) handleRegister(pkt reqPacket) (newUser owner.Identity, userErr error, err error) {
pkt.startParts()
pkt._got_tokenID_B = pkt.nextPart(bit.SizeUint32_B)
pkt._got_checkHash = pkt.nextPart(bit.SizeSha256_B)
pkt._got_sign = pkt.nextPart(owner.SignatureSize)
pkt.__tempBytes = pkt.nextPart(rP_ReadAll)
if pkt._got_tokenID_B == nil || pkt._got_checkHash == nil || pkt._got_sign == nil || len(pkt.__tempBytes) != 0 {
userErr = ErrMalformedPacket
return
}
pkt._got_tokenID = binary.BigEndian.Uint32(pkt._got_tokenID_B)
// --- ATOMIC token consumption ---
pkt._token, pkt.__tempBool, err = s.db.ConsumeRegisterToken(pkt._got_tokenID)
if err != nil {
return
} else if !pkt.__tempBool {
userErr = ErrNoAssociatedRegisterTokenFound
return
}
// --- verify identity signature ---
newUser, userErr = owner.Verify(
pkt.fullPacket[:len(pkt.fullPacket)-owner.SignatureSize],
pkt._got_sign,
)
if userErr != nil {
return
}
// --- check hash binds token + identity ---
sha := sha256.New()
sha.Write(pkt.header)
sha.Write(pkt._got_tokenID_B)
sha.Write(pkt._token)
sha.Write(newUser[:])
if !bytes.Equal(sha.Sum(nil), pkt._got_checkHash) {
userErr = ErrInvalidRegisterCheckHash
return
}
// --- ATOMIC register user ---
alreadyExists, err := s.db.AddRegisteredUserAtomic(newUser)
if err != nil {
return
} else if alreadyExists {
userErr = ErrUserAlreadyRegistered
return
}
err = s.rawWriteAck(&pkt.from, pkt.fullPacket, true)
return
}
func (s *server) handleLogin(pkt reqPacket) (user owner.Identity, userErr error, err error) {
pkt.startParts()
pkt._got_encapKey = pkt.nextPart(mlkem.EncapsulationKeySize1024)
pkt._got_sign = pkt.nextPart(owner.SignatureSize)
pkt.__tempBytes = pkt.nextPart(rP_ReadAll)
if pkt._got_encapKey == nil || pkt._got_sign == nil || len(pkt.__tempBytes) != 0 {
userErr = ErrMalformedPacket
return
}
// --- verify user signature ---
user, userErr = owner.Verify(
pkt.fullPacket[:len(pkt.fullPacket)-owner.SignatureSize],
pkt._got_sign,
)
if userErr != nil {
return
}
// --- check registered user ---
pkt.__tempBool, err = s.db.HasRegisteredUser(user)
if err != nil {
return
} else if !pkt.__tempBool {
err = s.rawWriteError(&pkt.from, pkt.fullPacket, notRegisteredB, true)
userErr = ErrNotRegisteredTriedToLogin
return
}
// --- ATOMIC anti replay check ---
pkt.__tempBool, err = s.db.CheckAndSetLastReq(user, pkt.fullPacket)
if err != nil {
return
} else if pkt.__tempBool {
userErr = ErrReplayAttackSuspected
return
}
// --- generate encapsulation key ---
encKey, userErr := mlkem.NewEncapsulationKey1024(pkt._got_encapKey)
if userErr != nil {
return
}
shared, ciphertext := encKey.Encapsulate()
// --- set active session ---
err = s.db.SetActiveSession(user, sha256.Sum256(shared))
if err != nil {
return
}
err = s.rawWriteData(&pkt.from, pkt.fullPacket, ciphertext, true)
return
}
func (s *server) handleData(pkt reqPacket) (user owner.Identity, data []byte, userErr error, err error) {
pkt.startParts()
pkt._got_user = pkt.nextPart(owner.IdentitySize)
pkt._got_seq_B = pkt.nextPart(bit.SizeUint32_B)
pkt._got_payload = pkt.nextPart(rP_ReadAll)
if pkt._got_user == nil || pkt._got_seq_B == nil || pkt._got_payload == nil {
userErr = ErrMalformedPacket
return
}
pkt._got_seq = binary.BigEndian.Uint32(pkt._got_seq_B)
// --- check logged in user ---
pkt._sesh, pkt.__tempBool, err = s.db.GetActiveSession(owner.Identity(pkt._got_user))
if err != nil {
return
} else if !pkt.__tempBool {
err = s.rawWriteError(&pkt.from, pkt.fullPacket, notLoggedB, true)
userErr = ErrUserNotLoggedIn
return
}
data, userErr = menc.AESGCM_Quick_Decrypt(pkt._sesh[:], pkt._got_payload, append(pkt.header, pkt._got_seq_B...))
if userErr != nil {
return
}
// --- ATOMIC anti replay check ---
pkt.__tempBool, err = s.db.CheckAndSetLastReq(user, pkt.fullPacket)
if err != nil {
return
} else if pkt.__tempBool {
userErr = ErrReplayAttackSuspected
return
}
err = s.rawWriteAck(&pkt.from, pkt.fullPacket, true)
return
}
func (s *server) rawWriteAck(to *net.UDPAddr, tohash []byte, signed bool) error {
return s.rawWrite(to, Rs_Ack, tohash, none, signed)
}
func (s *server) rawWriteData(to *net.UDPAddr, tohash []byte, data []byte, signed bool) error {
return s.rawWrite(to, Rs_Data, tohash, data, signed)
}
func (s *server) rawWriteError(to *net.UDPAddr, tohash []byte, err []byte, signed bool) error {
return s.rawWrite(to, Rs_Error, tohash, err, signed)
}
func (s *server) rawWrite(to *net.UDPAddr, tp PacketResType, tohash []byte, payload []byte, signed bool) error {
hash := sha256.Sum256(tohash)
data := append(hash[:], payload...)
if signed {
signature, err := s.secret.Sign(data)
if err != nil {
panic("klokotek")
}
data = append(data, signature...)
}
data = append(magicWithVersion, append([]byte{byte(tp)}, data...)...)
_, err := s.listener.WriteToSFUDP(data, to)
return err
}
func userErrPayload(err error) []byte {
switch {
case errors.Is(err, ErrMalformedPacket):
return []byte("malformed")
case errors.Is(err, ErrNoAssociatedRegisterTokenFound):
return []byte("no_token")
case errors.Is(err, ErrInvalidRegisterCheckHash):
return []byte("invalid_check")
case errors.Is(err, ErrUserAlreadyRegistered):
return []byte("already_registered")
case errors.Is(err, ErrNotRegisteredTriedToLogin):
return []byte("not_registered")
case errors.Is(err, ErrUserNotLoggedIn):
return []byte("not_logged")
case errors.Is(err, ErrReplayAttackSuspected):
return []byte("replay")
default:
return []byte("user_error")
}
}