package thrembio // 80% of AI generated code // human designed & reviewed-ish import ( "WhspBrd/menc" "WhspBrd/owner" "WhspBrd/sfudp" "WhspBrd/typio/bit" "bytes" "crypto/mlkem" "crypto/sha256" "encoding/binary" "errors" "fmt" "log" "net" "os" "strconv" "time" ) type Client interface { Close() error ping() error GetServerId() owner.Identity Register(tokenId uint32, token []byte) error Login() error Write(data []byte) error Read() ([]byte, error) } type client struct { conn *sfudp.SFUDPConn secret owner.Secret serverId owner.Identity aes *menc.AESGCM_AutoNonce sequence uint32 last []byte debug bool } func NewClient(addr *net.UDPAddr, secret owner.Secret) (Client, error) { if secret == nil { return nil, ErrSecretCantBeNil } conn, err := sfudp.DialSFUDP("udp", nil, addr) if err != nil { return nil, err } c := &client{ conn: conn, secret: secret, debug: os.Getenv("WHSPBRD_DEBUG") != "", } err = c.ping() if err != nil { return nil, err } return c, nil } func (c *client) Close() error { return c.conn.Close() } func (c *client) ping() error { err := c.rawWrite(c.rawHeader(Rq_Ping), none) if err != nil { return err } t, _, id, err := c.receive() if err != nil { return err } if t != Rs_Ack { return fmt.Errorf("expected ack") } if len(id) <= 0 { return fmt.Errorf("expected valid id") } c.serverId = id return nil } func (c *client) receive() (PacketResType, []byte, owner.Identity, error) { buf := make([]byte, 16384) if err := c.setReadDeadline(); err != nil { return 0, nil, owner.Identity{}, err } n, err := c.conn.Read(buf) if err != nil { return 0, nil, owner.Identity{}, err } if c.debug { log.Printf("thrembio receive bytes=%d", n) } return c.rawRead(buf[:n]) } func (c *client) setReadDeadline() error { ms := readTimeoutMs() if ms <= 0 { return nil } return c.conn.GetUDP().SetReadDeadline(time.Now().Add(time.Duration(ms) * time.Millisecond)) } func readTimeoutMs() int { val := os.Getenv("WHSPBRD_TIMEOUT_MS") if val == "" { return 5000 } ms, err := strconv.Atoi(val) if err != nil { return 5000 } return ms } func (c *client) GetServerId() owner.Identity { return c.serverId } func (c *client) Register(tokenId uint32, token []byte) error { header := c.rawHeader(Rq_Register) tokenIdB := make([]byte, bit.SizeUint32_B) binary.BigEndian.PutUint32(tokenIdB, tokenId) identity := c.secret.Identity() checkHash := sha256.Sum256(append(append(append(header, tokenIdB...), token...), identity[:]...)) signData := make([]byte, 0, len(header)+len(tokenIdB)+len(checkHash)) signData = append(signData, header...) signData = append(signData, tokenIdB...) signData = append(signData, checkHash[:]...) sign, err := c.secret.Sign(signData) if err != nil { return err } payload := make([]byte, 0, len(tokenIdB)+len(checkHash)+len(sign)) payload = append(payload, tokenIdB...) payload = append(payload, checkHash[:]...) payload = append(payload, sign...) if err := c.rawWrite(header, payload); err != nil { return err } t, data, id, err := c.receive() if err != nil { return err } if len(id) > 0 { c.serverId = id } switch t { case Rs_Ack: return nil case Rs_Error: return fmt.Errorf("register failed: %s", string(data)) default: return fmt.Errorf("unexpected response type %d", t) } } func (c *client) Login() error { decKey, err := mlkem.GenerateKey1024() if err != nil { return err } encapKey := decKey.EncapsulationKey().Bytes() header := c.rawHeader(Rq_Login) signData := make([]byte, 0, len(header)+len(encapKey)) signData = append(signData, header...) signData = append(signData, encapKey...) sign, err := c.secret.Sign(signData) if err != nil { return err } payload := make([]byte, 0, len(encapKey)+len(sign)) payload = append(payload, encapKey...) payload = append(payload, sign...) if err := c.rawWrite(header, payload); err != nil { return err } t, data, id, err := c.receive() if err != nil { return err } if len(id) > 0 { c.serverId = id } switch t { case Rs_Error: return fmt.Errorf("login failed: %s", string(data)) case Rs_Data: shared, err := decKey.Decapsulate(data) if err != nil { return err } key := sha256.Sum256(shared) c.aes, err = menc.NewAESGCM_AutoNonce(key[:]) if err != nil { return err } c.sequence = 0 return nil default: return fmt.Errorf("unexpected response type %d", t) } } func (c *client) Write(data []byte) error { if c.aes == nil { return errors.New("not logged in") } header := c.rawHeader(Rq_Data) userID := c.secret.Identity() seqB := make([]byte, bit.SizeUint32_B) binary.BigEndian.PutUint32(seqB, c.sequence) aad := make([]byte, 0, len(header)+len(seqB)) aad = append(aad, header...) aad = append(aad, seqB...) ciphertext, err := c.aes.Encrypt(data, aad) if err != nil { return err } payload := make([]byte, 0, len(userID)+len(seqB)+len(ciphertext)) payload = append(payload, userID[:]...) payload = append(payload, seqB...) payload = append(payload, ciphertext...) if err := c.rawWrite(header, payload); err != nil { return err } t, data, id, err := c.receive() if err != nil { return err } if len(id) > 0 { c.serverId = id } switch t { case Rs_Ack: c.sequence++ return nil case Rs_Error: return fmt.Errorf("write failed: %s", string(data)) default: return fmt.Errorf("unexpected response type %d", t) } } func (c *client) Read() ([]byte, error) { if c.aes == nil { return nil, errors.New("not logged in") } t, data, id, err := c.receive() if err != nil { return nil, err } if len(id) > 0 { c.serverId = id } switch t { case Rs_Data: if len(c.last) == 0 { return nil, errors.New("missing last request for decryption") } plain, err := c.aes.Decrypt(data, c.last) if err != nil { return nil, err } return plain, nil case Rs_Error: return nil, fmt.Errorf("server error: %s", string(data)) default: return nil, fmt.Errorf("unexpected response type %d", t) } } func (c *client) rawHeader(reqType PacketReqType) []byte { buf := make([]byte, commonHeaderSize) buf[0] = magicBytes[0] buf[1] = magicBytes[1] buf[2] = magicBytes[2] buf[3] = version nowMs := uint64(time.Now().UnixMilli()) binary.LittleEndian.PutUint64(buf[4:12], nowMs) buf[12] = byte(reqType) return buf } func (c *client) rawWrite(header, data []byte) error { last := append(header, data...) _, err := c.conn.Write(last) c.last = last return err } func (c *client) rawRead(p []byte) (PacketResType, []byte, owner.Identity, error) { if len(p) < 37 || p[0] != magicBytes[0] || p[1] != magicBytes[1] || p[2] != magicBytes[2] || p[3] != version { return 0, nil, owner.Identity{}, fmt.Errorf("bad packet") } tp := PacketResType(p[4]) if tp >= Rs_Unknown { return 0, nil, owner.Identity{}, fmt.Errorf("bad type") } hash := p[5:37] data := p[37:] expectedHash := sha256.Sum256(c.last) if !bytes.Equal(hash, expectedHash[:]) { return 0, nil, owner.Identity{}, fmt.Errorf("hash mismatch") } if len(data) >= owner.SignatureSize { payload := data[:len(data)-owner.SignatureSize] sig := data[len(data)-owner.SignatureSize:] buf := make([]byte, len(hash)+len(payload)) copy(buf, hash) copy(buf[len(hash):], payload) id, err := owner.Verify(buf, sig) if err == nil { return tp, payload, id, nil } } if tp == Rs_Data { return tp, data, owner.Identity{}, nil } return 0, nil, owner.Identity{}, fmt.Errorf("bad packet") }