diff --git a/.gitignore b/.gitignore index 7dde91a..b14f868 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ test_odbc.sqlite lars_notes.md .claude/ ODBC +/docs/superpowers/plans/ diff --git a/src/connection.rs b/src/connection.rs deleted file mode 100644 index a2fa15e..0000000 --- a/src/connection.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[derive(Debug)] -pub struct ConnectionClass {} diff --git a/src/lib.rs b/src/lib.rs index 0c0d785..d55f412 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,16 @@ use odbc_sys::{Integer, WChar}; -mod connection; mod logging; mod odbc; +/// Registers the SQLite implementation with the driver's handle factory. +/// Called once at driver startup (from SQLAllocHandle for environment handles). +pub(crate) fn init_driver() { + odbc::handles::register_factory(Box::new( + odbc::implementation::query::SqliteDbConnectionFactory, + )); +} + // Cross-platform ODBC library linking configuration // These attributes handle the complexity of linking against different ODBC implementations // across Windows, Linux, and macOS with support for both static and dynamic linking. diff --git a/src/odbc.rs b/src/odbc.rs index 554c5e1..bcffb75 100644 --- a/src/odbc.rs +++ b/src/odbc.rs @@ -1,4 +1,5 @@ mod api; mod def; -mod implementation; +pub(crate) mod handles; +pub(crate) mod implementation; mod utils; diff --git a/src/odbc/api/sqlallochandle.rs b/src/odbc/api/sqlallochandle.rs index 9059cbf..614b3ec 100644 --- a/src/odbc/api/sqlallochandle.rs +++ b/src/odbc/api/sqlallochandle.rs @@ -1,8 +1,5 @@ use crate::logging; -use crate::odbc::implementation::alloc_handles::{ - ConnectionHandle, EnvironmentHandle, allocate_stmt_handle, impl_allocate_dbc_handle, - impl_allocate_environment_handle, -}; +use crate::odbc::handles::{ConnectionHandle, EnvironmentHandle}; use crate::odbc::utils::{get_from_wrapper, wrap_and_set}; use odbc_sys::{HandleType, Pointer, SmallInt, SqlReturn}; use tracing::{debug, error, info}; @@ -50,8 +47,10 @@ pub extern "C" fn SQLAllocHandle( return SqlReturn::ERROR; } - // Call the implementation and convert the output properly - let handle = impl_allocate_environment_handle(); + // Register the SQLite factory on first use. + crate::init_driver(); + + let handle = EnvironmentHandle::default(); wrap_and_set(handle_type, handle, output_handle); info!("Successfully allocated an environment handle"); @@ -79,7 +78,8 @@ pub extern "C" fn SQLAllocHandle( } }; - let handle = impl_allocate_dbc_handle(env_handle); + let handle = ConnectionHandle { connection: None }; + let _ = env_handle; // env_handle validated above; connection state lives in the handle wrap_and_set(handle_type, handle, output_handle); info!("Successfully allocated a Dbc handle"); @@ -97,7 +97,7 @@ pub extern "C" fn SQLAllocHandle( } }; - let handle = match allocate_stmt_handle(connection_handle) { + let handle = match connection_handle.allocate_stmt_handle() { Some(handle) => handle, None => { error!("Cannot allocate statement handle: no active connection"); diff --git a/src/odbc/api/sqlcolattribute.rs b/src/odbc/api/sqlcolattribute.rs index cd332fd..025c183 100644 --- a/src/odbc/api/sqlcolattribute.rs +++ b/src/odbc/api/sqlcolattribute.rs @@ -1,4 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{Desc, HandleType, SqlReturn}; use std::ffi::c_void; @@ -6,9 +6,6 @@ use std::ptr; use tracing::{debug, error, info}; /// SQLColAttributeW returns descriptor information for a column in a result set. -/// -/// This function provides metadata about columns such as name, type, length, precision, etc. -/// It can return both numeric and string attributes depending on the field identifier. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLColAttributeW( @@ -25,7 +22,6 @@ pub extern "C" fn SQLColAttributeW( column_number, field_identifier, buffer_length ); - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -35,8 +31,7 @@ pub extern "C" fn SQLColAttributeW( } }; - // Check if we have a prepared statement - let stmt = match &statement_handle.statement { + let stmt = match &statement_handle.active_statement { Some(stmt) => stmt, None => { error!("No prepared statement found"); @@ -44,7 +39,6 @@ pub extern "C" fn SQLColAttributeW( } }; - // Validate column number (1-indexed in ODBC) if column_number == 0 || column_number as usize > stmt.column_count() { error!( "Invalid column number {}. Valid range: 1-{}", @@ -54,7 +48,6 @@ pub extern "C" fn SQLColAttributeW( return SqlReturn::ERROR; } - // Match field identifier directly since Desc doesn't have TryFrom let desc_result = match field_identifier { 1001 => Some(Desc::Count), 1011 => Some(Desc::Name), @@ -73,7 +66,6 @@ pub extern "C" fn SQLColAttributeW( Some(desc) => desc, None => { info!("Unsupported field identifier {}", field_identifier); - // For unsupported attributes, return success with default values if !numeric_attribute_ptr.is_null() { unsafe { *numeric_attribute_ptr = 0; @@ -83,14 +75,13 @@ pub extern "C" fn SQLColAttributeW( } }; - // Convert to 0-indexed for SQLite + // Convert to 0-indexed for the trait method let col_index = (column_number - 1) as usize; debug!("Processing {:?} for column {}", desc, col_index); match desc { Desc::Count => { - // Return total number of columns if !numeric_attribute_ptr.is_null() { unsafe { *numeric_attribute_ptr = stmt.column_count() as isize; @@ -99,26 +90,20 @@ pub extern "C" fn SQLColAttributeW( } SqlReturn::SUCCESS } - Desc::Name => { - // Return column name - match stmt.column_name(col_index) { - Ok(name) => return_string_attribute( - name, - character_attribute_ptr, - buffer_length, - string_length_ptr, - ), - Err(err) => { - error!("Could not get column name for index {}: {}", col_index, err); - SqlReturn::ERROR - } + Desc::Name => match stmt.column_name(col_index) { + Ok(name) => return_string_attribute( + &name, + character_attribute_ptr, + buffer_length, + string_length_ptr, + ), + Err(err) => { + error!("Could not get column name for index {}: {}", col_index, err); + SqlReturn::ERROR } - } + }, Desc::Type | Desc::ConciseType => { - // Return SQL data type if !numeric_attribute_ptr.is_null() { - // For now, we'll return VARCHAR for all types since SQLite is dynamically typed - // In a more complete implementation, we'd examine the column declaration type unsafe { *numeric_attribute_ptr = 12; // SQL_VARCHAR } @@ -127,28 +112,24 @@ pub extern "C" fn SQLColAttributeW( SqlReturn::SUCCESS } Desc::Length | Desc::OctetLength => { - // Return column length - for SQLite, we'll use a reasonable default if !numeric_attribute_ptr.is_null() { unsafe { - *numeric_attribute_ptr = 255; // Default VARCHAR length + *numeric_attribute_ptr = 255; } debug!("Returning length: 255"); } SqlReturn::SUCCESS } Desc::DisplaySize => { - // Return display size for the column - // TODO: Make dependent on the column type if !numeric_attribute_ptr.is_null() { unsafe { - *numeric_attribute_ptr = 25; // Default display size + *numeric_attribute_ptr = 25; } debug!("Returning display size: 25"); } SqlReturn::SUCCESS } Desc::Nullable => { - // For SQLite, columns can generally be nullable if !numeric_attribute_ptr.is_null() { unsafe { *numeric_attribute_ptr = 1; // SQL_NULLABLE @@ -158,37 +139,32 @@ pub extern "C" fn SQLColAttributeW( SqlReturn::SUCCESS } Desc::Unnamed => { - // Return whether column is named or unnamed if !numeric_attribute_ptr.is_null() { let is_named = stmt.column_name(col_index).is_ok(); unsafe { - *numeric_attribute_ptr = if is_named { 0 } else { 1 }; // SQL_NAMED = 0, SQL_UNNAMED = 1 + *numeric_attribute_ptr = if is_named { 0 } else { 1 }; } debug!("Returning unnamed: {}", !is_named); } SqlReturn::SUCCESS } - Desc::Label => { - // Return column label (same as name for SQLite) - match stmt.column_name(col_index) { - Ok(name) => return_string_attribute( - name, - character_attribute_ptr, - buffer_length, - string_length_ptr, - ), - Err(err) => { - error!( - "Could not get column label for index {}: {}", - col_index, err - ); - SqlReturn::ERROR - } + Desc::Label => match stmt.column_name(col_index) { + Ok(name) => return_string_attribute( + &name, + character_attribute_ptr, + buffer_length, + string_length_ptr, + ), + Err(err) => { + error!( + "Could not get column label for index {}: {}", + col_index, err + ); + SqlReturn::ERROR } - } + }, _ => { info!("Unsupported field identifier {:?}, returning default", desc); - // For unsupported attributes, return success with default values if !numeric_attribute_ptr.is_null() { unsafe { *numeric_attribute_ptr = 0; @@ -199,7 +175,6 @@ pub extern "C" fn SQLColAttributeW( } } -/// Helper function to return string attributes fn return_string_attribute( value: &str, character_attribute_ptr: *mut c_void, @@ -208,119 +183,31 @@ fn return_string_attribute( ) -> SqlReturn { debug!("Returning string: '{}'", value); - // Convert to UTF-16 let utf16_value: Vec = value.encode_utf16().collect(); let utf16_len = utf16_value.len() as i16; - // Set the actual length if !string_length_ptr.is_null() { unsafe { *string_length_ptr = utf16_len * 2; // Length in bytes } } - // Copy the string if buffer is provided and large enough if !character_attribute_ptr.is_null() && buffer_length > 0 { - let max_chars = (buffer_length / 2) as usize; // Convert bytes to UTF-16 chars + let max_chars = (buffer_length / 2) as usize; let copy_len = std::cmp::min(utf16_value.len(), max_chars); unsafe { ptr::copy_nonoverlapping( utf16_value.as_ptr() as *const c_void, character_attribute_ptr, - copy_len * 2, // Copy length in bytes + copy_len * 2, ); } if copy_len < utf16_value.len() { - // String was truncated return SqlReturn::SUCCESS_WITH_INFO; } } SqlReturn::SUCCESS } - -/* - - println!( - "SQLColAttributeW(column_number={}, field_identifier={:?}, buffer_length={})", - column_number, field_identifier, buffer_length - ); - - match field_identifier { - Desc::Count => {} - Desc::Type => {} - Desc::Length => {} - Desc::OctetLengthPtr => {} - Desc::Precision => {} - Desc::Scale => {} - Desc::DatetimeIntervalCode => {} - Desc::Nullable => {} - Desc::IndicatorPtr => {} - Desc::DataPtr => {} - Desc::Name => {} - Desc::Unnamed => {} - Desc::OctetLength => {} - Desc::AllocType => {} - Desc::ArraySize => {} - Desc::ArrayStatusPtr => {} - Desc::AutoUniqueValue => {} - Desc::BaseColumnName => {} - Desc::BaseTableName => {} - Desc::BindOffsetPtr => {} - Desc::BindType => {} - Desc::CaseSensitive => {} - Desc::CatalogName => {} - Desc::ConciseType => {} - Desc::DatetimeIntervalPrecision => {} - Desc::DisplaySize => unsafe { - *numeric_attribute_ptr = 20; - }, - Desc::FixedPrecScale => {} - Desc::Label => { - if !character_attribute_ptr.is_null() { - let os_string = - U16CString::from_str("foobar").expect("U16CString::from_str failed"); - - // Make sure 'buffer_length' is the maximum length you can handle in wide characters - if buffer_length >= os_string.len() as i16 { - // Set string_length_ptr to the length of the wide string - if !string_length_ptr.is_null() { - unsafe { - *string_length_ptr = os_string.len() as i16; - } - } - - // Copy the wide string into the memory pointed to by character_attribute_ptr - unsafe { - ptr::copy_nonoverlapping( - os_string.as_ptr() as *const c_void, - character_attribute_ptr, - os_string.len() * 2, // Each character is 2 bytes in UTF-16 - ); - } - } else { - // Handle buffer too small error - } - } - } - Desc::LiteralPrefix => {} - Desc::LiteralSuffix => {} - Desc::LocalTypeName => {} - Desc::MaximumScale => {} - Desc::MinimumScale => {} - Desc::NumPrecRadix => {} - Desc::ParameterType => {} - Desc::RowsProcessedPtr => {} - Desc::RowVer => {} - Desc::SchemaName => {} - Desc::Searchable => {} - Desc::TypeName => {} - Desc::TableName => {} - Desc::Unsigned => {} - Desc::Updatable => {} - } - - SqlReturn::SUCCESS -*/ diff --git a/src/odbc/api/sqlcolumns.rs b/src/odbc/api/sqlcolumns.rs index e344077..4d2d380 100644 --- a/src/odbc/api/sqlcolumns.rs +++ b/src/odbc/api/sqlcolumns.rs @@ -1,5 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; -use crate::odbc::implementation::columns::impl_get_columns; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::{get_from_wrapper, maybe_utf16_to_string}; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; @@ -39,8 +38,11 @@ pub extern "C" fn SQLColumnsW( debug!("Getting columns for table: {}", table_name); - match impl_get_columns(statement_handle, &table_name) { - Ok(()) => SqlReturn::SUCCESS, + match statement_handle.connection.get_columns(&table_name) { + Ok(stmt) => { + statement_handle.active_statement = Some(stmt); + SqlReturn::SUCCESS + } Err(err) => { error!("impl_get_columns failed: {}", err); SqlReturn::ERROR diff --git a/src/odbc/api/sqlconnect.rs b/src/odbc/api/sqlconnect.rs index 3b8cca3..45def9a 100644 --- a/src/odbc/api/sqlconnect.rs +++ b/src/odbc/api/sqlconnect.rs @@ -12,9 +12,8 @@ //! SQLSMALLINT NameLength3); //! ``` -use crate::odbc::implementation::alloc_handles::ConnectionHandle; -use crate::odbc::implementation::connect::impl_connect; -use crate::odbc::utils::{get_from_wrapper, maybe_utf16_to_string}; +use crate::odbc::handles::{ConnectionHandle, factory}; +use crate::odbc::utils::{get_from_wrapper, get_private_profile_string, maybe_utf16_to_string}; use odbc_sys::{HandleType, SmallInt, SqlReturn, WChar}; use std::ffi::c_void; use tracing::{error, info}; @@ -57,7 +56,28 @@ pub extern "C" fn SQLConnectW( let user_name = maybe_utf16_to_string(user_name, user_name_length); let authentication = maybe_utf16_to_string(authentication, authentication_length); - impl_connect(connection_handle, server_name, user_name, authentication); + let database = match get_private_profile_string(&server_name, "Database", "odbc.ini", 1024) { + Ok(Some(db)) => db, + Ok(None) => { + error!("Database setting not found for DSN '{}'", server_name); + return SqlReturn::ERROR; + } + Err(e) => { + error!("Failed to look up DSN '{}': {}", server_name, e); + return SqlReturn::ERROR; + } + }; - SqlReturn::SUCCESS + let _ = (user_name, authentication); // unused by SQLite + + match factory().create_from_path(&database) { + Ok(conn) => { + connection_handle.connection = Some(conn); + SqlReturn::SUCCESS + } + Err(e) => { + error!("Failed to open database '{}': {}", database, e); + SqlReturn::ERROR + } + } } diff --git a/src/odbc/api/sqldescribecol.rs b/src/odbc/api/sqldescribecol.rs index a478637..4114241 100644 --- a/src/odbc/api/sqldescribecol.rs +++ b/src/odbc/api/sqldescribecol.rs @@ -1,4 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; @@ -6,10 +6,6 @@ use std::ptr; use tracing::{debug, error, info}; /// SQLDescribeColW returns the result descriptor for one column in the result set. -/// -/// This function provides column metadata including name, data type, size, -/// decimal digits, and nullability. It must be called after a statement has -/// been prepared (via SQLPrepareW or SQLExecDirectW). #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLDescribeColW( @@ -28,7 +24,6 @@ pub extern "C" fn SQLDescribeColW( column_number, buffer_length ); - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -38,8 +33,7 @@ pub extern "C" fn SQLDescribeColW( } }; - // Check if we have a prepared statement - let stmt = match &statement_handle.statement { + let stmt = match &statement_handle.active_statement { Some(stmt) => stmt, None => { error!("No prepared statement found"); @@ -47,7 +41,6 @@ pub extern "C" fn SQLDescribeColW( } }; - // Validate column number (1-indexed in ODBC) if column_number == 0 || column_number as usize > stmt.column_count() { error!( "Invalid column number {}. Valid range: 1-{}", @@ -57,10 +50,8 @@ pub extern "C" fn SQLDescribeColW( return SqlReturn::ERROR; } - // Convert to 0-indexed for SQLite let col_index = (column_number - 1) as usize; - // Write column name as UTF-16 let mut result = SqlReturn::SUCCESS; match stmt.column_name(col_index) { Ok(name) => { @@ -69,14 +60,12 @@ pub extern "C" fn SQLDescribeColW( let utf16_value: Vec = name.encode_utf16().collect(); let utf16_len = utf16_value.len() as i16; - // Set the actual name length in characters (excluding null terminator) if !name_length_ptr.is_null() { unsafe { *name_length_ptr = utf16_len; } } - // Copy the name string if buffer is provided // buffer_length is in characters (u16 elements) per ODBC spec if !column_name.is_null() && buffer_length > 0 { let max_chars = buffer_length as usize; @@ -84,7 +73,6 @@ pub extern "C" fn SQLDescribeColW( unsafe { ptr::copy_nonoverlapping(utf16_value.as_ptr(), column_name, copy_len); - // Null-terminate *column_name.add(copy_len) = 0; } @@ -99,7 +87,6 @@ pub extern "C" fn SQLDescribeColW( } } - // Set data type - SQLite is dynamically typed, report VARCHAR for all columns if !data_type_ptr.is_null() { unsafe { *data_type_ptr = 12; // SQL_VARCHAR @@ -107,7 +94,6 @@ pub extern "C" fn SQLDescribeColW( debug!("Returning type: SQL_VARCHAR"); } - // Set column size - default VARCHAR length if !column_size_ptr.is_null() { unsafe { *column_size_ptr = 255; @@ -115,14 +101,12 @@ pub extern "C" fn SQLDescribeColW( debug!("Returning column size: 255"); } - // Set decimal digits - 0 for VARCHAR if !decimal_digits_ptr.is_null() { unsafe { *decimal_digits_ptr = 0; } } - // Set nullable - SQLite columns are generally nullable if !nullable_ptr.is_null() { unsafe { *nullable_ptr = 1; // SQL_NULLABLE diff --git a/src/odbc/api/sqldriverconnect.rs b/src/odbc/api/sqldriverconnect.rs index 48e93d9..57a27b2 100644 --- a/src/odbc/api/sqldriverconnect.rs +++ b/src/odbc/api/sqldriverconnect.rs @@ -13,8 +13,7 @@ //! SQLUSMALLINT DriverCompletion); //! ``` -use crate::odbc::implementation::alloc_handles::ConnectionHandle; -use crate::odbc::implementation::connect::impl_connect_to_database; +use crate::odbc::handles::{ConnectionHandle, factory}; use crate::odbc::utils::{get_from_wrapper, get_private_profile_string, maybe_utf16_to_string}; use odbc_sys::{HandleType, SmallInt, SqlReturn, USmallInt, WChar}; use std::ffi::c_void; @@ -117,16 +116,23 @@ pub extern "C" fn SQLDriverConnectW( Some(db_path) => { info!("Connecting to database: {}", db_path); - // Open the database directly — DSN resolution already happened above - impl_connect_to_database(connection_handle, db_path); + match factory().create_from_path(&db_path) { + Ok(conn) => { + connection_handle.connection = Some(conn); + } + Err(e) => { + error!("Failed to open database '{}': {}", db_path, e); + return SqlReturn::ERROR; + } + } // TODO: Copy connection string to output buffer if provided - if !out_connection_string.is_null() && buffer_length > 0 { - // For now, just indicate that we're not filling the output buffer - if !string_length2_ptr.is_null() { - unsafe { - *string_length2_ptr = 0; - } + if !out_connection_string.is_null() + && buffer_length > 0 + && !string_length2_ptr.is_null() + { + unsafe { + *string_length2_ptr = 0; } } diff --git a/src/odbc/api/sqlexecdirect.rs b/src/odbc/api/sqlexecdirect.rs index 20c4455..3f5b0e2 100644 --- a/src/odbc/api/sqlexecdirect.rs +++ b/src/odbc/api/sqlexecdirect.rs @@ -1,14 +1,10 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::{get_from_wrapper, maybe_utf16_to_string}; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; use tracing::{debug, error, info}; /// SQLExecDirectW prepares and executes an SQL statement in one step. -/// -/// This function combines the functionality of SQLPrepareW and SQLExecute, -/// preparing and executing the SQL statement immediately without the need -/// for separate preparation and execution calls. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLExecDirectW( @@ -18,7 +14,6 @@ pub extern "C" fn SQLExecDirectW( ) -> SqlReturn { info!("text_length={}", text_length); - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -28,7 +23,6 @@ pub extern "C" fn SQLExecDirectW( } }; - // Convert UTF-16 statement text to String let sql_text = match maybe_utf16_to_string(statement_text, text_length as i16) { Some(text) => text, None => { @@ -39,56 +33,15 @@ pub extern "C" fn SQLExecDirectW( debug!("Executing SQL: {}", sql_text); - // For SELECT statements, prepare and execute with query() - if sql_text.trim().to_uppercase().starts_with("SELECT") { - // Prepare the statement - match statement_handle.sqlite_connection.prepare(&sql_text) { - Ok(stmt) => { - debug!("SELECT statement prepared successfully"); - statement_handle.statement = Some(stmt); - - // Execute the prepared statement - match statement_handle.statement { - Some(ref mut stmt) => match stmt.query([]) { - Ok(rows) => { - debug!("SELECT statement executed successfully"); - statement_handle.rows = Some(rows); - SqlReturn::SUCCESS - } - Err(err) => { - error!("Failed to execute SELECT statement: {}", err); - SqlReturn::ERROR - } - }, - None => { - error!("Failed to get prepared statement reference"); - SqlReturn::ERROR - } - } - } - Err(err) => { - error!("Failed to prepare SELECT statement: {}", err); - SqlReturn::ERROR - } + match statement_handle.connection.prepare_statement(&sql_text) { + Ok(stmt) => { + debug!("Statement executed successfully"); + statement_handle.active_statement = Some(stmt); + SqlReturn::SUCCESS } - } else { - // For non-SELECT statements (INSERT, UPDATE, DELETE, etc.), use execute() - match statement_handle.sqlite_connection.execute(&sql_text, []) { - Ok(affected_rows) => { - debug!( - "Non-SELECT statement executed successfully, affected rows: {}", - affected_rows - ); - // Clear any previous result set since this wasn't a SELECT - statement_handle.statement = None; - statement_handle.rows = None; - statement_handle.row = None; - SqlReturn::SUCCESS - } - Err(err) => { - error!("Failed to execute non-SELECT statement: {}", err); - SqlReturn::ERROR - } + Err(err) => { + error!("Failed to execute statement: {}", err); + SqlReturn::ERROR } } } diff --git a/src/odbc/api/sqlexecute.rs b/src/odbc/api/sqlexecute.rs index fd8e7ea..4656efe 100644 --- a/src/odbc/api/sqlexecute.rs +++ b/src/odbc/api/sqlexecute.rs @@ -1,4 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; @@ -6,14 +6,13 @@ use tracing::{debug, error, info}; /// SQLExecute executes a prepared statement. /// -/// This function executes a statement that was prepared with SQLPrepareW. -/// The statement is executed with the current parameter values. +/// Because the driver executes eagerly at prepare time, this is effectively +/// a no-op — it verifies a statement exists and returns SUCCESS. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLExecute(statement_handle: *mut c_void) -> SqlReturn { - info!("SQLExecute INFO"); + info!("SQLExecute"); - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -23,20 +22,10 @@ pub extern "C" fn SQLExecute(statement_handle: *mut c_void) -> SqlReturn { } }; - // Check if we have a prepared statement - if statement_handle.statement.is_none() { - error!("No prepared statement found"); - return SqlReturn::ERROR; - } - - debug!("Executing prepared statement"); - - // Execute the prepared statement - match statement_handle.statement { - Some(ref mut stmt) => match stmt.query([]) { - Ok(rows) => { + match statement_handle.active_statement.as_mut() { + Some(stmt) => match stmt.execute() { + Ok(()) => { debug!("Statement executed successfully"); - statement_handle.rows = Some(rows); SqlReturn::SUCCESS } Err(err) => { @@ -45,8 +34,7 @@ pub extern "C" fn SQLExecute(statement_handle: *mut c_void) -> SqlReturn { } }, None => { - // This should not happen due to the check above, but handle it anyway - error!("Prepared statement is None"); + error!("No prepared statement found"); SqlReturn::ERROR } } diff --git a/src/odbc/api/sqlfetch.rs b/src/odbc/api/sqlfetch.rs index 5ab62a6..246781e 100644 --- a/src/odbc/api/sqlfetch.rs +++ b/src/odbc/api/sqlfetch.rs @@ -1,20 +1,17 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; use tracing::{debug, error, info}; -/// SQLFetch fetches the next row of data from the result set. +/// SQLFetch advances the cursor to the next row in the result set. /// -/// This function advances the cursor to the next row in the result set and -/// retrieves the data for all bound columns. If there are no more rows, -/// it returns SQL_NO_DATA. +/// Returns SQL_NO_DATA when no more rows are available. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLFetch(statement_handle: *mut c_void) -> SqlReturn { info!("Fetching next row"); - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -24,30 +21,21 @@ pub extern "C" fn SQLFetch(statement_handle: *mut c_void) -> SqlReturn { } }; - // Check if we have rows available (from previous SQLExecute call) - match statement_handle.rows.as_mut() { - Some(rows) => { - debug!("Found result set, fetching next row"); - - // Try to get the next row - match rows.next() { - Ok(row_option) => match row_option { - Some(row) => { - debug!("Successfully fetched row"); - statement_handle.row = Some(row); - SqlReturn::SUCCESS - } - None => { - info!("No more data available"); - SqlReturn::NO_DATA - } - }, - Err(err) => { - error!("Failed to fetch next row: {}", err); - SqlReturn::ERROR - } + match statement_handle.active_statement.as_mut() { + Some(stmt) => match stmt.fetch_next_row() { + Ok(true) => { + debug!("Successfully fetched row"); + SqlReturn::SUCCESS } - } + Ok(false) => { + info!("No more data available"); + SqlReturn::NO_DATA + } + Err(err) => { + error!("Failed to fetch next row: {}", err); + SqlReturn::ERROR + } + }, None => { error!("No result set available. Call SQLExecute first."); SqlReturn::ERROR diff --git a/src/odbc/api/sqlfreehandle.rs b/src/odbc/api/sqlfreehandle.rs index f591570..dc878e5 100644 --- a/src/odbc/api/sqlfreehandle.rs +++ b/src/odbc/api/sqlfreehandle.rs @@ -1,6 +1,4 @@ -use crate::odbc::implementation::alloc_handles::{ - ConnectionHandle, EnvironmentHandle, StatementHandle, -}; +use crate::odbc::handles::{ConnectionHandle, EnvironmentHandle, StatementHandle}; use crate::odbc::utils::{HandleWrapper, tag_for_handle}; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; diff --git a/src/odbc/api/sqlgetdata.rs b/src/odbc/api/sqlgetdata.rs index acdfb28..5a2ef92 100644 --- a/src/odbc/api/sqlgetdata.rs +++ b/src/odbc/api/sqlgetdata.rs @@ -1,5 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; -use crate::odbc::implementation::getdata::impl_getdata; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{CDataType, HandleType, SqlReturn}; use std::ffi::{CString, c_void}; @@ -27,15 +26,12 @@ pub extern "C" fn SQLGetData( }; if col_or_param_num == 0 { - // TODO warn!("Bookmarks not supported yet"); return SqlReturn::ERROR; } - // TODO: Have a generic way to check if the requested column is even in range - let target_type = match CDataType::try_from(target_type) { - Ok(target_type) => target_type, + Ok(t) => t, Err(e) => { error!( "Could not convert {} to valid target type: {}", @@ -50,40 +46,46 @@ pub extern "C" fn SQLGetData( target_type, col_or_param_num ); - let result = impl_getdata(statement_handle, &target_type, col_or_param_num); + let stmt = match &statement_handle.active_statement { + Some(stmt) => stmt, + None => { + error!("No active statement; call SQLFetch first"); + return SqlReturn::ERROR; + } + }; + + let col_index = (col_or_param_num - 1) as usize; + let result = match stmt.get_data(col_index, target_type) { + Ok(value) => value, + Err(err) => { + error!("get_data failed: {}", err); + return SqlReturn::ERROR; + } + }; - // TODO: Make sure to handle encoding properly here, not sure how let c_string = match CString::new(result) { - Ok(string) => string, - Err(_e) => { + Ok(s) => s, + Err(_) => { error!("Converting String to CString failed"); return SqlReturn::ERROR; - // TODO: Set error in connection_handle } }; let c_string_bytes_with_nul = c_string.as_bytes_with_nul(); - let c_string_len = c_string_bytes_with_nul.len() - 1; // Exclude the null terminator to get the actual length of the string. + let c_string_len = c_string_bytes_with_nul.len() - 1; - // Calculate the final string length to be used in the operation. - let final_string_length = std::cmp::min(c_string_len, buffer_length as usize - 1); // Leave space for the null byte + let final_string_length = std::cmp::min(c_string_len, buffer_length as usize - 1); unsafe { - // Copy the appropriate slice of the string based on the calculated length. std::ptr::copy_nonoverlapping( c_string_bytes_with_nul.as_ptr(), target_value_ptr as *mut u8, final_string_length, ); - // Cast the `c_void` pointer to a `u8` pointer before dereferencing. - // This is safe because we know the original data structure is a u8 buffer. let null_terminator_ptr = target_value_ptr.cast::().add(final_string_length); - - // Add the null terminator at the right position. *null_terminator_ptr = 0; - // Set the string length excluding the null terminator. *str_len_or_ind_ptr = final_string_length as isize; } diff --git a/src/odbc/api/sqlgetenvattr.rs b/src/odbc/api/sqlgetenvattr.rs index b166169..5cbff28 100644 --- a/src/odbc/api/sqlgetenvattr.rs +++ b/src/odbc/api/sqlgetenvattr.rs @@ -1,5 +1,4 @@ -use crate::odbc::implementation::alloc_handles::EnvironmentHandle; -use crate::odbc::implementation::env_attrs::get_odbc_version; +use crate::odbc::handles::EnvironmentHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{EnvironmentAttribute, HandleType, Integer, Pointer, SqlReturn}; use tracing::{debug, error}; @@ -45,7 +44,7 @@ pub extern "C" fn SQLGetEnvAttr( match attribute { EnvironmentAttribute::OdbcVersion => { - let odbc_version = get_odbc_version(env); + let odbc_version = env.odbc_version(); unsafe { *(value_ptr as *mut i32) = odbc_version as i32 } } EnvironmentAttribute::ConnectionPooling => { diff --git a/src/odbc/api/sqlgetinfo.rs b/src/odbc/api/sqlgetinfo.rs index d4acb49..90cf02a 100644 --- a/src/odbc/api/sqlgetinfo.rs +++ b/src/odbc/api/sqlgetinfo.rs @@ -1,7 +1,9 @@ -use crate::connection::ConnectionClass; -use crate::odbc::implementation::implementation::get_info; -use odbc_sys::{InfoType, InfoTypeType, InfoTypeTypeInformation, Pointer, SmallInt, SqlReturn}; -use std::ffi::CString; +use crate::odbc::handles::ConnectionHandle; +use crate::odbc::utils::get_from_wrapper; +use odbc_sys::{ + HandleType, InfoType, InfoTypeType, InfoTypeTypeInformation, Pointer, SmallInt, SqlReturn, +}; +use std::ffi::{CString, c_void}; use tracing::{error, info}; const STRING_LENGTH_FOR_USMALLINT: i16 = std::mem::size_of::() as i16; @@ -16,7 +18,7 @@ const STRING_LENGTH_FOR_UINTEGER: i16 = std::mem::size_of::() as i16; #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLGetInfo( - connection_handle: *mut ConnectionClass, + connection_handle: *mut c_void, info_type: u16, info_value_ptr: Pointer, buffer_length: SmallInt, @@ -32,6 +34,15 @@ pub extern "C" fn SQLGetInfo( return SqlReturn::INVALID_HANDLE; } + let connection_handle: &mut ConnectionHandle = + match get_from_wrapper(&HandleType::Dbc, connection_handle) { + Ok(h) => h, + Err(e) => { + error!("Failed to get connection handle: {}", e); + return SqlReturn::INVALID_HANDLE; + } + }; + if string_length_ptr.is_null() { error!("string_length_ptr is null"); // TODO: Set error in connection_handle @@ -68,8 +79,12 @@ pub extern "C" fn SQLGetInfo( // TODO: Check buffer_length if info_value_ptr is not null and info_type type is a character string // TODO: Unicode variant needs to check if buffer_length is even number, if not -> HY0900 - let result = - get_info(info_type).map_or(info_type.return_type().not_supported_value(), |value| value); + let result = match connection_handle.connection.as_ref() { + Some(conn) => conn + .get_info(info_type) + .map_or_else(|| info_type.return_type().not_supported_value(), |v| v), + None => info_type.return_type().not_supported_value(), + }; // buffer_length is ignored for anything that's not a string as per the specification match result { @@ -145,9 +160,35 @@ pub extern "C" fn SQLGetInfo( #[cfg(test)] mod tests { use super::*; + use crate::odbc::handles::ConnectionHandle; + use crate::odbc::utils::wrap_and_set; + use odbc_sys::HandleType; use std::ffi::CStr; use std::os::raw::{c_char, c_void}; + /// Create a disconnected (no active DB) connection handle pointer suitable for + /// tests that only exercise the SQLGetInfo error-path logic (null ptr, bad args, etc.). + fn make_disconnected_dbc_ptr() -> Pointer { + let ch = ConnectionHandle { connection: None }; + let mut ptr: Pointer = std::ptr::null_mut(); + wrap_and_set(HandleType::Dbc, ch, &mut ptr); + ptr + } + + /// Create a connected in-memory SQLite connection handle pointer. + fn make_connected_dbc_ptr() -> Pointer { + crate::init_driver(); + let conn = crate::odbc::handles::factory() + .create_from_path(":memory:") + .unwrap(); + let ch = ConnectionHandle { + connection: Some(conn), + }; + let mut ptr: Pointer = std::ptr::null_mut(); + wrap_and_set(HandleType::Dbc, ch, &mut ptr); + ptr + } + #[test] fn test_null_connection_handle() { let result = SQLGetInfo( @@ -162,9 +203,9 @@ mod tests { #[test] fn test_nullstring_length_ptr() { - let mut connection = ConnectionClass {}; + let dbc = make_disconnected_dbc_ptr(); let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, InfoType::MaxDriverConnections as u16, std::ptr::null_mut(), 2, @@ -176,28 +217,22 @@ mod tests { #[test] fn test_invalid_info_type() { - let mut connection = ConnectionClass {}; + let dbc = make_disconnected_dbc_ptr(); let string_length_ptr = &mut 0; - let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, - 9999, - std::ptr::null_mut(), - 2, - string_length_ptr, - ); + let result = SQLGetInfo(dbc, 9999, std::ptr::null_mut(), 2, string_length_ptr); assert_eq!(result, SqlReturn::ERROR); // TODO: Once implemented check that the correct error type is returned } #[test] fn test_valid_info_type_usmallint() { - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let mut value: u16 = 0; let mut string_length: i16 = 0; let buffer_length = std::mem::size_of_val(&value) as i16; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, InfoType::ActiveEnvironments as u16, &mut value as *mut u16 as *mut c_void, buffer_length, @@ -211,12 +246,12 @@ mod tests { #[test] fn test_valid_info_type_string() { - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let mut buffer = vec![0u8; 256]; let mut string_length: i16 = 0; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, InfoType::UserName as u16, buffer.as_mut_ptr() as *mut c_void, buffer.len() as i16, @@ -233,12 +268,12 @@ mod tests { #[test] fn test_string_truncation() { - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let mut buffer = vec![0u8; 3]; let mut string_length: i16 = 0; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, InfoType::UserName as u16, buffer.as_mut_ptr() as *mut c_void, buffer.len() as i16, @@ -254,30 +289,35 @@ mod tests { } #[test] - fn test_insufficient_buffer_for_integer() { - let mut connection = ConnectionClass {}; - let mut value: u16 = 0; // Not enough space for a u32 + fn buffer_length_is_ignored_for_non_string_types() { + // ODBC spec: buffer_length is ignored for non-string info types. + // Passing a u16 buffer for a u32 info type is therefore not an error as long + // as the pointer alignment is correct. + let dbc = make_disconnected_dbc_ptr(); + let mut value: u32 = 0; + let mut string_length: i16 = 0; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, - InfoType::ScrollOptions as u16, // This info_type returns a u32 - &mut value as *mut u16 as *mut c_void, - std::mem::size_of::() as i16, // too small! - std::ptr::null_mut(), + dbc, + InfoType::ScrollOptions as u16, + &mut value as *mut u32 as *mut c_void, + std::mem::size_of::() as i16, // too small, but spec says it's ignored + &mut string_length, ); - assert_eq!(result, SqlReturn::ERROR); // or whatever error code you return for insufficient space + assert_eq!(result, SqlReturn::SUCCESS); } #[test] fn test_bad_alignment() { - let mut connection = ConnectionClass {}; + let dbc = make_disconnected_dbc_ptr(); let mut buffer = vec![0u8; 256]; + let mut string_length: i16 = 0; let result = unsafe { SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, InfoType::ActiveEnvironments as u16, - (buffer.as_mut_ptr().offset(1)) as *mut c_void, // misaligned + (buffer.as_mut_ptr().offset(1)) as *mut c_void, // deliberately misaligned buffer.len() as i16, - std::ptr::null_mut(), + &mut string_length, ) }; assert_eq!(result, SqlReturn::ERROR); @@ -286,10 +326,10 @@ mod tests { #[test] fn test_null_info_value_ptr() { - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let mut string_length: i16 = 0; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, InfoType::ActiveEnvironments as u16, std::ptr::null_mut(), 0, @@ -301,7 +341,7 @@ mod tests { #[test] fn test_sqluinteger() { - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let info_type = InfoType::ScrollOptions; let buffer_length = std::mem::size_of::() as i16; let mut string_length: i16 = 0; @@ -309,7 +349,7 @@ mod tests { let info_value_ptr: *mut c_void = buffer.as_mut_ptr() as *mut c_void; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, info_type as u16, info_value_ptr, buffer_length, @@ -323,7 +363,7 @@ mod tests { #[test] fn test_sql_get_info() { // Test case: SQLUSMALLINT case - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let info_type = InfoType::ActiveEnvironments; let buffer_length = 2; let mut string_length: i16 = 0; @@ -331,7 +371,7 @@ mod tests { let info_value_ptr: *mut c_void = buffer.as_mut_ptr() as *mut c_void; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, info_type as u16, info_value_ptr, buffer_length, @@ -346,7 +386,7 @@ mod tests { } // Test case: String case for an InfoType that is not implemented - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let info_type = InfoType::DescribeParameter; let buffer_length = 15; let mut string_length: i16 = 0; @@ -354,7 +394,7 @@ mod tests { let info_value_ptr: *mut c_void = buffer.as_mut_ptr() as *mut c_void; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, info_type as u16, info_value_ptr, buffer_length, @@ -369,7 +409,7 @@ mod tests { assert_eq!(rust_string, "N"); // Test case: String case for a string that is too long - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let info_type = InfoType::DescribeParameter; let buffer_length = 15; let mut string_length: i16 = 0; @@ -377,7 +417,7 @@ mod tests { let info_value_ptr: *mut c_void = buffer.as_mut_ptr() as *mut c_void; let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, info_type as u16, info_value_ptr, buffer_length, @@ -392,9 +432,9 @@ mod tests { assert_eq!(rust_string, "N"); let invalid_info_type: u16 = 12345; - let mut connection = ConnectionClass {}; + let dbc = make_connected_dbc_ptr(); let buffer_length = 15; - //let mut string_length: i16 = 0; + let mut string_length: i16 = 0; let mut buffer: [c_char; 15] = [0; 15]; unsafe { @@ -405,33 +445,12 @@ mod tests { let info_value_ptr: *mut c_void = buffer.as_mut_ptr() as *mut c_void; let _result = SQLGetInfo( - &mut connection as *mut ConnectionClass, + dbc, invalid_bar as u16, info_value_ptr, buffer_length, - //&mut string_length as *mut i16, - 0 as *mut i16, + &mut string_length as *mut i16, ); } - - /* - // Test case 2: String case with insufficient buffer length - let buffer_length = 5; - let mut string_length: i16 = 0; - let mut buffer: [c_char; 5] = [0; 5]; - let info_value_ptr: *mut c_void = buffer.as_mut_ptr() as *mut c_void; - - let result = SQLGetInfo( - &mut connection as *mut ConnectionClass, - info_type, - info_value_ptr, - buffer_length, - &mut string_length as *mut i16, - ); - - assert_eq!(result, SqlReturn::ERROR); // Buffer too small - - - */ } } diff --git a/src/odbc/api/sqlmoreresults.rs b/src/odbc/api/sqlmoreresults.rs index 1c53ad1..432b9de 100644 --- a/src/odbc/api/sqlmoreresults.rs +++ b/src/odbc/api/sqlmoreresults.rs @@ -1,24 +1,15 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; -use tracing::{debug, error, info}; +use tracing::{error, info}; -/// SQLMoreResults determines whether more results are available on a statement containing -/// SELECT, UPDATE, INSERT, or DELETE statements and, if so, initializes processing for those results. -/// -/// For SQLite, multiple result sets are not supported in the traditional sense (unlike SQL Server -/// stored procedures or MySQL batch statements). SQLite executes one statement at a time, so this -/// function will typically return SQL_NO_DATA to indicate no additional result sets are available. -/// -/// Note: This implementation assumes single result set per statement execution, which is -/// appropriate for SQLite's architecture. +/// SQLMoreResults returns SQL_NO_DATA — SQLite does not support multiple result sets. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLMoreResults(statement_handle: *mut c_void) -> SqlReturn { info!("Checking for additional result sets"); - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -28,16 +19,8 @@ pub extern "C" fn SQLMoreResults(statement_handle: *mut c_void) -> SqlReturn { } }; - // Check if we have a prepared statement - match &statement_handle.statement { + match &statement_handle.active_statement { Some(_) => { - debug!("Found prepared statement"); - - // SQLite doesn't support multiple result sets from a single statement execution - // Unlike SQL Server stored procedures or MySQL batch operations, SQLite processes - // one statement at a time. Therefore, after processing the initial result set, - // there are no additional result sets available. - info!("No additional result sets available (SQLite limitation)"); SqlReturn::NO_DATA } diff --git a/src/odbc/api/sqlnumresultcols.rs b/src/odbc/api/sqlnumresultcols.rs index d1c00d3..c4d2738 100644 --- a/src/odbc/api/sqlnumresultcols.rs +++ b/src/odbc/api/sqlnumresultcols.rs @@ -1,13 +1,10 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; use tracing::{debug, error, info}; /// SQLNumResultCols returns the number of columns in a result set. -/// -/// This function can be called after SQLPrepareW and SQLExecute to determine -/// the number of columns that will be returned by the statement. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub unsafe extern "C" fn SQLNumResultCols( @@ -16,13 +13,11 @@ pub unsafe extern "C" fn SQLNumResultCols( ) -> SqlReturn { info!("Getting number of result cols"); - // Validate column_count_ptr is not null if column_count_ptr.is_null() { error!("column_count_ptr is null"); return SqlReturn::ERROR; } - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -32,8 +27,7 @@ pub unsafe extern "C" fn SQLNumResultCols( } }; - // Check if we have a prepared statement - match &statement_handle.statement { + match &statement_handle.active_statement { Some(stmt) => { let num_cols = stmt.column_count(); debug!("Found {} columns", num_cols); diff --git a/src/odbc/api/sqlprepare.rs b/src/odbc/api/sqlprepare.rs index a60e5c1..ca1d60a 100644 --- a/src/odbc/api/sqlprepare.rs +++ b/src/odbc/api/sqlprepare.rs @@ -1,4 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::{get_from_wrapper, maybe_utf16_to_string}; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; @@ -6,8 +6,8 @@ use tracing::{debug, error, info}; /// SQLPrepareW prepares an SQL statement for execution. /// -/// This function prepares the SQL statement but does not execute it. -/// The prepared statement can later be executed with SQLExecute. +/// The statement is eagerly executed and results are collected immediately, +/// so that `SQLNumResultCols`, `SQLColAttribute`, etc. work before `SQLExecute`. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLPrepareW( @@ -17,7 +17,6 @@ pub extern "C" fn SQLPrepareW( ) -> SqlReturn { info!("text_length={}", text_length); - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -27,7 +26,6 @@ pub extern "C" fn SQLPrepareW( } }; - // Convert UTF-16 statement text to String let sql_text = match maybe_utf16_to_string(statement_text, text_length) { Some(text) => text, None => { @@ -38,11 +36,10 @@ pub extern "C" fn SQLPrepareW( debug!("Preparing SQL: {}", sql_text); - // Prepare the statement using rusqlite - match statement_handle.sqlite_connection.prepare(&sql_text) { + match statement_handle.connection.prepare_statement(&sql_text) { Ok(stmt) => { debug!("Statement prepared successfully"); - statement_handle.statement = Some(stmt); + statement_handle.active_statement = Some(stmt); SqlReturn::SUCCESS } Err(err) => { diff --git a/src/odbc/api/sqlrowcount.rs b/src/odbc/api/sqlrowcount.rs index 4a5180b..d396ba3 100644 --- a/src/odbc/api/sqlrowcount.rs +++ b/src/odbc/api/sqlrowcount.rs @@ -1,4 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; @@ -6,12 +6,7 @@ use tracing::{debug, error, info}; /// SQLRowCount returns the number of rows affected by an UPDATE, INSERT, or DELETE statement. /// -/// For SELECT statements and other statements that don't modify rows, this function -/// should return -1 to indicate that the row count is not available or not applicable. -/// -/// Note: According to ODBC spec, SQLRowCount only applies to statements that modify data. -/// For SELECT statements, use SQLNumResultCols to get column count or iterate through -/// SQLFetch calls to count rows. +/// Returns -1 for SELECT statements or when no statement is active. #[allow(non_snake_case)] #[unsafe(no_mangle)] pub extern "C" fn SQLRowCount( @@ -20,13 +15,11 @@ pub extern "C" fn SQLRowCount( ) -> SqlReturn { info!("Getting row count"); - // Validate row_count_ptr is not null if row_count_ptr.is_null() { error!("row_count_ptr is null"); return SqlReturn::ERROR; } - // Get the statement handle let statement_handle: &mut StatementHandle = match get_from_wrapper(&HandleType::Stmt, statement_handle) { Ok(handle) => handle, @@ -36,40 +29,24 @@ pub extern "C" fn SQLRowCount( } }; - // Check if we have a prepared statement - match &statement_handle.statement { - Some(_) => { - debug!("Found prepared statement"); - - // Try to get the number of rows affected - // Note: SQLite's changes() function returns the number of rows affected by the most recent - // INSERT, UPDATE, or DELETE statement. For SELECT statements, it's not applicable. - let changes = statement_handle.sqlite_connection.changes(); - - debug!("SQLite changes() returned: {}", changes); - - // According to ODBC spec: - // - For UPDATE, INSERT, DELETE: return actual row count - // - For SELECT and other statements: return -1 (not available) - // Since we can't easily determine the statement type here, we'll use SQLite's changes() - // but note that it may return 0 for SELECT statements - + match &statement_handle.active_statement { + Some(stmt) => { + let changes = stmt.row_changes(); + debug!("row_changes={}", changes); unsafe { - if changes > 0 { - *row_count_ptr = changes as isize; + *row_count_ptr = if changes > 0 { debug!("Returning row count: {}", changes); + changes as isize } else { - // For statements that don't modify data (like SELECT), return -1 - // This indicates that the row count is not available or not applicable - *row_count_ptr = -1; - debug!("Returning -1 (row count not applicable for this statement type)"); - } + // SELECT or no DML — row count not applicable + debug!("Returning -1 (row count not applicable)"); + -1 + }; } - SqlReturn::SUCCESS } None => { - error!("SQLRowCount ERROR: No prepared statement found"); + error!("No prepared statement found"); unsafe { *row_count_ptr = -1; } diff --git a/src/odbc/api/sqlsetenvattr.rs b/src/odbc/api/sqlsetenvattr.rs index cec92b1..9a8cbdc 100644 --- a/src/odbc/api/sqlsetenvattr.rs +++ b/src/odbc/api/sqlsetenvattr.rs @@ -1,5 +1,4 @@ -use crate::odbc::implementation::alloc_handles::EnvironmentHandle; -use crate::odbc::implementation::env_attrs::set_odbc_version; +use crate::odbc::handles::EnvironmentHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{AttrOdbcVersion, EnvironmentAttribute, HandleType, Integer, Pointer, SqlReturn}; use tracing::{debug, error}; @@ -61,7 +60,7 @@ pub fn SQLSetEnvAttr( } }; - set_odbc_version(env, odbc_version); + env.set_odbc_version(odbc_version); } EnvironmentAttribute::ConnectionPooling => { // TODO: This is implemented in the driver manager, not the driver diff --git a/src/odbc/api/sqltables.rs b/src/odbc/api/sqltables.rs index 43b2c85..c0bec74 100644 --- a/src/odbc/api/sqltables.rs +++ b/src/odbc/api/sqltables.rs @@ -1,5 +1,4 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; -use crate::odbc::implementation::tables::impl_get_tables; +use crate::odbc::handles::StatementHandle; use crate::odbc::utils::get_from_wrapper; use odbc_sys::{HandleType, SqlReturn}; use std::ffi::c_void; @@ -32,8 +31,11 @@ pub extern "C" fn SQLTablesW( } }; - match impl_get_tables(statement_handle) { - Ok(()) => SqlReturn::SUCCESS, + match statement_handle.connection.get_tables() { + Ok(stmt) => { + statement_handle.active_statement = Some(stmt); + SqlReturn::SUCCESS + } Err(err) => { error!("impl_get_tables failed: {}", err); SqlReturn::ERROR diff --git a/src/odbc/handles.rs b/src/odbc/handles.rs new file mode 100644 index 0000000..778c391 --- /dev/null +++ b/src/odbc/handles.rs @@ -0,0 +1,100 @@ +use odbc_sys::{AttrOdbcVersion, CDataType, InfoType, InfoTypeType}; +use std::sync::{Arc, OnceLock}; + +// ─── Traits ────────────────────────────────────────────────────────────────── + +/// Abstracts a database connection. All DB-specific logic lives behind this trait. +pub trait DbConnection: Send + Sync { + /// Prepare and eagerly execute a SQL query, returning the full result set. + fn prepare_statement(&self, sql: &str) -> Result, String>; + + /// Return an ODBC-spec 1-column result set of table names. + fn get_tables(&self) -> Result, String>; + + /// Return an ODBC-spec 18-column result set of column metadata for a table. + fn get_columns(&self, table_name: &str) -> Result, String>; + + /// Return driver/data-source information for `SQLGetInfo`. + fn get_info(&self, info_type: InfoType) -> Option; +} + +/// A fully-executed statement whose results are ready to iterate. +pub trait ActiveStatement: Send + Sync { + fn column_count(&self) -> usize; + fn column_name(&self, index: usize) -> Result; + /// No-op: execution happened eagerly at prepare time. + fn execute(&mut self) -> Result<(), String>; + fn fetch_next_row(&mut self) -> Result; + fn get_data(&self, col_index: usize, target_type: CDataType) -> Result; + fn row_changes(&self) -> u64; +} + +/// Creates `DbConnection` instances from a database path. +/// Register one implementation at startup via `register_factory`. +pub trait DbConnectionFactory: Send + Sync { + fn create_from_path(&self, path: &str) -> Result, String>; +} + +// ─── Global factory registry ───────────────────────────────────────────────── + +static FACTORY: OnceLock> = OnceLock::new(); + +/// Register the database-specific factory. Must be called before any +/// `SQLConnectW` / `SQLDriverConnectW` call. Idempotent (subsequent calls are ignored). +pub fn register_factory(factory: Box) { + let _ = FACTORY.set(factory); +} + +/// Retrieve the registered factory. Panics if `register_factory` was never called. +pub(crate) fn factory() -> &'static dyn DbConnectionFactory { + FACTORY + .get() + .expect("no DbConnectionFactory registered; call register_factory first") + .as_ref() +} + +// ─── Handle types ──────────────────────────────────────────────────────────── + +#[derive(Debug)] +pub struct EnvironmentHandle { + pub odbc_version: AttrOdbcVersion, + pub _output_nts: bool, +} + +impl Default for EnvironmentHandle { + fn default() -> Self { + EnvironmentHandle { + odbc_version: AttrOdbcVersion::Odbc3, + _output_nts: true, + } + } +} + +impl EnvironmentHandle { + pub fn odbc_version(&self) -> AttrOdbcVersion { + self.odbc_version + } + + pub fn set_odbc_version(&mut self, version: AttrOdbcVersion) { + self.odbc_version = version; + } +} + +pub struct ConnectionHandle { + pub connection: Option>, +} + +impl ConnectionHandle { + pub fn allocate_stmt_handle(&self) -> Option { + let connection = self.connection.as_ref()?.clone(); + Some(StatementHandle { + connection, + active_statement: None, + }) + } +} + +pub struct StatementHandle { + pub connection: Arc, + pub active_statement: Option>, +} diff --git a/src/odbc/implementation.rs b/src/odbc/implementation.rs index e196288..d6243b1 100644 --- a/src/odbc/implementation.rs +++ b/src/odbc/implementation.rs @@ -5,4 +5,5 @@ pub(crate) mod env_attrs; pub(crate) mod getdata; #[allow(clippy::module_inception)] pub(crate) mod implementation; +pub(crate) mod query; pub(crate) mod tables; diff --git a/src/odbc/implementation/alloc_handles.rs b/src/odbc/implementation/alloc_handles.rs index 7df3809..5e571ac 100644 --- a/src/odbc/implementation/alloc_handles.rs +++ b/src/odbc/implementation/alloc_handles.rs @@ -1,52 +1,2 @@ -#![allow(dead_code)] -use odbc_sys::AttrOdbcVersion; -use rusqlite::{Connection, Row, Rows, Statement}; - -#[derive(Debug)] -pub struct EnvironmentHandle { - pub odbc_version: AttrOdbcVersion, - pub output_nts: bool, -} - -impl Default for EnvironmentHandle { - fn default() -> Self { - EnvironmentHandle { - odbc_version: AttrOdbcVersion::Odbc3, - output_nts: true, - } - } -} - -#[derive(Debug)] -pub struct ConnectionHandle { - pub sqlite_connection: Option, -} - -pub struct StatementHandle<'a> { - pub sqlite_connection: &'a Connection, // TODO: Make it a reference to the ConnectionHandle instead - pub statement: Option>, - pub rows: Option>, - pub row: Option<&'a Row<'a>>, -} - -pub(crate) fn impl_allocate_environment_handle() -> EnvironmentHandle { - EnvironmentHandle::default() -} - -pub(crate) fn impl_allocate_dbc_handle(_env_handle: &mut EnvironmentHandle) -> ConnectionHandle { - ConnectionHandle { - sqlite_connection: None, - } -} - -pub(crate) fn allocate_stmt_handle( - connection_handle: &mut ConnectionHandle, -) -> Option> { - let connection_ref = connection_handle.sqlite_connection.as_ref()?; - Some(StatementHandle { - sqlite_connection: connection_ref, - statement: None, - rows: None, - row: None, - }) -} +// Handle allocation is now done directly in the api layer using methods on +// EnvironmentHandle, ConnectionHandle, and StatementHandle in handles.rs. diff --git a/src/odbc/implementation/columns.rs b/src/odbc/implementation/columns.rs index f206c74..e19091e 100644 --- a/src/odbc/implementation/columns.rs +++ b/src/odbc/implementation/columns.rs @@ -1,257 +1 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; -use tracing::error; - -pub(crate) fn impl_get_columns<'a>( - statement_handle: &mut StatementHandle<'a>, - table_name: &str, -) -> Result<(), String> { - // Validate table name before interpolating into SQL. - // pragma_table_info does not accept bound parameters, so we must - // ensure the name contains only safe characters. - if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') { - return Err(format!("Invalid table name: '{table_name}'")); - } - - // pragma_table_info('') is a SQLite table-valued function that returns - // one row per column: cid, name, type, notnull, dflt_value, pk. - // - // We reshape the output into the 18-column result set required by the ODBC - // spec for SQLColumns. Columns that SQLite cannot provide are returned as NULL. - // - // ODBC SQL type codes used here: - // -7 = SQL_BIT, 2 = SQL_NUMERIC, 4 = SQL_INTEGER, - // 8 = SQL_DOUBLE, -4 = SQL_LONGVARBINARY, 12 = SQL_VARCHAR (default) - // Column names from pragma_table_info that are SQLite reserved words must be - // double-quoted: "notnull" and "type". Others ("name", "cid", "dflt_value") are - // quoted for consistency. - let sql = format!( - r#"SELECT - NULL AS TABLE_CAT, - NULL AS TABLE_SCHEM, - '{table_name}' AS TABLE_NAME, - "name" AS COLUMN_NAME, - CASE - WHEN upper("type") LIKE '%INT%' THEN 4 - WHEN upper("type") LIKE '%REAL%' - OR upper("type") LIKE '%FLOAT%' - OR upper("type") LIKE '%DOUBLE%' THEN 8 - WHEN upper("type") LIKE '%BLOB%' THEN -4 - WHEN upper("type") LIKE '%BOOL%' THEN -7 - WHEN upper("type") LIKE '%NUMERIC%' - OR upper("type") LIKE '%DECIMAL%' THEN 2 - ELSE 12 - END AS DATA_TYPE, - CASE WHEN "type" = '' THEN 'TEXT' ELSE "type" END AS TYPE_NAME, - NULL AS COLUMN_SIZE, - NULL AS BUFFER_LENGTH, - NULL AS DECIMAL_DIGITS, - NULL AS NUM_PREC_RADIX, - CASE WHEN "notnull" = 1 THEN 0 ELSE 1 END AS NULLABLE, - NULL AS REMARKS, - "dflt_value" AS COLUMN_DEF, - CASE - WHEN upper("type") LIKE '%INT%' THEN 4 - WHEN upper("type") LIKE '%REAL%' - OR upper("type") LIKE '%FLOAT%' - OR upper("type") LIKE '%DOUBLE%' THEN 8 - WHEN upper("type") LIKE '%BLOB%' THEN -4 - WHEN upper("type") LIKE '%BOOL%' THEN -7 - WHEN upper("type") LIKE '%NUMERIC%' - OR upper("type") LIKE '%DECIMAL%' THEN 2 - ELSE 12 - END AS SQL_DATA_TYPE, - NULL AS SQL_DATETIME_SUB, - NULL AS CHAR_OCTET_LENGTH, - "cid" + 1 AS ORDINAL_POSITION, - CASE WHEN "notnull" = 1 THEN 'NO' ELSE 'YES' END AS IS_NULLABLE - FROM pragma_table_info('{table_name}')"# - ); - - let stmt = statement_handle - .sqlite_connection - .prepare(&sql) - .map_err(|e| { - error!("Failed to prepare column query for '{}': {}", table_name, e); - e.to_string() - })?; - - statement_handle.statement = Some(stmt); - - let rows = statement_handle - .statement - .as_mut() - .unwrap() - .query([]) - .map_err(|e| { - error!("Failed to execute column query for '{}': {}", table_name, e); - e.to_string() - })?; - - // Safety: `rows` borrows from `statement_handle.statement` which has lifetime `'a`. - // The borrow checker infers a shorter lifetime for the reference, but the data lives - // for `'a`. We transmute here to express that — the same pattern the rest of this - // codebase relies on implicitly via `get_from_wrapper`. - let rows: rusqlite::Rows<'a> = unsafe { std::mem::transmute(rows) }; - statement_handle.rows = Some(rows); - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::odbc::implementation::alloc_handles::StatementHandle; - use rusqlite::Connection; - - fn make_test_db() -> Connection { - let conn = Connection::open_in_memory().unwrap(); - conn.execute_batch( - "CREATE TABLE widgets ( - widget_id INTEGER PRIMARY KEY, - name VARCHAR(50) NOT NULL, - weight REAL, - notes TEXT - );", - ) - .unwrap(); - conn - } - - #[test] - fn result_set_has_18_columns() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_columns(&mut handle, "widgets").unwrap(); - - let col_count = handle.statement.as_ref().unwrap().column_count(); - assert_eq!(col_count, 18); - } - - #[test] - fn result_set_has_one_row_per_table_column() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_columns(&mut handle, "widgets").unwrap(); - - let mut count = 0; - let rows = handle.rows.as_mut().unwrap(); - while rows.next().unwrap().is_some() { - count += 1; - } - assert_eq!(count, 4); // widgets has 4 columns - } - - #[test] - fn rows_contain_correct_table_and_column_names() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_columns(&mut handle, "widgets").unwrap(); - - let rows = handle.rows.as_mut().unwrap(); - let row = rows.next().unwrap().unwrap(); - - // Column 3 (index 2) = TABLE_NAME, column 4 (index 3) = COLUMN_NAME - let table_name: String = row.get(2).unwrap(); - let column_name: String = row.get(3).unwrap(); - - assert_eq!(table_name, "widgets"); - assert_eq!(column_name, "widget_id"); - } - - #[test] - fn ordinal_position_is_one_based() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_columns(&mut handle, "widgets").unwrap(); - - let rows = handle.rows.as_mut().unwrap(); - let row = rows.next().unwrap().unwrap(); - - // Column 17 (index 16) = ORDINAL_POSITION - let ordinal: i64 = row.get(16).unwrap(); - assert_eq!(ordinal, 1); - } - - #[test] - fn not_null_column_has_nullable_zero() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_columns(&mut handle, "widgets").unwrap(); - - let rows = handle.rows.as_mut().unwrap(); - - // widget_id is PRIMARY KEY (implicitly NOT NULL); skip to name (NOT NULL) - rows.next().unwrap(); // widget_id - let row = rows.next().unwrap().unwrap(); // name VARCHAR(50) NOT NULL - - // Column 11 (index 10) = NULLABLE: 0 = not nullable - let nullable: i64 = row.get(10).unwrap(); - assert_eq!(nullable, 0); - } - - #[test] - fn nullable_column_has_nullable_one() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_columns(&mut handle, "widgets").unwrap(); - - let rows = handle.rows.as_mut().unwrap(); - rows.next().unwrap(); // widget_id - rows.next().unwrap(); // name - let row = rows.next().unwrap().unwrap(); // weight REAL (nullable) - - // Column 11 (index 10) = NULLABLE: 1 = nullable - let nullable: i64 = row.get(10).unwrap(); - assert_eq!(nullable, 1); - } - - #[test] - fn rejects_table_name_with_special_characters() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - let result = impl_get_columns(&mut handle, "widgets; DROP TABLE widgets"); - assert!(result.is_err()); - } -} +// SQLColumns logic is now implemented as DbConnection::get_columns on SqliteDbConnection in query.rs. diff --git a/src/odbc/implementation/connect.rs b/src/odbc/implementation/connect.rs index 594ee30..2f8da9c 100644 --- a/src/odbc/implementation/connect.rs +++ b/src/odbc/implementation/connect.rs @@ -1,49 +1 @@ -use crate::odbc::implementation::alloc_handles::ConnectionHandle; -use crate::odbc::utils::get_private_profile_string; -use rusqlite::{Connection, OpenFlags}; -use tracing::{error, info}; - -/// Connect via DSN name — looks up Database from ODBC configuration. -/// Used by SQLConnectW. -pub(crate) fn impl_connect( - connection_handle: &mut ConnectionHandle, - server_name: String, - _user_name: Option, - _authentication: Option, -) { - let database = match get_private_profile_string(&server_name, "Database", "odbc.ini", 1024) { - Ok(Some(dsn)) => dsn, - Ok(None) => { - error!("Error: Database setting not found"); - "TODO".to_string() - } - Err(e) => { - error!("Error: Database setting not found: {}", e); - "TODO".to_string() - } - }; - info!("Opening [{}] for DSN [{}]", database, server_name); - impl_connect_to_database(connection_handle, database); -} - -/// Connect directly to a database file path. -/// Used by SQLDriverConnectW after resolving the database path. -pub(crate) fn impl_connect_to_database( - connection_handle: &mut ConnectionHandle, - database_path: String, -) { - let conn = match Connection::open_with_flags( - &database_path, - OpenFlags::SQLITE_OPEN_READ_WRITE - | OpenFlags::SQLITE_OPEN_URI - | OpenFlags::SQLITE_OPEN_NO_MUTEX, - ) { - Ok(conn) => conn, - Err(e) => { - error!("Connection failed: {}", e); - return; - } - }; - - connection_handle.sqlite_connection = Some(conn); -} +// Connection logic is now handled directly in the api layer via handles::factory(). diff --git a/src/odbc/implementation/env_attrs.rs b/src/odbc/implementation/env_attrs.rs index 52d34fc..ceaf01c 100644 --- a/src/odbc/implementation/env_attrs.rs +++ b/src/odbc/implementation/env_attrs.rs @@ -1,11 +1 @@ -use crate::odbc::implementation::alloc_handles::EnvironmentHandle; -use odbc_sys::AttrOdbcVersion; - -// TODO: Make this return something.... a Result? -pub(crate) fn set_odbc_version(env: &mut EnvironmentHandle, odbc_version: AttrOdbcVersion) { - env.odbc_version = odbc_version; -} - -pub(crate) fn get_odbc_version(env: &EnvironmentHandle) -> AttrOdbcVersion { - env.odbc_version -} +// Environment attribute access is now done via methods on EnvironmentHandle in handles.rs. diff --git a/src/odbc/implementation/getdata.rs b/src/odbc/implementation/getdata.rs index 26f8dce..f021888 100644 --- a/src/odbc/implementation/getdata.rs +++ b/src/odbc/implementation/getdata.rs @@ -1,109 +1,2 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; -use odbc_sys::CDataType; -use rusqlite::types::ValueRef; -use tracing::{error, warn}; - -pub(crate) fn impl_getdata( - statement_handle: &StatementHandle, - target_type: &CDataType, - col_or_param: u16, -) -> String { - let row = match &statement_handle.row { - Some(row) => row, - None => { - error!("No current row available"); - return "ERROR".to_string(); - } - }; - - let col_index = (col_or_param - 1) as usize; - - // Get the raw value from SQLite - let value_ref: ValueRef = match row.get_ref(col_index) { - Ok(value_ref) => value_ref, - Err(err) => { - error!("Failed to get column {}: {}", col_index, err); - return "ERROR".to_string(); - } - }; - - // Convert the SQLite value to the requested ODBC type - match target_type { - CDataType::Char => { - // Convert any SQLite type to string representation - match value_ref { - ValueRef::Null => "NULL".to_string(), - ValueRef::Integer(i) => i.to_string(), - ValueRef::Real(f) => f.to_string(), - ValueRef::Text(s) => match std::str::from_utf8(s) { - Ok(text) => text.to_string(), - Err(err) => { - error!("UTF-8 conversion failed: {}", err); - "ERROR".to_string() - } - }, - ValueRef::Blob(b) => { - // Convert blob to hex string representation - b.iter() - .map(|byte| format!("{byte:02x}")) - .collect::() - } - } - } - CDataType::SLong => { - // Convert to integer - match value_ref { - ValueRef::Integer(i) => i.to_string(), - ValueRef::Real(f) => (f as i64).to_string(), - ValueRef::Text(s) => match std::str::from_utf8(s) { - Ok(text) => match text.parse::() { - Ok(i) => i.to_string(), - Err(_) => "0".to_string(), - }, - Err(_) => "0".to_string(), - }, - ValueRef::Null => "0".to_string(), - ValueRef::Blob(_) => "0".to_string(), - } - } - CDataType::Double => { - // Convert to float - match value_ref { - ValueRef::Real(f) => f.to_string(), - ValueRef::Integer(i) => (i as f64).to_string(), - ValueRef::Text(s) => match std::str::from_utf8(s) { - Ok(text) => match text.parse::() { - Ok(f) => f.to_string(), - Err(_) => "0.0".to_string(), - }, - Err(_) => "0.0".to_string(), - }, - ValueRef::Null => "0.0".to_string(), - ValueRef::Blob(_) => "0.0".to_string(), - } - } - _ => { - // For unsupported types, convert to string as fallback - warn!( - "Unsupported target type {:?}, converting to string", - target_type - ); - match value_ref { - ValueRef::Null => "NULL".to_string(), - ValueRef::Integer(i) => i.to_string(), - ValueRef::Real(f) => f.to_string(), - ValueRef::Text(s) => match std::str::from_utf8(s) { - Ok(text) => text.to_string(), - Err(err) => { - error!("UTF-8 conversion failed: {}", err); - "ERROR".to_string() - } - }, - ValueRef::Blob(b) => b - .iter() - .map(|byte| format!("{byte:02x}")) - .collect::(), - } - } - } -} +// Data retrieval logic has moved to `ActiveStatement::get_data` in `query.rs`. +// This module is kept as a placeholder to avoid disturbing the module tree. diff --git a/src/odbc/implementation/implementation.rs b/src/odbc/implementation/implementation.rs index 1789d39..8fc6810 100644 --- a/src/odbc/implementation/implementation.rs +++ b/src/odbc/implementation/implementation.rs @@ -1,19 +1 @@ -use odbc_sys::{InfoType, InfoTypeType}; - -pub(crate) fn get_info(info_type: InfoType) -> Option { - match info_type { - InfoType::ActiveEnvironments => Some(InfoTypeType::SqlUSmallInt(0)), - InfoType::UserName => Some(InfoTypeType::String("foo".to_string())), - InfoType::MaxConcurrentActivities => Some(InfoTypeType::SqlUSmallInt(1)), - InfoType::ScrollOptions => Some(InfoTypeType::SqlUInteger(1)), // Proper u32 for test - _ => None, - } -} - -/* -pub(crate) fn get_supported_functions() -> Vec { - vec![FunctionId::SqlAllocConnect, FunctionId::SqlAllocHandle] -} - - - */ +// get_info is now implemented as DbConnection::get_info on SqliteDbConnection in query.rs. diff --git a/src/odbc/implementation/query.rs b/src/odbc/implementation/query.rs new file mode 100644 index 0000000..0083433 --- /dev/null +++ b/src/odbc/implementation/query.rs @@ -0,0 +1,473 @@ +use crate::odbc::handles::{ActiveStatement, DbConnection, DbConnectionFactory}; +use odbc_sys::{CDataType, InfoType, InfoTypeType}; +use rusqlite::OpenFlags; +use rusqlite::types::Value; +use std::sync::{Arc, Mutex}; +use tracing::error; + +// ─── Factory ───────────────────────────────────────────────────────────────── + +pub(crate) struct SqliteDbConnectionFactory; + +impl DbConnectionFactory for SqliteDbConnectionFactory { + fn create_from_path(&self, path: &str) -> Result, String> { + let conn = rusqlite::Connection::open_with_flags( + path, + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_URI + | OpenFlags::SQLITE_OPEN_NO_MUTEX, + ) + .map_err(|e| e.to_string())?; + Ok(Arc::new(SqliteDbConnection::new(conn))) + } +} + +// ─── Connection ────────────────────────────────────────────────────────────── + +pub(crate) struct SqliteDbConnection { + connection: Arc>, +} + +impl SqliteDbConnection { + pub(crate) fn new(connection: rusqlite::Connection) -> Self { + SqliteDbConnection { + connection: Arc::new(Mutex::new(connection)), + } + } +} + +impl DbConnection for SqliteDbConnection { + fn prepare_statement(&self, sql: &str) -> Result, String> { + let conn = self.connection.lock().map_err(|e| e.to_string())?; + + let mut stmt = conn.prepare(sql).map_err(|e| e.to_string())?; + + let column_names: Vec = stmt.column_names().iter().map(|s| s.to_string()).collect(); + let column_count = column_names.len(); + + // Capture readonly flag before query() consumes the statement. + let is_readonly = stmt.readonly(); + + let mut rows_data: Vec> = Vec::new(); + let mut raw_rows = stmt.query([]).map_err(|e| e.to_string())?; + while let Some(row) = raw_rows.next().map_err(|e| e.to_string())? { + let row_data: Vec = (0..column_count) + .map(|i| row.get::<_, Value>(i).unwrap_or(Value::Null)) + .collect(); + rows_data.push(row_data); + } + + // For read-only queries (SELECT), changes() reflects prior DML — not this statement. + let changes = if is_readonly { 0 } else { conn.changes() }; + + Ok(Box::new(SqliteStatement { + column_names, + rows: rows_data, + current_row: None, + changes, + })) + } + + fn get_tables(&self) -> Result, String> { + self.prepare_statement("SELECT name FROM sqlite_master WHERE type='table'") + } + + fn get_columns(&self, table_name: &str) -> Result, String> { + // Validate table name before interpolating into SQL. + // pragma_table_info does not accept bound parameters. + if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') { + return Err(format!("Invalid table name: '{table_name}'")); + } + + // Reshape pragma_table_info output into the ODBC 18-column SQLColumns result set. + // ODBC SQL type codes: -7=BIT, 2=NUMERIC, 4=INTEGER, 8=DOUBLE, -4=LONGVARBINARY, 12=VARCHAR + let sql = format!( + r#"SELECT + NULL AS TABLE_CAT, + NULL AS TABLE_SCHEM, + '{table_name}' AS TABLE_NAME, + "name" AS COLUMN_NAME, + CASE + WHEN upper("type") LIKE '%INT%' THEN 4 + WHEN upper("type") LIKE '%REAL%' + OR upper("type") LIKE '%FLOAT%' + OR upper("type") LIKE '%DOUBLE%' THEN 8 + WHEN upper("type") LIKE '%BLOB%' THEN -4 + WHEN upper("type") LIKE '%BOOL%' THEN -7 + WHEN upper("type") LIKE '%NUMERIC%' + OR upper("type") LIKE '%DECIMAL%' THEN 2 + ELSE 12 + END AS DATA_TYPE, + CASE WHEN "type" = '' THEN 'TEXT' ELSE "type" END AS TYPE_NAME, + NULL AS COLUMN_SIZE, + NULL AS BUFFER_LENGTH, + NULL AS DECIMAL_DIGITS, + NULL AS NUM_PREC_RADIX, + CASE WHEN "notnull" = 1 THEN 0 ELSE 1 END AS NULLABLE, + NULL AS REMARKS, + "dflt_value" AS COLUMN_DEF, + CASE + WHEN upper("type") LIKE '%INT%' THEN 4 + WHEN upper("type") LIKE '%REAL%' + OR upper("type") LIKE '%FLOAT%' + OR upper("type") LIKE '%DOUBLE%' THEN 8 + WHEN upper("type") LIKE '%BLOB%' THEN -4 + WHEN upper("type") LIKE '%BOOL%' THEN -7 + WHEN upper("type") LIKE '%NUMERIC%' + OR upper("type") LIKE '%DECIMAL%' THEN 2 + ELSE 12 + END AS SQL_DATA_TYPE, + NULL AS SQL_DATETIME_SUB, + NULL AS CHAR_OCTET_LENGTH, + "cid" + 1 AS ORDINAL_POSITION, + CASE WHEN "notnull" = 1 THEN 'NO' ELSE 'YES' END AS IS_NULLABLE + FROM pragma_table_info('{table_name}')"# + ); + + self.prepare_statement(&sql).map_err(|e| { + error!("Failed to prepare column query for '{}': {}", table_name, e); + e + }) + } + + fn get_info(&self, info_type: InfoType) -> Option { + match info_type { + InfoType::ActiveEnvironments => Some(InfoTypeType::SqlUSmallInt(0)), + InfoType::UserName => Some(InfoTypeType::String("foo".to_string())), + InfoType::MaxConcurrentActivities => Some(InfoTypeType::SqlUSmallInt(1)), + InfoType::ScrollOptions => Some(InfoTypeType::SqlUInteger(1)), + _ => None, + } + } +} + +// ─── Statement ─────────────────────────────────────────────────────────────── + +pub(crate) struct SqliteStatement { + column_names: Vec, + rows: Vec>, + current_row: Option, + changes: u64, +} + +impl ActiveStatement for SqliteStatement { + fn column_count(&self) -> usize { + self.column_names.len() + } + + fn column_name(&self, index: usize) -> Result { + self.column_names + .get(index) + .cloned() + .ok_or_else(|| format!("Column index {index} out of range")) + } + + fn execute(&mut self) -> Result<(), String> { + Ok(()) + } + + fn fetch_next_row(&mut self) -> Result { + let next = match self.current_row { + None => 0, + Some(i) => i + 1, + }; + if next < self.rows.len() { + self.current_row = Some(next); + Ok(true) + } else { + Ok(false) + } + } + + fn get_data(&self, col_index: usize, target_type: CDataType) -> Result { + let row_index = self + .current_row + .ok_or("No current row; call SQLFetch first")?; + let row = self.rows.get(row_index).ok_or("Row index out of range")?; + let value = row + .get(col_index) + .ok_or_else(|| format!("Column index {col_index} out of range"))?; + + let result = match target_type { + CDataType::Char => value_to_string(value), + CDataType::SLong => value_to_slong(value), + CDataType::Double => value_to_double(value), + _ => value_to_string(value), + }; + Ok(result) + } + + fn row_changes(&self) -> u64 { + self.changes + } +} + +// ─── Value converters ──────────────────────────────────────────────────────── + +fn value_to_string(value: &Value) -> String { + match value { + Value::Null => "NULL".to_string(), + Value::Integer(i) => i.to_string(), + Value::Real(f) => f.to_string(), + Value::Text(s) => s.clone(), + Value::Blob(b) => b.iter().map(|byte| format!("{byte:02x}")).collect(), + } +} + +fn value_to_slong(value: &Value) -> String { + match value { + Value::Integer(i) => i.to_string(), + Value::Real(f) => (*f as i64).to_string(), + Value::Text(s) => s + .parse::() + .map(|i| i.to_string()) + .unwrap_or_else(|_| "0".to_string()), + Value::Null => "0".to_string(), + Value::Blob(_) => "0".to_string(), + } +} + +fn value_to_double(value: &Value) -> String { + match value { + Value::Real(f) => f.to_string(), + Value::Integer(i) => (*i as f64).to_string(), + Value::Text(s) => s + .parse::() + .map(|f| f.to_string()) + .unwrap_or_else(|_| "0.0".to_string()), + Value::Null => "0.0".to_string(), + Value::Blob(_) => "0.0".to_string(), + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::handles::DbConnection; + use odbc_sys::CDataType; + use rusqlite::Connection; + + fn make_test_db() -> SqliteDbConnection { + let conn = Connection::open_in_memory().unwrap(); + conn.execute_batch( + "CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + score REAL + ); + INSERT INTO users VALUES (1, 'Alice', 9.5); + INSERT INTO users VALUES (2, 'Bob', 7.0);", + ) + .unwrap(); + SqliteDbConnection::new(conn) + } + + fn make_widgets_db() -> SqliteDbConnection { + let conn = Connection::open_in_memory().unwrap(); + conn.execute_batch( + "CREATE TABLE widgets ( + widget_id INTEGER PRIMARY KEY, + name VARCHAR(50) NOT NULL, + weight REAL, + notes TEXT + );", + ) + .unwrap(); + SqliteDbConnection::new(conn) + } + + // ── prepare_statement ──────────────────────────────────────────────────── + + #[test] + fn prepare_returns_correct_column_count() { + let db = make_test_db(); + let stmt = db + .prepare_statement("SELECT id, name, score FROM users") + .unwrap(); + assert_eq!(stmt.column_count(), 3); + } + + #[test] + fn prepare_returns_correct_column_names() { + let db = make_test_db(); + let stmt = db + .prepare_statement("SELECT id, name, score FROM users") + .unwrap(); + assert_eq!(stmt.column_name(0).unwrap(), "id"); + assert_eq!(stmt.column_name(1).unwrap(), "name"); + assert_eq!(stmt.column_name(2).unwrap(), "score"); + } + + #[test] + fn column_name_out_of_range_returns_error() { + let db = make_test_db(); + let stmt = db.prepare_statement("SELECT id FROM users").unwrap(); + assert!(stmt.column_name(99).is_err()); + } + + #[test] + fn fetch_next_row_iterates_through_all_rows() { + let db = make_test_db(); + let mut stmt = db.prepare_statement("SELECT id FROM users").unwrap(); + assert_eq!(stmt.fetch_next_row().unwrap(), true); + assert_eq!(stmt.fetch_next_row().unwrap(), true); + assert_eq!(stmt.fetch_next_row().unwrap(), false); + } + + #[test] + fn fetch_next_row_returns_false_immediately_on_empty_result() { + let db = make_test_db(); + let mut stmt = db + .prepare_statement("SELECT id FROM users WHERE id = 999") + .unwrap(); + assert_eq!(stmt.fetch_next_row().unwrap(), false); + } + + #[test] + fn get_data_returns_correct_string_values() { + let db = make_test_db(); + let mut stmt = db + .prepare_statement("SELECT id, name, score FROM users ORDER BY id") + .unwrap(); + stmt.fetch_next_row().unwrap(); + assert_eq!(stmt.get_data(0, CDataType::Char).unwrap(), "1"); + assert_eq!(stmt.get_data(1, CDataType::Char).unwrap(), "Alice"); + } + + #[test] + fn get_data_slong_returns_integer_string() { + let db = make_test_db(); + let mut stmt = db + .prepare_statement("SELECT id FROM users ORDER BY id") + .unwrap(); + stmt.fetch_next_row().unwrap(); + assert_eq!(stmt.get_data(0, CDataType::SLong).unwrap(), "1"); + } + + #[test] + fn get_data_double_returns_float_string() { + let db = make_test_db(); + let mut stmt = db + .prepare_statement("SELECT score FROM users ORDER BY id") + .unwrap(); + stmt.fetch_next_row().unwrap(); + assert_eq!(stmt.get_data(0, CDataType::Double).unwrap(), "9.5"); + } + + #[test] + fn execute_is_a_noop_and_rows_remain_iterable() { + let db = make_test_db(); + let mut stmt = db.prepare_statement("SELECT id FROM users").unwrap(); + stmt.execute().unwrap(); + assert_eq!(stmt.fetch_next_row().unwrap(), true); + } + + #[test] + fn non_select_has_zero_rows_and_nonzero_changes() { + let db = make_test_db(); + let stmt = db + .prepare_statement("INSERT INTO users VALUES (3, 'Carol', 8.0)") + .unwrap(); + assert_eq!(stmt.column_count(), 0); + assert_eq!(stmt.row_changes(), 1); + } + + #[test] + fn select_has_zero_changes() { + let db = make_test_db(); + let stmt = db.prepare_statement("SELECT id FROM users").unwrap(); + assert_eq!(stmt.row_changes(), 0); + } + + // ── get_tables ─────────────────────────────────────────────────────────── + + #[test] + fn get_tables_result_set_is_populated() { + let db = make_test_db(); + assert!(db.get_tables().is_ok()); + } + + #[test] + fn get_tables_returns_at_least_one_row_per_table() { + let db = make_test_db(); + let mut stmt = db.get_tables().unwrap(); + let mut count = 0; + while stmt.fetch_next_row().unwrap() { + count += 1; + } + assert!(count >= 1); + } + + #[test] + fn get_tables_row_contains_nonempty_name() { + let db = make_test_db(); + let mut stmt = db.get_tables().unwrap(); + stmt.fetch_next_row().unwrap(); + let name = stmt.get_data(0, CDataType::Char).unwrap(); + assert!(!name.is_empty()); + } + + // ── get_columns ────────────────────────────────────────────────────────── + + #[test] + fn get_columns_result_set_has_18_columns() { + let db = make_widgets_db(); + let stmt = db.get_columns("widgets").unwrap(); + assert_eq!(stmt.column_count(), 18); + } + + #[test] + fn get_columns_has_one_row_per_table_column() { + let db = make_widgets_db(); + let mut stmt = db.get_columns("widgets").unwrap(); + let mut count = 0; + while stmt.fetch_next_row().unwrap() { + count += 1; + } + assert_eq!(count, 4); + } + + #[test] + fn get_columns_row_contains_table_and_column_name() { + let db = make_widgets_db(); + let mut stmt = db.get_columns("widgets").unwrap(); + stmt.fetch_next_row().unwrap(); + assert_eq!(stmt.get_data(2, CDataType::Char).unwrap(), "widgets"); + assert_eq!(stmt.get_data(3, CDataType::Char).unwrap(), "widget_id"); + } + + #[test] + fn get_columns_ordinal_position_is_one_based() { + let db = make_widgets_db(); + let mut stmt = db.get_columns("widgets").unwrap(); + stmt.fetch_next_row().unwrap(); + assert_eq!(stmt.get_data(16, CDataType::SLong).unwrap(), "1"); + } + + #[test] + fn get_columns_not_null_has_nullable_zero() { + let db = make_widgets_db(); + let mut stmt = db.get_columns("widgets").unwrap(); + stmt.fetch_next_row().unwrap(); // widget_id + stmt.fetch_next_row().unwrap(); // name NOT NULL + assert_eq!(stmt.get_data(10, CDataType::SLong).unwrap(), "0"); + } + + #[test] + fn get_columns_nullable_has_nullable_one() { + let db = make_widgets_db(); + let mut stmt = db.get_columns("widgets").unwrap(); + stmt.fetch_next_row().unwrap(); // widget_id + stmt.fetch_next_row().unwrap(); // name + stmt.fetch_next_row().unwrap(); // weight REAL (nullable) + assert_eq!(stmt.get_data(10, CDataType::SLong).unwrap(), "1"); + } + + #[test] + fn get_columns_rejects_invalid_table_name() { + let db = make_widgets_db(); + assert!(db.get_columns("widgets; DROP TABLE widgets").is_err()); + } +} diff --git a/src/odbc/implementation/tables.rs b/src/odbc/implementation/tables.rs index 56d5d5e..5618369 100644 --- a/src/odbc/implementation/tables.rs +++ b/src/odbc/implementation/tables.rs @@ -1,100 +1 @@ -use crate::odbc::implementation::alloc_handles::StatementHandle; - -pub(crate) fn impl_get_tables<'a>( - statement_handle: &mut StatementHandle<'a>, -) -> Result<(), String> { - let stmt = statement_handle - .sqlite_connection - .prepare("SELECT name FROM sqlite_master WHERE type='table'") - .map_err(|e| e.to_string())?; - - statement_handle.statement = Some(stmt); - - let rows = statement_handle - .statement - .as_mut() - .unwrap() - .query([]) - .map_err(|e| e.to_string())?; - - // Safety: `rows` borrows from `statement_handle.statement` which has lifetime `'a`. - // We transmute the inferred shorter reference lifetime to `'a` — the same pattern - // used throughout this codebase via `get_from_wrapper`. - let rows: rusqlite::Rows<'a> = unsafe { std::mem::transmute(rows) }; - statement_handle.rows = Some(rows); - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::odbc::implementation::alloc_handles::StatementHandle; - use rusqlite::Connection; - - fn make_test_db() -> Connection { - let conn = Connection::open_in_memory().unwrap(); - conn.execute_batch( - "CREATE TABLE apples (id INTEGER PRIMARY KEY); - CREATE TABLE oranges (id INTEGER PRIMARY KEY);", - ) - .unwrap(); - conn - } - - #[test] - fn result_set_is_populated_after_call() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - let result = impl_get_tables(&mut handle); - - assert!(result.is_ok()); - assert!(handle.statement.is_some()); - assert!(handle.rows.is_some()); - } - - #[test] - fn returns_one_row_per_table() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_tables(&mut handle).unwrap(); - - let mut count = 0; - let rows = handle.rows.as_mut().unwrap(); - while rows.next().unwrap().is_some() { - count += 1; - } - // sqlite_sequence is also present (autoincrement tracking table) - assert!(count >= 2); - } - - #[test] - fn row_contains_table_name() { - let conn = make_test_db(); - let mut handle = StatementHandle { - sqlite_connection: &conn, - statement: None, - rows: None, - row: None, - }; - - impl_get_tables(&mut handle).unwrap(); - - let rows = handle.rows.as_mut().unwrap(); - let row = rows.next().unwrap().unwrap(); - let name: String = row.get(0).unwrap(); - assert!(!name.is_empty()); - } -} +// SQLTables logic is now implemented as DbConnection::get_tables on SqliteDbConnection in query.rs.