diff --git a/pkg/cmd/console.go b/pkg/cmd/console.go index 5f45198..f5db7bf 100644 --- a/pkg/cmd/console.go +++ b/pkg/cmd/console.go @@ -46,7 +46,7 @@ func (l *LoggerQueue) Error(message string) { } func (l *LoggerQueue) processLogs(ctx context.Context) { - ticker := time.NewTicker(10 * time.Millisecond) + ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() for { @@ -113,7 +113,7 @@ loop: Logger.Info(fmt.Sprintf("Unknown command: %s", input)) } input = "" - case keyboard.KeyBackspace | keyboard.KeyBackspace2: + case keyboard.KeyBackspace, keyboard.KeyBackspace2: if len(input) > 0 { input = input[:len(input)-1] } diff --git a/pkg/net/conn.go b/pkg/net/conn.go new file mode 100644 index 0000000..6bf18db --- /dev/null +++ b/pkg/net/conn.go @@ -0,0 +1,821 @@ +package net + +import ( + "bytes" + "cimeyclust.com/steel/pkg/net/packets" + "cimeyclust.com/steel/pkg/utils" + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + // currentProtocol is the current RakNet protocol version. This is Minecraft + // specific. + currentProtocol byte = 11 + + maxMTUSize = 1400 + maxWindowSize = 2048 +) + +// Conn represents a connection to a specific client. It is not a real +// connection, as UDP is connectionless, but rather a connection emulated using +// RakNet. Methods may be called on Conn from multiple goroutines +// simultaneously. +type Conn struct { + // rtt is the last measured round-trip time between both ends of the + // connection. The rtt is measured in nanoseconds. + rtt atomic.Int64 + + closing atomic.Int64 + + conn net.PacketConn + addr net.Addr + limits bool + + once sync.Once + closed, connected chan struct{} + close func() + + mu sync.Mutex + buf *bytes.Buffer + + ackBuf, nackBuf *bytes.Buffer + + pk *packets.Packet + + seq, orderIndex, messageIndex utils.Uint24 + splitID uint32 + + // mtuSize is the MTU size of the connection. Packets longer than this size + // must be split into fragments for them to arrive at the client without + // losing bytes. + mtuSize uint16 + + // splits is a map of slices indexed by split IDs. The length of each of the + // slices is equal to the split count, and packets are positioned in that + // slice indexed by the split index. + splits map[uint16][][]byte + + // win is an ordered queue used to track which datagrams were received and + // which datagrams were missing, so that we can send NACKs to request + // missing datagrams. + win *datagramWindow + + ackMu sync.Mutex + // ackSlice is a slice containing sequence numbers of datagrams that were + // received over the last second. When ticked, all of these packets are sent + // in an ACK and the slice is cleared. + ackSlice []utils.Uint24 + + // packetQueue is an ordered queue containing packets indexed by their order + // index. + packetQueue *packetQueue + // packets is a channel containing content of packets that were fully + // processed. Calling Conn.Read() consumes a value from this channel. + packets chan *bytes.Buffer + + // retransmission is a queue filled with packets that were sent with a given + // datagram sequence number. + retransmission *resendMap + + // readDeadline is a channel that receives a time.Time after a specific + // time. It is used to listen for timeouts in Read after calling + // SetReadDeadline. + readDeadline <-chan time.Time + + lastActivity atomic.Pointer[time.Time] +} + +// newConn constructs a new connection specifically dedicated to the address +// passed. +func newConn(conn net.PacketConn, addr net.Addr, mtuSize uint16) *Conn { + return newConnWithLimits(conn, addr, mtuSize, true) +} + +// newConnWithLimits returns a Conn for the net.Addr passed with a specific mtu +// size. The limits bool passed specifies if the connection should limit the +// bounds of things such as the size of packets. This is generally recommended +// for connections coming from a client. +func newConnWithLimits(conn net.PacketConn, addr net.Addr, mtuSize uint16, limits bool) *Conn { + if mtuSize < 500 || mtuSize > 1500 { + mtuSize = maxMTUSize + } + c := &Conn{ + addr: addr, + conn: conn, + limits: limits, + mtuSize: mtuSize, + pk: new(packets.Packet), + closed: make(chan struct{}), + connected: make(chan struct{}), + packets: make(chan *bytes.Buffer, 512), + splits: make(map[uint16][][]byte), + win: newDatagramWindow(), + packetQueue: newPacketQueue(), + retransmission: newRecoveryQueue(), + buf: bytes.NewBuffer(make([]byte, 0, mtuSize)), + ackBuf: bytes.NewBuffer(make([]byte, 0, 256)), + nackBuf: bytes.NewBuffer(make([]byte, 0, 256)), + } + t := time.Now() + c.lastActivity.Store(&t) + go c.startTicking() + return c +} + +// startTicking makes the connection start ticking, sending ACKs and pings to +// the other end where necessary and checking if the connection should be timed +// out. +func (conn *Conn) startTicking() { + var ( + interval = time.Second / 10 + ticker = time.NewTicker(interval) + i int64 + acksLeft int + ) + defer ticker.Stop() + for { + select { + case t := <-ticker.C: + i++ + conn.flushACKs() + if i%2 == 0 { + // We send a connected ping to calculate the rtt and let the + // other side know we haven't timed out. + conn.sendPing() + } + if i%3 == 0 { + conn.checkResend(t) + } + if i%5 == 0 { + conn.mu.Lock() + if t.Sub(*conn.lastActivity.Load()) > time.Second*5+conn.retransmission.rtt()*2 { + // No activity for too long: Start timeout. + _ = conn.Close() + } + conn.mu.Unlock() + } + if unix := conn.closing.Load(); unix != 0 { + before := acksLeft + conn.mu.Lock() + acksLeft = len(conn.retransmission.unacknowledged) + conn.mu.Unlock() + + if before != 0 && acksLeft == 0 { + _ = conn.Close() + } + + since := time.Since(time.Unix(unix, 0)) + if (acksLeft == 0 && since > time.Second) || since > time.Second*8 { + conn.closeImmediately() + } + } + case <-conn.closed: + return + } + } +} + +// flushACKs flushes all pending datagram acknowledgements. +func (conn *Conn) flushACKs() { + conn.ackMu.Lock() + defer conn.ackMu.Unlock() + + if len(conn.ackSlice) > 0 { + // Write an ACK packet to the connection containing all datagram + // sequence numbers that we received since the last tick. + if err := conn.sendACK(conn.ackSlice...); err != nil { + return + } + conn.ackSlice = conn.ackSlice[:0] + } +} + +// checkResend checks if the connection needs to resend any packets. It sends +// an ACK for packets it has received and sends any packets that have been +// pending for too long. +func (conn *Conn) checkResend(now time.Time) { + conn.mu.Lock() + defer conn.mu.Unlock() + + var ( + resend []utils.Uint24 + rtt = conn.retransmission.rtt() + delay = rtt + rtt/2 + ) + conn.rtt.Store(int64(rtt)) + + for seq, t := range conn.retransmission.unacknowledged { + // These packets have not been acknowledged for too long: We resend them + // by ourselves, even though no NACK has been issued yet. + if now.Sub(t.timestamp) > delay { + resend = append(resend, seq) + } + } + _ = conn.resend(resend) +} + +// Write writes a buffer b over the RakNet connection. The amount of bytes +// written n is always equal to the length of the bytes written if writing was +// successful. If not, an error is returned and n is 0. Write may be called +// simultaneously from multiple goroutines, but will write one by one. +func (conn *Conn) Write(b []byte) (n int, err error) { + select { + case <-conn.closed: + return 0, conn.wrap(net.ErrClosed, "write") + default: + conn.mu.Lock() + defer conn.mu.Unlock() + n, err := conn.write(b) + return n, conn.wrap(err, "write") + } +} + +// write writes a buffer b over the RakNet connection. The amount of bytes +// written n is always equal to the length of the bytes written if the write +// was successful. If not, an error is returned and n is 0. Write may be called +// simultaneously from multiple goroutines, but will write one by one. Unlike +// Write, write will not lock. +func (conn *Conn) write(b []byte) (n int, err error) { + fragments := conn.split(b) + orderIndex := conn.orderIndex + conn.orderIndex++ + + splitID := uint16(conn.splitID) + split := len(fragments) > 1 + if split { + conn.splitID++ + } + for splitIndex, content := range fragments { + sequenceNumber := conn.seq + conn.seq++ + messageIndex := conn.messageIndex + conn.messageIndex++ + + conn.buf.WriteByte(packets.BitFlagDatagram | packets.BitFlagNeedsBAndAS) + utils.WriteUint24(conn.buf, sequenceNumber) + pk := packetPool.Get().(*packets.Packet) + if cap(pk.Content) < len(content) { + pk.Content = make([]byte, len(content)) + } + // We set the actual slice size to the same size as the content. It + // might be bigger than the previous size, in which case it will grow, + // which is fine as the underlying array will always be big enough. + pk.Content = pk.Content[:len(content)] + copy(pk.Content, content) + + pk.OrderIndex = orderIndex + pk.MessageIndex = messageIndex + + pk.Split = split + if split { + // If there were more than one fragment, the pk was split, so we + // need to make sure we set the appropriate fields. + pk.SplitCount = uint32(len(fragments)) + pk.SplitIndex = uint32(splitIndex) + pk.SplitID = splitID + } + pk.Write(conn.buf) + // We then send the pk to the connection. + if _, err := conn.conn.WriteTo(conn.buf.Bytes(), conn.addr); err != nil { + return 0, net.ErrClosed + } + + // We reset the buffer so that we can re-use it for each fragment + // created when splitting the packet. + conn.buf.Reset() + + // Finally we add the pk to the recovery queue. + conn.retransmission.add(sequenceNumber, pk) + n += len(content) + } + return +} + +// Read reads from the connection into the byte slice passed. If successful, +// the amount of bytes read n is returned, and the error returned will be nil. +// Read blocks until a packet is received over the connection, or until the +// session is closed or the read times out, in which case an error is returned. +func (conn *Conn) Read(b []byte) (n int, err error) { + select { + case pk := <-conn.packets: + if len(b) < pk.Len() { + err = conn.wrap(errBufferTooSmall, "read") + } + return copy(b, pk.Bytes()), err + case <-conn.closed: + return 0, conn.wrap(net.ErrClosed, "read") + case <-conn.readDeadline: + return 0, conn.wrap(context.DeadlineExceeded, "read") + } +} + +// ReadPacket attempts to read the next packet as a byte slice. ReadPacket +// blocks until a packet is received over the connection, or until the session +// is closed or the read times out, in which case an error is returned. +func (conn *Conn) ReadPacket() (b []byte, err error) { + select { + case packet := <-conn.packets: + return packet.Bytes(), err + case <-conn.closed: + return nil, conn.wrap(net.ErrClosed, "read") + case <-conn.readDeadline: + return nil, conn.wrap(context.DeadlineExceeded, "read") + } +} + +// Close closes the connection. All blocking Read or Write actions are +// cancelled and will return an error, as soon as the closing of the connection +// is acknowledged by the client. +func (conn *Conn) Close() error { + conn.closing.CompareAndSwap(0, time.Now().Unix()) + return nil +} + +// closeImmediately sends a Disconnect notification to the other end of the +// connection and closes the underlying UDP connection immediately. +func (conn *Conn) closeImmediately() { + conn.once.Do(func() { + _, _ = conn.Write([]byte{packets.IDDisconnectNotification}) + close(conn.closed) + if conn.close != nil { + conn.close() + conn.close = nil + } + }) +} + +// RemoteAddr returns the remote address of the connection, meaning the address +// this connection leads to. +func (conn *Conn) RemoteAddr() net.Addr { + return conn.addr +} + +// LocalAddr returns the local address of the connection, which is always the +// same as the listener's. +func (conn *Conn) LocalAddr() net.Addr { + return conn.conn.LocalAddr() +} + +// SetReadDeadline sets the read deadline of the connection. An error is +// returned only if the time passed is before time.Now(). Calling +// SetReadDeadline means the next Read call that exceeds the deadline will fail +// and return an error. Setting the read deadline to the default value of +// time.Time removes the deadline. +func (conn *Conn) SetReadDeadline(t time.Time) error { + if t.IsZero() { + conn.readDeadline = make(chan time.Time) + return nil + } + if t.Before(time.Now()) { + panic(fmt.Errorf("read deadline cannot be before now")) + } + conn.readDeadline = time.After(time.Until(t)) + return nil +} + +// SetWriteDeadline has no behaviour. It is merely there to satisfy the +// net.Conn interface. +func (conn *Conn) SetWriteDeadline(time.Time) error { + return nil +} + +// SetDeadline sets the deadline of the connection for both Read and Write. +// SetDeadline is equivalent to calling both SetReadDeadline and +// SetWriteDeadline. +func (conn *Conn) SetDeadline(t time.Time) error { + return conn.SetReadDeadline(t) +} + +// Latency returns a rolling average of rtt between the sending and the +// receiving end of the connection. The rtt returned is updated continuously +// and is half the average round trip time (RTT). +func (conn *Conn) Latency() time.Duration { + return time.Duration(conn.rtt.Load() / 2) +} + +// sendPing pings the connection, updating the rtt of the Conn if successful. +func (conn *Conn) sendPing() { + b := bytes.NewBuffer(nil) + (&packets.ConnectedPing{ClientTimestamp: timestamp()}).Write(b) + _, _ = conn.Write(b.Bytes()) +} + +// packetPool is a sync.Pool used to pool packets that encapsulate their +// content. +var packetPool = sync.Pool{ + New: func() interface{} { + return &packets.Packet{Reliability: packets.ReliabilityReliableOrdered} + }, +} + +const ( + // Datagram header + + // Datagram sequence number + + // Packet header + + // Packet content length + + // Packet message index + + // Packet order index + + // Packet order channel + packetAdditionalSize = 1 + 3 + 1 + 2 + 3 + 3 + 1 + // Packet split count + + // Packet split ID + + // Packet split index + splitAdditionalSize = 4 + 2 + 4 +) + +// split splits a content buffer in smaller buffers so that they do not exceed +// the MTU size that the connection holds. +func (conn *Conn) split(b []byte) [][]byte { + maxSize := int(conn.mtuSize-packetAdditionalSize) - 28 + contentLength := len(b) + if contentLength > maxSize { + // If the content size is bigger than the maximum size here, it means + // the packet will get split. This means that the packet will get even + // bigger because a split packet uses 4 + 2 + 4 more bytes. + maxSize -= splitAdditionalSize + } + fragmentCount := contentLength / maxSize + if contentLength%maxSize != 0 { + // If the content length can't be divided by maxSize perfectly, we need + // to reserve another fragment for the last bit of the packet. + fragmentCount++ + } + fragments := make([][]byte, fragmentCount) + + buf := bytes.NewBuffer(b) + for i := 0; i < fragmentCount; i++ { + // Take a piece out of the content with the size of maxSize. + fragments[i] = buf.Next(maxSize) + } + return fragments +} + +// receive receives a packet from the connection, handling it as appropriate. +// If not successful, an error is returned. +func (conn *Conn) receive(b *bytes.Buffer) error { + headerFlags, err := b.ReadByte() + if err != nil { + return fmt.Errorf("error reading datagram header flags: %v", err) + } + if headerFlags&packets.BitFlagDatagram == 0 { + // Ignore packets that do not have the datagram bitflag. + return nil + } + t := time.Now() + conn.lastActivity.Store(&t) + switch { + case headerFlags&packets.BitFlagACK != 0: + return conn.handleACK(b) + case headerFlags&packets.BitFlagNACK != 0: + return conn.handleNACK(b) + default: + return conn.receiveDatagram(b) + } +} + +// receiveDatagram handles the receiving of a datagram found in buffer b. If +// successful, all packets inside the datagram are handled. if not, an error is +// returned. +func (conn *Conn) receiveDatagram(b *bytes.Buffer) error { + seq, err := utils.ReadUint24(b) + if err != nil { + return fmt.Errorf("error reading datagram sequence number: %v", err) + } + conn.ackMu.Lock() + // Add this sequence number to the received datagrams, so that it is + // included in an ACK. + conn.ackSlice = append(conn.ackSlice, seq) + conn.ackMu.Unlock() + + if !conn.win.new(seq) { + // Datagram was already received, this might happen if a packet took a long time to arrive, and we already sent + // a NACK for it. This is expected to happen sometimes under normal circumstances, so no reason to return an + // error. + return nil + } + conn.win.add(seq) + if conn.win.shift() == 0 { + // Datagram window couldn't be shifted up, so we're still missing + // packets. + rtt := time.Duration(conn.rtt.Load()) + if missing := conn.win.missing(rtt + rtt/2); len(missing) > 0 { + if err = conn.sendNACK(missing); err != nil { + return fmt.Errorf("error sending NACK to request datagrams: %v", err) + } + } + } + if conn.win.size() > maxWindowSize && conn.limits { + return fmt.Errorf("datagram receive queue window size is too big (%v-%v)", conn.win.lowest, conn.win.highest) + } + return conn.handleDatagram(b) +} + +// handleDatagram handles the contents of a datagram encoded in a bytes.Buffer. +func (conn *Conn) handleDatagram(b *bytes.Buffer) error { + for b.Len() > 0 { + if err := conn.pk.Read(b); err != nil { + return fmt.Errorf("error decoding datagram packet: %v", err) + } + handle := conn.receivePacket + if conn.pk.Split { + handle = conn.receiveSplitPacket + } + if err := handle(conn.pk); err != nil { + return fmt.Errorf("error handling packet in datagram: %v", err) + } + } + return nil +} + +// receivePacket handles the receiving of a packet. It puts the packet in the +// queue and takes out all packets that were obtainable after that, and handles +// them. +func (conn *Conn) receivePacket(packet *packets.Packet) error { + if packet.Reliability != packets.ReliabilityReliableOrdered { + // If it isn't a reliable ordered packet, handle it immediately. + return conn.handlePacket(packet.Content) + } + if !conn.packetQueue.put(packet.OrderIndex, packet.Content) { + // An ordered packet arrived twice. + return nil + } + if conn.packetQueue.WindowSize() > maxWindowSize && conn.limits { + return fmt.Errorf("packet queue window size is too big (%v-%v)", conn.packetQueue.lowest, conn.packetQueue.highest) + } + for _, content := range conn.packetQueue.fetch() { + if err := conn.handlePacket(content); err != nil { + return fmt.Errorf("error handling packet: %v", err) + } + } + return nil +} + +// handlePacket handles a packet serialised in byte slice b. If not successful, +// an error is returned. If the packet was not handled by RakNet, it is sent to +// the packet channel. +func (conn *Conn) handlePacket(b []byte) error { + buffer := bytes.NewBuffer(b) + id, err := buffer.ReadByte() + if err != nil { + return fmt.Errorf("error reading packet ID: %v", err) + } + + switch id { + case packets.IDConnectionRequest: + return conn.handleConnectionRequest(buffer) + case packets.IDConnectionRequestAccepted: + return conn.handleConnectionRequestAccepted(buffer) + case packets.IDNewIncomingConnection: + select { + case <-conn.connected: + default: + close(conn.connected) + } + case packets.IDConnectedPing: + return conn.handleConnectedPing(buffer) + case packets.IDConnectedPong: + return conn.handleConnectedPong(buffer) + case packets.IDDisconnectNotification: + conn.closeImmediately() + case packets.IDDetectLostConnections: + // Let the other end know the connection is still alive. + conn.sendPing() + default: + _ = buffer.UnreadByte() + // Insert the packet contents the packet queue could release in the + // channel so that Conn.Read() can get a hold of them, but always first + // try to escape if the connection was closed. + select { + case <-conn.closed: + case conn.packets <- buffer: + } + } + return nil +} + +// handleConnectedPing handles a connected ping packet inside of buffer b. An +// error is returned if the packet was invalid. +func (conn *Conn) handleConnectedPing(b *bytes.Buffer) error { + packet := &packets.ConnectedPing{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading connected ping: %v", err) + } + b.Reset() + + // Respond with a connected pong that has the ping timestamp found in the + // connected ping, and our own timestamp for the pong timestamp. + (&packets.ConnectedPong{ClientTimestamp: packet.ClientTimestamp, ServerTimestamp: timestamp()}).Write(b) + _, err := conn.Write(b.Bytes()) + return err +} + +// handleConnectedPong handles a connected pong packet inside of buffer b. An +// error is returned if the packet was invalid. +func (conn *Conn) handleConnectedPong(b *bytes.Buffer) error { + packet := &packets.ConnectedPong{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading connected pong: %v", err) + } + if packet.ClientTimestamp > timestamp() { + return fmt.Errorf("error measuring rtt: ping timestamp is in the future") + } + // We don't actually use the ConnectedPong to measure rtt. It is too + // unreliable and doesn't give a good idea of the connection quality. + return nil +} + +// handleConnectionRequest handles a connection request packet inside of buffer +// b. An error is returned if the packet was invalid. +func (conn *Conn) handleConnectionRequest(b *bytes.Buffer) error { + packet := &packets.ConnectionRequest{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading connection request: %v", err) + } + b.Reset() + (&packets.ConnectionRequestAccepted{ClientAddress: *conn.addr.(*net.UDPAddr), RequestTimestamp: packet.RequestTimestamp, AcceptedTimestamp: timestamp()}).Write(b) + _, err := conn.Write(b.Bytes()) + return err +} + +// handleConnectionRequestAccepted handles a serialised connection request +// accepted packet in b, and returns an error if not successful. +func (conn *Conn) handleConnectionRequestAccepted(b *bytes.Buffer) error { + packet := &packets.ConnectionRequestAccepted{} + _ = packet.Read(b) + b.Reset() + + (&packets.NewIncomingConnection{ServerAddress: *conn.addr.(*net.UDPAddr), RequestTimestamp: packet.RequestTimestamp, AcceptedTimestamp: packet.AcceptedTimestamp, SystemAddresses: packet.SystemAddresses}).Write(b) + _, err := conn.Write(b.Bytes()) + + select { + case <-conn.connected: + default: + close(conn.connected) + } + return err +} + +// receiveSplitPacket handles a passed split packet. If it is the last split +// packet of its sequence, it will continue handling the full packet as it +// otherwise would. An error is returned if the packet was not valid. +func (conn *Conn) receiveSplitPacket(p *packets.Packet) error { + const maxSplitCount = 256 + if (p.SplitCount > maxSplitCount || len(conn.splits) > maxSplitCount) && conn.limits { + return fmt.Errorf("split count %v (%v active) exceeds the maximum %v", p.SplitCount, len(conn.splits), maxSplitCount) + } + m, ok := conn.splits[p.SplitID] + if !ok { + m = make([][]byte, p.SplitCount) + conn.splits[p.SplitID] = m + } + if p.SplitIndex > uint32(len(m)-1) { + // The split index was either negative or was bigger than the slice + // size, meaning the packet is invalid. + return fmt.Errorf("error handing split packet: split index %v is out of range (0 - %v)", p.SplitIndex, len(m)-1) + } + m[p.SplitIndex] = p.Content + + size := 0 + for _, fragment := range m { + if len(fragment) == 0 { + // We haven't yet received all split fragments, so we cannot add the packets together yet. + return nil + } + // First we calculate the total size required to hold the content of the + // combined content. + size += len(fragment) + } + + content := make([]byte, 0, size) + for _, fragment := range m { + content = append(content, fragment...) + } + + delete(conn.splits, p.SplitID) + + p.Content = content + return conn.receivePacket(p) +} + +// sendACK sends an acknowledgement packet containing the packet sequence +// numbers passed. If not successful, an error is returned. +func (conn *Conn) sendACK(received ...utils.Uint24) error { + defer conn.ackBuf.Reset() + return conn.sendAcknowledgement(received, packets.BitFlagACK, conn.ackBuf) +} + +// sendNACK sends an acknowledgement packet containing the packet sequence +// numbers passed. If not successful, an error is returned. +func (conn *Conn) sendNACK(received []utils.Uint24) error { + defer conn.nackBuf.Reset() + return conn.sendAcknowledgement(received, packets.BitFlagNACK, conn.nackBuf) +} + +// sendAcknowledgement sends an acknowledgement packet with the packets passed, +// potentially sending multiple if too many packets are passed. The bitflag is +// added to the header byte. +func (conn *Conn) sendAcknowledgement(received []utils.Uint24, bitflag byte, buf *bytes.Buffer) error { + ack := &packets.Acknowledgement{Packets: received} + + for len(ack.Packets) != 0 { + buf.WriteByte(bitflag | packets.BitFlagDatagram) + n, err := ack.Write(buf, conn.mtuSize) + if err != nil { + panic(fmt.Sprintf("error encoding ACK packet: %v", err)) + } + // We managed to write n packets in the ACK with this MTU size, write + // the next of the packets in a new ACK. + ack.Packets = ack.Packets[n:] + if _, err := conn.conn.WriteTo(buf.Bytes(), conn.addr); err != nil { + return fmt.Errorf("error sending ACK packet: %v", err) + } + buf.Reset() + } + return nil +} + +// handleACK handles an acknowledgement packet from the other end of the +// connection. These mean that a datagram was successfully received by the +// other end. +func (conn *Conn) handleACK(b *bytes.Buffer) error { + conn.mu.Lock() + defer conn.mu.Unlock() + + ack := &packets.Acknowledgement{} + if err := ack.Read(b); err != nil { + return fmt.Errorf("error reading ACK: %v", err) + } + for _, sequenceNumber := range ack.Packets { + // Take out all stored packets from the recovery queue. + p, ok := conn.retransmission.acknowledge(sequenceNumber) + if ok { + // Clear the packet and return it to the pool so that it may be + // re-used. + p.Content = nil + packetPool.Put(p) + } + } + return nil +} + +// handleNACK handles a negative acknowledgment packet from the other end of +// the connection. These mean that a datagram was found missing. +func (conn *Conn) handleNACK(b *bytes.Buffer) error { + conn.mu.Lock() + defer conn.mu.Unlock() + + nack := &packets.Acknowledgement{} + if err := nack.Read(b); err != nil { + return fmt.Errorf("error reading NACK: %v", err) + } + return conn.resend(nack.Packets) +} + +// resend sends all datagrams currently in the recovery queue with the sequence +// numbers passed. +func (conn *Conn) resend(sequenceNumbers []utils.Uint24) (err error) { + for _, sequenceNumber := range sequenceNumbers { + pk, ok := conn.retransmission.retransmit(sequenceNumber) + if !ok { + // We could not resend this datagram. Maybe it was already resent + // before at the request of the client. This is generally expected + // so we just continue. + continue + } + + // We first write a new datagram header using a new send sequence number + // that we find. + if err := conn.buf.WriteByte(packets.BitFlagDatagram | packets.BitFlagNeedsBAndAS); err != nil { + return fmt.Errorf("error writing recovered datagram header: %v", err) + } + newSeqNum := conn.seq + conn.seq++ + utils.WriteUint24(conn.buf, newSeqNum) + pk.Write(conn.buf) + + // We then send the pk to the connection. + if _, err := conn.conn.WriteTo(conn.buf.Bytes(), conn.addr); err != nil { + return fmt.Errorf("error sending pk to addr %v: %v", conn.addr, err) + } + // We then re-add the pk to the recovery queue in case the new one gets + // lost too, in which case we need to resend it again. + conn.retransmission.add(newSeqNum, pk) + conn.buf.Reset() + } + return nil +} + +// requestConnection requests the connection from the server, provided this +// connection operates as a client. An error occurs if the request was not +// successful. +func (conn *Conn) requestConnection(id int64) error { + b := bytes.NewBuffer(nil) + (&packets.ConnectionRequest{ClientGUID: id, RequestTimestamp: timestamp()}).Write(b) + _, err := conn.Write(b.Bytes()) + return err +} diff --git a/pkg/net/datagram_window.go b/pkg/net/datagram_window.go new file mode 100644 index 0000000..6fe54ca --- /dev/null +++ b/pkg/net/datagram_window.go @@ -0,0 +1,81 @@ +package net + +import ( + "cimeyclust.com/steel/pkg/utils" + "time" +) + +// datagramWindow is a queue for incoming datagrams. +type datagramWindow struct { + lowest, highest utils.Uint24 + queue map[utils.Uint24]time.Time +} + +// newDatagramWindow returns a new initialised datagram window. +func newDatagramWindow() *datagramWindow { + return &datagramWindow{queue: make(map[utils.Uint24]time.Time)} +} + +// new checks if the index passed is new to the datagramWindow. +func (win *datagramWindow) new(index utils.Uint24) bool { + if index < win.lowest { + return true + } + _, ok := win.queue[index] + return !ok +} + +// add puts an index in the window. +func (win *datagramWindow) add(index utils.Uint24) { + if index >= win.highest { + win.highest = index + 1 + } + win.queue[index] = time.Now() +} + +// shift attempts to delete as many indices from the queue as possible, +// increasing the lowest index if and when possible. +func (win *datagramWindow) shift() (n int) { + var index utils.Uint24 + for index = win.lowest; index < win.highest; index++ { + if _, ok := win.queue[index]; !ok { + break + } + delete(win.queue, index) + n++ + } + win.lowest = index + return n +} + +// missing returns a slice of all indices in the datagram queue that weren't +// set using add while within the window of lowest and highest index. The queue +// is shifted after this call. +func (win *datagramWindow) missing(since time.Duration) (indices []utils.Uint24) { + var ( + missing = false + ) + for index := int(win.highest) - 1; index >= int(win.lowest); index-- { + i := utils.Uint24(index) + t, ok := win.queue[i] + if ok { + if time.Since(t) >= since { + // All packets before this one took too long to arrive, so we + // mark them as missing. + missing = true + } + continue + } + if missing { + indices = append(indices, i) + win.queue[i] = time.Time{} + } + } + win.shift() + return indices +} + +// size returns the size of the datagramWindow. +func (win *datagramWindow) size() utils.Uint24 { + return win.highest - win.lowest +} diff --git a/pkg/net/err.go b/pkg/net/err.go new file mode 100644 index 0000000..76efd31 --- /dev/null +++ b/pkg/net/err.go @@ -0,0 +1,36 @@ +package net + +import ( + "errors" + "net" + "strings" +) + +var ( + errBufferTooSmall = errors.New("a message sent was larger than the buffer used to receive the message into") + errListenerClosed = errors.New("use of closed listener") +) + +// ErrConnectionClosed checks if the error passed was an error caused by +// reading from a Conn of which the connection was closed. +func ErrConnectionClosed(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), net.ErrClosed.Error()) +} + +// wrap wraps the error passed into a net.OpError with the op as operation and +// returns it, or nil if the error passed is nil. +func (conn *Conn) wrap(err error, op string) error { + if err == nil { + return nil + } + return &net.OpError{ + Op: op, + Net: "raknet", + Source: conn.LocalAddr(), + Addr: conn.RemoteAddr(), + Err: err, + } +} diff --git a/pkg/net/listener.go b/pkg/net/listener.go new file mode 100644 index 0000000..95de00c --- /dev/null +++ b/pkg/net/listener.go @@ -0,0 +1,311 @@ +package net + +import ( + "bytes" + "cimeyclust.com/steel/pkg/net/packets" + "fmt" + "log/slog" + "math" + "math/rand" + "net" + "sync" + "sync/atomic" + "time" +) + +// UpstreamPacketListener allows for a custom PacketListener implementation. +type UpstreamPacketListener interface { + ListenPacket(network, address string) (net.PacketConn, error) +} + +// ListenConfig may be used to pass additional configuration to a Listener. +type ListenConfig struct { + // ErrorLog is a logger that errors from packet decoding are logged to. It + // may be set to a logger that simply discards the messages. The default + // value is slog.Default(). + ErrorLog *slog.Logger + + // UpstreamPacketListener adds an abstraction for net.ListenPacket. + UpstreamPacketListener UpstreamPacketListener +} + +// Listener implements a RakNet connection listener. It follows the same +// methods as those implemented by the TCPListener in the net package. Listener +// implements the net.Listener interface. +type Listener struct { + once sync.Once + closed chan struct{} + + // log is a logger that errors from packet decoding are logged to. It may be + // set to a logger that simply discards the messages. + log *slog.Logger + + conn net.PacketConn + // incoming is a channel of incoming connections. Connections that end up in + // here will also end up in the connections map. + incoming chan *Conn + + // connections is a map of currently active connections, indexed by their + // address. + connections sync.Map + + // id is a random server ID generated upon starting listening. It is used + // several times throughout the connection sequence of RakNet. + id int64 + + // pongData is a byte slice of data that is sent in an unconnected pong + // packet each time the client sends and unconnected ping to the server. + pongData atomic.Pointer[[]byte] +} + +// listenerID holds the next ID to use for a Listener. +var listenerID atomic.Int64 + +func init() { + listenerID.Store(rand.New(rand.NewSource(time.Now().Unix())).Int63()) +} + +// Listen listens on the address passed and returns a listener that may be used +// to accept connections. If not successful, an error is returned. The address +// follows the same rules as those defined in the net.TCPListen() function. +// Specific features of the listener may be modified once it is returned, such +// as the used log and/or the accepted protocol. +func (l ListenConfig) Listen(address string) (*Listener, error) { + var conn net.PacketConn + var err error + + if l.UpstreamPacketListener == nil { + conn, err = net.ListenPacket("udp", address) + } else { + conn, err = l.UpstreamPacketListener.ListenPacket("udp", address) + } + if err != nil { + return nil, &net.OpError{Op: "listen", Net: "raknet", Source: nil, Addr: nil, Err: err} + } + listener := &Listener{ + conn: conn, + incoming: make(chan *Conn), + closed: make(chan struct{}), + log: slog.Default(), + id: listenerID.Add(1), + } + if l.ErrorLog != nil { + listener.log = l.ErrorLog + } + + go listener.listen() + return listener, nil +} + +// Listen listens on the address passed and returns a listener that may be used +// to accept connections. If not successful, an error is returned. The address +// follows the same rules as those defined in the net.TCPListen() function. +// Specific features of the listener may be modified once it is returned, such +// as the used log and/or the accepted protocol. +func Listen(address string) (*Listener, error) { + var lc ListenConfig + return lc.Listen(address) +} + +// Accept blocks until a connection can be accepted by the listener. If +// successful, Accept returns a connection that is ready to send and receive +// data. If not successful, a nil listener is returned and an error describing +// the problem. +func (listener *Listener) Accept() (net.Conn, error) { + conn, ok := <-listener.incoming + if !ok { + return nil, &net.OpError{Op: "accept", Net: "raknet", Source: nil, Addr: nil, Err: errListenerClosed} + } + return conn, nil +} + +// Addr returns the address the Listener is bound to and listening for +// connections on. +func (listener *Listener) Addr() net.Addr { + return listener.conn.LocalAddr() +} + +// Close closes the listener so that it may be cleaned up. It makes sure the +// goroutine handling incoming packets is able to be freed. +func (listener *Listener) Close() error { + var err error + listener.once.Do(func() { + close(listener.closed) + err = listener.conn.Close() + }) + return err +} + +// PongData sets the pong data that is used to respond with when a client sends +// a ping. It usually holds game specific data that is used to display in a +// server list. If a data slice is set with a size bigger than math.MaxInt16, +// the function panics. +func (listener *Listener) PongData(data []byte) { + if len(data) > math.MaxInt16 { + panic(fmt.Sprintf("error setting pong data: pong data must not be longer than %v", math.MaxInt16)) + } + listener.pongData.Store(&data) +} + +// ID returns the unique ID of the listener. This ID is usually used by a +// client to identify a specific server during a single session. +func (listener *Listener) ID() int64 { + return listener.id +} + +// listen continuously reads from the listener's UDP connection, until closed +// has a value in it. +func (listener *Listener) listen() { + // Create a buffer with the maximum size a UDP packet sent over RakNet is + // allowed to have. We can re-use this buffer for each packet. + b := make([]byte, 1500) + buf := bytes.NewBuffer(b[:0]) + for { + n, addr, err := listener.conn.ReadFrom(b) + if err != nil { + close(listener.incoming) + return + } + _, _ = buf.Write(b[:n]) + + // Technically we should not re-use the same byte slice after its + // ownership has been taken by the buffer, but we can do this anyway + // because we copy the data later. + if err := listener.handle(buf, addr); err != nil { + listener.log.Error("listener: handle packet: "+err.Error(), "address", addr.String()) + } + buf.Reset() + } +} + +// handle handles an incoming packet in buffer b from the address passed. If +// not successful, an error is returned describing the issue. +func (listener *Listener) handle(b *bytes.Buffer, addr net.Addr) error { + value, found := listener.connections.Load(addr.String()) + if !found { + // If there was no session yet, it means the packet is an offline + // message. It is not contained in a datagram. + packetID, err := b.ReadByte() + if err != nil { + return fmt.Errorf("error reading packet ID byte: %v", err) + } + switch packetID { + case packets.IDUnconnectedPing, packets.IDUnconnectedPingOpenConnections: + return listener.handleUnconnectedPing(b, addr) + case packets.IDOpenConnectionRequest1: + return listener.handleOpenConnectionRequest1(b, addr) + case packets.IDOpenConnectionRequest2: + return listener.handleOpenConnectionRequest2(b, addr) + default: + // In some cases, the client will keep trying to send datagrams + // while it has already timed out. In this case, we should not print + // an error. + if packetID&packets.BitFlagDatagram == 0 { + return fmt.Errorf("unknown packet received (%x): %x", packetID, b.Bytes()) + } + } + return nil + } + conn := value.(*Conn) + select { + case <-conn.closed: + // Connection was closed already. + return nil + default: + err := conn.receive(b) + if err != nil { + conn.closeImmediately() + } + return err + } +} + +// handleOpenConnectionRequest2 handles an open connection request 2 packet +// stored in buffer b, coming from an address addr. +func (listener *Listener) handleOpenConnectionRequest2(b *bytes.Buffer, addr net.Addr) error { + packet := &packets.OpenConnectionRequest2{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading open connection request 2: %v", err) + } + b.Reset() + + mtuSize := packet.ClientPreferredMTUSize + if mtuSize > maxMTUSize { + mtuSize = maxMTUSize + } + + (&packets.OpenConnectionReply2{ServerGUID: listener.id, ClientAddress: *addr.(*net.UDPAddr), MTUSize: mtuSize}).Write(b) + if _, err := listener.conn.WriteTo(b.Bytes(), addr); err != nil { + return fmt.Errorf("error sending open connection reply 2: %v", err) + } + + conn := newConn(listener.conn, addr, packet.ClientPreferredMTUSize) + conn.close = func() { + // Make sure to remove the connection from the Listener once the Conn is + // closed. + listener.connections.Delete(addr.String()) + } + listener.connections.Store(addr.String(), conn) + + go func() { + t := time.NewTimer(time.Second * 10) + defer t.Stop() + select { + case <-conn.connected: + // Add the connection to the incoming channel so that a caller of + // Accept() can receive it. + listener.incoming <- conn + case <-listener.closed: + _ = conn.Close() + case <-t.C: + // It took too long to complete this connection. We closed it and go + // back to accepting. + _ = conn.Close() + } + }() + + return nil +} + +// handleOpenConnectionRequest1 handles an open connection request 1 packet +// stored in buffer b, coming from an address addr. +func (listener *Listener) handleOpenConnectionRequest1(b *bytes.Buffer, addr net.Addr) error { + packet := &packets.OpenConnectionRequest1{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading open connection request 1: %v", err) + } + b.Reset() + mtuSize := packet.MaximumSizeNotDropped + if mtuSize > maxMTUSize { + mtuSize = maxMTUSize + } + + if packet.Protocol != currentProtocol { + (&packets.IncompatibleProtocolVersion{ServerGUID: listener.id, ServerProtocol: currentProtocol}).Write(b) + _, _ = listener.conn.WriteTo(b.Bytes(), addr) + return fmt.Errorf("error handling open connection request 1: incompatible protocol version %v (listener protocol = %v)", packet.Protocol, currentProtocol) + } + + (&packets.OpenConnectionReply1{ServerGUID: listener.id, Secure: false, ServerPreferredMTUSize: mtuSize}).Write(b) + _, err := listener.conn.WriteTo(b.Bytes(), addr) + return err +} + +// handleUnconnectedPing handles an unconnected ping packet stored in buffer b, +// coming from an address addr. +func (listener *Listener) handleUnconnectedPing(b *bytes.Buffer, addr net.Addr) error { + pk := &packets.UnconnectedPing{} + if err := pk.Read(b); err != nil { + return fmt.Errorf("error reading unconnected ping: %v", err) + } + b.Reset() + + (&packets.UnconnectedPong{ServerGUID: listener.id, SendTimestamp: pk.SendTimestamp, Data: *listener.pongData.Load()}).Write(b) + _, err := listener.conn.WriteTo(b.Bytes(), addr) + return err +} + +// timestamp returns a timestamp in milliseconds. +func timestamp() int64 { + return time.Now().UnixNano() / int64(time.Second) +} diff --git a/pkg/net/network.go b/pkg/net/network.go index e86e27e..67294eb 100644 --- a/pkg/net/network.go +++ b/pkg/net/network.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "net" - "time" ) // Start Starts the TCP server on the specified address. @@ -14,9 +13,9 @@ func Run(baseCtx context.Context, addr string) { defer cancel() // Start listening on the specified address - listener, err := net.Listen("tcp", addr) + listener, err := Listen(addr) if err != nil { - cmd.Logger.Error("Error starting TCP server: %v") + cmd.Logger.Error(fmt.Sprintf("Error starting UDP server: %v", err)) return } @@ -25,21 +24,6 @@ func Run(baseCtx context.Context, addr string) { cmd.Logger.Info(fmt.Sprintf("Listening on %s", addr)) - // var wg sync.WaitGroup - - go func() { - for { - // Print test every second - select { - case <-ctx.Done(): - return - default: - cmd.Logger.Info("Test") - } - time.Sleep(1 * time.Second) - } - }() - // Listen for an incoming connection in a goroutine. go func() { go func() { @@ -50,7 +34,9 @@ func Run(baseCtx context.Context, addr string) { listener.Close() }() for { + cmd.Logger.Info("Waiting for connection...") conn, err := listener.Accept() + cmd.Logger.Info("Connection accepted") if err != nil { select { case <-ctx.Done(): @@ -75,16 +61,8 @@ func handleRequest(conn net.Conn) { // Close the connection when you're done with it. defer conn.Close() - // Make a buffer to hold incoming data. - buf := make([]byte, 1024) - - // Read the incoming connection into the buffer. - _, err := conn.Read(buf) - if err != nil { - cmd.Logger.Error(fmt.Sprintf("Error reading: %v", err)) - return - } - - // Send a response back to person contacting us. - conn.Write([]byte("Hello, World!\n")) + b := make([]byte, 1024*1024*4) + _, _ = conn.Read(b) + cmd.Logger.Info(fmt.Sprintf("Received: %v", b)) + _, _ = conn.Write([]byte{1, 2, 3}) } diff --git a/pkg/net/packet_queue.go b/pkg/net/packet_queue.go new file mode 100644 index 0000000..42749bb --- /dev/null +++ b/pkg/net/packet_queue.go @@ -0,0 +1,54 @@ +package net + +import "cimeyclust.com/steel/pkg/utils" + +// packetQueue is an ordered queue for reliable ordered packets. +type packetQueue struct { + lowest utils.Uint24 + highest utils.Uint24 + queue map[utils.Uint24][]byte +} + +// newPacketQueue returns a new initialised ordered queue. +func newPacketQueue() *packetQueue { + return &packetQueue{queue: make(map[utils.Uint24][]byte)} +} + +// put puts a value at the index passed. If the index was already occupied +// once, false is returned. +func (queue *packetQueue) put(index utils.Uint24, packet []byte) bool { + if index < queue.lowest { + return false + } + if _, ok := queue.queue[index]; ok { + return false + } + if index >= queue.highest { + queue.highest = index + 1 + } + queue.queue[index] = packet + return true +} + +// fetch attempts to take out as many values from the ordered queue as +// possible. Upon encountering an index that has no value yet, the function +// returns all values that it did find and takes them out. +func (queue *packetQueue) fetch() (packets [][]byte) { + index := queue.lowest + for index < queue.highest { + packet, ok := queue.queue[index] + if !ok { + break + } + delete(queue.queue, index) + packets = append(packets, packet) + index++ + } + queue.lowest = index + return +} + +// WindowSize returns the size of the window held by the packet queue. +func (queue *packetQueue) WindowSize() utils.Uint24 { + return queue.highest - queue.lowest +} diff --git a/pkg/net/packets/connected_ping.go b/pkg/net/packets/connected_ping.go new file mode 100644 index 0000000..20bee15 --- /dev/null +++ b/pkg/net/packets/connected_ping.go @@ -0,0 +1,19 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type ConnectedPing struct { + ClientTimestamp int64 +} + +func (pk *ConnectedPing) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectedPing) + _ = binary.Write(buf, binary.BigEndian, pk.ClientTimestamp) +} + +func (pk *ConnectedPing) Read(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &pk.ClientTimestamp) +} diff --git a/pkg/net/packets/connected_pong.go b/pkg/net/packets/connected_pong.go new file mode 100644 index 0000000..4cd8284 --- /dev/null +++ b/pkg/net/packets/connected_pong.go @@ -0,0 +1,22 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type ConnectedPong struct { + ClientTimestamp int64 + ServerTimestamp int64 +} + +func (pk *ConnectedPong) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectedPong) + _ = binary.Write(buf, binary.BigEndian, pk.ClientTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.ServerTimestamp) +} + +func (pk *ConnectedPong) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.ClientTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.ServerTimestamp) +} diff --git a/pkg/net/packets/connection_request.go b/pkg/net/packets/connection_request.go new file mode 100644 index 0000000..677afa4 --- /dev/null +++ b/pkg/net/packets/connection_request.go @@ -0,0 +1,25 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type ConnectionRequest struct { + ClientGUID int64 + RequestTimestamp int64 + Secure bool +} + +func (pk *ConnectionRequest) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectionRequest) + _ = binary.Write(buf, binary.BigEndian, pk.ClientGUID) + _ = binary.Write(buf, binary.BigEndian, pk.RequestTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.Secure) +} + +func (pk *ConnectionRequest) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.ClientGUID) + _ = binary.Read(buf, binary.BigEndian, &pk.RequestTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.Secure) +} diff --git a/pkg/net/packets/connection_request_accepted.go b/pkg/net/packets/connection_request_accepted.go new file mode 100644 index 0000000..cf53e29 --- /dev/null +++ b/pkg/net/packets/connection_request_accepted.go @@ -0,0 +1,38 @@ +package packets + +import ( + "bytes" + "encoding/binary" + "net" +) + +type ConnectionRequestAccepted struct { + ClientAddress net.UDPAddr + SystemAddresses [20]net.UDPAddr + RequestTimestamp int64 + AcceptedTimestamp int64 +} + +func (pk *ConnectionRequestAccepted) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectionRequestAccepted) + writeAddr(buf, pk.ClientAddress) + _ = binary.Write(buf, binary.BigEndian, int16(0)) + for _, addr := range pk.SystemAddresses { + writeAddr(buf, addr) + } + _ = binary.Write(buf, binary.BigEndian, pk.RequestTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.AcceptedTimestamp) +} + +func (pk *ConnectionRequestAccepted) Read(buf *bytes.Buffer) error { + _ = readAddr(buf, &pk.ClientAddress) + buf.Next(2) + for i := 0; i < 20; i++ { + _ = readAddr(buf, &pk.SystemAddresses[i]) + if buf.Len() == 16 { + break + } + } + _ = binary.Read(buf, binary.BigEndian, &pk.RequestTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.AcceptedTimestamp) +} diff --git a/pkg/net/packets/incompatible_protocol_version.go b/pkg/net/packets/incompatible_protocol_version.go new file mode 100644 index 0000000..e73fca2 --- /dev/null +++ b/pkg/net/packets/incompatible_protocol_version.go @@ -0,0 +1,25 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type IncompatibleProtocolVersion struct { + Magic [16]byte + ServerProtocol byte + ServerGUID int64 +} + +func (pk *IncompatibleProtocolVersion) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDIncompatibleProtocolVersion) + _ = binary.Write(buf, binary.BigEndian, pk.ServerProtocol) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) +} + +func (pk *IncompatibleProtocolVersion) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.ServerProtocol) + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + return binary.Read(buf, binary.BigEndian, &pk.ServerGUID) +} diff --git a/pkg/net/packets/new_incoming_connection.go b/pkg/net/packets/new_incoming_connection.go new file mode 100644 index 0000000..4c854f7 --- /dev/null +++ b/pkg/net/packets/new_incoming_connection.go @@ -0,0 +1,36 @@ +package packets + +import ( + "bytes" + "encoding/binary" + "net" +) + +type NewIncomingConnection struct { + ServerAddress net.UDPAddr + SystemAddresses [20]net.UDPAddr + RequestTimestamp int64 + AcceptedTimestamp int64 +} + +func (pk *NewIncomingConnection) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDNewIncomingConnection) + writeAddr(buf, pk.ServerAddress) + for _, addr := range pk.SystemAddresses { + writeAddr(buf, addr) + } + _ = binary.Write(buf, binary.BigEndian, pk.RequestTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.AcceptedTimestamp) +} + +func (pk *NewIncomingConnection) Read(buf *bytes.Buffer) error { + _ = readAddr(buf, &pk.ServerAddress) + for i := 0; i < 20; i++ { + _ = readAddr(buf, &pk.SystemAddresses[i]) + if buf.Len() == 16 { + break + } + } + _ = binary.Read(buf, binary.BigEndian, &pk.RequestTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.AcceptedTimestamp) +} diff --git a/pkg/net/packets/open_connection_reply_1.go b/pkg/net/packets/open_connection_reply_1.go new file mode 100644 index 0000000..6541408 --- /dev/null +++ b/pkg/net/packets/open_connection_reply_1.go @@ -0,0 +1,28 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type OpenConnectionReply1 struct { + Magic [16]byte + ServerGUID int64 + Secure bool + ServerPreferredMTUSize uint16 +} + +func (pk *OpenConnectionReply1) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionReply1) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) + _ = binary.Write(buf, binary.BigEndian, pk.Secure) + _ = binary.Write(buf, binary.BigEndian, pk.ServerPreferredMTUSize) +} + +func (pk *OpenConnectionReply1) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = binary.Read(buf, binary.BigEndian, &pk.ServerGUID) + _ = binary.Read(buf, binary.BigEndian, &pk.Secure) + return binary.Read(buf, binary.BigEndian, &pk.ServerPreferredMTUSize) +} diff --git a/pkg/net/packets/open_connection_reply_2.go b/pkg/net/packets/open_connection_reply_2.go new file mode 100644 index 0000000..7dea214 --- /dev/null +++ b/pkg/net/packets/open_connection_reply_2.go @@ -0,0 +1,32 @@ +package packets + +import ( + "bytes" + "encoding/binary" + "net" +) + +type OpenConnectionReply2 struct { + Magic [16]byte + ServerGUID int64 + ClientAddress net.UDPAddr + MTUSize uint16 + Secure bool +} + +func (pk *OpenConnectionReply2) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionReply2) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) + writeAddr(buf, pk.ClientAddress) + _ = binary.Write(buf, binary.BigEndian, pk.MTUSize) + _ = binary.Write(buf, binary.BigEndian, pk.Secure) +} + +func (pk *OpenConnectionReply2) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = binary.Read(buf, binary.BigEndian, &pk.ServerGUID) + _ = readAddr(buf, &pk.ClientAddress) + _ = binary.Read(buf, binary.BigEndian, &pk.MTUSize) + return binary.Read(buf, binary.BigEndian, &pk.Secure) +} diff --git a/pkg/net/packets/open_connection_request_1.go b/pkg/net/packets/open_connection_request_1.go new file mode 100644 index 0000000..e036fac --- /dev/null +++ b/pkg/net/packets/open_connection_request_1.go @@ -0,0 +1,25 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type OpenConnectionRequest1 struct { + Magic [16]byte + Protocol byte + MaximumSizeNotDropped uint16 +} + +func (pk *OpenConnectionRequest1) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionRequest1) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.Protocol) + _, _ = buf.Write(make([]byte, pk.MaximumSizeNotDropped-uint16(buf.Len()+28))) +} + +func (pk *OpenConnectionRequest1) Read(buf *bytes.Buffer) error { + pk.MaximumSizeNotDropped = uint16(buf.Len()+1) + 28 + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + return binary.Read(buf, binary.BigEndian, &pk.Protocol) +} diff --git a/pkg/net/packets/open_connection_request_2.go b/pkg/net/packets/open_connection_request_2.go new file mode 100644 index 0000000..d84ddb8 --- /dev/null +++ b/pkg/net/packets/open_connection_request_2.go @@ -0,0 +1,29 @@ +package packets + +import ( + "bytes" + "encoding/binary" + "net" +) + +type OpenConnectionRequest2 struct { + Magic [16]byte + ServerAddress net.UDPAddr + ClientPreferredMTUSize uint16 + ClientGUID int64 +} + +func (pk *OpenConnectionRequest2) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionRequest2) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + writeAddr(buf, pk.ServerAddress) + _ = binary.Write(buf, binary.BigEndian, pk.ClientPreferredMTUSize) + _ = binary.Write(buf, binary.BigEndian, pk.ClientGUID) +} + +func (pk *OpenConnectionRequest2) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = readAddr(buf, &pk.ServerAddress) + _ = binary.Read(buf, binary.BigEndian, &pk.ClientPreferredMTUSize) + return binary.Read(buf, binary.BigEndian, &pk.ClientGUID) +} diff --git a/pkg/net/packets/packet.go b/pkg/net/packets/packet.go new file mode 100644 index 0000000..54e8d7c --- /dev/null +++ b/pkg/net/packets/packet.go @@ -0,0 +1,419 @@ +package packets + +import ( + "bytes" + "cimeyclust.com/steel/pkg/utils" + "encoding/binary" + "fmt" + "net" + "sort" +) + +const ( + IDConnectedPing byte = 0x00 + IDUnconnectedPing byte = 0x01 + IDUnconnectedPingOpenConnections byte = 0x02 + IDConnectedPong byte = 0x03 + IDDetectLostConnections byte = 0x04 + IDOpenConnectionRequest1 byte = 0x05 + IDOpenConnectionReply1 byte = 0x06 + IDOpenConnectionRequest2 byte = 0x07 + IDOpenConnectionReply2 byte = 0x08 + IDConnectionRequest byte = 0x09 + IDConnectionRequestAccepted byte = 0x10 + IDNewIncomingConnection byte = 0x13 + IDDisconnectNotification byte = 0x15 + + IDIncompatibleProtocolVersion byte = 0x19 + + IDUnconnectedPong byte = 0x1c +) + +// unconnectedMessageSequence is a sequence of bytes which is found in every unconnected message sent in +// RakNet. +var unconnectedMessageSequence = [16]byte{0x00, 0xff, 0xff, 0x00, 0xfe, 0xfe, 0xfe, 0xfe, 0xfd, 0xfd, 0xfd, 0xfd, 0x12, 0x34, 0x56, 0x78} + +// writeAddr writes a UDP address to the buffer passed. +func writeAddr(buffer *bytes.Buffer, addr net.UDPAddr) { + var ver byte = 6 + if addr.IP.To4() != nil { + ver = 4 + } + if addr.IP == nil { + addr.IP = make([]byte, 16) + } + _ = buffer.WriteByte(ver) + if ver == 4 { + ipBytes := addr.IP.To4() + + _ = buffer.WriteByte(^ipBytes[0]) + _ = buffer.WriteByte(^ipBytes[1]) + _ = buffer.WriteByte(^ipBytes[2]) + _ = buffer.WriteByte(^ipBytes[3]) + _ = binary.Write(buffer, binary.BigEndian, uint16(addr.Port)) + } else { + _ = binary.Write(buffer, binary.LittleEndian, int16(23)) // syscall.AF_INET6 on Windows. + _ = binary.Write(buffer, binary.BigEndian, uint16(addr.Port)) + // The IPv6 address is enclosed in two 0 integers. + _ = binary.Write(buffer, binary.BigEndian, int32(0)) + _, _ = buffer.Write(addr.IP.To16()) + _ = binary.Write(buffer, binary.BigEndian, int32(0)) + } +} + +// readAddr decodes a RakNet address from the buffer passed. If not successful, an error is returned. +func readAddr(buffer *bytes.Buffer, addr *net.UDPAddr) error { + ver, err := buffer.ReadByte() + if err != nil { + return err + } + if ver == 4 { + ipBytes := make([]byte, 4) + if _, err := buffer.Read(ipBytes); err != nil { + return fmt.Errorf("error reading raknet address ipv4 bytes: %v", err) + } + // Construct an IPv4 out of the 4 bytes we just read. + addr.IP = net.IPv4((-ipBytes[0]-1)&0xff, (-ipBytes[1]-1)&0xff, (-ipBytes[2]-1)&0xff, (-ipBytes[3]-1)&0xff) + var port uint16 + if err := binary.Read(buffer, binary.BigEndian, &port); err != nil { + return fmt.Errorf("error reading raknet address port: %v", err) + } + addr.Port = int(port) + } else { + buffer.Next(2) + var port uint16 + if err := binary.Read(buffer, binary.LittleEndian, &port); err != nil { + return fmt.Errorf("error reading raknet address port: %v", err) + } + addr.Port = int(port) + buffer.Next(4) + addr.IP = make([]byte, 16) + if _, err := buffer.Read(addr.IP); err != nil { + return fmt.Errorf("error reading raknet address ipv6 bytes: %v", err) + } + buffer.Next(4) + } + return nil +} + +const ( + // BitFlagDatagram is set for every valid datagram. It is used to identify + // packets that are datagrams. + BitFlagDatagram = 0x80 + // BitFlagACK is set for every ACK Packet. + BitFlagACK = 0x40 + // BitFlagNACK is set for every NACK Packet. + BitFlagNACK = 0x20 + // BitFlagNeedsBAndAS is set for every datagram with Packet data, but is not + // actually used. + BitFlagNeedsBAndAS = 0x04 +) + +// noinspection GoUnusedConst +const ( + // ReliabilityUnreliable means that the Packet sent could arrive out of + // order, be duplicated, or just not arrive at all. It is usually used for + // high frequency packets of which the order does not matter. + // lint:ignore U1000 While this constant is unused, it is here for the sake + // of having all reliabilities documented. + ReliabilityUnreliable byte = iota + // ReliabilityUnreliableSequenced means that the Packet sent could be + // duplicated or not arrive at all, but ensures that it is always handled in + // the right order. + ReliabilityUnreliableSequenced + // ReliabilityReliable means that the Packet sent could not arrive, or + // arrive out of order, but ensures that the Packet is not duplicated. + ReliabilityReliable + // ReliabilityReliableOrdered means that every Packet sent arrives, arrives + // in the right order and is not duplicated. + ReliabilityReliableOrdered + // ReliabilityReliableSequenced means that the Packet sent could not arrive, + // but ensures that the Packet will be in the right order and not be + // duplicated. + ReliabilityReliableSequenced + + // SplitFlag is set in the header if the Packet was split. If so, the + // encapsulation contains additional data about the fragment. + SplitFlag = 0x10 +) + +// Packet is an encapsulation around every Packet sent after the connection is +// established. It is +type Packet struct { + Reliability byte + + Content []byte + MessageIndex utils.Uint24 + sequenceIndex utils.Uint24 + OrderIndex utils.Uint24 + + Split bool + SplitCount uint32 + SplitIndex uint32 + SplitID uint16 +} + +// Write writes the Packet and its Content to the buffer passed. +func (packet *Packet) Write(b *bytes.Buffer) { + header := packet.Reliability << 5 + if packet.Split { + header |= SplitFlag + } + b.WriteByte(header) + _ = binary.Write(b, binary.BigEndian, uint16(len(packet.Content))<<3) + if packet.reliable() { + utils.WriteUint24(b, packet.MessageIndex) + } + if packet.sequenced() { + utils.WriteUint24(b, packet.sequenceIndex) + } + if packet.sequencedOrOrdered() { + utils.WriteUint24(b, packet.OrderIndex) + // Order channel, we don't care about this. + b.WriteByte(0) + } + if packet.Split { + _ = binary.Write(b, binary.BigEndian, packet.SplitCount) + _ = binary.Write(b, binary.BigEndian, packet.SplitID) + _ = binary.Write(b, binary.BigEndian, packet.SplitIndex) + } + b.Write(packet.Content) +} + +// Read reads a Packet and its Content from the buffer passed. +func (packet *Packet) Read(b *bytes.Buffer) error { + header, err := b.ReadByte() + if err != nil { + return fmt.Errorf("error reading Packet header: %v", err) + } + packet.Split = (header & SplitFlag) != 0 + packet.Reliability = (header & 224) >> 5 + var packetLength uint16 + if err := binary.Read(b, binary.BigEndian, &packetLength); err != nil { + return fmt.Errorf("error reading Packet length: %v", err) + } + packetLength >>= 3 + if packetLength == 0 { + return fmt.Errorf("invalid Packet length: cannot be 0") + } + + if packet.reliable() { + packet.MessageIndex, err = utils.ReadUint24(b) + if err != nil { + return fmt.Errorf("error reading Packet message index: %v", err) + } + } + + if packet.sequenced() { + packet.sequenceIndex, err = utils.ReadUint24(b) + if err != nil { + return fmt.Errorf("error reading Packet sequence index: %v", err) + } + } + + if packet.sequencedOrOrdered() { + packet.OrderIndex, err = utils.ReadUint24(b) + if err != nil { + return fmt.Errorf("error reading Packet order index: %v", err) + } + // Order channel (byte), we don't care about this. + b.Next(1) + } + + if packet.Split { + if err := binary.Read(b, binary.BigEndian, &packet.SplitCount); err != nil { + return fmt.Errorf("error reading Packet split count: %v", err) + } + if err := binary.Read(b, binary.BigEndian, &packet.SplitID); err != nil { + return fmt.Errorf("error reading Packet split ID: %v", err) + } + if err := binary.Read(b, binary.BigEndian, &packet.SplitIndex); err != nil { + return fmt.Errorf("error reading Packet split index: %v", err) + } + } + + packet.Content = make([]byte, packetLength) + if n, err := b.Read(packet.Content); err != nil || n != int(packetLength) { + return fmt.Errorf("not enough data in Packet: %v bytes read but need %v", n, packetLength) + } + return nil +} + +func (packet *Packet) reliable() bool { + switch packet.Reliability { + case ReliabilityReliable, + ReliabilityReliableOrdered, + ReliabilityReliableSequenced: + return true + } + return false +} + +func (packet *Packet) sequencedOrOrdered() bool { + switch packet.Reliability { + case ReliabilityUnreliableSequenced, + ReliabilityReliableOrdered, + ReliabilityReliableSequenced: + return true + } + return false +} + +func (packet *Packet) sequenced() bool { + switch packet.Reliability { + case ReliabilityUnreliableSequenced, + ReliabilityReliableSequenced: + return true + } + return false +} + +const ( + // packetRange indicates a range of packets, followed by the first and the + // last Packet in the range. + packetRange = iota + // packetSingle indicates a single Packet, followed by its sequence number. + packetSingle +) + +// Acknowledgement is an Acknowledgement Packet that may either be an ACK or a +// NACK, depending on the purpose that it is sent with. +type Acknowledgement struct { + Packets []utils.Uint24 +} + +// Write encodes an Acknowledgement Packet and returns an error if not +// successful. +func (ack *Acknowledgement) Write(b *bytes.Buffer, mtu uint16) (n int, err error) { + packets := ack.Packets + if len(packets) == 0 { + return 0, binary.Write(b, binary.BigEndian, int16(0)) + } + buffer := bytes.NewBuffer(nil) + // Sort packets before encoding to ensure packets are encoded correctly. + sort.Slice(packets, func(i, j int) bool { + return packets[i] < packets[j] + }) + + var firstPacketInRange utils.Uint24 + var lastPacketInRange utils.Uint24 + var recordCount int16 + + for index, packet := range packets { + if buffer.Len() >= int(mtu-10) { + // We must make sure the final Packet length doesn't exceed the MTU + // size. + break + } + n++ + if index == 0 { + // The first Packet, set the first and last Packet to it. + firstPacketInRange = packet + lastPacketInRange = packet + continue + } + if packet == lastPacketInRange+1 { + // Packet is still part of the current range, as it's sequenced + // properly with the last Packet. Set the last Packet in range to + // the Packet and continue to the next Packet. + lastPacketInRange = packet + continue + } else { + // We got to the end of a range/single Packet. We need to write + // those down now. + if firstPacketInRange == lastPacketInRange { + // First Packet equals last Packet, so we have a single Packet + // record. Write down the Packet, and set the first and last + // Packet to the current Packet. + if err := buffer.WriteByte(packetSingle); err != nil { + return 0, err + } + utils.WriteUint24(buffer, firstPacketInRange) + + firstPacketInRange = packet + lastPacketInRange = packet + } else { + // There's a gap between the first and last Packet, so we have a + // range of packets. Write the first and last Packet of the + // range and set both to the current Packet. + if err := buffer.WriteByte(packetRange); err != nil { + return 0, err + } + utils.WriteUint24(buffer, firstPacketInRange) + utils.WriteUint24(buffer, lastPacketInRange) + + firstPacketInRange = packet + lastPacketInRange = packet + } + // Keep track of the amount of records as we need to write that + // first. + recordCount++ + } + } + + // Make sure the last single Packet/range is written, as we always need to + // know one Packet ahead to know how we should write the current. + if firstPacketInRange == lastPacketInRange { + if err := buffer.WriteByte(packetSingle); err != nil { + return 0, err + } + utils.WriteUint24(buffer, firstPacketInRange) + } else { + if err := buffer.WriteByte(packetRange); err != nil { + return 0, err + } + utils.WriteUint24(buffer, firstPacketInRange) + utils.WriteUint24(buffer, lastPacketInRange) + } + recordCount++ + if err := binary.Write(b, binary.BigEndian, recordCount); err != nil { + return 0, err + } + if _, err := b.Write(buffer.Bytes()); err != nil { + return 0, err + } + return n, nil +} + +// Read decodes an Acknowledgement Packet and returns an error if not +// successful. +func (ack *Acknowledgement) Read(b *bytes.Buffer) error { + const maxAcknowledgementPackets = 8192 + var recordCount int16 + if err := binary.Read(b, binary.BigEndian, &recordCount); err != nil { + return err + } + for i := int16(0); i < recordCount; i++ { + recordType, err := b.ReadByte() + if err != nil { + return err + } + switch recordType { + case packetRange: + start, err := utils.ReadUint24(b) + if err != nil { + return err + } + end, err := utils.ReadUint24(b) + if err != nil { + return err + } + for pack := start; pack <= end; pack++ { + ack.Packets = append(ack.Packets, pack) + if len(ack.Packets) > maxAcknowledgementPackets { + return fmt.Errorf("maximum amount of packets in acknowledgement exceeded") + } + } + case packetSingle: + packet, err := utils.ReadUint24(b) + if err != nil { + return err + } + ack.Packets = append(ack.Packets, packet) + if len(ack.Packets) > maxAcknowledgementPackets { + return fmt.Errorf("maximum amount of packets in acknowledgement exceeded") + } + } + } + return nil +} diff --git a/pkg/net/packets/unconnected_ping.go b/pkg/net/packets/unconnected_ping.go new file mode 100644 index 0000000..add148e --- /dev/null +++ b/pkg/net/packets/unconnected_ping.go @@ -0,0 +1,25 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type UnconnectedPing struct { + Magic [16]byte + SendTimestamp int64 + ClientGUID int64 +} + +func (pk *UnconnectedPing) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDUnconnectedPing) + _ = binary.Write(buf, binary.BigEndian, pk.SendTimestamp) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ClientGUID) +} + +func (pk *UnconnectedPing) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.SendTimestamp) + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + return binary.Read(buf, binary.BigEndian, &pk.ClientGUID) +} diff --git a/pkg/net/packets/unconnected_pong.go b/pkg/net/packets/unconnected_pong.go new file mode 100644 index 0000000..9e77ef5 --- /dev/null +++ b/pkg/net/packets/unconnected_pong.go @@ -0,0 +1,33 @@ +package packets + +import ( + "bytes" + "encoding/binary" +) + +type UnconnectedPong struct { + Magic [16]byte + SendTimestamp int64 + ServerGUID int64 + Data []byte +} + +func (pk *UnconnectedPong) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDUnconnectedPong) + _ = binary.Write(buf, binary.BigEndian, pk.SendTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, int16(len(pk.Data))) + _ = binary.Write(buf, binary.BigEndian, pk.Data) +} + +func (pk *UnconnectedPong) Read(buf *bytes.Buffer) error { + var l int16 + _ = binary.Read(buf, binary.BigEndian, &pk.SendTimestamp) + _ = binary.Read(buf, binary.BigEndian, &pk.ServerGUID) + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = binary.Read(buf, binary.BigEndian, &l) + pk.Data = make([]byte, l) + _, err := buf.Read(pk.Data) + return err +} diff --git a/pkg/net/resend_map.go b/pkg/net/resend_map.go new file mode 100644 index 0000000..60bb6fb --- /dev/null +++ b/pkg/net/resend_map.go @@ -0,0 +1,84 @@ +package net + +import ( + "cimeyclust.com/steel/pkg/net/packets" + "cimeyclust.com/steel/pkg/utils" + "time" +) + +// resendMap is a map of packets, used to recover datagrams if the other end of +// the connection ended up not having them. +type resendMap struct { + unacknowledged map[utils.Uint24]resendRecord + delays map[time.Time]time.Duration +} + +// resendRecord represents a single packet with a timestamp from when it was +// initially sent. It may be either acknowledged or NACKed by the other end. +type resendRecord struct { + pk *packets.Packet + timestamp time.Time +} + +// newRecoveryQueue returns a new initialised recovery queue. +func newRecoveryQueue() *resendMap { + return &resendMap{ + delays: make(map[time.Time]time.Duration), + unacknowledged: make(map[utils.Uint24]resendRecord), + } +} + +// add puts a packet at the index passed and records the current time. +func (m *resendMap) add(index utils.Uint24, pk *packets.Packet) { + m.unacknowledged[index] = resendRecord{pk: pk, timestamp: time.Now()} +} + +// acknowledge marks a packet with the index passed as acknowledged. The packet +// is removed from the resendMap and returned if found. +func (m *resendMap) acknowledge(index utils.Uint24) (*packets.Packet, bool) { + return m.remove(index, 1) +} + +// retransmit looks up a packet with an index from the resendMap so that it may +// be resent. +func (m *resendMap) retransmit(index utils.Uint24) (*packets.Packet, bool) { + return m.remove(index, 2) +} + +// remove deletes an index from the resendMap and adds the time since the +// packet was originally sent multiplied by mul to the delays slice. +func (m *resendMap) remove(index utils.Uint24, mul int) (*packets.Packet, bool) { + record, ok := m.unacknowledged[index] + if !ok { + return nil, false + } + delete(m.unacknowledged, index) + + now := time.Now() + m.delays[now] = now.Sub(record.timestamp) * time.Duration(mul) + return record.pk, true +} + +// rtt returns the average round trip time between the putting of the value +// into the recovery queue and the taking out of it again. It is measured over +// the last delayRecordCount values add in. +func (m *resendMap) rtt() time.Duration { + const averageDuration = time.Second * 5 + var ( + total, records time.Duration + now = time.Now() + ) + for t, rtt := range m.delays { + if now.Sub(t) > averageDuration { + delete(m.delays, t) + continue + } + total += rtt + records++ + } + if records == 0 { + // No records yet, generally should not happen. Just return a reasonable amount of time. + return time.Millisecond * 50 + } + return total / records +} diff --git a/pkg/utils/binary.go b/pkg/utils/binary.go new file mode 100644 index 0000000..0191d7a --- /dev/null +++ b/pkg/utils/binary.go @@ -0,0 +1,30 @@ +package utils + +import ( + "bytes" + "fmt" +) + +// Uint24 represents an integer existing out of 3 bytes. It is actually a +// uint32, but is an alias for the sake of clarity. +type Uint24 uint32 + +// ReadUint24 reads 3 bytes from the buffer passed and combines it into a +// uint24. If there were no 3 bytes to read, an error is returned. +func ReadUint24(b *bytes.Buffer) (Uint24, error) { + ba, _ := b.ReadByte() + bb, _ := b.ReadByte() + bc, err := b.ReadByte() + if err != nil { + return 0, fmt.Errorf("error reading uint24: %v", err) + } + return Uint24(ba) | (Uint24(bb) << 8) | (Uint24(bc) << 16), nil +} + +// WriteUint24 writes a uint24 to the buffer passed as 3 bytes. If not +// successful, an error is returned. +func WriteUint24(b *bytes.Buffer, value Uint24) { + b.WriteByte(byte(value)) + b.WriteByte(byte(value >> 8)) + b.WriteByte(byte(value >> 16)) +} diff --git a/steel.go b/steel.go index 668b76c..38b9d44 100644 --- a/steel.go +++ b/steel.go @@ -36,7 +36,7 @@ func main() { go func() { defer wg.Done() - net.Run(ctx, "localhost:8080") + net.Run(ctx, "localhost:19132") }() // Handles Ctrl+C