From 2ed91a30f742975dc2b7a8b209ac9af898d5d71f Mon Sep 17 00:00:00 2001 From: Nas Date: Tue, 31 Mar 2026 07:57:51 +1100 Subject: [PATCH 1/3] Quote postgres column identifiers to handle reserved keywords --- .../implementations/_postgres_helpers.py | 39 +++++++- .../implementations/postgresql_loader.py | 7 +- tests/unit/test_postgres_helpers.py | 91 +++++++++++++++++++ 3 files changed, 131 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_postgres_helpers.py diff --git a/src/amp/loaders/implementations/_postgres_helpers.py b/src/amp/loaders/implementations/_postgres_helpers.py index eb0f71e..a0d4157 100644 --- a/src/amp/loaders/implementations/_postgres_helpers.py +++ b/src/amp/loaders/implementations/_postgres_helpers.py @@ -7,6 +7,31 @@ from pyarrow import csv +def _quote_identifier(name: str) -> str: + """Return a safely double-quoted PostgreSQL identifier. + + Wraps the name in double quotes and escapes any embedded double quotes + by doubling them (standard SQL identifier quoting rules). This prevents + syntax errors when column or table names collide with PostgreSQL reserved + keywords such as `to`, `end`, `index`, etc. + + Args: + name: Raw identifier (column or table name) + + Returns: + Double-quoted identifier safe for use in SQL DDL/DML statements + + Examples: + >>> _quote_identifier('to') + '"to"' + >>> _quote_identifier('block_num') + '"block_num"' + >>> _quote_identifier('weird"name') + '"weird\"\"name"' + """ + return '"' + name.replace('"', '""') + '"' + + def prepare_csv_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[io.StringIO, List[str]]: """ Convert Arrow data to CSV format optimized for PostgreSQL COPY. @@ -43,7 +68,8 @@ def prepare_csv_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[io.StringIO csv_buffer = io.StringIO(csv_data) - # Get column names from Arrow schema + # Get column names from Arrow schema (raw, unquoted — quoting is the + # caller's responsibility when constructing SQL identifiers) column_names = [field.name for field in data.schema] return csv_buffer, column_names @@ -103,12 +129,14 @@ def prepare_insert_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[str, Lis # Convert Arrow data to Python objects data_dict = data.to_pydict() - # Get column names + # Get column names and quote each one so reserved keywords (e.g. `to`, + # `end`, `index`) do not cause syntax errors in the INSERT statement. column_names = [field.name for field in data.schema] + quoted_columns = [_quote_identifier(c) for c in column_names] # Create INSERT statement template placeholders = ', '.join(['%s'] * len(column_names)) - insert_sql = f'({", ".join(column_names)}) VALUES ({placeholders})' + insert_sql = f'({", ".join(quoted_columns)}) VALUES ({placeholders})' # Prepare data for insertion rows = [] @@ -123,6 +151,11 @@ def prepare_insert_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[str, Lis return insert_sql, rows +def quote_column_names(schema: pa.Schema) -> List[str]: + """Return a list of double-quoted identifiers for all columns in schema.""" + return [_quote_identifier(field.name) for field in schema] + + def has_binary_columns(schema: pa.Schema) -> bool: """Check if schema contains any binary column types.""" return any( diff --git a/src/amp/loaders/implementations/postgresql_loader.py b/src/amp/loaders/implementations/postgresql_loader.py index 2591b52..b249d6c 100644 --- a/src/amp/loaders/implementations/postgresql_loader.py +++ b/src/amp/loaders/implementations/postgresql_loader.py @@ -7,7 +7,7 @@ from ...streaming.state import BatchIdentifier from ...streaming.types import BlockRange from ..base import DataLoader, LoadMode -from ._postgres_helpers import has_binary_columns, prepare_csv_data, prepare_insert_data +from ._postgres_helpers import has_binary_columns, prepare_csv_data, prepare_insert_data, quote_column_names @dataclass @@ -215,10 +215,11 @@ def _copy_arrow_data(self, cursor: Any, data: Union[pa.RecordBatch, pa.Table], t def _csv_copy_arrow_data(self, cursor: Any, data: Union[pa.RecordBatch, pa.Table], table_name: str) -> None: """Use CSV COPY for non-binary data.""" - csv_buffer, column_names = prepare_csv_data(data) + csv_buffer, _ = prepare_csv_data(data) + quoted_column_names = quote_column_names(data.schema) try: - cursor.copy_from(csv_buffer, table_name, columns=column_names, sep='\t', null='\\N') + cursor.copy_from(csv_buffer, table_name, columns=quoted_column_names, sep='\t', null='\\N') except Exception as e: if 'does not exist' in str(e): raise RuntimeError( diff --git a/tests/unit/test_postgres_helpers.py b/tests/unit/test_postgres_helpers.py new file mode 100644 index 0000000..eac467f --- /dev/null +++ b/tests/unit/test_postgres_helpers.py @@ -0,0 +1,91 @@ +""" +Unit tests for PostgreSQL SQL identifier quoting in _postgres_helpers.py. + +Verifies that reserved keyword column names (e.g. 'to', 'from') are double-quoted in +generated INSERT and COPY SQL to prevent syntax errors. +""" + +import pytest +import pyarrow as pa + +from amp.loaders.implementations._postgres_helpers import ( + prepare_insert_data, + quote_column_names, +) + + +@pytest.fixture +def eth_tx_batch(): + """Arrow RecordBatch modelling Ethereum transaction data with reserved keyword column names.""" + schema = pa.schema( + [ + pa.field('block_hash', pa.binary(32), nullable=False), + pa.field('block_num', pa.uint64(), nullable=False), + pa.field('tx_index', pa.uint32(), nullable=False), + pa.field('tx_hash', pa.binary(32), nullable=False), + pa.field('to', pa.binary(20), nullable=True), # reserved keyword; nullable for contract creation + pa.field('nonce', pa.uint64(), nullable=False), + pa.field('value', pa.decimal128(38, 0), nullable=False), + pa.field('from', pa.binary(20), nullable=False), # reserved keyword + ] + ) + return pa.RecordBatch.from_arrays( + [ + pa.array([b'\x01' * 32, b'\x02' * 32], type=pa.binary(32)), + pa.array([18_000_000, 18_000_001], type=pa.uint64()), + pa.array([0, 1], type=pa.uint32()), + pa.array([b'\x03' * 32, b'\x04' * 32], type=pa.binary(32)), + pa.array([b'\xaa' * 20, None], type=pa.binary(20)), # None = contract creation tx + pa.array([0, 1], type=pa.uint64()), + pa.array([1_000_000_000, 2_000_000_000], type=pa.decimal128(38, 0)), + pa.array([b'\xbb' * 20, b'\xcc' * 20], type=pa.binary(20)), + ], + schema=schema, + ) + + + + +@pytest.mark.unit +class TestInsertSqlIdentifierQuoting: + """ + prepare_insert_data() must double-quote all column names in the generated + INSERT SQL template to prevent reserved-keyword syntax errors. + """ + + def test_all_column_names_are_quoted_in_insert_sql(self, eth_tx_batch): + """Every column in the INSERT template must be wrapped in double quotes.""" + sql_template, _ = prepare_insert_data(eth_tx_batch) + + for col_name in eth_tx_batch.schema.names: + assert f'"{col_name}"' in sql_template, ( + f"Column '{col_name}' must be double-quoted in the INSERT SQL template.\n" + f"Generated template: {sql_template}" + ) + + def test_placeholder_count_matches_column_count(self, eth_tx_batch): + """The VALUES clause must have exactly one %s placeholder per column.""" + sql_template, _ = prepare_insert_data(eth_tx_batch) + + assert sql_template.count('%s') == len(eth_tx_batch.schema) + + def test_row_count_preserved(self, eth_tx_batch): + """The returned rows list must contain one tuple per input row.""" + _, rows = prepare_insert_data(eth_tx_batch) + + assert len(rows) == eth_tx_batch.num_rows + + +@pytest.mark.unit +class TestQuoteColumnNames: + """quote_column_names() must double-quote every column name in a schema.""" + + def test_all_columns_quoted(self): + schema = pa.schema([('block_num', pa.int64()), ('to', pa.string()), ('value', pa.string())]) + result = quote_column_names(schema) + assert result == ['"block_num"', '"to"', '"value"'] + + def test_embedded_double_quote_escaped(self): + schema = pa.schema([('weird"name', pa.string())]) + result = quote_column_names(schema) + assert result == ['"weird""name"'] From 3a387f177e915603fbd437394bd3c9fcb836fc6b Mon Sep 17 00:00:00 2001 From: Nas Date: Tue, 31 Mar 2026 08:17:10 +1100 Subject: [PATCH 2/3] Format and linting --- tests/unit/test_postgres_helpers.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/unit/test_postgres_helpers.py b/tests/unit/test_postgres_helpers.py index eac467f..48dc887 100644 --- a/tests/unit/test_postgres_helpers.py +++ b/tests/unit/test_postgres_helpers.py @@ -5,8 +5,8 @@ generated INSERT and COPY SQL to prevent syntax errors. """ -import pytest import pyarrow as pa +import pytest from amp.loaders.implementations._postgres_helpers import ( prepare_insert_data, @@ -20,23 +20,23 @@ def eth_tx_batch(): schema = pa.schema( [ pa.field('block_hash', pa.binary(32), nullable=False), - pa.field('block_num', pa.uint64(), nullable=False), - pa.field('tx_index', pa.uint32(), nullable=False), - pa.field('tx_hash', pa.binary(32), nullable=False), - pa.field('to', pa.binary(20), nullable=True), # reserved keyword; nullable for contract creation - pa.field('nonce', pa.uint64(), nullable=False), - pa.field('value', pa.decimal128(38, 0), nullable=False), - pa.field('from', pa.binary(20), nullable=False), # reserved keyword + pa.field('block_num', pa.uint64(), nullable=False), + pa.field('tx_index', pa.uint32(), nullable=False), + pa.field('tx_hash', pa.binary(32), nullable=False), + pa.field('to', pa.binary(20), nullable=True), # reserved keyword; nullable for contract creation + pa.field('nonce', pa.uint64(), nullable=False), + pa.field('value', pa.decimal128(38, 0), nullable=False), + pa.field('from', pa.binary(20), nullable=False), # reserved keyword ] ) return pa.RecordBatch.from_arrays( [ pa.array([b'\x01' * 32, b'\x02' * 32], type=pa.binary(32)), - pa.array([18_000_000, 18_000_001], type=pa.uint64()), - pa.array([0, 1], type=pa.uint32()), + pa.array([18_000_000, 18_000_001], type=pa.uint64()), + pa.array([0, 1], type=pa.uint32()), pa.array([b'\x03' * 32, b'\x04' * 32], type=pa.binary(32)), - pa.array([b'\xaa' * 20, None], type=pa.binary(20)), # None = contract creation tx - pa.array([0, 1], type=pa.uint64()), + pa.array([b'\xaa' * 20, None], type=pa.binary(20)), # None = contract creation tx + pa.array([0, 1], type=pa.uint64()), pa.array([1_000_000_000, 2_000_000_000], type=pa.decimal128(38, 0)), pa.array([b'\xbb' * 20, b'\xcc' * 20], type=pa.binary(20)), ], @@ -44,8 +44,6 @@ def eth_tx_batch(): ) - - @pytest.mark.unit class TestInsertSqlIdentifierQuoting: """ @@ -60,7 +58,7 @@ def test_all_column_names_are_quoted_in_insert_sql(self, eth_tx_batch): for col_name in eth_tx_batch.schema.names: assert f'"{col_name}"' in sql_template, ( f"Column '{col_name}' must be double-quoted in the INSERT SQL template.\n" - f"Generated template: {sql_template}" + f'Generated template: {sql_template}' ) def test_placeholder_count_matches_column_count(self, eth_tx_batch): From 1004d20d62c39dec49c8e57312da65428f9a39ba Mon Sep 17 00:00:00 2001 From: Nas Date: Tue, 31 Mar 2026 09:55:02 +1100 Subject: [PATCH 3/3] psycopg2's copy_from columns param already handles quoting --- .../implementations/_postgres_helpers.py | 5 ----- .../implementations/postgresql_loader.py | 7 +++---- tests/unit/test_postgres_helpers.py | 20 +------------------ 3 files changed, 4 insertions(+), 28 deletions(-) diff --git a/src/amp/loaders/implementations/_postgres_helpers.py b/src/amp/loaders/implementations/_postgres_helpers.py index a0d4157..8b47635 100644 --- a/src/amp/loaders/implementations/_postgres_helpers.py +++ b/src/amp/loaders/implementations/_postgres_helpers.py @@ -151,11 +151,6 @@ def prepare_insert_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[str, Lis return insert_sql, rows -def quote_column_names(schema: pa.Schema) -> List[str]: - """Return a list of double-quoted identifiers for all columns in schema.""" - return [_quote_identifier(field.name) for field in schema] - - def has_binary_columns(schema: pa.Schema) -> bool: """Check if schema contains any binary column types.""" return any( diff --git a/src/amp/loaders/implementations/postgresql_loader.py b/src/amp/loaders/implementations/postgresql_loader.py index b249d6c..2591b52 100644 --- a/src/amp/loaders/implementations/postgresql_loader.py +++ b/src/amp/loaders/implementations/postgresql_loader.py @@ -7,7 +7,7 @@ from ...streaming.state import BatchIdentifier from ...streaming.types import BlockRange from ..base import DataLoader, LoadMode -from ._postgres_helpers import has_binary_columns, prepare_csv_data, prepare_insert_data, quote_column_names +from ._postgres_helpers import has_binary_columns, prepare_csv_data, prepare_insert_data @dataclass @@ -215,11 +215,10 @@ def _copy_arrow_data(self, cursor: Any, data: Union[pa.RecordBatch, pa.Table], t def _csv_copy_arrow_data(self, cursor: Any, data: Union[pa.RecordBatch, pa.Table], table_name: str) -> None: """Use CSV COPY for non-binary data.""" - csv_buffer, _ = prepare_csv_data(data) - quoted_column_names = quote_column_names(data.schema) + csv_buffer, column_names = prepare_csv_data(data) try: - cursor.copy_from(csv_buffer, table_name, columns=quoted_column_names, sep='\t', null='\\N') + cursor.copy_from(csv_buffer, table_name, columns=column_names, sep='\t', null='\\N') except Exception as e: if 'does not exist' in str(e): raise RuntimeError( diff --git a/tests/unit/test_postgres_helpers.py b/tests/unit/test_postgres_helpers.py index 48dc887..7b27c33 100644 --- a/tests/unit/test_postgres_helpers.py +++ b/tests/unit/test_postgres_helpers.py @@ -8,10 +8,7 @@ import pyarrow as pa import pytest -from amp.loaders.implementations._postgres_helpers import ( - prepare_insert_data, - quote_column_names, -) +from amp.loaders.implementations._postgres_helpers import prepare_insert_data @pytest.fixture @@ -72,18 +69,3 @@ def test_row_count_preserved(self, eth_tx_batch): _, rows = prepare_insert_data(eth_tx_batch) assert len(rows) == eth_tx_batch.num_rows - - -@pytest.mark.unit -class TestQuoteColumnNames: - """quote_column_names() must double-quote every column name in a schema.""" - - def test_all_columns_quoted(self): - schema = pa.schema([('block_num', pa.int64()), ('to', pa.string()), ('value', pa.string())]) - result = quote_column_names(schema) - assert result == ['"block_num"', '"to"', '"value"'] - - def test_embedded_double_quote_escaped(self): - schema = pa.schema([('weird"name', pa.string())]) - result = quote_column_names(schema) - assert result == ['"weird""name"']