diff --git a/pkg/unikontainers/ipc.go b/pkg/unikontainers/ipc.go index 86d54e4d..b48957a3 100644 --- a/pkg/unikontainers/ipc.go +++ b/pkg/unikontainers/ipc.go @@ -17,6 +17,7 @@ package unikontainers import ( "errors" "fmt" + "io" "io/fs" "net" "os" @@ -81,9 +82,21 @@ func SockAddrExists(sockAddr string) bool { // SendIPCMessage creates a new connection to socketAddress, sends the message and closes the connection func SendIPCMessage(socketAddress string, message IPCMessage) error { - conn, err := net.Dial("unix", socketAddress) + var conn net.Conn + var err error + + // FIX #405: Backoff retry loop to handle IPC socket race conditions during reexec + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + conn, err = net.DialTimeout("unix", socketAddress, 100*time.Millisecond) + if err == nil { + break + } + time.Sleep(10 * time.Millisecond) + } + if err != nil { - return err + return fmt.Errorf("timeout waiting for ipc socket %s: %w", socketAddress, err) } defer conn.Close() @@ -145,27 +158,51 @@ func createListener(socketAddress string, mustBeValid bool) (*net.UnixListener, return listener, nil } -// awaitMessage opens a new connection to socketAddress -// and waits for a given message +// AwaitMessage accepts a connection from the listener and waits for the +// expected IPC message. It implements a 10-second timeout to prevent +// the process from blocking indefinitely (Fixes #405) func AwaitMessage(listener *net.UnixListener, expectedMessage IPCMessage) error { + timeout := 10 * time.Second + deadline := time.Now().Add(timeout) + + // Set deadline for the initial connection (Accept) + if err := listener.SetDeadline(deadline); err != nil { + return fmt.Errorf("failed to set listener deadline: %w", err) + } + conn, err := listener.AcceptUnix() if err != nil { - return err + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return fmt.Errorf("IPC handshake timeout: no connection received within %v", timeout) + } + return fmt.Errorf("failed to accept IPC connection: %w", err) } defer func() { - err = conn.Close() - if err != nil { - logrus.WithError(err).Error("failed to close connection") + if cerr := conn.Close(); cerr != nil { + logrus.WithError(cerr).Error("failed to close IPC connection") } }() + + // Set deadline for the actual data transfer (Read) + if err := conn.SetDeadline(deadline); err != nil { + return fmt.Errorf("failed to set connection deadline: %w", err) + } + + // io.ReadFull ensures we don't return early with a partial message buf := make([]byte, len(expectedMessage)) - n, err := conn.Read(buf) - if err != nil { - return fmt.Errorf("failed to read from socket: %w", err) + if _, err := io.ReadFull(conn, buf); err != nil { + if errors.Is(err, io.ErrUnexpectedEOF) { + return fmt.Errorf("connection closed before full message was received") + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return fmt.Errorf("IPC handshake timeout: message not received within %v", timeout) + } + return fmt.Errorf("failed to read from IPC socket: %w", err) } - msg := string(buf[0:n]) - if msg != string(expectedMessage) { - return fmt.Errorf("received unexpected message: %s (expected %s)", msg, expectedMessage) + + if string(buf) != string(expectedMessage) { + return fmt.Errorf("received unexpected message: %q (expected %q)", string(buf), expectedMessage) } + return nil }