604 lines
14 KiB
Go
604 lines
14 KiB
Go
package thrembio
|
|
|
|
// 10% of AI generated code
|
|
// human made; AI details/comments
|
|
// FINAL
|
|
|
|
import (
|
|
"WhspBrd/menc"
|
|
"WhspBrd/owner"
|
|
"WhspBrd/sfudp"
|
|
"WhspBrd/thrembio/taskpool"
|
|
"WhspBrd/typio/bit"
|
|
"bytes"
|
|
"crypto/mlkem"
|
|
"crypto/sha256"
|
|
"encoding/binary"
|
|
"errors"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
ErrServerAlreadyOpen = errors.New("server already open")
|
|
ErrMalformedPacket = errors.New("user provided malformed packet")
|
|
ErrNoAssociatedRegisterTokenFound = errors.New("user provided register token ID that doesn't exist")
|
|
ErrInvalidRegisterCheckHash = errors.New("user provided invalid register check hash (binds token + identity)")
|
|
ErrUserAlreadyRegistered = errors.New("user is already registered")
|
|
ErrNotRegisteredTriedToLogin = errors.New("user that is not registered tried to login")
|
|
ErrReplayAttackSuspected = errors.New("replay attack suspected due to repeated packet")
|
|
ErrUserNotLoggedIn = errors.New("user not logged in tried to send data")
|
|
)
|
|
|
|
type ReadType bit.Bit8
|
|
|
|
const (
|
|
RT_None ReadType = iota
|
|
// New user registered
|
|
RT_Register
|
|
// User logged in
|
|
RT_Login
|
|
// Data from authenticated user
|
|
RT_Data
|
|
// Client-related error
|
|
RT_UserError
|
|
// Internal server error
|
|
RT_InsideError
|
|
)
|
|
|
|
type serverFlag bit.Bit8
|
|
|
|
const (
|
|
SF_None serverFlag = 0
|
|
// Don't log user errors (like invalid register token, invalid login, etc.)
|
|
SF_NoUserErrorLog serverFlag = 1 << 0
|
|
// Don't log register events (successful ones)
|
|
SF_NoRegisterLog serverFlag = 1 << 1
|
|
// Don't log login events (successful ones)
|
|
SF_NoLoginLog serverFlag = 1 << 2
|
|
)
|
|
|
|
func (f *serverFlag) has(flag serverFlag) bool {
|
|
return (*f)&flag != 0
|
|
}
|
|
|
|
type reqPacket struct {
|
|
from net.UDPAddr // The address of the client that sent the request.
|
|
|
|
time uint64
|
|
reqType PacketReqType // The type of the request (register, login, data, etc.)
|
|
|
|
fullPacket []byte // The full packet that was sent by the client, including everything.
|
|
header []byte
|
|
packet []byte // The packet without the common header (magic, version, timestamp, type), so just the payload-ish.
|
|
|
|
// register
|
|
_got_tokenID_B []byte
|
|
_got_checkHash []byte
|
|
_got_sign []byte
|
|
_got_tokenID uint32
|
|
_token []byte
|
|
|
|
// login
|
|
_got_encapKey []byte
|
|
//_got_sign []byte
|
|
|
|
// data
|
|
_got_user []byte
|
|
_got_seq_B []byte
|
|
_got_payload []byte
|
|
_got_seq uint32
|
|
_sesh [32]byte
|
|
|
|
// temps
|
|
__tempBool bool
|
|
__tempBytes []byte
|
|
|
|
// for auto parts
|
|
__offset int
|
|
}
|
|
|
|
const rP_ReadAll = -1
|
|
|
|
func (pkt *reqPacket) startParts() {
|
|
pkt.__offset = 0
|
|
}
|
|
|
|
func (pkt *reqPacket) nextPart(size int) []byte {
|
|
start := pkt.__offset
|
|
if start > len(pkt.packet) {
|
|
return nil
|
|
}
|
|
|
|
var end int
|
|
switch {
|
|
case size == rP_ReadAll:
|
|
end = len(pkt.packet)
|
|
case size >= 0 && start+size <= len(pkt.packet):
|
|
end = start + size
|
|
default:
|
|
return nil
|
|
}
|
|
|
|
pkt.__offset = end
|
|
return pkt.packet[start:end]
|
|
}
|
|
|
|
type readPacket struct {
|
|
rt ReadType
|
|
u owner.Identity
|
|
fa net.UDPAddr
|
|
d []byte
|
|
e error
|
|
}
|
|
|
|
type Server interface {
|
|
Open() error
|
|
Close()
|
|
SetFlags(flags serverFlag)
|
|
|
|
/*
|
|
Read waits for incoming packets until a valid one is processed.
|
|
Invalid packets are ignored.
|
|
|
|
Returns one of:
|
|
RT_Register:
|
|
(RT_Register, newUser, fromAddr, nil, nil)
|
|
Ignored if SF_NoRegisterLog is enabled.
|
|
|
|
RT_Login:
|
|
(RT_Login, user, fromAddr, nil, nil)
|
|
Ignored if SF_NoLoginLog is enabled.
|
|
|
|
RT_Data:
|
|
(RT_Data, user, fromAddr, requestData, nil)
|
|
|
|
RT_UserError:
|
|
(RT_UserError, ?user, fromAddr, nil, error)
|
|
Ignored if SF_NoUserError is enabled.
|
|
|
|
RT_InsideError:
|
|
(RT_InsideError, emptyUser, emptyAddr, nil, error)
|
|
*/
|
|
Read() (readType ReadType, user owner.Identity, fromAddr net.UDPAddr, data []byte, err error)
|
|
Write(data []byte, to owner.Identity) error
|
|
}
|
|
|
|
type server struct {
|
|
port uint16
|
|
secret owner.Secret
|
|
db ServerDB
|
|
|
|
flags serverFlag
|
|
listener *sfudp.SFUDPConn
|
|
debug bool
|
|
|
|
bufPool sync.Pool
|
|
taskPool *taskpool.Pool[reqPacket]
|
|
|
|
packets chan readPacket
|
|
|
|
done chan struct{}
|
|
doneMu sync.Once
|
|
}
|
|
|
|
func NewServer(port uint16, secret owner.Secret, db ServerDB) (Server, error) {
|
|
if secret == nil {
|
|
return nil, ErrSecretCantBeNil
|
|
}
|
|
|
|
srv := &server{
|
|
port: port,
|
|
secret: secret,
|
|
db: db,
|
|
listener: nil,
|
|
taskPool: taskpool.New[reqPacket](0, 0),
|
|
bufPool: sync.Pool{
|
|
New: func() any {
|
|
b := make([]byte, 16384)
|
|
return &b
|
|
},
|
|
},
|
|
debug: os.Getenv("WHSPBRD_DEBUG") != "",
|
|
}
|
|
|
|
return srv, nil
|
|
}
|
|
|
|
func (s *server) SetFlags(flags serverFlag) {
|
|
s.flags = flags
|
|
}
|
|
|
|
func (s *server) Open() error {
|
|
if s.listener != nil {
|
|
return ErrServerAlreadyOpen
|
|
}
|
|
|
|
l, err := sfudp.ListenSFUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: int(s.port)})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.listener = l
|
|
s.packets = make(chan readPacket, 1024)
|
|
s.done = make(chan struct{})
|
|
|
|
s.taskPool.Open()
|
|
|
|
go s.serve()
|
|
return nil
|
|
}
|
|
|
|
func (s *server) serve() {
|
|
defer close(s.packets)
|
|
|
|
for {
|
|
select {
|
|
case <-s.done:
|
|
return
|
|
default:
|
|
}
|
|
|
|
bufPtr := s.bufPool.Get().(*[]byte)
|
|
buf := *bufPtr
|
|
|
|
length, addr, err := s.listener.ReadFromSFUDP(buf[:])
|
|
if err != nil {
|
|
s.bufPool.Put(bufPtr)
|
|
return
|
|
}
|
|
|
|
if length < commonHeaderSize ||
|
|
buf[0] != magicBytes[0] ||
|
|
buf[1] != magicBytes[1] ||
|
|
buf[2] != magicBytes[2] ||
|
|
buf[3] != version {
|
|
continue
|
|
}
|
|
|
|
ts := binary.LittleEndian.Uint64(buf[4:12])
|
|
now := time.Now()
|
|
nowMs := uint64(now.UnixMilli())
|
|
if ts > nowMs || nowMs-ts > 8_000 {
|
|
continue
|
|
}
|
|
|
|
reqType := PacketReqType(buf[12])
|
|
if reqType >= Rq_Unknown {
|
|
continue
|
|
}
|
|
|
|
pkt := reqPacket{
|
|
from: *addr,
|
|
time: ts,
|
|
reqType: reqType,
|
|
fullPacket: buf[:length],
|
|
}
|
|
pkt.header = pkt.fullPacket[:13]
|
|
pkt.packet = pkt.fullPacket[13:]
|
|
|
|
s.taskPool.Dispatch(taskpool.Task[reqPacket]{
|
|
Value: pkt,
|
|
Handle: s.handlePacket,
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s *server) handlePacket(pkt reqPacket) {
|
|
defer s.bufPool.Put(&pkt.fullPacket)
|
|
|
|
var (
|
|
user owner.Identity
|
|
data []byte
|
|
userErr error
|
|
err error
|
|
)
|
|
|
|
switch pkt.reqType {
|
|
|
|
case Rq_Ping:
|
|
s.rawWriteAck(&pkt.from, pkt.fullPacket, true)
|
|
return
|
|
case Rq_Register:
|
|
user, userErr, err = s.handleRegister(pkt)
|
|
case Rq_Login:
|
|
user, userErr, err = s.handleLogin(pkt)
|
|
case Rq_Data:
|
|
user, data, userErr, err = s.handleData(pkt)
|
|
default:
|
|
return
|
|
}
|
|
|
|
s.emitResult(pkt, user, data, userErr, err)
|
|
}
|
|
|
|
func (s *server) emitResult(
|
|
pkt reqPacket,
|
|
user owner.Identity,
|
|
data []byte,
|
|
userErr error,
|
|
err error,
|
|
) {
|
|
var rp readPacket
|
|
|
|
if err != nil {
|
|
rp = readPacket{RT_InsideError, owner.Identity{}, net.UDPAddr{}, nil, err}
|
|
} else if userErr != nil {
|
|
if s.debug {
|
|
log.Printf("user error: %v", userErr)
|
|
}
|
|
if err == nil {
|
|
_ = s.rawWriteError(&pkt.from, pkt.fullPacket, userErrPayload(userErr), true)
|
|
}
|
|
if s.flags.has(SF_NoUserErrorLog) {
|
|
return
|
|
}
|
|
rp = readPacket{RT_UserError, owner.Identity{}, pkt.from, nil, userErr}
|
|
} else {
|
|
switch pkt.reqType {
|
|
case Rq_Register:
|
|
if s.flags.has(SF_NoRegisterLog) {
|
|
return
|
|
}
|
|
rp = readPacket{RT_Register, user, pkt.from, nil, nil}
|
|
case Rq_Login:
|
|
if s.flags.has(SF_NoLoginLog) {
|
|
return
|
|
}
|
|
rp = readPacket{RT_Login, user, pkt.from, nil, nil}
|
|
case Rq_Data:
|
|
rp = readPacket{RT_Data, user, pkt.from, data, nil}
|
|
}
|
|
s.db.SetLastIp(user, &pkt.from)
|
|
}
|
|
|
|
select {
|
|
case s.packets <- rp:
|
|
case <-s.done:
|
|
}
|
|
}
|
|
|
|
func (s *server) Close() {
|
|
s.doneMu.Do(func() {
|
|
if s.listener == nil {
|
|
return
|
|
}
|
|
close(s.done)
|
|
s.listener.Close()
|
|
s.taskPool.Close()
|
|
s.listener = nil
|
|
})
|
|
}
|
|
|
|
func (s *server) Read() (ReadType, owner.Identity, net.UDPAddr, []byte, error) {
|
|
p, ok := <-s.packets
|
|
if !ok {
|
|
return RT_InsideError, owner.Identity{}, net.UDPAddr{}, nil, net.ErrClosed
|
|
}
|
|
return p.rt, p.u, p.fa, p.d, p.e
|
|
}
|
|
|
|
func (s *server) Write(data []byte, user owner.Identity) error {
|
|
ip, ex := s.db.GetLastIp(user)
|
|
if !ex {
|
|
return errors.New("no known IP for user")
|
|
}
|
|
req, ex := s.db.GetLastReq(user)
|
|
if !ex {
|
|
return errors.New("no known last request for user")
|
|
}
|
|
key, ex, err := s.db.GetActiveSession(user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !ex {
|
|
return errors.New("no active session for user")
|
|
}
|
|
ciphertext, err := menc.AESGCM_Quick_Encrypt(key[:], data, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.rawWriteData(ip, req, ciphertext, false)
|
|
}
|
|
|
|
func (s *server) handleRegister(pkt reqPacket) (newUser owner.Identity, userErr error, err error) {
|
|
pkt.startParts()
|
|
pkt._got_tokenID_B = pkt.nextPart(bit.SizeUint32_B)
|
|
pkt._got_checkHash = pkt.nextPart(bit.SizeSha256_B)
|
|
pkt._got_sign = pkt.nextPart(owner.SignatureSize)
|
|
pkt.__tempBytes = pkt.nextPart(rP_ReadAll)
|
|
if pkt._got_tokenID_B == nil || pkt._got_checkHash == nil || pkt._got_sign == nil || len(pkt.__tempBytes) != 0 {
|
|
userErr = ErrMalformedPacket
|
|
return
|
|
}
|
|
pkt._got_tokenID = binary.BigEndian.Uint32(pkt._got_tokenID_B)
|
|
|
|
// --- ATOMIC token consumption ---
|
|
pkt._token, pkt.__tempBool, err = s.db.ConsumeRegisterToken(pkt._got_tokenID)
|
|
if err != nil {
|
|
return
|
|
} else if !pkt.__tempBool {
|
|
userErr = ErrNoAssociatedRegisterTokenFound
|
|
return
|
|
}
|
|
|
|
// --- verify identity signature ---
|
|
newUser, userErr = owner.Verify(
|
|
pkt.fullPacket[:len(pkt.fullPacket)-owner.SignatureSize],
|
|
pkt._got_sign,
|
|
)
|
|
if userErr != nil {
|
|
return
|
|
}
|
|
|
|
// --- check hash binds token + identity ---
|
|
sha := sha256.New()
|
|
sha.Write(pkt.header)
|
|
sha.Write(pkt._got_tokenID_B)
|
|
sha.Write(pkt._token)
|
|
sha.Write(newUser[:])
|
|
if !bytes.Equal(sha.Sum(nil), pkt._got_checkHash) {
|
|
userErr = ErrInvalidRegisterCheckHash
|
|
return
|
|
}
|
|
|
|
// --- ATOMIC register user ---
|
|
alreadyExists, err := s.db.AddRegisteredUserAtomic(newUser)
|
|
if err != nil {
|
|
return
|
|
} else if alreadyExists {
|
|
userErr = ErrUserAlreadyRegistered
|
|
return
|
|
}
|
|
|
|
err = s.rawWriteAck(&pkt.from, pkt.fullPacket, true)
|
|
return
|
|
}
|
|
|
|
func (s *server) handleLogin(pkt reqPacket) (user owner.Identity, userErr error, err error) {
|
|
pkt.startParts()
|
|
pkt._got_encapKey = pkt.nextPart(mlkem.EncapsulationKeySize1024)
|
|
pkt._got_sign = pkt.nextPart(owner.SignatureSize)
|
|
pkt.__tempBytes = pkt.nextPart(rP_ReadAll)
|
|
|
|
if pkt._got_encapKey == nil || pkt._got_sign == nil || len(pkt.__tempBytes) != 0 {
|
|
userErr = ErrMalformedPacket
|
|
return
|
|
}
|
|
|
|
// --- verify user signature ---
|
|
user, userErr = owner.Verify(
|
|
pkt.fullPacket[:len(pkt.fullPacket)-owner.SignatureSize],
|
|
pkt._got_sign,
|
|
)
|
|
if userErr != nil {
|
|
return
|
|
}
|
|
|
|
// --- check registered user ---
|
|
pkt.__tempBool, err = s.db.HasRegisteredUser(user)
|
|
if err != nil {
|
|
return
|
|
} else if !pkt.__tempBool {
|
|
err = s.rawWriteError(&pkt.from, pkt.fullPacket, notRegisteredB, true)
|
|
userErr = ErrNotRegisteredTriedToLogin
|
|
return
|
|
}
|
|
|
|
// --- ATOMIC anti replay check ---
|
|
pkt.__tempBool, err = s.db.CheckAndSetLastReq(user, pkt.fullPacket)
|
|
if err != nil {
|
|
return
|
|
} else if pkt.__tempBool {
|
|
userErr = ErrReplayAttackSuspected
|
|
return
|
|
}
|
|
|
|
// --- generate encapsulation key ---
|
|
encKey, userErr := mlkem.NewEncapsulationKey1024(pkt._got_encapKey)
|
|
if userErr != nil {
|
|
return
|
|
}
|
|
shared, ciphertext := encKey.Encapsulate()
|
|
|
|
// --- set active session ---
|
|
err = s.db.SetActiveSession(user, sha256.Sum256(shared))
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
err = s.rawWriteData(&pkt.from, pkt.fullPacket, ciphertext, true)
|
|
return
|
|
}
|
|
|
|
func (s *server) handleData(pkt reqPacket) (user owner.Identity, data []byte, userErr error, err error) {
|
|
pkt.startParts()
|
|
pkt._got_user = pkt.nextPart(owner.IdentitySize)
|
|
pkt._got_seq_B = pkt.nextPart(bit.SizeUint32_B)
|
|
pkt._got_payload = pkt.nextPart(rP_ReadAll)
|
|
if pkt._got_user == nil || pkt._got_seq_B == nil || pkt._got_payload == nil {
|
|
userErr = ErrMalformedPacket
|
|
return
|
|
}
|
|
pkt._got_seq = binary.BigEndian.Uint32(pkt._got_seq_B)
|
|
|
|
// --- check logged in user ---
|
|
pkt._sesh, pkt.__tempBool, err = s.db.GetActiveSession(owner.Identity(pkt._got_user))
|
|
if err != nil {
|
|
return
|
|
} else if !pkt.__tempBool {
|
|
err = s.rawWriteError(&pkt.from, pkt.fullPacket, notLoggedB, true)
|
|
userErr = ErrUserNotLoggedIn
|
|
return
|
|
}
|
|
|
|
data, userErr = menc.AESGCM_Quick_Decrypt(pkt._sesh[:], pkt._got_payload, append(pkt.header, pkt._got_seq_B...))
|
|
if userErr != nil {
|
|
return
|
|
}
|
|
|
|
// --- ATOMIC anti replay check ---
|
|
pkt.__tempBool, err = s.db.CheckAndSetLastReq(user, pkt.fullPacket)
|
|
if err != nil {
|
|
return
|
|
} else if pkt.__tempBool {
|
|
userErr = ErrReplayAttackSuspected
|
|
return
|
|
}
|
|
|
|
err = s.rawWriteAck(&pkt.from, pkt.fullPacket, true)
|
|
return
|
|
}
|
|
|
|
func (s *server) rawWriteAck(to *net.UDPAddr, tohash []byte, signed bool) error {
|
|
return s.rawWrite(to, Rs_Ack, tohash, none, signed)
|
|
}
|
|
|
|
func (s *server) rawWriteData(to *net.UDPAddr, tohash []byte, data []byte, signed bool) error {
|
|
return s.rawWrite(to, Rs_Data, tohash, data, signed)
|
|
}
|
|
|
|
func (s *server) rawWriteError(to *net.UDPAddr, tohash []byte, err []byte, signed bool) error {
|
|
return s.rawWrite(to, Rs_Error, tohash, err, signed)
|
|
}
|
|
|
|
func (s *server) rawWrite(to *net.UDPAddr, tp PacketResType, tohash []byte, payload []byte, signed bool) error {
|
|
hash := sha256.Sum256(tohash)
|
|
data := append(hash[:], payload...)
|
|
if signed {
|
|
signature, err := s.secret.Sign(data)
|
|
if err != nil {
|
|
panic("klokotek")
|
|
}
|
|
data = append(data, signature...)
|
|
}
|
|
data = append(magicWithVersion, append([]byte{byte(tp)}, data...)...)
|
|
_, err := s.listener.WriteToSFUDP(data, to)
|
|
return err
|
|
|
|
}
|
|
|
|
func userErrPayload(err error) []byte {
|
|
switch {
|
|
case errors.Is(err, ErrMalformedPacket):
|
|
return []byte("malformed")
|
|
case errors.Is(err, ErrNoAssociatedRegisterTokenFound):
|
|
return []byte("no_token")
|
|
case errors.Is(err, ErrInvalidRegisterCheckHash):
|
|
return []byte("invalid_check")
|
|
case errors.Is(err, ErrUserAlreadyRegistered):
|
|
return []byte("already_registered")
|
|
case errors.Is(err, ErrNotRegisteredTriedToLogin):
|
|
return []byte("not_registered")
|
|
case errors.Is(err, ErrUserNotLoggedIn):
|
|
return []byte("not_logged")
|
|
case errors.Is(err, ErrReplayAttackSuspected):
|
|
return []byte("replay")
|
|
default:
|
|
return []byte("user_error")
|
|
}
|
|
}
|