376 lines
7.4 KiB
Go
376 lines
7.4 KiB
Go
package thrembio
|
|
|
|
// 80% of AI generated code
|
|
// human designed & reviewed-ish
|
|
|
|
import (
|
|
"WhspBrd/menc"
|
|
"WhspBrd/owner"
|
|
"WhspBrd/sfudp"
|
|
"WhspBrd/typio/bit"
|
|
"bytes"
|
|
"crypto/mlkem"
|
|
"crypto/sha256"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
type Client interface {
|
|
Close() error
|
|
ping() error
|
|
GetServerId() owner.Identity
|
|
|
|
Register(tokenId uint32, token []byte) error
|
|
Login() error
|
|
Write(data []byte) error
|
|
Read() ([]byte, error)
|
|
}
|
|
|
|
type client struct {
|
|
conn *sfudp.SFUDPConn
|
|
secret owner.Secret
|
|
|
|
serverId owner.Identity
|
|
aes *menc.AESGCM_AutoNonce
|
|
sequence uint32
|
|
|
|
last []byte
|
|
|
|
debug bool
|
|
}
|
|
|
|
func NewClient(addr *net.UDPAddr, secret owner.Secret) (Client, error) {
|
|
if secret == nil {
|
|
return nil, ErrSecretCantBeNil
|
|
}
|
|
|
|
conn, err := sfudp.DialSFUDP("udp", nil, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c := &client{
|
|
conn: conn,
|
|
secret: secret,
|
|
debug: os.Getenv("WHSPBRD_DEBUG") != "",
|
|
}
|
|
|
|
err = c.ping()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (c *client) Close() error {
|
|
return c.conn.Close()
|
|
}
|
|
|
|
func (c *client) ping() error {
|
|
err := c.rawWrite(c.rawHeader(Rq_Ping), none)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
t, _, id, err := c.receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if t != Rs_Ack {
|
|
return fmt.Errorf("expected ack")
|
|
}
|
|
if len(id) <= 0 {
|
|
return fmt.Errorf("expected valid id")
|
|
}
|
|
c.serverId = id
|
|
return nil
|
|
}
|
|
|
|
func (c *client) receive() (PacketResType, []byte, owner.Identity, error) {
|
|
buf := make([]byte, 16384)
|
|
if err := c.setReadDeadline(); err != nil {
|
|
return 0, nil, owner.Identity{}, err
|
|
}
|
|
n, err := c.conn.Read(buf)
|
|
if err != nil {
|
|
return 0, nil, owner.Identity{}, err
|
|
}
|
|
if c.debug {
|
|
log.Printf("thrembio receive bytes=%d", n)
|
|
}
|
|
return c.rawRead(buf[:n])
|
|
}
|
|
|
|
func (c *client) setReadDeadline() error {
|
|
ms := readTimeoutMs()
|
|
if ms <= 0 {
|
|
return nil
|
|
}
|
|
return c.conn.GetUDP().SetReadDeadline(time.Now().Add(time.Duration(ms) * time.Millisecond))
|
|
}
|
|
|
|
func readTimeoutMs() int {
|
|
val := os.Getenv("WHSPBRD_TIMEOUT_MS")
|
|
if val == "" {
|
|
return 5000
|
|
}
|
|
ms, err := strconv.Atoi(val)
|
|
if err != nil {
|
|
return 5000
|
|
}
|
|
return ms
|
|
}
|
|
|
|
func (c *client) GetServerId() owner.Identity {
|
|
return c.serverId
|
|
}
|
|
|
|
func (c *client) Register(tokenId uint32, token []byte) error {
|
|
header := c.rawHeader(Rq_Register)
|
|
tokenIdB := make([]byte, bit.SizeUint32_B)
|
|
binary.BigEndian.PutUint32(tokenIdB, tokenId)
|
|
|
|
identity := c.secret.Identity()
|
|
checkHash := sha256.Sum256(append(append(append(header, tokenIdB...), token...), identity[:]...))
|
|
|
|
signData := make([]byte, 0, len(header)+len(tokenIdB)+len(checkHash))
|
|
signData = append(signData, header...)
|
|
signData = append(signData, tokenIdB...)
|
|
signData = append(signData, checkHash[:]...)
|
|
|
|
sign, err := c.secret.Sign(signData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
payload := make([]byte, 0, len(tokenIdB)+len(checkHash)+len(sign))
|
|
payload = append(payload, tokenIdB...)
|
|
payload = append(payload, checkHash[:]...)
|
|
payload = append(payload, sign...)
|
|
|
|
if err := c.rawWrite(header, payload); err != nil {
|
|
return err
|
|
}
|
|
|
|
t, data, id, err := c.receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(id) > 0 {
|
|
c.serverId = id
|
|
}
|
|
|
|
switch t {
|
|
case Rs_Ack:
|
|
return nil
|
|
case Rs_Error:
|
|
return fmt.Errorf("register failed: %s", string(data))
|
|
default:
|
|
return fmt.Errorf("unexpected response type %d", t)
|
|
}
|
|
}
|
|
|
|
func (c *client) Login() error {
|
|
decKey, err := mlkem.GenerateKey1024()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
encapKey := decKey.EncapsulationKey().Bytes()
|
|
header := c.rawHeader(Rq_Login)
|
|
|
|
signData := make([]byte, 0, len(header)+len(encapKey))
|
|
signData = append(signData, header...)
|
|
signData = append(signData, encapKey...)
|
|
|
|
sign, err := c.secret.Sign(signData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
payload := make([]byte, 0, len(encapKey)+len(sign))
|
|
payload = append(payload, encapKey...)
|
|
payload = append(payload, sign...)
|
|
|
|
if err := c.rawWrite(header, payload); err != nil {
|
|
return err
|
|
}
|
|
|
|
t, data, id, err := c.receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(id) > 0 {
|
|
c.serverId = id
|
|
}
|
|
|
|
switch t {
|
|
case Rs_Error:
|
|
return fmt.Errorf("login failed: %s", string(data))
|
|
case Rs_Data:
|
|
shared, err := decKey.Decapsulate(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
key := sha256.Sum256(shared)
|
|
c.aes, err = menc.NewAESGCM_AutoNonce(key[:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.sequence = 0
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unexpected response type %d", t)
|
|
}
|
|
}
|
|
|
|
func (c *client) Write(data []byte) error {
|
|
if c.aes == nil {
|
|
return errors.New("not logged in")
|
|
}
|
|
|
|
header := c.rawHeader(Rq_Data)
|
|
userID := c.secret.Identity()
|
|
seqB := make([]byte, bit.SizeUint32_B)
|
|
binary.BigEndian.PutUint32(seqB, c.sequence)
|
|
|
|
aad := make([]byte, 0, len(header)+len(seqB))
|
|
aad = append(aad, header...)
|
|
aad = append(aad, seqB...)
|
|
|
|
ciphertext, err := c.aes.Encrypt(data, aad)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
payload := make([]byte, 0, len(userID)+len(seqB)+len(ciphertext))
|
|
payload = append(payload, userID[:]...)
|
|
payload = append(payload, seqB...)
|
|
payload = append(payload, ciphertext...)
|
|
|
|
if err := c.rawWrite(header, payload); err != nil {
|
|
return err
|
|
}
|
|
|
|
t, data, id, err := c.receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(id) > 0 {
|
|
c.serverId = id
|
|
}
|
|
|
|
switch t {
|
|
case Rs_Ack:
|
|
c.sequence++
|
|
return nil
|
|
case Rs_Error:
|
|
return fmt.Errorf("write failed: %s", string(data))
|
|
default:
|
|
return fmt.Errorf("unexpected response type %d", t)
|
|
}
|
|
}
|
|
|
|
func (c *client) Read() ([]byte, error) {
|
|
if c.aes == nil {
|
|
return nil, errors.New("not logged in")
|
|
}
|
|
|
|
t, data, id, err := c.receive()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(id) > 0 {
|
|
c.serverId = id
|
|
}
|
|
|
|
switch t {
|
|
case Rs_Data:
|
|
if len(c.last) == 0 {
|
|
return nil, errors.New("missing last request for decryption")
|
|
}
|
|
plain, err := c.aes.Decrypt(data, c.last)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return plain, nil
|
|
case Rs_Error:
|
|
return nil, fmt.Errorf("server error: %s", string(data))
|
|
default:
|
|
return nil, fmt.Errorf("unexpected response type %d", t)
|
|
}
|
|
}
|
|
|
|
func (c *client) rawHeader(reqType PacketReqType) []byte {
|
|
buf := make([]byte, commonHeaderSize)
|
|
|
|
buf[0] = magicBytes[0]
|
|
buf[1] = magicBytes[1]
|
|
buf[2] = magicBytes[2]
|
|
buf[3] = version
|
|
|
|
nowMs := uint64(time.Now().UnixMilli())
|
|
binary.LittleEndian.PutUint64(buf[4:12], nowMs)
|
|
|
|
buf[12] = byte(reqType)
|
|
|
|
return buf
|
|
}
|
|
|
|
func (c *client) rawWrite(header, data []byte) error {
|
|
last := append(header, data...)
|
|
_, err := c.conn.Write(last)
|
|
c.last = last
|
|
return err
|
|
}
|
|
|
|
func (c *client) rawRead(p []byte) (PacketResType, []byte, owner.Identity, error) {
|
|
if len(p) < 37 ||
|
|
p[0] != magicBytes[0] ||
|
|
p[1] != magicBytes[1] ||
|
|
p[2] != magicBytes[2] ||
|
|
p[3] != version {
|
|
return 0, nil, owner.Identity{}, fmt.Errorf("bad packet")
|
|
}
|
|
|
|
tp := PacketResType(p[4])
|
|
if tp >= Rs_Unknown {
|
|
return 0, nil, owner.Identity{}, fmt.Errorf("bad type")
|
|
}
|
|
|
|
hash := p[5:37]
|
|
data := p[37:]
|
|
|
|
expectedHash := sha256.Sum256(c.last)
|
|
if !bytes.Equal(hash, expectedHash[:]) {
|
|
return 0, nil, owner.Identity{}, fmt.Errorf("hash mismatch")
|
|
}
|
|
|
|
if len(data) >= owner.SignatureSize {
|
|
payload := data[:len(data)-owner.SignatureSize]
|
|
sig := data[len(data)-owner.SignatureSize:]
|
|
|
|
buf := make([]byte, len(hash)+len(payload))
|
|
copy(buf, hash)
|
|
copy(buf[len(hash):], payload)
|
|
|
|
id, err := owner.Verify(buf, sig)
|
|
if err == nil {
|
|
return tp, payload, id, nil
|
|
}
|
|
}
|
|
|
|
if tp == Rs_Data {
|
|
return tp, data, owner.Identity{}, nil
|
|
}
|
|
|
|
return 0, nil, owner.Identity{}, fmt.Errorf("bad packet")
|
|
}
|