whspbrd-final/relay/server.go
2026-05-02 22:09:19 +02:00

243 lines
5.5 KiB
Go

package relay
import (
"WhspBrd/owner"
"WhspBrd/thrembio"
"crypto/mlkem"
"encoding/binary"
"sync"
)
type Server interface {
Start() error
Close()
GetId() owner.Identity
AddRegisterToken(tokenId uint32, token []byte) error
GetRegisterTokens() map[uint32][]byte
}
type server struct {
db thrembio.ServerDB
th thrembio.Server
sc owner.Secret
messages map[owner.Identity][][]byte
userKems map[owner.Identity]*mlkem.EncapsulationKey1024
mu sync.RWMutex
}
func NewServer(port uint16, secret owner.Secret, path string) (Server, error) {
db, err := thrembio.NewSQLiteDB(path)
if err != nil {
return nil, err
}
th, err := thrembio.NewServer(port, secret, db)
if err != nil {
return nil, err
}
th.SetFlags(thrembio.SF_NoUserErrorLog | thrembio.SF_NoRegisterLog | thrembio.SF_NoLoginLog)
return &server{
db: db,
th: th,
sc: secret,
messages: make(map[owner.Identity][][]byte),
userKems: make(map[owner.Identity]*mlkem.EncapsulationKey1024),
}, nil
}
func (s *server) Start() error {
err := s.th.Open()
if err != nil {
return err
}
return s.handleData()
}
func (s *server) handleData() error {
for {
rt, user, _, data, err := s.th.Read()
if err != nil {
if rt == thrembio.RT_UserError {
continue
}
return err
}
switch rt {
case thrembio.RT_Register:
// New user registered
continue
case thrembio.RT_Login:
// User logged in
continue
case thrembio.RT_Data:
// Parse relay protocol command
if len(data) < 6 {
s.th.Write([]byte("error"), user)
continue
}
cmd := string(data[:6])
payload := data[6:]
switch cmd {
case "pubkem":
s.handlePublishKem(user, payload)
case "getkem":
s.handleGetKem(user, payload)
case "send":
s.handleSend(user, payload)
case "msglen":
s.handleMessageLen(user)
case "msgget":
s.handleMessageGet(user, payload)
default:
s.th.Write([]byte("error"), user)
}
case thrembio.RT_UserError, thrembio.RT_InsideError:
// Handle errors or continue
continue
}
}
}
func (s *server) handlePublishKem(user owner.Identity, payload []byte) {
// Expected: [serverPuha] [kek] [osign(serverPuha | kek)]
if len(payload) < owner.IdentitySize+mlkem.EncapsulationKeySize1024+owner.SignatureSize {
s.th.Write([]byte("error"), user)
return
}
serverPuha := owner.Identity(payload[:owner.IdentitySize])
kek := payload[owner.IdentitySize : owner.IdentitySize+mlkem.EncapsulationKeySize1024]
osign := payload[owner.IdentitySize+mlkem.EncapsulationKeySize1024:]
// Verify signature signed by the user
whosver, err := owner.Verify(append(serverPuha[:], kek...), osign)
if err != nil || !owner.IdentityEq(whosver, user) {
s.th.Write([]byte("error"), user)
return
}
// Store KEM for this user
kemKey, err := mlkem.NewEncapsulationKey1024(kek)
if err != nil {
s.th.Write([]byte("error"), user)
return
}
s.mu.Lock()
s.userKems[user] = kemKey
s.mu.Unlock()
s.th.Write([]byte("done"), user)
}
func (s *server) handleGetKem(user owner.Identity, payload []byte) {
// Expected: [receiverPuha]
if len(payload) < owner.IdentitySize {
s.th.Write([]byte("error"), user)
return
}
receiverPuha := owner.Identity(payload[:owner.IdentitySize])
// Get receiver's KEM
s.mu.RLock()
kemKey, exists := s.userKems[receiverPuha]
s.mu.RUnlock()
if !exists || kemKey == nil {
s.th.Write([]byte("error"), user)
return
}
// Return [serverPuha] [kek] [osign(serverPuha | kek)]
id := s.sc.Identity()
responseData := append(id[:], kemKey.Bytes()...)
osign, err := s.sc.Sign(responseData)
if err != nil {
s.th.Write([]byte("error"), user)
return
}
responseData = append(responseData, osign...)
s.th.Write(responseData, user)
}
func (s *server) handleSend(user owner.Identity, payload []byte) {
// Expected: [serverPuha] [receiverPuha] [kct] [time] [encData] [osign(...)]
// Minimum: owner.IdentitySize*2 + mlkem.CiphertextSize1024 + 8 + owner.SignatureSize
if len(payload) < owner.IdentitySize*2+mlkem.CiphertextSize1024+8+owner.SignatureSize {
s.th.Write([]byte("error"), user)
return
}
// Extract receiverPuha (sender is verified by thrembio layer)
receiverPuha := owner.Identity(payload[owner.IdentitySize : owner.IdentitySize*2])
// Store the full message payload for the receiver
s.mu.Lock()
s.messages[receiverPuha] = append(s.messages[receiverPuha], payload)
s.mu.Unlock()
// Acknowledge
s.th.Write([]byte("send"), user)
}
func (s *server) handleMessageLen(user owner.Identity) {
// Return message count as 4-byte big-endian
s.mu.RLock()
msgCount := len(s.messages[user])
s.mu.RUnlock()
lenB := make([]byte, 4)
binary.BigEndian.PutUint32(lenB, uint32(msgCount))
s.th.Write(lenB, user)
}
func (s *server) handleMessageGet(user owner.Identity, payload []byte) {
// Expected: [id] (4 bytes)
if len(payload) < 4 {
s.th.Write([]byte("error"), user)
return
}
msgId := binary.BigEndian.Uint32(payload[:4])
s.mu.RLock()
msgs := s.messages[user]
if msgId >= uint32(len(msgs)) {
s.mu.RUnlock()
s.th.Write([]byte("error"), user)
return
}
// Return the full message payload: [serverPuha] [senderPuha] [kct] [time] [encData] [osign(...)]
msg := msgs[msgId]
s.mu.RUnlock()
s.th.Write(msg, user)
}
func (s *server) Close() {
s.th.Close()
}
func (s *server) GetId() owner.Identity {
return s.sc.Identity()
}
func (s *server) AddRegisterToken(tokenId uint32, token []byte) error {
return s.db.SetRegisterToken(tokenId, token)
}
func (s *server) GetRegisterTokens() map[uint32][]byte {
tokens, err := s.db.GetRegisterTokens()
if err != nil {
return nil
}
return tokens
}