diff --git a/main.go b/main.go index 950a289..5515d58 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,8 @@ import ( "log" "log/slog" "os" + "os/signal" + "syscall" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -700,15 +702,38 @@ func main() { registerAllTools(s, appInstance, debugLogger) - // Cleanup on exit - defer func() { + // Cleanup function + cleanup := func() { + debugLogger.Info("Starting cleanup...") if err := appInstance.Disconnect(); err != nil { debugLogger.Error("Failed to disconnect from database", "error", err) + } else { + debugLogger.Info("Database connection closed successfully") } + debugLogger.Info("Server shutdown complete") + } + defer cleanup() + + // Set up signal handling for graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + go func() { + sig := <-sigChan + debugLogger.Info("Received shutdown signal", "signal", sig.String()) + cancel() }() // Start the stdio server - if err := server.ServeStdio(s); err != nil { + debugLogger.Info("Starting PostgreSQL MCP Server", "version", version) + + // Create a custom StdioServer with context support + stdioServer := server.NewStdioServer(s) + if err := stdioServer.Listen(ctx, os.Stdin, os.Stdout); err != nil && err != context.Canceled { + debugLogger.Error("Server error", "error", err) fmt.Fprintf(os.Stderr, "Server error: %v\n", err) return }