diff --git a/docs/connection_string_allow_list_design.md b/docs/connection_string_allow_list_design.md new file mode 100644 index 000000000..abbb85522 --- /dev/null +++ b/docs/connection_string_allow_list_design.md @@ -0,0 +1,1577 @@ +# Connection String Allow-List Design for mssql-python + +**Date:** October 23, 2025 +**Author:** Engineering Team +**Status:** Design Proposal + +--- + +## Executive Summary + +This document outlines the design for implementing a **connection string parameter allow-list** in the mssql-python driver. Currently, the driver has a limited allow-list in `_construct_connection_string()` for **kwargs only**, but passes the base connection string **as-is** to the ODBC driver. The new design implements a comprehensive parser that uses **lenient ODBC-style parsing** to extract parameters and validates **all** connection string parameters against an allow-list before passing them to ODBC Driver 18 for SQL Server. + +**Key Design Philosophy**: The parser follows **ODBC driver behavior** - it parses valid key=value pairs and **silently ignores** malformed entries (with warning logs), rather than raising exceptions. This matches the behavior of all Microsoft SQL Server drivers (ODBC, .NET, JDBC) and ensures maximum compatibility. + +This allow-list approach is necessary for three key reasons: + +1. **ODBC Feature Compatibility**: Some ODBC connection string parameters require additional configurations (e.g., Always Encrypted extensibility modules) for which Python doesn't have a first-class experience yet. Allowing these parameters without proper support would create confusion and support burden. + +2. **Future Driver Evolution**: The driver may evolve its underlying implementation over time. While ODBC parity is a goal, not all ODBC features may remain available as the driver evolves. By deliberately allow-listing parameters now, we can ensure a smoother evolution path and avoid the breaking change of removing previously exposed parameters later. It's easier to add parameters over time than to remove them once users depend on them. + +3. **Simplified Connection Experience**: ODBC connection strings have accumulated many parameter synonyms over decades of backward compatibility (e.g., "server", "address", "addr", "network address" all mean the same thing). A modern Python driver should provide a cleaner, simplified API by exposing only a curated set of parameters with clear, consistent naming. + +--- + +## Table of Contents + +1. [Problem Statement](#problem-statement) +2. [Current Implementation Analysis](#current-implementation-analysis) +3. [Design Goals](#design-goals) +4. [Architecture Overview](#architecture-overview) +5. [Connection String Parser Design](#connection-string-parser-design) +6. [Allow-List Strategy](#allow-list-strategy) +7. [Performance Considerations](#performance-considerations) +8. [Data Flow Diagrams](#data-flow-diagrams) +9. [Implementation Details](#implementation-details) +10. [Testing Strategy](#testing-strategy) +11. [Design Considerations](#design-considerations) +12. [Future Enhancements](#future-enhancements) + +--- + +## Problem Statement + +### Current Issues + +1. **Inconsistent Filtering**: The driver currently: + - Filters **kwargs** through an allow-list (only 6 parameters: Server, Uid, Pwd, Database, Encrypt, TrustServerCertificate) + - Passes the base `connection_str` parameter **directly** to ODBC without validation + +2. **ODBC Feature Compatibility**: Some ODBC connection string parameters require additional infrastructure: + - Always Encrypted with extensibility modules requires custom key store providers + - Column Encryption Key caching requires additional Python bindings + - These features don't have first-class Python API support yet + - Allowing these parameters creates user confusion and support issues + +3. **Future Driver Evolution**: The driver may evolve its underlying implementation: + - While ODBC parity is a goal, not all ODBC features may remain available as the driver evolves + - Some ODBC-specific parameters may not translate to future implementations + - Being deliberate about which parameters to expose avoids future breaking changes + - It's easier to add parameters over time than to remove them once users depend on them + - Gating parameters now prevents users from building dependencies on features that may not be available + +4. **Parameter Synonym Bloat**: ODBC connection strings have accumulated many synonyms: + - "Server", "Address", "Addr", "Network Address" all mean the same thing + - "Uid", "User", "User ID" all mean the same thing + - This creates confusion and inconsistent usage patterns + - A modern Python driver should have a clean, minimal API surface + +5. **No Parsing Logic**: The current implementation uses simple string splitting on `;` which doesn't handle: + - Escaped characters (e.g., `{}` in values) + - Quoted values + - Empty values + - Malformed connection strings + +6. **Parser Behavior Mismatch**: The original strict parser design would raise exceptions for malformed connection strings, but **ODBC drivers use lenient parsing** - they silently ignore malformed entries and continue parsing valid ones. This mismatch could break user code that works with other Microsoft SQL Server drivers. + + **Citations**: + - `mssql_python/helpers.py`, lines 28-30: `connection_attributes = connection_str.split(";")` - splits on semicolon without handling braced values. Passwords could have the special characters which are considered delimiters in connection strings. + - `mssql_python/helpers.py`, line 33: `if attribute.lower().split("=")[0] == "driver":` - splits on `=` without handling escaped or braced values + - `mssql_python/helpers.py`, lines 66-67: `for param in parameters:` / `if param.lower().startswith("app="):` - simple string operations, no ODBC-compliant parsing + - `mssql_python/helpers.py`, line 69: `key, _ = param.split("=", 1)` - splits on first `=` only, doesn't handle braces or escaping + - **ODBC Driver Investigation**: Research into the ODBC driver codebase (`/Sql/Ntdbms/sqlncli/odbc/sqlcconn.cpp`) confirms that ODBC uses lenient parsing - malformed entries without `=` are silently ignored with `hr = S_FALSE; goto RetExit`, and parsing continues for subsequent parameters. + +### Design Motivations + +1. **Controlled Feature Set**: By implementing an allow-list, we can: + - Only expose ODBC features that have proper Python API support + - Prevent users from attempting to use unsupported features + - Reduce the support burden by rejecting parameters we can't properly handle + +2. **Migration Path**: The allow-list provides: + - A stable API surface that will work across current and future driver implementations + - Clear documentation of what parameters are supported + - A deliberate, controlled approach to exposing parameters (easier to add than remove) + - Protection against breaking changes when evolving the driver + - Ability to achieve ODBC parity incrementally while maintaining backward compatibility + +3. **Simplified API**: By normalizing synonyms and exposing only canonical parameter names: + - Users have a consistent, predictable API + - Documentation is clearer + - Code examples are more uniform + - New Python developers aren't confused by legacy ODBC conventions + +4. **ODBC-Compatible Behavior**: By using lenient parsing that matches ODBC driver behavior: + - User code that works with ODBC/other Microsoft drivers continues to work + - Malformed connection string entries are handled gracefully (logged but not fatal) + - Maximum compatibility with existing connection string patterns + - Debugging is easier with warning logs for problematic entries + +### Requirements + +1. Parse the complete connection string (base + kwargs) using **lenient ODBC-style parsing** +2. Validate all parameters against an allow-list +3. Reconstruct a clean connection string with only allowed parameters +4. Maintain backward compatibility with existing code +5. Ensure high performance (sub-millisecond overhead) +6. Handle ODBC connection string syntax correctly +7. Normalize parameter synonyms to canonical names +8. Prepare for future driver enhancements +9. **Match ODBC driver behavior**: Silently ignore malformed entries with warning logs, never raise exceptions for syntax errors during parsing +10. **Provide diagnostic logging**: Log warnings for ignored/malformed entries to help users debug connection string issues + +--- + +## Current Implementation Analysis + +### Code Flow (Before This Design) + +``` +User Input: + ├─ connection_str: "Server=localhost;Database=mydb;SomeParam=value" + └─ kwargs: {encrypt: "yes", server: "override"} + +Current Flow: + 1. add_driver_name_to_app_parameter(connection_str) + ├─ Finds any "APP=" parameter (case-insensitive) + ├─ Overwrites value to "MSSQL-Python" (preserves key casing) + └─ Adds "APP=MSSQL-Python" if not present + + 2. add_driver_to_connection_str(connection_str) + ├─ Strips any existing "Driver=" params (always removed) + ├─ Adds "Driver={ODBC Driver 18 for SQL Server}" at position 0 + └─ Returns: "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=mydb;SomeParam=value;APP=MSSQL-Python" + + 3. _construct_connection_string(connection_str, **kwargs) + ├─ Takes output from step 2 + ├─ Appends only ALLOW-LISTED kwargs: + │ ├─ server → "Server" + │ ├─ user/uid → "Uid" + │ ├─ password/pwd → "Pwd" + │ ├─ database → "Database" + │ ├─ encrypt → "Encrypt" + │ └─ trust_server_certificate → "TrustServerCertificate" + ├─ **else: continue** (filters out other kwargs) + └─ Returns: "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=mydb;SomeParam=value;APP=MSSQL-Python;Encrypt=yes" + + 4. ddbc_bindings.Connection(connection_str, ...) + └─ Passes final string to ODBC (including "SomeParam=value" - UNFILTERED!) +``` + +### Key Observations + +**1. Deliberate Driver and APP Control** (by design): + +The driver **intentionally** controls these two parameters to ensure consistent behavior: + +- **Driver Parameter** (`helpers.py:38-49`): Any user-provided `Driver=` value is **stripped and replaced** with `{ODBC Driver 18 for SQL Server}`. This ensures the Python driver always uses the correct ODBC driver version. + +- **APP Parameter** (`helpers.py:99-109`): Any user-provided `APP=` value is **overwritten** to `MSSQL-Python`. This ensures proper application identification in SQL Server logs and monitoring tools, making it easy to identify connections from this Python driver. + +These are **intentional design choices** that will be preserved in the new allow-list implementation. + +**2. The base `connection_str` parameter bypasses all other filtering** (the problem): + +Only kwargs go through the allow-list check for non-Driver/APP parameters. This means: + +1. Users can pass unsupported ODBC parameters that the Python driver can't properly handle +2. Parameters that require additional infrastructure (like Always Encrypted extensibility) get passed to ODBC without validation +3. Parameters that may not be supported in future driver versions can create forward compatibility issues +4. Multiple synonyms for the same parameter create API inconsistency + +**3. The parsing is inadequate for ODBC connection strings**: + +The current parsing in `add_driver_to_connection_str()` (helpers.py:28-30) uses simple `split(";")`: + +```python +# Current implementation (helpers.py) +connection_attributes = connection_str.split(";") +for attribute in connection_attributes: + if attribute.lower().split("=")[0] == "driver": + continue +``` + +**This breaks with valid ODBC connection strings like**: +- `Server={local;host};Database=mydb` → incorrectly splits into 3 parts instead of 2 +- `PWD={p}}w{{d}` → doesn't unescape `}}` to `}` and `{{` to `{` +- `Server=localhost;` → creates empty string element +- `Server=localhost` (no semicolon) → handled, but inconsistent with trailing semicolon case + +--- + +## Design Goals + +### Functional Requirements + +| ID | Requirement | Priority | +|----|-------------|----------| +| FR-1 | Parse complete ODBC connection strings correctly with **lenient parsing** | **P0** | +| FR-2 | Filter all parameters against an allow-list | **P0** | +| FR-3 | Support ODBC connection string syntax (`;`, `{}`, `=`) | **P0** | +| FR-4 | Merge kwargs with connection string parameters | **P0** | +| FR-5 | Preserve parameter values exactly (including special chars) | **P0** | +| FR-6 | Maintain backward compatibility | **P1** | +| FR-7 | Provide clear warning logs for malformed/rejected params | **P1** | +| FR-8 | **Never raise exceptions** for malformed connection string syntax (ODBC behavior) | **P0** | +| FR-9 | Log diagnostics for ignored entries to aid debugging | **P1** | + +### Non-Functional Requirements + +| ID | Requirement | Target | Priority | +|----|-------------|--------|----------| +| NFR-1 | Parsing overhead | < 1 millisecond | **P0** | +| NFR-2 | Memory efficiency | < 5KB per connection | **P1** | +| NFR-3 | Code maintainability | Clear, documented, testable | **P1** | +| NFR-4 | Thread safety | Thread-safe parsing | **P1** | + +--- + +## Architecture Overview + +### High-Level Components + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Connection.__init__() │ +│ │ +│ Input: connection_str (str), **kwargs (dict) │ +└────────────────────┬────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ ConnectionStringParser │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ 1. parse(connection_str) → Dict[str, str] │ │ +│ │ - Tokenize connection string │ │ +│ │ - Handle escaping/quoting │ │ +│ │ - Return key-value pairs │ │ +│ └───────────────────────────────────────────────────────────┘ │ +└────────────────────┬────────────────────────────────────────────┘ + │ Parsed params dict + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ ParameterAllowList │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ 2. filter(params_dict) → Dict[str, str] │ │ +│ │ - Check each param against allow-list │ │ +│ │ - Normalize parameter names │ │ +│ │ - Log warnings for rejected params │ │ +│ └───────────────────────────────────────────────────────────┘ │ +└────────────────────┬────────────────────────────────────────────┘ + │ Filtered params dict + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ ConnectionStringBuilder │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ 3. merge_kwargs(filtered_params, kwargs) │ │ +│ │ - Merge kwargs into filtered params │ │ +│ │ - kwargs override connection_str values │ │ +│ └───────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ 4. add_driver(merged_params) │ │ +│ │ - Add Driver={ODBC Driver 18 for SQL Server} │ │ +│ │ - Add APP=MSSQL-Python │ │ +│ └───────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ 5. build() → str │ │ +│ │ - Reconstruct connection string │ │ +│ │ - Proper escaping for values with special chars │ │ +│ └───────────────────────────────────────────────────────────┘ │ +└────────────────────┬────────────────────────────────────────────┘ + │ Final connection string + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ ddbc_bindings.Connection() │ +│ │ +│ Passes to ODBC Driver │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Connection String Parser Design + +### ODBC Connection String Syntax + +Based on the official [ODBC Connection String specification (MS-ODBCSTR)](https://learn.microsoft.com/en-us/openspecs/sql_server_protocols/ms-odbcstr/55953f0e-2d30-4ad4-8e56-b4207e491409), specifically section 2.1.2 "ODBC Connection String Format": + +**Official ABNF Grammar:** + +```abnf +ODBCConnectionString = *(KeyValuePair SC) KeyValuePair [SC] +KeyValuePair = (Key EQ Value / SpaceStr) +Key = SpaceStr KeyName +KeyName = (nonSP-SC-EQ *nonEQ) +Value = (SpaceStr ValueFormat1 SpaceStr) / (ValueContent2) +ValueFormat1 = LCB ValueContent1 RCB +ValueContent1 = *(nonRCB / ESCAPEDRCB) +ValueContent2 = SpaceStr / SpaceStr (nonSP-LCB-SC) *nonSC + +; Character definitions +nonRCB = %x01-7C / %x7E-FFFF ; not "}" +nonSP-LCB-SC = %x01-1F / %x21-3A / %x3C-7A / %x7C-FFFF ; not space, "{" or ";" +nonSP-SC-EQ = %x01-1F / %x21-3A / %x3C / %x3E-FFFF ; not space, ";" or "=" +nonEQ = %x01-3C / %x3E-FFFF ; not "=" +nonSC = %x01-3A / %x3C-FFFF ; not ";" + +; Where: +SC = ";" ; semicolon +EQ = "=" ; equals +LCB = "{" ; left curly brace +RCB = "}" ; right curly brace +ESCAPEDRCB = "}}" ; escaped right curly brace +SpaceStr = *SP ; zero or more spaces +``` + +**Simplified Explanation:** + +- **Connection string** = Zero or more key-value pairs separated by semicolons, with optional trailing semicolon +- **Key-value pair** = `Key=Value` or just whitespace +- **Key** = Optional leading spaces + key name (non-space, non-semicolon, non-equals characters) +- **Value** can be in two formats: + - **Format 1 (Braced)**: `{content}` where `}` inside is escaped as `}}` + - **Format 2 (Unbraced)**: Characters that don't contain unescaped semicolons or braces +- **Escaping**: Only `}` needs escaping inside braced values, done by doubling: `}}` → `}` + +**Key Points:** +- Parameters separated by `;` +- Format: `KEY=VALUE` +- Values containing `;` or `{` must be wrapped in braces: `{value}` +- Right braces `}` in braced values are escaped by doubling: `}}` → `}` +- Left braces `{` in braced values are escaped by doubling: `{{` → `{` +- Trailing semicolons are allowed +- Whitespace around keys and values follows specific rules (see SpaceStr in grammar) + +### Parser Implementation + +**Lenient Parsing Philosophy** + +The parser implementation follows **ODBC driver behavior** - it uses **lenient parsing** that: +- **Never raises exceptions** for malformed connection string syntax +- **Silently ignores** malformed entries (entries without `=`, empty keys, unclosed braces) +- **Logs warnings** for ignored entries to provide diagnostics +- **Returns partial results** - successfully parsed key-value pairs even when some entries are malformed + +This matches the behavior observed in Microsoft's ODBC driver (`sqlcconn.cpp` lines 4273-4285), which returns `S_FALSE` and continues parsing when encountering malformed entries. + +```python +class ConnectionStringParser: + """ + Parses ODBC connection strings into key-value dictionaries. + Handles ODBC-specific syntax including braces, escaping, and semicolons. + + Uses LENIENT PARSING (matching ODBC driver behavior): + - Ignores malformed entries instead of raising exceptions + - Logs warnings for ignored/malformed entries + - Returns partial results (valid entries only) + """ + + def __init__(self): + """Initialize the parser.""" + self._logger = None # Lazy initialization + + def _get_logger(self): + """Get or create the logger instance (lazy initialization).""" + if self._logger is None: + from mssql_python.logging_config import get_logger + self._logger = get_logger() + return self._logger + + def _log_warning(self, message: str): + """Log a warning message using the configured logger.""" + logger = self._get_logger() + logger.warning(message) + + def parse(self, connection_str: str) -> Dict[str, str]: + """ + Parse a connection string into a dictionary of parameters. + + LENIENT PARSING: Ignores malformed entries (logs warnings instead of raising exceptions). + This matches ODBC driver behavior. + + Args: + connection_str: ODBC-format connection string + + Returns: + Dictionary mapping parameter names (lowercase) to values. + Returns empty dict if all entries are malformed. + + Examples: + >>> parser.parse("Server=localhost;Database=mydb") + {'server': 'localhost', 'database': 'mydb'} + + >>> parser.parse("Server={;local;};PWD={p}}w{{d}") + {'server': ';local;', 'pwd': 'p}w{d'} + + >>> parser.parse("Server=localhost;InvalidEntry;Database=mydb") + # Logs: WARNING: Ignoring malformed connection string entry (no '=' found): 'InvalidEntry' + {'server': 'localhost', 'database': 'mydb'} # Partial result + + >>> parser.parse("Server=localhost;PWD={unclosed") + # Logs: WARNING: Ignoring malformed braced value (unclosed brace): 'PWD={unclosed' + {'server': 'localhost'} # Partial result + """ + # Example: "" or None → return empty dict + if not connection_str: + return {} + + # Example: " \t " → strip to "" → return empty dict + connection_str = connection_str.strip() + if not connection_str: + return {} + + # Dictionary to store parsed key=value pairs + # Example: will become {'server': 'localhost', 'database': 'mydb'} + params = {} + + # Track current position in the string as we parse character by character + # Example: for "Server=localhost", starts at 0 (the 'S') + current_pos = 0 + str_len = len(connection_str) + + # Main parsing loop - process one key=value pair per iteration + # Example: "Server=localhost;Database=mydb" → processes 2 pairs + while current_pos < str_len: + # Skip leading whitespace and semicolons + # Example: " ; Server=localhost" → skips to position of 'S' + # Example: "Server=localhost;;Database=mydb" → skips double semicolons + while current_pos < str_len and connection_str[current_pos] in ' \t;': + current_pos += 1 + + # If we've reached the end after skipping whitespace/semicolons, we're done + # Example: "Server=localhost; " → exits cleanly after trailing whitespace + if current_pos >= str_len: + break + + # Parse the key (everything before '=' or ';') + # Example: "Server=localhost" → key_start=0 + key_start = current_pos + + # Advance until we hit '=', ';', or end of string + # Example: "Server=localhost" → stops at '=' (position 6) + # Example: "InvalidEntry;Database=mydb" → stops at ';' (position 12) + while current_pos < str_len and connection_str[current_pos] not in '=;': + current_pos += 1 + + # Check if we found a valid '=' separator + # Example: "InvalidEntry;..." → current_pos points to ';', not '=' + if current_pos >= str_len or connection_str[current_pos] != '=': + # LENIENT: No '=' found, this is a malformed entry + # Example: "Server=localhost;BadEntry;Database=mydb" + # → "BadEntry" has no '=', so extract it and log warning + malformed_entry = connection_str[key_start:current_pos].strip() + if malformed_entry: # Only log if non-empty (avoid logging for just whitespace) + # Example: logs "Ignoring malformed connection string entry (no '=' found): 'BadEntry'" + self._log_warning( + f"Ignoring malformed connection string entry (no '=' found): '{malformed_entry}'" + ) + # Skip to next semicolon to continue parsing + # Example: "BadEntry;Database=mydb" → skip to ';' before "Database" + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + continue + + # Extract and normalize the key + # Example: "Server=localhost" → key = "server" (lowercase) + # Example: " SERVER =localhost" → key = "server" (stripped and lowercased) + key = connection_str[key_start:current_pos].strip().lower() + + # LENIENT: Skip entries with empty keys + # Example: "=somevalue;Server=localhost" → empty key before '=' + if not key: + # Example: logs "Ignoring connection string entry with empty key" + self._log_warning("Ignoring connection string entry with empty key") + current_pos += 1 # Skip the '=' + # Skip to next semicolon + # Example: "=badvalue;Server=localhost" → skip to ';' before "Server" + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + continue + + # Move past the '=' character + # Example: "Server=localhost" → current_pos now points to 'l' in "localhost" + current_pos += 1 + + # Parse the value (with lenient error handling for unclosed braces) + # Example: "Server=localhost" → value="localhost", current_pos=16 + # Example: "PWD={p;w}" → value="p;w", current_pos=9 + try: + value, current_pos = self._parse_value(connection_str, current_pos) + # Store the key=value pair (later occurrences override earlier ones) + # Example: "Server=old;Server=new" → params['server'] = 'new' + params[key] = value + except ValueError as e: + # LENIENT: Unclosed brace or other parsing error + # Example: "Server=localhost;PWD={unclosed;Database=mydb" + # → logs warning for PWD, continues to parse Database + self._log_warning(f"Ignoring malformed braced value: {e}") + # Skip to next semicolon to continue parsing other entries + # Example: skip from '{unclosed' to ';' before "Database" + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + + # Return all successfully parsed key=value pairs + # Example: "Server=localhost;BadEntry;Database=mydb" → {'server': 'localhost', 'database': 'mydb'} + return params + + def _parse_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a parameter value from the connection string. + + Handles both simple values and braced values with escaping. + + Args: + connection_str: The connection string + start_pos: Starting position of the value + + Returns: + Tuple of (parsed_value, new_position) + """ + str_len = len(connection_str) + + # Skip leading whitespace before the value + # Example: "Server= localhost" → skip spaces, start_pos points to 'l' + while start_pos < str_len and connection_str[start_pos] in ' \t': + start_pos += 1 + + # If we've consumed the entire string or reached a semicolon, return empty value + # Example: "Server=" → empty value + # Example: "Server=;" → empty value + if start_pos >= str_len: + return '', start_pos + + # Determine if this is a braced value or simple value + # Braced value: starts with '{', requires special escape handling + # Simple value: everything else, read until semicolon + if connection_str[start_pos] == '{': + # Example: "PWD={p;w}" → delegate to _parse_braced_value + # Example: "Server={local;server}" → delegate to _parse_braced_value + return self._parse_braced_value(connection_str, start_pos) + else: + # Example: "Server=localhost" → delegate to _parse_simple_value + # Example: "Database=mydb" → delegate to _parse_simple_value + return self._parse_simple_value(connection_str, start_pos) + + def _parse_simple_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a simple (non-braced) value up to the next semicolon. + + Simple values cannot contain semicolons or opening braces. + + Args: + connection_str: The connection string + start_pos: Starting position of the value + + Returns: + Tuple of (parsed_value, new_position) + + Examples: + "Server=localhost;..." → returns ("localhost", position_after_t) + "Database=mydb" → returns ("mydb", end_of_string) + """ + str_len = len(connection_str) + # Mark the start of the value + # Example: "Server=localhost;Database=mydb" + # ^value_start (position of 'l') + value_start = start_pos + + # Read characters until we hit a semicolon or end of string + # Example: "localhost;Database=mydb" → reads 'localhost', stops at ';' + # Example: "mydb" → reads 'mydb', stops at end of string + while start_pos < str_len and connection_str[start_pos] != ';': + start_pos += 1 + + # Extract the value and strip trailing whitespace + # Example: "localhost ;..." → value="localhost" (trailing spaces removed) + # Example: "mydb" → value="mydb" + value = connection_str[value_start:start_pos].rstrip() + + # Return the extracted value and the position after it + # Example: returns ("localhost", position_of_semicolon) + return value, start_pos + + def _parse_braced_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a braced value with proper handling of escaped braces. + + Braced values: + - Start with '{' and end with '}' + - '}' inside the value is escaped as '}}' + - '{' inside the value is escaped as '{{' + - Can contain semicolons and other special characters + + Args: + connection_str: The connection string + start_pos: Starting position (should point to opening '{') + + Returns: + Tuple of (parsed_value, new_position) + + Raises: + ValueError: If the braced value is not closed (missing '}') + + Examples: + "{p}}w{{d}" → returns ("p}w{d", position_after_closing_brace) + "{;local;}" → returns (";local;", position_after_closing_brace) + "{unclosed" → raises ValueError (caught by caller in lenient mode) + """ + str_len = len(connection_str) + + # Skip the opening '{' character + # Example: "{password}" → start_pos moves from '{' to 'p' + start_pos += 1 + + # Build the value character by character, handling escape sequences + # Example: will accumulate ['p', '}', 'w', '{', 'd'] for "{p}}w{{d}" + value = [] + + # Process each character until we find the closing '}' or reach end of string + while start_pos < str_len: + # Get current character + # Example: 'p' in "password}", or '}' in "p}}w{{d}" + ch = connection_str[start_pos] + + if ch == '}': + # Found a '}' - could be escaped '}}' or the closing brace + # Check if next character is also '}' (escaped brace) + if start_pos + 1 < str_len and connection_str[start_pos + 1] == '}': + # Escaped right brace: '}}' → '}' + # Example: "{p}}word}" → '}}' becomes single '}' in output + value.append('}') + start_pos += 2 # Skip both '}' characters + else: + # Single '}' means end of braced value + # Example: "{password}" → found closing '}' + start_pos += 1 # Skip the closing '}' + # Join all accumulated characters and return + # Example: ['p', 'a', 's', 's'] → "pass" + return ''.join(value), start_pos + else: + # Regular character (including '{', ';', '=', etc.) + # Example: 'p', 'a', 's', 's' in "{password}" + # Example: ';' in "{local;server}" + # Note: '{{' is also handled here - first '{' is added to value, + # second '{' will be added on next iteration + value.append(ch) + start_pos += 1 + + # We've reached end of string without finding closing '}' + # Example: "PWD={unclosed;Server=localhost" + # → while loop exits because start_pos >= str_len + + # Raise ValueError - unclosed braced value + # NOTE: In lenient parsing mode, this exception is caught by parse() + # which logs a warning and continues parsing remaining entries + # Example: parse() will log "Ignoring malformed braced value: Unclosed braced value in connection string" + raise ValueError("Unclosed braced value in connection string") +``` + +--- + +## Allow-List Strategy + +### Allowed Parameters + +The allow-list is designed to include only parameters that: + +1. **Have Python API Support**: Parameters that the driver can properly handle and configure +2. **Are Runtime-Agnostic**: Parameters that will work with current and future driver implementations (or can be mapped appropriately) +3. **Are Essential for Connectivity**: Core parameters needed for database connections +4. **Have Clear Semantics**: Parameters with well-defined behavior and no ambiguity + +**Philosophy**: We take a deliberate, allow-list approach to exposing parameters because: +- It's easier to add parameters over time than to remove them once users depend on them +- This enables us to achieve ODBC parity incrementally while maintaining backward compatibility +- We can carefully evaluate each parameter's necessity and ensure proper Python API support before exposing it +- Users won't build dependencies on features that may not be available in future driver versions + +**Special Parameters** (handled outside the allow-list): +- **Driver**: Always hardcoded to `{ODBC Driver 18 for SQL Server}`. User-provided values are stripped and replaced to ensure driver consistency. +- **APP**: Always set to `MSSQL-Python`. User-provided values are overwritten to ensure proper application identification in SQL Server logs and monitoring. + +These special parameters maintain the existing behavior and ensure consistent driver operation. + +**Excluded Parameters** include: +- Always Encrypted extensibility parameters (no Python key store provider API yet) +- Advanced ODBC-specific features without Python bindings or TDS runtime equivalents +- Deprecated or legacy parameters +- Parameters with unclear behavior or side effects +- ODBC driver configuration parameters that don't translate to TDS runtime + +```python +# File: mssql_python/connection_string_allowlist.py + +class ConnectionStringAllowList: + """ + Manages the allow-list of permitted connection string parameters. + """ + + # Core connection parameters + ALLOWED_PARAMS = { + # Server identification + 'server': 'Server', + 'address': 'Server', + 'addr': 'Server', + 'network address': 'Server', + + # Authentication + 'uid': 'Uid', + 'user id': 'Uid', + 'user': 'Uid', + 'pwd': 'Pwd', + 'password': 'Pwd', + 'authentication': 'Authentication', + 'trusted_connection': 'Trusted_Connection', + + # Database + 'database': 'Database', + 'initial catalog': 'Database', + + # Driver (read-only - always set by mssql-python) + 'driver': 'Driver', + + # Encryption + 'encrypt': 'Encrypt', + 'trustservercertificate': 'TrustServerCertificate', + 'hostnameincertificate': 'HostNameInCertificate', + + # Connection behavior + 'connection timeout': 'Connection Timeout', + 'connect timeout': 'Connection Timeout', + 'timeout': 'Connection Timeout', + 'login timeout': 'Login Timeout', + 'multisubnetfailover': 'MultiSubnetFailover', + 'applicationintent': 'ApplicationIntent', + 'application intent': 'ApplicationIntent', + 'transparentnetworkipresolution': 'TransparentNetworkIPResolution', + + # Application identification + 'app': 'APP', + 'application name': 'APP', + 'workstation id': 'Workstation ID', + 'wsid': 'Workstation ID', + + # MARS (Multiple Active Result Sets) + 'mars_connection': 'MARS_Connection', + 'multipleactiveresultsets': 'MARS_Connection', + + # Language/Regional + 'language': 'Language', + + # Connection Pooling (driver level) + 'pooling': 'Pooling', + + # Column Encryption + 'columnencryption': 'ColumnEncryption', + + # Attach database file + 'attachdbfilename': 'AttachDbFilename', + + # Failover + 'failover partner': 'Failover_Partner', + + # Application name / intent + 'application role': 'ApplicationRole', + + # Packet size + 'packet size': 'Packet Size', + } + + # Parameters that should be handled separately (not in allow-list) + BLOCKED_PARAMS = { + 'pwd', # Captured separately for logging sanitization + 'password', + } + + @classmethod + def normalize_key(cls, key: str) -> Optional[str]: + """ + Normalize a parameter key to its canonical form. + + Args: + key: Parameter key from connection string (case-insensitive) + + Returns: + Canonical parameter name if allowed, None otherwise + + Examples: + >>> ConnectionStringAllowList.normalize_key('SERVER') + 'Server' + >>> ConnectionStringAllowList.normalize_key('UnsupportedParam') + None + """ + key_lower = key.lower().strip() + return cls.ALLOWED_PARAMS.get(key_lower) + + @classmethod + def filter_params(cls, params: Dict[str, str], warn_rejected: bool = True) -> Dict[str, str]: + """ + Filter parameters against the allow-list. + + Args: + params: Dictionary of connection string parameters + warn_rejected: Whether to log warnings for rejected parameters + + Returns: + Dictionary containing only allowed parameters with normalized keys + """ + from mssql_python.logging_config import get_logger + from mssql_python.helpers import sanitize_user_input + + logger = get_logger() + filtered = {} + rejected = [] + + for key, value in params.items(): + normalized_key = cls.normalize_key(key) + + if normalized_key: + # Parameter is allowed + filtered[normalized_key] = value + else: + # Parameter is not in allow-list + rejected.append(key) + if warn_rejected and logger: + safe_key = sanitize_user_input(key) + logger.warning( + f"Connection string parameter '{safe_key}' is not in the allow-list and will be ignored" + ) + + return filtered +``` + +### Allow-List Rationale + +| Parameter Category | Purpose | Include in Allow-List? | Rationale | +|--------------------|---------|------------------------|-----------| +| Server/Authentication | Core connection functionality | **Yes** | Essential, runtime-agnostic | +| Encryption (TLS/SSL) | TLS/SSL configuration | **Yes** | Essential for security, supported in all runtimes | +| Connection Behavior | Timeouts, failover, MARS | **Yes** | Core functionality, can be mapped across implementations | +| Application Identification | Logging, monitoring | **Yes** | Informational, no side effects | +| Always Encrypted Extensions | Custom key store providers | **No** | Requires Python key store provider API (not yet available) | +| ODBC Driver Internals | Driver-specific configuration | **No** | ODBC-specific, may not work in future implementations | +| Deprecated Parameters | Legacy ODBC parameters | **No** | Should not expose in modern Python API | +| Synonym Parameters | Alternative names for same parameter | **Normalize** | Accept but normalize to canonical name | + +**Normalization Strategy**: +- Accept common synonyms (e.g., "user", "uid", "user id") for ease of use +- Always normalize to a single canonical name (e.g., "Uid") +- This provides flexibility while maintaining consistency +- Prepares for potential Python-style naming in future (e.g., "user_id") + +--- + +## Performance Considerations + +### Optimization Strategies + +1. **Lazy Initialization** + - Parse connection string only once during connection initialization + - Cache the parsed dictionary + +2. **Early Termination** + - Simple parameter counting before full parse + - Skip parsing if connection string is empty + +3. **Minimal Allocations** + - Reuse string builders + - Single-pass parsing + +4. **Compiled Regex (if needed)** + - Pre-compile any regex patterns + - Use simple string operations where possible + +### Performance Targets + +```python +# Benchmark targets (on modern hardware) +Performance Metric Target Worst Case +──────────────────────────────────────────────────────────────── +Parse simple connection string < 0.1ms < 0.5ms +Parse complex connection string < 0.5ms < 2ms +Filter against allow-list < 0.1ms < 0.5ms +Rebuild connection string < 0.1ms < 0.5ms +──────────────────────────────────────────────────────────────── +Total overhead per connection < 1ms < 5ms +``` + +### Memory Usage + +```python +# Estimated memory per connection +Component Size +────────────────────────────────────────────── +Parsed params dict ~1-2 KB +Filtered params dict ~1-2 KB +Rebuilt connection string ~0.5-1 KB +────────────────────────────────────────────── +Total ~3-5 KB +``` + +--- + +## Data Flow Diagrams + +### Diagram 1: Current Flow (Before This Design) + +``` +┌─────────────────────────────────────────────────────┐ +│ User provides: │ +│ connection_str = "Server=myserver;Secret=value" │ +│ kwargs = {"encrypt": "yes"} │ +└──────────────────┬──────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────┐ +│ add_driver_to_connection_str() │ +│ ├─ Adds Driver={ODBC Driver 18 for SQL Server} │ +│ └─ NO PARAMETER FILTERING │ +└──────────────────┬──────────────────────────────────┘ + │ + │ Output: "Driver={ODBC Driver 18 for SQL Server}; + │ Server=myserver;Secret=value" + ▼ +┌─────────────────────────────────────────────────────┐ +│ _construct_connection_string() │ +│ ├─ Takes connection_str from above (UNFILTERED!) │ +│ └─ Appends FILTERED kwargs: │ +│ ├─ "encrypt" → "Encrypt=yes" │ +│ └─ Other kwargs rejected ✗ │ +└──────────────────┬──────────────────────────────────┘ + │ + │ Final: "Driver={ODBC Driver 18...}; + │ Server=myserver;Secret=value; + │ Encrypt=yes" + │ + │ ⚠️ "Secret=value" was NEVER FILTERED! ⚠️ + │ ⚠️ Could be unsupported ODBC parameter! ⚠️ + ▼ +┌─────────────────────────────────────────────────────┐ +│ ddbc_bindings.Connection(final_connection_str) │ +│ └─ Passes to ODBC driver (including "Secret") │ +└─────────────────────────────────────────────────────┘ +``` + +### Diagram 2: Proposed Flow (With Allow-List) + +``` +┌──────────────────────────────────────────────────────────────┐ +│ User provides: │ +│ connection_str = "Server=myserver;Secret=value;Encrypt=no" │ +│ kwargs = {"encrypt": "yes", "database": "mydb"} │ +└──────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ ConnectionStringParser.parse(connection_str) │ +│ │ +│ Parse result: │ +│ { │ +│ 'server': 'myserver', ← Allowed │ +│ 'secret': 'value', ← NOT in allow-list ⚠️ | +│ 'encrypt': 'no' ← Allowed (will be │ +│ } overridden by kwargs) │ +└──────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ ConnectionStringAllowList.filter_params(parsed_params) │ +│ │ +│ ├─ Check 'server' → Allowed → Normalize to 'Server' │ +│ ├─ Check 'secret' → ✗ REJECTED → Log warning, drop param │ +│ │ (Not in allow-list - may be unsupported ODBC parameter) │ +│ └─ Check 'encrypt' → Allowed → Normalize to 'Encrypt' │ +│ │ +│ Filtered result: │ +│ { │ +│ 'Server': 'myserver', │ +│ 'Encrypt': 'no' │ +│ } │ +└──────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ ConnectionStringBuilder.merge_kwargs(filtered, kwargs) │ +│ │ +│ ├─ Process kwargs through allow-list: │ +│ │ ├─ 'encrypt' → Normalize to 'Encrypt' │ +│ │ └─ 'database' → Normalize to 'Database' │ +│ │ │ +│ ├─ Merge (kwargs override connection_str): │ +│ │ ├─ 'Server': 'myserver' (from connection_str) │ +│ │ ├─ 'Encrypt': 'yes' ← OVERRIDES 'no' from conn_str │ +│ │ └─ 'Database': 'mydb' (from kwargs) │ +│ │ │ +│ └─ Add driver-specific params: │ +│ ├─ 'Driver': '{ODBC Driver 18 for SQL Server}' │ +│ └─ 'APP': 'MSSQL-Python' │ +│ │ +│ Merged result: │ +│ { │ +│ 'Driver': '{ODBC Driver 18 for SQL Server}', │ +│ 'Server': 'myserver', │ +│ 'Database': 'mydb', │ +│ 'Encrypt': 'yes', │ +│ 'APP': 'MSSQL-Python' │ +│ } │ +│ │ +│ Note: All parameters are validated and normalized │ +│ Only supported features are passed to ODBC │ +└──────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ ConnectionStringBuilder.build() │ +│ │ +│ ├─ Reconstruct connection string: │ +│ │ ├─ Check each value for special chars (;, {, }) │ +│ │ ├─ Add braces if needed: {value} │ +│ │ ├─ Escape braces in value: } → }} │ +│ │ └─ Join with semicolons │ +│ │ │ +│ └─ Final string: │ +│ "Driver={ODBC Driver 18 for SQL Server};" │ +│ "Server=myserver;" │ +│ "Database=mydb;" │ +│ "Encrypt=yes;" │ +│ "APP=MSSQL-Python" │ +│ │ +│ Benefits of this approach: │ +│ - "Secret=value" was filtered out (unsupported param) │ +│ - Only parameters with Python API support are passed │ +│ - Forward compatible with future driver enhancements │ +│ - Synonyms normalized to canonical names │ +└──────────────────┬───────────────────────────────────────────┘ + │ + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ ddbc_bindings.Connection(final_connection_str) │ +│ │ +│ - Passes ONLY ALLOWED parameters to ODBC driver │ +│ - All parameters have proper Python API support │ +│ - Forward compatible with future driver enhancements │ +└──────────────────────────────────────────────────────────────┘ +``` + +### Diagram 3: Parser State Machine + +``` + ┌─────────┐ + │ START │ + └────┬────┘ + │ + ▼ + ┌──────────────────────┐ + │ Skip whitespace/';' │◄─────────────┐ + └──────────┬───────────┘ │ + │ │ + ▼ │ + ┌──────────┐ │ + │ Parse │ │ + │ KEY │ │ + └────┬─────┘ │ + │ │ + │ Found '=' │ + ▼ │ + ┌──────────┐ │ + ┌────────┤ Check │ │ + │ │ next │ │ + │ │ char │ │ + │ └────┬─────┘ │ + │ │ │ + '{' ? │ │ Other │ + │ ▼ │ + │ ┌──────────────┐ │ + │ │ Parse SIMPLE │ │ + │ │ VALUE │ │ + │ │ (until ';') │ │ + │ └──────┬───────┘ │ + │ │ │ + │ └─────────┬──────────────────┘ + │ │ + │ │ Store key=value + │ │ More params? + │ │ + ▼ ▼ + ┌──────────────┐ ┌────────┐ + │ Parse BRACED │ │ END │ + │ VALUE │ └────────┘ + │ (handle '}}')| + └──────┬───────┘ + │ + └────────────────────┘ +``` + +--- + +## Implementation Details + +### File Structure + +``` +mssql_python/ +├── connection.py # Modified +├── helpers.py # Modified +├── connection_string_parser.py # NEW +├── connection_string_allowlist.py # NEW +└── connection_string_builder.py # NEW +``` + +### Modified: connection.py + +```python +def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str: + """ + Construct the connection string by parsing, filtering, and merging parameters. + + This method: + 1. Parses the base connection_str into parameters + 2. Filters parameters against an allow-list + 3. Merges kwargs (which also go through allow-list) + 4. Adds driver and APP parameters + 5. Rebuilds the connection string + + Args: + connection_str: Base connection string from user + **kwargs: Additional key/value pairs for the connection string + + Returns: + str: The constructed and filtered connection string + """ + from mssql_python.connection_string_parser import ConnectionStringParser + from mssql_python.connection_string_allowlist import ConnectionStringAllowList + from mssql_python.connection_string_builder import ConnectionStringBuilder + from mssql_python.helpers import log, sanitize_connection_string + + # Step 1: Parse base connection string + parser = ConnectionStringParser() + parsed_params = parser.parse(connection_str) + + # Step 2: Filter against allow-list + filtered_params = ConnectionStringAllowList.filter_params( + parsed_params, + warn_rejected=True + ) + + # Step 3: Build connection string + builder = ConnectionStringBuilder(filtered_params) + + # Step 4: Add kwargs (they go through allow-list too) + for key, value in kwargs.items(): + normalized_key = ConnectionStringAllowList.normalize_key(key) + if normalized_key: + builder.add_param(normalized_key, value) + else: + log('warning', f"Ignoring unknown connection parameter from kwargs: {key}") + + # Step 5: Add Driver and APP parameters (always controlled by the driver) + # These maintain existing behavior: Driver is always hardcoded, APP is always MSSQL-Python + builder.add_param('Driver', '{ODBC Driver 18 for SQL Server}') + builder.add_param('APP', 'MSSQL-Python') # Always set, overrides any user value + + # Step 6: Build final string + conn_str = builder.build() + + log('info', "Final connection string: %s", sanitize_connection_string(conn_str)) + + return conn_str +``` + +**Key Design Note**: The new implementation **preserves** the existing behavior for `Driver` and `APP`: +- `Driver` is **always** set to `{ODBC Driver 18 for SQL Server}`, regardless of user input +- `APP` is **always** set to `MSSQL-Python`, regardless of user input +- Both parameters are set **after** allow-list filtering, ensuring they cannot be overridden +- This maintains backward compatibility and ensures consistent application identification + +### NEW: connection_string_parser.py + +```python +""" +ODBC connection string parser for mssql-python. + +Handles ODBC-specific syntax: +- Semicolon-separated key=value pairs +- Braced values: {value} +- Escaped braces: }} → } +""" + +from typing import Dict, Tuple, Optional + +class ConnectionStringParser: + # Implementation as shown in earlier section + pass +``` + +### NEW: connection_string_builder.py + +```python +""" +Connection string builder for mssql-python. + +Reconstructs ODBC connection strings from parameter dictionaries +with proper escaping and formatting. +""" + +from typing import Dict + +class ConnectionStringBuilder: + """ + Builds ODBC connection strings from parameter dictionaries. + """ + + def __init__(self, initial_params: Optional[Dict[str, str]] = None): + """ + Initialize the builder with optional initial parameters. + + Args: + initial_params: Dictionary of initial connection parameters + """ + self._params: Dict[str, str] = initial_params.copy() if initial_params else {} + + def add_param(self, key: str, value: str) -> 'ConnectionStringBuilder': + """ + Add or update a connection parameter. + + Args: + key: Parameter name (case-sensitive, should be normalized) + value: Parameter value + + Returns: + Self for method chaining + """ + self._params[key] = value + return self + + def has_param(self, key: str) -> bool: + """Check if a parameter exists.""" + return key in self._params + + def build(self) -> str: + """ + Build the final connection string. + + Returns: + ODBC-formatted connection string + """ + parts = [] + + # Build in specific order: Driver first, then others + if 'Driver' in self._params: + parts.append(f"Driver={self._escape_value(self._params['Driver'])}") + + # Add other parameters (sorted for consistency) + for key in sorted(self._params.keys()): + if key == 'Driver': + continue # Already added + + value = self._params[key] + escaped_value = self._escape_value(value) + parts.append(f"{key}={escaped_value}") + + # Join with semicolons + return ';'.join(parts) + + def _escape_value(self, value: str) -> str: + """ + Escape a parameter value if it contains special characters. + + Special characters that require bracing: ; { } + Braces inside braced values: } → }} + + Args: + value: Parameter value to escape + + Returns: + Escaped value (possibly wrapped in braces) + """ + if not value: + return value + + # Check if value contains special characters + needs_braces = any(ch in value for ch in ';{}') + + if needs_braces: + # Escape existing braces by doubling them + escaped = value.replace('}', '}}').replace('{', '{{') + return f'{{{escaped}}}' + else: + return value +``` + +--- + +## Testing Strategy + +### Unit Tests + +```python +# tests/test_connection_string_parser.py + +class TestConnectionStringParser: + """Unit tests for ConnectionStringParser.""" + + def test_parse_empty_string(self): + """Test parsing empty connection string.""" + parser = ConnectionStringParser() + result = parser.parse("") + assert result == {} + + def test_parse_simple_params(self): + """Test parsing simple key=value pairs.""" + parser = ConnectionStringParser() + result = parser.parse("Server=localhost;Database=mydb") + assert result == { + 'server': 'localhost', + 'database': 'mydb' + } + + def test_parse_braced_values(self): + """Test parsing braced values.""" + parser = ConnectionStringParser() + result = parser.parse("Server={;local;};PWD={p}}w{{d}") + assert result == { + 'server': ';local;', + 'pwd': 'p}w{d' + } + + def test_parse_trailing_semicolon(self): + """Test parsing with trailing semicolon.""" + parser = ConnectionStringParser() + result = parser.parse("Server=localhost;") + assert result == {'server': 'localhost'} + + def test_parse_malformed_no_equals(self): + """Test that malformed entries (no '=') are ignored with lenient parsing.""" + parser = ConnectionStringParser() + # "Server localhost" has no '=', so it's ignored + # Only valid entries are returned + result = parser.parse("Server=localhost;Invalid Entry;Database=mydb") + assert result == { + 'server': 'localhost', + 'database': 'mydb' + } + + def test_parse_unclosed_brace_ignored(self): + """Test that unclosed braces are ignored with lenient parsing.""" + parser = ConnectionStringParser() + # "PWD={unclosed" is malformed, so it's ignored + # Only valid entries are returned + result = parser.parse("Server=localhost;PWD={unclosed;Database=mydb") + assert result == { + 'server': 'localhost', + 'database': 'mydb' + } + + def test_parse_all_malformed_returns_empty(self): + """Test that all-malformed connection strings return empty dict.""" + parser = ConnectionStringParser() + result = parser.parse("NoEquals;AlsoNoEquals") + assert result == {} + + def test_parse_malformed_with_logging(self, caplog): + """Test that malformed entries generate warning logs.""" + parser = ConnectionStringParser() + with caplog.at_level(logging.WARNING): + result = parser.parse("Server=localhost;BadEntry") + + assert result == {'server': 'localhost'} + assert "Ignoring malformed connection string entry" in caplog.text + + +class TestConnectionStringAllowList: + """Unit tests for ConnectionStringAllowList.""" + + def test_normalize_key_allowed(self): + """Test normalization of allowed keys.""" + assert ConnectionStringAllowList.normalize_key('SERVER') == 'Server' + assert ConnectionStringAllowList.normalize_key('uid') == 'Uid' + + def test_normalize_key_not_allowed(self): + """Test normalization of disallowed keys.""" + assert ConnectionStringAllowList.normalize_key('BadParam') is None + + def test_filter_params_allows_good_params(self): + """Test filtering allows known parameters.""" + params = {'server': 'localhost', 'database': 'mydb'} + filtered = ConnectionStringAllowList.filter_params(params) + assert 'Server' in filtered + assert 'Database' in filtered + + def test_filter_params_rejects_bad_params(self): + """Test filtering rejects unknown parameters.""" + params = {'server': 'localhost', 'badparam': 'value'} + filtered = ConnectionStringAllowList.filter_params(params) + assert 'Server' in filtered + assert 'badparam' not in filtered + + +class TestConnectionStringBuilder: + """Unit tests for ConnectionStringBuilder.""" + + def test_build_simple(self): + """Test building simple connection string.""" + builder = ConnectionStringBuilder() + builder.add_param('Server', 'localhost') + builder.add_param('Database', 'mydb') + result = builder.build() + assert 'Server=localhost' in result + assert 'Database=mydb' in result + + def test_build_with_escaping(self): + """Test building with special characters requiring escaping.""" + builder = ConnectionStringBuilder() + builder.add_param('PWD', 'p;w{d}') + result = builder.build() + assert 'PWD={p;w{{d}}}' in result or 'PWD={p;w{d}}' in result +``` + +### Integration Tests + +```python +# tests/test_connection_integration.py + +class TestConnectionIntegration: + """Integration tests for the complete connection flow.""" + + def test_connection_with_filtered_params(self): + """Test that unknown parameters are filtered out.""" + # This should work (filtered params removed) + conn = Connection( + "Server=localhost;Database=mydb;BadParam=value", + encrypt="yes" + ) + # Verify connection string doesn't contain BadParam + assert 'badparam' not in conn.connection_str.lower() + + def test_kwargs_override_connection_str(self): + """Test that kwargs override connection_str parameters.""" + conn = Connection( + "Server=localhost;Encrypt=no", + encrypt="yes" + ) + # Verify Encrypt=yes is in final string + assert 'Encrypt=yes' in conn.connection_str or 'Encrypt = yes' in conn.connection_str +``` + +--- + +## Design Considerations + +### Privacy and Logging + +When filtering connection string parameters, proper handling of sensitive information is important: + +| Consideration | Implementation | +|---------------|----------------| +| **Password Handling** | Use `sanitize_connection_string()` before logging | +| **Credential Leakage** | Special handling for password parameters in logs | +| **Information Disclosure** | Sanitize connection strings in debug output | +| **Error Messages** | Don't include sensitive data in exception messages | + +### Logging Best Practices + +1. **Password Sanitization** + - Never log actual password values + - Use `sanitize_connection_string()` before logging connection strings + - Replace password values with `***` in debug output + +2. **Parameter Filtering** + - Log warnings for rejected parameters (after sanitization) + - Provide clear feedback about which parameters were filtered + +3. **Error Messages** + - Don't include connection string values in exception messages + - Use generic error messages for connection failures + +--- + +## Future Enhancements + +### Phase 2 Enhancements + +1. **Extended Parameter Support** + - Use the parsed key-value parameters from the connection string parser to support additional connection options + - Map allow-listed parameters to their appropriate configurations + - The same parser output will be used for current and future implementations, ensuring consistent behavior + +--- + +## Appendices + +### Appendix A: ODBC Connection String Specification + +**Reference**: [MS-ODBCSTR: ODBC Connection String Structure](https://learn.microsoft.com/en-us/openspecs/sql_server_protocols/ms-odbcstr/55953f0e-2d30-4ad4-8e56-b4207e491409), Section 2.1.2 + +**Official ABNF Grammar:** + +```abnf +ODBCConnectionString = *(KeyValuePair SC) KeyValuePair [SC] +KeyValuePair = (Key EQ Value / SpaceStr) +Key = SpaceStr KeyName +KeyName = (nonSP-SC-EQ *nonEQ) +Value = (SpaceStr ValueFormat1 SpaceStr) / (ValueContent2) +ValueFormat1 = LCB ValueContent1 RCB +ValueContent1 = *(nonRCB / ESCAPEDRCB) +ValueContent2 = SpaceStr / SpaceStr (nonSP-LCB-SC) *nonSC + +; Character class definitions +nonRCB = %x01-7C / %x7E-FFFF ; any character except "}" (0x7D) +nonSP-LCB-SC = %x01-1F / %x21-3A / %x3C-7A / %x7C-FFFF + ; any character except space (0x20), "{" (0x7B), or ";" (0x3B) +nonSP-SC-EQ = %x01-1F / %x21-3A / %x3C / %x3E-FFFF + ; any character except space (0x20), ";" (0x3B), or "=" (0x3D) +nonEQ = %x01-3C / %x3E-FFFF ; any character except "=" (0x3D) +nonSC = %x01-3A / %x3C-FFFF ; any character except ";" (0x3B) + +; Terminal symbols +SC = ";" ; semicolon separator +EQ = "=" ; equals sign +LCB = "{" ; left curly brace +RCB = "}" ; right curly brace +ESCAPEDRCB = "}}" ; escaped right curly brace +SpaceStr = *SP ; zero or more space characters (0x20) +``` + +**Key Implementation Notes:** + +1. **Key-Value Pairs**: Multiple pairs separated by semicolons (`;`) +2. **Braced Values**: Values containing special characters (`;`, `{`) must use braced format `{...}` +3. **Escaping**: Only `}` is escaped inside braced values by doubling: `}}` → `}` +4. **Left Brace Escaping**: `{` inside braced values is also escaped by doubling: `{{` → `{` +5. **Trailing Semicolons**: Optional trailing semicolon is allowed +6. **Whitespace**: Leading/trailing spaces in keys and certain value formats are significant + +**Examples:** +``` +Server=localhost;Database=mydb +Server={local;server};PWD={p}}w{{d} +Driver={ODBC Driver 18 for SQL Server};Encrypt=yes; +``` + +### Appendix B: Performance Benchmarks + +To be filled in during implementation with actual measurements. + +### Appendix C: References + +1. **[MS-ODBCSTR: ODBC Connection String Structure](https://learn.microsoft.com/en-us/openspecs/sql_server_protocols/ms-odbcstr/55953f0e-2d30-4ad4-8e56-b4207e491409)** - Official ODBC connection string specification with ABNF grammar +2. [ODBC Programmer's Reference](https://docs.microsoft.com/en-us/sql/odbc/reference/develop-app/connection-strings) - General ODBC documentation +3. [SQL Server Connection Strings](https://docs.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute) - SQL Server-specific connection string attributes +4. [ODBC Driver for SQL Server](https://docs.microsoft.com/en-us/sql/connect/odbc/microsoft-odbc-driver-for-sql-server) - Microsoft ODBC Driver 18 for SQL Server documentation + +--- + +## Revision History + +| Date | Version | Author | Changes | +|------|---------|--------|---------| +| 2025-10-23 | 1.0 | Engineering Team | Initial design document | + +--- + +**End of Document** diff --git a/docs/parser_state_machine.md b/docs/parser_state_machine.md new file mode 100644 index 000000000..241abb61d --- /dev/null +++ b/docs/parser_state_machine.md @@ -0,0 +1,215 @@ +# Connection String Parser State Machine + +This document describes the state machine for the ODBC connection string parser (`_ConnectionStringParser`). + +## Overview + +The parser processes ODBC connection strings character-by-character, handling: +- Semicolon-separated key=value pairs +- Simple values (unquoted) +- Braced values with escape sequences: `{value}`, `}}` → `}`, `{{` → `{` +- Whitespace normalization +- Error detection and collection + +## State Machine Diagram + +```mermaid +stateDiagram-v2 + [*] --> START: Begin parsing + + START --> SKIP_WHITESPACE: Start of new segment + + SKIP_WHITESPACE --> SKIP_WHITESPACE: Space, tab, semicolon + SKIP_WHITESPACE --> END: EOF + SKIP_WHITESPACE --> PARSE_KEY: Other char + + PARSE_KEY --> PARSE_KEY: Any char except equals or semicolon + PARSE_KEY --> ERROR_NO_EQUALS: EOF or semicolon found + PARSE_KEY --> ERROR_EMPTY_KEY: Equals found but key is empty + PARSE_KEY --> VALIDATE_KEY: Equals found and key exists + + ERROR_NO_EQUALS --> SKIP_TO_SEMICOLON: Record error + ERROR_EMPTY_KEY --> SKIP_TO_SEMICOLON: Record error + + VALIDATE_KEY --> PARSE_VALUE: Key validated + + PARSE_VALUE --> CHECK_VALUE_TYPE: Skip whitespace + + CHECK_VALUE_TYPE --> PARSE_SIMPLE_VALUE: First char not left brace + CHECK_VALUE_TYPE --> PARSE_BRACED_VALUE: First char is left brace + + PARSE_SIMPLE_VALUE --> PARSE_SIMPLE_VALUE: Any char except semicolon + PARSE_SIMPLE_VALUE --> STORE_PARAM: Semicolon or EOF + + PARSE_BRACED_VALUE --> PARSE_BRACED_VALUE: Regular char + PARSE_BRACED_VALUE --> CHECK_RIGHT_BRACE: Right brace encountered + PARSE_BRACED_VALUE --> CHECK_LEFT_BRACE: Left brace encountered + PARSE_BRACED_VALUE --> ERROR_UNCLOSED_BRACE: EOF without closing brace + + CHECK_RIGHT_BRACE --> PARSE_BRACED_VALUE: Double right brace (escaped) + CHECK_RIGHT_BRACE --> STORE_PARAM: Single right brace (end of value) + + CHECK_LEFT_BRACE --> PARSE_BRACED_VALUE: Double left brace (escaped) + CHECK_LEFT_BRACE --> PARSE_BRACED_VALUE: Single left brace (keep as-is) + + ERROR_UNCLOSED_BRACE --> SKIP_TO_SEMICOLON: Record error + + STORE_PARAM --> CHECK_DUPLICATE: Parameter extracted + + CHECK_DUPLICATE --> ERROR_DUPLICATE: Key seen before + CHECK_DUPLICATE --> VALIDATE_ALLOWLIST: New key + + ERROR_DUPLICATE --> SKIP_TO_SEMICOLON: Record error + + VALIDATE_ALLOWLIST --> CHECK_RESERVED: If allowlist provided + VALIDATE_ALLOWLIST --> SAVE_PARAM: No allowlist + + CHECK_RESERVED --> ERROR_RESERVED: Driver or APP keyword + CHECK_RESERVED --> CHECK_UNKNOWN: Not reserved + + CHECK_UNKNOWN --> ERROR_UNKNOWN: Unknown keyword + CHECK_UNKNOWN --> SAVE_PARAM: Known keyword + + ERROR_RESERVED --> SKIP_TO_SEMICOLON: Record error + ERROR_UNKNOWN --> SKIP_TO_SEMICOLON: Record error + + SAVE_PARAM --> SKIP_WHITESPACE: Continue parsing + SKIP_TO_SEMICOLON --> SKIP_WHITESPACE: Error recovery + + END --> RAISE_ERRORS: If errors collected + END --> RETURN_PARAMS: No errors + + RAISE_ERRORS --> [*]: Throw ConnectionStringParseError + RETURN_PARAMS --> [*]: Return dict of params +``` + +## States Description + +### Main States + +| State | Description | +|-------|-------------| +| **START** | Initial state at beginning of parsing | +| **SKIP_WHITESPACE** | Skip whitespace (spaces, tabs) and semicolons between parameters | +| **PARSE_KEY** | Extract parameter key up to '=' sign | +| **VALIDATE_KEY** | Check if key is non-empty | +| **PARSE_VALUE** | Determine value type and extract it | +| **CHECK_VALUE_TYPE** | Decide between simple or braced value parsing | +| **PARSE_SIMPLE_VALUE** | Extract unquoted value up to ';' or EOF | +| **PARSE_BRACED_VALUE** | Extract braced value with escape handling | +| **STORE_PARAM** | Prepare to store the key-value pair | +| **CHECK_DUPLICATE** | Verify key hasn't been seen before | +| **VALIDATE_ALLOWLIST** | Check parameter against allowlist (if provided) | +| **CHECK_RESERVED** | Verify parameter is not reserved (Driver, APP) | +| **CHECK_UNKNOWN** | Verify parameter is recognized | +| **SAVE_PARAM** | Store the parameter in results | +| **SKIP_TO_SEMICOLON** | Error recovery: advance to next ';' | +| **END** | Parsing complete | +| **RAISE_ERRORS** | Collect and throw all errors | +| **RETURN_PARAMS** | Return parsed parameters dictionary | + +### Error States + +| Error State | Trigger | Error Message | +|-------------|---------|---------------| +| **ERROR_NO_EQUALS** | Key without '=' separator | "Incomplete specification: keyword '{key}' has no value (missing '=')" | +| **ERROR_EMPTY_KEY** | '=' with no preceding key | "Empty keyword found (format: =value)" | +| **ERROR_DUPLICATE** | Same key appears twice | "Duplicate keyword '{key}' found" | +| **ERROR_UNCLOSED_BRACE** | '{' without matching '}' | "Unclosed braced value starting at position {pos}" | +| **ERROR_RESERVED** | User tries to set Driver or APP | "Reserved keyword '{key}' is controlled by the driver and cannot be specified by the user" | +| **ERROR_UNKNOWN** | Key not in allowlist | "Unknown keyword '{key}' is not recognized" | + +## Special Characters & Escaping + +### Braced Value Escape Sequences + +| Input | Parsed Output | Description | +|-------|---------------|-------------| +| `{value}` | `value` | Basic braced value | +| `{val;ue}` | `val;ue` | Semicolon allowed inside braces | +| `{val}}ue}` | `val}ue` | Escaped right brace: `}}` → `}` | +| `{val{{ue}` | `val{ue` | Escaped left brace: `{{` → `{` | +| `{a=b}` | `a=b` | Equals sign allowed inside braces | +| `{sp ace}` | `sp ace` | Spaces preserved inside braces | + +### Simple Value Rules + +- Read until semicolon (`;`) or end of string +- Leading whitespace after '=' is skipped +- Trailing whitespace is stripped from value +- Cannot contain semicolons (unescaped) + +## Examples + +### Valid Parsing Flow + +``` +Input: "Server=localhost;Database=mydb" + +START → SKIP_WHITESPACE → PARSE_KEY("Server") + → VALIDATE_KEY → PARSE_VALUE → CHECK_VALUE_TYPE + → PARSE_SIMPLE_VALUE("localhost") → STORE_PARAM + → CHECK_DUPLICATE → VALIDATE_ALLOWLIST → SAVE_PARAM + → SKIP_WHITESPACE → PARSE_KEY("Database") + → VALIDATE_KEY → PARSE_VALUE → CHECK_VALUE_TYPE + → PARSE_SIMPLE_VALUE("mydb") → STORE_PARAM + → CHECK_DUPLICATE → VALIDATE_ALLOWLIST → SAVE_PARAM + → END → RETURN_PARAMS +``` + +### Error Handling Flow + +``` +Input: "Server=localhost;Server=other" (duplicate) + +START → ... → SAVE_PARAM(Server=localhost) + → SKIP_WHITESPACE → PARSE_KEY("Server") + → VALIDATE_KEY → PARSE_VALUE → ... → STORE_PARAM + → CHECK_DUPLICATE → ERROR_DUPLICATE + → SKIP_TO_SEMICOLON → END → RAISE_ERRORS +``` + +### Braced Value with Escaping + +``` +Input: "PWD={p}}w{{d;test}" + +START → SKIP_WHITESPACE → PARSE_KEY("PWD") + → VALIDATE_KEY → PARSE_VALUE → CHECK_VALUE_TYPE + → PARSE_BRACED_VALUE + - Read 'p' + - Read '}' → CHECK_RIGHT_BRACE + - Next is '}' → Escaped: add '}', continue + - Read 'w' + - Read '{' → CHECK_LEFT_BRACE + - Next is '{' → Escaped: add '{', continue + - Read 'd', ';', 't', 'e', 's', 't' + - Read '}' → CHECK_RIGHT_BRACE + - Next is not '}' → End of value + → STORE_PARAM (value="p}w{d;test") + → ... → SAVE_PARAM → END → RETURN_PARAMS +``` + +## Parser Characteristics + +### Key Features + +1. **Error Collection**: Collects all errors before raising exception (batch error reporting) +2. **Case-Insensitive Keys**: All keys normalized to lowercase during parsing +3. **Duplicate Detection**: Tracks seen keys to prevent duplicates +4. **Reserved Keywords**: Blocks user from setting `Driver` and `APP` +5. **Allowlist Validation**: Optional validation against allowed parameters +6. **Escape Handling**: Proper ODBC brace escape sequences (`{{`, `}}`) +7. **Error Recovery**: Skips to next semicolon after errors to continue validation + +### Error Handling Strategy + +- **Non-fatal errors**: Continue parsing to collect all errors +- **Fatal errors**: Stop immediately (e.g., unclosed brace in value parsing) +- **Batch reporting**: All errors reported together in `ConnectionStringParseError` + +## References + +- MS-ODBCSTR Specification: [ODBC Connection String Format](https://learn.microsoft.com/en-us/openspecs/sql_server_protocols/ms-odbcstr/) +- Implementation: `mssql_python/connection_string_parser.py` +- Tests: `tests/test_010_connection_string_parser.py` diff --git a/eng/pipelines/build-whl-pipeline.yml b/eng/pipelines/build-whl-pipeline.yml index a22edd080..a6540c8aa 100644 --- a/eng/pipelines/build-whl-pipeline.yml +++ b/eng/pipelines/build-whl-pipeline.yml @@ -340,7 +340,7 @@ jobs: python -m pytest -v displayName: 'Run Pytest to validate bindings' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' # Build wheel package for universal2 - script: | @@ -801,7 +801,7 @@ jobs: displayName: 'Test wheel installation and basic functionality on $(BASE_IMAGE)' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' # Run pytest with source code while testing installed wheel - script: | @@ -856,7 +856,7 @@ jobs: " displayName: 'Run pytest suite on $(BASE_IMAGE) $(ARCH)' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=localhost;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' continueOnError: true # Don't fail pipeline if tests fail # Cleanup diff --git a/eng/pipelines/pr-validation-pipeline.yml b/eng/pipelines/pr-validation-pipeline.yml index d2ede2470..1f6ea8094 100644 --- a/eng/pipelines/pr-validation-pipeline.yml +++ b/eng/pipelines/pr-validation-pipeline.yml @@ -190,7 +190,7 @@ jobs: python -m pytest -v --junitxml=test-results.xml --cov=. --cov-report=xml --capture=tee-sys --cache-clear displayName: 'Run pytest with coverage' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' DB_PASSWORD: $(DB_PASSWORD) - task: PublishTestResults@2 @@ -359,12 +359,12 @@ jobs: echo "SQL Server IP: $SQLSERVER_IP" docker exec \ - -e DB_CONNECTION_STRING="Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ + -e DB_CONNECTION_STRING="Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ -e DB_PASSWORD="$(DB_PASSWORD)" \ test-container-$(distroName) bash -c " source /opt/venv/bin/activate echo 'Build successful, running tests now on $(distroName)' - echo 'Using connection string: Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' + echo 'Using connection string: Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' python -m pytest -v --junitxml=test-results-$(distroName).xml --cov=. --cov-report=xml:coverage-$(distroName).xml --capture=tee-sys --cache-clear " displayName: 'Run pytest with coverage in $(distroName) container' @@ -570,13 +570,13 @@ jobs: echo "SQL Server IP: $SQLSERVER_IP" docker exec \ - -e DB_CONNECTION_STRING="Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ + -e DB_CONNECTION_STRING="Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ -e DB_PASSWORD="$(DB_PASSWORD)" \ test-container-$(distroName)-$(archName) bash -c " source /opt/venv/bin/activate echo 'Build successful, running tests now on $(distroName) ARM64' echo 'Architecture:' \$(uname -m) - echo 'Using connection string: Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' + echo 'Using connection string: Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' python main.py python -m pytest -v --junitxml=test-results-$(distroName)-$(archName).xml --cov=. --cov-report=xml:coverage-$(distroName)-$(archName).xml --capture=tee-sys --cache-clear " @@ -778,12 +778,12 @@ jobs: echo "SQL Server IP: $SQLSERVER_IP" docker exec \ - -e DB_CONNECTION_STRING="Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ + -e DB_CONNECTION_STRING="Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ -e DB_PASSWORD="$(DB_PASSWORD)" \ test-container-rhel9 bash -c " source myvenv/bin/activate echo 'Build successful, running tests now on RHEL 9' - echo 'Using connection string: Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' + echo 'Using connection string: Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' python main.py python -m pytest -v --junitxml=test-results-rhel9.xml --cov=. --cov-report=xml:coverage-rhel9.xml --capture=tee-sys --cache-clear " @@ -997,13 +997,13 @@ jobs: echo "SQL Server IP: $SQLSERVER_IP" docker exec \ - -e DB_CONNECTION_STRING="Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ + -e DB_CONNECTION_STRING="Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ -e DB_PASSWORD="$(DB_PASSWORD)" \ test-container-rhel9-arm64 bash -c " source myvenv/bin/activate echo 'Build successful, running tests now on RHEL 9 ARM64' echo 'Architecture:' \$(uname -m) - echo 'Using connection string: Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' + echo 'Using connection string: Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' python -m pytest -v --junitxml=test-results-rhel9-arm64.xml --cov=. --cov-report=xml:coverage-rhel9-arm64.xml --capture=tee-sys --cache-clear " displayName: 'Run pytest with coverage in RHEL 9 ARM64 container' @@ -1225,13 +1225,13 @@ jobs: echo "SQL Server IP: $SQLSERVER_IP" docker exec \ - -e DB_CONNECTION_STRING="Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ + -e DB_CONNECTION_STRING="Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ -e DB_PASSWORD="$(DB_PASSWORD)" \ test-container-alpine bash -c " echo 'Build successful, running tests now on Alpine x86_64' echo 'Architecture:' \$(uname -m) echo 'Alpine version:' \$(cat /etc/alpine-release) - echo 'Using connection string: Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' + echo 'Using connection string: Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' # Activate virtual environment source /workspace/venv/bin/activate @@ -1467,13 +1467,13 @@ jobs: echo "SQL Server IP: $SQLSERVER_IP" docker exec \ - -e DB_CONNECTION_STRING="Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ + -e DB_CONNECTION_STRING="Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes" \ -e DB_PASSWORD="$(DB_PASSWORD)" \ test-container-alpine-arm64 bash -c " echo 'Build successful, running tests now on Alpine ARM64' echo 'Architecture:' \$(uname -m) echo 'Alpine version:' \$(cat /etc/alpine-release) - echo 'Using connection string: Driver=ODBC Driver 18 for SQL Server;Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' + echo 'Using connection string: Server=$SQLSERVER_IP;Database=TestDB;Uid=SA;Pwd=***;TrustServerCertificate=yes' # Activate virtual environment source /workspace/venv/bin/activate @@ -1574,7 +1574,7 @@ jobs: lcov_cobertura total.info --output unified-coverage/coverage.xml displayName: 'Generate unified coverage (Python + C++)' env: - DB_CONNECTION_STRING: 'Driver=ODBC Driver 18 for SQL Server;Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' + DB_CONNECTION_STRING: 'Server=tcp:127.0.0.1,1433;Database=master;Uid=SA;Pwd=$(DB_PASSWORD);TrustServerCertificate=yes' DB_PASSWORD: $(DB_PASSWORD) - task: PublishTestResults@2 diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index cf510ca2a..27ebaaa8c 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -26,6 +26,9 @@ NotSupportedError, ) +# Connection string parser exceptions +from .connection_string_parser import ConnectionStringParseError + # Type Objects from .type import ( Date, diff --git a/mssql_python/connection.py b/mssql_python/connection.py index f0663d727..062898cb2 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -242,39 +242,66 @@ def _construct_connection_string( self, connection_str: str = "", **kwargs: Any ) -> str: """ - Construct the connection string by concatenating the connection string - with key/value pairs from kwargs. - + Construct the connection string by parsing, filtering, and merging parameters. + + This method: + 1. Parses the base connection_str into parameters + 2. Filters parameters against an allow-list + 3. Merges kwargs (which also go through allow-list) + 4. Adds Driver and APP parameters (always controlled by the driver) + 5. Rebuilds the connection string + Args: connection_str (str): The base connection string. **kwargs: Additional key/value pairs for the connection string. Returns: - str: The constructed connection string. + str: The constructed and filtered connection string. """ - # Add the driver attribute to the connection string - conn_str = add_driver_to_connection_str(connection_str) - - # Add additional key-value pairs to the connection string + from mssql_python.connection_string_parser import _ConnectionStringParser + from mssql_python.connection_string_allowlist import ConnectionStringAllowList + from mssql_python.connection_string_builder import _ConnectionStringBuilder + + # Step 1: Parse base connection string with allowlist validation + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + parsed_params = parser.parse(connection_str) + + # Step 2: Filter against allow-list + filtered_params = ConnectionStringAllowList.filter_params( + parsed_params, + warn_rejected=True + ) + + # Step 3: Process kwargs and merge with filtered_params + # kwargs override connection string values (processed after, so they take precedence) for key, value in kwargs.items(): - if key.lower() == "host" or key.lower() == "server": - key = "Server" - elif key.lower() == "user" or key.lower() == "uid": - key = "Uid" - elif key.lower() == "password" or key.lower() == "pwd": - key = "Pwd" - elif key.lower() == "database": - key = "Database" - elif key.lower() == "encrypt": - key = "Encrypt" - elif key.lower() == "trust_server_certificate": - key = "TrustServerCertificate" + normalized_key = ConnectionStringAllowList.normalize_key(key) + if normalized_key: + # Driver and APP are reserved - raise error if user tries to set them + if normalized_key in ('Driver', 'APP'): + raise ValueError( + f"Connection parameter '{key}' is reserved and controlled by the driver. " + f"It cannot be set by the user." + ) + # kwargs override any existing values from connection string + filtered_params[normalized_key] = str(value) else: - continue - conn_str += f"{key}={value};" - - log("info", "Final connection string: %s", sanitize_connection_string(conn_str)) - + log('warning', f"Ignoring unknown connection parameter from kwargs: {key}") + + # Step 4: Build connection string with merged params + builder = _ConnectionStringBuilder(filtered_params) + + # Step 5: Add Driver and APP parameters (always controlled by the driver) + # These maintain existing behavior: Driver is always hardcoded, APP is always MSSQL-Python + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + builder.add_param('APP', 'MSSQL-Python') + + # Step 6: Build final string + conn_str = builder.build() + + log('info', "Final connection string: %s", sanitize_connection_string(conn_str)) + return conn_str @property diff --git a/mssql_python/connection_string_allowlist.py b/mssql_python/connection_string_allowlist.py new file mode 100644 index 000000000..1bcb25294 --- /dev/null +++ b/mssql_python/connection_string_allowlist.py @@ -0,0 +1,150 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Connection string parameter allow-list for mssql-python. + +Manages allowed connection string parameters and handles parameter +normalization, synonym mapping, and filtering. +""" + +from typing import Dict, Optional + + +class ConnectionStringAllowList: + """ + Manages the allow-list of permitted connection string parameters. + + This class implements a deliberate allow-list approach to exposing + connection string parameters, enabling: + - Incremental ODBC parity while maintaining backward compatibility + - Forward compatibility with future driver enhancements + - Simplified API by normalizing parameter synonyms + """ + + # Core connection parameters with synonym mapping + # Maps lowercase parameter names to their canonical form + ALLOWED_PARAMS = { + # Server identification + 'server': 'Server', + 'host': 'Server', # Common synonym + 'address': 'Server', + 'addr': 'Server', + 'network address': 'Server', + + # Authentication + 'uid': 'Uid', + 'user id': 'Uid', + 'user': 'Uid', + 'pwd': 'Pwd', + 'password': 'Pwd', + 'authentication': 'Authentication', + 'trusted_connection': 'Trusted_Connection', + + # Database + 'database': 'Database', + 'initial catalog': 'Database', + + # Driver (always controlled by mssql-python) + 'driver': 'Driver', + + # Application name (always controlled by mssql-python) + 'app': 'APP', + 'application name': 'APP', + + # Encryption + 'encrypt': 'Encrypt', + 'trustservercertificate': 'TrustServerCertificate', + 'trust_server_certificate': 'TrustServerCertificate', # Python-style underscore synonym + 'trust server certificate': 'TrustServerCertificate', + 'hostnameincertificate': 'HostNameInCertificate', + + # Connection behavior + 'connection timeout': 'Connection Timeout', + 'connect timeout': 'Connection Timeout', + 'timeout': 'Connection Timeout', + 'login timeout': 'Login Timeout', + 'multisubnetfailover': 'MultiSubnetFailover', + 'multi subnet failover': 'MultiSubnetFailover', + 'applicationintent': 'ApplicationIntent', + 'application intent': 'ApplicationIntent', + + # Failover + 'failover partner': 'Failover_Partner', + 'failoverpartner': 'Failover_Partner', + + # Packet size + 'packet size': 'Packet Size', + 'packetsize': 'Packet Size', + } + + @classmethod + def normalize_key(cls, key: str) -> Optional[str]: + """ + Normalize a parameter key to its canonical form. + + Args: + key: Parameter key from connection string (case-insensitive) + + Returns: + Canonical parameter name if allowed, None otherwise + + Examples: + >>> ConnectionStringAllowList.normalize_key('SERVER') + 'Server' + >>> ConnectionStringAllowList.normalize_key('user') + 'Uid' + >>> ConnectionStringAllowList.normalize_key('UnsupportedParam') + None + """ + key_lower = key.lower().strip() + return cls.ALLOWED_PARAMS.get(key_lower) + + @classmethod + def filter_params(cls, params: Dict[str, str], warn_rejected: bool = True) -> Dict[str, str]: + """ + Filter parameters against the allow-list. + + Args: + params: Dictionary of connection string parameters (keys should be lowercase) + warn_rejected: Whether to log warnings for rejected parameters + + Returns: + Dictionary containing only allowed parameters with normalized keys + + Note: + Driver and APP parameters are filtered here but will be set by + the driver in _construct_connection_string to maintain control. + """ + # Import here to avoid circular dependency issues + try: + from mssql_python.logging_config import get_logger + from mssql_python.helpers import sanitize_user_input + logger = get_logger() + except ImportError: + logger = None + sanitize_user_input = lambda x: str(x)[:50] # Simple fallback + + filtered = {} + rejected = [] + + for key, value in params.items(): + normalized_key = cls.normalize_key(key) + + if normalized_key: + # Skip Driver and APP - these are controlled by the driver + if normalized_key in ('Driver', 'APP'): + continue + + # Parameter is allowed + filtered[normalized_key] = value + else: + # Parameter is not in allow-list + rejected.append(key) + if warn_rejected and logger: + safe_key = sanitize_user_input(key) + logger.warning( + f"Connection string parameter '{safe_key}' is not in the allow-list and will be ignored" + ) + + return filtered diff --git a/mssql_python/connection_string_builder.py b/mssql_python/connection_string_builder.py new file mode 100644 index 000000000..fb99cba00 --- /dev/null +++ b/mssql_python/connection_string_builder.py @@ -0,0 +1,125 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Connection string builder for mssql-python. + +Reconstructs ODBC connection strings from parameter dictionaries +with proper escaping and formatting per MS-ODBCSTR specification. +""" + +from typing import Dict, Optional + + +class _ConnectionStringBuilder: + """ + Internal builder for ODBC connection strings. Not part of public API. + + Handles proper escaping of special characters and reconstructs + connection strings in ODBC format. + """ + + def __init__(self, initial_params: Optional[Dict[str, str]] = None): + """ + Initialize the builder with optional initial parameters. + + Args: + initial_params: Dictionary of initial connection parameters + """ + self._params: Dict[str, str] = initial_params.copy() if initial_params else {} + + def add_param(self, key: str, value: str) -> '_ConnectionStringBuilder': + """ + Add or update a connection parameter. + + Args: + key: Parameter name (should be normalized canonical name) + value: Parameter value + + Returns: + Self for method chaining + """ + self._params[key] = str(value) + return self + + def has_param(self, key: str) -> bool: + """ + Check if a parameter exists. + + Args: + key: Parameter name to check + + Returns: + True if parameter exists, False otherwise + """ + return key in self._params + + def build(self) -> str: + """ + Build the final connection string. + + Returns: + ODBC-formatted connection string with proper escaping + + Note: + - Driver parameter is placed first + - Other parameters are sorted for consistency + - Values are escaped if they contain special characters + """ + parts = [] + + # Build in specific order: Driver first, then others + if 'Driver' in self._params: + parts.append(f"Driver={self._escape_value(self._params['Driver'])}") + + # Add other parameters (sorted for consistency) + for key in sorted(self._params.keys()): + if key == 'Driver': + continue # Already added + + value = self._params[key] + escaped_value = self._escape_value(value) + parts.append(f"{key}={escaped_value}") + + # Join with semicolons + return ';'.join(parts) + + def _escape_value(self, value: str) -> str: + """ + Escape a parameter value if it contains special characters. + + Per MS-ODBCSTR specification: + - Values containing ';', '{', '}', '=', or spaces should be braced for safety + - '}' inside braced values is escaped as '}}' + - '{' inside braced values is escaped as '{{' + + Args: + value: Parameter value to escape + + Returns: + Escaped value (possibly wrapped in braces) + + Examples: + >>> builder = _ConnectionStringBuilder() + >>> builder._escape_value("localhost") + 'localhost' + >>> builder._escape_value("local;host") + '{local;host}' + >>> builder._escape_value("p}w{d") + '{p}}w{{d}' + >>> builder._escape_value("ODBC Driver 18 for SQL Server") + '{ODBC Driver 18 for SQL Server}' + """ + if not value: + return value + + # Check if value contains special characters that require bracing + # Include spaces and = for safety, even though technically not always required + needs_braces = any(ch in value for ch in ';{}= ') + + if needs_braces: + # Escape existing braces by doubling them + escaped = value.replace('}', '}}').replace('{', '{{') + return f'{{{escaped}}}' + else: + return value diff --git a/mssql_python/connection_string_parser.py b/mssql_python/connection_string_parser.py new file mode 100644 index 000000000..7d841f090 --- /dev/null +++ b/mssql_python/connection_string_parser.py @@ -0,0 +1,312 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +ODBC connection string parser for mssql-python. + +Handles ODBC-specific syntax per MS-ODBCSTR specification: +- Semicolon-separated key=value pairs +- Braced values: {value} +- Escaped braces: }} → }, {{ → { + +Parser behavior: +- Validates all key=value pairs +- Raises exceptions for malformed syntax (missing values, unknown keywords, duplicates) +- Collects all errors and reports them together +""" + +from typing import Dict, Tuple, Optional, List +import logging + + +class ConnectionStringParseError(Exception): + """Exception raised when connection string parsing fails.""" + + def __init__(self, errors: List[str]): + """ + Initialize the error with a list of validation errors. + + Args: + errors: List of error messages + """ + self.errors = errors + message = "Connection string parsing failed:\n " + "\n ".join(errors) + super().__init__(message) + + +class _ConnectionStringParser: + """ + Internal parser for ODBC connection strings. Not part of public API. + + Implements the ODBC Connection String format as specified in MS-ODBCSTR. + Handles braced values, escaped characters, and proper tokenization. + + Validates connection strings and raises errors for: + - Unknown/unrecognized keywords + - Duplicate keywords + - Incomplete specifications (keyword with no value) + + Reference: https://learn.microsoft.com/en-us/openspecs/sql_server_protocols/ms-odbcstr/55953f0e-2d30-4ad4-8e56-b4207e491409 + """ + + def __init__(self, allowlist=None): + """ + Initialize the parser. + + Args: + allowlist: Optional ConnectionStringAllowList instance for keyword validation. + If None, no keyword validation is performed. + """ + self._allowlist = allowlist + + def parse(self, connection_str: str) -> Dict[str, str]: + """ + Parse a connection string into a dictionary of parameters. + + Validates the connection string and raises ConnectionStringParseError + if any issues are found (unknown keywords, duplicates, missing values). + + Args: + connection_str: ODBC-format connection string + + Returns: + Dictionary mapping parameter names (lowercase) to values + + Raises: + ConnectionStringParseError: If validation errors are found + + Examples: + >>> parser = _ConnectionStringParser() + >>> result = parser.parse("Server=localhost;Database=mydb") + {'server': 'localhost', 'database': 'mydb'} + + >>> parser.parse("Server={;local;};PWD={p}}w{{d}") + {'server': ';local;', 'pwd': 'p}w{d'} + + >>> parser.parse("Server=localhost;Server=other") + ConnectionStringParseError: Duplicate keyword 'server' + """ + if not connection_str: + return {} + + connection_str = connection_str.strip() + if not connection_str: + return {} + + # Collect all errors for batch reporting + errors = [] + + # Dictionary to store parsed key=value pairs + params = {} + + # Track which keys we've seen to detect duplicates + seen_keys = {} # Maps normalized key -> first occurrence position + + # Track current position in the string + current_pos = 0 + str_len = len(connection_str) + + # Main parsing loop + while current_pos < str_len: + # Skip leading whitespace and semicolons + while current_pos < str_len and connection_str[current_pos] in ' \t;': + current_pos += 1 + + if current_pos >= str_len: + break + + # Parse the key + key_start = current_pos + + # Advance until we hit '=', ';', or end of string + while current_pos < str_len and connection_str[current_pos] not in '=;': + current_pos += 1 + + # Check if we found a valid '=' separator + if current_pos >= str_len or connection_str[current_pos] != '=': + # ERROR: No '=' found - incomplete specification + incomplete_text = connection_str[key_start:current_pos].strip() + if incomplete_text: + errors.append(f"Incomplete specification: keyword '{incomplete_text}' has no value (missing '=')") + # Skip to next semicolon + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + continue + + # Extract and normalize the key + key = connection_str[key_start:current_pos].strip().lower() + + # ERROR: Empty key + if not key: + errors.append("Empty keyword found (format: =value)") + current_pos += 1 # Skip the '=' + # Skip to next semicolon + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + continue + + # Move past the '=' + current_pos += 1 + + # Parse the value + try: + value, current_pos = self._parse_value(connection_str, current_pos) + + # Check for duplicates + if key in seen_keys: + errors.append(f"Duplicate keyword '{key}' found") + else: + seen_keys[key] = True + params[key] = value + + except ValueError as e: + errors.append(f"Error parsing value for keyword '{key}': {e}") + # Skip to next semicolon + while current_pos < str_len and connection_str[current_pos] != ';': + current_pos += 1 + + # Validate keywords against allowlist if provided + if self._allowlist: + unknown_keys = [] + reserved_keys = [] + + for key in params.keys(): + # Check if this key can be normalized (i.e., it's known) + normalized_key = self._allowlist.normalize_key(key) + + if normalized_key is None: + # Unknown keyword + unknown_keys.append(key) + elif normalized_key in ('Driver', 'APP'): + # Reserved keyword - user cannot set these + reserved_keys.append(key) + + if reserved_keys: + for key in reserved_keys: + errors.append( + f"Reserved keyword '{key}' is controlled by the driver and cannot be specified by the user" + ) + + if unknown_keys: + for key in unknown_keys: + errors.append(f"Unknown keyword '{key}' is not recognized") + + # If we collected any errors, raise them all together + if errors: + raise ConnectionStringParseError(errors) + + return params + + def _parse_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a parameter value from the connection string. + + Handles both simple values and braced values with escaping. + + Args: + connection_str: The connection string + start_pos: Starting position of the value + + Returns: + Tuple of (parsed_value, new_position) + + Raises: + ValueError: If braced value is not properly closed + """ + str_len = len(connection_str) + + # Skip leading whitespace before the value + while start_pos < str_len and connection_str[start_pos] in ' \t': + start_pos += 1 + + # If we've consumed the entire string or reached a semicolon, return empty value + if start_pos >= str_len: + return '', start_pos + + # Determine if this is a braced value or simple value + if connection_str[start_pos] == '{': + return self._parse_braced_value(connection_str, start_pos) + else: + return self._parse_simple_value(connection_str, start_pos) + + def _parse_simple_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a simple (non-braced) value up to the next semicolon. + + Args: + connection_str: The connection string + start_pos: Starting position of the value + + Returns: + Tuple of (parsed_value, new_position) + """ + str_len = len(connection_str) + value_start = start_pos + + # Read characters until we hit a semicolon or end of string + while start_pos < str_len and connection_str[start_pos] != ';': + start_pos += 1 + + # Extract the value and strip trailing whitespace + value = connection_str[value_start:start_pos].rstrip() + return value, start_pos + + def _parse_braced_value(self, connection_str: str, start_pos: int) -> Tuple[str, int]: + """ + Parse a braced value with proper handling of escaped braces. + + Braced values: + - Start with '{' and end with '}' + - '}' inside the value is escaped as '}}' + - '{' inside the value is escaped as '{{' + - Can contain semicolons and other special characters + + Args: + connection_str: The connection string + start_pos: Starting position (should point to opening '{') + + Returns: + Tuple of (parsed_value, new_position) + + Raises: + ValueError: If the braced value is not closed (missing '}') + """ + str_len = len(connection_str) + brace_start_pos = start_pos + + # Skip the opening '{' + start_pos += 1 + + # Build the value character by character + value = [] + + while start_pos < str_len: + ch = connection_str[start_pos] + + if ch == '}': + # Check if next character is also '}' (escaped brace) + if start_pos + 1 < str_len and connection_str[start_pos + 1] == '}': + # Escaped right brace: '}}' → '}' + value.append('}') + start_pos += 2 + else: + # Single '}' means end of braced value + start_pos += 1 + return ''.join(value), start_pos + elif ch == '{': + # Check if it's an escaped left brace + if start_pos + 1 < str_len and connection_str[start_pos + 1] == '{': + # Escaped left brace: '{{' → '{' + value.append('{') + start_pos += 2 + else: + # Single '{' inside braced value - keep it as is + value.append(ch) + start_pos += 1 + else: + # Regular character + value.append(ch) + start_pos += 1 + + # Reached end without finding closing '}' + raise ValueError(f"Unclosed braced value starting at position {brace_start_pos}") diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 9526d1584..c495d0500 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -124,104 +124,43 @@ def test_connection(db_connection): def test_construct_connection_string(db_connection): # Check if the connection string is constructed correctly with kwargs - conn_str = db_connection._construct_connection_string( - host="localhost", - user="me", - password="mypwd", - database="mydb", - encrypt="yes", - trust_server_certificate="yes", - ) - assert ( - "Server=localhost;" in conn_str - ), "Connection string should contain 'Server=localhost;'" - assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'" - assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'" - assert ( - "Database=mydb;" in conn_str - ), "Connection string should contain 'Database=mydb;'" - assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'" - assert ( - "TrustServerCertificate=yes;" in conn_str - ), "Connection string should contain 'TrustServerCertificate=yes;'" - assert ( - "APP=MSSQL-Python" in conn_str - ), "Connection string should contain 'APP=MSSQL-Python'" - assert ( - "Driver={ODBC Driver 18 for SQL Server}" in conn_str - ), "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" - assert ( - "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" - == conn_str - ), "Connection string is incorrect" - + conn_str = db_connection._construct_connection_string(host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes") + # With the new allow-list implementation, parameters are sorted + assert "Server=localhost" in conn_str, "Connection string should contain 'Server=localhost'" + assert "Uid=me" in conn_str, "Connection string should contain 'Uid=me'" + assert "Pwd=mypwd" in conn_str, "Connection string should contain 'Pwd=mypwd'" + assert "Database=mydb" in conn_str, "Connection string should contain 'Database=mydb'" + assert "Encrypt=yes" in conn_str, "Connection string should contain 'Encrypt=yes'" + assert "TrustServerCertificate=yes" in conn_str, "Connection string should contain 'TrustServerCertificate=yes'" + assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'" + assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" def test_connection_string_with_attrs_before(db_connection): # Check if the connection string is constructed correctly with attrs_before - conn_str = db_connection._construct_connection_string( - host="localhost", - user="me", - password="mypwd", - database="mydb", - encrypt="yes", - trust_server_certificate="yes", - attrs_before={1256: "token"}, - ) - assert ( - "Server=localhost;" in conn_str - ), "Connection string should contain 'Server=localhost;'" - assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'" - assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'" - assert ( - "Database=mydb;" in conn_str - ), "Connection string should contain 'Database=mydb;'" - assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'" - assert ( - "TrustServerCertificate=yes;" in conn_str - ), "Connection string should contain 'TrustServerCertificate=yes;'" - assert ( - "APP=MSSQL-Python" in conn_str - ), "Connection string should contain 'APP=MSSQL-Python'" - assert ( - "Driver={ODBC Driver 18 for SQL Server}" in conn_str - ), "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" - assert ( - "{1256: token}" not in conn_str - ), "Connection string should not contain '{1256: token}'" - + conn_str = db_connection._construct_connection_string(host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes", attrs_before={1256: "token"}) + # With the new allow-list implementation, parameters are sorted + assert "Server=localhost" in conn_str, "Connection string should contain 'Server=localhost'" + assert "Uid=me" in conn_str, "Connection string should contain 'Uid=me'" + assert "Pwd=mypwd" in conn_str, "Connection string should contain 'Pwd=mypwd'" + assert "Database=mydb" in conn_str, "Connection string should contain 'Database=mydb'" + assert "Encrypt=yes" in conn_str, "Connection string should contain 'Encrypt=yes'" + assert "TrustServerCertificate=yes" in conn_str, "Connection string should contain 'TrustServerCertificate=yes'" + assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'" + assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" + assert "{1256: token}" not in conn_str, "Connection string should not contain '{1256: token}'" def test_connection_string_with_odbc_param(db_connection): # Check if the connection string is constructed correctly with ODBC parameters - conn_str = db_connection._construct_connection_string( - server="localhost", - uid="me", - pwd="mypwd", - database="mydb", - encrypt="yes", - trust_server_certificate="yes", - ) - assert ( - "Server=localhost;" in conn_str - ), "Connection string should contain 'Server=localhost;'" - assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'" - assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'" - assert ( - "Database=mydb;" in conn_str - ), "Connection string should contain 'Database=mydb;'" - assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'" - assert ( - "TrustServerCertificate=yes;" in conn_str - ), "Connection string should contain 'TrustServerCertificate=yes;'" - assert ( - "APP=MSSQL-Python" in conn_str - ), "Connection string should contain 'APP=MSSQL-Python'" - assert ( - "Driver={ODBC Driver 18 for SQL Server}" in conn_str - ), "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" - assert ( - "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" - == conn_str - ), "Connection string is incorrect" + conn_str = db_connection._construct_connection_string(server="localhost", uid="me", pwd="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes") + # With the new allow-list implementation, parameters are sorted + assert "Server=localhost" in conn_str, "Connection string should contain 'Server=localhost'" + assert "Uid=me" in conn_str, "Connection string should contain 'Uid=me'" + assert "Pwd=mypwd" in conn_str, "Connection string should contain 'Pwd=mypwd'" + assert "Database=mydb" in conn_str, "Connection string should contain 'Database=mydb'" + assert "Encrypt=yes" in conn_str, "Connection string should contain 'Encrypt=yes'" + assert "TrustServerCertificate=yes" in conn_str, "Connection string should contain 'TrustServerCertificate=yes'" + assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'" + assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'" def test_autocommit_default(db_connection): diff --git a/tests/test_006_exceptions.py b/tests/test_006_exceptions.py index ef7056690..1b368838d 100644 --- a/tests/test_006_exceptions.py +++ b/tests/test_006_exceptions.py @@ -15,6 +15,7 @@ raise_exception, truncate_error_message, ) +from mssql_python import ConnectionStringParseError def drop_table_if_exists(cursor, table_name): @@ -193,15 +194,11 @@ def test_foreign_key_constraint_error(cursor, db_connection): def test_connection_error(): - # RuntimeError is raised on Windows, while on MacOS it raises OperationalError - # In MacOS the error goes by "Client unable to establish connection" - # In Windows it goes by "Neither DSN nor SERVER keyword supplied" - # TODO: Make this test platform independent - with pytest.raises((RuntimeError, OperationalError)) as excinfo: + # The new connection string parser now validates the connection string before passing to ODBC + # Invalid strings like "InvalidConnectionString" (missing key=value format) will raise ConnectionStringParseError + with pytest.raises(ConnectionStringParseError) as excinfo: connect("InvalidConnectionString") - assert "Client unable to establish connection" in str( - excinfo.value - ) or "Neither DSN nor SERVER keyword supplied" in str(excinfo.value) + assert "Incomplete specification" in str(excinfo.value) or "has no value" in str(excinfo.value) def test_truncate_error_message_successful_cases(): diff --git a/tests/test_010_connection_string_parser.py b/tests/test_010_connection_string_parser.py new file mode 100644 index 000000000..6b257e930 --- /dev/null +++ b/tests/test_010_connection_string_parser.py @@ -0,0 +1,309 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Unit tests for _ConnectionStringParser (internal). +""" + +import pytest +from mssql_python.connection_string_parser import _ConnectionStringParser, ConnectionStringParseError +from mssql_python.connection_string_allowlist import ConnectionStringAllowList + + +class TestConnectionStringParser: + """Unit tests for _ConnectionStringParser.""" + + def test_parse_empty_string(self): + """Test parsing an empty string returns empty dict.""" + parser = _ConnectionStringParser() + result = parser.parse("") + assert result == {} + + def test_parse_whitespace_only(self): + """Test parsing whitespace-only connection string.""" + parser = _ConnectionStringParser() + result = parser.parse(" \t ") + assert result == {} + + def test_parse_simple_params(self): + """Test parsing simple key=value pairs.""" + parser = _ConnectionStringParser() + result = parser.parse("Server=localhost;Database=mydb") + assert result == { + 'server': 'localhost', + 'database': 'mydb' + } + + def test_parse_single_param(self): + """Test parsing a single parameter.""" + parser = _ConnectionStringParser() + result = parser.parse("Server=localhost") + assert result == {'server': 'localhost'} + + def test_parse_trailing_semicolon(self): + """Test parsing with trailing semicolon.""" + parser = _ConnectionStringParser() + result = parser.parse("Server=localhost;") + assert result == {'server': 'localhost'} + + def test_parse_multiple_semicolons(self): + """Test parsing with multiple consecutive semicolons.""" + parser = _ConnectionStringParser() + result = parser.parse("Server=localhost;;Database=mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_braced_value_with_semicolon(self): + """Test parsing braced values containing semicolons.""" + parser = _ConnectionStringParser() + result = parser.parse("Server={;local;host};Database=mydb") + assert result == { + 'server': ';local;host', + 'database': 'mydb' + } + + def test_parse_braced_value_with_escaped_right_brace(self): + """Test parsing braced values with escaped }}.""" + parser = _ConnectionStringParser() + result = parser.parse("PWD={p}}w{{d}") + assert result == {'pwd': 'p}w{d'} + + def test_parse_braced_value_with_all_escapes(self): + """Test parsing braced values with both {{ and }} escapes.""" + parser = _ConnectionStringParser() + result = parser.parse("Value={test}}{{escape}") + assert result == {'value': 'test}{escape'} + + def test_parse_empty_value(self): + """Test parsing parameter with empty value.""" + parser = _ConnectionStringParser() + result = parser.parse("Server=;Database=mydb") + assert result == {'server': '', 'database': 'mydb'} + + def test_parse_empty_braced_value(self): + """Test parsing parameter with empty braced value.""" + parser = _ConnectionStringParser() + result = parser.parse("Server={};Database=mydb") + assert result == {'server': '', 'database': 'mydb'} + + def test_parse_whitespace_around_key(self): + """Test parsing with whitespace around keys.""" + parser = _ConnectionStringParser() + result = parser.parse(" Server =localhost; Database =mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_whitespace_in_simple_value(self): + """Test parsing simple value with trailing whitespace.""" + parser = _ConnectionStringParser() + result = parser.parse("Server=localhost ;Database=mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_case_insensitive_keys(self): + """Test that keys are normalized to lowercase.""" + parser = _ConnectionStringParser() + result = parser.parse("SERVER=localhost;DatABase=mydb") + assert result == {'server': 'localhost', 'database': 'mydb'} + + def test_parse_special_chars_in_simple_value(self): + """Test parsing simple values with special characters (not ; { }).""" + parser = _ConnectionStringParser() + result = parser.parse("Server=server:1433;User=domain\\user") + assert result == {'server': 'server:1433', 'user': 'domain\\user'} + + def test_parse_complex_connection_string(self): + """Test parsing a complex realistic connection string.""" + parser = _ConnectionStringParser() + conn_str = "Server=tcp:server.database.windows.net,1433;Database=mydb;UID=user@server;PWD={p@ss;w}}rd};Encrypt=yes" + result = parser.parse(conn_str) + assert result == { + 'server': 'tcp:server.database.windows.net,1433', + 'database': 'mydb', + 'uid': 'user@server', + 'pwd': 'p@ss;w}rd', # }} escapes to single } + 'encrypt': 'yes' + } + + def test_parse_driver_parameter(self): + """Test parsing Driver parameter with braced value.""" + parser = _ConnectionStringParser() + result = parser.parse("Driver={ODBC Driver 18 for SQL Server};Server=localhost") + assert result == { + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'localhost' + } + + def test_parse_braced_value_with_left_brace(self): + """Test parsing braced value containing unescaped single {.""" + parser = _ConnectionStringParser() + result = parser.parse("Value={test{value}") + assert result == {'value': 'test{value'} + + def test_parse_braced_value_double_left_brace(self): + """Test parsing braced value with escaped {{ (left brace).""" + parser = _ConnectionStringParser() + result = parser.parse("Value={test{{value}") + assert result == {'value': 'test{value'} + + def test_parse_unicode_characters(self): + """Test parsing values with unicode characters.""" + parser = _ConnectionStringParser() + result = parser.parse("Database=数据库;Server=сервер") + assert result == {'database': '数据库', 'server': 'сервер'} + + def test_parse_equals_in_braced_value(self): + """Test parsing braced value containing equals sign.""" + parser = _ConnectionStringParser() + result = parser.parse("Value={key=value}") + assert result == {'value': 'key=value'} + + +class TestConnectionStringParserErrors: + """Test error handling in ConnectionStringParser.""" + + def test_error_duplicate_keys(self): + """Test that duplicate keys raise an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=first;Server=second;Server=third") + + assert "Duplicate keyword 'server'" in str(exc_info.value) + assert len(exc_info.value.errors) == 2 # Two duplicates (second and third) + + def test_error_incomplete_specification_no_equals(self): + """Test that keyword without '=' raises an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server;Database=mydb") + + assert "Incomplete specification" in str(exc_info.value) + assert "'server'" in str(exc_info.value).lower() + + def test_error_incomplete_specification_trailing(self): + """Test that trailing keyword without value raises an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=localhost;Database") + + assert "Incomplete specification" in str(exc_info.value) + assert "'database'" in str(exc_info.value).lower() + + def test_error_empty_key(self): + """Test that empty keyword raises an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("=value;Server=localhost") + + assert "Empty keyword" in str(exc_info.value) + + def test_error_unclosed_braced_value(self): + """Test that unclosed braces raise an error.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("PWD={unclosed;Server=localhost") + + assert "Unclosed braced value" in str(exc_info.value) + + def test_error_multiple_issues_collected(self): + """Test that multiple errors are collected and reported together.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=first;InvalidEntry;Server=second;Database") + + # Should have: incomplete spec for InvalidEntry, duplicate Server, incomplete spec for Database + assert len(exc_info.value.errors) >= 3 + assert "Incomplete specification" in str(exc_info.value) + assert "Duplicate keyword" in str(exc_info.value) + + def test_error_unknown_keyword_with_allowlist(self): + """Test that unknown keywords are flagged when allowlist is provided.""" + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=localhost;UnknownParam=value") + + assert "Unknown keyword 'unknownparam'" in str(exc_info.value) + + def test_error_multiple_unknown_keywords(self): + """Test that multiple unknown keywords are all flagged.""" + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=localhost;Unknown1=val1;Database=mydb;Unknown2=val2") + + errors_str = str(exc_info.value) + assert "Unknown keyword 'unknown1'" in errors_str + assert "Unknown keyword 'unknown2'" in errors_str + + def test_error_combined_unknown_and_duplicate(self): + """Test that unknown keywords and duplicates are both flagged.""" + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=first;UnknownParam=value;Server=second") + + errors_str = str(exc_info.value) + assert "Unknown keyword 'unknownparam'" in errors_str + assert "Duplicate keyword 'server'" in errors_str + + def test_valid_with_allowlist(self): + """Test that valid keywords pass when allowlist is provided.""" + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # These are all valid keywords in the allowlist + result = parser.parse("Server=localhost;Database=mydb;UID=user;PWD=pass") + assert result == { + 'server': 'localhost', + 'database': 'mydb', + 'uid': 'user', + 'pwd': 'pass' + } + + def test_no_validation_without_allowlist(self): + """Test that unknown keywords are allowed when no allowlist is provided.""" + parser = _ConnectionStringParser() # No allowlist + + # Should parse successfully even with unknown keywords + result = parser.parse("Server=localhost;MadeUpKeyword=value") + assert result == { + 'server': 'localhost', + 'madeupkeyword': 'value' + } + + +class TestConnectionStringParserEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_error_all_duplicates(self): + """Test string with only duplicates.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=a;Server=b;Server=c") + + # First occurrence is kept, other two are duplicates + assert len(exc_info.value.errors) == 2 + + def test_error_mixed_valid_and_errors(self): + """Test that valid params are parsed even when errors exist.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=localhost;BadEntry;Database=mydb;Server=dup") + + # Should detect incomplete and duplicate + assert len(exc_info.value.errors) >= 2 + + def test_normalization_still_works(self): + """Test that key normalization to lowercase still works.""" + parser = _ConnectionStringParser() + result = parser.parse("SERVER=srv;DaTaBaSe=db") + assert result == {'server': 'srv', 'database': 'db'} + + def test_error_duplicate_after_normalization(self): + """Test that duplicates are detected after normalization.""" + parser = _ConnectionStringParser() + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=first;SERVER=second") + + assert "Duplicate keyword 'server'" in str(exc_info.value) diff --git a/tests/test_011_connection_string_allowlist.py b/tests/test_011_connection_string_allowlist.py new file mode 100644 index 000000000..4e1cd3728 --- /dev/null +++ b/tests/test_011_connection_string_allowlist.py @@ -0,0 +1,184 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Unit tests for ConnectionStringAllowList. +""" + +import pytest +from mssql_python.connection_string_allowlist import ConnectionStringAllowList + + +class TestConnectionStringAllowList: + """Unit tests for ConnectionStringAllowList.""" + + def test_normalize_key_server(self): + """Test normalization of 'server' and its synonyms.""" + assert ConnectionStringAllowList.normalize_key('server') == 'Server' + assert ConnectionStringAllowList.normalize_key('SERVER') == 'Server' + assert ConnectionStringAllowList.normalize_key('Server') == 'Server' + assert ConnectionStringAllowList.normalize_key('address') == 'Server' + assert ConnectionStringAllowList.normalize_key('addr') == 'Server' + assert ConnectionStringAllowList.normalize_key('network address') == 'Server' + + def test_normalize_key_authentication(self): + """Test normalization of authentication parameters.""" + assert ConnectionStringAllowList.normalize_key('uid') == 'Uid' + assert ConnectionStringAllowList.normalize_key('user id') == 'Uid' + assert ConnectionStringAllowList.normalize_key('user') == 'Uid' + assert ConnectionStringAllowList.normalize_key('pwd') == 'Pwd' + assert ConnectionStringAllowList.normalize_key('password') == 'Pwd' + + def test_normalize_key_database(self): + """Test normalization of database parameters.""" + assert ConnectionStringAllowList.normalize_key('database') == 'Database' + assert ConnectionStringAllowList.normalize_key('initial catalog') == 'Database' + + def test_normalize_key_encryption(self): + """Test normalization of encryption parameters.""" + assert ConnectionStringAllowList.normalize_key('encrypt') == 'Encrypt' + assert ConnectionStringAllowList.normalize_key('trustservercertificate') == 'TrustServerCertificate' + assert ConnectionStringAllowList.normalize_key('trust server certificate') == 'TrustServerCertificate' + + def test_normalize_key_timeout(self): + """Test normalization of timeout parameters.""" + assert ConnectionStringAllowList.normalize_key('connection timeout') == 'Connection Timeout' + assert ConnectionStringAllowList.normalize_key('connect timeout') == 'Connection Timeout' + assert ConnectionStringAllowList.normalize_key('timeout') == 'Connection Timeout' + assert ConnectionStringAllowList.normalize_key('login timeout') == 'Login Timeout' + + def test_normalize_key_mars(self): + """Test that MARS parameters are not in the allowlist.""" + assert ConnectionStringAllowList.normalize_key('mars_connection') is None + assert ConnectionStringAllowList.normalize_key('mars connection') is None + assert ConnectionStringAllowList.normalize_key('multipleactiveresultsets') is None + + def test_normalize_key_app(self): + """Test normalization of APP parameters.""" + assert ConnectionStringAllowList.normalize_key('app') == 'APP' + assert ConnectionStringAllowList.normalize_key('application name') == 'APP' + + def test_normalize_key_driver(self): + """Test normalization of Driver parameter.""" + assert ConnectionStringAllowList.normalize_key('driver') == 'Driver' + assert ConnectionStringAllowList.normalize_key('DRIVER') == 'Driver' + + def test_normalize_key_not_allowed(self): + """Test normalization of disallowed keys returns None.""" + assert ConnectionStringAllowList.normalize_key('BadParam') is None + assert ConnectionStringAllowList.normalize_key('UnsupportedParameter') is None + assert ConnectionStringAllowList.normalize_key('RandomKey') is None + + def test_normalize_key_whitespace(self): + """Test normalization handles whitespace.""" + assert ConnectionStringAllowList.normalize_key(' server ') == 'Server' + assert ConnectionStringAllowList.normalize_key(' uid ') == 'Uid' + + def test_filter_params_allows_good_params(self): + """Test filtering allows known parameters.""" + params = {'server': 'localhost', 'database': 'mydb', 'encrypt': 'yes'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'Database' in filtered + assert 'Encrypt' in filtered + assert filtered['Server'] == 'localhost' + assert filtered['Database'] == 'mydb' + assert filtered['Encrypt'] == 'yes' + + def test_filter_params_rejects_bad_params(self): + """Test filtering rejects unknown parameters.""" + params = {'server': 'localhost', 'badparam': 'value', 'anotherbad': 'test'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'badparam' not in filtered + assert 'anotherbad' not in filtered + + def test_filter_params_normalizes_keys(self): + """Test filtering normalizes parameter keys.""" + params = {'server': 'localhost', 'uid': 'user', 'pwd': 'pass'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'Uid' in filtered + assert 'Pwd' in filtered + assert 'server' not in filtered # Original key should not be present + + def test_filter_params_handles_synonyms(self): + """Test filtering handles parameter synonyms correctly.""" + params = { + 'address': 'server1', + 'user': 'testuser', + 'initial catalog': 'testdb', + 'connection timeout': '30' + } + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert filtered['Server'] == 'server1' + assert filtered['Uid'] == 'testuser' + assert filtered['Database'] == 'testdb' + assert filtered['Connection Timeout'] == '30' + + def test_filter_params_empty_dict(self): + """Test filtering empty parameter dictionary.""" + filtered = ConnectionStringAllowList.filter_params({}, warn_rejected=False) + assert filtered == {} + + def test_filter_params_removes_driver(self): + """Test that Driver parameter is filtered out (controlled by driver).""" + params = {'driver': '{Some Driver}', 'server': 'localhost'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert 'Driver' not in filtered + assert 'Server' in filtered + + def test_filter_params_removes_app(self): + """Test that APP parameter is filtered out (controlled by driver).""" + params = {'app': 'MyApp', 'server': 'localhost'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert 'APP' not in filtered + assert 'Server' in filtered + + def test_filter_params_mixed_case_keys(self): + """Test filtering with mixed case keys.""" + params = {'SERVER': 'localhost', 'DataBase': 'mydb', 'EncRypt': 'yes'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert 'Server' in filtered + assert 'Database' in filtered + assert 'Encrypt' in filtered + + def test_filter_params_preserves_values(self): + """Test that filtering preserves original values unchanged.""" + params = { + 'server': 'localhost:1433', + 'database': 'MyDatabase', + 'pwd': 'P@ssw0rd!123' + } + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert filtered['Server'] == 'localhost:1433' + assert filtered['Database'] == 'MyDatabase' + assert filtered['Pwd'] == 'P@ssw0rd!123' + + def test_filter_params_application_intent(self): + """Test filtering application intent parameters.""" + params = {'applicationintent': 'ReadOnly', 'application intent': 'ReadWrite'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + # Last one wins (application intent → ReadWrite) + assert filtered['ApplicationIntent'] == 'ReadWrite' + + def test_filter_params_failover_partner(self): + """Test filtering failover partner parameters.""" + params = {'failover partner': 'backup.server.com'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + assert filtered['Failover_Partner'] == 'backup.server.com' + + def test_filter_params_column_encryption(self): + """Test that column encryption parameter is not in the allowlist.""" + params = {'columnencryption': 'Enabled', 'column encryption': 'Disabled'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + # Column encryption is not in the allowlist, so it should be filtered out + assert 'ColumnEncryption' not in filtered + assert len(filtered) == 0 + + def test_filter_params_multisubnetfailover(self): + """Test filtering multi-subnet failover parameters.""" + params = {'multisubnetfailover': 'yes', 'multi subnet failover': 'no'} + filtered = ConnectionStringAllowList.filter_params(params, warn_rejected=False) + # Last one wins + assert filtered['MultiSubnetFailover'] == 'no' diff --git a/tests/test_012_connection_string_integration.py b/tests/test_012_connection_string_integration.py new file mode 100644 index 000000000..c193b9cdd --- /dev/null +++ b/tests/test_012_connection_string_integration.py @@ -0,0 +1,620 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Integration tests for connection string allow-list feature. + +These tests verify end-to-end behavior of the parser, filter, and builder pipeline. +""" + +import pytest +import os +from unittest.mock import patch, MagicMock +from mssql_python.connection_string_parser import _ConnectionStringParser, ConnectionStringParseError +from mssql_python.connection_string_allowlist import ConnectionStringAllowList +from mssql_python.connection_string_builder import _ConnectionStringBuilder +from mssql_python import connect +from mssql_python.connection import Connection +from mssql_python.exceptions import DatabaseError, InterfaceError + + +class TestConnectionStringIntegration: + """Integration tests for the complete connection string flow.""" + + def test_parse_filter_build_simple(self): + """Test complete flow with simple parameters.""" + # Parse + parser = _ConnectionStringParser() + parsed = parser.parse("Server=localhost;Database=mydb;Encrypt=yes") + + # Filter + filtered = ConnectionStringAllowList.filter_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + builder.add_param('APP', 'MSSQL-Python') + result = builder.build() + + # Verify + assert 'Driver={ODBC Driver 18 for SQL Server}' in result + assert 'Server=localhost' in result + assert 'Database=mydb' in result + assert 'Encrypt=yes' in result + assert 'APP=MSSQL-Python' in result + + def test_parse_filter_build_with_unsupported_param(self): + """Test that unsupported parameters are flagged as errors with allowlist.""" + # Parse with allowlist + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # Should raise error for unknown keyword + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=localhost;Database=mydb;UnsupportedParam=value") + + assert "Unknown keyword 'unsupportedparam'" in str(exc_info.value) + + def test_parse_filter_build_with_braced_values(self): + """Test complete flow with braced values and special characters.""" + # Parse + parser = _ConnectionStringParser() + parsed = parser.parse("Server={local;host};PWD={p@ss;w}}rd}") + + # Filter + filtered = ConnectionStringAllowList.filter_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Verify - values with special chars should be re-escaped + assert 'Driver={ODBC Driver 18 for SQL Server}' in result + assert 'Server={local;host}' in result + assert 'Pwd={p@ss;w}}rd}' in result or 'PWD={p@ss;w}}rd}' in result + + def test_parse_filter_build_synonym_normalization(self): + """Test that parameter synonyms are normalized.""" + # Parse + parser = _ConnectionStringParser() + parsed = parser.parse("address=server1;user=testuser;initial catalog=testdb") + + # Filter (normalizes synonyms) + filtered = ConnectionStringAllowList.filter_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Verify - should use canonical names + assert 'Server=server1' in result + assert 'Uid=testuser' in result + assert 'Database=testdb' in result + # Original names should not appear + assert 'address' not in result.lower() + assert 'user=' not in result.lower() + assert 'initial catalog' not in result.lower() + + def test_parse_filter_build_driver_and_app_reserved(self): + """Test that Driver and APP in connection string raise errors.""" + # Parser should reject Driver and APP as reserved keywords + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # Test with APP + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("APP=UserApp;Server=localhost") + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'app'" in error_lower + + # Test with Driver + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Driver={Some Other Driver};Server=localhost") + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'driver'" in error_lower + + # Test with both + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Driver={Some Other Driver};APP=UserApp;Server=localhost") + error_str = str(exc_info.value).lower() + assert "reserved keyword" in error_str + # Should have errors for both + assert len(exc_info.value.errors) == 2 + + def test_parse_filter_build_empty_input(self): + """Test complete flow with empty input.""" + # Parse + parser = _ConnectionStringParser() + parsed = parser.parse("") + + # Filter + filtered = ConnectionStringAllowList.filter_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Verify - should only have Driver + assert result == 'Driver={ODBC Driver 18 for SQL Server}' + + def test_parse_filter_build_complex_realistic(self): + """Test complete flow with complex realistic connection string.""" + # Parse + parser = _ConnectionStringParser() + conn_str = "Server=tcp:server.database.windows.net,1433;Database=mydb;UID=user@server;PWD={P@ss;w}}rd};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30" + parsed = parser.parse(conn_str) + + # Filter + filtered = ConnectionStringAllowList.filter_params(parsed, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + builder.add_param('APP', 'MSSQL-Python') + result = builder.build() + + # Verify key parameters are present + assert 'Driver={ODBC Driver 18 for SQL Server}' in result + assert 'Server=tcp:server.database.windows.net,1433' in result + assert 'Database=mydb' in result + assert 'Uid=user@server' in result + assert 'Pwd={P@ss;w}}rd}' in result or 'PWD={P@ss;w}}rd}' in result + assert 'Encrypt=yes' in result + assert 'TrustServerCertificate=no' in result + assert 'Connection Timeout=30' in result + assert 'APP=MSSQL-Python' in result + + def test_parse_error_incomplete_specification(self): + """Test that incomplete specifications raise errors.""" + parser = _ConnectionStringParser() + + # Incomplete specification raises error + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server localhost;Database=mydb") + + assert "Incomplete specification" in str(exc_info.value) + assert "'server localhost'" in str(exc_info.value).lower() + + def test_parse_error_unclosed_brace(self): + """Test that unclosed braces raise errors.""" + parser = _ConnectionStringParser() + + # Unclosed brace raises error + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("PWD={unclosed;Server=localhost") + + assert "Unclosed braced value" in str(exc_info.value) + + def test_parse_error_duplicate_keywords(self): + """Test that duplicate keywords raise errors.""" + parser = _ConnectionStringParser() + + # Duplicate keywords raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=first;Server=second") + + assert "Duplicate keyword 'server'" in str(exc_info.value) + + def test_round_trip_preserves_values(self): + """Test that parsing and rebuilding preserves parameter values.""" + original_params = { + 'server': 'localhost:1433', + 'database': 'TestDB', + 'uid': 'testuser', + 'pwd': 'Test@123', + 'encrypt': 'yes' + } + + # Filter + filtered = ConnectionStringAllowList.filter_params(original_params, warn_rejected=False) + + # Build + builder = _ConnectionStringBuilder(filtered) + builder.add_param('Driver', 'ODBC Driver 18 for SQL Server') + result = builder.build() + + # Parse back + parser = _ConnectionStringParser() + parsed = parser.parse(result) + + # Verify values are preserved (keys are normalized to lowercase in parsing) + assert parsed['server'] == 'localhost:1433' + assert parsed['database'] == 'TestDB' + assert parsed['uid'] == 'testuser' + assert parsed['pwd'] == 'Test@123' + assert parsed['encrypt'] == 'yes' + assert parsed['driver'] == 'ODBC Driver 18 for SQL Server' + + def test_builder_escaping_is_correct(self): + """Test that builder correctly escapes special characters.""" + builder = _ConnectionStringBuilder() + builder.add_param('Server', 'local;host') + builder.add_param('PWD', 'p}w{d') + builder.add_param('Value', 'test;{value}') + result = builder.build() + + # Parse back to verify escaping worked + parser = _ConnectionStringParser() + parsed = parser.parse(result) + + assert parsed['server'] == 'local;host' + assert parsed['pwd'] == 'p}w{d' + assert parsed['value'] == 'test;{value}' + + def test_multiple_errors_collected(self): + """Test that multiple errors are collected and reported together.""" + parser = _ConnectionStringParser() + + # Multiple errors: incomplete spec, duplicate + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=first;InvalidEntry;Server=second;Database") + + # Should have multiple errors + assert len(exc_info.value.errors) >= 3 + assert "Incomplete specification" in str(exc_info.value) + assert "Duplicate keyword" in str(exc_info.value) + + def test_parser_without_allowlist_accepts_unknown(self): + """Test that parser without allowlist accepts unknown keywords.""" + parser = _ConnectionStringParser() # No allowlist + + # Should parse successfully even with unknown keywords + result = parser.parse("Server=localhost;MadeUpKeyword=value") + assert result == { + 'server': 'localhost', + 'madeupkeyword': 'value' + } + + def test_parser_with_allowlist_rejects_unknown(self): + """Test that parser with allowlist rejects unknown keywords.""" + allowlist = ConnectionStringAllowList() + parser = _ConnectionStringParser(allowlist=allowlist) + + # Should raise error for unknown keyword + with pytest.raises(ConnectionStringParseError) as exc_info: + parser.parse("Server=localhost;MadeUpKeyword=value") + + assert "Unknown keyword 'madeupkeyword'" in str(exc_info.value) + + +class TestConnectAPIIntegration: + """Integration tests for the connect() API with connection string validation.""" + + def test_connect_with_unknown_keyword_raises_error(self): + """Test that connect() raises error for unknown keywords.""" + # connect() uses allowlist validation internally + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=localhost;Database=test;UnknownKeyword=value") + + assert "Unknown keyword 'unknownkeyword'" in str(exc_info.value) + + def test_connect_with_duplicate_keywords_raises_error(self): + """Test that connect() raises error for duplicate keywords.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=first;Server=second;Database=test") + + assert "Duplicate keyword 'server'" in str(exc_info.value) + + def test_connect_with_incomplete_specification_raises_error(self): + """Test that connect() raises error for incomplete specifications.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server localhost;Database=test") + + assert "Incomplete specification" in str(exc_info.value) + + def test_connect_with_unclosed_brace_raises_error(self): + """Test that connect() raises error for unclosed braces.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("PWD={unclosed;Server=localhost") + + assert "Unclosed braced value" in str(exc_info.value) + + def test_connect_with_multiple_errors_collected(self): + """Test that connect() collects multiple errors.""" + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=first;InvalidEntry;Server=second;Database") + + # Should have multiple errors + assert len(exc_info.value.errors) >= 3 + error_str = str(exc_info.value) + assert "Incomplete specification" in error_str + assert "Duplicate keyword" in error_str + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_kwargs_override_connection_string(self, mock_ddbc_conn): + """Test that kwargs override connection string parameters.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + conn = connect("Server=original;Database=originaldb", + Server="overridden", + Database="overriddendb") + + # Verify the override worked + assert "overridden" in conn.connection_str.lower() + assert "overriddendb" in conn.connection_str.lower() + # Original values should not be in the final connection string + assert "original" not in conn.connection_str.lower() or "originaldb" not in conn.connection_str.lower() + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_app_parameter_in_connection_string_raises_error(self, mock_ddbc_conn): + """Test that APP parameter in connection string raises ConnectionStringParseError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set APP in connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=localhost;APP=UserApp;Database=test") + + # Verify error message + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'app'" in error_lower + assert "controlled by the driver" in error_lower + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_app_parameter_in_kwargs_raises_error(self, mock_ddbc_conn): + """Test that APP parameter in kwargs raises ValueError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set APP via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect("Server=localhost;Database=test", APP="UserApp") + + assert "reserved and controlled by the driver" in str(exc_info.value) + assert "APP" in str(exc_info.value) or "app" in str(exc_info.value).lower() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_driver_parameter_in_connection_string_raises_error(self, mock_ddbc_conn): + """Test that Driver parameter in connection string raises ConnectionStringParseError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set Driver in connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + connect("Server=localhost;Driver={Some Other Driver};Database=test") + + # Verify error message + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'driver'" in error_lower + assert "controlled by the driver" in error_lower + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_driver_parameter_in_kwargs_raises_error(self, mock_ddbc_conn): + """Test that Driver parameter in kwargs raises ValueError.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # User tries to set Driver via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect("Server=localhost;Database=test", Driver="Some Other Driver") + + assert "reserved and controlled by the driver" in str(exc_info.value) + assert "Driver" in str(exc_info.value) + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_synonym_normalization(self, mock_ddbc_conn): + """Test that connect() normalizes parameter synonyms.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + conn = connect("address=server1;user=testuser;initial catalog=testdb") + + # Synonyms should be normalized to canonical names + conn_str_lower = conn.connection_str.lower() + assert "server=server1" in conn_str_lower + assert "uid=testuser" in conn_str_lower + assert "database=testdb" in conn_str_lower + # Original names should not appear + assert "address=" not in conn_str_lower + assert "user=" not in conn_str_lower + assert "initial catalog=" not in conn_str_lower + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_kwargs_unknown_parameter_warned(self, mock_ddbc_conn): + """Test that unknown kwargs are warned about but don't raise errors during parsing.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + # Unknown kwargs are filtered out with a warning, but don't cause parse errors + # because kwargs bypass the parser's allowlist validation + conn = connect("Server=localhost", Database="test", UnknownParam="value") + + # UnknownParam should be filtered out (warned but not included) + conn_str_lower = conn.connection_str.lower() + assert "database=test" in conn_str_lower + assert "unknownparam" not in conn_str_lower + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_empty_connection_string(self, mock_ddbc_conn): + """Test that connect() works with empty connection string and kwargs.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + conn = connect("", Server="localhost", Database="test") + + # Should have Server and Database from kwargs + conn_str_lower = conn.connection_str.lower() + assert "server=localhost" in conn_str_lower + assert "database=test" in conn_str_lower + assert "driver=" in conn_str_lower # Driver is always added + assert "app=mssql-python" in conn_str_lower # APP is always added + + conn.close() + + @patch('mssql_python.connection.ddbc_bindings.Connection') + def test_connect_special_characters_in_values(self, mock_ddbc_conn): + """Test that connect() properly handles special characters in parameter values.""" + # Mock the underlying ODBC connection + mock_ddbc_conn.return_value = MagicMock() + + conn = connect("Server={local;host};PWD={p@ss;w}}rd};Database=test") + + # Special characters should be preserved through parsing and building + # The connection string should properly escape them + assert "local;host" in conn.connection_str or "{local;host}" in conn.connection_str + assert "p@ss;w}rd" in conn.connection_str or "{p@ss;w}}rd}" in conn.connection_str + + conn.close() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_with_real_database(self, conn_str): + """Test that connect() works with a real database connection.""" + # This test only runs if DB_CONNECTION_STRING is set + conn = connect(conn_str) + assert conn is not None + + # Verify connection string has required parameters + assert "Driver=" in conn.connection_str or "driver=" in conn.connection_str + assert "APP=MSSQL-Python" in conn.connection_str or "app=mssql-python" in conn.connection_str.lower() + + # Test basic query execution + cursor = conn.cursor() + cursor.execute("SELECT 1 AS test") + row = cursor.fetchone() + assert row[0] == 1 + cursor.close() + + conn.close() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_kwargs_override_with_real_database(self, conn_str): + """Test that kwargs override works with a real database connection.""" + # Parse the original connection string to extract server + parser = _ConnectionStringParser() + original_params = parser.parse(conn_str) + + # Get the server from original connection for reconnection + server = original_params.get('server', 'localhost') + + # Create connection with overridden autocommit + conn = connect(conn_str, autocommit=True) + + # Verify connection works and autocommit is set + assert conn.autocommit == True + + # Verify connection string still has all required params + assert "Driver=" in conn.connection_str or "driver=" in conn.connection_str + assert "APP=MSSQL-Python" in conn.connection_str or "app=mssql-python" in conn.connection_str.lower() + + conn.close() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_reserved_params_in_connection_string_raise_error(self, conn_str): + """Test that reserved params (Driver, APP) in connection string raise error.""" + # Try to add Driver to connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + test_conn_str = conn_str + ";Driver={User Driver}" + connect(test_conn_str) + assert "reserved keyword" in str(exc_info.value).lower() + assert "driver" in str(exc_info.value).lower() + + # Try to add APP to connection string - should raise error + with pytest.raises(ConnectionStringParseError) as exc_info: + test_conn_str = conn_str + ";APP=UserApp" + connect(test_conn_str) + assert "reserved keyword" in str(exc_info.value).lower() + assert "app" in str(exc_info.value).lower() + + # Try Application Name synonym + with pytest.raises(ConnectionStringParseError) as exc_info: + test_conn_str = conn_str + ";Application Name=UserApp" + connect(test_conn_str) + assert "reserved keyword" in str(exc_info.value).lower() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_connect_reserved_params_in_kwargs_raise_error(self, conn_str): + """Test that reserved params (Driver, APP) in kwargs raise ValueError.""" + # Try to override Driver via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect(conn_str, Driver="User Driver") + assert "reserved and controlled by the driver" in str(exc_info.value) + + # Try to override APP via kwargs - should raise ValueError + with pytest.raises(ValueError) as exc_info: + connect(conn_str, APP="UserApp") + assert "reserved and controlled by the driver" in str(exc_info.value) + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_app_name_received_by_sql_server(self, conn_str): + """Test that SQL Server receives the driver-controlled APP name 'MSSQL-Python'.""" + # Connect to SQL Server + conn = connect(conn_str) + + try: + # Query SQL Server to get the application name it received + cursor = conn.cursor() + cursor.execute("SELECT APP_NAME() AS app_name") + row = cursor.fetchone() + cursor.close() + + # Verify SQL Server received the driver-controlled application name + assert row is not None, "Failed to get APP_NAME() from SQL Server" + app_name_received = row[0] + + # SQL Server should have received 'MSSQL-Python', not any user-provided value + assert app_name_received == 'MSSQL-Python', \ + f"Expected SQL Server to receive 'MSSQL-Python', but got '{app_name_received}'" + + print(f"\n SQL Server correctly received APP_NAME: '{app_name_received}'") + finally: + conn.close() + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_app_name_in_connection_string_raises_error(self, conn_str): + """Test that APP in connection string raises ConnectionStringParseError.""" + # Connection strings with APP parameter should now raise an error (not silently filter) + + # Try to add APP to connection string + test_conn_str = conn_str + ";APP=UserDefinedApp" + + # Should raise ConnectionStringParseError + with pytest.raises(ConnectionStringParseError) as exc_info: + connect(test_conn_str) + + error_lower = str(exc_info.value).lower() + assert "reserved keyword" in error_lower + assert "'app'" in error_lower + assert "controlled by the driver" in error_lower + + print("\n APP in connection string correctly raised ConnectionStringParseError") + + @pytest.mark.skipif(not os.getenv('DB_CONNECTION_STRING'), + reason="Requires database connection string") + def test_app_name_in_kwargs_rejected_before_sql_server(self, conn_str): + """Test that APP in kwargs raises ValueError before even attempting to connect to SQL Server.""" + # Unlike connection strings (which are silently filtered), kwargs with APP should raise an error + # This prevents the connection attempt entirely + + with pytest.raises(ValueError) as exc_info: + connect(conn_str, APP="UserDefinedApp") + + assert "reserved and controlled by the driver" in str(exc_info.value) + assert "APP" in str(exc_info.value) or "app" in str(exc_info.value).lower() + + print("\n APP in kwargs correctly raised ValueError before connecting to SQL Server") + + + + + +