244 lines
4.7 KiB
Go
244 lines
4.7 KiB
Go
// SFUDP = Safely Fragmented UDP
|
|
package sfudp
|
|
|
|
// 80% of AI generated code
|
|
// human designed; not human reviewed
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
MaxFragmentSize = 1024
|
|
HeaderSize = 8 // [4B ID][2B Index][1B Count][1B Reserved]
|
|
MaxPayload = MaxFragmentSize - HeaderSize
|
|
MaxFragments = 255
|
|
FragmentTTL = 20 * time.Second
|
|
)
|
|
|
|
// fragmentSet holds state for a fragmented message.
|
|
type fragmentSet struct {
|
|
frags [][]byte
|
|
received int
|
|
total int
|
|
totalSize int
|
|
timestamp time.Time
|
|
}
|
|
|
|
// SFUDPConn = classic UDPConn + transparent fragmentation.
|
|
type SFUDPConn struct {
|
|
udp *net.UDPConn
|
|
fragments map[uint32]*fragmentSet
|
|
mu sync.Mutex
|
|
|
|
nextID uint32
|
|
stopGC chan struct{}
|
|
readBuf []byte
|
|
}
|
|
|
|
func (c *SFUDPConn) GetUDP() *net.UDPConn { return c.udp }
|
|
|
|
func ListenSFUDP(network string, laddr *net.UDPAddr) (*SFUDPConn, error) {
|
|
u, err := net.ListenUDP(network, laddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newConn(u), nil
|
|
}
|
|
|
|
func DialSFUDP(network string, laddr, raddr *net.UDPAddr) (*SFUDPConn, error) {
|
|
u, err := net.DialUDP(network, laddr, raddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newConn(u), nil
|
|
}
|
|
|
|
func newConn(u *net.UDPConn) *SFUDPConn {
|
|
c := &SFUDPConn{
|
|
udp: u,
|
|
fragments: make(map[uint32]*fragmentSet),
|
|
stopGC: make(chan struct{}),
|
|
readBuf: make([]byte, MaxFragmentSize),
|
|
}
|
|
go c.gcLoop()
|
|
return c
|
|
}
|
|
|
|
func (c *SFUDPConn) Close() error {
|
|
close(c.stopGC)
|
|
return c.udp.Close()
|
|
}
|
|
|
|
// periodic cleanup of stale fragment sets
|
|
func (c *SFUDPConn) gcLoop() {
|
|
t := time.NewTicker(FragmentTTL / 2)
|
|
defer t.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-t.C:
|
|
now := time.Now()
|
|
c.mu.Lock()
|
|
for id, fs := range c.fragments {
|
|
if now.Sub(fs.timestamp) > FragmentTTL {
|
|
delete(c.fragments, id)
|
|
}
|
|
}
|
|
c.mu.Unlock()
|
|
case <-c.stopGC:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- Writing ---
|
|
|
|
func (c *SFUDPConn) Write(data []byte) (int, error) {
|
|
return c.writeFragments(data, nil)
|
|
}
|
|
|
|
func (c *SFUDPConn) WriteToSFUDP(data []byte, addr *net.UDPAddr) (int, error) {
|
|
return c.writeFragments(data, addr)
|
|
}
|
|
|
|
func (c *SFUDPConn) writeFragments(data []byte, addr *net.UDPAddr) (int, error) {
|
|
if len(data) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
write := func(b []byte) (int, error) {
|
|
if addr != nil {
|
|
return c.udp.WriteToUDP(b, addr)
|
|
}
|
|
return c.udp.Write(b)
|
|
}
|
|
|
|
// single packet
|
|
if len(data) <= MaxPayload {
|
|
pkt := make([]byte, HeaderSize+len(data))
|
|
binary.LittleEndian.PutUint32(pkt[0:4], 0)
|
|
binary.LittleEndian.PutUint16(pkt[4:6], 0)
|
|
pkt[6] = 1
|
|
copy(pkt[HeaderSize:], data)
|
|
_, err := write(pkt)
|
|
return len(data), err
|
|
}
|
|
|
|
fragCount := (len(data) + MaxPayload - 1) / MaxPayload
|
|
if fragCount > MaxFragments {
|
|
return 0, errors.New("message too large")
|
|
}
|
|
|
|
id := atomic.AddUint32(&c.nextID, 1)
|
|
total := 0
|
|
|
|
for i := 0; i < fragCount; i++ {
|
|
start := i * MaxPayload
|
|
end := start + MaxPayload
|
|
if end > len(data) {
|
|
end = len(data)
|
|
}
|
|
|
|
pkt := make([]byte, HeaderSize+(end-start))
|
|
binary.LittleEndian.PutUint32(pkt[0:4], id)
|
|
binary.LittleEndian.PutUint16(pkt[4:6], uint16(i))
|
|
pkt[6] = uint8(fragCount)
|
|
|
|
copy(pkt[HeaderSize:], data[start:end])
|
|
|
|
if _, err := write(pkt); err != nil {
|
|
return total, err
|
|
}
|
|
total += end - start
|
|
}
|
|
|
|
return total, nil
|
|
}
|
|
|
|
// --- Reading ---
|
|
|
|
func (c *SFUDPConn) Read(b []byte) (int, error) {
|
|
n, _, err := c.ReadFromSFUDP(b)
|
|
return n, err
|
|
}
|
|
|
|
func (c *SFUDPConn) ReadFromSFUDP(b []byte) (int, *net.UDPAddr, error) {
|
|
for {
|
|
n, addr, err := c.udp.ReadFromUDP(c.readBuf)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
if n < HeaderSize {
|
|
continue
|
|
}
|
|
|
|
h := c.readBuf[:HeaderSize]
|
|
id := binary.LittleEndian.Uint32(h[0:4])
|
|
index := binary.LittleEndian.Uint16(h[4:6])
|
|
count := h[6]
|
|
|
|
if count == 0 || int(index) >= int(count) {
|
|
continue
|
|
}
|
|
|
|
payload := c.readBuf[HeaderSize:n]
|
|
|
|
// single fragment
|
|
if id == 0 && count == 1 {
|
|
if len(payload) > len(b) {
|
|
return 0, addr, errors.New("buffer too small")
|
|
}
|
|
copy(b, payload)
|
|
return len(payload), addr, nil
|
|
}
|
|
|
|
c.mu.Lock()
|
|
fs := c.fragments[id]
|
|
if fs == nil {
|
|
fs = &fragmentSet{
|
|
frags: make([][]byte, count),
|
|
total: int(count),
|
|
timestamp: time.Now(),
|
|
}
|
|
c.fragments[id] = fs
|
|
}
|
|
fs.timestamp = time.Now()
|
|
c.mu.Unlock()
|
|
|
|
if fs.frags[index] == nil {
|
|
cp := make([]byte, len(payload))
|
|
copy(cp, payload)
|
|
fs.frags[index] = cp
|
|
fs.received++
|
|
fs.totalSize += len(cp)
|
|
}
|
|
|
|
if fs.received == fs.total {
|
|
if fs.totalSize > len(b) {
|
|
c.mu.Lock()
|
|
delete(c.fragments, id)
|
|
c.mu.Unlock()
|
|
return 0, addr, errors.New("buffer too small")
|
|
}
|
|
|
|
off := 0
|
|
for _, f := range fs.frags {
|
|
copy(b[off:], f)
|
|
off += len(f)
|
|
}
|
|
|
|
c.mu.Lock()
|
|
delete(c.fragments, id)
|
|
c.mu.Unlock()
|
|
|
|
return off, addr, nil
|
|
}
|
|
}
|
|
}
|