311 lines
7.2 KiB
Go
311 lines
7.2 KiB
Go
package thrembio
|
|
|
|
// 90% of AI generated code
|
|
// human designed & reviewed
|
|
|
|
import (
|
|
"WhspBrd/owner"
|
|
"bytes"
|
|
"database/sql"
|
|
"net"
|
|
"sync"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
type ServerDB interface {
|
|
// Registration tokens (atomic usage)
|
|
GetRegisterTokens() (map[uint32][]byte, error) // optional read-only snapshot
|
|
SetRegisterToken(id uint32, token []byte) error // add new token
|
|
ConsumeRegisterToken(id uint32) ([]byte, bool, error) // atomic get + remove
|
|
|
|
// Registered users
|
|
GetRegisteredUsers() ([]owner.Identity, error) // optional snapshot
|
|
HasRegisteredUser(user owner.Identity) (bool, error)
|
|
AddRegisteredUserAtomic(user owner.Identity) (alreadyExists bool, err error) // atomic insert
|
|
RemoveRegisteredUser(user owner.Identity) error
|
|
|
|
// Active sessions
|
|
GetActiveSessions() (map[owner.Identity][32]byte, error)
|
|
SetActiveSession(user owner.Identity, secret [32]byte) error
|
|
GetActiveSession(user owner.Identity) ([32]byte, bool, error)
|
|
RemoveActiveSession(user owner.Identity) error
|
|
|
|
/*
|
|
Replay prevention (atomic)
|
|
Not db persistent.
|
|
*/
|
|
CheckAndSetLastReq(user owner.Identity, reqData []byte) (isReplay bool, err error)
|
|
CheckAndSetLastSeq(user owner.Identity, seq uint32) (isReplay bool, err error)
|
|
GetLastReq(user owner.Identity) ([]byte, bool)
|
|
|
|
/*
|
|
Not db persistent
|
|
*/
|
|
SetLastIp(user owner.Identity, ip *net.UDPAddr)
|
|
GetLastIp(user owner.Identity) (*net.UDPAddr, bool)
|
|
|
|
// Close DB
|
|
Close() error
|
|
}
|
|
|
|
type serverDB struct {
|
|
db *sql.DB
|
|
|
|
lastReqsMu sync.Mutex
|
|
lastReqs map[owner.Identity][]byte
|
|
|
|
lastSeqsMu sync.Mutex
|
|
lastSeqs map[owner.Identity]uint32
|
|
|
|
lastIpsMu sync.Mutex
|
|
lastIps map[owner.Identity]*net.UDPAddr
|
|
}
|
|
|
|
func NewSQLiteDB(path string) (ServerDB, error) {
|
|
db, err := sql.Open("sqlite3", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s := &serverDB{
|
|
db: db,
|
|
lastReqs: make(map[owner.Identity][]byte),
|
|
lastSeqs: make(map[owner.Identity]uint32),
|
|
lastIps: make(map[owner.Identity]*net.UDPAddr),
|
|
}
|
|
|
|
if err := s.init(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
func (s *serverDB) init() error {
|
|
_, err := s.db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS register_tokens (
|
|
id INTEGER PRIMARY KEY,
|
|
token BLOB NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS registered_users (
|
|
username BLOB PRIMARY KEY
|
|
);
|
|
CREATE TABLE IF NOT EXISTS active_sessions (
|
|
username BLOB PRIMARY KEY,
|
|
session_id BLOB NOT NULL
|
|
);`)
|
|
return err
|
|
}
|
|
|
|
// --- helpers ---
|
|
func idToBytes(id owner.Identity) []byte {
|
|
b := make([]byte, len(id))
|
|
copy(b, id[:])
|
|
return b
|
|
}
|
|
|
|
func bytesToID(b []byte) owner.Identity {
|
|
var id owner.Identity
|
|
copy(id[:], b)
|
|
return id
|
|
}
|
|
|
|
// --- register tokens ---
|
|
func (s *serverDB) GetRegisterTokens() (map[uint32][]byte, error) {
|
|
rows, err := s.db.Query(`SELECT id, token FROM register_tokens`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
tokens := make(map[uint32][]byte)
|
|
for rows.Next() {
|
|
var id uint32
|
|
var token []byte
|
|
if err := rows.Scan(&id, &token); err != nil {
|
|
return nil, err
|
|
}
|
|
tokens[id] = token
|
|
}
|
|
return tokens, nil
|
|
}
|
|
|
|
func (s *serverDB) SetRegisterToken(id uint32, token []byte) error {
|
|
_, err := s.db.Exec(`
|
|
INSERT INTO register_tokens(id, token)
|
|
VALUES(?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET token = excluded.token
|
|
`, id, token)
|
|
return err
|
|
}
|
|
|
|
// --- ATOMIC: consume token ---
|
|
func (s *serverDB) ConsumeRegisterToken(id uint32) ([]byte, bool, error) {
|
|
row := s.db.QueryRow(`
|
|
DELETE FROM register_tokens
|
|
WHERE id = ?
|
|
RETURNING token
|
|
`, id)
|
|
|
|
var token []byte
|
|
err := row.Scan(&token)
|
|
if err == sql.ErrNoRows {
|
|
return nil, false, nil
|
|
}
|
|
return token, err == nil, err
|
|
}
|
|
|
|
// --- registered users ---
|
|
func (s *serverDB) GetRegisteredUsers() ([]owner.Identity, error) {
|
|
rows, err := s.db.Query(`SELECT username FROM registered_users`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var users []owner.Identity
|
|
for rows.Next() {
|
|
var u []byte
|
|
if err := rows.Scan(&u); err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, bytesToID(u))
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (s *serverDB) HasRegisteredUser(user owner.Identity) (bool, error) {
|
|
var exists int
|
|
err := s.db.QueryRow(
|
|
`SELECT EXISTS(SELECT 1 FROM registered_users WHERE username = ?)`,
|
|
idToBytes(user),
|
|
).Scan(&exists)
|
|
return exists == 1, err
|
|
}
|
|
|
|
// --- ATOMIC: add user if not exists ---
|
|
func (s *serverDB) AddRegisteredUserAtomic(user owner.Identity) (alreadyExists bool, err error) {
|
|
res, err := s.db.Exec(`
|
|
INSERT OR IGNORE INTO registered_users(username) VALUES(?)
|
|
`, idToBytes(user))
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return rows == 0, nil
|
|
}
|
|
|
|
func (s *serverDB) RemoveRegisteredUser(user owner.Identity) error {
|
|
_, err := s.db.Exec(`DELETE FROM registered_users WHERE username = ?`, idToBytes(user))
|
|
return err
|
|
}
|
|
|
|
// --- active sessions ---
|
|
func (s *serverDB) GetActiveSessions() (map[owner.Identity][32]byte, error) {
|
|
rows, err := s.db.Query(`SELECT username, session_id FROM active_sessions`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
sessions := make(map[owner.Identity][32]byte)
|
|
for rows.Next() {
|
|
var username []byte
|
|
var key [32]byte
|
|
if err := rows.Scan(&username, &key); err != nil {
|
|
return nil, err
|
|
}
|
|
sessions[bytesToID(username)] = key
|
|
}
|
|
return sessions, nil
|
|
}
|
|
|
|
func (s *serverDB) SetActiveSession(user owner.Identity, key [32]byte) error {
|
|
_, err := s.db.Exec(`
|
|
INSERT INTO active_sessions(username, session_id)
|
|
VALUES(?, ?)
|
|
ON CONFLICT(username) DO UPDATE SET session_id = excluded.session_id
|
|
`, idToBytes(user), key[:])
|
|
return err
|
|
}
|
|
|
|
func (s *serverDB) GetActiveSession(user owner.Identity) ([32]byte, bool, error) {
|
|
var key [32]byte
|
|
err := s.db.QueryRow(
|
|
`SELECT session_id FROM active_sessions WHERE username = ?`,
|
|
idToBytes(user),
|
|
).Scan(&key)
|
|
if err == sql.ErrNoRows {
|
|
return [32]byte{}, false, nil
|
|
}
|
|
return key, err == nil, err
|
|
}
|
|
|
|
func (s *serverDB) RemoveActiveSession(user owner.Identity) error {
|
|
_, err := s.db.Exec(`DELETE FROM active_sessions WHERE username = ?`, idToBytes(user))
|
|
return err
|
|
}
|
|
|
|
// --- ATOMIC: login replay prevention ---
|
|
func (s *serverDB) CheckAndSetLastReq(user owner.Identity, reqData []byte) (isReplay bool, err error) {
|
|
s.lastReqsMu.Lock()
|
|
defer s.lastReqsMu.Unlock()
|
|
|
|
last, exists := s.lastReqs[user]
|
|
if exists && bytes.Equal(last, reqData) {
|
|
return true, nil
|
|
}
|
|
|
|
s.lastReqs[user] = reqData
|
|
return false, nil
|
|
}
|
|
|
|
func (s *serverDB) CheckAndSetLastSeq(user owner.Identity, seq uint32) (isReplay bool, err error) {
|
|
s.lastSeqsMu.Lock()
|
|
defer s.lastSeqsMu.Unlock()
|
|
|
|
last, exists := s.lastSeqs[user]
|
|
if exists && seq <= last {
|
|
return true, nil
|
|
}
|
|
|
|
s.lastSeqs[user] = seq
|
|
return false, nil
|
|
}
|
|
|
|
func (s *serverDB) GetLastReq(user owner.Identity) ([]byte, bool) {
|
|
s.lastReqsMu.Lock()
|
|
defer s.lastReqsMu.Unlock()
|
|
|
|
req, exists := s.lastReqs[user]
|
|
return req, exists
|
|
}
|
|
|
|
// just ip save
|
|
func (s *serverDB) SetLastIp(user owner.Identity, ip *net.UDPAddr) {
|
|
s.lastIpsMu.Lock()
|
|
defer s.lastIpsMu.Unlock()
|
|
|
|
s.lastIps[user] = ip
|
|
}
|
|
|
|
func (s *serverDB) GetLastIp(user owner.Identity) (*net.UDPAddr, bool) {
|
|
s.lastIpsMu.Lock()
|
|
defer s.lastIpsMu.Unlock()
|
|
|
|
ip, exists := s.lastIps[user]
|
|
return ip, exists
|
|
}
|
|
|
|
// --- close ---
|
|
func (s *serverDB) Close() error { return s.db.Close() }
|