whspbrd-final/sfudp/sfudp.go
2026-05-02 22:09:19 +02:00

244 lines
4.7 KiB
Go

// SFUDP = Safely Fragmented UDP
package sfudp
// 80% of AI generated code
// human designed; not human reviewed
import (
"encoding/binary"
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
const (
MaxFragmentSize = 1024
HeaderSize = 8 // [4B ID][2B Index][1B Count][1B Reserved]
MaxPayload = MaxFragmentSize - HeaderSize
MaxFragments = 255
FragmentTTL = 20 * time.Second
)
// fragmentSet holds state for a fragmented message.
type fragmentSet struct {
frags [][]byte
received int
total int
totalSize int
timestamp time.Time
}
// SFUDPConn = classic UDPConn + transparent fragmentation.
type SFUDPConn struct {
udp *net.UDPConn
fragments map[uint32]*fragmentSet
mu sync.Mutex
nextID uint32
stopGC chan struct{}
readBuf []byte
}
func (c *SFUDPConn) GetUDP() *net.UDPConn { return c.udp }
func ListenSFUDP(network string, laddr *net.UDPAddr) (*SFUDPConn, error) {
u, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
return newConn(u), nil
}
func DialSFUDP(network string, laddr, raddr *net.UDPAddr) (*SFUDPConn, error) {
u, err := net.DialUDP(network, laddr, raddr)
if err != nil {
return nil, err
}
return newConn(u), nil
}
func newConn(u *net.UDPConn) *SFUDPConn {
c := &SFUDPConn{
udp: u,
fragments: make(map[uint32]*fragmentSet),
stopGC: make(chan struct{}),
readBuf: make([]byte, MaxFragmentSize),
}
go c.gcLoop()
return c
}
func (c *SFUDPConn) Close() error {
close(c.stopGC)
return c.udp.Close()
}
// periodic cleanup of stale fragment sets
func (c *SFUDPConn) gcLoop() {
t := time.NewTicker(FragmentTTL / 2)
defer t.Stop()
for {
select {
case <-t.C:
now := time.Now()
c.mu.Lock()
for id, fs := range c.fragments {
if now.Sub(fs.timestamp) > FragmentTTL {
delete(c.fragments, id)
}
}
c.mu.Unlock()
case <-c.stopGC:
return
}
}
}
// --- Writing ---
func (c *SFUDPConn) Write(data []byte) (int, error) {
return c.writeFragments(data, nil)
}
func (c *SFUDPConn) WriteToSFUDP(data []byte, addr *net.UDPAddr) (int, error) {
return c.writeFragments(data, addr)
}
func (c *SFUDPConn) writeFragments(data []byte, addr *net.UDPAddr) (int, error) {
if len(data) == 0 {
return 0, nil
}
write := func(b []byte) (int, error) {
if addr != nil {
return c.udp.WriteToUDP(b, addr)
}
return c.udp.Write(b)
}
// single packet
if len(data) <= MaxPayload {
pkt := make([]byte, HeaderSize+len(data))
binary.LittleEndian.PutUint32(pkt[0:4], 0)
binary.LittleEndian.PutUint16(pkt[4:6], 0)
pkt[6] = 1
copy(pkt[HeaderSize:], data)
_, err := write(pkt)
return len(data), err
}
fragCount := (len(data) + MaxPayload - 1) / MaxPayload
if fragCount > MaxFragments {
return 0, errors.New("message too large")
}
id := atomic.AddUint32(&c.nextID, 1)
total := 0
for i := 0; i < fragCount; i++ {
start := i * MaxPayload
end := start + MaxPayload
if end > len(data) {
end = len(data)
}
pkt := make([]byte, HeaderSize+(end-start))
binary.LittleEndian.PutUint32(pkt[0:4], id)
binary.LittleEndian.PutUint16(pkt[4:6], uint16(i))
pkt[6] = uint8(fragCount)
copy(pkt[HeaderSize:], data[start:end])
if _, err := write(pkt); err != nil {
return total, err
}
total += end - start
}
return total, nil
}
// --- Reading ---
func (c *SFUDPConn) Read(b []byte) (int, error) {
n, _, err := c.ReadFromSFUDP(b)
return n, err
}
func (c *SFUDPConn) ReadFromSFUDP(b []byte) (int, *net.UDPAddr, error) {
for {
n, addr, err := c.udp.ReadFromUDP(c.readBuf)
if err != nil {
return 0, nil, err
}
if n < HeaderSize {
continue
}
h := c.readBuf[:HeaderSize]
id := binary.LittleEndian.Uint32(h[0:4])
index := binary.LittleEndian.Uint16(h[4:6])
count := h[6]
if count == 0 || int(index) >= int(count) {
continue
}
payload := c.readBuf[HeaderSize:n]
// single fragment
if id == 0 && count == 1 {
if len(payload) > len(b) {
return 0, addr, errors.New("buffer too small")
}
copy(b, payload)
return len(payload), addr, nil
}
c.mu.Lock()
fs := c.fragments[id]
if fs == nil {
fs = &fragmentSet{
frags: make([][]byte, count),
total: int(count),
timestamp: time.Now(),
}
c.fragments[id] = fs
}
fs.timestamp = time.Now()
c.mu.Unlock()
if fs.frags[index] == nil {
cp := make([]byte, len(payload))
copy(cp, payload)
fs.frags[index] = cp
fs.received++
fs.totalSize += len(cp)
}
if fs.received == fs.total {
if fs.totalSize > len(b) {
c.mu.Lock()
delete(c.fragments, id)
c.mu.Unlock()
return 0, addr, errors.New("buffer too small")
}
off := 0
for _, f := range fs.frags {
copy(b[off:], f)
off += len(f)
}
c.mu.Lock()
delete(c.fragments, id)
c.mu.Unlock()
return off, addr, nil
}
}
}