package thrembio // 10% of AI generated code // human made; AI details/comments // FINAL import ( "WhspBrd/menc" "WhspBrd/owner" "WhspBrd/sfudp" "WhspBrd/thrembio/taskpool" "WhspBrd/typio/bit" "bytes" "crypto/mlkem" "crypto/sha256" "encoding/binary" "errors" "log" "net" "os" "sync" "time" ) var ( ErrServerAlreadyOpen = errors.New("server already open") ErrMalformedPacket = errors.New("user provided malformed packet") ErrNoAssociatedRegisterTokenFound = errors.New("user provided register token ID that doesn't exist") ErrInvalidRegisterCheckHash = errors.New("user provided invalid register check hash (binds token + identity)") ErrUserAlreadyRegistered = errors.New("user is already registered") ErrNotRegisteredTriedToLogin = errors.New("user that is not registered tried to login") ErrReplayAttackSuspected = errors.New("replay attack suspected due to repeated packet") ErrUserNotLoggedIn = errors.New("user not logged in tried to send data") ) type ReadType bit.Bit8 const ( RT_None ReadType = iota // New user registered RT_Register // User logged in RT_Login // Data from authenticated user RT_Data // Client-related error RT_UserError // Internal server error RT_InsideError ) type serverFlag bit.Bit8 const ( SF_None serverFlag = 0 // Don't log user errors (like invalid register token, invalid login, etc.) SF_NoUserErrorLog serverFlag = 1 << 0 // Don't log register events (successful ones) SF_NoRegisterLog serverFlag = 1 << 1 // Don't log login events (successful ones) SF_NoLoginLog serverFlag = 1 << 2 ) func (f *serverFlag) has(flag serverFlag) bool { return (*f)&flag != 0 } type reqPacket struct { from net.UDPAddr // The address of the client that sent the request. time uint64 reqType PacketReqType // The type of the request (register, login, data, etc.) fullPacket []byte // The full packet that was sent by the client, including everything. header []byte packet []byte // The packet without the common header (magic, version, timestamp, type), so just the payload-ish. // register _got_tokenID_B []byte _got_checkHash []byte _got_sign []byte _got_tokenID uint32 _token []byte // login _got_encapKey []byte //_got_sign []byte // data _got_user []byte _got_seq_B []byte _got_payload []byte _got_seq uint32 _sesh [32]byte // temps __tempBool bool __tempBytes []byte // for auto parts __offset int } const rP_ReadAll = -1 func (pkt *reqPacket) startParts() { pkt.__offset = 0 } func (pkt *reqPacket) nextPart(size int) []byte { start := pkt.__offset if start > len(pkt.packet) { return nil } var end int switch { case size == rP_ReadAll: end = len(pkt.packet) case size >= 0 && start+size <= len(pkt.packet): end = start + size default: return nil } pkt.__offset = end return pkt.packet[start:end] } type readPacket struct { rt ReadType u owner.Identity fa net.UDPAddr d []byte e error } type Server interface { Open() error Close() SetFlags(flags serverFlag) /* Read waits for incoming packets until a valid one is processed. Invalid packets are ignored. Returns one of: RT_Register: (RT_Register, newUser, fromAddr, nil, nil) Ignored if SF_NoRegisterLog is enabled. RT_Login: (RT_Login, user, fromAddr, nil, nil) Ignored if SF_NoLoginLog is enabled. RT_Data: (RT_Data, user, fromAddr, requestData, nil) RT_UserError: (RT_UserError, ?user, fromAddr, nil, error) Ignored if SF_NoUserError is enabled. RT_InsideError: (RT_InsideError, emptyUser, emptyAddr, nil, error) */ Read() (readType ReadType, user owner.Identity, fromAddr net.UDPAddr, data []byte, err error) Write(data []byte, to owner.Identity) error } type server struct { port uint16 secret owner.Secret db ServerDB flags serverFlag listener *sfudp.SFUDPConn debug bool bufPool sync.Pool taskPool *taskpool.Pool[reqPacket] packets chan readPacket done chan struct{} doneMu sync.Once } func NewServer(port uint16, secret owner.Secret, db ServerDB) (Server, error) { if secret == nil { return nil, ErrSecretCantBeNil } srv := &server{ port: port, secret: secret, db: db, listener: nil, taskPool: taskpool.New[reqPacket](0, 0), bufPool: sync.Pool{ New: func() any { b := make([]byte, 16384) return &b }, }, debug: os.Getenv("WHSPBRD_DEBUG") != "", } return srv, nil } func (s *server) SetFlags(flags serverFlag) { s.flags = flags } func (s *server) Open() error { if s.listener != nil { return ErrServerAlreadyOpen } l, err := sfudp.ListenSFUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: int(s.port)}) if err != nil { return err } s.listener = l s.packets = make(chan readPacket, 1024) s.done = make(chan struct{}) s.taskPool.Open() go s.serve() return nil } func (s *server) serve() { defer close(s.packets) for { select { case <-s.done: return default: } bufPtr := s.bufPool.Get().(*[]byte) buf := *bufPtr length, addr, err := s.listener.ReadFromSFUDP(buf[:]) if err != nil { s.bufPool.Put(bufPtr) return } if length < commonHeaderSize || buf[0] != magicBytes[0] || buf[1] != magicBytes[1] || buf[2] != magicBytes[2] || buf[3] != version { continue } ts := binary.LittleEndian.Uint64(buf[4:12]) now := time.Now() nowMs := uint64(now.UnixMilli()) if ts > nowMs || nowMs-ts > 8_000 { continue } reqType := PacketReqType(buf[12]) if reqType >= Rq_Unknown { continue } pkt := reqPacket{ from: *addr, time: ts, reqType: reqType, fullPacket: buf[:length], } pkt.header = pkt.fullPacket[:13] pkt.packet = pkt.fullPacket[13:] s.taskPool.Dispatch(taskpool.Task[reqPacket]{ Value: pkt, Handle: s.handlePacket, }) } } func (s *server) handlePacket(pkt reqPacket) { defer s.bufPool.Put(&pkt.fullPacket) var ( user owner.Identity data []byte userErr error err error ) switch pkt.reqType { case Rq_Ping: s.rawWriteAck(&pkt.from, pkt.fullPacket, true) return case Rq_Register: user, userErr, err = s.handleRegister(pkt) case Rq_Login: user, userErr, err = s.handleLogin(pkt) case Rq_Data: user, data, userErr, err = s.handleData(pkt) default: return } s.emitResult(pkt, user, data, userErr, err) } func (s *server) emitResult( pkt reqPacket, user owner.Identity, data []byte, userErr error, err error, ) { var rp readPacket if err != nil { rp = readPacket{RT_InsideError, owner.Identity{}, net.UDPAddr{}, nil, err} } else if userErr != nil { if s.debug { log.Printf("user error: %v", userErr) } if err == nil { _ = s.rawWriteError(&pkt.from, pkt.fullPacket, userErrPayload(userErr), true) } if s.flags.has(SF_NoUserErrorLog) { return } rp = readPacket{RT_UserError, owner.Identity{}, pkt.from, nil, userErr} } else { switch pkt.reqType { case Rq_Register: if s.flags.has(SF_NoRegisterLog) { return } rp = readPacket{RT_Register, user, pkt.from, nil, nil} case Rq_Login: if s.flags.has(SF_NoLoginLog) { return } rp = readPacket{RT_Login, user, pkt.from, nil, nil} case Rq_Data: rp = readPacket{RT_Data, user, pkt.from, data, nil} } s.db.SetLastIp(user, &pkt.from) } select { case s.packets <- rp: case <-s.done: } } func (s *server) Close() { s.doneMu.Do(func() { if s.listener == nil { return } close(s.done) s.listener.Close() s.taskPool.Close() s.listener = nil }) } func (s *server) Read() (ReadType, owner.Identity, net.UDPAddr, []byte, error) { p, ok := <-s.packets if !ok { return RT_InsideError, owner.Identity{}, net.UDPAddr{}, nil, net.ErrClosed } return p.rt, p.u, p.fa, p.d, p.e } func (s *server) Write(data []byte, user owner.Identity) error { ip, ex := s.db.GetLastIp(user) if !ex { return errors.New("no known IP for user") } req, ex := s.db.GetLastReq(user) if !ex { return errors.New("no known last request for user") } key, ex, err := s.db.GetActiveSession(user) if err != nil { return err } if !ex { return errors.New("no active session for user") } ciphertext, err := menc.AESGCM_Quick_Encrypt(key[:], data, req) if err != nil { return err } return s.rawWriteData(ip, req, ciphertext, false) } func (s *server) handleRegister(pkt reqPacket) (newUser owner.Identity, userErr error, err error) { pkt.startParts() pkt._got_tokenID_B = pkt.nextPart(bit.SizeUint32_B) pkt._got_checkHash = pkt.nextPart(bit.SizeSha256_B) pkt._got_sign = pkt.nextPart(owner.SignatureSize) pkt.__tempBytes = pkt.nextPart(rP_ReadAll) if pkt._got_tokenID_B == nil || pkt._got_checkHash == nil || pkt._got_sign == nil || len(pkt.__tempBytes) != 0 { userErr = ErrMalformedPacket return } pkt._got_tokenID = binary.BigEndian.Uint32(pkt._got_tokenID_B) // --- ATOMIC token consumption --- pkt._token, pkt.__tempBool, err = s.db.ConsumeRegisterToken(pkt._got_tokenID) if err != nil { return } else if !pkt.__tempBool { userErr = ErrNoAssociatedRegisterTokenFound return } // --- verify identity signature --- newUser, userErr = owner.Verify( pkt.fullPacket[:len(pkt.fullPacket)-owner.SignatureSize], pkt._got_sign, ) if userErr != nil { return } // --- check hash binds token + identity --- sha := sha256.New() sha.Write(pkt.header) sha.Write(pkt._got_tokenID_B) sha.Write(pkt._token) sha.Write(newUser[:]) if !bytes.Equal(sha.Sum(nil), pkt._got_checkHash) { userErr = ErrInvalidRegisterCheckHash return } // --- ATOMIC register user --- alreadyExists, err := s.db.AddRegisteredUserAtomic(newUser) if err != nil { return } else if alreadyExists { userErr = ErrUserAlreadyRegistered return } err = s.rawWriteAck(&pkt.from, pkt.fullPacket, true) return } func (s *server) handleLogin(pkt reqPacket) (user owner.Identity, userErr error, err error) { pkt.startParts() pkt._got_encapKey = pkt.nextPart(mlkem.EncapsulationKeySize1024) pkt._got_sign = pkt.nextPart(owner.SignatureSize) pkt.__tempBytes = pkt.nextPart(rP_ReadAll) if pkt._got_encapKey == nil || pkt._got_sign == nil || len(pkt.__tempBytes) != 0 { userErr = ErrMalformedPacket return } // --- verify user signature --- user, userErr = owner.Verify( pkt.fullPacket[:len(pkt.fullPacket)-owner.SignatureSize], pkt._got_sign, ) if userErr != nil { return } // --- check registered user --- pkt.__tempBool, err = s.db.HasRegisteredUser(user) if err != nil { return } else if !pkt.__tempBool { err = s.rawWriteError(&pkt.from, pkt.fullPacket, notRegisteredB, true) userErr = ErrNotRegisteredTriedToLogin return } // --- ATOMIC anti replay check --- pkt.__tempBool, err = s.db.CheckAndSetLastReq(user, pkt.fullPacket) if err != nil { return } else if pkt.__tempBool { userErr = ErrReplayAttackSuspected return } // --- generate encapsulation key --- encKey, userErr := mlkem.NewEncapsulationKey1024(pkt._got_encapKey) if userErr != nil { return } shared, ciphertext := encKey.Encapsulate() // --- set active session --- err = s.db.SetActiveSession(user, sha256.Sum256(shared)) if err != nil { return } err = s.rawWriteData(&pkt.from, pkt.fullPacket, ciphertext, true) return } func (s *server) handleData(pkt reqPacket) (user owner.Identity, data []byte, userErr error, err error) { pkt.startParts() pkt._got_user = pkt.nextPart(owner.IdentitySize) pkt._got_seq_B = pkt.nextPart(bit.SizeUint32_B) pkt._got_payload = pkt.nextPart(rP_ReadAll) if pkt._got_user == nil || pkt._got_seq_B == nil || pkt._got_payload == nil { userErr = ErrMalformedPacket return } pkt._got_seq = binary.BigEndian.Uint32(pkt._got_seq_B) // --- check logged in user --- pkt._sesh, pkt.__tempBool, err = s.db.GetActiveSession(owner.Identity(pkt._got_user)) if err != nil { return } else if !pkt.__tempBool { err = s.rawWriteError(&pkt.from, pkt.fullPacket, notLoggedB, true) userErr = ErrUserNotLoggedIn return } data, userErr = menc.AESGCM_Quick_Decrypt(pkt._sesh[:], pkt._got_payload, append(pkt.header, pkt._got_seq_B...)) if userErr != nil { return } // --- ATOMIC anti replay check --- pkt.__tempBool, err = s.db.CheckAndSetLastReq(user, pkt.fullPacket) if err != nil { return } else if pkt.__tempBool { userErr = ErrReplayAttackSuspected return } err = s.rawWriteAck(&pkt.from, pkt.fullPacket, true) return } func (s *server) rawWriteAck(to *net.UDPAddr, tohash []byte, signed bool) error { return s.rawWrite(to, Rs_Ack, tohash, none, signed) } func (s *server) rawWriteData(to *net.UDPAddr, tohash []byte, data []byte, signed bool) error { return s.rawWrite(to, Rs_Data, tohash, data, signed) } func (s *server) rawWriteError(to *net.UDPAddr, tohash []byte, err []byte, signed bool) error { return s.rawWrite(to, Rs_Error, tohash, err, signed) } func (s *server) rawWrite(to *net.UDPAddr, tp PacketResType, tohash []byte, payload []byte, signed bool) error { hash := sha256.Sum256(tohash) data := append(hash[:], payload...) if signed { signature, err := s.secret.Sign(data) if err != nil { panic("klokotek") } data = append(data, signature...) } data = append(magicWithVersion, append([]byte{byte(tp)}, data...)...) _, err := s.listener.WriteToSFUDP(data, to) return err } func userErrPayload(err error) []byte { switch { case errors.Is(err, ErrMalformedPacket): return []byte("malformed") case errors.Is(err, ErrNoAssociatedRegisterTokenFound): return []byte("no_token") case errors.Is(err, ErrInvalidRegisterCheckHash): return []byte("invalid_check") case errors.Is(err, ErrUserAlreadyRegistered): return []byte("already_registered") case errors.Is(err, ErrNotRegisteredTriedToLogin): return []byte("not_registered") case errors.Is(err, ErrUserNotLoggedIn): return []byte("not_logged") case errors.Is(err, ErrReplayAttackSuspected): return []byte("replay") default: return []byte("user_error") } }