243 lines
5.5 KiB
Go
243 lines
5.5 KiB
Go
package relay
|
|
|
|
import (
|
|
"WhspBrd/owner"
|
|
"WhspBrd/thrembio"
|
|
"crypto/mlkem"
|
|
"encoding/binary"
|
|
"sync"
|
|
)
|
|
|
|
type Server interface {
|
|
Start() error
|
|
Close()
|
|
GetId() owner.Identity
|
|
AddRegisterToken(tokenId uint32, token []byte) error
|
|
GetRegisterTokens() map[uint32][]byte
|
|
}
|
|
|
|
type server struct {
|
|
db thrembio.ServerDB
|
|
th thrembio.Server
|
|
sc owner.Secret
|
|
messages map[owner.Identity][][]byte
|
|
userKems map[owner.Identity]*mlkem.EncapsulationKey1024
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewServer(port uint16, secret owner.Secret, path string) (Server, error) {
|
|
db, err := thrembio.NewSQLiteDB(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
th, err := thrembio.NewServer(port, secret, db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
th.SetFlags(thrembio.SF_NoUserErrorLog | thrembio.SF_NoRegisterLog | thrembio.SF_NoLoginLog)
|
|
|
|
return &server{
|
|
db: db,
|
|
th: th,
|
|
sc: secret,
|
|
messages: make(map[owner.Identity][][]byte),
|
|
userKems: make(map[owner.Identity]*mlkem.EncapsulationKey1024),
|
|
}, nil
|
|
}
|
|
|
|
func (s *server) Start() error {
|
|
err := s.th.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.handleData()
|
|
}
|
|
|
|
func (s *server) handleData() error {
|
|
for {
|
|
rt, user, _, data, err := s.th.Read()
|
|
if err != nil {
|
|
if rt == thrembio.RT_UserError {
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
|
|
switch rt {
|
|
case thrembio.RT_Register:
|
|
// New user registered
|
|
continue
|
|
case thrembio.RT_Login:
|
|
// User logged in
|
|
continue
|
|
case thrembio.RT_Data:
|
|
// Parse relay protocol command
|
|
if len(data) < 6 {
|
|
s.th.Write([]byte("error"), user)
|
|
continue
|
|
}
|
|
|
|
cmd := string(data[:6])
|
|
payload := data[6:]
|
|
|
|
switch cmd {
|
|
case "pubkem":
|
|
s.handlePublishKem(user, payload)
|
|
case "getkem":
|
|
s.handleGetKem(user, payload)
|
|
case "send":
|
|
s.handleSend(user, payload)
|
|
case "msglen":
|
|
s.handleMessageLen(user)
|
|
case "msgget":
|
|
s.handleMessageGet(user, payload)
|
|
default:
|
|
s.th.Write([]byte("error"), user)
|
|
}
|
|
case thrembio.RT_UserError, thrembio.RT_InsideError:
|
|
// Handle errors or continue
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *server) handlePublishKem(user owner.Identity, payload []byte) {
|
|
// Expected: [serverPuha] [kek] [osign(serverPuha | kek)]
|
|
if len(payload) < owner.IdentitySize+mlkem.EncapsulationKeySize1024+owner.SignatureSize {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
serverPuha := owner.Identity(payload[:owner.IdentitySize])
|
|
kek := payload[owner.IdentitySize : owner.IdentitySize+mlkem.EncapsulationKeySize1024]
|
|
osign := payload[owner.IdentitySize+mlkem.EncapsulationKeySize1024:]
|
|
|
|
// Verify signature signed by the user
|
|
whosver, err := owner.Verify(append(serverPuha[:], kek...), osign)
|
|
if err != nil || !owner.IdentityEq(whosver, user) {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
// Store KEM for this user
|
|
kemKey, err := mlkem.NewEncapsulationKey1024(kek)
|
|
if err != nil {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.userKems[user] = kemKey
|
|
s.mu.Unlock()
|
|
|
|
s.th.Write([]byte("done"), user)
|
|
}
|
|
|
|
func (s *server) handleGetKem(user owner.Identity, payload []byte) {
|
|
// Expected: [receiverPuha]
|
|
if len(payload) < owner.IdentitySize {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
receiverPuha := owner.Identity(payload[:owner.IdentitySize])
|
|
|
|
// Get receiver's KEM
|
|
s.mu.RLock()
|
|
kemKey, exists := s.userKems[receiverPuha]
|
|
s.mu.RUnlock()
|
|
|
|
if !exists || kemKey == nil {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
// Return [serverPuha] [kek] [osign(serverPuha | kek)]
|
|
id := s.sc.Identity()
|
|
responseData := append(id[:], kemKey.Bytes()...)
|
|
osign, err := s.sc.Sign(responseData)
|
|
if err != nil {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
responseData = append(responseData, osign...)
|
|
|
|
s.th.Write(responseData, user)
|
|
}
|
|
|
|
func (s *server) handleSend(user owner.Identity, payload []byte) {
|
|
// Expected: [serverPuha] [receiverPuha] [kct] [time] [encData] [osign(...)]
|
|
// Minimum: owner.IdentitySize*2 + mlkem.CiphertextSize1024 + 8 + owner.SignatureSize
|
|
if len(payload) < owner.IdentitySize*2+mlkem.CiphertextSize1024+8+owner.SignatureSize {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
// Extract receiverPuha (sender is verified by thrembio layer)
|
|
receiverPuha := owner.Identity(payload[owner.IdentitySize : owner.IdentitySize*2])
|
|
|
|
// Store the full message payload for the receiver
|
|
s.mu.Lock()
|
|
s.messages[receiverPuha] = append(s.messages[receiverPuha], payload)
|
|
s.mu.Unlock()
|
|
|
|
// Acknowledge
|
|
s.th.Write([]byte("send"), user)
|
|
}
|
|
|
|
func (s *server) handleMessageLen(user owner.Identity) {
|
|
// Return message count as 4-byte big-endian
|
|
s.mu.RLock()
|
|
msgCount := len(s.messages[user])
|
|
s.mu.RUnlock()
|
|
|
|
lenB := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(lenB, uint32(msgCount))
|
|
s.th.Write(lenB, user)
|
|
}
|
|
|
|
func (s *server) handleMessageGet(user owner.Identity, payload []byte) {
|
|
// Expected: [id] (4 bytes)
|
|
if len(payload) < 4 {
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
msgId := binary.BigEndian.Uint32(payload[:4])
|
|
|
|
s.mu.RLock()
|
|
msgs := s.messages[user]
|
|
if msgId >= uint32(len(msgs)) {
|
|
s.mu.RUnlock()
|
|
s.th.Write([]byte("error"), user)
|
|
return
|
|
}
|
|
|
|
// Return the full message payload: [serverPuha] [senderPuha] [kct] [time] [encData] [osign(...)]
|
|
msg := msgs[msgId]
|
|
s.mu.RUnlock()
|
|
|
|
s.th.Write(msg, user)
|
|
}
|
|
|
|
func (s *server) Close() {
|
|
s.th.Close()
|
|
}
|
|
|
|
func (s *server) GetId() owner.Identity {
|
|
return s.sc.Identity()
|
|
}
|
|
|
|
func (s *server) AddRegisterToken(tokenId uint32, token []byte) error {
|
|
return s.db.SetRegisterToken(tokenId, token)
|
|
}
|
|
|
|
func (s *server) GetRegisterTokens() map[uint32][]byte {
|
|
tokens, err := s.db.GetRegisterTokens()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return tokens
|
|
}
|