diff --git a/main.go b/main.go index 6ad1e1a..f2df281 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "net" ) @@ -12,6 +13,11 @@ const ( StatusResponseID = 0x00 ) +const ( + SEGMENT_BITS = 0x7F + CONTINUE_BIT = 0x80 +) + type HandshakePacket struct { ProtocolVersion int ServerAddress string @@ -45,8 +51,13 @@ type Chat struct { Text string `json:"text"` } +type Reader struct { + data []byte + cursor int +} + func main() { - listener, err := net.Listen("tcp", ":25565") + listener, err := net.Listen("tcp", "localhost:25565") if err != nil { fmt.Println("Error starting server:", err) return @@ -67,32 +78,52 @@ func main() { } func handleConnection(conn net.Conn) { - var packetID byte - if err := readByte(conn, &packetID); err != nil { + r := &Reader{data: make([]byte, 1024)} + + // Print the whole packet + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + r.data = buf[:n] + + fmt.Println("Packet: ", r.data) + + _, err := r.readVarInt() + if err != nil { + fmt.Println("Error reading first number:", err) + return + } + + packetID, err := r.readVarInt() + if err != nil { fmt.Println("Error reading packet ID:", err) return } - if packetID == HandshakePacketID { - handleHandshake(conn) + fmt.Println("Packet ID: ", packetID) + + switch packetID { + case HandshakePacketID: + handleHandshake(r) + default: + fmt.Println("Unknown packet ID:", packetID) } } -func handleHandshake(conn net.Conn) { +func handleHandshake(r *Reader) { var handshake HandshakePacket - if err := readHandshakePacket(conn, &handshake); err != nil { + if err := readHandshakePacket(r, &handshake); err != nil { fmt.Println("Error reading handshake packet:", err) return } if handshake.NextState == 1 { - handleStatusRequest(conn) + handleStatusRequest(r) } } -func handleStatusRequest(conn net.Conn) { - var packetID byte - if err := readByte(conn, &packetID); err != nil { +func handleStatusRequest(r *Reader) { + packetID, err := r.readVarInt() + if err != nil { fmt.Println("Error reading packet ID:", err) return } @@ -119,46 +150,96 @@ func handleStatusRequest(conn net.Conn) { return } - if err := writeByte(conn, StatusResponseID); err != nil { + w := &Writer{data: make([]byte, 1024)} + + if err := w.writeVarInt(StatusResponseID); err != nil { fmt.Println("Error writing response ID:", err) return } - if err := writeVarInt(conn, len(response)); err != nil { + if err := w.writeVarInt(len(response)); err != nil { fmt.Println("Error writing response length:", err) return } - if _, err := conn.Write(response); err != nil { + if err := w.writeString(string(response)); err != nil { fmt.Println("Error writing response:", err) return } } } -func readByte(conn net.Conn, b *byte) error { - buf := make([]byte, 1) - _, err := conn.Read(buf) - if err != nil { - return err +func (r *Reader) readByte() (byte, error) { + if r.cursor >= len(r.data) { + return 0, errors.New("EOF") } - *b = buf[0] + b := r.data[r.cursor] + r.cursor++ + return b, nil +} + +func (r *Reader) readVarInt() (int, error) { + value := 0 + position := 0 + + for { + currentByte, err := r.readByte() + if err != nil { + return 0, err + } + + value |= int(currentByte&SEGMENT_BITS) << position + + if (currentByte & CONTINUE_BIT) == 0 { + break + } + + position += 7 + if position >= 32 { + return 0, errors.New("VarInt is too big") + } + } + + return value, nil +} + +func (r *Reader) readString() (string, error) { + length, err := r.readVarInt() + if err != nil { + return "", err + } + + if r.cursor+length > len(r.data) { + return "", errors.New("EOF: Reached String unexpectedly") + } + + str := string(r.data[r.cursor : r.cursor+length]) + r.cursor += length + return str, nil +} + +type Writer struct { + data []byte + cursor int +} + +func (w *Writer) writeByte(b byte) error { + if w.cursor >= len(w.data) { + return errors.New("EOF") + } + w.data[w.cursor] = b + w.cursor++ return nil } -func writeByte(conn net.Conn, b byte) error { - _, err := conn.Write([]byte{b}) - return err -} - -func writeVarInt(conn net.Conn, value int) error { +func (w *Writer) writeVarInt(value int) error { for { - temp := byte(value & 0x7F) + temp := byte(value & SEGMENT_BITS) value >>= 7 if value != 0 { - temp |= 0x80 + temp |= CONTINUE_BIT } - if _, err := conn.Write([]byte{temp}); err != nil { + if err := w.writeByte(temp); err != nil { return err } if value == 0 { @@ -168,27 +249,72 @@ func writeVarInt(conn net.Conn, value int) error { return nil } -func readVarInt(conn net.Conn, value *int) error { - var result int - var shift uint +func (w *Writer) writeString(s string) error { + length := len(s) + if err := w.writeVarInt(length); err != nil { + return err + } + if w.cursor+length > len(w.data) { + return errors.New("EOF") + } + copy(w.data[w.cursor:], s) + w.cursor += length + return nil +} + +func writeByte(conn net.Conn, b byte) error { + _, err := conn.Write([]byte{b}) + return err +} + +func writeVarInt(conn net.Conn, value int) error { + var buf [5]byte + n := 0 for { - var b byte - if err := readByte(conn, &b); err != nil { - return err + temp := byte(value & SEGMENT_BITS) + value >>= 7 + if value != 0 { + temp |= CONTINUE_BIT } - result |= int(b&0x7F) << shift - shift += 7 - if b&0x80 == 0 { + buf[n] = temp + n++ + if value == 0 { break } } - *value = result - return nil + _, err := conn.Write(buf[:n]) + return err } -func readHandshakePacket(conn net.Conn, packet *HandshakePacket) error { - if err := readVarInt(conn, &packet.ProtocolVersion); err != nil { +func readHandshakePacket(r *Reader, packet *HandshakePacket) error { + // Print the whole packet + fmt.Println("Packet: ", r.data) + fmt.Println("Works so far") + var err error + packet.ProtocolVersion, err = r.readVarInt() + fmt.Println("version: ", packet.ProtocolVersion) + if err != nil { + return err } + + packet.ServerAddress, err = r.readString() + fmt.Println(packet.ServerAddress) + if err != nil { + return err + } + + port, err := r.readByte() + if err != nil { + return err + } + packet.ServerPort = uint16(port) + + packet.NextState, err = r.readVarInt() + fmt.Println("Next state: ", packet.NextState) + if err != nil { + return err + } + return nil }