Skip to content

Commit 10744c8

Browse files
committed
FIX: Encoding Decoding
1 parent ade9a05 commit 10744c8

File tree

5 files changed

+5619
-56
lines changed

5 files changed

+5619
-56
lines changed

mssql_python/connection.py

Lines changed: 135 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@
5454
INFO_TYPE_STRING_THRESHOLD: int = 10000
5555

5656
# UTF-16 encoding variants that should use SQL_WCHAR by default
57-
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16", "utf-16le", "utf-16be"])
57+
# Note: "utf-16" with BOM is NOT included as it's problematic for SQL_WCHAR
58+
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16le", "utf-16be"])
59+
60+
# Valid encoding characters (alphanumeric, dash, underscore only)
61+
import string
62+
VALID_ENCODING_CHARS: frozenset[str] = frozenset(string.ascii_letters + string.digits + '-_')
5863

5964

6065
def _validate_encoding(encoding: str) -> bool:
@@ -70,7 +75,17 @@ def _validate_encoding(encoding: str) -> bool:
7075
Note:
7176
Uses LRU cache to avoid repeated expensive codecs.lookup() calls.
7277
Cache size is limited to 128 entries which should cover most use cases.
78+
Also validates that encoding name only contains safe characters.
7379
"""
80+
# First check for dangerous characters (security validation)
81+
if not all(c in VALID_ENCODING_CHARS for c in encoding):
82+
return False
83+
84+
# Check length limit (prevent DOS)
85+
if len(encoding) > 100:
86+
return False
87+
88+
# Then check if it's a valid Python codec
7489
try:
7590
codecs.lookup(encoding)
7691
return True
@@ -226,6 +241,11 @@ def __init__(
226241
# Initialize output converters dictionary and its lock for thread safety
227242
self._output_converters = {}
228243
self._converters_lock = threading.Lock()
244+
245+
# Initialize encoding/decoding settings lock for thread safety
246+
# This lock protects both _encoding_settings and _decoding_settings dictionaries
247+
# to prevent race conditions when multiple threads are reading/writing encoding settings
248+
self._encoding_lock = threading.RLock() # RLock allows recursive locking
229249

230250
# Initialize search escape character
231251
self._searchescape = None
@@ -429,6 +449,20 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
429449
# Normalize encoding to casefold for more robust Unicode handling
430450
encoding = encoding.casefold()
431451
logger.debug("setencoding: Encoding normalized to %s", encoding)
452+
453+
# Reject 'utf-16' with BOM for SQL_WCHAR (ambiguous byte order)
454+
if encoding == "utf-16" and ctype == ConstantsDDBC.SQL_WCHAR.value:
455+
logger.debug(
456+
"warning",
457+
"utf-16 with BOM rejected for SQL_WCHAR",
458+
)
459+
raise ProgrammingError(
460+
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
461+
ddbc_error=(
462+
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
463+
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
464+
),
465+
)
432466

433467
# Set default ctype based on encoding if not provided
434468
if ctype is None:
@@ -455,9 +489,34 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
455489
f"SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})"
456490
),
457491
)
492+
493+
# Validate that SQL_WCHAR ctype only used with UTF-16 encodings (not utf-16 with BOM)
494+
if ctype == ConstantsDDBC.SQL_WCHAR.value:
495+
if encoding == "utf-16":
496+
raise ProgrammingError(
497+
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
498+
ddbc_error=(
499+
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
500+
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
501+
),
502+
)
503+
elif encoding not in UTF16_ENCODINGS:
504+
logger.debug(
505+
"warning",
506+
"Non-UTF-16 encoding %s attempted with SQL_WCHAR ctype",
507+
sanitize_user_input(encoding),
508+
)
509+
raise ProgrammingError(
510+
driver_error=f"SQL_WCHAR only supports UTF-16 encodings",
511+
ddbc_error=(
512+
f"Cannot use encoding '{encoding}' with SQL_WCHAR. "
513+
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
514+
),
515+
)
458516

459-
# Store the encoding settings
460-
self._encoding_settings = {"encoding": encoding, "ctype": ctype}
517+
# Store the encoding settings (thread-safe with lock)
518+
with self._encoding_lock:
519+
self._encoding_settings = {"encoding": encoding, "ctype": ctype}
461520

462521
# Log with sanitized values for security
463522
logger.debug(
@@ -469,7 +528,7 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
469528

470529
def getencoding(self) -> Dict[str, Union[str, int]]:
471530
"""
472-
Gets the current text encoding settings.
531+
Gets the current text encoding settings (thread-safe).
473532
474533
Returns:
475534
dict: A dictionary containing 'encoding' and 'ctype' keys.
@@ -481,14 +540,19 @@ def getencoding(self) -> Dict[str, Union[str, int]]:
481540
settings = cnxn.getencoding()
482541
print(f"Current encoding: {settings['encoding']}")
483542
print(f"Current ctype: {settings['ctype']}")
543+
544+
Note:
545+
This method is thread-safe and can be called from multiple threads concurrently.
484546
"""
485547
if self._closed:
486548
raise InterfaceError(
487549
driver_error="Connection is closed",
488550
ddbc_error="Connection is closed",
489551
)
490552

491-
return self._encoding_settings.copy()
553+
# Thread-safe read with lock to prevent race conditions
554+
with self._encoding_lock:
555+
return self._encoding_settings.copy()
492556

493557
def setdecoding(
494558
self, sqltype: int, encoding: Optional[str] = None, ctype: Optional[int] = None
@@ -574,6 +638,38 @@ def setdecoding(
574638

575639
# Normalize encoding to lowercase for consistency
576640
encoding = encoding.lower()
641+
642+
# Reject 'utf-16' with BOM for SQL_WCHAR (ambiguous byte order)
643+
if sqltype == ConstantsDDBC.SQL_WCHAR.value and encoding == "utf-16":
644+
logger.debug(
645+
"warning",
646+
"utf-16 with BOM rejected for SQL_WCHAR",
647+
)
648+
raise ProgrammingError(
649+
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
650+
ddbc_error=(
651+
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
652+
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
653+
),
654+
)
655+
656+
# Validate SQL_WCHAR only supports UTF-16 encodings (SQL_WMETADATA is more flexible)
657+
if sqltype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
658+
logger.debug(
659+
"warning",
660+
"Non-UTF-16 encoding %s attempted with SQL_WCHAR sqltype",
661+
sanitize_user_input(encoding),
662+
)
663+
raise ProgrammingError(
664+
driver_error=f"SQL_WCHAR only supports UTF-16 encodings",
665+
ddbc_error=(
666+
f"Cannot use encoding '{encoding}' with SQL_WCHAR. "
667+
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
668+
),
669+
)
670+
671+
# SQL_WMETADATA can use any valid encoding (UTF-8, UTF-16, etc.)
672+
# No restriction needed here - let users configure as needed
577673

578674
# Set default ctype based on encoding if not provided
579675
if ctype is None:
@@ -597,9 +693,34 @@ def setdecoding(
597693
f"SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})"
598694
),
599695
)
696+
697+
# Validate that SQL_WCHAR ctype only used with UTF-16 encodings (not utf-16 with BOM)
698+
if ctype == ConstantsDDBC.SQL_WCHAR.value:
699+
if encoding == "utf-16":
700+
raise ProgrammingError(
701+
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
702+
ddbc_error=(
703+
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
704+
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
705+
),
706+
)
707+
elif encoding not in UTF16_ENCODINGS:
708+
logger.debug(
709+
"warning",
710+
"Non-UTF-16 encoding %s attempted with SQL_WCHAR ctype",
711+
sanitize_user_input(encoding),
712+
)
713+
raise ProgrammingError(
714+
driver_error=f"SQL_WCHAR ctype only supports UTF-16 encodings",
715+
ddbc_error=(
716+
f"Cannot use encoding '{encoding}' with SQL_WCHAR ctype. "
717+
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
718+
),
719+
)
600720

601-
# Store the decoding settings for the specified sqltype
602-
self._decoding_settings[sqltype] = {"encoding": encoding, "ctype": ctype}
721+
# Store the decoding settings for the specified sqltype (thread-safe with lock)
722+
with self._encoding_lock:
723+
self._decoding_settings[sqltype] = {"encoding": encoding, "ctype": ctype}
603724

604725
# Log with sanitized values for security
605726
sqltype_name = {
@@ -618,7 +739,7 @@ def setdecoding(
618739

619740
def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
620741
"""
621-
Gets the current text decoding settings for the specified SQL type.
742+
Gets the current text decoding settings for the specified SQL type (thread-safe).
622743
623744
Args:
624745
sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA.
@@ -634,6 +755,9 @@ def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
634755
settings = cnxn.getdecoding(mssql_python.SQL_CHAR)
635756
print(f"SQL_CHAR encoding: {settings['encoding']}")
636757
print(f"SQL_CHAR ctype: {settings['ctype']}")
758+
759+
Note:
760+
This method is thread-safe and can be called from multiple threads concurrently.
637761
"""
638762
if self._closed:
639763
raise InterfaceError(
@@ -657,7 +781,9 @@ def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
657781
),
658782
)
659783

660-
return self._decoding_settings[sqltype].copy()
784+
# Thread-safe read with lock to prevent race conditions
785+
with self._encoding_lock:
786+
return self._decoding_settings[sqltype].copy()
661787

662788
def set_attr(self, attribute: int, value: Union[int, str, bytes, bytearray]) -> None:
663789
"""

