Skip to content
23 changes: 21 additions & 2 deletions ipykernel/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,29 @@
class Heartbeat(Thread):
"""A simple ping-pong style heartbeat that runs in a thread."""

def __init__(self, context, addr=None):
"""Initialize the heartbeat thread."""
def __init__(self, context, addr=None, *, curve_publickey=None, curve_secretkey=None):
"""Initialize the heartbeat thread.

Parameters
----------
context : zmq.Context
addr : tuple, optional
(transport, ip, port)
curve_publickey : bytes, optional
CurveZMQ public key (Z85). When provided together with
*curve_secretkey*, the heartbeat socket will operate as a
CurveZMQ server so that only authenticated clients can connect.
curve_secretkey : bytes, optional
CurveZMQ secret key (Z85, paired with *curve_publickey*).
"""
if addr is None:
addr = ("tcp", localhost(), 0)
Thread.__init__(self, name="Heartbeat")
self.context = context
self.transport, self.ip, self.port = addr
self.original_port = self.port
self.curve_publickey = curve_publickey
self.curve_secretkey = curve_secretkey
Comment on lines +51 to +52
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add types at class level for the new fields.

if self.original_port == 0:
self.pick_port()
self.addr = (self.ip, self.port)
Expand Down Expand Up @@ -94,6 +109,10 @@ def run(self):
self.name = "Heartbeat"
self.socket = self.context.socket(zmq.ROUTER)
self.socket.linger = 1000
if self.curve_secretkey is not None:
self.socket.curve_secretkey = self.curve_secretkey
self.socket.curve_publickey = self.curve_publickey
self.socket.curve_server = True
try:
self._bind_socket()
except Exception:
Expand Down
51 changes: 50 additions & 1 deletion ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ def abs_connection_file(self):
""",
).tag(config=True)

enable_curve = Bool(
bool(int(os.environ.get("JUPYTER_ENABLE_CURVE", "0"))),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the name it's non obvious, i think it's ok to have "curve" in inner variable names, maybe not in env-var and options. ; should we also trim(), the value of environ.get, or be stricter on it's format ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think we could hide "curve" here as an implementation detail; an argument for keeping it as-is would be alignment with ipyparallel (ipython/ipyparallel#553).

I think an obvious choice would be enable_transport_encryption?

Same change would need to follow in jupyter-server/jupyter_server#1638.

CC @minrk in case if you have an opinion here

help="Enable CurveZMQ transport encryption and authentication. "
"When True, a keypair is generated at startup and stored in the "
"connection file so that clients can authenticate and encrypt "
"all ZMQ channels.",
).tag(config=True)

# Internal CurveZMQ keypair (Z85-encoded bytes); populated in init_sockets
# when enable_curve is True.
_curve_publickey: bytes | None = None
_curve_secretkey: bytes | None = None

# polling
parent_handle = Integer(
int(os.environ.get("JPY_PARENT_PID") or 0),
Expand All @@ -211,6 +224,17 @@ def excepthook(self, etype, evalue, tb):
# write uncaught traceback to 'real' stderr, not zmq-forwarder
traceback.print_exception(etype, evalue, tb, file=sys.__stderr__)

def _apply_curve_server_options(self, socket: zmq.Socket[t.Any]) -> None:
"""Set CurveZMQ server-side options on *socket* before it is bound.

