whspbrd-final/thrembio/client.go
2026-05-02 22:09:19 +02:00

376 lines
7.4 KiB
Go

package thrembio
// 80% of AI generated code
// human designed & reviewed-ish
import (
"WhspBrd/menc"
"WhspBrd/owner"
"WhspBrd/sfudp"
"WhspBrd/typio/bit"
"bytes"
"crypto/mlkem"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"log"
"net"
"os"
"strconv"
"time"
)
type Client interface {
Close() error
ping() error
GetServerId() owner.Identity
Register(tokenId uint32, token []byte) error
Login() error
Write(data []byte) error
Read() ([]byte, error)
}
type client struct {
conn *sfudp.SFUDPConn
secret owner.Secret
serverId owner.Identity
aes *menc.AESGCM_AutoNonce
sequence uint32
last []byte
debug bool
}
func NewClient(addr *net.UDPAddr, secret owner.Secret) (Client, error) {
if secret == nil {
return nil, ErrSecretCantBeNil
}
conn, err := sfudp.DialSFUDP("udp", nil, addr)
if err != nil {
return nil, err
}
c := &client{
conn: conn,
secret: secret,
debug: os.Getenv("WHSPBRD_DEBUG") != "",
}
err = c.ping()
if err != nil {
return nil, err
}
return c, nil
}
func (c *client) Close() error {
return c.conn.Close()
}
func (c *client) ping() error {
err := c.rawWrite(c.rawHeader(Rq_Ping), none)
if err != nil {
return err
}
t, _, id, err := c.receive()
if err != nil {
return err
}
if t != Rs_Ack {
return fmt.Errorf("expected ack")
}
if len(id) <= 0 {
return fmt.Errorf("expected valid id")
}
c.serverId = id
return nil
}
func (c *client) receive() (PacketResType, []byte, owner.Identity, error) {
buf := make([]byte, 16384)
if err := c.setReadDeadline(); err != nil {
return 0, nil, owner.Identity{}, err
}
n, err := c.conn.Read(buf)
if err != nil {
return 0, nil, owner.Identity{}, err
}
if c.debug {
log.Printf("thrembio receive bytes=%d", n)
}
return c.rawRead(buf[:n])
}
func (c *client) setReadDeadline() error {
ms := readTimeoutMs()
if ms <= 0 {
return nil
}
return c.conn.GetUDP().SetReadDeadline(time.Now().Add(time.Duration(ms) * time.Millisecond))
}
func readTimeoutMs() int {
val := os.Getenv("WHSPBRD_TIMEOUT_MS")
if val == "" {
return 5000
}
ms, err := strconv.Atoi(val)
if err != nil {
return 5000
}
return ms
}
func (c *client) GetServerId() owner.Identity {
return c.serverId
}
func (c *client) Register(tokenId uint32, token []byte) error {
header := c.rawHeader(Rq_Register)
tokenIdB := make([]byte, bit.SizeUint32_B)
binary.BigEndian.PutUint32(tokenIdB, tokenId)
identity := c.secret.Identity()
checkHash := sha256.Sum256(append(append(append(header, tokenIdB...), token...), identity[:]...))
signData := make([]byte, 0, len(header)+len(tokenIdB)+len(checkHash))
signData = append(signData, header...)
signData = append(signData, tokenIdB...)
signData = append(signData, checkHash[:]...)
sign, err := c.secret.Sign(signData)
if err != nil {
return err
}
payload := make([]byte, 0, len(tokenIdB)+len(checkHash)+len(sign))
payload = append(payload, tokenIdB...)
payload = append(payload, checkHash[:]...)
payload = append(payload, sign...)
if err := c.rawWrite(header, payload); err != nil {
return err
}
t, data, id, err := c.receive()
if err != nil {
return err
}
if len(id) > 0 {
c.serverId = id
}
switch t {
case Rs_Ack:
return nil
case Rs_Error:
return fmt.Errorf("register failed: %s", string(data))
default:
return fmt.Errorf("unexpected response type %d", t)
}
}
func (c *client) Login() error {
decKey, err := mlkem.GenerateKey1024()
if err != nil {
return err
}
encapKey := decKey.EncapsulationKey().Bytes()
header := c.rawHeader(Rq_Login)
signData := make([]byte, 0, len(header)+len(encapKey))
signData = append(signData, header...)
signData = append(signData, encapKey...)
sign, err := c.secret.Sign(signData)
if err != nil {
return err
}
payload := make([]byte, 0, len(encapKey)+len(sign))
payload = append(payload, encapKey...)
payload = append(payload, sign...)
if err := c.rawWrite(header, payload); err != nil {
return err
}
t, data, id, err := c.receive()
if err != nil {
return err
}
if len(id) > 0 {
c.serverId = id
}
switch t {
case Rs_Error:
return fmt.Errorf("login failed: %s", string(data))
case Rs_Data:
shared, err := decKey.Decapsulate(data)
if err != nil {
return err
}
key := sha256.Sum256(shared)
c.aes, err = menc.NewAESGCM_AutoNonce(key[:])
if err != nil {
return err
}
c.sequence = 0
return nil
default:
return fmt.Errorf("unexpected response type %d", t)
}
}
func (c *client) Write(data []byte) error {
if c.aes == nil {
return errors.New("not logged in")
}
header := c.rawHeader(Rq_Data)
userID := c.secret.Identity()
seqB := make([]byte, bit.SizeUint32_B)
binary.BigEndian.PutUint32(seqB, c.sequence)
aad := make([]byte, 0, len(header)+len(seqB))
aad = append(aad, header...)
aad = append(aad, seqB...)
ciphertext, err := c.aes.Encrypt(data, aad)
if err != nil {
return err
}
payload := make([]byte, 0, len(userID)+len(seqB)+len(ciphertext))
payload = append(payload, userID[:]...)
payload = append(payload, seqB...)
payload = append(payload, ciphertext...)
if err := c.rawWrite(header, payload); err != nil {
return err
}
t, data, id, err := c.receive()
if err != nil {
return err
}
if len(id) > 0 {
c.serverId = id
}
switch t {
case Rs_Ack:
c.sequence++
return nil
case Rs_Error:
return fmt.Errorf("write failed: %s", string(data))
default:
return fmt.Errorf("unexpected response type %d", t)
}
}
func (c *client) Read() ([]byte, error) {
if c.aes == nil {
return nil, errors.New("not logged in")
}
t, data, id, err := c.receive()
if err != nil {
return nil, err
}
if len(id) > 0 {
c.serverId = id
}
switch t {
case Rs_Data:
if len(c.last) == 0 {
return nil, errors.New("missing last request for decryption")
}
plain, err := c.aes.Decrypt(data, c.last)
if err != nil {
return nil, err
}
return plain, nil
case Rs_Error:
return nil, fmt.Errorf("server error: %s", string(data))
default:
return nil, fmt.Errorf("unexpected response type %d", t)
}
}
func (c *client) rawHeader(reqType PacketReqType) []byte {
buf := make([]byte, commonHeaderSize)
buf[0] = magicBytes[0]
buf[1] = magicBytes[1]
buf[2] = magicBytes[2]
buf[3] = version
nowMs := uint64(time.Now().UnixMilli())
binary.LittleEndian.PutUint64(buf[4:12], nowMs)
buf[12] = byte(reqType)
return buf
}
func (c *client) rawWrite(header, data []byte) error {
last := append(header, data...)
_, err := c.conn.Write(last)
c.last = last
return err
}
func (c *client) rawRead(p []byte) (PacketResType, []byte, owner.Identity, error) {
if len(p) < 37 ||
p[0] != magicBytes[0] ||
p[1] != magicBytes[1] ||
p[2] != magicBytes[2] ||
p[3] != version {
return 0, nil, owner.Identity{}, fmt.Errorf("bad packet")
}
tp := PacketResType(p[4])
if tp >= Rs_Unknown {
return 0, nil, owner.Identity{}, fmt.Errorf("bad type")
}
hash := p[5:37]
data := p[37:]
expectedHash := sha256.Sum256(c.last)
if !bytes.Equal(hash, expectedHash[:]) {
return 0, nil, owner.Identity{}, fmt.Errorf("hash mismatch")
}
if len(data) >= owner.SignatureSize {
payload := data[:len(data)-owner.SignatureSize]
sig := data[len(data)-owner.SignatureSize:]
buf := make([]byte, len(hash)+len(payload))
copy(buf, hash)
copy(buf[len(hash):], payload)
id, err := owner.Verify(buf, sig)
if err == nil {
return tp, payload, id, nil
}
}
if tp == Rs_Data {
return tp, data, owner.Identity{}, nil
}
return 0, nil, owner.Identity{}, fmt.Errorf("bad packet")
}