mssql_python/cursor.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from mssql_python.helpers import check_error
2121
from mssql_python.logging import logger
2222
from mssql_python import ddbc_bindings
23-
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError
23+
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError, OperationalError, DatabaseError
2424
from mssql_python.row import Row
2525
from mssql_python import get_settings
2626

@@ -285,6 +285,53 @@ def _get_numeric_data(self, param: decimal.Decimal) -> Any:
285285
numeric_data.val = bytes(byte_array)
286286
return numeric_data
287287

288+
def _get_encoding_settings(self):
289+
"""
290+
Get the encoding settings from the connection.
291+
292+
Returns:
293+
dict: A dictionary with 'encoding' and 'ctype' keys, or default settings if not available
294+
"""
295+
if hasattr(self._connection, 'getencoding'):
296+
try:
297+
return self._connection.getencoding()
298+
except (OperationalError, DatabaseError) as db_error:
299+
# Only catch database-related errors, not programming errors
300+
from mssql_python.helpers import log
301+
log('warning', f"Failed to get encoding settings from connection due to database error: {db_error}")
302+
return {
303+
'encoding': 'utf-16le',
304+
'ctype': ddbc_sql_const.SQL_WCHAR.value
305+
}
306+
307+
# Return default encoding settings if getencoding is not available
308+
return {
309+
'encoding': 'utf-16le',
310+
'ctype': ddbc_sql_const.SQL_WCHAR.value
311+
}
312+
313+
def _get_decoding_settings(self, sql_type):
314+
"""
315+
Get decoding settings for a specific SQL type.
316+
317+
Args:
318+
sql_type: SQL type constant (SQL_CHAR, SQL_WCHAR, etc.)
319+
320+
Returns:
321+
Dictionary containing the decoding settings.
322+
"""
323+
try:
324+
# Get decoding settings from connection for this SQL type
325+
return self._connection.getdecoding(sql_type)
326+
except (OperationalError, DatabaseError) as db_error:
327+
# Only handle expected database-related errors
328+
from mssql_python.helpers import log
329+
log('warning', f"Failed to get decoding settings for SQL type {sql_type} due to database error: {db_error}")
330+
if sql_type == ddbc_sql_const.SQL_WCHAR.value:
331+
return {'encoding': 'utf-16le', 'ctype': ddbc_sql_const.SQL_WCHAR.value}
332+
else:
333+
return {'encoding': 'utf-8', 'ctype': ddbc_sql_const.SQL_CHAR.value}
334+
288335
def _map_sql_type( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-return-statements,too-many-branches
289336
self,
290337
param: Any,
@@ -1132,6 +1179,9 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state
11321179
# Clear any previous messages
11331180
self.messages = []
11341181

1182+
# Getting encoding setting
1183+
encoding_settings = self._get_encoding_settings()
1184+
11351185
# Apply timeout if set (non-zero)
11361186
if self._timeout > 0:
11371187
logger.debug("execute: Setting query timeout=%d seconds", self._timeout)
@@ -1202,6 +1252,7 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state
12021252
parameters_type,
12031253
self.is_stmt_prepared,
12041254
use_prepare,
1255+
encoding_settings
12051256
)
12061257
# Check return code
12071258
try:
@@ -2027,6 +2078,9 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
20272078
# Now transpose the processed parameters
20282079
columnwise_params, row_count = self._transpose_rowwise_to_columnwise(processed_parameters)
20292080