This is a no-op when enable_curve is False or keys have not been
generated yet, so it is safe to call unconditionally.
"""
if self.enable_curve and self._curve_secretkey is not None:
socket.curve_secretkey = self._curve_secretkey
socket.curve_publickey = self._curve_publickey
socket.curve_server = True

def init_poller(self):
"""Initialize the poller."""
if sys.platform == "win32":
Expand Down Expand Up @@ -274,6 +298,10 @@ def write_connection_file(self, **kwargs: Any) -> None:
iopub_port=self.iopub_port,
control_port=self.control_port,
)
if self.enable_curve and self._curve_publickey is not None:
# write_connection_file() in jupyter-client handles JSON-safe key serialization
connection_info["curve_publickey"] = self._curve_publickey
connection_info["curve_secretkey"] = self._curve_secretkey
if Path(cf).exists():
# If the file exists, merge our info into it. For example, if the
# original file had port number 0, we update with the actual port
Expand Down Expand Up @@ -328,13 +356,27 @@ def init_sockets(self):
self.context = context = zmq.Context()
atexit.register(self.close)

if self.enable_curve:
self._curve_publickey, self._curve_secretkey = zmq.curve_keypair()
self.log.debug("CurveZMQ enabled; generated server keypair")
Comment thread
krassowski marked this conversation as resolved.
elif self.transport == "tcp":
self.log.warning(
"Kernel is running over TCP without encryption."
" All communication (including code and outputs) is sent in plain text"
" and is susceptible to eavesdropping."
" Use IPC transport or set IPKernelApp.enable_curve=True to enable"
" CurveZMQ encryption."
)
Comment on lines +363 to +369
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nbclient downstream tests are failing due to addition of this warning, see:

  E       AssertionError: assert '[IPKernelApp] WARNING | Kernel is running over TCP without encryption. All communication (including code and outputs) is sent in plain text and is susceptible to eavesdropping. Use IPC transport or set IPKernelApp.enable_curve=True to enable CurveZMQ encryption.\n[IPKernelApp] WARNING | Kernel is running over TCP without encryption. All communication (including code and outputs) is sent in plain text and is susceptible to eavesdropping. Use IPC transport or set IPKernelApp.enable_curve=True to enable CurveZMQ encryption.' == ''

I believe we should keep it and update nbclient tests, any objections?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to fix nbclient.


self.shell_socket = context.socket(zmq.ROUTER)
self.shell_socket.linger = 1000
self._apply_curve_server_options(self.shell_socket)
self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
self.log.debug("shell ROUTER Channel on port: %i", self.shell_port)

self.stdin_socket = context.socket(zmq.ROUTER)
self.stdin_socket.linger = 1000
self._apply_curve_server_options(self.stdin_socket)
self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
self.log.debug("stdin ROUTER Channel on port: %i", self.stdin_port)

Expand All @@ -351,6 +393,7 @@ def init_control(self, context):
"""Initialize the control channel."""
self.control_socket = context.socket(zmq.ROUTER)
self.control_socket.linger = 1000
self._apply_curve_server_options(self.control_socket)
self.control_port = self._bind_socket(self.control_socket, self.control_port)
self.log.debug("control ROUTER Channel on port: %i", self.control_port)

Expand Down Expand Up @@ -379,6 +422,7 @@ def init_iopub(self, context):
"""Initialize the iopub channel."""
self.iopub_socket = context.socket(zmq.XPUB)
self.iopub_socket.linger = 1000
self._apply_curve_server_options(self.iopub_socket)
self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
self.log.debug("iopub PUB Channel on port: %i", self.iopub_port)
self.configure_tornado_logger()
Expand All @@ -392,7 +436,12 @@ def init_heartbeat(self):
# heartbeat doesn't share context, because it mustn't be blocked
# by the GIL, which is accessed by libzmq when freeing zero-copy messages
hb_ctx = zmq.Context()
self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
self.heartbeat = Heartbeat(
hb_ctx,
(self.transport, self.ip, self.hb_port),
curve_publickey=self._curve_publickey if self.enable_curve else None,
curve_secretkey=self._curve_secretkey if self.enable_curve else None,
)
self.hb_port = self.heartbeat.port
self.log.debug("Heartbeat REP Channel on port: %i", self.hb_port)
self.heartbeat.start()
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"ipython>=7.23.1",
"comm>=0.1.1",
"traitlets>=5.4.0",
"jupyter_client>=8.8.0",
"jupyter_client @ git+https://github.com/krassowski/jupyter_client.git@add-curve-encryption",
"jupyter_core>=5.1,!=6.0.*",
# For tk event loop support only.
"nest_asyncio2>=1.7.0",
Expand Down Expand Up @@ -71,6 +71,9 @@ cov = [
pyqt5 = ["pyqt5"]
pyside6 = ["pyside6"]

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.version]
path = "ipykernel/_version.py"

Expand Down
178 changes: 178 additions & 0 deletions tests/test_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

import json
import os
import time

import pytest
import zmq

from ipykernel.kernelapp import IPKernelApp


@pytest.fixture
def temp_folder_path(tmp_path):
return str(tmp_path)
Comment on lines +14 to +16
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fix _make_app? It's the only utils that will receive the result of this fixture ?



@pytest.fixture
def curve_disabled_kernel_app(temp_folder_path):
app, connection_file_path = _make_app(temp_folder_path, enable_curve=False)
try:
yield app, connection_file_path
finally:
app.close()


@pytest.fixture
def curve_enabled_kernel_app(temp_folder_path):
app, connection_file_path = _make_app(temp_folder_path, enable_curve=True)
try:
yield app, connection_file_path
finally:
app.close()


def test_curve_disabled_by_default():
"""CurveZMQ must be off by default so existing kernels are unaffected."""
app = IPKernelApp()
assert app.enable_curve is False


def test_connection_file_no_curve_keys_by_default(curve_disabled_kernel_app):
"""Connection file must not contain curve keys when Curve is disabled."""
app, connection_file_path = curve_disabled_kernel_app
app.init_sockets()
app.init_heartbeat()
app.write_connection_file()
with open(connection_file_path) as f:
info = json.load(f)
assert "curve_publickey" not in info
assert "curve_secretkey" not in info


def test_curve_connection_file_has_keys(curve_enabled_kernel_app):
"""When Curve is enabled the connection file must carry both keys."""
app, connection_file_path = curve_enabled_kernel_app
app.init_sockets()
app.init_heartbeat()
app.write_connection_file()
with open(connection_file_path) as f:
info = json.load(f)
assert "curve_publickey" in info, "curve_publickey missing from connection file"
assert "curve_secretkey" in info, "curve_secretkey missing from connection file"
# Keys are Z85-encoded ASCII strings - always exactly 40 characters.
assert len(info["curve_publickey"]) == 40
assert len(info["curve_secretkey"]) == 40
# Existing fields must still be present (backward-compat check).
assert "key" in info
assert "shell_port" in info


def test_curve_keys_are_stable_per_startup(curve_enabled_kernel_app):
"""Keys generated at startup stay the same throughout the process lifetime."""
app, _connection_file_path = curve_enabled_kernel_app
app.init_sockets()
pub1 = app._curve_publickey
# Writing the file twice should not regenerate keys.
app.init_heartbeat()
app.write_connection_file()
assert app._curve_publickey == pub1


def test_curve_socket_server_options(curve_enabled_kernel_app):
"""Bound sockets must have CURVE_SERVER=True when Curve is enabled."""
app, _connection_file_path = curve_enabled_kernel_app
app.init_sockets()
# shell and stdin are ROUTER sockets configured directly.
assert app.shell_socket.curve_server, "shell_socket missing curve_server"
assert app.stdin_socket.curve_server, "stdin_socket missing curve_server"
assert app.control_socket.curve_server, "control_socket missing curve_server"
# Key material is write-only in pyzmq; we verify it was applied
# through the curve_server flag and the reject test below.


def test_no_curve_socket_options_when_disabled(curve_disabled_kernel_app):
"""No CURVE options are set when Curve is disabled (default)."""
app, _connection_file_path = curve_disabled_kernel_app
app.init_sockets()
# curve_server defaults to 0/False; key options are write-only.
assert not app.shell_socket.curve_server


def test_curve_unauthenticated_socket_messages_dropped(curve_enabled_kernel_app):
"""With CurveZMQ, frames from a socket without the server key are dropped.

