diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 464cc13f..d5517af5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -408,15 +408,15 @@ jobs: RAW=$(printf "%s" "$RAW" | sed -n '$p') # Validate RAW is JSON + # Validate JSON; do NOT exit — allow retry if ! printf '%s' "$RAW" | python3 -c 'import sys,json; json.load(sys.stdin)' >/dev/null 2>&1; then - echo "Token endpoint did not return valid JSON:" - printf '%s\n' "$RAW" - exit 1 + echo "Token endpoint did not return valid JSON, retrying..." + TOKEN="" + else + # Extract token only if JSON is valid + TOKEN=$(printf '%s' "$RAW" | python3 -c 'import sys,json; print(json.load(sys.stdin).get("access_token", ""))') fi - # Extract token (without printing it) - TOKEN=$(printf '%s' "$RAW" | python3 -c 'import sys,json; print(json.load(sys.stdin).get("access_token", ""))') - if [ -n "$TOKEN" ] && [ "$TOKEN" != "null" ]; then echo "Access token retrieved successfully." break diff --git a/setup.py b/setup.py index 7b5a7fe9..9bfe1e8f 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ # version should use the format 'x.x.x' (instead of 'vx.x.x') setup( name='vertica-python', - version='1.4.0', + version='1.5.0', description='Official native Python client for the Vertica database.', long_description="vertica-python is the official Vertica database client for the Python programming language. Please check the [project homepage](https://github.com/vertica/vertica-python) for the details.", long_description_content_type='text/markdown', @@ -59,6 +59,7 @@ python_requires=">=3.8", install_requires=[ 'python-dateutil>=1.5', + 'pyotp>=2.9.0', ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/vertica_python/tests/integration_tests/test_authentication.py b/vertica_python/tests/integration_tests/test_authentication.py index 6080480c..85503b54 100644 --- a/vertica_python/tests/integration_tests/test_authentication.py +++ b/vertica_python/tests/integration_tests/test_authentication.py @@ -123,6 +123,121 @@ def test_oauth_access_token(self): cur.execute("SELECT authentication_method FROM sessions WHERE session_id=(SELECT current_session())") res = cur.fetchone() self.assertEqual(res[0], 'OAuth') + # ------------------------------- + # TOTP Authentication Test for Vertica-Python Driver + # ------------------------------- + import os + import pyotp + from io import StringIO + import sys -exec(AuthenticationTestCase.createPrepStmtClass()) + # Positive TOTP Test (Like SHA512 format) + def totp_positive_scenario(self): + with self._connect() as conn: + cur = conn.cursor() + + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + + try: + # Create user with MFA + cur.execute("CREATE USER totp_user IDENTIFIED BY 'password' ENFORCEMFA") + + # Grant authentication + # Note: METHOD is 'trusted' or 'password' depending on how MFA is enforced in Vertica + cur.execute("CREATE AUTHENTICATION totp_auth METHOD 'password' HOST '0.0.0.0/0'") + cur.execute("GRANT AUTHENTICATION totp_auth TO totp_user") + + # Generate TOTP + TOTP_SECRET = "O5D7DQICJTM34AZROWHSAO4O53ELRJN3" + totp_code = pyotp.TOTP(TOTP_SECRET).now() + + # Set connection info + self._conn_info['user'] = 'totp_user' + self._conn_info['password'] = 'password' + self._conn_info['totp'] = totp_code + + # Try connection + with self._connect() as totp_conn: + c = totp_conn.cursor() + c.execute("SELECT 1") + res = c.fetchone() + self.assertEqual(res[0], 1) + + finally: + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + + # Negative Test: Missing TOTP + def totp_missing_code_scenario(self): + with self._connect() as conn: + cur = conn.cursor() + + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + + try: + cur.execute("CREATE USER totp_user IDENTIFIED BY 'password' ENFORCEMFA") + cur.execute("CREATE AUTHENTICATION totp_auth METHOD 'password' HOST '0.0.0.0/0'") + cur.execute("GRANT AUTHENTICATION totp_auth TO totp_user") + + self._conn_info['user'] = 'totp_user' + self._conn_info['password'] = 'password' + self._conn_info.pop('totp', None) # No TOTP + + err_msg = "TOTP was requested but not provided" + self.assertConnectionFail(err_msg=err_msg) + + finally: + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + + # Negative Test: Invalid TOTP Format + def totp_invalid_format_scenario(self): + with self._connect() as conn: + cur = conn.cursor() + + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + + try: + cur.execute("CREATE USER totp_user IDENTIFIED BY 'password' ENFORCEMFA") + cur.execute("CREATE AUTHENTICATION totp_auth METHOD 'password' HOST '0.0.0.0/0'") + cur.execute("GRANT AUTHENTICATION totp_auth TO totp_user") + + self._conn_info['user'] = 'totp_user' + self._conn_info['password'] = 'password' + self._conn_info['totp'] = "123" # Invalid + + err_msg = "Invalid TOTP format" + self.assertConnectionFail(err_msg=err_msg) + + finally: + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + + # Negative Test: Wrong TOTP (Valid format, wrong value) + def totp_wrong_code_scenario(self): + with self._connect() as conn: + cur = conn.cursor() + + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + + try: + cur.execute("CREATE USER totp_user IDENTIFIED BY 'password' ENFORCEMFA") + cur.execute("CREATE AUTHENTICATION totp_auth METHOD 'password' HOST '0.0.0.0/0'") + cur.execute("GRANT AUTHENTICATION totp_auth TO totp_user") + + self._conn_info['user'] = 'totp_user' + self._conn_info['password'] = 'password' + self._conn_info['totp'] = "999999" # Wrong OTP + + err_msg = "Invalid TOTP" + self.assertConnectionFail(err_msg=err_msg) + + finally: + cur.execute("DROP USER IF EXISTS totp_user") + cur.execute("DROP AUTHENTICATION IF EXISTS totp_auth CASCADE") + diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index 94159502..0d0e6a54 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -44,6 +44,11 @@ import ssl import uuid import warnings +import re +import time +import signal +import select +import sys from collections import deque from struct import unpack @@ -303,6 +308,13 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self.address_list = _AddressList(self.options['host'], self.options['port'], self.options['backup_server_node'], self._logger) + # TOTP support + self.totp = self.options.get('totp') + if self.totp is not None: + if not isinstance(self.totp, str): + raise TypeError('The value of connection option "totp" should be a string') + self._logger.info('TOTP received in connection options') + # OAuth authentication setup self.options.setdefault('oauth_access_token', DEFAULT_OAUTH_ACCESS_TOKEN) if not isinstance(self.options['oauth_access_token'], str): @@ -918,16 +930,112 @@ def startup_connection(self) -> None: else: auth_category = '' - self.write(messages.Startup(user, database, session_label, os_user_name, autocommit, binary_transfer, - request_complex_types, oauth_access_token, workload, auth_category)) + # Check if user has provided TOTP in options + totp = self.options.get("totp", None) + retried_totp = False + + def send_startup(totp_value=None): + self.write(messages.Startup( + user, database, session_label, os_user_name, + autocommit, binary_transfer, request_complex_types, + oauth_access_token, workload, auth_category, + totp_value + )) + + send_startup(totp_value=totp) # ✅ First attempt while True: message = self.read_message() - + self._logger.debug(f"Received message: {type(message).__name__}") + self._logger.debug(f"Message code: {getattr(message, 'code', None)}") if isinstance(message, messages.Authentication): if message.code == messages.Authentication.OK: self._logger.info("User {} successfully authenticated" .format(self.options['user'])) + # 🔁 Continue reading messages after successful authentication + while True: + message = self.read_message() + self._logger.debug(f"Post-auth message: {type(message).__name__}") + if isinstance(message, messages.ReadyForQuery): + self.transaction_status = message.transaction_status + # self.session_id = message.session_id + self._logger.info("Connection is ready") + break + elif isinstance(message, messages.ParameterStatus): + self.parameters[message.key] = message.value + elif isinstance(message, messages.BackendKeyData): + self.backend_pid = message.pid + self.backend_key = message.key + elif isinstance(message, messages.ErrorResponse): + error_msg = message.error_message() + + # Extract only the "Message: ..." part + match = re.search(r'Message: (.+?)(?:, Sqlstate|$)', error_msg, re.DOTALL) + short_msg = match.group(1).strip() if match else error_msg.strip() + + if "Invalid TOTP" in short_msg: + print("Authentication failed: Invalid TOTP token.") + self._logger.error("Authentication failed: Invalid TOTP token.") + self.close_socket() + raise errors.ConnectionError("Authentication failed: Invalid TOTP token.") + + # Generic error fallback + print(f"Authentication failed: {short_msg}") + self._logger.error(short_msg) + raise errors.ConnectionError(f"Authentication failed: {short_msg}") + else: + self._logger.warning(f"Unexpected message type: {type(message).__name__}") + + break + elif message.code == messages.Authentication.TOTP: + if retried_totp: + raise errors.ConnectionError("TOTP authentication failed.") + + # ✅ If TOTP not provided initially, prompt only once + if not totp: + timeout_seconds = 30 # 5 minutes timeout + try: + print("Enter TOTP: ", end="", flush=True) + ready, _, _ = select.select([sys.stdin], [], [], timeout_seconds) + if ready: + totp_input = sys.stdin.readline().strip() + + # ❌ Blank TOTP entered + if not totp_input: + self._logger.error("Invalid TOTP: Cannot be empty.") + raise errors.ConnectionError("Invalid TOTP: Cannot be empty.") + + # ❌ Validate TOTP format (must be 6 digits) + if not totp_input.isdigit() or len(totp_input) != 6: + print("Invalid TOTP format. Please enter a 6-digit code.") + self._logger.error("Invalid TOTP format entered.") + raise errors.ConnectionError("Invalid TOTP format: Must be a 6-digit number.") + # ✅ Valid TOTP — retry connection + totp = totp_input + self.close_socket() + self.socket = self.establish_socket_connection(self.address_list) + self._logger.info(f"Retrying with TOTP: '{totp}'") + + # ✅ Re-init required attributes + self.backend_pid = 0 + self.backend_key = 0 + self.transaction_status = None + self.session_id = None + + self._logger.debug("Startup message sent with TOTP.") + send_startup(totp_value=totp) + + else: + self._logger.error("Session timeout: No TOTP entered within time limit.") + self.close_socket() + raise errors.ConnectionError("Session timeout: No TOTP entered within time limit.") + except (KeyboardInterrupt, EOFError): + raise errors.ConnectionError("TOTP input cancelled.") + else: + raise errors.ConnectionError("TOTP was requested but not provided.") + retried_totp = True + continue + elif message.code == messages.Authentication.CHANGE_PASSWORD: msg = "The password for user {} has expired".format(self.options['user']) self._logger.error(msg)