diff --git a/pkg/cmd/console.go b/pkg/cmd/console.go index 54a46e5..fc4a385 100644 --- a/pkg/cmd/console.go +++ b/pkg/cmd/console.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/eiannone/keyboard" "sync" @@ -44,7 +45,7 @@ func (l *LoggerQueue) Error(message string) { l.log("ERROR", message, "") } -func (l *LoggerQueue) processLogs(stopChan chan struct{}) { +func (l *LoggerQueue) processLogs(ctx context.Context) { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() @@ -52,9 +53,8 @@ func (l *LoggerQueue) processLogs(stopChan chan struct{}) { select { case logEntry := <-l.queue: displayLog(logEntry) - case <-stopChan: + case <-ctx.Done(): fmt.Print("\033[2K\r") - println("Stopping") return case <-ticker.C: refreshInputDisplay(l.currentInput) @@ -71,53 +71,62 @@ func refreshInputDisplay(input string) { fmt.Print("\033[2K\r> " + input) } -func Start(wg *sync.WaitGroup, stopChan chan struct{}) { - defer wg.Done() - - if err := keyboard.Open(); err != nil { - panic(err) - } - defer func() { - go keyboard.Close() - }() - - go Logger.processLogs(stopChan) +func Run(baseCtx context.Context) { + ctx, cancel := context.WithCancel(baseCtx) + defer cancel() input := "" - for { - char, key, err := keyboard.GetKey() - if err != nil { - Logger.Error(fmt.Sprintf("Keyboard error: %v", err)) - break - } - - switch { - case key == keyboard.KeyEnter: - switch input { - case "": - // Do nothing - case "stop": - Logger.Info("Received stop command.") - close(stopChan) - return - default: - Logger.Info(fmt.Sprintf("Unknown command: %s", input)) - - } - input = "" - case key == keyboard.KeyBackspace || key == keyboard.KeyBackspace2: - if len(input) > 0 { - input = input[:len(input)-1] - } - case key == keyboard.KeySpace: - input += " " - case key == keyboard.KeyCtrlC: - Logger.Info("Received stop command.") - close(stopChan) - return - case char != 0: - input += string(char) - } - Logger.currentInput = input // Update stored user input + event, err := keyboard.GetKeys(10) + if err != nil { + Logger.Error(fmt.Sprintf("Keyboard error: %v", err)) + return } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + Logger.processLogs(ctx) + }() +loop: + for { + select { + case <-ctx.Done(): + cancel() + break loop + case event := <-event: + if event.Rune != 0 { + input += string(event.Rune) + } + switch event.Key { + case keyboard.KeyEnter: + switch input { + case "": + // Do nothing + case "stop": + Logger.Info("Received stop command.") + cancel() + break loop + default: + Logger.Info(fmt.Sprintf("Unknown command: %s", input)) + } + input = "" + case keyboard.KeyBackspace | keyboard.KeyBackspace2: + if len(input) > 0 { + input = input[:len(input)-1] + } + case keyboard.KeySpace: + input += " " + case keyboard.KeyCtrlC: + Logger.Info("Received stop command.") + cancel() + break loop + } + Logger.currentInput = input // Update stored user input + } + } + + wg.Wait() } diff --git a/pkg/net/network.go b/pkg/net/network.go index a0b871b..9663f7f 100644 --- a/pkg/net/network.go +++ b/pkg/net/network.go @@ -2,39 +2,40 @@ package net import ( "cimeyclust.com/steel/pkg/cmd" + "context" "fmt" "net" - "sync" "time" ) // Start Starts the TCP server on the specified address. -func Start(addr string, stopChan <-chan struct{}, wg *sync.WaitGroup) { - defer wg.Done() +func Run(baseCtx context.Context, addr string) { + ctx, cancel := context.WithCancel(baseCtx) + defer cancel() // Start listening on the specified address listener, err := net.Listen("tcp", addr) if err != nil { cmd.Logger.Error("Error starting TCP server: %v") + return } // Close the listener when the application closes. - defer func(listener net.Listener) { - err := listener.Close() - if err != nil { - return - } - }(listener) + defer listener.Close() - // slog.Info(fmt.Sprintf("Listening on %s", addr)) cmd.Logger.Info(fmt.Sprintf("Listening on %s", addr)) + // var wg sync.WaitGroup + go func() { + go func() { + <-ctx.Done() + }() for { // Print test every second select { - case <-stopChan: - return // Safe exit if stop has been signalled + case <-ctx.Done(): + return default: cmd.Logger.Info("Test") } @@ -43,13 +44,18 @@ func Start(addr string, stopChan <-chan struct{}, wg *sync.WaitGroup) { }() // Listen for an incoming connection in a goroutine. + // wg.Add(1) go func() { + go func() { + <-ctx.Done() + listener.Close() + }() for { conn, err := listener.Accept() if err != nil { select { - case <-stopChan: - return // Safe exit if stop has been signalled + case <-ctx.Done(): + return default: cmd.Logger.Error(fmt.Sprintf("Error while accepting conn: %v", err)) } @@ -61,8 +67,8 @@ func Start(addr string, stopChan <-chan struct{}, wg *sync.WaitGroup) { } }() - // Block until we receive a stop signal - <-stopChan + // Wait for all components to finish + <-ctx.Done() } // handleRequest handles incoming requests. diff --git a/steel.go b/steel.go index 5daecb2..576a970 100644 --- a/steel.go +++ b/steel.go @@ -3,21 +3,34 @@ package main import ( "cimeyclust.com/steel/pkg/cmd" "cimeyclust.com/steel/pkg/net" + "context" "sync" ) func main() { // Channel for stopping the program - stopChan := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var wg sync.WaitGroup // Start the console handler wg.Add(1) - go cmd.Start(&wg, stopChan) + go func() { + defer wg.Done() + + defer cancel() + + cmd.Run(ctx) + }() // Start the network server wg.Add(1) - go net.Start("localhost:8080", stopChan, &wg) + go func() { + defer wg.Done() + + net.Run(ctx, "localhost:8080") + }() cmd.Logger.Info("Started")