2081+
# Get encoding settings
2082+
encoding_settings = self._get_encoding_settings()
2083+
20302084
# Add debug logging
20312085
logger.debug(
20322086
"Executing batch query with %d parameter sets:\n%s",
@@ -2038,7 +2092,7 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
20382092
)
20392093

20402094
ret = ddbc_bindings.SQLExecuteMany(
2041-
self.hstmt, operation, columnwise_params, parameters_type, row_count
2095+
self.hstmt, operation, columnwise_params, parameters_type, row_count, encoding_settings
20422096
)
20432097

20442098
# Capture any diagnostic messages after execution
@@ -2070,10 +2124,13 @@ def fetchone(self) -> Union[None, Row]:
20702124
"""
20712125
self._check_closed() # Check if the cursor is closed
20722126

2127+
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
2128+
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)
2129+
20732130
# Fetch raw data
20742131
row_data = []
20752132
try:
2076-
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)
2133+
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))
20772134

20782135
if self.hstmt:
20792136
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
@@ -2121,10 +2178,13 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]:
21212178
if size <= 0:
21222179
return []
21232180

2181+
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
2182+
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)
2183+
21242184
# Fetch raw data
21252185
rows_data = []
21262186
try:
2127-
_ = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)
2187+
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))
21282188

21292189
if self.hstmt:
21302190
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
@@ -2164,10 +2224,13 @@ def fetchall(self) -> List[Row]:
21642224
if not self._has_result_set and self.description:
21652225
self._reset_rownumber()
21662226

2227+
char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
2228+
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)
2229+
21672230
# Fetch raw data
21682231
rows_data = []
21692232
try:
2170-
_ = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
2233+
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))
21712234

21722235
if self.hstmt:
21732236
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

0 commit comments

Comments
 (0)