This is the core security property: a raw DEALER socket that connects to
a CURVE_SERVER-enabled ROUTER cannot deliver messages to it. Compare
with test_transport_security.py in jupyter-client which shows the *absence*
of this property today.
"""
app, _connection_file_path = curve_enabled_kernel_app
app.init_sockets()

# Build the endpoint URL from the bound port.
endpoint = f"tcp://{app.ip}:{app.shell_port}"

ctx = zmq.Context()
unauth = ctx.socket(zmq.DEALER)
try:
unauth.connect(endpoint)
# ZMQ delivers the connect synchronously, but the curve
# handshake silently drops the message.
unauth.send(b"probe", flags=zmq.NOBLOCK)

poller = zmq.Poller()
poller.register(app.shell_socket, zmq.POLLIN)
events = dict(poller.poll(timeout=300))
assert app.shell_socket not in events, (
"Unauthenticated message reached the kernel socket - CurveZMQ should have dropped it"
)
finally:
unauth.close(linger=0)
ctx.term()


def test_curve_authenticated_socket_can_communicate(curve_enabled_kernel_app):
"""With CurveZMQ, a correctly-keyed client socket can reach the kernel."""
app, _connection_file_path = curve_enabled_kernel_app
app.init_sockets()

endpoint = f"tcp://{app.ip}:{app.shell_port}"
server_public = app._curve_publickey

ctx = zmq.Context()
auth_client = ctx.socket(zmq.DEALER)
# Client uses the server's public key as CURVE_SERVERKEY; its own
# keypair is used only for encryption, not for access control.
client_pub, client_sec = zmq.curve_keypair()
auth_client.curve_secretkey = client_sec
auth_client.curve_publickey = client_pub
auth_client.curve_serverkey = server_public
try:
auth_client.connect(endpoint)
# Allow the handshake to complete.
time.sleep(0.05)
auth_client.send(b"probe", flags=zmq.NOBLOCK)

poller = zmq.Poller()
poller.register(app.shell_socket, zmq.POLLIN)
events = dict(poller.poll(timeout=1000))
assert app.shell_socket in events, (
"Authenticated client message was not received by kernel socket"
)
finally:
auth_client.close(linger=0)
ctx.term()


def _make_app(temp_folder_path, **kwargs):
"""Return a minimal IPKernelApp rooted in temporary directory *temp_folder_path*."""
connection_file_path = os.path.join(temp_folder_path, "kernel.json")
app = IPKernelApp(connection_file=connection_file_path, **kwargs)
# Replicate the subset of initialize() that sets up connection info
# without importing IPython shell machinery.
super(IPKernelApp, app).initialize(argv=[""])
app.init_connection_file()
return app, connection_file_path
24 changes: 24 additions & 0 deletions tests/test_kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,27 @@ def test_trio_loop():
app.io_loop.add_callback(app.io_loop.stop)
app.kernel.destroy()
app.close()


def test_init_sockets_curve_enabled_logs_debug():
app = IPKernelApp(enable_curve=True)
with patch.object(app.log, "debug") as mock_debug:
app.init_sockets()
app.cleanup_connection_file()
app.close()
messages = [str(call) for call in mock_debug.call_args_list]
assert any("CurveZMQ enabled" in m for m in messages), (
"Expected a debug log mentioning CurveZMQ when enable_curve=True"
)


def test_init_sockets_tcp_without_curve_logs_warning():
app = IPKernelApp(transport="tcp", enable_curve=False)
with patch.object(app.log, "warning") as mock_warning:
app.init_sockets()
app.cleanup_connection_file()
app.close()
messages = [str(call) for call in mock_warning.call_args_list]
assert any("Kernel is running over TCP without encryption" in m for m in messages), (
"Expected a warning about missing encryption when transport=tcp and enable_curve=False"
)
Loading