diff --git a/.gitignore b/.gitignore index a59f517..87ba9d9 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,15 @@ htmlcov/ .hypothesis/ .pytest_cache/ testresults/ +reports/ +exports/ +.ruff_cache/ + +# Virtual environments (both root and library-specific) +test-env/ +test-env-*/ +libs/*/test-env/ +libs/*/test-env-*/ # Cassandra test data cassandra_data/ diff --git a/examples/bulk_operations/docker-compose-single.yml b/examples/bulk_operations/docker-compose-single.yml deleted file mode 100644 index 073b12d..0000000 --- a/examples/bulk_operations/docker-compose-single.yml +++ /dev/null @@ -1,46 +0,0 @@ -version: '3.8' - -# Single node Cassandra for testing with limited resources - -services: - cassandra-1: - image: cassandra:5.0 - container_name: bulk-cassandra-1 - hostname: cassandra-1 - environment: - - CASSANDRA_CLUSTER_NAME=BulkOpsCluster - - CASSANDRA_DC=datacenter1 - - CASSANDRA_ENDPOINT_SNITCH=GossipingPropertyFileSnitch - - CASSANDRA_NUM_TOKENS=256 - - MAX_HEAP_SIZE=1G - - JVM_OPTS=-XX:+UseG1GC -XX:G1RSetUpdatingPauseTimePercent=5 -XX:MaxGCPauseMillis=300 - - ports: - - "9042:9042" - volumes: - - cassandra1-data:/var/lib/cassandra - - deploy: - resources: - limits: - memory: 2G - reservations: - memory: 1G - - healthcheck: - test: ["CMD-SHELL", "nodetool info | grep -q 'Native Transport active: true' && cqlsh -e 'SELECT now() FROM system.local'"] - interval: 30s - timeout: 10s - retries: 15 - start_period: 90s - - networks: - - cassandra-net - -networks: - cassandra-net: - driver: bridge - -volumes: - cassandra1-data: - driver: local diff --git a/examples/bulk_operations/docker-compose.yml b/examples/bulk_operations/docker-compose.yml deleted file mode 100644 index 82e571c..0000000 --- a/examples/bulk_operations/docker-compose.yml +++ /dev/null @@ -1,160 +0,0 @@ -version: '3.8' - -# Bulk Operations Example - 3-node Cassandra cluster -# Optimized for token-aware bulk operations testing - -services: - # First Cassandra node (seed) - cassandra-1: - image: cassandra:5.0 - container_name: bulk-cassandra-1 - hostname: cassandra-1 - environment: - # Cluster configuration - - CASSANDRA_CLUSTER_NAME=BulkOpsCluster - - CASSANDRA_SEEDS=cassandra-1 - - CASSANDRA_DC=datacenter1 - - CASSANDRA_ENDPOINT_SNITCH=GossipingPropertyFileSnitch - - CASSANDRA_NUM_TOKENS=256 - - # Memory settings (reduced for development) - - MAX_HEAP_SIZE=2G - - JVM_OPTS=-XX:+UseG1GC -XX:G1RSetUpdatingPauseTimePercent=5 -XX:MaxGCPauseMillis=300 - - ports: - - "9042:9042" - volumes: - - cassandra1-data:/var/lib/cassandra - - # Resource limits for stability - deploy: - resources: - limits: - memory: 3G - reservations: - memory: 2G - - healthcheck: - test: ["CMD-SHELL", "nodetool info | grep -q 'Native Transport active: true' && cqlsh -e 'SELECT now() FROM system.local'"] - interval: 30s - timeout: 10s - retries: 15 - start_period: 120s - - networks: - - cassandra-net - - # Second Cassandra node - cassandra-2: - image: cassandra:5.0 - container_name: bulk-cassandra-2 - hostname: cassandra-2 - environment: - - CASSANDRA_CLUSTER_NAME=BulkOpsCluster - - CASSANDRA_SEEDS=cassandra-1 - - CASSANDRA_DC=datacenter1 - - CASSANDRA_ENDPOINT_SNITCH=GossipingPropertyFileSnitch - - CASSANDRA_NUM_TOKENS=256 - - MAX_HEAP_SIZE=2G - - JVM_OPTS=-XX:+UseG1GC -XX:G1RSetUpdatingPauseTimePercent=5 -XX:MaxGCPauseMillis=300 - - ports: - - "9043:9042" - volumes: - - cassandra2-data:/var/lib/cassandra - depends_on: - cassandra-1: - condition: service_healthy - - deploy: - resources: - limits: - memory: 3G - reservations: - memory: 2G - - healthcheck: - test: ["CMD-SHELL", "nodetool info | grep -q 'Native Transport active: true' && nodetool status | grep -c UN | grep -q 2"] - interval: 30s - timeout: 10s - retries: 15 - start_period: 120s - - networks: - - cassandra-net - - # Third Cassandra node - starts after cassandra-2 to avoid overwhelming the system - cassandra-3: - image: cassandra:5.0 - container_name: bulk-cassandra-3 - hostname: cassandra-3 - environment: - - CASSANDRA_CLUSTER_NAME=BulkOpsCluster - - CASSANDRA_SEEDS=cassandra-1 - - CASSANDRA_DC=datacenter1 - - CASSANDRA_ENDPOINT_SNITCH=GossipingPropertyFileSnitch - - CASSANDRA_NUM_TOKENS=256 - - MAX_HEAP_SIZE=2G - - JVM_OPTS=-XX:+UseG1GC -XX:G1RSetUpdatingPauseTimePercent=5 -XX:MaxGCPauseMillis=300 - - ports: - - "9044:9042" - volumes: - - cassandra3-data:/var/lib/cassandra - depends_on: - cassandra-2: - condition: service_healthy - - deploy: - resources: - limits: - memory: 3G - reservations: - memory: 2G - - healthcheck: - test: ["CMD-SHELL", "nodetool info | grep -q 'Native Transport active: true' && nodetool status | grep -c UN | grep -q 3"] - interval: 30s - timeout: 10s - retries: 15 - start_period: 120s - - networks: - - cassandra-net - - # Initialization container - creates keyspace and tables - init-cassandra: - image: cassandra:5.0 - container_name: bulk-init - depends_on: - cassandra-3: - condition: service_healthy - volumes: - - ./scripts/init.cql:/init.cql:ro - command: > - bash -c " - echo 'Waiting for cluster to stabilize...'; - sleep 15; - echo 'Checking cluster status...'; - until cqlsh cassandra-1 -e 'SELECT now() FROM system.local'; do - echo 'Waiting for Cassandra to be ready...'; - sleep 5; - done; - echo 'Creating keyspace and tables...'; - cqlsh cassandra-1 -f /init.cql || echo 'Init script may have already run'; - echo 'Initialization complete!'; - " - networks: - - cassandra-net - -networks: - cassandra-net: - driver: bridge - -volumes: - cassandra1-data: - driver: local - cassandra2-data: - driver: local - cassandra3-data: - driver: local diff --git a/examples/bulk_operations/example_count.py b/examples/bulk_operations/example_count.py deleted file mode 100644 index f8b7b77..0000000 --- a/examples/bulk_operations/example_count.py +++ /dev/null @@ -1,207 +0,0 @@ -#!/usr/bin/env python3 -""" -Example: Token-aware bulk count operation. - -This example demonstrates how to count all rows in a table -using token-aware parallel processing for maximum performance. -""" - -import asyncio -import logging -import time - -from rich.console import Console -from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn -from rich.table import Table - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Rich console for pretty output -console = Console() - - -async def count_table_example(): - """Demonstrate token-aware counting of a large table.""" - - # Connect to cluster - console.print("[cyan]Connecting to Cassandra cluster...[/cyan]") - - async with AsyncCluster(contact_points=["localhost", "127.0.0.1"], port=9042) as cluster: - session = await cluster.connect() - # Create test data if needed - console.print("[yellow]Setting up test keyspace and table...[/yellow]") - - # Create keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_demo - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 3 - } - """ - ) - - # Create table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_demo.large_table ( - partition_key INT, - clustering_key INT, - data TEXT, - value DOUBLE, - PRIMARY KEY (partition_key, clustering_key) - ) - """ - ) - - # Check if we need to insert test data - result = await session.execute("SELECT COUNT(*) FROM bulk_demo.large_table LIMIT 1") - current_count = result.one().count - - if current_count < 10000: - console.print( - f"[yellow]Table has {current_count} rows. " f"Inserting test data...[/yellow]" - ) - - # Insert some test data using prepared statement - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_demo.large_table - (partition_key, clustering_key, data, value) - VALUES (?, ?, ?, ?) - """ - ) - - with Progress( - SpinnerColumn(), - *Progress.get_default_columns(), - TimeElapsedColumn(), - console=console, - ) as progress: - task = progress.add_task("[green]Inserting test data...", total=10000) - - for pk in range(100): - for ck in range(100): - await session.execute( - insert_stmt, (pk, ck, f"data-{pk}-{ck}", pk * ck * 0.1) - ) - progress.update(task, advance=1) - - # Now demonstrate bulk counting - console.print("\n[bold cyan]Token-Aware Bulk Count Demo[/bold cyan]\n") - - operator = TokenAwareBulkOperator(session) - - # Progress tracking - stats_list = [] - - def progress_callback(stats): - """Track progress during operation.""" - stats_list.append( - { - "rows": stats.rows_processed, - "ranges": stats.ranges_completed, - "total_ranges": stats.total_ranges, - "progress": stats.progress_percentage, - "rate": stats.rows_per_second, - } - ) - - # Perform count with different split counts - table = Table(title="Bulk Count Performance Comparison") - table.add_column("Split Count", style="cyan") - table.add_column("Total Rows", style="green") - table.add_column("Duration (s)", style="yellow") - table.add_column("Rows/Second", style="magenta") - table.add_column("Ranges Processed", style="blue") - - for split_count in [1, 4, 8, 16, 32]: - console.print(f"\n[cyan]Counting with {split_count} splits...[/cyan]") - - start_time = time.time() - - try: - with Progress( - SpinnerColumn(), - *Progress.get_default_columns(), - TimeElapsedColumn(), - console=console, - ) as progress: - current_task = progress.add_task( - f"[green]Counting with {split_count} splits...", total=100 - ) - - # Track progress - last_progress = 0 - - def update_progress(stats, task=current_task): - nonlocal last_progress - progress.update(task, completed=int(stats.progress_percentage)) - last_progress = stats.progress_percentage - progress_callback(stats) - - count, final_stats = await operator.count_by_token_ranges_with_stats( - keyspace="bulk_demo", - table="large_table", - split_count=split_count, - progress_callback=update_progress, - ) - - duration = time.time() - start_time - - table.add_row( - str(split_count), - f"{count:,}", - f"{duration:.2f}", - f"{final_stats.rows_per_second:,.0f}", - str(final_stats.ranges_completed), - ) - - except Exception as e: - console.print(f"[red]Error: {e}[/red]") - continue - - # Display results - console.print("\n") - console.print(table) - - # Show token range distribution - console.print("\n[bold]Token Range Analysis:[/bold]") - - from bulk_operations.token_utils import discover_token_ranges - - ranges = await discover_token_ranges(session, "bulk_demo") - - range_table = Table(title="Natural Token Ranges") - range_table.add_column("Range #", style="cyan") - range_table.add_column("Start Token", style="green") - range_table.add_column("End Token", style="yellow") - range_table.add_column("Size", style="magenta") - range_table.add_column("Replicas", style="blue") - - for i, r in enumerate(ranges[:5]): # Show first 5 - range_table.add_row( - str(i + 1), str(r.start), str(r.end), f"{r.size:,}", ", ".join(r.replicas) - ) - - if len(ranges) > 5: - range_table.add_row("...", "...", "...", "...", "...") - - console.print(range_table) - console.print(f"\nTotal natural ranges: {len(ranges)}") - - -if __name__ == "__main__": - try: - asyncio.run(count_table_example()) - except KeyboardInterrupt: - console.print("\n[yellow]Operation cancelled by user[/yellow]") - except Exception as e: - console.print(f"\n[red]Error: {e}[/red]") - logger.exception("Unexpected error") diff --git a/examples/bulk_operations/example_csv_export.py b/examples/bulk_operations/example_csv_export.py deleted file mode 100755 index 1d3ceda..0000000 --- a/examples/bulk_operations/example_csv_export.py +++ /dev/null @@ -1,230 +0,0 @@ -#!/usr/bin/env python3 -""" -Example: Export Cassandra table to CSV format. - -This demonstrates: -- Basic CSV export -- Compressed CSV export -- Custom delimiters and NULL handling -- Progress tracking -- Resume capability -""" - -import asyncio -import logging -from pathlib import Path - -from rich.console import Console -from rich.logging import RichHandler -from rich.progress import Progress, SpinnerColumn, TextColumn -from rich.table import Table - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(message)s", - handlers=[RichHandler(console=Console(stderr=True))], -) -logger = logging.getLogger(__name__) - - -async def export_examples(): - """Run various CSV export examples.""" - console = Console() - - # Connect to Cassandra - console.print("\n[bold blue]Connecting to Cassandra...[/bold blue]") - cluster = AsyncCluster(["localhost"]) - session = await cluster.connect() - - try: - # Ensure test data exists - await setup_test_data(session) - - # Create bulk operator - operator = TokenAwareBulkOperator(session) - - # Example 1: Basic CSV export - console.print("\n[bold green]Example 1: Basic CSV Export[/bold green]") - output_path = Path("exports/products.csv") - output_path.parent.mkdir(exist_ok=True) - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("Exporting to CSV...", total=None) - - def progress_callback(export_progress): - progress.update( - task, - description=f"Exported {export_progress.rows_exported:,} rows " - f"({export_progress.progress_percentage:.1f}%)", - ) - - result = await operator.export_to_csv( - keyspace="bulk_demo", - table="products", - output_path=output_path, - progress_callback=progress_callback, - ) - - console.print(f"✓ Exported {result.rows_exported:,} rows to {output_path}") - console.print(f" File size: {result.bytes_written:,} bytes") - - # Example 2: Compressed CSV with custom delimiter - console.print("\n[bold green]Example 2: Compressed Tab-Delimited Export[/bold green]") - output_path = Path("exports/products_tab.csv") - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("Exporting compressed CSV...", total=None) - - def progress_callback(export_progress): - progress.update( - task, - description=f"Exported {export_progress.rows_exported:,} rows", - ) - - result = await operator.export_to_csv( - keyspace="bulk_demo", - table="products", - output_path=output_path, - delimiter="\t", - compression="gzip", - progress_callback=progress_callback, - ) - - console.print(f"✓ Exported to {output_path}.gzip") - console.print(f" Compressed size: {result.bytes_written:,} bytes") - - # Example 3: Export with specific columns and NULL handling - console.print("\n[bold green]Example 3: Selective Column Export[/bold green]") - output_path = Path("exports/products_summary.csv") - - result = await operator.export_to_csv( - keyspace="bulk_demo", - table="products", - output_path=output_path, - columns=["id", "name", "price", "category"], - null_string="NULL", - ) - - console.print(f"✓ Exported {result.rows_exported:,} rows (selected columns)") - - # Show export summary - console.print("\n[bold cyan]Export Summary:[/bold cyan]") - summary_table = Table(show_header=True, header_style="bold magenta") - summary_table.add_column("Export", style="cyan") - summary_table.add_column("Format", style="green") - summary_table.add_column("Rows", justify="right") - summary_table.add_column("Size", justify="right") - summary_table.add_column("Compression") - - summary_table.add_row( - "products.csv", - "CSV", - "10,000", - "~500 KB", - "None", - ) - summary_table.add_row( - "products_tab.csv.gzip", - "TSV", - "10,000", - "~150 KB", - "gzip", - ) - summary_table.add_row( - "products_summary.csv", - "CSV", - "10,000", - "~300 KB", - "None", - ) - - console.print(summary_table) - - # Example 4: Demonstrate resume capability - console.print("\n[bold green]Example 4: Resume Capability[/bold green]") - console.print("Progress files saved at:") - for csv_file in Path("exports").glob("*.csv"): - progress_file = csv_file.with_suffix(".csv.progress") - if progress_file.exists(): - console.print(f" • {progress_file}") - - finally: - await session.close() - await cluster.shutdown() - - -async def setup_test_data(session): - """Create test keyspace and data if not exists.""" - # Create keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_demo - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - # Create table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_demo.products ( - id INT PRIMARY KEY, - name TEXT, - description TEXT, - price DECIMAL, - category TEXT, - in_stock BOOLEAN, - tags SET, - attributes MAP, - created_at TIMESTAMP - ) - """ - ) - - # Check if data exists - result = await session.execute("SELECT COUNT(*) FROM bulk_demo.products") - count = result.one().count - - if count < 10000: - logger.info("Inserting test data...") - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_demo.products - (id, name, description, price, category, in_stock, tags, attributes, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, toTimestamp(now())) - """ - ) - - # Insert in batches - for i in range(10000): - await session.execute( - insert_stmt, - ( - i, - f"Product {i}", - f"Description for product {i}" if i % 3 != 0 else None, - float(10 + (i % 1000) * 0.1), - ["Electronics", "Books", "Clothing", "Food"][i % 4], - i % 5 != 0, # 80% in stock - {"tag1", f"tag{i % 10}"} if i % 2 == 0 else None, - {"color": ["red", "blue", "green"][i % 3], "size": "M"} if i % 4 == 0 else {}, - ), - ) - - -if __name__ == "__main__": - asyncio.run(export_examples()) diff --git a/examples/bulk_operations/example_export_formats.py b/examples/bulk_operations/example_export_formats.py deleted file mode 100755 index f6ca15f..0000000 --- a/examples/bulk_operations/example_export_formats.py +++ /dev/null @@ -1,283 +0,0 @@ -#!/usr/bin/env python3 -""" -Example: Export Cassandra data to multiple formats. - -This demonstrates exporting to: -- CSV (with compression) -- JSON (line-delimited and array) -- Parquet (foundation for Iceberg) - -Shows why Parquet is critical for the Iceberg integration. -""" - -import asyncio -import logging -from pathlib import Path - -from rich.console import Console -from rich.logging import RichHandler -from rich.panel import Panel -from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeRemainingColumn -from rich.table import Table - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(message)s", - handlers=[RichHandler(console=Console(stderr=True))], -) -logger = logging.getLogger(__name__) - - -async def export_format_examples(): - """Demonstrate all export formats.""" - console = Console() - - # Header - console.print( - Panel.fit( - "[bold cyan]Cassandra Bulk Export Examples[/bold cyan]\n" - "Exporting to CSV, JSON, and Parquet formats", - border_style="cyan", - ) - ) - - # Connect to Cassandra - console.print("\n[bold blue]Connecting to Cassandra...[/bold blue]") - cluster = AsyncCluster(["localhost"]) - session = await cluster.connect() - - try: - # Setup test data - await setup_test_data(session) - - # Create bulk operator - operator = TokenAwareBulkOperator(session) - - # Create exports directory - exports_dir = Path("exports") - exports_dir.mkdir(exist_ok=True) - - # Export to different formats - results = {} - - # 1. CSV Export - console.print("\n[bold green]1. CSV Export (Universal Format)[/bold green]") - console.print(" • Human readable") - console.print(" • Compatible with Excel, databases, etc.") - console.print(" • Good for data exchange") - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task("Exporting to CSV...", total=100) - - def csv_progress(export_progress): - progress.update( - task, - completed=export_progress.progress_percentage, - description=f"CSV: {export_progress.rows_exported:,} rows", - ) - - results["csv"] = await operator.export_to_csv( - keyspace="export_demo", - table="events", - output_path=exports_dir / "events.csv", - compression="gzip", - progress_callback=csv_progress, - ) - - # 2. JSON Export (Line-delimited) - console.print("\n[bold green]2. JSON Export (Streaming Format)[/bold green]") - console.print(" • Preserves data types") - console.print(" • Works with streaming tools") - console.print(" • Good for data pipelines") - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task("Exporting to JSONL...", total=100) - - def json_progress(export_progress): - progress.update( - task, - completed=export_progress.progress_percentage, - description=f"JSON: {export_progress.rows_exported:,} rows", - ) - - results["json"] = await operator.export_to_json( - keyspace="export_demo", - table="events", - output_path=exports_dir / "events.jsonl", - format_mode="jsonl", - compression="gzip", - progress_callback=json_progress, - ) - - # 3. Parquet Export (Foundation for Iceberg) - console.print("\n[bold yellow]3. Parquet Export (CRITICAL for Iceberg)[/bold yellow]") - console.print(" • Columnar format for analytics") - console.print(" • Excellent compression") - console.print(" • Schema included in file") - console.print(" • [bold red]This is what Iceberg uses![/bold red]") - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task("Exporting to Parquet...", total=100) - - def parquet_progress(export_progress): - progress.update( - task, - completed=export_progress.progress_percentage, - description=f"Parquet: {export_progress.rows_exported:,} rows", - ) - - results["parquet"] = await operator.export_to_parquet( - keyspace="export_demo", - table="events", - output_path=exports_dir / "events.parquet", - compression="snappy", - row_group_size=10000, - progress_callback=parquet_progress, - ) - - # Show results comparison - console.print("\n[bold cyan]Export Results Comparison:[/bold cyan]") - comparison = Table(show_header=True, header_style="bold magenta") - comparison.add_column("Format", style="cyan") - comparison.add_column("File", style="green") - comparison.add_column("Size", justify="right") - comparison.add_column("Rows", justify="right") - comparison.add_column("Time", justify="right") - - for format_name, result in results.items(): - file_path = Path(result.output_path) - if format_name != "parquet" and result.metadata.get("compression"): - file_path = file_path.with_suffix( - file_path.suffix + f".{result.metadata['compression']}" - ) - - size_mb = result.bytes_written / (1024 * 1024) - duration = (result.completed_at - result.started_at).total_seconds() - - comparison.add_row( - format_name.upper(), - file_path.name, - f"{size_mb:.1f} MB", - f"{result.rows_exported:,}", - f"{duration:.1f}s", - ) - - console.print(comparison) - - # Explain Parquet importance - console.print( - Panel( - "[bold yellow]Why Parquet Matters for Iceberg:[/bold yellow]\n\n" - "• Iceberg tables store data in Parquet files\n" - "• Columnar format enables fast analytics queries\n" - "• Built-in schema makes evolution easier\n" - "• Compression reduces storage costs\n" - "• Row groups enable efficient filtering\n\n" - "[bold cyan]Next Phase:[/bold cyan] These Parquet files will become " - "Iceberg table data files!", - title="[bold red]The Path to Iceberg[/bold red]", - border_style="yellow", - ) - ) - - finally: - await session.close() - await cluster.shutdown() - - -async def setup_test_data(session): - """Create test keyspace and data.""" - # Create keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS export_demo - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - # Create events table with various data types - await session.execute( - """ - CREATE TABLE IF NOT EXISTS export_demo.events ( - event_id UUID PRIMARY KEY, - event_type TEXT, - user_id INT, - timestamp TIMESTAMP, - properties MAP, - tags SET, - metrics LIST, - is_processed BOOLEAN, - processing_time DECIMAL - ) - """ - ) - - # Check if data exists - result = await session.execute("SELECT COUNT(*) FROM export_demo.events") - count = result.one().count - - if count < 50000: - logger.info("Inserting test events...") - insert_stmt = await session.prepare( - """ - INSERT INTO export_demo.events - (event_id, event_type, user_id, timestamp, properties, - tags, metrics, is_processed, processing_time) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """ - ) - - # Insert test events - import uuid - from datetime import datetime, timedelta - from decimal import Decimal - - base_time = datetime.now() - timedelta(days=30) - event_types = ["login", "purchase", "view", "click", "logout"] - - for i in range(50000): - event_time = base_time + timedelta(seconds=i * 60) - - await session.execute( - insert_stmt, - ( - uuid.uuid4(), - event_types[i % len(event_types)], - i % 1000, # user_id - event_time, - {"source": "web", "version": "2.0"} if i % 3 == 0 else {}, - {f"tag{i % 5}", f"cat{i % 3}"} if i % 2 == 0 else None, - [float(i), float(i * 0.1), float(i * 0.01)] if i % 4 == 0 else None, - i % 10 != 0, # 90% processed - Decimal(str(0.001 * (i % 1000))), - ), - ) - - -if __name__ == "__main__": - asyncio.run(export_format_examples()) diff --git a/examples/bulk_operations/example_iceberg_export.py b/examples/bulk_operations/example_iceberg_export.py deleted file mode 100644 index 1a08f1b..0000000 --- a/examples/bulk_operations/example_iceberg_export.py +++ /dev/null @@ -1,302 +0,0 @@ -#!/usr/bin/env python3 -"""Example: Export Cassandra data to Apache Iceberg tables. - -This demonstrates the power of Apache Iceberg: -- ACID transactions on data lakes -- Schema evolution -- Time travel queries -- Hidden partitioning -- Integration with modern analytics tools -""" - -import asyncio -import logging -from datetime import datetime, timedelta -from pathlib import Path - -from pyiceberg.partitioning import PartitionField, PartitionSpec -from pyiceberg.transforms import DayTransform -from rich.console import Console -from rich.logging import RichHandler -from rich.panel import Panel -from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeRemainingColumn -from rich.table import Table as RichTable - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator -from bulk_operations.iceberg import IcebergExporter - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(message)s", - handlers=[RichHandler(console=Console(stderr=True))], -) -logger = logging.getLogger(__name__) - - -async def iceberg_export_demo(): - """Demonstrate Cassandra to Iceberg export with advanced features.""" - console = Console() - - # Header - console.print( - Panel.fit( - "[bold cyan]Apache Iceberg Export Demo[/bold cyan]\n" - "Exporting Cassandra data to modern data lakehouse format", - border_style="cyan", - ) - ) - - # Connect to Cassandra - console.print("\n[bold blue]1. Connecting to Cassandra...[/bold blue]") - cluster = AsyncCluster(["localhost"]) - session = await cluster.connect() - - try: - # Setup test data - await setup_demo_data(session, console) - - # Create bulk operator - operator = TokenAwareBulkOperator(session) - - # Configure Iceberg export - warehouse_path = Path("iceberg_warehouse") - console.print( - f"\n[bold blue]2. Setting up Iceberg warehouse at:[/bold blue] {warehouse_path}" - ) - - # Create Iceberg exporter - exporter = IcebergExporter( - operator=operator, - warehouse_path=warehouse_path, - compression="snappy", - row_group_size=10000, - ) - - # Example 1: Basic export - console.print("\n[bold green]Example 1: Basic Iceberg Export[/bold green]") - console.print(" • Creates Iceberg table from Cassandra schema") - console.print(" • Writes data in Parquet format") - console.print(" • Enables ACID transactions") - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task("Exporting to Iceberg...", total=100) - - def iceberg_progress(export_progress): - progress.update( - task, - completed=export_progress.progress_percentage, - description=f"Iceberg: {export_progress.rows_exported:,} rows", - ) - - result = await exporter.export( - keyspace="iceberg_demo", - table="user_events", - namespace="cassandra_export", - table_name="user_events", - progress_callback=iceberg_progress, - ) - - console.print(f"✓ Exported {result.rows_exported:,} rows to Iceberg") - console.print(" Table: iceberg://cassandra_export.user_events") - - # Example 2: Partitioned export - console.print("\n[bold green]Example 2: Partitioned Iceberg Table[/bold green]") - console.print(" • Partitions by day for efficient queries") - console.print(" • Hidden partitioning (no query changes needed)") - console.print(" • Automatic partition pruning") - - # Create partition spec (partition by day) - partition_spec = PartitionSpec( - PartitionField( - source_id=4, # event_time field ID - field_id=1000, - transform=DayTransform(), - name="event_day", - ) - ) - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task("Exporting with partitions...", total=100) - - def partition_progress(export_progress): - progress.update( - task, - completed=export_progress.progress_percentage, - description=f"Partitioned: {export_progress.rows_exported:,} rows", - ) - - result = await exporter.export( - keyspace="iceberg_demo", - table="user_events", - namespace="cassandra_export", - table_name="user_events_partitioned", - partition_spec=partition_spec, - progress_callback=partition_progress, - ) - - console.print("✓ Created partitioned Iceberg table") - console.print(" Partitioned by: event_day (daily partitions)") - - # Show Iceberg features - console.print("\n[bold cyan]Iceberg Features Enabled:[/bold cyan]") - features = RichTable(show_header=True, header_style="bold magenta") - features.add_column("Feature", style="cyan") - features.add_column("Description", style="green") - features.add_column("Example Query") - - features.add_row( - "Time Travel", - "Query data at any point in time", - "SELECT * FROM table AS OF '2025-01-01'", - ) - features.add_row( - "Schema Evolution", - "Add/drop/rename columns safely", - "ALTER TABLE table ADD COLUMN new_field STRING", - ) - features.add_row( - "Hidden Partitioning", - "Partition pruning without query changes", - "WHERE event_time > '2025-01-01' -- uses partitions", - ) - features.add_row( - "ACID Transactions", - "Atomic commits and rollbacks", - "Multiple concurrent writers supported", - ) - features.add_row( - "Incremental Processing", - "Process only new data", - "Read incrementally from snapshot N to M", - ) - - console.print(features) - - # Explain the power of Iceberg - console.print( - Panel( - "[bold yellow]Why Apache Iceberg Matters:[/bold yellow]\n\n" - "• [cyan]Netflix Scale:[/cyan] Created by Netflix to handle petabytes\n" - "• [cyan]Open Format:[/cyan] Works with Spark, Trino, Flink, and more\n" - "• [cyan]Cloud Native:[/cyan] Designed for S3, GCS, Azure storage\n" - "• [cyan]Performance:[/cyan] Faster than traditional data lakes\n" - "• [cyan]Reliability:[/cyan] ACID guarantees prevent data corruption\n\n" - "[bold green]Your Cassandra data is now ready for:[/bold green]\n" - "• Analytics with Spark or Trino\n" - "• Machine learning pipelines\n" - "• Data warehousing with Snowflake/BigQuery\n" - "• Real-time processing with Flink", - title="[bold red]The Modern Data Lakehouse[/bold red]", - border_style="yellow", - ) - ) - - # Show next steps - console.print("\n[bold blue]Next Steps:[/bold blue]") - console.print( - "1. Query with Spark: spark.read.format('iceberg').load('cassandra_export.user_events')" - ) - console.print( - "2. Time travel: SELECT * FROM user_events FOR SYSTEM_TIME AS OF '2025-01-01'" - ) - console.print("3. Schema evolution: ALTER TABLE user_events ADD COLUMNS (score DOUBLE)") - console.print(f"4. Explore warehouse: {warehouse_path}/") - - finally: - await session.close() - await cluster.shutdown() - - -async def setup_demo_data(session, console): - """Create demo keyspace and data.""" - console.print("\n[bold blue]Setting up demo data...[/bold blue]") - - # Create keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS iceberg_demo - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - # Create table with various data types - await session.execute( - """ - CREATE TABLE IF NOT EXISTS iceberg_demo.user_events ( - user_id UUID, - event_id UUID, - event_type TEXT, - event_time TIMESTAMP, - properties MAP, - metrics MAP, - tags SET, - is_processed BOOLEAN, - score DECIMAL, - PRIMARY KEY (user_id, event_time, event_id) - ) WITH CLUSTERING ORDER BY (event_time DESC, event_id ASC) - """ - ) - - # Check if data exists - result = await session.execute("SELECT COUNT(*) FROM iceberg_demo.user_events") - count = result.one().count - - if count < 10000: - console.print(" Inserting sample events...") - insert_stmt = await session.prepare( - """ - INSERT INTO iceberg_demo.user_events - (user_id, event_id, event_type, event_time, properties, - metrics, tags, is_processed, score) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """ - ) - - # Insert events over the last 30 days - import uuid - from decimal import Decimal - - base_time = datetime.now() - timedelta(days=30) - event_types = ["login", "purchase", "view", "click", "share", "logout"] - - for i in range(10000): - user_id = uuid.UUID(f"00000000-0000-0000-0000-{i % 100:012d}") - event_time = base_time + timedelta(minutes=i * 5) - - await session.execute( - insert_stmt, - ( - user_id, - uuid.uuid4(), - event_types[i % len(event_types)], - event_time, - {"device": "mobile", "version": "2.0"} if i % 3 == 0 else {}, - {"duration": float(i % 300), "count": float(i % 10)}, - {f"tag{i % 5}", f"category{i % 3}"}, - i % 10 != 0, # 90% processed - Decimal(str(0.1 * (i % 100))), - ), - ) - - console.print(" ✓ Created 10,000 events across 100 users") - - -if __name__ == "__main__": - asyncio.run(iceberg_export_demo()) diff --git a/examples/bulk_operations/fix_export_consistency.py b/examples/bulk_operations/fix_export_consistency.py deleted file mode 100644 index dbd3293..0000000 --- a/examples/bulk_operations/fix_export_consistency.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -"""Fix the export_by_token_ranges method to handle consistency level properly.""" - -# Here's the corrected version of the export_by_token_ranges method - -corrected_code = """ - # Stream results from each range - for split in splits: - # Check if this is a wraparound range - if split.end < split.start: - # Wraparound range needs to be split into two queries - # First part: from start to MAX_TOKEN - if consistency_level is not None: - async with await self.session.execute_stream( - prepared_stmts["select_wraparound_gt"], - (split.start,), - consistency_level=consistency_level - ) as result: - async for row in result: - stats.rows_processed += 1 - yield row - else: - async with await self.session.execute_stream( - prepared_stmts["select_wraparound_gt"], - (split.start,) - ) as result: - async for row in result: - stats.rows_processed += 1 - yield row - - # Second part: from MIN_TOKEN to end - if consistency_level is not None: - async with await self.session.execute_stream( - prepared_stmts["select_wraparound_lte"], - (split.end,), - consistency_level=consistency_level - ) as result: - async for row in result: - stats.rows_processed += 1 - yield row - else: - async with await self.session.execute_stream( - prepared_stmts["select_wraparound_lte"], - (split.end,) - ) as result: - async for row in result: - stats.rows_processed += 1 - yield row - else: - # Normal range - use prepared statement - if consistency_level is not None: - async with await self.session.execute_stream( - prepared_stmts["select_range"], - (split.start, split.end), - consistency_level=consistency_level - ) as result: - async for row in result: - stats.rows_processed += 1 - yield row - else: - async with await self.session.execute_stream( - prepared_stmts["select_range"], - (split.start, split.end) - ) as result: - async for row in result: - stats.rows_processed += 1 - yield row - - stats.ranges_completed += 1 - - if progress_callback: - progress_callback(stats) - - stats.end_time = time.time() -""" - -print(corrected_code) diff --git a/examples/bulk_operations/pyproject.toml b/examples/bulk_operations/pyproject.toml deleted file mode 100644 index 39dc0a8..0000000 --- a/examples/bulk_operations/pyproject.toml +++ /dev/null @@ -1,102 +0,0 @@ -[build-system] -requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "async-cassandra-bulk-operations" -version = "0.1.0" -description = "Token-aware bulk operations example for async-cassandra" -readme = "README.md" -requires-python = ">=3.12" -license = {text = "Apache-2.0"} -authors = [ - {name = "AxonOps", email = "info@axonops.com"}, -] -dependencies = [ - # For development, install async-cassandra from parent directory: - # pip install -e ../.. - # For production, use: "async-cassandra>=0.2.0", - "pyiceberg[pyarrow]>=0.8.0", - "pyarrow>=18.0.0", - "pandas>=2.0.0", - "rich>=13.0.0", # For nice progress bars - "click>=8.0.0", # For CLI -] - -[project.optional-dependencies] -dev = [ - "pytest>=8.0.0", - "pytest-asyncio>=0.24.0", - "pytest-cov>=5.0.0", - "black>=24.0.0", - "ruff>=0.8.0", - "mypy>=1.13.0", -] - -[project.scripts] -bulk-ops = "bulk_operations.cli:main" - -[tool.pytest.ini_options] -minversion = "8.0" -addopts = [ - "-ra", - "--strict-markers", - "--asyncio-mode=auto", - "--cov=bulk_operations", - "--cov-report=html", - "--cov-report=term-missing", -] -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -markers = [ - "unit: Unit tests that don't require Cassandra", - "integration: Integration tests that require a running Cassandra cluster", - "slow: Tests that take a long time to run", -] - -[tool.black] -line-length = 100 -target-version = ["py312"] -include = '\.pyi?$' - -[tool.isort] -profile = "black" -line_length = 100 -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -ensure_newline_before_comments = true -known_first_party = ["async_cassandra"] - -[tool.ruff] -line-length = 100 -target-version = "py312" - -[tool.ruff.lint] -select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - # "I", # isort - disabled since we use isort separately - "B", # flake8-bugbear - "C90", # mccabe complexity - "UP", # pyupgrade - "SIM", # flake8-simplify -] -ignore = ["E501"] # Line too long - handled by black - -[tool.mypy] -python_version = "3.12" -warn_return_any = true -warn_unused_configs = true -disallow_untyped_defs = true -disallow_incomplete_defs = true -check_untyped_defs = true -no_implicit_optional = true -warn_redundant_casts = true -warn_unused_ignores = true -warn_no_return = true -strict_equality = true diff --git a/examples/bulk_operations/run_integration_tests.sh b/examples/bulk_operations/run_integration_tests.sh deleted file mode 100755 index a25133f..0000000 --- a/examples/bulk_operations/run_integration_tests.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -# Integration test runner for bulk operations - -echo "🚀 Bulk Operations Integration Test Runner" -echo "=========================================" - -# Check if docker or podman is available -if command -v podman &> /dev/null; then - CONTAINER_TOOL="podman" -elif command -v docker &> /dev/null; then - CONTAINER_TOOL="docker" -else - echo "❌ Error: Neither docker nor podman found. Please install one." - exit 1 -fi - -echo "Using container tool: $CONTAINER_TOOL" - -# Function to wait for cluster to be ready -wait_for_cluster() { - echo "⏳ Waiting for Cassandra cluster to be ready..." - local max_attempts=60 - local attempt=0 - - while [ $attempt -lt $max_attempts ]; do - if $CONTAINER_TOOL exec bulk-cassandra-1 nodetool status 2>/dev/null | grep -q "UN"; then - echo "✅ Cassandra cluster is ready!" - return 0 - fi - attempt=$((attempt + 1)) - echo -n "." - sleep 5 - done - - echo "❌ Timeout waiting for cluster to be ready" - return 1 -} - -# Function to show cluster status -show_cluster_status() { - echo "" - echo "📊 Cluster Status:" - echo "==================" - $CONTAINER_TOOL exec bulk-cassandra-1 nodetool status || true - echo "" -} - -# Main execution -echo "" -echo "1️⃣ Starting Cassandra cluster..." -$CONTAINER_TOOL-compose up -d - -if wait_for_cluster; then - show_cluster_status - - echo "2️⃣ Running integration tests..." - echo "" - - # Run pytest with integration markers - pytest tests/test_integration.py -v -s -m integration - TEST_RESULT=$? - - echo "" - echo "3️⃣ Cluster token information:" - echo "==============================" - echo "Sample output from nodetool describering:" - $CONTAINER_TOOL exec bulk-cassandra-1 nodetool describering bulk_test 2>/dev/null | head -20 || true - - echo "" - echo "4️⃣ Test Summary:" - echo "================" - if [ $TEST_RESULT -eq 0 ]; then - echo "✅ All integration tests passed!" - else - echo "❌ Some tests failed. Please check the output above." - fi - - echo "" - read -p "Press Enter to stop the cluster, or Ctrl+C to keep it running..." - - echo "Stopping cluster..." - $CONTAINER_TOOL-compose down -else - echo "❌ Failed to start cluster. Check container logs:" - $CONTAINER_TOOL-compose logs - $CONTAINER_TOOL-compose down - exit 1 -fi - -echo "" -echo "✨ Done!" diff --git a/examples/bulk_operations/scripts/init.cql b/examples/bulk_operations/scripts/init.cql deleted file mode 100644 index 70902c6..0000000 --- a/examples/bulk_operations/scripts/init.cql +++ /dev/null @@ -1,72 +0,0 @@ --- Initialize keyspace and tables for bulk operations example --- This script creates test data for demonstrating token-aware bulk operations - --- Create keyspace with NetworkTopologyStrategy for production-like setup -CREATE KEYSPACE IF NOT EXISTS bulk_ops -WITH replication = { - 'class': 'NetworkTopologyStrategy', - 'datacenter1': 3 -} -AND durable_writes = true; - --- Use the keyspace -USE bulk_ops; - --- Create a large table for bulk operations testing -CREATE TABLE IF NOT EXISTS large_dataset ( - id UUID, - partition_key INT, - clustering_key INT, - data TEXT, - value DOUBLE, - created_at TIMESTAMP, - metadata MAP, - PRIMARY KEY (partition_key, clustering_key, id) -) WITH CLUSTERING ORDER BY (clustering_key ASC, id ASC) - AND compression = {'class': 'LZ4Compressor'} - AND compaction = {'class': 'SizeTieredCompactionStrategy'}; - --- Create an index for testing -CREATE INDEX IF NOT EXISTS idx_created_at ON large_dataset (created_at); - --- Create a table for export/import testing -CREATE TABLE IF NOT EXISTS orders ( - order_id UUID, - customer_id UUID, - order_date DATE, - order_time TIMESTAMP, - total_amount DECIMAL, - status TEXT, - items LIST>>, - shipping_address MAP, - PRIMARY KEY ((customer_id), order_date, order_id) -) WITH CLUSTERING ORDER BY (order_date DESC, order_id ASC) - AND compression = {'class': 'LZ4Compressor'}; - --- Create a simple counter table -CREATE TABLE IF NOT EXISTS page_views ( - page_id UUID, - date DATE, - views COUNTER, - PRIMARY KEY ((page_id), date) -) WITH CLUSTERING ORDER BY (date DESC); - --- Create a time series table -CREATE TABLE IF NOT EXISTS sensor_data ( - sensor_id UUID, - bucket TIMESTAMP, - reading_time TIMESTAMP, - temperature DOUBLE, - humidity DOUBLE, - pressure DOUBLE, - location FROZEN>, - PRIMARY KEY ((sensor_id, bucket), reading_time) -) WITH CLUSTERING ORDER BY (reading_time DESC) - AND compression = {'class': 'LZ4Compressor'} - AND default_time_to_live = 2592000; -- 30 days TTL - --- Grant permissions (if authentication is enabled) --- GRANT ALL ON KEYSPACE bulk_ops TO cassandra; - --- Display confirmation -SELECT keyspace_name, table_name FROM system_schema.tables WHERE keyspace_name = 'bulk_ops'; diff --git a/examples/bulk_operations/test_simple_count.py b/examples/bulk_operations/test_simple_count.py deleted file mode 100644 index 549f1ea..0000000 --- a/examples/bulk_operations/test_simple_count.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python3 -"""Simple test to debug count issue.""" - -import asyncio - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - - -async def test_count(): - """Test count with error details.""" - async with AsyncCluster(contact_points=["localhost"]) as cluster: - session = await cluster.connect() - - operator = TokenAwareBulkOperator(session) - - try: - count = await operator.count_by_token_ranges( - keyspace="bulk_test", table="test_data", split_count=4, parallelism=2 - ) - print(f"Count successful: {count}") - except Exception as e: - print(f"Error: {e}") - if hasattr(e, "errors"): - print(f"Detailed errors: {e.errors}") - for err in e.errors: - print(f" - {err}") - - -if __name__ == "__main__": - asyncio.run(test_count()) diff --git a/examples/bulk_operations/test_single_node.py b/examples/bulk_operations/test_single_node.py deleted file mode 100644 index aa762de..0000000 --- a/examples/bulk_operations/test_single_node.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -"""Quick test to verify token range discovery with single node.""" - -import asyncio - -from async_cassandra import AsyncCluster -from bulk_operations.token_utils import ( - MAX_TOKEN, - MIN_TOKEN, - TOTAL_TOKEN_RANGE, - discover_token_ranges, -) - - -async def test_single_node(): - """Test token range discovery with single node.""" - print("Connecting to single-node cluster...") - - async with AsyncCluster(contact_points=["localhost"]) as cluster: - session = await cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_single - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - print("Discovering token ranges...") - ranges = await discover_token_ranges(session, "test_single") - - print(f"\nToken ranges discovered: {len(ranges)}") - print("Expected with 1 node × 256 vnodes: 256 ranges") - - # Verify we have the expected number of ranges - assert len(ranges) == 256, f"Expected 256 ranges, got {len(ranges)}" - - # Verify ranges cover the entire ring - sorted_ranges = sorted(ranges, key=lambda r: r.start) - - # Debug first and last ranges - print(f"First range: {sorted_ranges[0].start} to {sorted_ranges[0].end}") - print(f"Last range: {sorted_ranges[-1].start} to {sorted_ranges[-1].end}") - print(f"MIN_TOKEN: {MIN_TOKEN}, MAX_TOKEN: {MAX_TOKEN}") - - # The token ring is circular, so we need to handle wraparound - # The smallest token in the sorted list might not be MIN_TOKEN - # because of how Cassandra distributes vnodes - - # Check for gaps or overlaps - gaps = [] - overlaps = [] - for i in range(len(sorted_ranges) - 1): - current = sorted_ranges[i] - next_range = sorted_ranges[i + 1] - if current.end < next_range.start: - gaps.append((current.end, next_range.start)) - elif current.end > next_range.start: - overlaps.append((current.end, next_range.start)) - - print(f"\nGaps found: {len(gaps)}") - if gaps: - for gap in gaps[:3]: - print(f" Gap: {gap[0]} to {gap[1]}") - - print(f"Overlaps found: {len(overlaps)}") - - # Check if ranges form a complete ring - # In a proper token ring, each range's end should equal the next range's start - # The last range should wrap around to the first - total_size = sum(r.size for r in ranges) - print(f"\nTotal token space covered: {total_size:,}") - print(f"Expected total space: {TOTAL_TOKEN_RANGE:,}") - - # Show sample ranges - print("\nSample token ranges (first 5):") - for i, r in enumerate(sorted_ranges[:5]): - print(f" Range {i+1}: {r.start} to {r.end} (size: {r.size:,})") - - print("\n✅ All tests passed!") - - # Session is closed automatically by the context manager - return True - - -if __name__ == "__main__": - try: - asyncio.run(test_single_node()) - except Exception as e: - print(f"❌ Error: {e}") - import traceback - - traceback.print_exc() - exit(1) diff --git a/examples/bulk_operations/tests/__init__.py b/examples/bulk_operations/tests/__init__.py deleted file mode 100644 index ce61b96..0000000 --- a/examples/bulk_operations/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Test package for bulk operations.""" diff --git a/examples/bulk_operations/tests/conftest.py b/examples/bulk_operations/tests/conftest.py deleted file mode 100644 index 4445379..0000000 --- a/examples/bulk_operations/tests/conftest.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -Pytest configuration for bulk operations tests. - -Handles test markers and Docker/Podman support. -""" - -import os -import subprocess -from pathlib import Path - -import pytest - - -def get_container_runtime(): - """Detect whether to use docker or podman.""" - # Check environment variable first - runtime = os.environ.get("CONTAINER_RUNTIME", "").lower() - if runtime in ["docker", "podman"]: - return runtime - - # Auto-detect - for cmd in ["docker", "podman"]: - try: - subprocess.run([cmd, "--version"], capture_output=True, check=True) - return cmd - except (subprocess.CalledProcessError, FileNotFoundError): - continue - - raise RuntimeError("Neither docker nor podman found. Please install one.") - - -# Set container runtime globally -CONTAINER_RUNTIME = get_container_runtime() -os.environ["CONTAINER_RUNTIME"] = CONTAINER_RUNTIME - - -def pytest_configure(config): - """Configure pytest with custom markers.""" - config.addinivalue_line("markers", "unit: Unit tests that don't require external services") - config.addinivalue_line("markers", "integration: Integration tests requiring Cassandra cluster") - config.addinivalue_line("markers", "slow: Tests that take a long time to run") - - -def pytest_collection_modifyitems(config, items): - """Automatically skip integration tests if not explicitly requested.""" - if config.getoption("markexpr"): - # User specified markers, respect their choice - return - - # Check if Cassandra is available - cassandra_available = check_cassandra_available() - - skip_integration = pytest.mark.skip( - reason="Integration tests require running Cassandra cluster. Use -m integration to run." - ) - - for item in items: - if "integration" in item.keywords and not cassandra_available: - item.add_marker(skip_integration) - - -def check_cassandra_available(): - """Check if Cassandra cluster is available.""" - try: - # Try to connect to the first node - import socket - - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - result = sock.connect_ex(("127.0.0.1", 9042)) - sock.close() - return result == 0 - except Exception: - return False - - -@pytest.fixture(scope="session") -def container_runtime(): - """Get the container runtime being used.""" - return CONTAINER_RUNTIME - - -@pytest.fixture(scope="session") -def docker_compose_file(): - """Path to docker-compose file.""" - return Path(__file__).parent.parent / "docker-compose.yml" - - -@pytest.fixture(scope="session") -def docker_compose_command(container_runtime): - """Get the appropriate docker-compose command.""" - if container_runtime == "podman": - return ["podman-compose"] - else: - return ["docker-compose"] diff --git a/examples/bulk_operations/tests/integration/README.md b/examples/bulk_operations/tests/integration/README.md deleted file mode 100644 index 25138a4..0000000 --- a/examples/bulk_operations/tests/integration/README.md +++ /dev/null @@ -1,100 +0,0 @@ -# Integration Tests for Bulk Operations - -This directory contains integration tests that validate bulk operations against a real Cassandra cluster. - -## Test Organization - -The integration tests are organized into logical modules: - -- **test_token_discovery.py** - Tests for token range discovery with vnodes - - Validates token range discovery matches cluster configuration - - Compares with nodetool describering output - - Ensures complete ring coverage without gaps - -- **test_bulk_count.py** - Tests for bulk count operations - - Validates full data coverage (no missing/duplicate rows) - - Tests wraparound range handling - - Performance testing with different parallelism levels - -- **test_bulk_export.py** - Tests for bulk export operations - - Validates streaming export completeness - - Tests memory efficiency for large exports - - Handles different CQL data types - -- **test_token_splitting.py** - Tests for token range splitting strategies - - Tests proportional splitting based on range sizes - - Handles small vnode ranges appropriately - - Validates replica-aware clustering - -## Running Integration Tests - -Integration tests require a running Cassandra cluster. They are skipped by default. - -### Run all integration tests: -```bash -pytest tests/integration --integration -``` - -### Run specific test module: -```bash -pytest tests/integration/test_bulk_count.py --integration -v -``` - -### Run specific test: -```bash -pytest tests/integration/test_bulk_count.py::TestBulkCount::test_full_table_coverage_with_token_ranges --integration -v -``` - -## Test Infrastructure - -### Automatic Cassandra Startup - -The tests will automatically start a single-node Cassandra container if one is not already running, using either: -- `docker-compose-single.yml` (via docker-compose or podman-compose) - -### Manual Cassandra Setup - -You can also manually start Cassandra: - -```bash -# Single node (recommended for basic tests) -podman-compose -f docker-compose-single.yml up -d - -# Multi-node cluster (for advanced tests) -podman-compose -f docker-compose.yml up -d -``` - -### Test Fixtures - -Common fixtures are defined in `conftest.py`: -- `ensure_cassandra` - Session-scoped fixture that ensures Cassandra is running -- `cluster` - Creates AsyncCluster connection -- `session` - Creates test session with keyspace - -## Test Requirements - -- Cassandra 4.0+ (or ScyllaDB) -- Docker or Podman with compose -- Python packages: pytest, pytest-asyncio, async-cassandra - -## Debugging Tips - -1. **View Cassandra logs:** - ```bash - podman logs bulk-cassandra-1 - ``` - -2. **Check token ranges manually:** - ```bash - podman exec bulk-cassandra-1 nodetool describering bulk_test - ``` - -3. **Run with verbose output:** - ```bash - pytest tests/integration --integration -v -s - ``` - -4. **Run with coverage:** - ```bash - pytest tests/integration --integration --cov=bulk_operations - ``` diff --git a/examples/bulk_operations/tests/integration/__init__.py b/examples/bulk_operations/tests/integration/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/bulk_operations/tests/integration/conftest.py b/examples/bulk_operations/tests/integration/conftest.py deleted file mode 100644 index c4f43aa..0000000 --- a/examples/bulk_operations/tests/integration/conftest.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Shared configuration and fixtures for integration tests. -""" - -import os -import subprocess -import time - -import pytest - - -def is_cassandra_running(): - """Check if Cassandra is accessible on localhost.""" - try: - from cassandra.cluster import Cluster - - cluster = Cluster(["localhost"]) - session = cluster.connect() - session.shutdown() - cluster.shutdown() - return True - except Exception: - return False - - -def start_cassandra_if_needed(): - """Start Cassandra using docker-compose if not already running.""" - if is_cassandra_running(): - return True - - # Try to start single-node Cassandra - compose_file = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "docker-compose-single.yml" - ) - - if not os.path.exists(compose_file): - return False - - print("\nStarting Cassandra container for integration tests...") - - # Try podman first, then docker - for cmd in ["podman-compose", "docker-compose"]: - try: - subprocess.run([cmd, "-f", compose_file, "up", "-d"], check=True, capture_output=True) - break - except (subprocess.CalledProcessError, FileNotFoundError): - continue - else: - print("Could not start Cassandra - neither podman-compose nor docker-compose found") - return False - - # Wait for Cassandra to be ready - print("Waiting for Cassandra to be ready...") - for _i in range(60): # Wait up to 60 seconds - if is_cassandra_running(): - print("Cassandra is ready!") - return True - time.sleep(1) - - print("Cassandra failed to start in time") - return False - - -@pytest.fixture(scope="session", autouse=True) -def ensure_cassandra(): - """Ensure Cassandra is running for integration tests.""" - if not start_cassandra_if_needed(): - pytest.skip("Cassandra is not available for integration tests") - - -# Skip integration tests if not explicitly requested -def pytest_collection_modifyitems(config, items): - """Skip integration tests unless --integration flag is passed.""" - if not config.getoption("--integration", default=False): - skip_integration = pytest.mark.skip( - reason="Integration tests not requested (use --integration flag)" - ) - for item in items: - if "integration" in item.keywords: - item.add_marker(skip_integration) - - -def pytest_addoption(parser): - """Add custom command line options.""" - parser.addoption( - "--integration", action="store_true", default=False, help="Run integration tests" - ) diff --git a/examples/bulk_operations/tests/integration/test_bulk_count.py b/examples/bulk_operations/tests/integration/test_bulk_count.py deleted file mode 100644 index 8c94b5d..0000000 --- a/examples/bulk_operations/tests/integration/test_bulk_count.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -Integration tests for bulk count operations. - -What this tests: ---------------- -1. Full data coverage with token ranges (no missing/duplicate rows) -2. Wraparound range handling -3. Count accuracy across different data distributions -4. Performance with parallelism - -Why this matters: ----------------- -- Count is the simplest bulk operation - if it fails, everything fails -- Proves our token range queries are correct -- Gaps mean data loss in production -- Duplicates mean incorrect counting -- Critical for data integrity -""" - -import asyncio - -import pytest - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - - -@pytest.mark.integration -class TestBulkCount: - """Test bulk count operations against real Cassandra cluster.""" - - @pytest.fixture - async def cluster(self): - """Create connection to test cluster.""" - cluster = AsyncCluster( - contact_points=["localhost"], - port=9042, - ) - yield cluster - await cluster.shutdown() - - @pytest.fixture - async def session(self, cluster): - """Create test session with keyspace and table.""" - session = await cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_test - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - # Create test table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_test.test_data ( - id INT PRIMARY KEY, - data TEXT, - value DOUBLE - ) - """ - ) - - # Clear any existing data - await session.execute("TRUNCATE bulk_test.test_data") - - yield session - - @pytest.mark.asyncio - async def test_full_table_coverage_with_token_ranges(self, session): - """ - Test that token ranges cover all data without gaps or duplicates. - - What this tests: - --------------- - 1. Insert known dataset across token range - 2. Count using token ranges - 3. Verify exact match with direct count - 4. No missing or duplicate rows - - Why this matters: - ---------------- - - Proves our token range queries are correct - - Gaps mean data loss in production - - Duplicates mean incorrect counting - - Critical for data integrity - """ - # Insert test data with known count - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.test_data (id, data, value) - VALUES (?, ?, ?) - """ - ) - - expected_count = 10000 - print(f"\nInserting {expected_count} test rows...") - - # Insert in batches for efficiency - batch_size = 100 - for i in range(0, expected_count, batch_size): - tasks = [] - for j in range(batch_size): - if i + j < expected_count: - tasks.append(session.execute(insert_stmt, (i + j, f"data-{i+j}", float(i + j)))) - await asyncio.gather(*tasks) - - # Count using direct query - result = await session.execute("SELECT COUNT(*) FROM bulk_test.test_data") - direct_count = result.one().count - assert ( - direct_count == expected_count - ), f"Direct count mismatch: {direct_count} vs {expected_count}" - - # Count using token ranges - operator = TokenAwareBulkOperator(session) - token_count = await operator.count_by_token_ranges( - keyspace="bulk_test", - table="test_data", - split_count=16, # Moderate splitting - parallelism=8, - ) - - print("\nCount comparison:") - print(f" Direct count: {direct_count}") - print(f" Token range count: {token_count}") - - assert ( - token_count == direct_count - ), f"Token range count mismatch: {token_count} vs {direct_count}" - - @pytest.mark.asyncio - async def test_count_with_wraparound_ranges(self, session): - """ - Test counting specifically with wraparound ranges. - - What this tests: - --------------- - 1. Insert data that falls in wraparound range - 2. Verify wraparound range is properly split - 3. Count includes all data - 4. No double counting - - Why this matters: - ---------------- - - Wraparound ranges are tricky edge cases - - CQL doesn't support OR in token queries - - Must split into two queries properly - - Common source of bugs - """ - # Insert test data - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.test_data (id, data, value) - VALUES (?, ?, ?) - """ - ) - - # Insert data with IDs that we know will hash to extreme token values - test_ids = [] - for i in range(50000, 60000): # Test range that includes wraparound tokens - test_ids.append(i) - - print(f"\nInserting {len(test_ids)} test rows...") - batch_size = 100 - for i in range(0, len(test_ids), batch_size): - tasks = [] - for j in range(batch_size): - if i + j < len(test_ids): - id_val = test_ids[i + j] - tasks.append( - session.execute(insert_stmt, (id_val, f"data-{id_val}", float(id_val))) - ) - await asyncio.gather(*tasks) - - # Get direct count - result = await session.execute("SELECT COUNT(*) FROM bulk_test.test_data") - direct_count = result.one().count - - # Count using token ranges with different split counts - operator = TokenAwareBulkOperator(session) - - for split_count in [4, 8, 16, 32]: - token_count = await operator.count_by_token_ranges( - keyspace="bulk_test", - table="test_data", - split_count=split_count, - parallelism=4, - ) - - print(f"\nSplit count {split_count}: {token_count} rows") - assert ( - token_count == direct_count - ), f"Count mismatch with {split_count} splits: {token_count} vs {direct_count}" - - @pytest.mark.asyncio - async def test_parallel_count_performance(self, session): - """ - Test parallel execution improves count performance. - - What this tests: - --------------- - 1. Count performance with different parallelism levels - 2. Results are consistent across parallelism levels - 3. No deadlocks or timeouts - 4. Higher parallelism provides benefit - - Why this matters: - ---------------- - - Parallel execution is the main benefit - - Must handle concurrent queries properly - - Performance validation - - Resource efficiency - """ - # Insert more data for meaningful parallelism test - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.test_data (id, data, value) - VALUES (?, ?, ?) - """ - ) - - # Clear and insert fresh data - await session.execute("TRUNCATE bulk_test.test_data") - - row_count = 50000 - print(f"\nInserting {row_count} rows for parallel test...") - - batch_size = 500 - for i in range(0, row_count, batch_size): - tasks = [] - for j in range(batch_size): - if i + j < row_count: - tasks.append(session.execute(insert_stmt, (i + j, f"data-{i+j}", float(i + j)))) - await asyncio.gather(*tasks) - - operator = TokenAwareBulkOperator(session) - - # Test with different parallelism levels - import time - - results = [] - for parallelism in [1, 2, 4, 8]: - start_time = time.time() - - count = await operator.count_by_token_ranges( - keyspace="bulk_test", table="test_data", split_count=32, parallelism=parallelism - ) - - duration = time.time() - start_time - results.append( - { - "parallelism": parallelism, - "count": count, - "duration": duration, - "rows_per_sec": count / duration, - } - ) - - print(f"\nParallelism {parallelism}:") - print(f" Count: {count}") - print(f" Duration: {duration:.2f}s") - print(f" Rows/sec: {count/duration:,.0f}") - - # All counts should be identical - counts = [r["count"] for r in results] - assert len(set(counts)) == 1, f"Inconsistent counts: {counts}" - - # Higher parallelism should generally be faster - # (though not always due to overhead) - assert ( - results[-1]["duration"] < results[0]["duration"] * 1.5 - ), "Parallel execution not providing benefit" - - @pytest.mark.asyncio - async def test_count_with_progress_callback(self, session): - """ - Test progress callback during count operations. - - What this tests: - --------------- - 1. Progress callbacks are invoked correctly - 2. Stats are accurate and updated - 3. Progress percentage is calculated correctly - 4. Final stats match actual results - - Why this matters: - ---------------- - - Users need progress feedback for long operations - - Stats help with monitoring and debugging - - Progress tracking enables better UX - - Critical for production observability - """ - # Insert test data - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.test_data (id, data, value) - VALUES (?, ?, ?) - """ - ) - - expected_count = 5000 - for i in range(expected_count): - await session.execute(insert_stmt, (i, f"data-{i}", float(i))) - - operator = TokenAwareBulkOperator(session) - - # Track progress callbacks - progress_updates = [] - - def progress_callback(stats): - progress_updates.append( - { - "rows": stats.rows_processed, - "ranges_completed": stats.ranges_completed, - "total_ranges": stats.total_ranges, - "percentage": stats.progress_percentage, - } - ) - - # Count with progress tracking - count, stats = await operator.count_by_token_ranges_with_stats( - keyspace="bulk_test", - table="test_data", - split_count=8, - parallelism=4, - progress_callback=progress_callback, - ) - - print(f"\nProgress updates received: {len(progress_updates)}") - print(f"Final count: {count}") - print( - f"Final stats: rows={stats.rows_processed}, ranges={stats.ranges_completed}/{stats.total_ranges}" - ) - - # Verify results - assert count == expected_count, f"Count mismatch: {count} vs {expected_count}" - assert stats.rows_processed == expected_count - assert stats.ranges_completed == stats.total_ranges - assert stats.success is True - assert len(stats.errors) == 0 - assert len(progress_updates) > 0, "No progress callbacks received" - - # Verify progress increased monotonically - for i in range(1, len(progress_updates)): - assert ( - progress_updates[i]["ranges_completed"] - >= progress_updates[i - 1]["ranges_completed"] - ) diff --git a/examples/bulk_operations/tests/integration/test_bulk_export.py b/examples/bulk_operations/tests/integration/test_bulk_export.py deleted file mode 100644 index 35e5eef..0000000 --- a/examples/bulk_operations/tests/integration/test_bulk_export.py +++ /dev/null @@ -1,382 +0,0 @@ -""" -Integration tests for bulk export operations. - -What this tests: ---------------- -1. Export captures all rows exactly once -2. Streaming doesn't exhaust memory -3. Order within ranges is preserved -4. Async iteration works correctly -5. Export handles different data types - -Why this matters: ----------------- -- Export must be complete and accurate -- Memory efficiency critical for large tables -- Streaming enables TB-scale exports -- Foundation for Iceberg integration -""" - -import asyncio - -import pytest - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - - -@pytest.mark.integration -class TestBulkExport: - """Test bulk export operations against real Cassandra cluster.""" - - @pytest.fixture - async def cluster(self): - """Create connection to test cluster.""" - cluster = AsyncCluster( - contact_points=["localhost"], - port=9042, - ) - yield cluster - await cluster.shutdown() - - @pytest.fixture - async def session(self, cluster): - """Create test session with keyspace and table.""" - session = await cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_test - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - # Create test table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_test.test_data ( - id INT PRIMARY KEY, - data TEXT, - value DOUBLE - ) - """ - ) - - # Clear any existing data - await session.execute("TRUNCATE bulk_test.test_data") - - yield session - - @pytest.mark.asyncio - async def test_export_streaming_completeness(self, session): - """ - Test streaming export doesn't miss or duplicate data. - - What this tests: - --------------- - 1. Export captures all rows exactly once - 2. Streaming doesn't exhaust memory - 3. Order within ranges is preserved - 4. Async iteration works correctly - - Why this matters: - ---------------- - - Export must be complete and accurate - - Memory efficiency critical for large tables - - Streaming enables TB-scale exports - - Foundation for Iceberg integration - """ - # Use smaller dataset for export test - await session.execute("TRUNCATE bulk_test.test_data") - - # Insert test data - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.test_data (id, data, value) - VALUES (?, ?, ?) - """ - ) - - expected_ids = set(range(1000)) - for i in expected_ids: - await session.execute(insert_stmt, (i, f"data-{i}", float(i))) - - # Export using token ranges - operator = TokenAwareBulkOperator(session) - - exported_ids = set() - row_count = 0 - - async for row in operator.export_by_token_ranges( - keyspace="bulk_test", table="test_data", split_count=16 - ): - exported_ids.add(row.id) - row_count += 1 - - # Verify row data integrity - assert row.data == f"data-{row.id}" - assert row.value == float(row.id) - - print("\nExport results:") - print(f" Expected rows: {len(expected_ids)}") - print(f" Exported rows: {row_count}") - print(f" Unique IDs: {len(exported_ids)}") - - # Verify completeness - assert row_count == len( - expected_ids - ), f"Row count mismatch: {row_count} vs {len(expected_ids)}" - - assert exported_ids == expected_ids, ( - f"Missing IDs: {expected_ids - exported_ids}, " - f"Duplicate IDs: {exported_ids - expected_ids}" - ) - - @pytest.mark.asyncio - async def test_export_with_wraparound_ranges(self, session): - """ - Test export handles wraparound ranges correctly. - - What this tests: - --------------- - 1. Data in wraparound ranges is exported - 2. No duplicates from split queries - 3. All edge cases handled - 4. Consistent with count operation - - Why this matters: - ---------------- - - Wraparound ranges are common with vnodes - - Export must handle same edge cases as count - - Data integrity is critical - - Foundation for all bulk operations - """ - # Insert data that will span wraparound ranges - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.test_data (id, data, value) - VALUES (?, ?, ?) - """ - ) - - # Insert data with various IDs to ensure coverage - test_data = {} - for i in range(0, 10000, 100): # Sparse data to hit various ranges - test_data[i] = f"data-{i}" - await session.execute(insert_stmt, (i, test_data[i], float(i))) - - # Export and verify - operator = TokenAwareBulkOperator(session) - - exported_data = {} - async for row in operator.export_by_token_ranges( - keyspace="bulk_test", - table="test_data", - split_count=32, # More splits to ensure wraparound handling - ): - exported_data[row.id] = row.data - - print(f"\nExported {len(exported_data)} rows") - assert len(exported_data) == len( - test_data - ), f"Export count mismatch: {len(exported_data)} vs {len(test_data)}" - - # Verify all data was exported correctly - for id_val, expected_data in test_data.items(): - assert id_val in exported_data, f"Missing ID {id_val}" - assert ( - exported_data[id_val] == expected_data - ), f"Data mismatch for ID {id_val}: {exported_data[id_val]} vs {expected_data}" - - @pytest.mark.asyncio - async def test_export_memory_efficiency(self, session): - """ - Test export streaming is memory efficient. - - What this tests: - --------------- - 1. Large exports don't consume excessive memory - 2. Streaming works as expected - 3. Can handle tables larger than memory - 4. Progress tracking during export - - Why this matters: - ---------------- - - Production tables can be TB in size - - Must stream, not buffer all data - - Memory efficiency enables large exports - - Critical for operational feasibility - """ - # Insert larger dataset - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.test_data (id, data, value) - VALUES (?, ?, ?) - """ - ) - - row_count = 10000 - print(f"\nInserting {row_count} rows for memory test...") - - # Insert in batches - batch_size = 100 - for i in range(0, row_count, batch_size): - tasks = [] - for j in range(batch_size): - if i + j < row_count: - # Create larger data values to test memory - data = f"data-{i+j}" * 10 # Make data larger - tasks.append(session.execute(insert_stmt, (i + j, data, float(i + j)))) - await asyncio.gather(*tasks) - - operator = TokenAwareBulkOperator(session) - - # Track memory usage indirectly via row processing rate - rows_exported = 0 - batch_timings = [] - - import time - - start_time = time.time() - last_batch_time = start_time - - async for _row in operator.export_by_token_ranges( - keyspace="bulk_test", table="test_data", split_count=16 - ): - rows_exported += 1 - - # Track timing every 1000 rows - if rows_exported % 1000 == 0: - current_time = time.time() - batch_duration = current_time - last_batch_time - batch_timings.append(batch_duration) - last_batch_time = current_time - print(f" Exported {rows_exported} rows...") - - total_duration = time.time() - start_time - - print("\nExport completed:") - print(f" Total rows: {rows_exported}") - print(f" Total time: {total_duration:.2f}s") - print(f" Rows/sec: {rows_exported/total_duration:.0f}") - - # Verify all rows exported - assert rows_exported == row_count, f"Export count mismatch: {rows_exported} vs {row_count}" - - # Verify consistent performance (no major slowdowns from memory pressure) - if len(batch_timings) > 2: - avg_batch_time = sum(batch_timings) / len(batch_timings) - max_batch_time = max(batch_timings) - assert ( - max_batch_time < avg_batch_time * 3 - ), "Export performance degraded, possible memory issue" - - @pytest.mark.asyncio - async def test_export_with_different_data_types(self, session): - """ - Test export handles various CQL data types correctly. - - What this tests: - --------------- - 1. Different data types are exported correctly - 2. NULL values handled properly - 3. Collections exported accurately - 4. Special characters preserved - - Why this matters: - ---------------- - - Real tables have diverse data types - - Export must preserve data fidelity - - Type handling affects Iceberg mapping - - Data integrity across formats - """ - # Create table with various data types - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_test.complex_data ( - id INT PRIMARY KEY, - text_col TEXT, - int_col INT, - double_col DOUBLE, - bool_col BOOLEAN, - list_col LIST, - set_col SET, - map_col MAP - ) - """ - ) - - await session.execute("TRUNCATE bulk_test.complex_data") - - # Insert test data with various types - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.complex_data - (id, text_col, int_col, double_col, bool_col, list_col, set_col, map_col) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """ - ) - - test_data = [ - (1, "normal text", 100, 1.5, True, ["a", "b", "c"], {1, 2, 3}, {"x": 1, "y": 2}), - (2, "special chars: 'quotes' \"double\" \n newline", -50, -2.5, False, [], set(), {}), - (3, None, None, None, None, None, None, None), # NULL values - (4, "", 0, 0.0, True, [""], {0}, {"": 0}), # Empty/zero values - (5, "unicode: 你好 🌟", 999999, 3.14159, False, ["α", "β", "γ"], {-1, -2}, {"π": 314}), - ] - - for row in test_data: - await session.execute(insert_stmt, row) - - # Export and verify - operator = TokenAwareBulkOperator(session) - - exported_rows = [] - async for row in operator.export_by_token_ranges( - keyspace="bulk_test", table="complex_data", split_count=4 - ): - exported_rows.append(row) - - print(f"\nExported {len(exported_rows)} rows with complex data types") - assert len(exported_rows) == len( - test_data - ), f"Export count mismatch: {len(exported_rows)} vs {len(test_data)}" - - # Sort both by ID for comparison - exported_rows.sort(key=lambda r: r.id) - test_data.sort(key=lambda r: r[0]) - - # Verify each row's data - for exported, expected in zip(exported_rows, test_data, strict=False): - assert exported.id == expected[0] - assert exported.text_col == expected[1] - assert exported.int_col == expected[2] - assert exported.double_col == expected[3] - assert exported.bool_col == expected[4] - - # Collections need special handling - # Note: Cassandra treats empty collections as NULL - if expected[5] is not None and expected[5] != []: - assert exported.list_col is not None, f"list_col is None for row {exported.id}" - assert list(exported.list_col) == expected[5] - else: - # Empty list or None in Cassandra returns as None - assert exported.list_col is None - - if expected[6] is not None and expected[6] != set(): - assert exported.set_col is not None, f"set_col is None for row {exported.id}" - assert set(exported.set_col) == expected[6] - else: - # Empty set or None in Cassandra returns as None - assert exported.set_col is None - - if expected[7] is not None and expected[7] != {}: - assert exported.map_col is not None, f"map_col is None for row {exported.id}" - assert dict(exported.map_col) == expected[7] - else: - # Empty map or None in Cassandra returns as None - assert exported.map_col is None diff --git a/examples/bulk_operations/tests/integration/test_data_integrity.py b/examples/bulk_operations/tests/integration/test_data_integrity.py deleted file mode 100644 index 1e82a58..0000000 --- a/examples/bulk_operations/tests/integration/test_data_integrity.py +++ /dev/null @@ -1,466 +0,0 @@ -""" -Integration tests for data integrity - verifying inserted data is correctly returned. - -What this tests: ---------------- -1. Data inserted is exactly what gets exported -2. All data types are preserved correctly -3. No data corruption during token range queries -4. Prepared statements maintain data integrity - -Why this matters: ----------------- -- Proves end-to-end data correctness -- Validates our token range implementation -- Ensures no data loss or corruption -- Critical for production confidence -""" - -import asyncio -import uuid -from datetime import datetime -from decimal import Decimal - -import pytest - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - - -@pytest.mark.integration -class TestDataIntegrity: - """Test that data inserted equals data exported.""" - - @pytest.fixture - async def cluster(self): - """Create connection to test cluster.""" - cluster = AsyncCluster( - contact_points=["localhost"], - port=9042, - ) - yield cluster - await cluster.shutdown() - - @pytest.fixture - async def session(self, cluster): - """Create test session with keyspace and tables.""" - session = await cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_test - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - yield session - - @pytest.mark.asyncio - async def test_simple_data_round_trip(self, session): - """ - Test that simple data inserted is exactly what we get back. - - What this tests: - --------------- - 1. Insert known dataset with various values - 2. Export using token ranges - 3. Verify every field matches exactly - 4. No missing or corrupted data - - Why this matters: - ---------------- - - Basic data integrity validation - - Ensures token range queries don't corrupt data - - Validates prepared statement parameter handling - - Foundation for trusting bulk operations - """ - # Create a simple test table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_test.integrity_test ( - id INT PRIMARY KEY, - name TEXT, - value DOUBLE, - active BOOLEAN - ) - """ - ) - - await session.execute("TRUNCATE bulk_test.integrity_test") - - # Insert test data with prepared statement - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.integrity_test (id, name, value, active) - VALUES (?, ?, ?, ?) - """ - ) - - # Create test dataset with various values - test_data = [ - (1, "Alice", 100.5, True), - (2, "Bob", -50.25, False), - (3, "Charlie", 0.0, True), - (4, None, 999.999, None), # Test NULLs - (5, "", -0.001, False), # Empty string - (6, "Special chars: 'quotes' \"double\"", 3.14159, True), - (7, "Unicode: 你好 🌟", 2.71828, False), - (8, "Very long name " * 100, 1.23456, True), # Long string - ] - - # Insert all test data - for row in test_data: - await session.execute(insert_stmt, row) - - # Export using bulk operator - operator = TokenAwareBulkOperator(session) - exported_data = [] - - async for row in operator.export_by_token_ranges( - keyspace="bulk_test", - table="integrity_test", - split_count=4, # Use multiple ranges to test splitting - ): - exported_data.append((row.id, row.name, row.value, row.active)) - - # Sort both datasets by ID for comparison - test_data_sorted = sorted(test_data, key=lambda x: x[0]) - exported_data_sorted = sorted(exported_data, key=lambda x: x[0]) - - # Verify we got all rows - assert len(exported_data_sorted) == len( - test_data_sorted - ), f"Row count mismatch: exported {len(exported_data_sorted)} vs inserted {len(test_data_sorted)}" - - # Verify each row matches exactly - for inserted, exported in zip(test_data_sorted, exported_data_sorted, strict=False): - assert ( - inserted == exported - ), f"Data mismatch for ID {inserted[0]}: inserted {inserted} vs exported {exported}" - - print(f"\n✓ All {len(test_data)} rows verified - data integrity maintained") - - @pytest.mark.asyncio - async def test_complex_data_types_round_trip(self, session): - """ - Test complex CQL data types maintain integrity. - - What this tests: - --------------- - 1. Collections (list, set, map) - 2. UUID types - 3. Timestamp/date types - 4. Decimal types - 5. Large text/blob data - - Why this matters: - ---------------- - - Real tables use complex types - - Collections need special handling - - Precision must be maintained - - Production data is complex - """ - # Create table with complex types - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_test.complex_integrity ( - id UUID PRIMARY KEY, - created TIMESTAMP, - amount DECIMAL, - tags SET, - metadata MAP, - events LIST, - data BLOB - ) - """ - ) - - await session.execute("TRUNCATE bulk_test.complex_integrity") - - # Insert test data - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.complex_integrity - (id, created, amount, tags, metadata, events, data) - VALUES (?, ?, ?, ?, ?, ?, ?) - """ - ) - - # Create test data - test_id = uuid.uuid4() - test_created = datetime.utcnow().replace(microsecond=0) # Cassandra timestamp precision - test_amount = Decimal("12345.6789") - test_tags = {"python", "cassandra", "async", "test"} - test_metadata = {"version": 1, "retries": 3, "timeout": 30} - test_events = [ - datetime(2024, 1, 1, 10, 0, 0), - datetime(2024, 1, 2, 11, 30, 0), - datetime(2024, 1, 3, 15, 45, 0), - ] - test_data = b"Binary data with \x00 null bytes and \xff high bytes" - - # Insert the data - await session.execute( - insert_stmt, - ( - test_id, - test_created, - test_amount, - test_tags, - test_metadata, - test_events, - test_data, - ), - ) - - # Export and verify - operator = TokenAwareBulkOperator(session) - exported_rows = [] - - async for row in operator.export_by_token_ranges( - keyspace="bulk_test", - table="complex_integrity", - split_count=2, - ): - exported_rows.append(row) - - # Should have exactly one row - assert len(exported_rows) == 1, f"Expected 1 row, got {len(exported_rows)}" - - row = exported_rows[0] - - # Verify each field - assert row.id == test_id, f"UUID mismatch: {row.id} vs {test_id}" - assert row.created == test_created, f"Timestamp mismatch: {row.created} vs {test_created}" - assert row.amount == test_amount, f"Decimal mismatch: {row.amount} vs {test_amount}" - assert set(row.tags) == test_tags, f"Set mismatch: {set(row.tags)} vs {test_tags}" - assert ( - dict(row.metadata) == test_metadata - ), f"Map mismatch: {dict(row.metadata)} vs {test_metadata}" - assert ( - list(row.events) == test_events - ), f"List mismatch: {list(row.events)} vs {test_events}" - assert bytes(row.data) == test_data, f"Blob mismatch: {bytes(row.data)} vs {test_data}" - - print("\n✓ Complex data types verified - all types preserved correctly") - - @pytest.mark.asyncio - async def test_large_dataset_integrity(self, session): # noqa: C901 - """ - Test integrity with larger dataset across many token ranges. - - What this tests: - --------------- - 1. 50K rows with computed values - 2. Verify no rows lost in token ranges - 3. Verify no duplicate rows - 4. Check computed values match - - Why this matters: - ---------------- - - Production tables are large - - Token range bugs appear at scale - - Wraparound ranges must work correctly - - Performance under load - """ - # Create table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_test.large_integrity ( - id INT PRIMARY KEY, - computed_value DOUBLE, - hash_value TEXT - ) - """ - ) - - await session.execute("TRUNCATE bulk_test.large_integrity") - - # Insert data with computed values - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.large_integrity (id, computed_value, hash_value) - VALUES (?, ?, ?) - """ - ) - - # Function to compute expected values - def compute_value(id_val): - return float(id_val * 3.14159 + id_val**0.5) - - def compute_hash(id_val): - return f"hash_{id_val % 1000:03d}_{id_val}" - - # Insert 50K rows in batches - total_rows = 50000 - batch_size = 1000 - - print(f"\nInserting {total_rows} rows for large dataset test...") - - for batch_start in range(0, total_rows, batch_size): - tasks = [] - for i in range(batch_start, min(batch_start + batch_size, total_rows)): - tasks.append( - session.execute( - insert_stmt, - ( - i, - compute_value(i), - compute_hash(i), - ), - ) - ) - await asyncio.gather(*tasks) - - if (batch_start + batch_size) % 10000 == 0: - print(f" Inserted {batch_start + batch_size} rows...") - - # Export all data - operator = TokenAwareBulkOperator(session) - exported_ids = set() - value_mismatches = [] - hash_mismatches = [] - - print("\nExporting and verifying data...") - - async for row in operator.export_by_token_ranges( - keyspace="bulk_test", - table="large_integrity", - split_count=32, # Many splits to test range handling - ): - # Check for duplicates - if row.id in exported_ids: - pytest.fail(f"Duplicate ID exported: {row.id}") - exported_ids.add(row.id) - - # Verify computed values - expected_value = compute_value(row.id) - if abs(row.computed_value - expected_value) > 0.0001: # Float precision - value_mismatches.append((row.id, row.computed_value, expected_value)) - - expected_hash = compute_hash(row.id) - if row.hash_value != expected_hash: - hash_mismatches.append((row.id, row.hash_value, expected_hash)) - - # Verify completeness - assert ( - len(exported_ids) == total_rows - ), f"Missing rows: exported {len(exported_ids)} vs inserted {total_rows}" - - # Check for missing IDs - expected_ids = set(range(total_rows)) - missing_ids = expected_ids - exported_ids - if missing_ids: - pytest.fail(f"Missing IDs: {sorted(list(missing_ids))[:10]}...") # Show first 10 - - # Check for value mismatches - if value_mismatches: - pytest.fail(f"Value mismatches found: {value_mismatches[:5]}...") # Show first 5 - - if hash_mismatches: - pytest.fail(f"Hash mismatches found: {hash_mismatches[:5]}...") # Show first 5 - - print(f"\n✓ All {total_rows} rows verified - large dataset integrity maintained") - print(" - No missing rows") - print(" - No duplicate rows") - print(" - All computed values correct") - print(" - All hash values correct") - - @pytest.mark.asyncio - async def test_wraparound_range_data_integrity(self, session): - """ - Test data integrity specifically for wraparound token ranges. - - What this tests: - --------------- - 1. Insert data with known tokens that span wraparound - 2. Verify wraparound range handling preserves data - 3. No data lost at ring boundaries - 4. Prepared statements work correctly with wraparound - - Why this matters: - ---------------- - - Wraparound ranges are error-prone - - Must split into two queries correctly - - Data at ring boundaries is critical - - Common source of data loss bugs - """ - # Create table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS bulk_test.wraparound_test ( - id INT PRIMARY KEY, - token_value BIGINT, - data TEXT - ) - """ - ) - - await session.execute("TRUNCATE bulk_test.wraparound_test") - - # First, let's find some IDs that hash to extreme token values - print("\nFinding IDs with extreme token values...") - - # Insert some data and check their tokens - insert_stmt = await session.prepare( - """ - INSERT INTO bulk_test.wraparound_test (id, token_value, data) - VALUES (?, ?, ?) - """ - ) - - # Try different IDs to find ones with extreme tokens - test_ids = [] - for i in range(100000, 200000): - # First insert a dummy row to query the token - await session.execute(insert_stmt, (i, 0, f"dummy_{i}")) - result = await session.execute( - f"SELECT token(id) as t FROM bulk_test.wraparound_test WHERE id = {i}" - ) - row = result.one() - if row: - token = row.t - # Remove the dummy row - await session.execute(f"DELETE FROM bulk_test.wraparound_test WHERE id = {i}") - - # Look for very high positive or very low negative tokens - if token > 9000000000000000000 or token < -9000000000000000000: - test_ids.append((i, token)) - await session.execute(insert_stmt, (i, token, f"data_{i}")) - - if len(test_ids) >= 20: - break - - print(f" Found {len(test_ids)} IDs with extreme tokens") - - # Export and verify - operator = TokenAwareBulkOperator(session) - exported_data = {} - - async for row in operator.export_by_token_ranges( - keyspace="bulk_test", - table="wraparound_test", - split_count=8, - ): - exported_data[row.id] = (row.token_value, row.data) - - # Verify all data was exported - for id_val, token_val in test_ids: - assert id_val in exported_data, f"Missing ID {id_val} with token {token_val}" - - exported_token, exported_data_val = exported_data[id_val] - assert ( - exported_token == token_val - ), f"Token mismatch for ID {id_val}: {exported_token} vs {token_val}" - assert ( - exported_data_val == f"data_{id_val}" - ), f"Data mismatch for ID {id_val}: {exported_data_val} vs data_{id_val}" - - print("\n✓ Wraparound range data integrity verified") - print(f" - All {len(test_ids)} extreme token rows exported correctly") - print(" - Token values preserved") - print(" - Data values preserved") diff --git a/examples/bulk_operations/tests/integration/test_export_formats.py b/examples/bulk_operations/tests/integration/test_export_formats.py deleted file mode 100644 index eedf0ee..0000000 --- a/examples/bulk_operations/tests/integration/test_export_formats.py +++ /dev/null @@ -1,449 +0,0 @@ -""" -Integration tests for export formats. - -What this tests: ---------------- -1. CSV export with real data -2. JSON export formats (JSONL and array) -3. Parquet export with schema mapping -4. Compression options -5. Data integrity across formats - -Why this matters: ----------------- -- Export formats are critical for data pipelines -- Each format has different use cases -- Parquet is foundation for Iceberg -- Must preserve data types correctly -""" - -import csv -import gzip -import json - -import pytest - -try: - import pyarrow.parquet as pq - - PYARROW_AVAILABLE = True -except ImportError: - PYARROW_AVAILABLE = False - -from async_cassandra import AsyncCluster -from bulk_operations.bulk_operator import TokenAwareBulkOperator - - -@pytest.mark.integration -class TestExportFormats: - """Test export to different formats.""" - - @pytest.fixture - async def cluster(self): - """Create connection to test cluster.""" - cluster = AsyncCluster( - contact_points=["localhost"], - port=9042, - ) - yield cluster - await cluster.shutdown() - - @pytest.fixture - async def session(self, cluster): - """Create test session with test data.""" - session = await cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS export_test - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - # Create test table with various types - await session.execute( - """ - CREATE TABLE IF NOT EXISTS export_test.data_types ( - id INT PRIMARY KEY, - text_val TEXT, - int_val INT, - float_val FLOAT, - bool_val BOOLEAN, - list_val LIST, - set_val SET, - map_val MAP, - null_val TEXT - ) - """ - ) - - # Clear and insert test data - await session.execute("TRUNCATE export_test.data_types") - - insert_stmt = await session.prepare( - """ - INSERT INTO export_test.data_types - (id, text_val, int_val, float_val, bool_val, - list_val, set_val, map_val, null_val) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """ - ) - - # Insert diverse test data - test_data = [ - (1, "test1", 100, 1.5, True, ["a", "b"], {1, 2}, {"k1": "v1"}, None), - (2, "test2", -50, -2.5, False, [], None, {}, None), - (3, "special'chars\"test", 0, 0.0, True, None, {0}, None, None), - (4, "unicode_test_你好", 999, 3.14, False, ["x"], {-1}, {"k": "v"}, None), - ] - - for row in test_data: - await session.execute(insert_stmt, row) - - yield session - - @pytest.mark.asyncio - async def test_csv_export_basic(self, session, tmp_path): - """ - Test basic CSV export functionality. - - What this tests: - --------------- - 1. CSV export creates valid file - 2. All rows are exported - 3. Data types are properly serialized - 4. NULL values handled correctly - - Why this matters: - ---------------- - - CSV is most common export format - - Must work with Excel and other tools - - Data integrity is critical - """ - operator = TokenAwareBulkOperator(session) - output_path = tmp_path / "test.csv" - - # Export to CSV - result = await operator.export_to_csv( - keyspace="export_test", - table="data_types", - output_path=output_path, - ) - - # Verify file exists - assert output_path.exists() - assert result.rows_exported == 4 - - # Read and verify content - with open(output_path) as f: - reader = csv.DictReader(f) - rows = list(reader) - - assert len(rows) == 4 - - # Verify first row - row1 = rows[0] - assert row1["id"] == "1" - assert row1["text_val"] == "test1" - assert row1["int_val"] == "100" - assert row1["float_val"] == "1.5" - assert row1["bool_val"] == "true" - assert "[a, b]" in row1["list_val"] - assert row1["null_val"] == "" # Default NULL representation - - @pytest.mark.asyncio - async def test_csv_export_compressed(self, session, tmp_path): - """ - Test CSV export with compression. - - What this tests: - --------------- - 1. Gzip compression works - 2. File has correct extension - 3. Compressed data is valid - 4. Size reduction achieved - - Why this matters: - ---------------- - - Large exports need compression - - Network transfer efficiency - - Storage cost reduction - """ - operator = TokenAwareBulkOperator(session) - output_path = tmp_path / "test.csv" - - # Export with compression - await operator.export_to_csv( - keyspace="export_test", - table="data_types", - output_path=output_path, - compression="gzip", - ) - - # Verify compressed file - compressed_path = output_path.with_suffix(".csv.gzip") - assert compressed_path.exists() - - # Read compressed content - with gzip.open(compressed_path, "rt") as f: - reader = csv.DictReader(f) - rows = list(reader) - - assert len(rows) == 4 - - @pytest.mark.asyncio - async def test_json_export_line_delimited(self, session, tmp_path): - """ - Test JSON line-delimited export. - - What this tests: - --------------- - 1. JSONL format (one JSON per line) - 2. Each line is valid JSON - 3. Data types preserved - 4. Collections handled correctly - - Why this matters: - ---------------- - - JSONL works with streaming tools - - Each line can be processed independently - - Better for large datasets - """ - operator = TokenAwareBulkOperator(session) - output_path = tmp_path / "test.jsonl" - - # Export as JSONL - result = await operator.export_to_json( - keyspace="export_test", - table="data_types", - output_path=output_path, - format_mode="jsonl", - ) - - assert output_path.exists() - assert result.rows_exported == 4 - - # Read and verify JSONL - with open(output_path) as f: - lines = f.readlines() - - assert len(lines) == 4 - - # Parse each line - rows = [json.loads(line) for line in lines] - - # Verify data types - row1 = rows[0] - assert row1["id"] == 1 - assert row1["text_val"] == "test1" - assert row1["bool_val"] is True - assert row1["list_val"] == ["a", "b"] - assert row1["set_val"] == [1, 2] # Sets become lists in JSON - assert row1["map_val"] == {"k1": "v1"} - assert row1["null_val"] is None - - @pytest.mark.asyncio - async def test_json_export_array(self, session, tmp_path): - """ - Test JSON array export. - - What this tests: - --------------- - 1. Valid JSON array format - 2. Proper array structure - 3. Pretty printing option - 4. Complete document - - Why this matters: - ---------------- - - Some APIs expect JSON arrays - - Easier for small datasets - - Human readable with indent - """ - operator = TokenAwareBulkOperator(session) - output_path = tmp_path / "test.json" - - # Export as JSON array - await operator.export_to_json( - keyspace="export_test", - table="data_types", - output_path=output_path, - format_mode="array", - indent=2, - ) - - assert output_path.exists() - - # Read and parse JSON - with open(output_path) as f: - data = json.load(f) - - assert isinstance(data, list) - assert len(data) == 4 - - # Verify structure - assert all(isinstance(row, dict) for row in data) - - @pytest.mark.asyncio - @pytest.mark.skipif(not PYARROW_AVAILABLE, reason="PyArrow not installed") - async def test_parquet_export(self, session, tmp_path): - """ - Test Parquet export - foundation for Iceberg. - - What this tests: - --------------- - 1. Valid Parquet file created - 2. Schema correctly mapped - 3. Data types preserved - 4. Row groups created - - Why this matters: - ---------------- - - Parquet is THE format for Iceberg - - Columnar storage for analytics - - Schema evolution support - - Excellent compression - """ - operator = TokenAwareBulkOperator(session) - output_path = tmp_path / "test.parquet" - - # Export to Parquet - result = await operator.export_to_parquet( - keyspace="export_test", - table="data_types", - output_path=output_path, - row_group_size=2, # Small for testing - ) - - assert output_path.exists() - assert result.rows_exported == 4 - - # Read Parquet file - table = pq.read_table(output_path) - - # Verify schema - schema = table.schema - assert "id" in schema.names - assert "text_val" in schema.names - assert "bool_val" in schema.names - - # Verify data - df = table.to_pandas() - assert len(df) == 4 - - # Check data types preserved - assert df.loc[0, "id"] == 1 - assert df.loc[0, "text_val"] == "test1" - assert df.loc[0, "bool_val"] is True or df.loc[0, "bool_val"] == 1 # numpy bool comparison - - # Verify row groups - parquet_file = pq.ParquetFile(output_path) - assert parquet_file.num_row_groups == 2 # 4 rows / 2 per group - - @pytest.mark.asyncio - async def test_export_with_column_selection(self, session, tmp_path): - """ - Test exporting specific columns only. - - What this tests: - --------------- - 1. Column selection works - 2. Only selected columns exported - 3. Order preserved - 4. Works across all formats - - Why this matters: - ---------------- - - Reduce export size - - Privacy/security (exclude sensitive columns) - - Performance optimization - """ - operator = TokenAwareBulkOperator(session) - columns = ["id", "text_val", "bool_val"] - - # Test CSV - csv_path = tmp_path / "selected.csv" - await operator.export_to_csv( - keyspace="export_test", - table="data_types", - output_path=csv_path, - columns=columns, - ) - - with open(csv_path) as f: - reader = csv.DictReader(f) - row = next(reader) - assert set(row.keys()) == set(columns) - - # Test JSON - json_path = tmp_path / "selected.jsonl" - await operator.export_to_json( - keyspace="export_test", - table="data_types", - output_path=json_path, - columns=columns, - ) - - with open(json_path) as f: - row = json.loads(f.readline()) - assert set(row.keys()) == set(columns) - - @pytest.mark.asyncio - async def test_export_progress_tracking(self, session, tmp_path): - """ - Test progress tracking and resume capability. - - What this tests: - --------------- - 1. Progress callbacks invoked - 2. Progress saved to file - 3. Resume information correct - 4. Stats accurately tracked - - Why this matters: - ---------------- - - Long exports need monitoring - - Resume saves time on failures - - Users need feedback - """ - operator = TokenAwareBulkOperator(session) - output_path = tmp_path / "progress_test.csv" - - progress_updates = [] - - async def track_progress(progress): - progress_updates.append( - { - "rows": progress.rows_exported, - "bytes": progress.bytes_written, - "percentage": progress.progress_percentage, - } - ) - - # Export with progress tracking - result = await operator.export_to_csv( - keyspace="export_test", - table="data_types", - output_path=output_path, - progress_callback=track_progress, - ) - - # Verify progress was tracked - assert len(progress_updates) > 0 - assert result.rows_exported == 4 - assert result.bytes_written > 0 - - # Verify progress file - progress_file = output_path.with_suffix(".csv.progress") - assert progress_file.exists() - - # Load and verify progress - from bulk_operations.exporters import ExportProgress - - loaded = ExportProgress.load(progress_file) - assert loaded.rows_exported == 4 - assert loaded.is_complete diff --git a/examples/bulk_operations/tests/integration/test_token_discovery.py b/examples/bulk_operations/tests/integration/test_token_discovery.py deleted file mode 100644 index b99115f..0000000 --- a/examples/bulk_operations/tests/integration/test_token_discovery.py +++ /dev/null @@ -1,198 +0,0 @@ -""" -Integration tests for token range discovery with vnodes. - -What this tests: ---------------- -1. Token range discovery matches cluster vnodes configuration -2. Validation against nodetool describering output -3. Token distribution across nodes -4. Non-overlapping and complete token coverage - -Why this matters: ----------------- -- Vnodes create hundreds of non-contiguous ranges -- Token metadata must match cluster reality -- Incorrect discovery means data loss -- Production clusters always use vnodes -""" - -import subprocess -from collections import defaultdict - -import pytest - -from async_cassandra import AsyncCluster -from bulk_operations.token_utils import TOTAL_TOKEN_RANGE, discover_token_ranges - - -@pytest.mark.integration -class TestTokenDiscovery: - """Test token range discovery against real Cassandra cluster.""" - - @pytest.fixture - async def cluster(self): - """Create connection to test cluster.""" - # Connect to all three nodes - cluster = AsyncCluster( - contact_points=["localhost", "127.0.0.1", "127.0.0.2"], - port=9042, - ) - yield cluster - await cluster.shutdown() - - @pytest.fixture - async def session(self, cluster): - """Create test session with keyspace.""" - session = await cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_test - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 3 - } - """ - ) - - yield session - - @pytest.mark.asyncio - async def test_token_range_discovery_with_vnodes(self, session): - """ - Test token range discovery matches cluster vnodes configuration. - - What this tests: - --------------- - 1. Number of ranges matches vnode configuration - 2. Each node owns approximately equal ranges - 3. All ranges have correct replica information - 4. Token ranges are non-overlapping and complete - - Why this matters: - ---------------- - - With 256 vnodes × 3 nodes = ~768 ranges expected - - Vnodes distribute ownership across the ring - - Incorrect discovery means data loss - - Must handle non-contiguous ownership correctly - """ - ranges = await discover_token_ranges(session, "bulk_test") - - # With 3 nodes and 256 vnodes each, expect many ranges - # Due to replication factor 3, each range has 3 replicas - assert len(ranges) > 100, f"Expected many ranges with vnodes, got {len(ranges)}" - - # Count ranges per node - ranges_per_node = defaultdict(int) - for r in ranges: - for replica in r.replicas: - ranges_per_node[replica] += 1 - - print(f"\nToken ranges discovered: {len(ranges)}") - print("Ranges per node:") - for node, count in sorted(ranges_per_node.items()): - print(f" {node}: {count} ranges") - - # Each node should own approximately the same number of ranges - counts = list(ranges_per_node.values()) - if len(counts) >= 3: - avg_count = sum(counts) / len(counts) - for count in counts: - # Allow 20% variance - assert ( - 0.8 * avg_count <= count <= 1.2 * avg_count - ), f"Uneven distribution: {ranges_per_node}" - - # Verify ranges cover the entire ring - sorted_ranges = sorted(ranges, key=lambda r: r.start) - - # With vnodes, tokens are randomly distributed, so the first range - # won't necessarily start at MIN_TOKEN. What matters is: - # 1. No gaps between consecutive ranges - # 2. The last range wraps around to the first range - # 3. Total coverage equals the token space - - # Check for gaps or overlaps between consecutive ranges - gaps = 0 - for i in range(len(sorted_ranges) - 1): - current = sorted_ranges[i] - next_range = sorted_ranges[i + 1] - - # Ranges should be contiguous - if current.end != next_range.start: - gaps += 1 - print(f"Gap found: {current.end} to {next_range.start}") - - assert gaps == 0, f"Found {gaps} gaps in token ranges" - - # Verify the last range wraps around to the first - assert sorted_ranges[-1].end == sorted_ranges[0].start, ( - f"Ring not closed: last range ends at {sorted_ranges[-1].end}, " - f"first range starts at {sorted_ranges[0].start}" - ) - - # Verify total coverage - total_size = sum(r.size for r in ranges) - # Allow for small rounding differences - assert abs(total_size - TOTAL_TOKEN_RANGE) <= len( - ranges - ), f"Total coverage {total_size} differs from expected {TOTAL_TOKEN_RANGE}" - - @pytest.mark.asyncio - async def test_compare_with_nodetool_describering(self, session): - """ - Compare discovered ranges with nodetool describering output. - - What this tests: - --------------- - 1. Our discovery matches nodetool output - 2. Token boundaries are correct - 3. Replica assignments match - 4. No missing or extra ranges - - Why this matters: - ---------------- - - nodetool is the source of truth - - Mismatches indicate bugs in discovery - - Critical for production reliability - - Validates driver metadata accuracy - """ - ranges = await discover_token_ranges(session, "bulk_test") - - # Get nodetool output from first node - try: - result = subprocess.run( - ["podman", "exec", "bulk-cassandra-1", "nodetool", "describering", "bulk_test"], - capture_output=True, - text=True, - check=True, - ) - nodetool_output = result.stdout - except subprocess.CalledProcessError: - # Try docker if podman fails - try: - result = subprocess.run( - ["docker", "exec", "bulk-cassandra-1", "nodetool", "describering", "bulk_test"], - capture_output=True, - text=True, - check=True, - ) - nodetool_output = result.stdout - except subprocess.CalledProcessError as e: - pytest.skip(f"Cannot run nodetool: {e}") - - print("\nNodetool describering output (first 20 lines):") - print("\n".join(nodetool_output.split("\n")[:20])) - - # Parse token count from nodetool output - token_ranges_in_output = nodetool_output.count("TokenRange") - - print("\nComparison:") - print(f" Discovered ranges: {len(ranges)}") - print(f" Nodetool ranges: {token_ranges_in_output}") - - # Should have same number of ranges (allowing small variance) - assert ( - abs(len(ranges) - token_ranges_in_output) <= 5 - ), f"Mismatch in range count: discovered {len(ranges)} vs nodetool {token_ranges_in_output}" diff --git a/examples/bulk_operations/tests/integration/test_token_splitting.py b/examples/bulk_operations/tests/integration/test_token_splitting.py deleted file mode 100644 index 72bc290..0000000 --- a/examples/bulk_operations/tests/integration/test_token_splitting.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -Integration tests for token range splitting functionality. - -What this tests: ---------------- -1. Token range splitting with different strategies -2. Proportional splitting based on range sizes -3. Handling of very small ranges (vnodes) -4. Replica-aware clustering - -Why this matters: ----------------- -- Efficient parallelism requires good splitting -- Vnodes create many small ranges that shouldn't be over-split -- Replica clustering improves coordinator efficiency -- Performance optimization foundation -""" - -import pytest - -from async_cassandra import AsyncCluster -from bulk_operations.token_utils import TokenRangeSplitter, discover_token_ranges - - -@pytest.mark.integration -class TestTokenSplitting: - """Test token range splitting strategies.""" - - @pytest.fixture - async def cluster(self): - """Create connection to test cluster.""" - cluster = AsyncCluster( - contact_points=["localhost"], - port=9042, - ) - yield cluster - await cluster.shutdown() - - @pytest.fixture - async def session(self, cluster): - """Create test session with keyspace.""" - session = await cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_test - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - yield session - - @pytest.mark.asyncio - async def test_token_range_splitting_with_vnodes(self, session): - """ - Test that splitting handles vnode token ranges correctly. - - What this tests: - --------------- - 1. Natural ranges from vnodes are small - 2. Splitting respects range boundaries - 3. Very small ranges aren't over-split - 4. Large splits still cover all ranges - - Why this matters: - ---------------- - - Vnodes create many small ranges - - Over-splitting causes overhead - - Under-splitting reduces parallelism - - Must balance performance - """ - ranges = await discover_token_ranges(session, "bulk_test") - splitter = TokenRangeSplitter() - - # Test different split counts - for split_count in [10, 50, 100, 500]: - splits = splitter.split_proportionally(ranges, split_count) - - print(f"\nSplitting {len(ranges)} ranges into {split_count} splits:") - print(f" Actual splits: {len(splits)}") - - # Verify coverage - total_size = sum(r.size for r in ranges) - split_size = sum(s.size for s in splits) - - assert split_size == total_size, f"Split size mismatch: {split_size} vs {total_size}" - - # With vnodes, we might not achieve the exact split count - # because many ranges are too small to split - if split_count < len(ranges): - assert ( - len(splits) >= split_count * 0.5 - ), f"Too few splits: {len(splits)} (wanted ~{split_count})" - - @pytest.mark.asyncio - async def test_single_range_splitting(self, session): - """ - Test splitting of individual token ranges. - - What this tests: - --------------- - 1. Single range can be split evenly - 2. Last split gets remainder - 3. Small ranges aren't over-split - 4. Split boundaries are correct - - Why this matters: - ---------------- - - Foundation of proportional splitting - - Must handle edge cases correctly - - Affects query generation - - Performance depends on even distribution - """ - ranges = await discover_token_ranges(session, "bulk_test") - splitter = TokenRangeSplitter() - - # Find a reasonably large range to test - sorted_ranges = sorted(ranges, key=lambda r: r.size, reverse=True) - large_range = sorted_ranges[0] - - print("\nTesting single range splitting:") - print(f" Range size: {large_range.size}") - print(f" Range: {large_range.start} to {large_range.end}") - - # Test different split counts - for split_count in [1, 2, 5, 10]: - splits = splitter.split_single_range(large_range, split_count) - - print(f"\n Splitting into {split_count}:") - print(f" Actual splits: {len(splits)}") - - # Verify coverage - assert sum(s.size for s in splits) == large_range.size - - # Verify contiguous - for i in range(len(splits) - 1): - assert splits[i].end == splits[i + 1].start - - # Verify boundaries - assert splits[0].start == large_range.start - assert splits[-1].end == large_range.end - - # Verify replicas preserved - for s in splits: - assert s.replicas == large_range.replicas - - @pytest.mark.asyncio - async def test_replica_clustering(self, session): - """ - Test clustering ranges by replica sets. - - What this tests: - --------------- - 1. Ranges are correctly grouped by replicas - 2. All ranges are included in clusters - 3. No ranges are duplicated - 4. Replica sets are handled consistently - - Why this matters: - ---------------- - - Coordinator efficiency depends on replica locality - - Reduces network hops in multi-DC setups - - Improves cache utilization - - Foundation for topology-aware operations - """ - # For this test, use multi-node replication - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS bulk_test_replicated - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 3 - } - """ - ) - - ranges = await discover_token_ranges(session, "bulk_test_replicated") - splitter = TokenRangeSplitter() - - clusters = splitter.cluster_by_replicas(ranges) - - print("\nReplica clustering results:") - print(f" Total ranges: {len(ranges)}") - print(f" Replica clusters: {len(clusters)}") - - total_clustered = sum(len(ranges_list) for ranges_list in clusters.values()) - print(f" Total ranges in clusters: {total_clustered}") - - # Verify all ranges are clustered - assert total_clustered == len( - ranges - ), f"Not all ranges clustered: {total_clustered} vs {len(ranges)}" - - # Verify no duplicates - seen_ranges = set() - for _replica_set, range_list in clusters.items(): - for r in range_list: - range_key = (r.start, r.end) - assert range_key not in seen_ranges, f"Duplicate range: {range_key}" - seen_ranges.add(range_key) - - # Print cluster distribution - for replica_set, range_list in sorted(clusters.items()): - print(f" Replicas {replica_set}: {len(range_list)} ranges") - - @pytest.mark.asyncio - async def test_proportional_splitting_accuracy(self, session): - """ - Test that proportional splitting maintains relative sizes. - - What this tests: - --------------- - 1. Large ranges get more splits than small ones - 2. Total coverage is preserved - 3. Split distribution matches range distribution - 4. No ranges are lost or duplicated - - Why this matters: - ---------------- - - Even work distribution across ranges - - Prevents hotspots from uneven splitting - - Optimizes parallel execution - - Critical for performance - """ - ranges = await discover_token_ranges(session, "bulk_test") - splitter = TokenRangeSplitter() - - # Calculate range size distribution - total_size = sum(r.size for r in ranges) - range_fractions = [(r, r.size / total_size) for r in ranges] - - # Sort by size for analysis - range_fractions.sort(key=lambda x: x[1], reverse=True) - - print("\nRange size distribution:") - print(f" Largest range: {range_fractions[0][1]:.2%} of total") - print(f" Smallest range: {range_fractions[-1][1]:.2%} of total") - print(f" Median range: {range_fractions[len(range_fractions)//2][1]:.2%} of total") - - # Test proportional splitting - target_splits = 100 - splits = splitter.split_proportionally(ranges, target_splits) - - # Analyze split distribution - splits_per_range = {} - for split in splits: - # Find which original range this split came from - for orig_range in ranges: - if (split.start >= orig_range.start and split.end <= orig_range.end) or ( - orig_range.start == split.start and orig_range.end == split.end - ): - key = (orig_range.start, orig_range.end) - splits_per_range[key] = splits_per_range.get(key, 0) + 1 - break - - # Verify proportionality - print("\nProportional splitting results:") - print(f" Target splits: {target_splits}") - print(f" Actual splits: {len(splits)}") - print(f" Ranges that got splits: {len(splits_per_range)}") - - # Large ranges should get more splits - large_range = range_fractions[0][0] - large_range_key = (large_range.start, large_range.end) - large_range_splits = splits_per_range.get(large_range_key, 0) - - small_range = range_fractions[-1][0] - small_range_key = (small_range.start, small_range.end) - small_range_splits = splits_per_range.get(small_range_key, 0) - - print(f" Largest range got {large_range_splits} splits") - print(f" Smallest range got {small_range_splits} splits") - - # Large ranges should generally get more splits - # (unless they're still too small to split effectively) - if large_range.size > small_range.size * 10: - assert ( - large_range_splits >= small_range_splits - ), "Large range should get at least as many splits as small range" diff --git a/examples/bulk_operations/tests/unit/__init__.py b/examples/bulk_operations/tests/unit/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/bulk_operations/tests/unit/test_bulk_operator.py b/examples/bulk_operations/tests/unit/test_bulk_operator.py deleted file mode 100644 index af03562..0000000 --- a/examples/bulk_operations/tests/unit/test_bulk_operator.py +++ /dev/null @@ -1,381 +0,0 @@ -""" -Unit tests for TokenAwareBulkOperator. - -What this tests: ---------------- -1. Parallel execution of token range queries -2. Result aggregation and streaming -3. Progress tracking -4. Error handling and recovery - -Why this matters: ----------------- -- Ensures correct parallel processing -- Validates data completeness -- Confirms non-blocking async behavior -- Handles failures gracefully - -Additional context: ---------------------------------- -These tests mock the async-cassandra library to test -our bulk operation logic in isolation. -""" - -import asyncio -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from bulk_operations.bulk_operator import ( - BulkOperationError, - BulkOperationStats, - TokenAwareBulkOperator, -) - - -class TestTokenAwareBulkOperator: - """Test the main bulk operator class.""" - - @pytest.fixture - def mock_cluster(self): - """Create a mock AsyncCluster.""" - cluster = Mock() - cluster.contact_points = ["127.0.0.1", "127.0.0.2", "127.0.0.3"] - return cluster - - @pytest.fixture - def mock_session(self, mock_cluster): - """Create a mock AsyncSession.""" - session = Mock() - # Mock the underlying sync session that has cluster attribute - session._session = Mock() - session._session.cluster = mock_cluster - session.execute = AsyncMock() - session.execute_stream = AsyncMock() - session.prepare = AsyncMock(return_value=Mock()) # Mock prepare method - - # Mock metadata structure - metadata = Mock() - - # Create proper column mock - partition_key_col = Mock() - partition_key_col.name = "id" # Set the name attribute properly - - keyspaces = { - "test_ks": Mock(tables={"test_table": Mock(partition_key=[partition_key_col])}) - } - metadata.keyspaces = keyspaces - mock_cluster.metadata = metadata - - return session - - @pytest.mark.unit - async def test_count_by_token_ranges_single_node(self, mock_session): - """ - Test counting rows with token ranges on single node. - - What this tests: - --------------- - 1. Token range discovery is called correctly - 2. Queries are generated for each token range - 3. Results are aggregated properly - 4. Single node operation works correctly - - Why this matters: - ---------------- - - Ensures basic counting functionality works - - Validates token range splitting logic - - Confirms proper result aggregation - - Foundation for more complex multi-node operations - """ - operator = TokenAwareBulkOperator(mock_session) - - # Mock token range discovery - with patch( - "bulk_operations.bulk_operator.discover_token_ranges", new_callable=AsyncMock - ) as mock_discover: - # Create proper TokenRange mocks - from bulk_operations.token_utils import TokenRange - - mock_ranges = [ - TokenRange(start=-1000, end=0, replicas=["127.0.0.1"]), - TokenRange(start=0, end=1000, replicas=["127.0.0.1"]), - ] - mock_discover.return_value = mock_ranges - - # Mock query results - mock_session.execute.side_effect = [ - Mock(one=Mock(return_value=Mock(count=500))), # First range - Mock(one=Mock(return_value=Mock(count=300))), # Second range - ] - - # Execute count - result = await operator.count_by_token_ranges( - keyspace="test_ks", table="test_table", split_count=2 - ) - - assert result == 800 - assert mock_session.execute.call_count == 2 - - @pytest.mark.unit - async def test_count_with_parallel_execution(self, mock_session): - """ - Test that counts are executed in parallel. - - What this tests: - --------------- - 1. Multiple token ranges are processed concurrently - 2. Parallelism limits are respected - 3. Total execution time reflects parallel processing - 4. Results are correctly aggregated from parallel tasks - - Why this matters: - ---------------- - - Parallel execution is critical for performance - - Must not block the event loop - - Resource limits must be respected - - Common pattern in production bulk operations - """ - operator = TokenAwareBulkOperator(mock_session) - - # Track execution times - execution_times = [] - - async def mock_execute_with_delay(stmt, params=None): - start = asyncio.get_event_loop().time() - await asyncio.sleep(0.1) # Simulate query time - execution_times.append(asyncio.get_event_loop().time() - start) - return Mock(one=Mock(return_value=Mock(count=100))) - - mock_session.execute = mock_execute_with_delay - - with patch( - "bulk_operations.bulk_operator.discover_token_ranges", new_callable=AsyncMock - ) as mock_discover: - # Create 4 ranges - from bulk_operations.token_utils import TokenRange - - mock_ranges = [ - TokenRange(start=i * 1000, end=(i + 1) * 1000, replicas=["node1"]) for i in range(4) - ] - mock_discover.return_value = mock_ranges - - # Execute count - start_time = asyncio.get_event_loop().time() - result = await operator.count_by_token_ranges( - keyspace="test_ks", table="test_table", split_count=4, parallelism=4 - ) - total_time = asyncio.get_event_loop().time() - start_time - - assert result == 400 # 4 ranges * 100 each - # If executed in parallel, total time should be ~0.1s, not 0.4s - assert total_time < 0.2 - - @pytest.mark.unit - async def test_count_with_error_handling(self, mock_session): - """ - Test error handling during count operations. - - What this tests: - --------------- - 1. Partial failures are handled gracefully - 2. BulkOperationError is raised with partial results - 3. Individual errors are collected and reported - 4. Operation continues despite individual failures - - Why this matters: - ---------------- - - Network issues can cause partial failures - - Users need visibility into what succeeded - - Partial results are often useful - - Critical for production reliability - """ - operator = TokenAwareBulkOperator(mock_session) - - with patch( - "bulk_operations.bulk_operator.discover_token_ranges", new_callable=AsyncMock - ) as mock_discover: - from bulk_operations.token_utils import TokenRange - - mock_ranges = [ - TokenRange(start=0, end=1000, replicas=["node1"]), - TokenRange(start=1000, end=2000, replicas=["node2"]), - ] - mock_discover.return_value = mock_ranges - - # First succeeds, second fails - mock_session.execute.side_effect = [ - Mock(one=Mock(return_value=Mock(count=500))), - Exception("Connection timeout"), - ] - - # Should raise BulkOperationError - with pytest.raises(BulkOperationError) as exc_info: - await operator.count_by_token_ranges( - keyspace="test_ks", table="test_table", split_count=2 - ) - - assert "Failed to count" in str(exc_info.value) - assert exc_info.value.partial_result == 500 - - @pytest.mark.unit - async def test_export_streaming(self, mock_session): - """ - Test streaming export functionality. - - What this tests: - --------------- - 1. Token ranges are discovered for export - 2. Results are streamed asynchronously - 3. Memory usage remains constant (streaming) - 4. All rows are yielded in order - - Why this matters: - ---------------- - - Streaming prevents memory exhaustion - - Essential for large dataset exports - - Async iteration must work correctly - - Foundation for Iceberg export functionality - """ - operator = TokenAwareBulkOperator(mock_session) - - # Mock token range discovery - with patch( - "bulk_operations.bulk_operator.discover_token_ranges", new_callable=AsyncMock - ) as mock_discover: - from bulk_operations.token_utils import TokenRange - - mock_ranges = [TokenRange(start=0, end=1000, replicas=["node1"])] - mock_discover.return_value = mock_ranges - - # Mock streaming results - async def mock_stream_results(): - for i in range(10): - row = Mock() - row.id = i - row.name = f"row_{i}" - yield row - - mock_stream_context = AsyncMock() - mock_stream_context.__aenter__.return_value = mock_stream_results() - mock_stream_context.__aexit__.return_value = None - - mock_session.execute_stream.return_value = mock_stream_context - - # Collect exported rows - exported_rows = [] - async for row in operator.export_by_token_ranges( - keyspace="test_ks", table="test_table", split_count=1 - ): - exported_rows.append(row) - - assert len(exported_rows) == 10 - assert exported_rows[0].id == 0 - assert exported_rows[9].name == "row_9" - - @pytest.mark.unit - async def test_progress_callback(self, mock_session): - """ - Test progress callback functionality. - - What this tests: - --------------- - 1. Progress callbacks are invoked during operation - 2. Statistics are updated correctly - 3. Progress percentage is calculated accurately - 4. Final statistics reflect complete operation - - Why this matters: - ---------------- - - Users need visibility into long-running operations - - Progress tracking enables better UX - - Statistics help with performance tuning - - Critical for production monitoring - """ - operator = TokenAwareBulkOperator(mock_session) - progress_updates = [] - - def progress_callback(stats: BulkOperationStats): - progress_updates.append( - { - "rows": stats.rows_processed, - "ranges": stats.ranges_completed, - "progress": stats.progress_percentage, - } - ) - - # Mock setup - with patch( - "bulk_operations.bulk_operator.discover_token_ranges", new_callable=AsyncMock - ) as mock_discover: - from bulk_operations.token_utils import TokenRange - - mock_ranges = [ - TokenRange(start=0, end=1000, replicas=["node1"]), - TokenRange(start=1000, end=2000, replicas=["node2"]), - ] - mock_discover.return_value = mock_ranges - - mock_session.execute.side_effect = [ - Mock(one=Mock(return_value=Mock(count=500))), - Mock(one=Mock(return_value=Mock(count=300))), - ] - - # Execute with progress callback - await operator.count_by_token_ranges( - keyspace="test_ks", - table="test_table", - split_count=2, - progress_callback=progress_callback, - ) - - assert len(progress_updates) >= 2 - # Check final progress - final_update = progress_updates[-1] - assert final_update["ranges"] == 2 - assert final_update["progress"] == 100.0 - - @pytest.mark.unit - async def test_operation_stats(self, mock_session): - """ - Test operation statistics collection. - - What this tests: - --------------- - 1. Statistics are collected during operations - 2. Duration is calculated correctly - 3. Rows per second metric is accurate - 4. All statistics fields are populated - - Why this matters: - ---------------- - - Performance metrics guide optimization - - Statistics enable capacity planning - - Benchmarking requires accurate metrics - - Production monitoring depends on these stats - """ - operator = TokenAwareBulkOperator(mock_session) - - with patch( - "bulk_operations.bulk_operator.discover_token_ranges", new_callable=AsyncMock - ) as mock_discover: - from bulk_operations.token_utils import TokenRange - - mock_ranges = [TokenRange(start=0, end=1000, replicas=["node1"])] - mock_discover.return_value = mock_ranges - - # Mock returns the same value for all calls (it's a single range) - mock_count_result = Mock() - mock_count_result.one.return_value = Mock(count=1000) - mock_session.execute.return_value = mock_count_result - - # Get stats after operation - count, stats = await operator.count_by_token_ranges_with_stats( - keyspace="test_ks", table="test_table", split_count=1 - ) - - assert count == 1000 - assert stats.rows_processed == 1000 - assert stats.ranges_completed == 1 - assert stats.duration_seconds > 0 - assert stats.rows_per_second > 0 diff --git a/examples/bulk_operations/tests/unit/test_csv_exporter.py b/examples/bulk_operations/tests/unit/test_csv_exporter.py deleted file mode 100644 index 9f17fff..0000000 --- a/examples/bulk_operations/tests/unit/test_csv_exporter.py +++ /dev/null @@ -1,365 +0,0 @@ -"""Unit tests for CSV exporter. - -What this tests: ---------------- -1. CSV header generation -2. Row serialization with different data types -3. NULL value handling -4. Collection serialization -5. Compression support -6. Progress tracking - -Why this matters: ----------------- -- CSV is a common export format -- Data type handling must be consistent -- Resume capability is critical for large exports -- Compression saves disk space -""" - -import csv -import gzip -import io -import uuid -from datetime import datetime -from unittest.mock import Mock - -import pytest - -from bulk_operations.bulk_operator import TokenAwareBulkOperator -from bulk_operations.exporters import CSVExporter, ExportFormat, ExportProgress - - -class MockRow: - """Mock Cassandra row object.""" - - def __init__(self, **kwargs): - self._fields = list(kwargs.keys()) - for key, value in kwargs.items(): - setattr(self, key, value) - - -class TestCSVExporter: - """Test CSV export functionality.""" - - @pytest.fixture - def mock_operator(self): - """Create mock bulk operator.""" - operator = Mock(spec=TokenAwareBulkOperator) - operator.session = Mock() - operator.session._session = Mock() - operator.session._session.cluster = Mock() - operator.session._session.cluster.metadata = Mock() - return operator - - @pytest.fixture - def exporter(self, mock_operator): - """Create CSV exporter instance.""" - return CSVExporter(mock_operator) - - def test_csv_value_serialization(self, exporter): - """ - Test serialization of different value types to CSV. - - What this tests: - --------------- - 1. NULL values become empty strings - 2. Booleans become true/false - 3. Collections get formatted properly - 4. Bytes are hex encoded - 5. Timestamps use ISO format - - Why this matters: - ---------------- - - CSV needs consistent string representation - - Must be reversible for imports - - Standard tools should understand the format - """ - # NULL handling - assert exporter._serialize_csv_value(None) == "" - - # Primitives - assert exporter._serialize_csv_value(True) == "true" - assert exporter._serialize_csv_value(False) == "false" - assert exporter._serialize_csv_value(42) == "42" - assert exporter._serialize_csv_value(3.14) == "3.14" - assert exporter._serialize_csv_value("test") == "test" - - # UUID - test_uuid = uuid.uuid4() - assert exporter._serialize_csv_value(test_uuid) == str(test_uuid) - - # Datetime - test_dt = datetime(2024, 1, 1, 12, 0, 0) - assert exporter._serialize_csv_value(test_dt) == "2024-01-01T12:00:00" - - # Collections - assert exporter._serialize_csv_value([1, 2, 3]) == "[1, 2, 3]" - assert exporter._serialize_csv_value({"a", "b"}) == "[a, b]" or "[b, a]" - assert exporter._serialize_csv_value({"k1": "v1", "k2": "v2"}) in [ - "{k1: v1, k2: v2}", - "{k2: v2, k1: v1}", - ] - - # Bytes - assert exporter._serialize_csv_value(b"\x00\x01\x02") == "000102" - - def test_null_string_customization(self, mock_operator): - """ - Test custom NULL string representation. - - What this tests: - --------------- - 1. Default empty string for NULL - 2. Custom NULL strings like "NULL" or "\\N" - 3. Consistent handling across all types - - Why this matters: - ---------------- - - Different tools expect different NULL representations - - PostgreSQL uses \\N, MySQL uses NULL - - Must be configurable for compatibility - """ - # Default exporter uses empty string - default_exporter = CSVExporter(mock_operator) - assert default_exporter._serialize_csv_value(None) == "" - - # Custom NULL string - custom_exporter = CSVExporter(mock_operator, null_string="NULL") - assert custom_exporter._serialize_csv_value(None) == "NULL" - - # PostgreSQL style - pg_exporter = CSVExporter(mock_operator, null_string="\\N") - assert pg_exporter._serialize_csv_value(None) == "\\N" - - @pytest.mark.asyncio - async def test_write_header(self, exporter): - """ - Test CSV header writing. - - What this tests: - --------------- - 1. Header contains column names - 2. Proper delimiter usage - 3. Quoting when needed - - Why this matters: - ---------------- - - Headers enable column mapping - - Must match data row format - - Standard CSV compliance - """ - output = io.StringIO() - columns = ["id", "name", "created_at", "tags"] - - await exporter.write_header(output, columns) - output.seek(0) - - reader = csv.reader(output) - header = next(reader) - assert header == columns - - @pytest.mark.asyncio - async def test_write_row(self, exporter): - """ - Test writing data rows to CSV. - - What this tests: - --------------- - 1. Row data properly formatted - 2. Complex types serialized - 3. Byte count tracking - 4. Thread safety with lock - - Why this matters: - ---------------- - - Data integrity is critical - - Concurrent writes must be safe - - Progress tracking needs accurate bytes - """ - output = io.StringIO() - - # Create test row - row = MockRow( - id=1, - name="Test User", - active=True, - score=99.5, - tags=["tag1", "tag2"], - metadata={"key": "value"}, - created_at=datetime(2024, 1, 1, 12, 0, 0), - ) - - bytes_written = await exporter.write_row(output, row) - output.seek(0) - - # Verify output - reader = csv.reader(output) - values = next(reader) - - assert values[0] == "1" - assert values[1] == "Test User" - assert values[2] == "true" - assert values[3] == "99.5" - assert values[4] == "[tag1, tag2]" - assert values[5] == "{key: value}" - assert values[6] == "2024-01-01T12:00:00" - - # Verify byte count - assert bytes_written > 0 - - @pytest.mark.asyncio - async def test_export_with_compression(self, mock_operator, tmp_path): - """ - Test CSV export with compression. - - What this tests: - --------------- - 1. Gzip compression works - 2. File has correct extension - 3. Compressed data is valid - - Why this matters: - ---------------- - - Large exports need compression - - Must work with standard tools - - File naming conventions matter - """ - exporter = CSVExporter(mock_operator, compression="gzip") - output_path = tmp_path / "test.csv" - - # Mock the export stream - test_rows = [ - MockRow(id=1, name="Alice", score=95.5), - MockRow(id=2, name="Bob", score=87.3), - ] - - async def mock_export(*args, **kwargs): - for row in test_rows: - yield row - - mock_operator.export_by_token_ranges = mock_export - - # Mock metadata - mock_keyspace = Mock() - mock_table = Mock() - mock_table.columns = {"id": None, "name": None, "score": None} - mock_keyspace.tables = {"test_table": mock_table} - mock_operator.session._session.cluster.metadata.keyspaces = {"test_ks": mock_keyspace} - - # Export - await exporter.export( - keyspace="test_ks", - table="test_table", - output_path=output_path, - ) - - # Verify compressed file exists - compressed_path = output_path.with_suffix(".csv.gzip") - assert compressed_path.exists() - - # Verify content - with gzip.open(compressed_path, "rt") as f: - reader = csv.reader(f) - header = next(reader) - assert header == ["id", "name", "score"] - - row1 = next(reader) - assert row1 == ["1", "Alice", "95.5"] - - row2 = next(reader) - assert row2 == ["2", "Bob", "87.3"] - - @pytest.mark.asyncio - async def test_export_progress_tracking(self, mock_operator, tmp_path): - """ - Test progress tracking during export. - - What this tests: - --------------- - 1. Progress initialized correctly - 2. Row count tracked - 3. Progress saved to file - 4. Completion marked - - Why this matters: - ---------------- - - Long exports need monitoring - - Resume capability requires state - - Users need feedback - """ - exporter = CSVExporter(mock_operator) - output_path = tmp_path / "test.csv" - - # Mock export - test_rows = [MockRow(id=i, value=f"test{i}") for i in range(100)] - - async def mock_export(*args, **kwargs): - for row in test_rows: - yield row - - mock_operator.export_by_token_ranges = mock_export - - # Mock metadata - mock_keyspace = Mock() - mock_table = Mock() - mock_table.columns = {"id": None, "value": None} - mock_keyspace.tables = {"test_table": mock_table} - mock_operator.session._session.cluster.metadata.keyspaces = {"test_ks": mock_keyspace} - - # Track progress callbacks - progress_updates = [] - - async def progress_callback(progress): - progress_updates.append(progress.rows_exported) - - # Export - progress = await exporter.export( - keyspace="test_ks", - table="test_table", - output_path=output_path, - progress_callback=progress_callback, - ) - - # Verify progress - assert progress.keyspace == "test_ks" - assert progress.table == "test_table" - assert progress.format == ExportFormat.CSV - assert progress.rows_exported == 100 - assert progress.completed_at is not None - - # Verify progress file - progress_file = output_path.with_suffix(".csv.progress") - assert progress_file.exists() - - # Load and verify - loaded_progress = ExportProgress.load(progress_file) - assert loaded_progress.rows_exported == 100 - - def test_custom_delimiter_and_quoting(self, mock_operator): - """ - Test custom CSV formatting options. - - What this tests: - --------------- - 1. Tab delimiter - 2. Pipe delimiter - 3. Different quoting styles - - Why this matters: - ---------------- - - Different systems expect different formats - - Must handle data with delimiters - - Flexibility for integration - """ - # Tab-delimited - tab_exporter = CSVExporter(mock_operator, delimiter="\t") - assert tab_exporter.delimiter == "\t" - - # Pipe-delimited - pipe_exporter = CSVExporter(mock_operator, delimiter="|") - assert pipe_exporter.delimiter == "|" - - # Quote all - quote_all_exporter = CSVExporter(mock_operator, quoting=csv.QUOTE_ALL) - assert quote_all_exporter.quoting == csv.QUOTE_ALL diff --git a/examples/bulk_operations/tests/unit/test_helpers.py b/examples/bulk_operations/tests/unit/test_helpers.py deleted file mode 100644 index 8f06738..0000000 --- a/examples/bulk_operations/tests/unit/test_helpers.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Helper utilities for unit tests. -""" - - -class MockToken: - """Mock token that supports comparison for sorting.""" - - def __init__(self, value): - self.value = value - - def __lt__(self, other): - return self.value < other.value - - def __eq__(self, other): - return self.value == other.value - - def __repr__(self): - return f"MockToken({self.value})" diff --git a/examples/bulk_operations/tests/unit/test_iceberg_catalog.py b/examples/bulk_operations/tests/unit/test_iceberg_catalog.py deleted file mode 100644 index c19a2cf..0000000 --- a/examples/bulk_operations/tests/unit/test_iceberg_catalog.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Unit tests for Iceberg catalog configuration. - -What this tests: ---------------- -1. Filesystem catalog creation -2. Warehouse directory setup -3. Custom catalog configuration -4. Catalog loading - -Why this matters: ----------------- -- Catalog is the entry point to Iceberg -- Proper configuration is critical -- Warehouse location affects data storage -- Supports multiple catalog types -""" - -import tempfile -import unittest -from pathlib import Path -from unittest.mock import Mock, patch - -from pyiceberg.catalog import Catalog - -from bulk_operations.iceberg.catalog import create_filesystem_catalog, get_or_create_catalog - - -class TestIcebergCatalog(unittest.TestCase): - """Test Iceberg catalog configuration.""" - - def setUp(self): - """Set up test fixtures.""" - self.temp_dir = tempfile.mkdtemp() - self.warehouse_path = Path(self.temp_dir) / "test_warehouse" - - def tearDown(self): - """Clean up test fixtures.""" - import shutil - - shutil.rmtree(self.temp_dir, ignore_errors=True) - - def test_create_filesystem_catalog_default_path(self): - """ - Test creating filesystem catalog with default path. - - What this tests: - --------------- - 1. Default warehouse path is created - 2. Catalog is properly configured - 3. SQLite URI is correct - - Why this matters: - ---------------- - - Easy setup for development - - Consistent default behavior - - No external dependencies - """ - with patch("bulk_operations.iceberg.catalog.Path.cwd") as mock_cwd: - mock_cwd.return_value = Path(self.temp_dir) - - catalog = create_filesystem_catalog("test_catalog") - - # Check catalog properties - self.assertEqual(catalog.name, "test_catalog") - - # Check warehouse directory was created - expected_warehouse = Path(self.temp_dir) / "iceberg_warehouse" - self.assertTrue(expected_warehouse.exists()) - - def test_create_filesystem_catalog_custom_path(self): - """ - Test creating filesystem catalog with custom path. - - What this tests: - --------------- - 1. Custom warehouse path is used - 2. Directory is created if missing - 3. Path objects are handled - - Why this matters: - ---------------- - - Flexibility in storage location - - Integration with existing infrastructure - - Path handling consistency - """ - catalog = create_filesystem_catalog( - name="custom_catalog", warehouse_path=self.warehouse_path - ) - - # Check catalog name - self.assertEqual(catalog.name, "custom_catalog") - - # Check warehouse directory exists - self.assertTrue(self.warehouse_path.exists()) - self.assertTrue(self.warehouse_path.is_dir()) - - def test_create_filesystem_catalog_string_path(self): - """ - Test creating catalog with string path. - - What this tests: - --------------- - 1. String paths are converted to Path objects - 2. Catalog works with string paths - - Why this matters: - ---------------- - - API flexibility - - Backward compatibility - - User convenience - """ - str_path = str(self.warehouse_path) - catalog = create_filesystem_catalog(name="string_path_catalog", warehouse_path=str_path) - - self.assertEqual(catalog.name, "string_path_catalog") - self.assertTrue(Path(str_path).exists()) - - def test_get_or_create_catalog_default(self): - """ - Test get_or_create_catalog with defaults. - - What this tests: - --------------- - 1. Default filesystem catalog is created - 2. Same parameters as create_filesystem_catalog - - Why this matters: - ---------------- - - Simplified API for common case - - Consistent behavior - """ - with patch("bulk_operations.iceberg.catalog.create_filesystem_catalog") as mock_create: - mock_catalog = Mock(spec=Catalog) - mock_create.return_value = mock_catalog - - result = get_or_create_catalog( - catalog_name="default_test", warehouse_path=self.warehouse_path - ) - - # Verify create_filesystem_catalog was called - mock_create.assert_called_once_with("default_test", self.warehouse_path) - self.assertEqual(result, mock_catalog) - - def test_get_or_create_catalog_custom_config(self): - """ - Test get_or_create_catalog with custom configuration. - - What this tests: - --------------- - 1. Custom config overrides defaults - 2. load_catalog is used for custom configs - - Why this matters: - ---------------- - - Support for different catalog types - - Flexibility for production deployments - - Integration with existing catalogs - """ - custom_config = { - "type": "rest", - "uri": "https://iceberg-catalog.example.com", - "credential": "token123", - } - - with patch("bulk_operations.iceberg.catalog.load_catalog") as mock_load: - mock_catalog = Mock(spec=Catalog) - mock_load.return_value = mock_catalog - - result = get_or_create_catalog(catalog_name="rest_catalog", config=custom_config) - - # Verify load_catalog was called with custom config - mock_load.assert_called_once_with("rest_catalog", **custom_config) - self.assertEqual(result, mock_catalog) - - def test_warehouse_directory_creation(self): - """ - Test that warehouse directory is created with proper permissions. - - What this tests: - --------------- - 1. Directory is created if missing - 2. Parent directories are created - 3. Existing directories are not affected - - Why this matters: - ---------------- - - Data needs a place to live - - Permissions affect data security - - Idempotent operation - """ - nested_path = self.warehouse_path / "nested" / "warehouse" - - # Ensure it doesn't exist - self.assertFalse(nested_path.exists()) - - # Create catalog - create_filesystem_catalog(name="nested_test", warehouse_path=nested_path) - - # Check all directories were created - self.assertTrue(nested_path.exists()) - self.assertTrue(nested_path.is_dir()) - self.assertTrue(nested_path.parent.exists()) - - # Create again - should not fail - create_filesystem_catalog(name="nested_test2", warehouse_path=nested_path) - self.assertTrue(nested_path.exists()) - - def test_catalog_properties(self): - """ - Test that catalog has expected properties. - - What this tests: - --------------- - 1. Catalog type is set correctly - 2. Warehouse location is set - 3. URI format is correct - - Why this matters: - ---------------- - - Properties affect catalog behavior - - Debugging and monitoring - - Integration requirements - """ - catalog = create_filesystem_catalog( - name="properties_test", warehouse_path=self.warehouse_path - ) - - # Check basic properties - self.assertEqual(catalog.name, "properties_test") - - # For SQL catalog, we'd check additional properties - # but they're not exposed in the base Catalog interface - - # Verify catalog can be used (basic smoke test) - # This would fail if catalog is misconfigured - namespaces = list(catalog.list_namespaces()) - self.assertIsInstance(namespaces, list) - - -if __name__ == "__main__": - unittest.main() diff --git a/examples/bulk_operations/tests/unit/test_iceberg_schema_mapper.py b/examples/bulk_operations/tests/unit/test_iceberg_schema_mapper.py deleted file mode 100644 index 9acc402..0000000 --- a/examples/bulk_operations/tests/unit/test_iceberg_schema_mapper.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Unit tests for Cassandra to Iceberg schema mapping. - -What this tests: ---------------- -1. CQL type to Iceberg type conversions -2. Collection type handling (list, set, map) -3. Field ID assignment -4. Primary key handling (required vs nullable) - -Why this matters: ----------------- -- Schema mapping is critical for data integrity -- Type mismatches can cause data loss -- Field IDs enable schema evolution -- Nullability affects query semantics -""" - -import unittest -from unittest.mock import Mock - -from pyiceberg.types import ( - BinaryType, - BooleanType, - DateType, - DecimalType, - DoubleType, - FloatType, - IntegerType, - ListType, - LongType, - MapType, - StringType, - TimestamptzType, -) - -from bulk_operations.iceberg.schema_mapper import CassandraToIcebergSchemaMapper - - -class TestCassandraToIcebergSchemaMapper(unittest.TestCase): - """Test schema mapping from Cassandra to Iceberg.""" - - def setUp(self): - """Set up test fixtures.""" - self.mapper = CassandraToIcebergSchemaMapper() - - def test_simple_type_mappings(self): - """ - Test mapping of simple CQL types to Iceberg types. - - What this tests: - --------------- - 1. String types (text, ascii, varchar) - 2. Numeric types (int, bigint, float, double) - 3. Boolean type - 4. Binary type (blob) - - Why this matters: - ---------------- - - Ensures basic data types are preserved - - Critical for data integrity - - Foundation for complex types - """ - test_cases = [ - # String types - ("text", StringType), - ("ascii", StringType), - ("varchar", StringType), - # Integer types - ("tinyint", IntegerType), - ("smallint", IntegerType), - ("int", IntegerType), - ("bigint", LongType), - ("counter", LongType), - # Floating point - ("float", FloatType), - ("double", DoubleType), - # Other types - ("boolean", BooleanType), - ("blob", BinaryType), - ("date", DateType), - ("timestamp", TimestamptzType), - ("uuid", StringType), - ("timeuuid", StringType), - ("inet", StringType), - ] - - for cql_type, expected_type in test_cases: - with self.subTest(cql_type=cql_type): - result = self.mapper._map_cql_type(cql_type) - self.assertIsInstance(result, expected_type) - - def test_decimal_type_mapping(self): - """ - Test decimal and varint type mappings. - - What this tests: - --------------- - 1. Decimal type with default precision - 2. Varint as decimal with 0 scale - - Why this matters: - ---------------- - - Financial data requires exact decimal representation - - Varint needs appropriate precision - """ - # Decimal - decimal_type = self.mapper._map_cql_type("decimal") - self.assertIsInstance(decimal_type, DecimalType) - self.assertEqual(decimal_type.precision, 38) - self.assertEqual(decimal_type.scale, 10) - - # Varint (arbitrary precision integer) - varint_type = self.mapper._map_cql_type("varint") - self.assertIsInstance(varint_type, DecimalType) - self.assertEqual(varint_type.precision, 38) - self.assertEqual(varint_type.scale, 0) - - def test_collection_type_mappings(self): - """ - Test mapping of collection types. - - What this tests: - --------------- - 1. List type with element type - 2. Set type (becomes list in Iceberg) - 3. Map type with key and value types - - Why this matters: - ---------------- - - Collections are common in Cassandra - - Iceberg has no native set type - - Nested types need proper handling - """ - # List - list_type = self.mapper._map_cql_type("list") - self.assertIsInstance(list_type, ListType) - self.assertIsInstance(list_type.element_type, StringType) - self.assertFalse(list_type.element_required) - - # Set (becomes List in Iceberg) - set_type = self.mapper._map_cql_type("set") - self.assertIsInstance(set_type, ListType) - self.assertIsInstance(set_type.element_type, IntegerType) - - # Map - map_type = self.mapper._map_cql_type("map") - self.assertIsInstance(map_type, MapType) - self.assertIsInstance(map_type.key_type, StringType) - self.assertIsInstance(map_type.value_type, DoubleType) - self.assertFalse(map_type.value_required) - - def test_nested_collection_types(self): - """ - Test mapping of nested collection types. - - What this tests: - --------------- - 1. List> - 2. Map> - - Why this matters: - ---------------- - - Cassandra supports nested collections - - Complex data structures need proper mapping - """ - # List> - nested_list = self.mapper._map_cql_type("list>") - self.assertIsInstance(nested_list, ListType) - self.assertIsInstance(nested_list.element_type, ListType) - self.assertIsInstance(nested_list.element_type.element_type, IntegerType) - - # Map> - nested_map = self.mapper._map_cql_type("map>") - self.assertIsInstance(nested_map, MapType) - self.assertIsInstance(nested_map.key_type, StringType) - self.assertIsInstance(nested_map.value_type, ListType) - self.assertIsInstance(nested_map.value_type.element_type, DoubleType) - - def test_frozen_type_handling(self): - """ - Test handling of frozen collections. - - What this tests: - --------------- - 1. Frozen> - 2. Frozen types are unwrapped - - Why this matters: - ---------------- - - Frozen is a Cassandra concept not in Iceberg - - Inner type should be preserved - """ - frozen_list = self.mapper._map_cql_type("frozen>") - self.assertIsInstance(frozen_list, ListType) - self.assertIsInstance(frozen_list.element_type, StringType) - - def test_field_id_assignment(self): - """ - Test unique field ID assignment. - - What this tests: - --------------- - 1. Sequential field IDs - 2. Unique IDs for nested fields - 3. ID counter reset - - Why this matters: - ---------------- - - Field IDs enable schema evolution - - Must be unique within schema - - IDs are permanent for a field - """ - # Reset counter - self.mapper.reset_field_ids() - - # Create mock column metadata - col1 = Mock() - col1.cql_type = "text" - col1.is_primary_key = True - - col2 = Mock() - col2.cql_type = "int" - col2.is_primary_key = False - - col3 = Mock() - col3.cql_type = "list" - col3.is_primary_key = False - - # Map columns - field1 = self.mapper._map_column("id", col1) - field2 = self.mapper._map_column("value", col2) - field3 = self.mapper._map_column("tags", col3) - - # Check field IDs - self.assertEqual(field1.field_id, 1) - self.assertEqual(field2.field_id, 2) - self.assertEqual(field3.field_id, 4) # ID 3 was used for list element - - # List type should have element ID too - self.assertEqual(field3.field_type.element_id, 3) - - def test_primary_key_required_fields(self): - """ - Test that primary key columns are marked as required. - - What this tests: - --------------- - 1. Primary key columns are required (not null) - 2. Non-primary columns are nullable - - Why this matters: - ---------------- - - Primary keys cannot be null in Cassandra - - Affects Iceberg query semantics - - Important for data validation - """ - # Primary key column - pk_col = Mock() - pk_col.cql_type = "text" - pk_col.is_primary_key = True - - pk_field = self.mapper._map_column("id", pk_col) - self.assertTrue(pk_field.required) - - # Regular column - reg_col = Mock() - reg_col.cql_type = "text" - reg_col.is_primary_key = False - - reg_field = self.mapper._map_column("name", reg_col) - self.assertFalse(reg_field.required) - - def test_table_schema_mapping(self): - """ - Test mapping of complete table schema. - - What this tests: - --------------- - 1. Multiple columns mapped correctly - 2. Schema contains all fields - 3. Field order preserved - - Why this matters: - ---------------- - - Complete schema mapping is the main use case - - All columns must be included - - Order affects data files - """ - # Mock table metadata - table_meta = Mock() - - # Mock columns - id_col = Mock() - id_col.cql_type = "uuid" - id_col.is_primary_key = True - - name_col = Mock() - name_col.cql_type = "text" - name_col.is_primary_key = False - - tags_col = Mock() - tags_col.cql_type = "set" - tags_col.is_primary_key = False - - table_meta.columns = { - "id": id_col, - "name": name_col, - "tags": tags_col, - } - - # Map schema - schema = self.mapper.map_table_schema(table_meta) - - # Verify schema - self.assertEqual(len(schema.fields), 3) - - # Check field names and types - field_names = [f.name for f in schema.fields] - self.assertEqual(field_names, ["id", "name", "tags"]) - - # Check types - self.assertIsInstance(schema.fields[0].field_type, StringType) - self.assertIsInstance(schema.fields[1].field_type, StringType) - self.assertIsInstance(schema.fields[2].field_type, ListType) - - def test_unknown_type_fallback(self): - """ - Test that unknown types fall back to string. - - What this tests: - --------------- - 1. Unknown CQL types become strings - 2. No exceptions thrown - - Why this matters: - ---------------- - - Future Cassandra versions may add types - - Graceful degradation is better than failure - """ - unknown_type = self.mapper._map_cql_type("future_type") - self.assertIsInstance(unknown_type, StringType) - - def test_time_type_mapping(self): - """ - Test time type mapping. - - What this tests: - --------------- - 1. Time type maps to LongType - 2. Represents nanoseconds since midnight - - Why this matters: - ---------------- - - Time representation differs between systems - - Precision must be preserved - """ - time_type = self.mapper._map_cql_type("time") - self.assertIsInstance(time_type, LongType) - - -if __name__ == "__main__": - unittest.main() diff --git a/examples/bulk_operations/tests/unit/test_token_ranges.py b/examples/bulk_operations/tests/unit/test_token_ranges.py deleted file mode 100644 index 1949b0e..0000000 --- a/examples/bulk_operations/tests/unit/test_token_ranges.py +++ /dev/null @@ -1,320 +0,0 @@ -""" -Unit tests for token range operations. - -What this tests: ---------------- -1. Token range calculation and splitting -2. Proportional distribution of ranges -3. Handling of ring wraparound -4. Replica awareness - -Why this matters: ----------------- -- Correct token ranges ensure complete data coverage -- Proportional splitting ensures balanced workload -- Proper handling prevents missing or duplicate data -- Replica awareness enables data locality - -Additional context: ---------------------------------- -Token ranges in Cassandra use Murmur3 hash with range: --9223372036854775808 to 9223372036854775807 -""" - -from unittest.mock import MagicMock, Mock - -import pytest - -from bulk_operations.token_utils import ( - TokenRange, - TokenRangeSplitter, - discover_token_ranges, - generate_token_range_query, -) - - -class TestTokenRange: - """Test TokenRange data class.""" - - @pytest.mark.unit - def test_token_range_creation(self): - """Test creating a token range.""" - range = TokenRange(start=-9223372036854775808, end=0, replicas=["node1", "node2", "node3"]) - - assert range.start == -9223372036854775808 - assert range.end == 0 - assert range.size == 9223372036854775808 - assert range.replicas == ["node1", "node2", "node3"] - assert 0.49 < range.fraction < 0.51 # About 50% of ring - - @pytest.mark.unit - def test_token_range_wraparound(self): - """Test token range that wraps around the ring.""" - # Range from positive to negative (wraps around) - range = TokenRange(start=9223372036854775800, end=-9223372036854775800, replicas=["node1"]) - - # Size calculation should handle wraparound - expected_size = 16 # Small range wrapping around - assert range.size == expected_size - assert range.fraction < 0.001 # Very small fraction of ring - - @pytest.mark.unit - def test_token_range_full_ring(self): - """Test token range covering entire ring.""" - range = TokenRange( - start=-9223372036854775808, - end=9223372036854775807, - replicas=["node1", "node2", "node3"], - ) - - assert range.size == 18446744073709551615 # 2^64 - 1 - assert range.fraction == 1.0 # 100% of ring - - -class TestTokenRangeSplitter: - """Test token range splitting logic.""" - - @pytest.mark.unit - def test_split_single_range_evenly(self): - """Test splitting a single range into equal parts.""" - splitter = TokenRangeSplitter() - original = TokenRange(start=0, end=1000, replicas=["node1", "node2"]) - - splits = splitter.split_single_range(original, 4) - - assert len(splits) == 4 - # Check splits are contiguous and cover entire range - assert splits[0].start == 0 - assert splits[0].end == 250 - assert splits[1].start == 250 - assert splits[1].end == 500 - assert splits[2].start == 500 - assert splits[2].end == 750 - assert splits[3].start == 750 - assert splits[3].end == 1000 - - # All splits should have same replicas - for split in splits: - assert split.replicas == ["node1", "node2"] - - @pytest.mark.unit - def test_split_proportionally(self): - """Test proportional splitting based on range sizes.""" - splitter = TokenRangeSplitter() - - # Create ranges of different sizes - ranges = [ - TokenRange(start=0, end=1000, replicas=["node1"]), # 10% of total - TokenRange(start=1000, end=9000, replicas=["node2"]), # 80% of total - TokenRange(start=9000, end=10000, replicas=["node3"]), # 10% of total - ] - - # Request 10 splits total - splits = splitter.split_proportionally(ranges, 10) - - # Should get approximately 1, 8, 1 splits for each range - node1_splits = [s for s in splits if s.replicas == ["node1"]] - node2_splits = [s for s in splits if s.replicas == ["node2"]] - node3_splits = [s for s in splits if s.replicas == ["node3"]] - - assert len(node1_splits) == 1 - assert len(node2_splits) == 8 - assert len(node3_splits) == 1 - assert len(splits) == 10 - - @pytest.mark.unit - def test_split_with_minimum_size(self): - """Test that small ranges don't get over-split.""" - splitter = TokenRangeSplitter() - - # Very small range - small_range = TokenRange(start=0, end=10, replicas=["node1"]) - - # Request many splits - splits = splitter.split_single_range(small_range, 100) - - # Should not create more splits than makes sense - # (implementation should have minimum split size) - assert len(splits) <= 10 # Assuming minimum split size of 1 - - @pytest.mark.unit - def test_cluster_by_replicas(self): - """Test clustering ranges by their replica sets.""" - splitter = TokenRangeSplitter() - - ranges = [ - TokenRange(start=0, end=100, replicas=["node1", "node2"]), - TokenRange(start=100, end=200, replicas=["node2", "node3"]), - TokenRange(start=200, end=300, replicas=["node1", "node2"]), - TokenRange(start=300, end=400, replicas=["node2", "node3"]), - ] - - clustered = splitter.cluster_by_replicas(ranges) - - # Should have 2 clusters based on replica sets - assert len(clustered) == 2 - - # Find clusters - cluster1 = None - cluster2 = None - for replicas, cluster_ranges in clustered.items(): - if set(replicas) == {"node1", "node2"}: - cluster1 = cluster_ranges - elif set(replicas) == {"node2", "node3"}: - cluster2 = cluster_ranges - - assert cluster1 is not None - assert cluster2 is not None - assert len(cluster1) == 2 - assert len(cluster2) == 2 - - -class TestTokenRangeDiscovery: - """Test discovering token ranges from cluster metadata.""" - - @pytest.mark.unit - async def test_discover_token_ranges(self): - """ - Test discovering token ranges from cluster metadata. - - What this tests: - --------------- - 1. Extraction from Cassandra metadata - 2. All token ranges are discovered - 3. Replica information is captured - 4. Async operation works correctly - - Why this matters: - ---------------- - - Must discover all ranges for completeness - - Replica info enables local processing - - Integration point with driver metadata - - Foundation of token-aware operations - """ - # Mock cluster metadata - mock_session = Mock() - mock_cluster = Mock() - mock_metadata = Mock() - mock_token_map = Mock() - - # Set up mock relationships - mock_session._session = Mock() - mock_session._session.cluster = mock_cluster - mock_cluster.metadata = mock_metadata - mock_metadata.token_map = mock_token_map - - # Mock tokens in the ring - from .test_helpers import MockToken - - mock_token1 = MockToken(-9223372036854775808) - mock_token2 = MockToken(0) - mock_token3 = MockToken(9223372036854775807) - mock_token_map.ring = [mock_token1, mock_token2, mock_token3] - - # Mock replicas - mock_token_map.get_replicas = MagicMock( - side_effect=[ - [Mock(address="127.0.0.1"), Mock(address="127.0.0.2")], - [Mock(address="127.0.0.2"), Mock(address="127.0.0.3")], - [Mock(address="127.0.0.3"), Mock(address="127.0.0.1")], # For wraparound - ] - ) - - # Discover ranges - ranges = await discover_token_ranges(mock_session, "test_keyspace") - - assert len(ranges) == 3 # Three tokens create three ranges - assert ranges[0].start == -9223372036854775808 - assert ranges[0].end == 0 - assert ranges[0].replicas == ["127.0.0.1", "127.0.0.2"] - assert ranges[1].start == 0 - assert ranges[1].end == 9223372036854775807 - assert ranges[1].replicas == ["127.0.0.2", "127.0.0.3"] - assert ranges[2].start == 9223372036854775807 - assert ranges[2].end == -9223372036854775808 # Wraparound - assert ranges[2].replicas == ["127.0.0.3", "127.0.0.1"] - - -class TestTokenRangeQueryGeneration: - """Test generating CQL queries with token ranges.""" - - @pytest.mark.unit - def test_generate_basic_token_range_query(self): - """ - Test generating a basic token range query. - - What this tests: - --------------- - 1. Valid CQL syntax generation - 2. Token function usage is correct - 3. Range boundaries use proper operators - 4. Fully qualified table names - - Why this matters: - ---------------- - - Query syntax must be valid CQL - - Token function enables range scans - - Boundary operators prevent gaps/overlaps - - Production queries depend on this - """ - range = TokenRange(start=0, end=1000, replicas=["node1"]) - - query = generate_token_range_query( - keyspace="test_ks", table="test_table", partition_keys=["id"], token_range=range - ) - - expected = "SELECT * FROM test_ks.test_table " "WHERE token(id) > 0 AND token(id) <= 1000" - assert query == expected - - @pytest.mark.unit - def test_generate_query_with_multiple_partition_keys(self): - """Test query generation with composite partition key.""" - range = TokenRange(start=-1000, end=1000, replicas=["node1"]) - - query = generate_token_range_query( - keyspace="test_ks", - table="test_table", - partition_keys=["country", "city"], - token_range=range, - ) - - expected = ( - "SELECT * FROM test_ks.test_table " - "WHERE token(country, city) > -1000 AND token(country, city) <= 1000" - ) - assert query == expected - - @pytest.mark.unit - def test_generate_query_with_column_selection(self): - """Test query generation with specific columns.""" - range = TokenRange(start=0, end=1000, replicas=["node1"]) - - query = generate_token_range_query( - keyspace="test_ks", - table="test_table", - partition_keys=["id"], - token_range=range, - columns=["id", "name", "created_at"], - ) - - expected = ( - "SELECT id, name, created_at FROM test_ks.test_table " - "WHERE token(id) > 0 AND token(id) <= 1000" - ) - assert query == expected - - @pytest.mark.unit - def test_generate_query_with_min_token(self): - """Test query generation starting from minimum token.""" - range = TokenRange(start=-9223372036854775808, end=0, replicas=["node1"]) # Min token - - query = generate_token_range_query( - keyspace="test_ks", table="test_table", partition_keys=["id"], token_range=range - ) - - # First range should use >= instead of > - expected = ( - "SELECT * FROM test_ks.test_table " - "WHERE token(id) >= -9223372036854775808 AND token(id) <= 0" - ) - assert query == expected diff --git a/examples/bulk_operations/tests/unit/test_token_utils.py b/examples/bulk_operations/tests/unit/test_token_utils.py deleted file mode 100644 index 8fe2de9..0000000 --- a/examples/bulk_operations/tests/unit/test_token_utils.py +++ /dev/null @@ -1,388 +0,0 @@ -""" -Unit tests for token range utilities. - -What this tests: ---------------- -1. Token range size calculations -2. Range splitting logic -3. Wraparound handling -4. Proportional distribution -5. Replica clustering - -Why this matters: ----------------- -- Ensures data completeness -- Prevents missing rows -- Maintains proper load distribution -- Enables efficient parallel processing - -Additional context: ---------------------------------- -Token ranges in Cassandra use Murmur3 hash which -produces 128-bit values from -2^63 to 2^63-1. -""" - -from unittest.mock import Mock - -import pytest - -from bulk_operations.token_utils import ( - MAX_TOKEN, - MIN_TOKEN, - TOTAL_TOKEN_RANGE, - TokenRange, - TokenRangeSplitter, - discover_token_ranges, - generate_token_range_query, -) - - -class TestTokenRange: - """Test the TokenRange dataclass.""" - - @pytest.mark.unit - def test_token_range_size_normal(self): - """ - Test size calculation for normal ranges. - - What this tests: - --------------- - 1. Size calculation for positive ranges - 2. Size calculation for negative ranges - 3. Basic arithmetic correctness - 4. No wraparound edge cases - - Why this matters: - ---------------- - - Token range sizes determine split proportions - - Incorrect sizes lead to unbalanced loads - - Foundation for all range splitting logic - - Critical for even data distribution - """ - range = TokenRange(start=0, end=1000, replicas=["node1"]) - assert range.size == 1000 - - range = TokenRange(start=-1000, end=0, replicas=["node1"]) - assert range.size == 1000 - - @pytest.mark.unit - def test_token_range_size_wraparound(self): - """ - Test size calculation for ranges that wrap around. - - What this tests: - --------------- - 1. Wraparound from MAX_TOKEN to MIN_TOKEN - 2. Correct size calculation across boundaries - 3. Edge case handling for ring topology - 4. Boundary arithmetic correctness - - Why this matters: - ---------------- - - Cassandra's token ring wraps around - - Last range often crosses the boundary - - Incorrect handling causes missing data - - Real clusters always have wraparound ranges - """ - # Range wraps from near max to near min - range = TokenRange(start=MAX_TOKEN - 1000, end=MIN_TOKEN + 1000, replicas=["node1"]) - expected_size = 1000 + 1000 + 1 # 1000 on each side plus the boundary - assert range.size == expected_size - - @pytest.mark.unit - def test_token_range_fraction(self): - """Test fraction calculation.""" - # Quarter of the ring - quarter_size = TOTAL_TOKEN_RANGE // 4 - range = TokenRange(start=0, end=quarter_size, replicas=["node1"]) - assert abs(range.fraction - 0.25) < 0.001 - - -class TestTokenRangeSplitter: - """Test the TokenRangeSplitter class.""" - - @pytest.fixture - def splitter(self): - """Create a TokenRangeSplitter instance.""" - return TokenRangeSplitter() - - @pytest.mark.unit - def test_split_single_range_no_split(self, splitter): - """Test that requesting 1 or 0 splits returns original range.""" - range = TokenRange(start=0, end=1000, replicas=["node1"]) - - result = splitter.split_single_range(range, 1) - assert len(result) == 1 - assert result[0].start == 0 - assert result[0].end == 1000 - - @pytest.mark.unit - def test_split_single_range_even_split(self, splitter): - """Test splitting a range into even parts.""" - range = TokenRange(start=0, end=1000, replicas=["node1"]) - - result = splitter.split_single_range(range, 4) - assert len(result) == 4 - - # Check splits - assert result[0].start == 0 - assert result[0].end == 250 - assert result[1].start == 250 - assert result[1].end == 500 - assert result[2].start == 500 - assert result[2].end == 750 - assert result[3].start == 750 - assert result[3].end == 1000 - - @pytest.mark.unit - def test_split_single_range_small_range(self, splitter): - """Test that very small ranges aren't split.""" - range = TokenRange(start=0, end=2, replicas=["node1"]) - - result = splitter.split_single_range(range, 10) - assert len(result) == 1 # Too small to split - - @pytest.mark.unit - def test_split_proportionally_empty(self, splitter): - """Test proportional splitting with empty input.""" - result = splitter.split_proportionally([], 10) - assert result == [] - - @pytest.mark.unit - def test_split_proportionally_single_range(self, splitter): - """Test proportional splitting with single range.""" - ranges = [TokenRange(start=0, end=1000, replicas=["node1"])] - - result = splitter.split_proportionally(ranges, 4) - assert len(result) == 4 - - @pytest.mark.unit - def test_split_proportionally_multiple_ranges(self, splitter): - """ - Test proportional splitting with ranges of different sizes. - - What this tests: - --------------- - 1. Proportional distribution based on size - 2. Larger ranges get more splits - 3. Rounding behavior is reasonable - 4. All input ranges are covered - - Why this matters: - ---------------- - - Uneven token distribution is common - - Load balancing requires proportional splits - - Prevents hotspots in processing - - Mimics real cluster token distributions - """ - ranges = [ - TokenRange(start=0, end=1000, replicas=["node1"]), # Size 1000 - TokenRange(start=1000, end=4000, replicas=["node2"]), # Size 3000 - ] - - result = splitter.split_proportionally(ranges, 4) - - # Should split proportionally: 1 split for first, 3 for second - # But implementation uses round(), so might be slightly different - assert len(result) >= 2 - assert len(result) <= 4 - - @pytest.mark.unit - def test_cluster_by_replicas(self, splitter): - """ - Test clustering ranges by replica sets. - - What this tests: - --------------- - 1. Ranges are grouped by replica nodes - 2. Replica order doesn't affect grouping - 3. All ranges are included in clusters - 4. Unique replica sets are identified - - Why this matters: - ---------------- - - Enables coordinator-local processing - - Reduces network traffic in operations - - Improves performance through locality - - Critical for multi-datacenter efficiency - """ - ranges = [ - TokenRange(start=0, end=100, replicas=["node1", "node2"]), - TokenRange(start=100, end=200, replicas=["node2", "node3"]), - TokenRange(start=200, end=300, replicas=["node1", "node2"]), - TokenRange(start=300, end=400, replicas=["node3", "node1"]), - ] - - clusters = splitter.cluster_by_replicas(ranges) - - # Should have 3 unique replica sets - assert len(clusters) == 3 - - # Check that ranges are properly grouped - key1 = tuple(sorted(["node1", "node2"])) - assert key1 in clusters - assert len(clusters[key1]) == 2 - - -class TestDiscoverTokenRanges: - """Test token range discovery from cluster metadata.""" - - @pytest.mark.unit - async def test_discover_token_ranges_success(self): - """ - Test successful token range discovery. - - What this tests: - --------------- - 1. Token ranges are extracted from metadata - 2. Replica information is preserved - 3. All ranges from token map are returned - 4. Async operation completes successfully - - Why this matters: - ---------------- - - Discovery is the foundation of token-aware ops - - Replica awareness enables local reads - - Must handle all Cassandra metadata structures - - Critical for multi-datacenter deployments - """ - # Mock session and cluster - mock_session = Mock() - mock_cluster = Mock() - mock_metadata = Mock() - mock_token_map = Mock() - - # Setup tokens in the ring - from .test_helpers import MockToken - - mock_token1 = MockToken(-1000) - mock_token2 = MockToken(0) - mock_token3 = MockToken(1000) - mock_token_map.ring = [mock_token1, mock_token2, mock_token3] - - # Setup replicas - mock_replica1 = Mock() - mock_replica1.address = "192.168.1.1" - mock_replica2 = Mock() - mock_replica2.address = "192.168.1.2" - - mock_token_map.get_replicas.side_effect = [ - [mock_replica1, mock_replica2], - [mock_replica2, mock_replica1], - [mock_replica1, mock_replica2], # For the third token range - ] - - mock_metadata.token_map = mock_token_map - mock_cluster.metadata = mock_metadata - mock_session._session = Mock() - mock_session._session.cluster = mock_cluster - - # Test discovery - ranges = await discover_token_ranges(mock_session, "test_ks") - - assert len(ranges) == 3 # Three tokens create three ranges - assert ranges[0].start == -1000 - assert ranges[0].end == 0 - assert ranges[0].replicas == ["192.168.1.1", "192.168.1.2"] - assert ranges[1].start == 0 - assert ranges[1].end == 1000 - assert ranges[1].replicas == ["192.168.1.2", "192.168.1.1"] - assert ranges[2].start == 1000 - assert ranges[2].end == -1000 # Wraparound range - assert ranges[2].replicas == ["192.168.1.1", "192.168.1.2"] - - @pytest.mark.unit - async def test_discover_token_ranges_no_token_map(self): - """Test error when token map is not available.""" - mock_session = Mock() - mock_cluster = Mock() - mock_metadata = Mock() - mock_metadata.token_map = None - mock_cluster.metadata = mock_metadata - mock_session._session = Mock() - mock_session._session.cluster = mock_cluster - - with pytest.raises(RuntimeError, match="Token map not available"): - await discover_token_ranges(mock_session, "test_ks") - - -class TestGenerateTokenRangeQuery: - """Test CQL query generation for token ranges.""" - - @pytest.mark.unit - def test_generate_query_all_columns(self): - """Test query generation with all columns.""" - query = generate_token_range_query( - keyspace="test_ks", - table="test_table", - partition_keys=["id"], - token_range=TokenRange(start=0, end=1000, replicas=["node1"]), - ) - - expected = "SELECT * FROM test_ks.test_table " "WHERE token(id) > 0 AND token(id) <= 1000" - assert query == expected - - @pytest.mark.unit - def test_generate_query_specific_columns(self): - """Test query generation with specific columns.""" - query = generate_token_range_query( - keyspace="test_ks", - table="test_table", - partition_keys=["id"], - token_range=TokenRange(start=0, end=1000, replicas=["node1"]), - columns=["id", "name", "value"], - ) - - expected = ( - "SELECT id, name, value FROM test_ks.test_table " - "WHERE token(id) > 0 AND token(id) <= 1000" - ) - assert query == expected - - @pytest.mark.unit - def test_generate_query_minimum_token(self): - """ - Test query generation for minimum token edge case. - - What this tests: - --------------- - 1. MIN_TOKEN uses >= instead of > - 2. Prevents missing first token value - 3. Query syntax is valid CQL - 4. Edge case is handled correctly - - Why this matters: - ---------------- - - MIN_TOKEN is a valid token value - - Using > would skip data at MIN_TOKEN - - Common source of missing data bugs - - DSBulk compatibility requires this behavior - """ - query = generate_token_range_query( - keyspace="test_ks", - table="test_table", - partition_keys=["id"], - token_range=TokenRange(start=MIN_TOKEN, end=0, replicas=["node1"]), - ) - - expected = ( - f"SELECT * FROM test_ks.test_table " - f"WHERE token(id) >= {MIN_TOKEN} AND token(id) <= 0" - ) - assert query == expected - - @pytest.mark.unit - def test_generate_query_compound_partition_key(self): - """Test query generation with compound partition key.""" - query = generate_token_range_query( - keyspace="test_ks", - table="test_table", - partition_keys=["id", "type"], - token_range=TokenRange(start=0, end=1000, replicas=["node1"]), - ) - - expected = ( - "SELECT * FROM test_ks.test_table " - "WHERE token(id, type) > 0 AND token(id, type) <= 1000" - ) - assert query == expected diff --git a/examples/bulk_operations/visualize_tokens.py b/examples/bulk_operations/visualize_tokens.py deleted file mode 100755 index 98c1c25..0000000 --- a/examples/bulk_operations/visualize_tokens.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python3 -""" -Visualize token distribution in the Cassandra cluster. - -This script helps understand how vnodes distribute tokens -across the cluster and validates our token range discovery. -""" - -import asyncio -from collections import defaultdict - -from rich.console import Console -from rich.table import Table - -from async_cassandra import AsyncCluster -from bulk_operations.token_utils import MAX_TOKEN, MIN_TOKEN, discover_token_ranges - -console = Console() - - -def analyze_node_distribution(ranges): - """Analyze and display token distribution by node.""" - primary_owner_count = defaultdict(int) - all_replica_count = defaultdict(int) - - for r in ranges: - # First replica is primary owner - if r.replicas: - primary_owner_count[r.replicas[0]] += 1 - for replica in r.replicas: - all_replica_count[replica] += 1 - - # Display node statistics - table = Table(title="Token Distribution by Node") - table.add_column("Node", style="cyan") - table.add_column("Primary Ranges", style="green") - table.add_column("Total Ranges (with replicas)", style="yellow") - table.add_column("Percentage of Ring", style="magenta") - - total_primary = sum(primary_owner_count.values()) - - for node in sorted(all_replica_count.keys()): - primary = primary_owner_count.get(node, 0) - total = all_replica_count.get(node, 0) - percentage = (primary / total_primary * 100) if total_primary > 0 else 0 - - table.add_row(node, str(primary), str(total), f"{percentage:.1f}%") - - console.print(table) - return primary_owner_count - - -def analyze_range_sizes(ranges): - """Analyze and display token range sizes.""" - console.print("\n[bold]Token Range Size Analysis[/bold]") - - range_sizes = [r.size for r in ranges] - avg_size = sum(range_sizes) / len(range_sizes) - min_size = min(range_sizes) - max_size = max(range_sizes) - - console.print(f"Average range size: {avg_size:,.0f}") - console.print(f"Smallest range: {min_size:,}") - console.print(f"Largest range: {max_size:,}") - console.print(f"Size ratio (max/min): {max_size/min_size:.2f}x") - - -def validate_ring_coverage(ranges): - """Validate token ring coverage for gaps.""" - console.print("\n[bold]Token Ring Coverage Validation[/bold]") - - sorted_ranges = sorted(ranges, key=lambda r: r.start) - - # Check for gaps - gaps = [] - for i in range(len(sorted_ranges) - 1): - current = sorted_ranges[i] - next_range = sorted_ranges[i + 1] - if current.end != next_range.start: - gaps.append((current.end, next_range.start)) - - if gaps: - console.print(f"[red]⚠ Found {len(gaps)} gaps in token ring![/red]") - for gap_start, gap_end in gaps[:5]: # Show first 5 - console.print(f" Gap: {gap_start} to {gap_end}") - else: - console.print("[green]✓ No gaps found - complete ring coverage[/green]") - - # Check first and last ranges - if sorted_ranges[0].start == MIN_TOKEN: - console.print("[green]✓ First range starts at MIN_TOKEN[/green]") - else: - console.print(f"[red]⚠ First range starts at {sorted_ranges[0].start}, not MIN_TOKEN[/red]") - - if sorted_ranges[-1].end == MAX_TOKEN: - console.print("[green]✓ Last range ends at MAX_TOKEN[/green]") - else: - console.print(f"[yellow]Last range ends at {sorted_ranges[-1].end}[/yellow]") - - return sorted_ranges - - -def display_sample_ranges(sorted_ranges): - """Display sample token ranges.""" - console.print("\n[bold]Sample Token Ranges (first 5)[/bold]") - sample_table = Table() - sample_table.add_column("Range #", style="cyan") - sample_table.add_column("Start", style="green") - sample_table.add_column("End", style="yellow") - sample_table.add_column("Size", style="magenta") - sample_table.add_column("Replicas", style="blue") - - for i, r in enumerate(sorted_ranges[:5]): - sample_table.add_row( - str(i + 1), str(r.start), str(r.end), f"{r.size:,}", ", ".join(r.replicas) - ) - - console.print(sample_table) - - -async def visualize_token_distribution(): - """Visualize how tokens are distributed across the cluster.""" - - console.print("[cyan]Connecting to Cassandra cluster...[/cyan]") - - async with AsyncCluster(contact_points=["localhost"]) as cluster, cluster.connect() as session: - # Create test keyspace if needed - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS token_test - WITH replication = { - 'class': 'SimpleStrategy', - 'replication_factor': 3 - } - """ - ) - - console.print("[green]✓ Connected to cluster[/green]\n") - - # Discover token ranges - ranges = await discover_token_ranges(session, "token_test") - - # Analyze distribution - console.print("[bold]Token Range Analysis[/bold]") - console.print(f"Total ranges discovered: {len(ranges)}") - console.print("Expected with 3 nodes × 256 vnodes: ~768 ranges\n") - - # Analyze node distribution - primary_owner_count = analyze_node_distribution(ranges) - - # Analyze range sizes - analyze_range_sizes(ranges) - - # Validate ring coverage - sorted_ranges = validate_ring_coverage(ranges) - - # Display sample ranges - display_sample_ranges(sorted_ranges) - - # Vnode insight - console.print("\n[bold]Vnode Configuration Insight[/bold]") - console.print(f"With {len(primary_owner_count)} nodes and {len(ranges)} ranges:") - console.print(f"Average vnodes per node: {len(ranges) / len(primary_owner_count):.1f}") - console.print("This matches the expected 256 vnodes per node configuration.") - - -if __name__ == "__main__": - try: - asyncio.run(visualize_token_distribution()) - except KeyboardInterrupt: - console.print("\n[yellow]Visualization cancelled[/yellow]") - except Exception as e: - console.print(f"\n[red]Error: {e}[/red]") - import traceback - - traceback.print_exc() diff --git a/examples/fastapi_app/.env.example b/examples/fastapi_app/.env.example deleted file mode 100644 index 80dabd7..0000000 --- a/examples/fastapi_app/.env.example +++ /dev/null @@ -1,29 +0,0 @@ -# FastAPI + async-cassandra Environment Configuration -# Copy this file to .env and update with your values - -# Cassandra Connection Settings -CASSANDRA_HOSTS=localhost,192.168.1.10 # Comma-separated list of contact points -CASSANDRA_PORT=9042 # Native transport port - -# Optional: Authentication (if enabled in Cassandra) -# CASSANDRA_USERNAME=cassandra -# CASSANDRA_PASSWORD=your-secure-password - -# Application Settings -LOG_LEVEL=INFO # DEBUG, INFO, WARNING, ERROR, CRITICAL -APP_ENV=development # development, staging, production - -# Performance Settings -CASSANDRA_EXECUTOR_THREADS=2 # Number of executor threads -CASSANDRA_IDLE_HEARTBEAT_INTERVAL=30 # Heartbeat interval in seconds -CASSANDRA_CONNECTION_TIMEOUT=5.0 # Connection timeout in seconds - -# Optional: SSL/TLS Configuration -# CASSANDRA_SSL_ENABLED=true -# CASSANDRA_SSL_CA_CERTS=/path/to/ca.pem -# CASSANDRA_SSL_CERTFILE=/path/to/cert.pem -# CASSANDRA_SSL_KEYFILE=/path/to/key.pem - -# Optional: Monitoring -# PROMETHEUS_ENABLED=true -# PROMETHEUS_PORT=9091 diff --git a/examples/fastapi_app/Dockerfile b/examples/fastapi_app/Dockerfile deleted file mode 100644 index 9b0dcb6..0000000 --- a/examples/fastapi_app/Dockerfile +++ /dev/null @@ -1,33 +0,0 @@ -# Use official Python runtime as base image -FROM python:3.12-slim - -# Set working directory in container -WORKDIR /app - -# Install system dependencies -RUN apt-get update && apt-get install -y \ - gcc \ - && rm -rf /var/lib/apt/lists/* - -# Copy requirements first for better caching -COPY requirements.txt . - -# Install Python dependencies -RUN pip install --no-cache-dir -r requirements.txt - -# Copy application code -COPY main.py . - -# Create non-root user to run the app -RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app -USER appuser - -# Expose port -EXPOSE 8000 - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \ - CMD python -c "import httpx; httpx.get('http://localhost:8000/health').raise_for_status()" - -# Run the application -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/examples/fastapi_app/README.md b/examples/fastapi_app/README.md deleted file mode 100644 index f6edf2a..0000000 --- a/examples/fastapi_app/README.md +++ /dev/null @@ -1,541 +0,0 @@ -# FastAPI Example Application - -This example demonstrates how to use async-cassandra with FastAPI to build a high-performance REST API backed by Cassandra. - -## 🎯 Purpose - -**This example serves a dual purpose:** -1. **Production Template**: A real-world example of how to integrate async-cassandra with FastAPI -2. **CI Integration Test**: This application is used in our CI/CD pipeline to validate that async-cassandra works correctly in a real async web framework environment - -## Overview - -The example showcases all the key features of async-cassandra: -- **Thread Safety**: Handles concurrent requests without data corruption -- **Memory Efficiency**: Streaming endpoints for large datasets -- **Error Handling**: Consistent error responses across all operations -- **Performance**: Async operations preventing event loop blocking -- **Monitoring**: Health checks and metrics endpoints -- **Production Patterns**: Proper lifecycle management, prepared statements, and error handling - -## What You'll Learn - -This example teaches essential patterns for production Cassandra applications: - -1. **Connection Management**: How to properly manage cluster and session lifecycle -2. **Prepared Statements**: Reusing prepared statements for performance and security -3. **Error Handling**: Converting Cassandra errors to appropriate HTTP responses -4. **Streaming**: Processing large datasets without memory exhaustion -5. **Concurrency**: Leveraging async for high-throughput operations -6. **Context Managers**: Ensuring resources are properly cleaned up -7. **Monitoring**: Building observable applications with health and metrics -8. **Testing**: Comprehensive test patterns for async applications - -## API Endpoints - -### 1. Basic CRUD Operations -- `POST /users` - Create a new user - - **Purpose**: Demonstrates basic insert operations with prepared statements - - **Validates**: UUID generation, timestamp handling, data validation -- `GET /users/{user_id}` - Get user by ID - - **Purpose**: Shows single-row query patterns - - **Validates**: UUID parsing, error handling for non-existent users -- `PUT /users/{user_id}` - Full update of user - - **Purpose**: Demonstrates full record replacement - - **Validates**: Update operations, timestamp updates -- `PATCH /users/{user_id}` - Partial update of user - - **Purpose**: Shows selective field updates - - **Validates**: Optional field handling, partial updates -- `DELETE /users/{user_id}` - Delete user - - **Purpose**: Demonstrates delete operations - - **Validates**: Idempotent deletes, cleanup -- `GET /users` - List users with pagination - - **Purpose**: Shows basic pagination patterns - - **Query params**: `limit` (default: 10, max: 100) - -### 2. Streaming Operations -- `GET /users/stream` - Stream large datasets efficiently - - **Purpose**: Demonstrates memory-efficient streaming for large result sets - - **Query params**: - - `limit`: Total rows to stream - - `fetch_size`: Rows per page (controls memory usage) - - `age_filter`: Filter users by minimum age - - **Validates**: Memory efficiency, streaming context managers -- `GET /users/stream/pages` - Page-by-page streaming - - **Purpose**: Shows manual page iteration for client-controlled paging - - **Query params**: Same as above - - **Validates**: Page-by-page processing, fetch more pages pattern - -### 3. Batch Operations -- `POST /users/batch` - Create multiple users in a single batch - - **Purpose**: Demonstrates batch insert performance benefits - - **Validates**: Batch size limits, atomic batch operations - -### 4. Performance Testing -- `GET /performance/async` - Test async performance with concurrent queries - - **Purpose**: Demonstrates concurrent query execution benefits - - **Query params**: `requests` (number of concurrent queries) - - **Validates**: Thread pool handling, concurrent execution -- `GET /performance/sync` - Compare with sequential execution - - **Purpose**: Shows performance difference vs sequential execution - - **Query params**: `requests` (number of sequential queries) - - **Validates**: Performance improvement metrics - -### 5. Error Simulation & Resilience Testing -- `GET /slow_query` - Simulates slow query with timeout handling - - **Purpose**: Tests timeout behavior and client timeout headers - - **Headers**: `X-Request-Timeout` (timeout in seconds) - - **Validates**: Timeout propagation, graceful timeout handling -- `GET /long_running_query` - Simulates very long operation (10s) - - **Purpose**: Tests long-running query behavior - - **Validates**: Long operation handling without blocking - -### 6. Context Manager Safety Testing -These endpoints validate critical safety properties of context managers: - -- `POST /context_manager_safety/query_error` - - **Purpose**: Verifies query errors don't close the session - - **Tests**: Executes invalid query, then valid query - - **Validates**: Error isolation, session stability after errors - -- `POST /context_manager_safety/streaming_error` - - **Purpose**: Ensures streaming errors don't affect the session - - **Tests**: Attempts invalid streaming, then valid streaming - - **Validates**: Streaming context cleanup without session impact - -- `POST /context_manager_safety/concurrent_streams` - - **Purpose**: Tests multiple concurrent streams don't interfere - - **Tests**: Runs 3 concurrent streams with different filters - - **Validates**: Stream isolation, independent lifecycles - -- `POST /context_manager_safety/nested_contexts` - - **Purpose**: Verifies proper cleanup order in nested contexts - - **Tests**: Creates cluster → session → stream nested contexts - - **Validates**: - - Innermost (stream) closes first - - Middle (session) closes without affecting cluster - - Outer (cluster) closes last - - Main app session unaffected - -- `POST /context_manager_safety/cancellation` - - **Purpose**: Tests cancelled streaming operations clean up properly - - **Tests**: Starts stream, cancels mid-flight, verifies cleanup - - **Validates**: - - No resource leaks on cancellation - - Session remains usable - - New streams can be started - -- `GET /context_manager_safety/status` - - **Purpose**: Monitor resource state - - **Returns**: Current state of session, cluster, and keyspace - - **Validates**: Resource tracking and monitoring - -### 7. Monitoring & Operations -- `GET /` - Welcome message with API information -- `GET /health` - Health check with Cassandra connectivity test - - **Purpose**: Load balancer health checks, monitoring - - **Returns**: Status and Cassandra connectivity -- `GET /metrics` - Application metrics - - **Purpose**: Performance monitoring, debugging - - **Returns**: Query counts, error counts, performance stats -- `POST /shutdown` - Graceful shutdown simulation - - **Purpose**: Tests graceful shutdown patterns - - **Note**: In production, use process managers - -## Running the Example - -### Prerequisites - -1. **Cassandra** running on localhost:9042 (or use Docker/Podman): - ```bash - # Using Docker - docker run -d --name cassandra-test -p 9042:9042 cassandra:5 - - # OR using Podman - podman run -d --name cassandra-test -p 9042:9042 cassandra:5 - ``` - -2. **Python 3.12+** with dependencies: - ```bash - cd examples/fastapi_app - pip install -r requirements.txt - ``` - -### Start the Application - -```bash -# Development mode with auto-reload -uvicorn main:app --reload - -# Production mode -uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1 -``` - -**Note**: Use only 1 worker to ensure proper connection management. For scaling, run multiple instances behind a load balancer. - -### Environment Variables - -- `CASSANDRA_HOSTS` - Comma-separated list of Cassandra hosts (default: localhost) -- `CASSANDRA_PORT` - Cassandra port (default: 9042) -- `CASSANDRA_KEYSPACE` - Keyspace name (default: test_keyspace) - -Example: -```bash -export CASSANDRA_HOSTS=node1,node2,node3 -export CASSANDRA_PORT=9042 -export CASSANDRA_KEYSPACE=production -``` - -## Testing the Application - -### Automated Test Suite - -The test suite validates all functionality and serves as integration tests in CI: - -```bash -# Run all tests -pytest tests/test_fastapi_app.py -v - -# Or run all tests in the tests directory -pytest tests/ -v -``` - -Tests cover: -- ✅ Thread safety under high concurrency -- ✅ Memory efficiency with streaming -- ✅ Error handling consistency -- ✅ Performance characteristics -- ✅ All endpoint functionality -- ✅ Timeout handling -- ✅ Connection lifecycle -- ✅ **Context manager safety** - - Query error isolation - - Streaming error containment - - Concurrent stream independence - - Nested context cleanup order - - Cancellation handling - -### Manual Testing Examples - -#### Welcome and health check: -```bash -# Check if API is running -curl http://localhost:8000/ -# Returns: {"message": "FastAPI + async-cassandra example is running!"} - -# Detailed health check -curl http://localhost:8000/health -# Returns health status and Cassandra connectivity -``` - -#### Create a user: -```bash -curl -X POST http://localhost:8000/users \ - -H "Content-Type: application/json" \ - -d '{"name": "John Doe", "email": "john@example.com", "age": 30}' - -# Response includes auto-generated UUID and timestamps: -# { -# "id": "123e4567-e89b-12d3-a456-426614174000", -# "name": "John Doe", -# "email": "john@example.com", -# "age": 30, -# "created_at": "2024-01-01T12:00:00", -# "updated_at": "2024-01-01T12:00:00" -# } -``` - -#### Get a user: -```bash -# Replace with actual UUID from create response -curl http://localhost:8000/users/550e8400-e29b-41d4-a716-446655440000 - -# Returns 404 if user not found with proper error message -``` - -#### Update operations: -```bash -# Full update (PUT) - all fields required -curl -X PUT http://localhost:8000/users/550e8400-e29b-41d4-a716-446655440000 \ - -H "Content-Type: application/json" \ - -d '{"name": "Jane Doe", "email": "jane@example.com", "age": 31}' - -# Partial update (PATCH) - only specified fields updated -curl -X PATCH http://localhost:8000/users/550e8400-e29b-41d4-a716-446655440000 \ - -H "Content-Type: application/json" \ - -d '{"age": 32}' -``` - -#### Delete a user: -```bash -# Returns 204 No Content on success -curl -X DELETE http://localhost:8000/users/550e8400-e29b-41d4-a716-446655440000 - -# Idempotent - deleting non-existent user also returns 204 -``` - -#### List users with pagination: -```bash -# Default limit is 10, max is 100 -curl "http://localhost:8000/users?limit=10" - -# Response includes list of users -``` - -#### Stream large dataset: -```bash -# Stream users with age > 25, 100 rows per page -curl "http://localhost:8000/users/stream?age_filter=25&fetch_size=100&limit=10000" - -# Streams JSON array of users without loading all in memory -# fetch_size controls memory usage (rows per Cassandra page) -``` - -#### Page-by-page streaming: -```bash -# Get one page at a time with state tracking -curl "http://localhost:8000/users/stream/pages?age_filter=25&fetch_size=50" - -# Returns: -# { -# "users": [...], -# "has_more": true, -# "page_state": "encoded_state_for_next_page" -# } -``` - -#### Batch operations: -```bash -# Create multiple users atomically -curl -X POST http://localhost:8000/users/batch \ - -H "Content-Type: application/json" \ - -d '[ - {"name": "User 1", "email": "user1@example.com", "age": 25}, - {"name": "User 2", "email": "user2@example.com", "age": 30}, - {"name": "User 3", "email": "user3@example.com", "age": 35} - ]' - -# Returns count of created users -``` - -#### Test performance: -```bash -# Run 500 concurrent queries (async) -curl "http://localhost:8000/performance/async?requests=500" - -# Compare with sequential execution -curl "http://localhost:8000/performance/sync?requests=500" - -# Response shows timing and requests/second -``` - -#### Check health: -```bash -curl http://localhost:8000/health - -# Returns: -# { -# "status": "healthy", -# "cassandra": "connected", -# "keyspace": "example" -# } - -# Returns 503 if Cassandra is not available -``` - -#### View metrics: -```bash -curl http://localhost:8000/metrics - -# Returns application metrics: -# { -# "total_queries": 1234, -# "active_connections": 10, -# "queries_per_second": 45.2, -# "average_query_time_ms": 12.5, -# "errors_count": 0 -# } -``` - -#### Test error scenarios: -```bash -# Test timeout handling with short timeout -curl -H "X-Request-Timeout: 0.1" http://localhost:8000/slow_query -# Returns 504 Gateway Timeout - -# Test with adequate timeout -curl -H "X-Request-Timeout: 10" http://localhost:8000/slow_query -# Returns success after 5 seconds -``` - -#### Test context manager safety: -```bash -# Test query error isolation -curl -X POST http://localhost:8000/context_manager_safety/query_error - -# Test streaming error containment -curl -X POST http://localhost:8000/context_manager_safety/streaming_error - -# Test concurrent streams -curl -X POST http://localhost:8000/context_manager_safety/concurrent_streams - -# Test nested context managers -curl -X POST http://localhost:8000/context_manager_safety/nested_contexts - -# Test cancellation handling -curl -X POST http://localhost:8000/context_manager_safety/cancellation - -# Check resource status -curl http://localhost:8000/context_manager_safety/status -``` - -## Key Concepts Explained - -For in-depth explanations of the core concepts used in this example: - -- **[Why Async Matters for Cassandra](../../docs/why-async-wrapper.md)** - Understand the benefits of async operations for database drivers -- **[Streaming Large Datasets](../../docs/streaming.md)** - Learn about memory-efficient data processing -- **[Context Manager Safety](../../docs/context-managers-explained.md)** - Critical patterns for resource management -- **[Connection Pooling](../../docs/connection-pooling.md)** - How connections are managed efficiently - -For prepared statements best practices, see the examples in the code above and the [main documentation](../../README.md#prepared-statements). - -## Key Implementation Patterns - -This example demonstrates several critical implementation patterns. For detailed documentation, see: - -- **[Architecture Overview](../../docs/architecture.md)** - How async-cassandra works internally -- **[API Reference](../../docs/api.md)** - Complete API documentation -- **[Getting Started Guide](../../docs/getting-started.md)** - Basic usage patterns - -Key patterns implemented in this example: - -### Application Lifecycle Management -- FastAPI's lifespan context manager for proper setup/teardown -- Single cluster and session instance shared across the application -- Graceful shutdown handling - -### Prepared Statements -- All parameterized queries use prepared statements -- Statements prepared once and reused for better performance -- Protection against CQL injection attacks - -### Streaming for Large Results -- Memory-efficient processing using `execute_stream()` -- Configurable fetch size for memory control -- Automatic cleanup with context managers - -### Error Handling -- Consistent error responses with proper HTTP status codes -- Cassandra exceptions mapped to appropriate HTTP errors -- Validation errors handled with 422 responses - -### Context Manager Safety -- **[Context Manager Safety Documentation](../../docs/context-managers-explained.md)** - -### Concurrent Request Handling -- Safe concurrent query execution using `asyncio.gather()` -- Thread pool executor manages concurrent operations -- No data corruption or connection issues under load - -## Common Patterns and Best Practices - -For comprehensive patterns and best practices when using async-cassandra: -- **[Getting Started Guide](../../docs/getting-started.md)** - Basic usage patterns -- **[Troubleshooting Guide](../../docs/troubleshooting.md)** - Common issues and solutions -- **[Streaming Documentation](../../docs/streaming.md)** - Memory-efficient data processing -- **[Performance Guide](../../docs/performance.md)** - Optimization strategies - -The code in this example demonstrates these patterns in action. Key takeaways: -- Use a single global session shared across all requests -- Handle specific Cassandra errors and convert to appropriate HTTP responses -- Use streaming for large datasets to prevent memory exhaustion -- Always use context managers for proper resource cleanup - -## Production Considerations - -For detailed production deployment guidance, see: -- **[Connection Pooling](../../docs/connection-pooling.md)** - Connection management strategies -- **[Performance Guide](../../docs/performance.md)** - Optimization techniques -- **[Monitoring Guide](../../docs/metrics-monitoring.md)** - Metrics and observability -- **[Thread Pool Configuration](../../docs/thread-pool-configuration.md)** - Tuning for your workload - -Key production patterns demonstrated in this example: -- Single global session shared across all requests -- Health check endpoints for load balancers -- Proper error handling and timeout management -- Input validation and security best practices - -## CI/CD Integration - -This example is automatically tested in our CI pipeline to ensure: -- async-cassandra integrates correctly with FastAPI -- All async operations work as expected -- No event loop blocking occurs -- Memory usage remains bounded with streaming -- Error handling works correctly - -## Extending the Example - -To add new features: - -1. **New Endpoints**: Follow existing patterns for consistency -2. **Authentication**: Add FastAPI middleware for auth -3. **Rate Limiting**: Use FastAPI middleware or Redis -4. **Caching**: Add Redis for frequently accessed data -5. **API Versioning**: Use FastAPI's APIRouter for versioning - -## Troubleshooting - -For comprehensive troubleshooting guidance, see: -- **[Troubleshooting Guide](../../docs/troubleshooting.md)** - Common issues and solutions - -Quick troubleshooting tips: -- **Connection issues**: Check Cassandra is running and environment variables are correct -- **Memory issues**: Use streaming endpoints and adjust `fetch_size` -- **Resource leaks**: Run `/context_manager_safety/*` endpoints to diagnose -- **Performance issues**: See the [Performance Guide](../../docs/performance.md) - -## Complete Example Workflow - -Here's a typical workflow demonstrating all key features: - -```bash -# 1. Check system health -curl http://localhost:8000/health - -# 2. Create some users -curl -X POST http://localhost:8000/users -H "Content-Type: application/json" \ - -d '{"name": "Alice", "email": "alice@example.com", "age": 28}' - -curl -X POST http://localhost:8000/users -H "Content-Type: application/json" \ - -d '{"name": "Bob", "email": "bob@example.com", "age": 35}' - -# 3. Create users in batch -curl -X POST http://localhost:8000/users/batch -H "Content-Type: application/json" \ - -d '[ - {"name": "Charlie", "email": "charlie@example.com", "age": 42}, - {"name": "Diana", "email": "diana@example.com", "age": 28}, - {"name": "Eve", "email": "eve@example.com", "age": 35} - ]' - -# 4. List all users -curl http://localhost:8000/users?limit=10 - -# 5. Stream users with age > 30 -curl "http://localhost:8000/users/stream?age_filter=30&fetch_size=2" - -# 6. Test performance -curl http://localhost:8000/performance/async?requests=100 - -# 7. Test context manager safety -curl -X POST http://localhost:8000/context_manager_safety/concurrent_streams - -# 8. View metrics -curl http://localhost:8000/metrics - -# 9. Clean up (delete a user) -curl -X DELETE http://localhost:8000/users/{user-id-from-create} -``` - -This example serves as both a learning resource and a production-ready template for building FastAPI applications with Cassandra using async-cassandra. diff --git a/examples/fastapi_app/docker-compose.yml b/examples/fastapi_app/docker-compose.yml deleted file mode 100644 index e2d9304..0000000 --- a/examples/fastapi_app/docker-compose.yml +++ /dev/null @@ -1,134 +0,0 @@ -version: '3.8' - -# FastAPI + async-cassandra Example Application -# This compose file sets up a complete development environment - -services: - # Apache Cassandra Database - cassandra: - image: cassandra:5.0 - container_name: fastapi-cassandra - ports: - - "9042:9042" # CQL native transport port - environment: - # Cluster configuration - - CASSANDRA_CLUSTER_NAME=FastAPICluster - - CASSANDRA_DC=datacenter1 - - CASSANDRA_ENDPOINT_SNITCH=GossipingPropertyFileSnitch - - # Memory settings (optimized for stability) - - HEAP_NEWSIZE=3G - - MAX_HEAP_SIZE=12G - - JVM_OPTS=-XX:+UseG1GC -XX:G1RSetUpdatingPauseTimePercent=5 -XX:MaxGCPauseMillis=300 - - # Enable authentication (optional) - # - CASSANDRA_AUTHENTICATOR=PasswordAuthenticator - # - CASSANDRA_AUTHORIZER=CassandraAuthorizer - - volumes: - # Persist data between container restarts - - cassandra_data:/var/lib/cassandra - - # Resource limits for stability - deploy: - resources: - limits: - memory: 16G - reservations: - memory: 16G - - healthcheck: - test: ["CMD-SHELL", "nodetool info | grep -q 'Native Transport active: true' && cqlsh -e 'SELECT now() FROM system.local'"] - interval: 30s - timeout: 10s - retries: 10 - start_period: 90s - - networks: - - app-network - - # FastAPI Application - app: - build: - context: . - dockerfile: Dockerfile - container_name: fastapi-app - ports: - - "8000:8000" # FastAPI port - environment: - # Cassandra connection settings - - CASSANDRA_HOSTS=cassandra - - CASSANDRA_PORT=9042 - - # Application settings - - LOG_LEVEL=INFO - - # Optional: Authentication (if enabled in Cassandra) - # - CASSANDRA_USERNAME=cassandra - # - CASSANDRA_PASSWORD=cassandra - - depends_on: - cassandra: - condition: service_healthy - - # Restart policy - restart: unless-stopped - - # Resource limits (adjust based on needs) - deploy: - resources: - limits: - cpus: '1' - memory: 512M - reservations: - cpus: '0.5' - memory: 256M - - networks: - - app-network - - # Mount source code for development (remove in production) - volumes: - - ./main.py:/app/main.py:ro - - # Override command for development with auto-reload - command: ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] - - # Optional: Prometheus for metrics - # prometheus: - # image: prom/prometheus:latest - # container_name: prometheus - # ports: - # - "9090:9090" - # volumes: - # - ./prometheus.yml:/etc/prometheus/prometheus.yml - # - prometheus_data:/prometheus - # networks: - # - app-network - - # Optional: Grafana for visualization - # grafana: - # image: grafana/grafana:latest - # container_name: grafana - # ports: - # - "3000:3000" - # environment: - # - GF_SECURITY_ADMIN_PASSWORD=admin - # volumes: - # - grafana_data:/var/lib/grafana - # networks: - # - app-network - -# Networks -networks: - app-network: - driver: bridge - -# Volumes -volumes: - cassandra_data: - driver: local - # prometheus_data: - # driver: local - # grafana_data: - # driver: local diff --git a/examples/fastapi_app/main.py b/examples/fastapi_app/main.py deleted file mode 100644 index f879257..0000000 --- a/examples/fastapi_app/main.py +++ /dev/null @@ -1,1215 +0,0 @@ -""" -Simple FastAPI example using async-cassandra. - -This demonstrates basic CRUD operations with Cassandra using the async wrapper. -Run with: uvicorn main:app --reload -""" - -import asyncio -import os -import uuid -from contextlib import asynccontextmanager -from datetime import datetime -from typing import List, Optional -from uuid import UUID - -from cassandra import OperationTimedOut, ReadTimeout, Unavailable, WriteTimeout - -# Import Cassandra driver exceptions for proper error detection -from cassandra.cluster import Cluster as SyncCluster -from cassandra.cluster import NoHostAvailable -from cassandra.policies import ConstantReconnectionPolicy -from fastapi import FastAPI, HTTPException, Query, Request -from pydantic import BaseModel - -from async_cassandra import AsyncCluster, StreamConfig - - -# Pydantic models -class UserCreate(BaseModel): - name: str - email: str - age: int - - -class User(BaseModel): - id: str - name: str - email: str - age: int - created_at: datetime - updated_at: datetime - - -class UserUpdate(BaseModel): - name: Optional[str] = None - email: Optional[str] = None - age: Optional[int] = None - - -# Global session, cluster, and keyspace -session = None -cluster = None -sync_session = None # For synchronous performance comparison -sync_cluster = None # For synchronous performance comparison -keyspace = "example" - - -def is_cassandra_unavailable_error(error: Exception) -> bool: - """ - Determine if an error indicates Cassandra is unavailable. - - This function checks for specific Cassandra driver exceptions that indicate - the database is not reachable or available. - """ - # Direct Cassandra driver exceptions - if isinstance( - error, (NoHostAvailable, Unavailable, OperationTimedOut, ReadTimeout, WriteTimeout) - ): - return True - - # Check error message for additional patterns - error_msg = str(error).lower() - unavailability_keywords = [ - "no host available", - "all hosts", - "connection", - "timeout", - "unavailable", - "no replicas", - "not enough replicas", - "cannot achieve consistency", - "operation timed out", - "read timeout", - "write timeout", - "connection pool", - "connection closed", - "connection refused", - "unable to connect", - ] - - return any(keyword in error_msg for keyword in unavailability_keywords) - - -def handle_cassandra_error(error: Exception, operation: str = "operation") -> HTTPException: - """ - Convert a Cassandra error to an appropriate HTTP exception. - - Returns 503 for availability issues, 500 for other errors. - """ - if is_cassandra_unavailable_error(error): - # Log the specific error type for debugging - error_type = type(error).__name__ - return HTTPException( - status_code=503, - detail=f"Service temporarily unavailable: Cassandra connection issue ({error_type}: {str(error)})", - ) - else: - # Other errors (like InvalidRequest) get 500 - return HTTPException( - status_code=500, detail=f"Internal server error during {operation}: {str(error)}" - ) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Manage database lifecycle.""" - global session, cluster, sync_session, sync_cluster - - try: - # Startup - connect to Cassandra with constant reconnection policy - # IMPORTANT: Using ConstantReconnectionPolicy with 2-second delay for testing - # This ensures quick reconnection during integration tests where we simulate - # Cassandra outages. In production, you might want ExponentialReconnectionPolicy - # to avoid overwhelming a recovering cluster. - # IMPORTANT: Use 127.0.0.1 instead of localhost to force IPv4 - contact_points = os.getenv("CASSANDRA_HOSTS", "127.0.0.1").split(",") - # Replace any "localhost" with "127.0.0.1" to ensure IPv4 - contact_points = ["127.0.0.1" if cp == "localhost" else cp for cp in contact_points] - - cluster = AsyncCluster( - contact_points=contact_points, - port=int(os.getenv("CASSANDRA_PORT", "9042")), - reconnection_policy=ConstantReconnectionPolicy( - delay=2.0 - ), # Reconnect every 2 seconds for testing - connect_timeout=10.0, # Quick connection timeout for faster test feedback - ) - session = await cluster.connect() - except Exception as e: - print(f"Failed to connect to Cassandra: {type(e).__name__}: {e}") - # Don't fail startup completely, allow health check to report unhealthy - session = None - yield - return - - # Create keyspace and table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS example - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("example") - - # Also create sync cluster for performance comparison - try: - sync_cluster = SyncCluster( - contact_points=contact_points, - port=int(os.getenv("CASSANDRA_PORT", "9042")), - reconnection_policy=ConstantReconnectionPolicy(delay=2.0), - connect_timeout=10.0, - protocol_version=5, - ) - sync_session = sync_cluster.connect() - sync_session.set_keyspace("example") - except Exception as e: - print(f"Failed to create sync cluster: {e}") - sync_session = None - - # Drop and recreate table for clean test environment - await session.execute("DROP TABLE IF EXISTS users") - await session.execute( - """ - CREATE TABLE users ( - id UUID PRIMARY KEY, - name TEXT, - email TEXT, - age INT, - created_at TIMESTAMP, - updated_at TIMESTAMP - ) - """ - ) - - yield - - # Shutdown - if session: - await session.close() - if cluster: - await cluster.shutdown() - if sync_session: - sync_session.shutdown() - if sync_cluster: - sync_cluster.shutdown() - - -# Create FastAPI app -app = FastAPI( - title="FastAPI + async-cassandra Example", - description="Simple CRUD API using async-cassandra", - version="1.0.0", - lifespan=lifespan, -) - - -@app.get("/") -async def root(): - """Root endpoint.""" - return {"message": "FastAPI + async-cassandra example is running!"} - - -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - try: - # Simple health check - verify session is available - if session is None: - return { - "status": "unhealthy", - "cassandra_connected": False, - "timestamp": datetime.now().isoformat(), - } - - # Test connection with a simple query - await session.execute("SELECT now() FROM system.local") - return { - "status": "healthy", - "cassandra_connected": True, - "timestamp": datetime.now().isoformat(), - } - except Exception: - return { - "status": "unhealthy", - "cassandra_connected": False, - "timestamp": datetime.now().isoformat(), - } - - -@app.post("/users", response_model=User, status_code=201) -async def create_user(user: UserCreate): - """Create a new user.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - user_id = uuid.uuid4() - now = datetime.now() - - # Use prepared statement for better performance - stmt = await session.prepare( - "INSERT INTO users (id, name, email, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)" - ) - await session.execute(stmt, [user_id, user.name, user.email, user.age, now, now]) - - return User( - id=str(user_id), - name=user.name, - email=user.email, - age=user.age, - created_at=now, - updated_at=now, - ) - except Exception as e: - raise handle_cassandra_error(e, "user creation") - - -@app.get("/users", response_model=List[User]) -async def list_users(limit: int = Query(10, ge=1, le=10000)): - """List all users.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - # Use prepared statement with validated limit - stmt = await session.prepare("SELECT * FROM users LIMIT ?") - result = await session.execute(stmt, [limit]) - - users = [] - async for row in result: - users.append( - User( - id=str(row.id), - name=row.name, - email=row.email, - age=row.age, - created_at=row.created_at, - updated_at=row.updated_at, - ) - ) - - return users - except Exception as e: - error_msg = str(e) - if any( - keyword in error_msg.lower() - for keyword in ["unavailable", "nohost", "connection", "timeout"] - ): - raise HTTPException( - status_code=503, - detail=f"Service temporarily unavailable: Cassandra connection issue - {error_msg}", - ) - raise HTTPException(status_code=500, detail=f"Internal server error: {error_msg}") - - -# Streaming endpoints - must come before /users/{user_id} to avoid route conflict -@app.get("/users/stream") -async def stream_users( - limit: int = Query(1000, ge=0, le=10000), fetch_size: int = Query(100, ge=10, le=1000) -): - """Stream users data for large result sets.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - # Handle special case where limit=0 - if limit == 0: - return { - "users": [], - "metadata": { - "total_returned": 0, - "pages_fetched": 0, - "fetch_size": fetch_size, - "streaming_enabled": True, - }, - } - - stream_config = StreamConfig(fetch_size=fetch_size) - - # Use context manager for proper resource cleanup - # Note: LIMIT not needed - fetch_size controls data flow - stmt = await session.prepare("SELECT * FROM users") - async with await session.execute_stream(stmt, stream_config=stream_config) as result: - users = [] - async for row in result: - # Handle both dict-like and object-like row access - if hasattr(row, "__getitem__"): - # Dictionary-like access - try: - user_dict = { - "id": str(row["id"]), - "name": row["name"], - "email": row["email"], - "age": row["age"], - "created_at": row["created_at"].isoformat(), - "updated_at": row["updated_at"].isoformat(), - } - except (KeyError, TypeError): - # Fall back to attribute access - user_dict = { - "id": str(row.id), - "name": row.name, - "email": row.email, - "age": row.age, - "created_at": row.created_at.isoformat(), - "updated_at": row.updated_at.isoformat(), - } - else: - # Object-like access - user_dict = { - "id": str(row.id), - "name": row.name, - "email": row.email, - "age": row.age, - "created_at": row.created_at.isoformat(), - "updated_at": row.updated_at.isoformat(), - } - users.append(user_dict) - - return { - "users": users, - "metadata": { - "total_returned": len(users), - "pages_fetched": result.page_number, - "fetch_size": fetch_size, - "streaming_enabled": True, - }, - } - - except Exception as e: - raise handle_cassandra_error(e, "streaming users") - - -@app.get("/users/stream/pages") -async def stream_users_by_pages( - limit: int = Query(1000, ge=0, le=10000), - fetch_size: int = Query(100, ge=10, le=1000), - max_pages: int = Query(10, ge=0, le=100), -): - """Stream users data page by page for memory efficiency.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - # Handle special case where limit=0 or max_pages=0 - if limit == 0 or max_pages == 0: - return { - "total_rows_processed": 0, - "pages_info": [], - "metadata": { - "fetch_size": fetch_size, - "max_pages_limit": max_pages, - "streaming_mode": "page_by_page", - }, - } - - stream_config = StreamConfig(fetch_size=fetch_size, max_pages=max_pages) - - # Use context manager for automatic cleanup - # Note: LIMIT not needed - fetch_size controls data flow - stmt = await session.prepare("SELECT * FROM users") - async with await session.execute_stream(stmt, stream_config=stream_config) as result: - pages_info = [] - total_processed = 0 - - async for page in result.pages(): - page_size = len(page) - total_processed += page_size - - # Extract sample user data, handling both dict-like and object-like access - sample_user = None - if page: - first_row = page[0] - if hasattr(first_row, "__getitem__"): - # Dictionary-like access - try: - sample_user = { - "id": str(first_row["id"]), - "name": first_row["name"], - "email": first_row["email"], - } - except (KeyError, TypeError): - # Fall back to attribute access - sample_user = { - "id": str(first_row.id), - "name": first_row.name, - "email": first_row.email, - } - else: - # Object-like access - sample_user = { - "id": str(first_row.id), - "name": first_row.name, - "email": first_row.email, - } - - pages_info.append( - { - "page_number": len(pages_info) + 1, - "rows_in_page": page_size, - "sample_user": sample_user, - } - ) - - return { - "total_rows_processed": total_processed, - "pages_info": pages_info, - "metadata": { - "fetch_size": fetch_size, - "max_pages_limit": max_pages, - "streaming_mode": "page_by_page", - }, - } - - except Exception as e: - raise handle_cassandra_error(e, "streaming users by pages") - - -@app.get("/users/{user_id}", response_model=User) -async def get_user(user_id: str): - """Get user by ID.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - user_uuid = uuid.UUID(user_id) - except ValueError: - raise HTTPException(status_code=400, detail="Invalid UUID") - - try: - stmt = await session.prepare("SELECT * FROM users WHERE id = ?") - result = await session.execute(stmt, [user_uuid]) - row = result.one() - - if not row: - raise HTTPException(status_code=404, detail="User not found") - - return User( - id=str(row.id), - name=row.name, - email=row.email, - age=row.age, - created_at=row.created_at, - updated_at=row.updated_at, - ) - except HTTPException: - raise - except Exception as e: - raise handle_cassandra_error(e, "checking user existence") - - -@app.delete("/users/{user_id}", status_code=204) -async def delete_user(user_id: str): - """Delete user by ID.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - user_uuid = uuid.UUID(user_id) - except ValueError: - raise HTTPException(status_code=400, detail="Invalid user ID format") - - try: - stmt = await session.prepare("DELETE FROM users WHERE id = ?") - await session.execute(stmt, [user_uuid]) - - return None # 204 No Content - except Exception as e: - error_msg = str(e) - if any( - keyword in error_msg.lower() - for keyword in ["unavailable", "nohost", "connection", "timeout"] - ): - raise HTTPException( - status_code=503, - detail=f"Service temporarily unavailable: Cassandra connection issue - {error_msg}", - ) - raise HTTPException(status_code=500, detail=f"Internal server error: {error_msg}") - - -@app.put("/users/{user_id}", response_model=User) -async def update_user(user_id: str, user_update: UserUpdate): - """Update user by ID.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - user_uuid = uuid.UUID(user_id) - except ValueError: - raise HTTPException(status_code=400, detail="Invalid user ID format") - - try: - # First check if user exists - check_stmt = await session.prepare("SELECT * FROM users WHERE id = ?") - result = await session.execute(check_stmt, [user_uuid]) - existing_user = result.one() - - if not existing_user: - raise HTTPException(status_code=404, detail="User not found") - except HTTPException: - raise - except Exception as e: - raise handle_cassandra_error(e, "checking user existence") - - try: - # Build update query dynamically based on provided fields - update_fields = [] - params = [] - - if user_update.name is not None: - update_fields.append("name = ?") - params.append(user_update.name) - - if user_update.email is not None: - update_fields.append("email = ?") - params.append(user_update.email) - - if user_update.age is not None: - update_fields.append("age = ?") - params.append(user_update.age) - - if not update_fields: - raise HTTPException(status_code=400, detail="No fields to update") - - # Always update the updated_at timestamp - update_fields.append("updated_at = ?") - params.append(datetime.now()) - params.append(user_uuid) # WHERE clause - - # Build a static query based on which fields are provided - # This approach avoids dynamic SQL construction - if len(update_fields) == 1: # Only updated_at - update_stmt = await session.prepare("UPDATE users SET updated_at = ? WHERE id = ?") - elif len(update_fields) == 2: # One field + updated_at - if "name = ?" in update_fields: - update_stmt = await session.prepare( - "UPDATE users SET name = ?, updated_at = ? WHERE id = ?" - ) - elif "email = ?" in update_fields: - update_stmt = await session.prepare( - "UPDATE users SET email = ?, updated_at = ? WHERE id = ?" - ) - elif "age = ?" in update_fields: - update_stmt = await session.prepare( - "UPDATE users SET age = ?, updated_at = ? WHERE id = ?" - ) - elif len(update_fields) == 3: # Two fields + updated_at - if "name = ?" in update_fields and "email = ?" in update_fields: - update_stmt = await session.prepare( - "UPDATE users SET name = ?, email = ?, updated_at = ? WHERE id = ?" - ) - elif "name = ?" in update_fields and "age = ?" in update_fields: - update_stmt = await session.prepare( - "UPDATE users SET name = ?, age = ?, updated_at = ? WHERE id = ?" - ) - elif "email = ?" in update_fields and "age = ?" in update_fields: - update_stmt = await session.prepare( - "UPDATE users SET email = ?, age = ?, updated_at = ? WHERE id = ?" - ) - else: # All fields - update_stmt = await session.prepare( - "UPDATE users SET name = ?, email = ?, age = ?, updated_at = ? WHERE id = ?" - ) - - await session.execute(update_stmt, params) - - # Return updated user - result = await session.execute(check_stmt, [user_uuid]) - updated_user = result.one() - - return User( - id=str(updated_user.id), - name=updated_user.name, - email=updated_user.email, - age=updated_user.age, - created_at=updated_user.created_at, - updated_at=updated_user.updated_at, - ) - except HTTPException: - raise - except Exception as e: - raise handle_cassandra_error(e, "checking user existence") - - -@app.patch("/users/{user_id}", response_model=User) -async def partial_update_user(user_id: str, user_update: UserUpdate): - """Partial update user by ID (same as PUT in this implementation).""" - return await update_user(user_id, user_update) - - -# Performance testing endpoints -@app.get("/performance/async") -async def test_async_performance(requests: int = Query(100, ge=1, le=1000)): - """Test async performance with concurrent queries.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - import time - - try: - start_time = time.time() - - # Prepare statement once - stmt = await session.prepare("SELECT * FROM users LIMIT 1") - - # Execute queries concurrently - async def execute_query(): - return await session.execute(stmt) - - tasks = [execute_query() for _ in range(requests)] - results = await asyncio.gather(*tasks) - - end_time = time.time() - duration = end_time - start_time - - return { - "requests": requests, - "total_time": duration, - "requests_per_second": requests / duration if duration > 0 else 0, - "avg_time_per_request": duration / requests if requests > 0 else 0, - "successful_requests": len(results), - "mode": "async", - } - except Exception as e: - raise handle_cassandra_error(e, "performance test") - - -@app.get("/performance/sync") -async def test_sync_performance(requests: int = Query(100, ge=1, le=1000)): - """Test TRUE sync performance using synchronous cassandra-driver.""" - if sync_session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Sync Cassandra connection not established", - ) - - import time - - try: - # Run synchronous operations in a thread pool to not block the event loop - import concurrent.futures - - def run_sync_test(): - start_time = time.time() - - # Prepare statement once - stmt = sync_session.prepare("SELECT * FROM users LIMIT 1") - - # Execute queries sequentially with the SYNC driver - results = [] - for _ in range(requests): - result = sync_session.execute(stmt) - results.append(result) - - end_time = time.time() - duration = end_time - start_time - - return { - "requests": requests, - "total_time": duration, - "requests_per_second": requests / duration if duration > 0 else 0, - "avg_time_per_request": duration / requests if requests > 0 else 0, - "successful_requests": len(results), - "mode": "sync (true blocking)", - } - - # Run in thread pool to avoid blocking the event loop - loop = asyncio.get_event_loop() - with concurrent.futures.ThreadPoolExecutor() as pool: - result = await loop.run_in_executor(pool, run_sync_test) - - return result - except Exception as e: - raise handle_cassandra_error(e, "sync performance test") - - -# Batch operations endpoint -@app.post("/users/batch", status_code=201) -async def create_users_batch(batch_data: dict): - """Create multiple users in a batch.""" - if session is None: - raise HTTPException( - status_code=503, - detail="Service temporarily unavailable: Cassandra connection not established", - ) - - try: - users = batch_data.get("users", []) - created_users = [] - - for user_data in users: - user_id = uuid.uuid4() - now = datetime.now() - - # Create user dict with proper fields - user_dict = { - "id": str(user_id), - "name": user_data.get("name", user_data.get("username", "")), - "email": user_data["email"], - "age": user_data.get("age", 25), - "created_at": now.isoformat(), - "updated_at": now.isoformat(), - } - - # Insert into database - stmt = await session.prepare( - "INSERT INTO users (id, name, email, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)" - ) - await session.execute( - stmt, [user_id, user_dict["name"], user_dict["email"], user_dict["age"], now, now] - ) - - created_users.append(user_dict) - - return {"created": created_users} - except Exception as e: - raise handle_cassandra_error(e, "batch user creation") - - -# Metrics endpoint -@app.get("/metrics") -async def get_metrics(): - """Get application metrics.""" - # Simple metrics implementation - return { - "total_requests": 1000, # Placeholder - "query_performance": { - "avg_response_time_ms": 50, - "p95_response_time_ms": 100, - "p99_response_time_ms": 200, - }, - "cassandra_connections": {"active": 10, "idle": 5, "total": 15}, - } - - -# Shutdown endpoint -@app.post("/shutdown") -async def shutdown(): - """Gracefully shutdown the application.""" - # In a real app, this would trigger graceful shutdown - return {"message": "Shutdown initiated"} - - -# Slow query endpoint for testing -@app.get("/slow_query") -async def slow_query(request: Request): - """Simulate a slow query for testing timeouts.""" - - # Check for timeout header - timeout_header = request.headers.get("X-Request-Timeout") - if timeout_header: - timeout = float(timeout_header) - # If timeout is very short, simulate timeout error - if timeout < 1.0: - raise HTTPException(status_code=504, detail="Gateway Timeout") - - await asyncio.sleep(5) # Simulate slow operation - return {"message": "Slow query completed"} - - -# Long running query endpoint -@app.get("/long_running_query") -async def long_running_query(): - """Simulate a long-running query.""" - await asyncio.sleep(10) # Simulate very long operation - return {"message": "Long query completed"} - - -# ============================================================================ -# Context Manager Safety Endpoints -# ============================================================================ - - -@app.post("/context_manager_safety/query_error") -async def test_query_error_session_safety(): - """Test that query errors don't close the session.""" - # Track session state - session_id_before = id(session) - is_closed_before = session.is_closed - - # Execute a bad query that will fail - try: - await session.execute("SELECT * FROM non_existent_table_xyz") - except Exception as e: - error_message = str(e) - - # Verify session is still usable - session_id_after = id(session) - is_closed_after = session.is_closed - - # Try a valid query to prove session works - result = await session.execute("SELECT release_version FROM system.local") - version = result.one().release_version - - return { - "test": "query_error_session_safety", - "session_unchanged": session_id_before == session_id_after, - "session_open": not is_closed_after and not is_closed_before, - "error_caught": error_message, - "session_still_works": bool(version), - "cassandra_version": version, - } - - -@app.post("/context_manager_safety/streaming_error") -async def test_streaming_error_session_safety(): - """Test that streaming errors don't close the session.""" - session_id_before = id(session) - error_message = None - stream_completed = False - - # Try to stream from non-existent table - try: - async with await session.execute_stream( - "SELECT * FROM non_existent_stream_table" - ) as stream: - async for row in stream: - pass - stream_completed = True - except Exception as e: - error_message = str(e) - - # Verify session is still usable - session_id_after = id(session) - - # Try a valid streaming query - row_count = 0 - # Use hardcoded query since keyspace is constant - stmt = await session.prepare("SELECT * FROM example.users LIMIT ?") - async with await session.execute_stream(stmt, [10]) as stream: - async for row in stream: - row_count += 1 - - return { - "test": "streaming_error_session_safety", - "session_unchanged": session_id_before == session_id_after, - "session_open": not session.is_closed, - "streaming_error_caught": bool(error_message), - "error_message": error_message, - "stream_completed": stream_completed, - "session_still_streams": row_count > 0, - "rows_after_error": row_count, - } - - -@app.post("/context_manager_safety/concurrent_streams") -async def test_concurrent_streams(): - """Test multiple concurrent streams don't interfere.""" - - # Create test data - users_to_create = [] - for i in range(30): - users_to_create.append( - { - "id": str(uuid.uuid4()), - "name": f"Stream Test User {i}", - "email": f"stream{i}@test.com", - "age": 20 + (i % 3) * 10, # Ages: 20, 30, 40 - } - ) - - # Insert test data - for user in users_to_create: - stmt = await session.prepare( - "INSERT INTO example.users (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - await session.execute( - stmt, - [UUID(user["id"]), user["name"], user["email"], user["age"]], - ) - - # Stream different age groups concurrently - async def stream_age_group(age: int) -> dict: - count = 0 - users = [] - - config = StreamConfig(fetch_size=5) - stmt = await session.prepare("SELECT * FROM example.users WHERE age = ? ALLOW FILTERING") - async with await session.execute_stream( - stmt, - [age], - stream_config=config, - ) as stream: - async for row in stream: - count += 1 - users.append(row.name) - - return {"age": age, "count": count, "users": users[:3]} # First 3 names - - # Run concurrent streams - results = await asyncio.gather(stream_age_group(20), stream_age_group(30), stream_age_group(40)) - - # Clean up test data - for user in users_to_create: - stmt = await session.prepare("DELETE FROM example.users WHERE id = ?") - await session.execute(stmt, [UUID(user["id"])]) - - return { - "test": "concurrent_streams", - "streams_completed": len(results), - "all_streams_independent": all(r["count"] == 10 for r in results), - "results": results, - "session_still_open": not session.is_closed, - } - - -@app.post("/context_manager_safety/nested_contexts") -async def test_nested_context_managers(): - """Test nested context managers close in correct order.""" - events = [] - - # Create a temporary keyspace for this test - temp_keyspace = f"test_nested_{uuid.uuid4().hex[:8]}" - - try: - # Create new cluster context - async with AsyncCluster(["127.0.0.1"]) as test_cluster: - events.append("cluster_opened") - - # Create session context - async with await test_cluster.connect() as test_session: - events.append("session_opened") - - # Create keyspace with safe identifier - # Validate keyspace name contains only safe characters - if not temp_keyspace.replace("_", "").isalnum(): - raise ValueError("Invalid keyspace name") - - # Use parameterized query for keyspace creation is not supported - # So we validate the input first - await test_session.execute( - f""" - CREATE KEYSPACE {temp_keyspace} - WITH REPLICATION = {{ - 'class': 'SimpleStrategy', - 'replication_factor': 1 - }} - """ - ) - await test_session.set_keyspace(temp_keyspace) - - # Create table - await test_session.execute( - """ - CREATE TABLE test_table ( - id UUID PRIMARY KEY, - value INT - ) - """ - ) - - # Insert test data - for i in range(5): - stmt = await test_session.prepare( - "INSERT INTO test_table (id, value) VALUES (?, ?)" - ) - await test_session.execute(stmt, [uuid.uuid4(), i]) - - # Create streaming context - row_count = 0 - async with await test_session.execute_stream("SELECT * FROM test_table") as stream: - events.append("stream_opened") - async for row in stream: - row_count += 1 - events.append("stream_closed") - - # Verify session still works after stream closed - result = await test_session.execute("SELECT COUNT(*) FROM test_table") - count_after_stream = result.one()[0] - events.append(f"session_works_after_stream:{count_after_stream}") - - # Session will close here - events.append("session_closing") - - events.append("session_closed") - - # Verify cluster still works after session closed - async with await test_cluster.connect() as verify_session: - result = await verify_session.execute("SELECT now() FROM system.local") - events.append(f"cluster_works_after_session:{bool(result.one())}") - - # Clean up keyspace - # Validate keyspace name before using in DROP - if temp_keyspace.replace("_", "").isalnum(): - await verify_session.execute(f"DROP KEYSPACE IF EXISTS {temp_keyspace}") - - # Cluster will close here - events.append("cluster_closing") - - events.append("cluster_closed") - - except Exception as e: - events.append(f"error:{str(e)}") - # Try to clean up - try: - # Validate keyspace name before cleanup - if temp_keyspace.replace("_", "").isalnum(): - await session.execute(f"DROP KEYSPACE IF EXISTS {temp_keyspace}") - except Exception: - pass - - # Verify our main session is still working - main_session_works = False - try: - result = await session.execute("SELECT now() FROM system.local") - main_session_works = bool(result.one()) - except Exception: - pass - - return { - "test": "nested_context_managers", - "events": events, - "correct_order": events - == [ - "cluster_opened", - "session_opened", - "stream_opened", - "stream_closed", - "session_works_after_stream:5", - "session_closing", - "session_closed", - "cluster_works_after_session:True", - "cluster_closing", - "cluster_closed", - ], - "row_count": row_count, - "main_session_unaffected": main_session_works, - } - - -@app.post("/context_manager_safety/cancellation") -async def test_streaming_cancellation(): - """Test that cancelled streaming operations clean up properly.""" - - # Create test data - test_ids = [] - for i in range(100): - test_id = uuid.uuid4() - test_ids.append(test_id) - stmt = await session.prepare( - "INSERT INTO example.users (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - await session.execute( - stmt, - [test_id, f"Cancel Test {i}", f"cancel{i}@test.com", 25], - ) - - # Start a streaming operation that we'll cancel - rows_before_cancel = 0 - cancelled = False - error_type = None - - async def stream_with_delay(): - nonlocal rows_before_cancel - try: - stmt = await session.prepare( - "SELECT * FROM example.users WHERE age = ? ALLOW FILTERING" - ) - async with await session.execute_stream(stmt, [25]) as stream: - async for row in stream: - rows_before_cancel += 1 - # Add delay to make cancellation more likely - await asyncio.sleep(0.01) - except asyncio.CancelledError: - nonlocal cancelled - cancelled = True - raise - except Exception as e: - nonlocal error_type - error_type = type(e).__name__ - raise - - # Create task and cancel it - task = asyncio.create_task(stream_with_delay()) - await asyncio.sleep(0.1) # Let it process some rows - task.cancel() - - # Wait for cancellation - try: - await task - except asyncio.CancelledError: - pass - - # Verify session still works - session_works = False - row_count_after = 0 - - try: - # Count rows to verify session works - stmt = await session.prepare( - "SELECT COUNT(*) FROM example.users WHERE age = ? ALLOW FILTERING" - ) - result = await session.execute(stmt, [25]) - row_count_after = result.one()[0] - session_works = True - - # Try streaming again - new_stream_count = 0 - stmt = await session.prepare( - "SELECT * FROM example.users WHERE age = ? LIMIT ? ALLOW FILTERING" - ) - async with await session.execute_stream(stmt, [25, 10]) as stream: - async for row in stream: - new_stream_count += 1 - - except Exception as e: - error_type = f"post_cancel_error:{type(e).__name__}" - - # Clean up test data - for test_id in test_ids: - stmt = await session.prepare("DELETE FROM example.users WHERE id = ?") - await session.execute(stmt, [test_id]) - - return { - "test": "streaming_cancellation", - "rows_processed_before_cancel": rows_before_cancel, - "was_cancelled": cancelled, - "session_still_works": session_works, - "total_rows": row_count_after, - "new_stream_worked": new_stream_count == 10, - "error_type": error_type, - "session_open": not session.is_closed, - } - - -@app.get("/context_manager_safety/status") -async def context_manager_safety_status(): - """Get current session and cluster status.""" - return { - "session_open": not session.is_closed, - "session_id": id(session), - "cluster_open": not cluster.is_closed, - "cluster_id": id(cluster), - "keyspace": keyspace, - } - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/fastapi_app/main_enhanced.py b/examples/fastapi_app/main_enhanced.py deleted file mode 100644 index 8393f8a..0000000 --- a/examples/fastapi_app/main_enhanced.py +++ /dev/null @@ -1,578 +0,0 @@ -""" -Enhanced FastAPI example demonstrating all async-cassandra features. - -This comprehensive example demonstrates: -- Timeout handling -- Streaming with memory management -- Connection monitoring -- Rate limiting -- Error handling -- Metrics collection - -Run with: uvicorn main_enhanced:app --reload -""" - -import asyncio -import os -import uuid -from contextlib import asynccontextmanager -from datetime import datetime -from typing import List, Optional - -from fastapi import BackgroundTasks, FastAPI, HTTPException, Query -from pydantic import BaseModel - -from async_cassandra import AsyncCluster, StreamConfig -from async_cassandra.constants import MAX_CONCURRENT_QUERIES -from async_cassandra.metrics import create_metrics_system -from async_cassandra.monitoring import RateLimitedSession, create_monitored_session - - -# Pydantic models -class UserCreate(BaseModel): - name: str - email: str - age: int - - -class User(BaseModel): - id: str - name: str - email: str - age: int - created_at: datetime - updated_at: datetime - - -class UserUpdate(BaseModel): - name: Optional[str] = None - email: Optional[str] = None - age: Optional[int] = None - - -class ConnectionHealth(BaseModel): - status: str - healthy_hosts: int - unhealthy_hosts: int - total_connections: int - avg_latency_ms: Optional[float] - timestamp: datetime - - -class UserBatch(BaseModel): - users: List[UserCreate] - - -# Global resources -session = None -monitor = None -metrics = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Manage application lifecycle with enhanced features.""" - global session, monitor, metrics - - # Create metrics system - metrics = create_metrics_system(backend="memory", prometheus_enabled=False) - - # Create monitored session with rate limiting - contact_points = os.getenv("CASSANDRA_HOSTS", "localhost").split(",") - # port = int(os.getenv("CASSANDRA_PORT", "9042")) # Not used in create_monitored_session - - # Use create_monitored_session for automatic monitoring setup - session, monitor = await create_monitored_session( - contact_points=contact_points, - max_concurrent=MAX_CONCURRENT_QUERIES, # Rate limiting - warmup=True, # Pre-establish connections - ) - - # Add metrics to session - session.session._metrics = metrics # For rate limited session - - # Set up keyspace and tables - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS example - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.session.set_keyspace("example") - - # Drop and recreate table for clean test environment - await session.execute("DROP TABLE IF EXISTS users") - await session.execute( - """ - CREATE TABLE users ( - id UUID PRIMARY KEY, - name TEXT, - email TEXT, - age INT, - created_at TIMESTAMP, - updated_at TIMESTAMP - ) - """ - ) - - # Start continuous monitoring - asyncio.create_task(monitor.start_monitoring(interval=30)) - - yield - - # Graceful shutdown - await monitor.stop_monitoring() - await session.session.close() - - -# Create FastAPI app -app = FastAPI( - title="Enhanced FastAPI + async-cassandra", - description="Comprehensive example with all features", - version="2.0.0", - lifespan=lifespan, -) - - -@app.get("/") -async def root(): - """Root endpoint.""" - return { - "message": "Enhanced FastAPI + async-cassandra example", - "features": [ - "Timeout handling", - "Memory-efficient streaming", - "Connection monitoring", - "Rate limiting", - "Metrics collection", - "Error handling", - ], - } - - -@app.get("/health", response_model=ConnectionHealth) -async def health_check(): - """Enhanced health check with connection monitoring.""" - try: - # Get cluster metrics - cluster_metrics = await monitor.get_cluster_metrics() - - # Calculate average latency - latencies = [h.latency_ms for h in cluster_metrics.hosts if h.latency_ms] - avg_latency = sum(latencies) / len(latencies) if latencies else None - - return ConnectionHealth( - status="healthy" if cluster_metrics.healthy_hosts > 0 else "unhealthy", - healthy_hosts=cluster_metrics.healthy_hosts, - unhealthy_hosts=cluster_metrics.unhealthy_hosts, - total_connections=cluster_metrics.total_connections, - avg_latency_ms=avg_latency, - timestamp=cluster_metrics.timestamp, - ) - except Exception as e: - raise HTTPException(status_code=503, detail=f"Health check failed: {str(e)}") - - -@app.get("/monitoring/hosts") -async def get_host_status(): - """Get detailed host status from monitoring.""" - cluster_metrics = await monitor.get_cluster_metrics() - - return { - "cluster_name": cluster_metrics.cluster_name, - "protocol_version": cluster_metrics.protocol_version, - "hosts": [ - { - "address": host.address, - "datacenter": host.datacenter, - "rack": host.rack, - "status": host.status, - "latency_ms": host.latency_ms, - "last_check": host.last_check.isoformat() if host.last_check else None, - "error": host.last_error, - } - for host in cluster_metrics.hosts - ], - } - - -@app.get("/monitoring/summary") -async def get_connection_summary(): - """Get connection summary.""" - return monitor.get_connection_summary() - - -@app.post("/users", response_model=User, status_code=201) -async def create_user(user: UserCreate, background_tasks: BackgroundTasks): - """Create a new user with timeout handling.""" - user_id = uuid.uuid4() - now = datetime.now() - - try: - # Prepare with timeout - stmt = await session.session.prepare( - "INSERT INTO users (id, name, email, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)", - timeout=10.0, # 10 second timeout for prepare - ) - - # Execute with timeout (using statement's default timeout) - await session.execute(stmt, [user_id, user.name, user.email, user.age, now, now]) - - # Background task to update metrics - background_tasks.add_task(update_user_count) - - return User( - id=str(user_id), - name=user.name, - email=user.email, - age=user.age, - created_at=now, - updated_at=now, - ) - except asyncio.TimeoutError: - raise HTTPException(status_code=504, detail="Query timeout") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create user: {str(e)}") - - -async def update_user_count(): - """Background task to update user count.""" - try: - result = await session.execute("SELECT COUNT(*) FROM users") - count = result.one()[0] - # In a real app, this would update a cache or metrics - print(f"Total users: {count}") - except Exception: - pass # Don't fail background tasks - - -@app.get("/users", response_model=List[User]) -async def list_users( - limit: int = Query(10, ge=1, le=100), - timeout: float = Query(30.0, ge=1.0, le=60.0), -): - """List users with configurable timeout.""" - try: - # Execute with custom timeout using prepared statement - stmt = await session.session.prepare("SELECT * FROM users LIMIT ?") - result = await session.execute( - stmt, - [limit], - timeout=timeout, - ) - - users = [] - async for row in result: - users.append( - User( - id=str(row.id), - name=row.name, - email=row.email, - age=row.age, - created_at=row.created_at, - updated_at=row.updated_at, - ) - ) - - return users - except asyncio.TimeoutError: - raise HTTPException(status_code=504, detail=f"Query timeout after {timeout}s") - - -@app.get("/users/stream/advanced") -async def stream_users_advanced( - limit: int = Query(1000, ge=0, le=100000), - fetch_size: int = Query(100, ge=10, le=5000), - max_pages: Optional[int] = Query(None, ge=1, le=1000), - timeout_seconds: Optional[float] = Query(None, ge=1.0, le=300.0), -): - """Advanced streaming with all configuration options.""" - try: - # Create stream config with all options - stream_config = StreamConfig( - fetch_size=fetch_size, - max_pages=max_pages, - timeout_seconds=timeout_seconds, - ) - - # Track streaming progress - progress = { - "pages_fetched": 0, - "rows_processed": 0, - "start_time": datetime.now(), - } - - def page_callback(page_number: int, page_size: int): - progress["pages_fetched"] = page_number - progress["rows_processed"] += page_size - - stream_config.page_callback = page_callback - - # Execute streaming query with prepared statement - # Note: LIMIT is not needed with paging - fetch_size controls data flow - stmt = await session.session.prepare("SELECT * FROM users") - - users = [] - - # CRITICAL: Always use context manager to prevent resource leaks - async with await session.session.execute_stream( - stmt, - stream_config=stream_config, - ) as stream: - async for row in stream: - users.append( - { - "id": str(row.id), - "name": row.name, - "email": row.email, - } - ) - - # Note: If you need to limit results, track count manually - # The fetch_size in StreamConfig controls page size efficiently - if limit and len(users) >= limit: - break - - end_time = datetime.now() - duration = (end_time - progress["start_time"]).total_seconds() - - return { - "users": users, - "metadata": { - "total_returned": len(users), - "pages_fetched": progress["pages_fetched"], - "rows_processed": progress["rows_processed"], - "duration_seconds": duration, - "rows_per_second": progress["rows_processed"] / duration if duration > 0 else 0, - "config": { - "fetch_size": fetch_size, - "max_pages": max_pages, - "timeout_seconds": timeout_seconds, - }, - }, - } - except asyncio.TimeoutError: - raise HTTPException(status_code=504, detail="Streaming timeout") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}") - - -@app.get("/users/{user_id}", response_model=User) -async def get_user(user_id: str): - """Get user by ID with proper error handling.""" - try: - user_uuid = uuid.UUID(user_id) - except ValueError: - raise HTTPException(status_code=400, detail="Invalid UUID format") - - try: - stmt = await session.session.prepare("SELECT * FROM users WHERE id = ?") - result = await session.execute(stmt, [user_uuid]) - row = result.one() - - if not row: - raise HTTPException(status_code=404, detail="User not found") - - return User( - id=str(row.id), - name=row.name, - email=row.email, - age=row.age, - created_at=row.created_at, - updated_at=row.updated_at, - ) - except HTTPException: - raise - except Exception as e: - # Check for NoHostAvailable - if "NoHostAvailable" in str(type(e)): - raise HTTPException(status_code=503, detail="No Cassandra hosts available") - raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") - - -@app.get("/metrics/queries") -async def get_query_metrics(): - """Get query performance metrics.""" - if not metrics or not hasattr(metrics, "collectors"): - return {"error": "Metrics not available"} - - # Get stats from in-memory collector - for collector in metrics.collectors: - if hasattr(collector, "get_stats"): - stats = await collector.get_stats() - return stats - - return {"error": "No stats available"} - - -@app.get("/rate_limit/status") -async def get_rate_limit_status(): - """Get rate limiting status.""" - if isinstance(session, RateLimitedSession): - return { - "rate_limiting_enabled": True, - "metrics": session.get_metrics(), - "max_concurrent": session.semaphore._value, - } - return {"rate_limiting_enabled": False} - - -@app.post("/test/timeout") -async def test_timeout_handling( - operation: str = Query("connect", pattern="^(connect|prepare|execute)$"), - timeout: float = Query(5.0, ge=0.1, le=30.0), -): - """Test timeout handling for different operations.""" - try: - if operation == "connect": - # Test connection timeout - cluster = AsyncCluster(["nonexistent.host"]) - await cluster.connect(timeout=timeout) - - elif operation == "prepare": - # Test prepare timeout (simulate with sleep) - await asyncio.wait_for(asyncio.sleep(timeout + 1), timeout=timeout) - - elif operation == "execute": - # Test execute timeout - await session.execute("SELECT * FROM users", timeout=timeout) - - return {"message": f"{operation} completed within {timeout}s"} - - except asyncio.TimeoutError: - return { - "error": "timeout", - "operation": operation, - "timeout_seconds": timeout, - "message": f"{operation} timed out after {timeout}s", - } - except Exception as e: - return { - "error": "exception", - "operation": operation, - "message": str(e), - } - - -@app.post("/test/concurrent_load") -async def test_concurrent_load( - concurrent_requests: int = Query(50, ge=1, le=500), - query_type: str = Query("read", pattern="^(read|write)$"), -): - """Test system under concurrent load.""" - start_time = datetime.now() - - async def execute_query(i: int): - try: - if query_type == "read": - await session.execute("SELECT * FROM users LIMIT 1") - return {"success": True, "index": i} - else: - user_id = uuid.uuid4() - stmt = await session.session.prepare( - "INSERT INTO users (id, name, email, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)" - ) - await session.execute( - stmt, - [ - user_id, - f"LoadTest{i}", - f"load{i}@test.com", - 25, - datetime.now(), - datetime.now(), - ], - ) - return {"success": True, "index": i, "user_id": str(user_id)} - except Exception as e: - return {"success": False, "index": i, "error": str(e)} - - # Execute queries concurrently - tasks = [execute_query(i) for i in range(concurrent_requests)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Analyze results - successful = sum(1 for r in results if isinstance(r, dict) and r.get("success")) - failed = len(results) - successful - - end_time = datetime.now() - duration = (end_time - start_time).total_seconds() - - # Get rate limit metrics if available - rate_limit_metrics = {} - if isinstance(session, RateLimitedSession): - rate_limit_metrics = session.get_metrics() - - return { - "test_summary": { - "concurrent_requests": concurrent_requests, - "query_type": query_type, - "successful": successful, - "failed": failed, - "duration_seconds": duration, - "requests_per_second": concurrent_requests / duration if duration > 0 else 0, - }, - "rate_limit_metrics": rate_limit_metrics, - "timestamp": datetime.now().isoformat(), - } - - -@app.post("/users/batch") -async def create_users_batch(batch: UserBatch): - """Create multiple users in a batch operation.""" - try: - # Prepare the insert statement - stmt = await session.session.prepare( - "INSERT INTO users (id, name, email, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)" - ) - - created_users = [] - now = datetime.now() - - # Execute batch inserts - for user_data in batch.users: - user_id = uuid.uuid4() - await session.execute( - stmt, [user_id, user_data.name, user_data.email, user_data.age, now, now] - ) - created_users.append( - { - "id": str(user_id), - "name": user_data.name, - "email": user_data.email, - "age": user_data.age, - "created_at": now.isoformat(), - "updated_at": now.isoformat(), - } - ) - - return {"created": len(created_users), "users": created_users} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Batch creation failed: {str(e)}") - - -@app.delete("/users/cleanup") -async def cleanup_test_users(): - """Clean up test users created during load testing.""" - try: - # Delete all users with LoadTest prefix - # Note: LIKE is not supported in Cassandra, we need to fetch all and filter - result = await session.execute("SELECT id, name FROM users") - - deleted_count = 0 - async for row in result: - if row.name and row.name.startswith("LoadTest"): - # Use prepared statement for delete - delete_stmt = await session.session.prepare("DELETE FROM users WHERE id = ?") - await session.execute(delete_stmt, [row.id]) - deleted_count += 1 - - return {"deleted": deleted_count} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}") - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/fastapi_app/requirements-ci.txt b/examples/fastapi_app/requirements-ci.txt deleted file mode 100644 index 5988c47..0000000 --- a/examples/fastapi_app/requirements-ci.txt +++ /dev/null @@ -1,13 +0,0 @@ -# FastAPI and web server -fastapi>=0.100.0 -uvicorn[standard]>=0.23.0 -pydantic>=2.0.0 -pydantic[email]>=2.0.0 - -# HTTP client for testing -httpx>=0.24.0 - -# Testing dependencies -pytest>=7.0.0 -pytest-asyncio>=0.21.0 -testcontainers[cassandra]>=3.7.0 diff --git a/examples/fastapi_app/requirements.txt b/examples/fastapi_app/requirements.txt deleted file mode 100644 index 1a1da90..0000000 --- a/examples/fastapi_app/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -# FastAPI Example Requirements -fastapi>=0.100.0 -uvicorn[standard]>=0.23.0 -httpx>=0.24.0 # For testing -pydantic>=2.0.0 -pydantic[email]>=2.0.0 - -# Install async-cassandra from parent directory in development -# In production, use: async-cassandra>=0.1.0 diff --git a/examples/fastapi_app/test_debug.py b/examples/fastapi_app/test_debug.py deleted file mode 100644 index 3f977a8..0000000 --- a/examples/fastapi_app/test_debug.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 -"""Debug FastAPI test issues.""" - -import asyncio -import sys - -sys.path.insert(0, ".") - -from main import app, session - - -async def test_lifespan(): - """Test if lifespan is triggered.""" - print(f"Initial session: {session}") - - # Manually trigger lifespan - async with app.router.lifespan_context(app): - print(f"Session after lifespan: {session}") - - # Test a simple query - if session: - result = await session.execute("SELECT now() FROM system.local") - print(f"Query result: {result}") - - -if __name__ == "__main__": - asyncio.run(test_lifespan()) diff --git a/examples/fastapi_app/test_error_detection.py b/examples/fastapi_app/test_error_detection.py deleted file mode 100644 index e44971b..0000000 --- a/examples/fastapi_app/test_error_detection.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python -""" -Test script to demonstrate enhanced Cassandra error detection in FastAPI app. -""" - -import asyncio - -import httpx - - -async def test_error_detection(): - """Test various error scenarios to demonstrate proper error detection.""" - - async with httpx.AsyncClient(base_url="http://localhost:8000") as client: - print("Testing Enhanced Cassandra Error Detection") - print("=" * 50) - - # Test 1: Health check - print("\n1. Testing health check endpoint...") - response = await client.get("/health") - print(f" Status: {response.status_code}") - print(f" Response: {response.json()}") - - # Test 2: Create a user (should work if Cassandra is up) - print("\n2. Testing user creation...") - user_data = {"name": "Test User", "email": "test@example.com", "age": 30} - try: - response = await client.post("/users", json=user_data) - print(f" Status: {response.status_code}") - if response.status_code == 201: - print(f" Created user: {response.json()['id']}") - else: - print(f" Error: {response.json()}") - except Exception as e: - print(f" Request failed: {e}") - - # Test 3: Invalid query (should get 500, not 503) - print("\n3. Testing invalid UUID handling...") - try: - response = await client.get("/users/not-a-uuid") - print(f" Status: {response.status_code}") - print(f" Response: {response.json()}") - except Exception as e: - print(f" Request failed: {e}") - - # Test 4: Non-existent user (should get 404, not 503) - print("\n4. Testing non-existent user...") - try: - response = await client.get("/users/00000000-0000-0000-0000-000000000000") - print(f" Status: {response.status_code}") - print(f" Response: {response.json()}") - except Exception as e: - print(f" Request failed: {e}") - - print("\n" + "=" * 50) - print("Error detection test completed!") - print("\nKey observations:") - print("- 503 errors: Cassandra unavailability (connection issues)") - print("- 500 errors: Other server errors (invalid queries, etc.)") - print("- 400/404 errors: Client errors (invalid input, not found)") - - -if __name__ == "__main__": - print("Starting FastAPI app error detection test...") - print("Make sure the FastAPI app is running on http://localhost:8000") - print() - - asyncio.run(test_error_detection()) diff --git a/examples/fastapi_app/tests/conftest.py b/examples/fastapi_app/tests/conftest.py deleted file mode 100644 index 50623a1..0000000 --- a/examples/fastapi_app/tests/conftest.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Pytest configuration for FastAPI example app tests. -""" - -import sys -from pathlib import Path - -import httpx -import pytest -import pytest_asyncio -from httpx import ASGITransport - -# Add parent directories to path -sys.path.insert(0, str(Path(__file__).parent.parent)) # fastapi_app dir -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) # project root - -# Import test utils -from tests.test_utils import cleanup_keyspace, create_test_keyspace, generate_unique_keyspace - - -@pytest_asyncio.fixture -async def unique_test_keyspace(): - """Create a unique keyspace for each test.""" - from async_cassandra import AsyncCluster - - cluster = AsyncCluster(contact_points=["localhost"], protocol_version=5) - session = await cluster.connect() - - # Create unique keyspace - keyspace = generate_unique_keyspace("fastapi_test") - await create_test_keyspace(session, keyspace) - - yield keyspace - - # Cleanup - await cleanup_keyspace(session, keyspace) - await session.close() - await cluster.shutdown() - - -@pytest_asyncio.fixture -async def app_client(unique_test_keyspace): - """Create test client for the FastAPI app with isolated keyspace.""" - # First, check that Cassandra is available - from async_cassandra import AsyncCluster - - try: - test_cluster = AsyncCluster(contact_points=["localhost"], protocol_version=5) - test_session = await test_cluster.connect() - await test_session.execute("SELECT now() FROM system.local") - await test_session.close() - await test_cluster.shutdown() - except Exception as e: - pytest.skip(f"Cassandra not available: {e}") - - # Set the test keyspace in environment - import os - - os.environ["TEST_KEYSPACE"] = unique_test_keyspace - - from main import app, lifespan - - # Manually handle lifespan since httpx doesn't do it properly - async with lifespan(app): - transport = ASGITransport(app=app) - async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - yield client - - # Clean up environment - os.environ.pop("TEST_KEYSPACE", None) diff --git a/examples/fastapi_app/tests/test_fastapi_app.py b/examples/fastapi_app/tests/test_fastapi_app.py deleted file mode 100644 index 5ae1ab5..0000000 --- a/examples/fastapi_app/tests/test_fastapi_app.py +++ /dev/null @@ -1,413 +0,0 @@ -""" -Comprehensive test suite for the FastAPI example application. - -This validates that the example properly demonstrates all the -improvements made to the async-cassandra library. -""" - -import asyncio -import time -import uuid - -import httpx -import pytest -import pytest_asyncio -from httpx import ASGITransport - - -class TestFastAPIExample: - """Test suite for FastAPI example application.""" - - @pytest_asyncio.fixture - async def app_client(self): - """Create test client for the FastAPI app.""" - # First, check that Cassandra is available - from async_cassandra import AsyncCluster - - try: - test_cluster = AsyncCluster(contact_points=["localhost"]) - test_session = await test_cluster.connect() - await test_session.execute("SELECT now() FROM system.local") - await test_session.close() - await test_cluster.shutdown() - except Exception as e: - pytest.skip(f"Cassandra not available: {e}") - - from main import app, lifespan - - # Manually handle lifespan since httpx doesn't do it properly - async with lifespan(app): - transport = ASGITransport(app=app) - async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - yield client - - @pytest.mark.asyncio - async def test_health_and_basic_operations(self, app_client): - """Test health check and basic CRUD operations.""" - print("\n=== Testing Health and Basic Operations ===") - - # Health check - health_resp = await app_client.get("/health") - assert health_resp.status_code == 200 - assert health_resp.json()["status"] == "healthy" - print("✓ Health check passed") - - # Create user - user_data = {"name": "Test User", "email": "test@example.com", "age": 30} - create_resp = await app_client.post("/users", json=user_data) - assert create_resp.status_code == 201 - user = create_resp.json() - print(f"✓ Created user: {user['id']}") - - # Get user - get_resp = await app_client.get(f"/users/{user['id']}") - assert get_resp.status_code == 200 - assert get_resp.json()["name"] == user_data["name"] - print("✓ Retrieved user successfully") - - # Update user - update_data = {"age": 31} - update_resp = await app_client.put(f"/users/{user['id']}", json=update_data) - assert update_resp.status_code == 200 - assert update_resp.json()["age"] == 31 - print("✓ Updated user successfully") - - # Delete user - delete_resp = await app_client.delete(f"/users/{user['id']}") - assert delete_resp.status_code == 204 - print("✓ Deleted user successfully") - - @pytest.mark.asyncio - async def test_thread_safety_under_concurrency(self, app_client): - """Test thread safety improvements with concurrent operations.""" - print("\n=== Testing Thread Safety Under Concurrency ===") - - async def create_and_read_user(user_id: int): - """Create a user and immediately read it back.""" - # Create - user_data = { - "name": f"Concurrent User {user_id}", - "email": f"concurrent{user_id}@test.com", - "age": 25 + (user_id % 10), - } - create_resp = await app_client.post("/users", json=user_data) - if create_resp.status_code != 201: - return None - - created_user = create_resp.json() - - # Immediately read back - get_resp = await app_client.get(f"/users/{created_user['id']}") - if get_resp.status_code != 200: - return None - - return get_resp.json() - - # Run many concurrent operations - num_concurrent = 50 - start_time = time.time() - - results = await asyncio.gather( - *[create_and_read_user(i) for i in range(num_concurrent)], return_exceptions=True - ) - - duration = time.time() - start_time - - # Check results - successful = [r for r in results if isinstance(r, dict)] - errors = [r for r in results if isinstance(r, Exception)] - - print(f"✓ Completed {num_concurrent} concurrent operations in {duration:.2f}s") - print(f" - Successful: {len(successful)}") - print(f" - Errors: {len(errors)}") - - # Thread safety should ensure high success rate - assert len(successful) >= num_concurrent * 0.95 # 95% success rate - - # Verify data consistency - for user in successful: - assert "id" in user - assert "name" in user - assert user["created_at"] is not None - - @pytest.mark.asyncio - async def test_streaming_memory_efficiency(self, app_client): - """Test streaming functionality for memory efficiency.""" - print("\n=== Testing Streaming Memory Efficiency ===") - - # Create a batch of users for streaming - batch_size = 100 - batch_data = { - "users": [ - {"name": f"Stream Test {i}", "email": f"stream{i}@test.com", "age": 20 + (i % 50)} - for i in range(batch_size) - ] - } - - batch_resp = await app_client.post("/users/batch", json=batch_data) - assert batch_resp.status_code == 201 - print(f"✓ Created {batch_size} users for streaming test") - - # Test regular streaming - stream_resp = await app_client.get(f"/users/stream?limit={batch_size}&fetch_size=10") - assert stream_resp.status_code == 200 - stream_data = stream_resp.json() - - assert stream_data["metadata"]["streaming_enabled"] is True - assert stream_data["metadata"]["pages_fetched"] > 1 - assert len(stream_data["users"]) >= batch_size - print( - f"✓ Streamed {len(stream_data['users'])} users in {stream_data['metadata']['pages_fetched']} pages" - ) - - # Test page-by-page streaming - pages_resp = await app_client.get( - f"/users/stream/pages?limit={batch_size}&fetch_size=10&max_pages=5" - ) - assert pages_resp.status_code == 200 - pages_data = pages_resp.json() - - assert pages_data["metadata"]["streaming_mode"] == "page_by_page" - assert len(pages_data["pages_info"]) <= 5 - print( - f"✓ Page-by-page streaming: {pages_data['total_rows_processed']} rows in {len(pages_data['pages_info'])} pages" - ) - - @pytest.mark.asyncio - async def test_error_handling_consistency(self, app_client): - """Test error handling improvements.""" - print("\n=== Testing Error Handling Consistency ===") - - # Test invalid UUID handling - invalid_uuid_resp = await app_client.get("/users/not-a-uuid") - assert invalid_uuid_resp.status_code == 400 - assert "Invalid UUID" in invalid_uuid_resp.json()["detail"] - print("✓ Invalid UUID error handled correctly") - - # Test non-existent resource - fake_uuid = str(uuid.uuid4()) - not_found_resp = await app_client.get(f"/users/{fake_uuid}") - assert not_found_resp.status_code == 404 - assert "User not found" in not_found_resp.json()["detail"] - print("✓ Resource not found error handled correctly") - - # Test validation errors - missing required field - invalid_user_resp = await app_client.post( - "/users", json={"name": "Test"} # Missing email and age - ) - assert invalid_user_resp.status_code == 422 - print("✓ Validation error handled correctly") - - # Test streaming with invalid parameters - invalid_stream_resp = await app_client.get("/users/stream?fetch_size=0") - assert invalid_stream_resp.status_code == 422 - print("✓ Streaming parameter validation working") - - @pytest.mark.asyncio - async def test_performance_comparison(self, app_client): - """Test performance endpoints to validate async benefits.""" - print("\n=== Testing Performance Comparison ===") - - # Compare async vs sync performance - num_requests = 50 - - # Test async performance - async_resp = await app_client.get(f"/performance/async?requests={num_requests}") - assert async_resp.status_code == 200 - async_data = async_resp.json() - - # Test sync performance - sync_resp = await app_client.get(f"/performance/sync?requests={num_requests}") - assert sync_resp.status_code == 200 - sync_data = sync_resp.json() - - print(f"✓ Async performance: {async_data['requests_per_second']:.1f} req/s") - print(f"✓ Sync performance: {sync_data['requests_per_second']:.1f} req/s") - print( - f"✓ Speedup factor: {async_data['requests_per_second'] / sync_data['requests_per_second']:.1f}x" - ) - - # Async should be significantly faster - assert async_data["requests_per_second"] > sync_data["requests_per_second"] - - @pytest.mark.asyncio - async def test_monitoring_endpoints(self, app_client): - """Test monitoring and metrics endpoints.""" - print("\n=== Testing Monitoring Endpoints ===") - - # Test metrics endpoint - metrics_resp = await app_client.get("/metrics") - assert metrics_resp.status_code == 200 - metrics = metrics_resp.json() - - assert "query_performance" in metrics - assert "cassandra_connections" in metrics - print("✓ Metrics endpoint working") - - # Test shutdown endpoint - shutdown_resp = await app_client.post("/shutdown") - assert shutdown_resp.status_code == 200 - assert "Shutdown initiated" in shutdown_resp.json()["message"] - print("✓ Shutdown endpoint working") - - @pytest.mark.asyncio - async def test_timeout_handling(self, app_client): - """Test timeout handling capabilities.""" - print("\n=== Testing Timeout Handling ===") - - # Test with short timeout (should timeout) - timeout_resp = await app_client.get("/slow_query", headers={"X-Request-Timeout": "0.1"}) - assert timeout_resp.status_code == 504 - print("✓ Short timeout handled correctly") - - # Test with adequate timeout - success_resp = await app_client.get("/slow_query", headers={"X-Request-Timeout": "10"}) - assert success_resp.status_code == 200 - print("✓ Adequate timeout allows completion") - - @pytest.mark.asyncio - async def test_context_manager_safety(self, app_client): - """Test comprehensive context manager safety in FastAPI.""" - print("\n=== Testing Context Manager Safety ===") - - # Get initial status - status = await app_client.get("/context_manager_safety/status") - assert status.status_code == 200 - initial_state = status.json() - print( - f"✓ Initial state: Session={initial_state['session_open']}, Cluster={initial_state['cluster_open']}" - ) - - # Test 1: Query errors don't close session - print("\nTest 1: Query Error Safety") - query_error_resp = await app_client.post("/context_manager_safety/query_error") - assert query_error_resp.status_code == 200 - query_result = query_error_resp.json() - assert query_result["session_unchanged"] is True - assert query_result["session_open"] is True - assert query_result["session_still_works"] is True - assert "non_existent_table_xyz" in query_result["error_caught"] - print("✓ Query errors don't close session") - print(f" - Error caught: {query_result['error_caught'][:50]}...") - print(f" - Session still works: {query_result['session_still_works']}") - - # Test 2: Streaming errors don't close session - print("\nTest 2: Streaming Error Safety") - stream_error_resp = await app_client.post("/context_manager_safety/streaming_error") - assert stream_error_resp.status_code == 200 - stream_result = stream_error_resp.json() - assert stream_result["session_unchanged"] is True - assert stream_result["session_open"] is True - assert stream_result["streaming_error_caught"] is True - # The session_still_streams might be False if no users exist, but session should work - if not stream_result["session_still_streams"]: - print(f" - Note: No users found ({stream_result['rows_after_error']} rows)") - # Create a user for subsequent tests - user_resp = await app_client.post( - "/users", json={"name": "Test User", "email": "test@example.com", "age": 30} - ) - assert user_resp.status_code == 201 - print("✓ Streaming errors don't close session") - print(f" - Error caught: {stream_result['error_message'][:50]}...") - print(f" - Session remains open: {stream_result['session_open']}") - - # Test 3: Concurrent streams don't interfere - print("\nTest 3: Concurrent Streams Safety") - concurrent_resp = await app_client.post("/context_manager_safety/concurrent_streams") - assert concurrent_resp.status_code == 200 - concurrent_result = concurrent_resp.json() - print(f" - Debug: Results = {concurrent_result['results']}") - assert concurrent_result["streams_completed"] == 3 - # Check if streams worked independently (each should have 10 users) - if not concurrent_result["all_streams_independent"]: - print( - f" - Warning: Stream counts varied: {[r['count'] for r in concurrent_result['results']]}" - ) - assert concurrent_result["session_still_open"] is True - print("✓ Concurrent streams completed") - for result in concurrent_result["results"]: - print(f" - Age {result['age']}: {result['count']} users") - - # Test 4: Nested context managers - print("\nTest 4: Nested Context Managers") - nested_resp = await app_client.post("/context_manager_safety/nested_contexts") - assert nested_resp.status_code == 200 - nested_result = nested_resp.json() - assert nested_result["correct_order"] is True - assert nested_result["main_session_unaffected"] is True - assert nested_result["row_count"] == 5 - print("✓ Nested contexts close in correct order") - print(f" - Events: {' → '.join(nested_result['events'][:5])}...") - print(f" - Main session unaffected: {nested_result['main_session_unaffected']}") - - # Test 5: Streaming cancellation - print("\nTest 5: Streaming Cancellation Safety") - cancel_resp = await app_client.post("/context_manager_safety/cancellation") - assert cancel_resp.status_code == 200 - cancel_result = cancel_resp.json() - assert cancel_result["was_cancelled"] is True - assert cancel_result["session_still_works"] is True - assert cancel_result["new_stream_worked"] is True - assert cancel_result["session_open"] is True - print("✓ Cancelled streams clean up properly") - print(f" - Rows before cancel: {cancel_result['rows_processed_before_cancel']}") - print(f" - Session works after cancel: {cancel_result['session_still_works']}") - print(f" - New stream successful: {cancel_result['new_stream_worked']}") - - # Verify final state matches initial state - final_status = await app_client.get("/context_manager_safety/status") - assert final_status.status_code == 200 - final_state = final_status.json() - assert final_state["session_id"] == initial_state["session_id"] - assert final_state["cluster_id"] == initial_state["cluster_id"] - assert final_state["session_open"] is True - assert final_state["cluster_open"] is True - print("\n✓ All context manager safety tests passed!") - print(" - Session remained stable throughout all tests") - print(" - No resource leaks detected") - - -async def run_all_tests(): - """Run all tests and print summary.""" - print("=" * 60) - print("FastAPI Example Application Test Suite") - print("=" * 60) - - test_suite = TestFastAPIExample() - - # Create client - from main import app - - async with httpx.AsyncClient(app=app, base_url="http://test") as client: - # Run tests - try: - await test_suite.test_health_and_basic_operations(client) - await test_suite.test_thread_safety_under_concurrency(client) - await test_suite.test_streaming_memory_efficiency(client) - await test_suite.test_error_handling_consistency(client) - await test_suite.test_performance_comparison(client) - await test_suite.test_monitoring_endpoints(client) - await test_suite.test_timeout_handling(client) - await test_suite.test_context_manager_safety(client) - - print("\n" + "=" * 60) - print("✅ All tests passed! The FastAPI example properly demonstrates:") - print(" - Thread safety improvements") - print(" - Memory-efficient streaming") - print(" - Consistent error handling") - print(" - Performance benefits of async") - print(" - Monitoring capabilities") - print(" - Timeout handling") - print("=" * 60) - - except AssertionError as e: - print(f"\n❌ Test failed: {e}") - raise - except Exception as e: - print(f"\n❌ Unexpected error: {e}") - raise - - -if __name__ == "__main__": - # Run the test suite - asyncio.run(run_all_tests()) diff --git a/libs/async-cassandra-bulk/pyproject.toml b/libs/async-cassandra-bulk/pyproject.toml index 9013c9c..85a92bc 100644 --- a/libs/async-cassandra-bulk/pyproject.toml +++ b/libs/async-cassandra-bulk/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "High-performance bulk operations for Apache Cassandra" readme = "README_PYPI.md" requires-python = ">=3.12" -license = "Apache-2.0" +license = {text = "Apache-2.0"} authors = [ {name = "AxonOps"}, ] @@ -35,7 +35,7 @@ classifiers = [ ] dependencies = [ - "async-cassandra>=0.1.0", + "async-cassandra>=0.0.1", ] [project.optional-dependencies] diff --git a/libs/async-cassandra/Makefile b/libs/async-cassandra/Makefile index 04ebfdc..044f49c 100644 --- a/libs/async-cassandra/Makefile +++ b/libs/async-cassandra/Makefile @@ -1,37 +1,570 @@ -.PHONY: help install test lint build clean publish-test publish +.PHONY: help install install-dev test test-quick test-core test-critical test-progressive test-all test-unit test-integration test-integration-keep test-stress test-bdd lint format type-check build clean cassandra-start cassandra-stop cassandra-status cassandra-wait help: @echo "Available commands:" - @echo " install Install dependencies" - @echo " test Run tests" - @echo " lint Run linters" - @echo " build Build package" - @echo " clean Clean build artifacts" - @echo " publish-test Publish to TestPyPI" - @echo " publish Publish to PyPI" + @echo "" + @echo "Installation:" + @echo " install Install the package" + @echo " install-dev Install with development dependencies" + @echo " install-examples Install example dependencies (e.g., pyarrow)" + @echo "" + @echo "Quick Test Commands:" + @echo " test-quick Run quick validation tests (~30s)" + @echo " test-core Run core functionality tests only (~1m)" + @echo " test-critical Run critical tests (core + FastAPI) (~2m)" + @echo " test-progressive Run tests in fail-fast order" + @echo "" + @echo "Test Suites:" + @echo " test Run all tests (excluding stress tests)" + @echo " test-unit Run unit tests only" + @echo " test-integration Run integration tests (auto-manages containers)" + @echo " test-integration-keep Run integration tests (keeps containers running)" + @echo " test-stress Run stress tests" + @echo " test-bdd Run BDD tests" + @echo " test-all Run ALL tests (unit, integration, stress, and BDD)" + @echo "" + @echo "Test Categories:" + @echo " test-resilience Run error handling and resilience tests" + @echo " test-features Run advanced feature tests" + @echo " test-fastapi Run FastAPI integration tests" + @echo " test-performance Run performance and benchmark tests" + @echo "" + @echo "Cassandra Management:" + @echo " cassandra-start Start Cassandra container" + @echo " cassandra-stop Stop Cassandra container" + @echo " cassandra-status Check if Cassandra is running" + @echo " cassandra-wait Wait for Cassandra to be ready" + @echo "" + @echo "Code Quality:" + @echo " lint Run linters" + @echo " format Format code" + @echo " type-check Run type checking" + @echo "" + @echo "Build:" + @echo " build Build distribution packages" + @echo " clean Clean build artifacts" + @echo "" + @echo "Examples:" + @echo " example-streaming Run streaming basic example" + @echo " example-export-csv Run CSV export example" + @echo " example-export-parquet Run Parquet export example" + @echo " example-realtime Run real-time processing example" + @echo " example-metrics Run metrics collection example" + @echo " example-non-blocking Run non-blocking demo" + @echo " example-context Run context manager safety demo" + @echo " example-fastapi Run FastAPI example app" + @echo " examples-all Run all examples sequentially" + @echo "" + @echo "Environment variables:" + @echo " CASSANDRA_CONTACT_POINTS Cassandra contact points (default: localhost)" + @echo " SKIP_INTEGRATION_TESTS=1 Skip integration tests" + @echo " KEEP_CONTAINERS=1 Keep containers running after tests" install: + pip install -e . + +install-dev: pip install -e ".[dev,test]" + pip install -r requirements-lint.txt + pre-commit install + +install-examples: + @echo "Installing example dependencies..." + pip install -r examples/requirements.txt + +# Environment setup +CONTAINER_RUNTIME ?= $(shell command -v podman >/dev/null 2>&1 && echo podman || echo docker) +CASSANDRA_CONTACT_POINTS ?= 127.0.0.1 +CASSANDRA_PORT ?= 9042 +CASSANDRA_IMAGE ?= cassandra:5 +CASSANDRA_CONTAINER_NAME ?= async-cassandra-test + +# Quick validation (30s) +test-quick: + @echo "Running quick validation tests..." + pytest tests/unit -v -x -m "quick" || pytest tests/unit -v -x -k "test_basic" --maxfail=5 + +# Core tests only (1m) +test-core: + @echo "Running core functionality tests..." + pytest tests/unit/test_basic_queries.py tests/unit/test_cluster.py tests/unit/test_session.py -v -x + +# Critical path - MUST ALL PASS +test-critical: + @echo "Running critical tests..." + @echo "=== Running Critical Unit Tests (No Cassandra) ===" + pytest tests/unit/test_critical_issues.py -v -x + @echo "=== Starting Cassandra for Integration Tests ===" + $(MAKE) cassandra-wait + @echo "=== Running Critical FastAPI Tests ===" + pytest tests/fastapi_integration -v + cd examples/fastapi_app && pytest tests/test_fastapi_app.py -v + @echo "=== Cleaning up Cassandra ===" + $(MAKE) cassandra-stop + +# Progressive execution - FAIL FAST +test-progressive: + @echo "Running tests in fail-fast order..." + @echo "=== Running Core Unit Tests (No Cassandra) ===" + @pytest tests/unit/test_basic_queries.py tests/unit/test_cluster.py tests/unit/test_session.py -v -x || exit 1 + @echo "=== Running Resilience Tests (No Cassandra) ===" + @pytest tests/unit/test_error_recovery.py tests/unit/test_retry_policy.py -v -x || exit 1 + @echo "=== Running Feature Tests (No Cassandra) ===" + @pytest tests/unit/test_streaming.py tests/unit/test_prepared_statements.py -v || exit 1 + @echo "=== Starting Cassandra for Integration Tests ===" + @$(MAKE) cassandra-wait || exit 1 + @echo "=== Running Integration Tests ===" + @pytest tests/integration -v || exit 1 + @echo "=== Running FastAPI Integration Tests ===" + @pytest tests/fastapi_integration -v || exit 1 + @echo "=== Running FastAPI Example App Tests ===" + @cd examples/fastapi_app && pytest tests/test_fastapi_app.py -v || exit 1 + @echo "=== Running BDD Tests ===" + @pytest tests/bdd -v || exit 1 + @echo "=== Cleaning up Cassandra ===" + @$(MAKE) cassandra-stop + +# Test suite commands +test-resilience: + @echo "Running resilience tests..." + pytest tests/unit/test_error_recovery.py tests/unit/test_retry_policy.py tests/unit/test_timeout_handling.py -v + +test-features: + @echo "Running feature tests..." + pytest tests/unit/test_streaming.py tests/unit/test_prepared_statements.py tests/unit/test_metrics.py -v + +test-performance: + @echo "Running performance tests..." + pytest tests/benchmarks -v + +# BDD tests - MUST PASS +test-bdd: cassandra-wait + @echo "Running BDD tests..." + @mkdir -p reports + pytest tests/bdd/ -v + +# Standard test command - runs everything except stress test: - pytest tests/ + @echo "Running standard test suite..." + @echo "=== Running Unit Tests (No Cassandra Required) ===" + pytest tests/unit/ -v + @echo "=== Starting Cassandra for Integration Tests ===" + $(MAKE) cassandra-wait + @echo "=== Running Integration/FastAPI/BDD Tests ===" + pytest tests/integration/ tests/fastapi_integration/ tests/bdd/ -v -m "not stress" + @echo "=== Cleaning up Cassandra ===" + $(MAKE) cassandra-stop + +test-unit: + @echo "Running unit tests (no Cassandra required)..." + pytest tests/unit/ -v --cov=async_cassandra --cov-report=html + @echo "Unit tests completed." + +test-integration: cassandra-wait + @echo "Running integration tests..." + CASSANDRA_CONTACT_POINTS=$(CASSANDRA_CONTACT_POINTS) pytest tests/integration/ -v -m "not stress" + @echo "Integration tests completed." + +test-integration-keep: cassandra-wait + @echo "Running integration tests (keeping containers after tests)..." + KEEP_CONTAINERS=1 CASSANDRA_CONTACT_POINTS=$(CASSANDRA_CONTACT_POINTS) pytest tests/integration/ -v -m "not stress" + @echo "Integration tests completed. Containers are still running." + +test-fastapi: cassandra-wait + @echo "Running FastAPI integration tests with real app and Cassandra..." + CASSANDRA_CONTACT_POINTS=$(CASSANDRA_CONTACT_POINTS) pytest tests/fastapi_integration/ -v + @echo "Running FastAPI example app tests..." + cd examples/fastapi_app && CASSANDRA_CONTACT_POINTS=$(CASSANDRA_CONTACT_POINTS) pytest tests/test_fastapi_app.py -v + @echo "FastAPI integration tests completed." + +test-stress: cassandra-wait + @echo "Running stress tests..." + CASSANDRA_CONTACT_POINTS=$(CASSANDRA_CONTACT_POINTS) pytest tests/integration/test_stress.py tests/benchmarks/ -v -m stress + @echo "Stress tests completed." + +# Full test suite - EVERYTHING MUST PASS +test-all: lint + @echo "Running complete test suite..." + @echo "=== Running Unit Tests (No Cassandra Required) ===" + pytest tests/unit/ -v --cov=async_cassandra --cov-report=html --cov-report=xml + + @echo "=== Running Integration Tests ===" + $(MAKE) cassandra-stop || true + $(MAKE) cassandra-wait + pytest tests/integration/ -v -m "not stress" + + @echo "=== Running FastAPI Integration Tests ===" + $(MAKE) cassandra-stop + $(MAKE) cassandra-wait + pytest tests/fastapi_integration/ -v + @echo "=== Running BDD Tests ===" + $(MAKE) cassandra-stop + $(MAKE) cassandra-wait + pytest tests/bdd/ -v + + @echo "=== Running Example App Tests ===" + $(MAKE) cassandra-stop + $(MAKE) cassandra-wait + cd examples/fastapi_app && pytest tests/ -v + + @echo "=== Running Stress Tests ===" + $(MAKE) cassandra-stop + $(MAKE) cassandra-wait + pytest tests/integration/ -v -m stress + + @echo "=== Cleaning up Cassandra ===" + $(MAKE) cassandra-stop + @echo "✅ All tests completed!" + +# Code quality - MUST PASS lint: - ruff check src tests - black --check src tests - isort --check-only src tests - mypy src + @echo "=== Running ruff ===" + ruff check src/ tests/ + @echo "=== Running black ===" + black --check src/ tests/ + @echo "=== Running isort ===" + isort --check-only src/ tests/ + @echo "=== Running mypy ===" + mypy src/ + +format: + black src/ tests/ + isort src/ tests/ -build: clean +type-check: + mypy src/ + +# Build +build: python -m build +# Cassandra management +cassandra-start: + @echo "Starting Cassandra container..." + @echo "Stopping any existing Cassandra container..." + @$(CONTAINER_RUNTIME) stop $(CASSANDRA_CONTAINER_NAME) 2>/dev/null || true + @$(CONTAINER_RUNTIME) rm -f $(CASSANDRA_CONTAINER_NAME) 2>/dev/null || true + @$(CONTAINER_RUNTIME) run -d \ + --name $(CASSANDRA_CONTAINER_NAME) \ + -p $(CASSANDRA_PORT):9042 \ + -e CASSANDRA_CLUSTER_NAME=TestCluster \ + -e CASSANDRA_DC=datacenter1 \ + -e CASSANDRA_ENDPOINT_SNITCH=GossipingPropertyFileSnitch \ + -e HEAP_NEWSIZE=512M \ + -e MAX_HEAP_SIZE=3G \ + -e JVM_OPTS="-XX:+UseG1GC -XX:G1RSetUpdatingPauseTimePercent=5 -XX:MaxGCPauseMillis=300" \ + --memory=4g \ + --memory-swap=4g \ + $(CASSANDRA_IMAGE) + @echo "Cassandra container started" + +cassandra-stop: + @echo "Stopping Cassandra container..." + @$(CONTAINER_RUNTIME) stop $(CASSANDRA_CONTAINER_NAME) 2>/dev/null || true + @$(CONTAINER_RUNTIME) rm $(CASSANDRA_CONTAINER_NAME) 2>/dev/null || true + @echo "Cassandra container stopped" + +cassandra-status: + @if $(CONTAINER_RUNTIME) ps --format "{{.Names}}" | grep -q "^$(CASSANDRA_CONTAINER_NAME)$$"; then \ + echo "Cassandra container is running"; \ + if $(CONTAINER_RUNTIME) exec $(CASSANDRA_CONTAINER_NAME) nodetool info 2>&1 | grep -q "Native Transport active: true"; then \ + if $(CONTAINER_RUNTIME) exec $(CASSANDRA_CONTAINER_NAME) cqlsh -e "SELECT release_version FROM system.local" 2>&1 | grep -q "[0-9]"; then \ + echo "Cassandra is ready and accepting CQL queries"; \ + else \ + echo "Cassandra native transport is active but CQL not ready yet"; \ + fi; \ + else \ + echo "Cassandra is starting up..."; \ + fi; \ + else \ + echo "Cassandra container is not running"; \ + exit 1; \ + fi + +cassandra-wait: + @echo "Ensuring Cassandra is ready..." + @if ! nc -z $(CASSANDRA_CONTACT_POINTS) $(CASSANDRA_PORT) 2>/dev/null; then \ + echo "Cassandra not running on $(CASSANDRA_CONTACT_POINTS):$(CASSANDRA_PORT), starting container..."; \ + $(MAKE) cassandra-start; \ + echo "Waiting for Cassandra to be ready..."; \ + for i in $$(seq 1 60); do \ + if $(CONTAINER_RUNTIME) exec $(CASSANDRA_CONTAINER_NAME) nodetool info 2>&1 | grep -q "Native Transport active: true"; then \ + if $(CONTAINER_RUNTIME) exec $(CASSANDRA_CONTAINER_NAME) cqlsh -e "SELECT release_version FROM system.local" 2>&1 | grep -q "[0-9]"; then \ + echo "Cassandra is ready! (verified with SELECT query)"; \ + exit 0; \ + fi; \ + fi; \ + printf "."; \ + sleep 2; \ + done; \ + echo ""; \ + echo "Timeout waiting for Cassandra"; \ + exit 1; \ + else \ + echo "Checking if Cassandra on $(CASSANDRA_CONTACT_POINTS):$(CASSANDRA_PORT) can accept queries..."; \ + if [ "$(CASSANDRA_CONTACT_POINTS)" = "127.0.0.1" ] && $(CONTAINER_RUNTIME) ps --format "{{.Names}}" | grep -q "^$(CASSANDRA_CONTAINER_NAME)$$"; then \ + if ! $(CONTAINER_RUNTIME) exec $(CASSANDRA_CONTAINER_NAME) cqlsh -e "SELECT release_version FROM system.local" 2>&1 | grep -q "[0-9]"; then \ + echo "Cassandra is running but not accepting queries yet, waiting..."; \ + for i in $$(seq 1 30); do \ + if $(CONTAINER_RUNTIME) exec $(CASSANDRA_CONTAINER_NAME) cqlsh -e "SELECT release_version FROM system.local" 2>&1 | grep -q "[0-9]"; then \ + echo "Cassandra is ready! (verified with SELECT query)"; \ + exit 0; \ + fi; \ + printf "."; \ + sleep 2; \ + done; \ + echo ""; \ + echo "Timeout waiting for Cassandra to accept queries"; \ + exit 1; \ + fi; \ + fi; \ + echo "Cassandra is already running and accepting queries"; \ + fi + +# Cleanup clean: - rm -rf dist/ build/ *.egg-info/ + rm -rf build/ + rm -rf dist/ + rm -rf *.egg-info + rm -rf .coverage + rm -rf htmlcov/ + rm -rf .pytest_cache/ + rm -rf .mypy_cache/ + rm -rf reports/*.json reports/*.html reports/*.xml find . -type d -name __pycache__ -exec rm -rf {} + find . -type f -name "*.pyc" -delete -publish-test: build - python -m twine upload --repository testpypi dist/* +clean-all: clean cassandra-stop + @echo "All cleaned up" + +# Example targets +.PHONY: example-streaming example-export-csv example-export-parquet example-realtime example-metrics example-non-blocking example-context example-fastapi examples-all + +# Ensure examples can connect to Cassandra +EXAMPLES_ENV = CASSANDRA_CONTACT_POINTS=$(CASSANDRA_CONTACT_POINTS) + +example-streaming: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ STREAMING BASIC EXAMPLE ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This example demonstrates memory-efficient streaming of large result sets ║" + @echo "║ ║" + @echo "║ What you'll see: ║" + @echo "║ • Streaming 100,000 events without loading all into memory ║" + @echo "║ • Progress tracking with page-by-page processing ║" + @echo "║ • True Async Paging - pages fetched on-demand as you process ║" + @echo "║ • Different streaming patterns (basic, filtered, page-based) ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Connecting to Cassandra at $(CASSANDRA_CONTACT_POINTS)..." + @echo "" + @$(EXAMPLES_ENV) python examples/streaming_basic.py + +example-export-csv: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ CSV EXPORT EXAMPLE ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This example exports a large Cassandra table to CSV format efficiently ║" + @echo "║ ║" + @echo "║ What you'll see: ║" + @echo "║ • Creating and populating a sample products table (5,000 items) ║" + @echo "║ • Streaming export with progress tracking ║" + @echo "║ • Memory-efficient processing (no loading entire table into memory) ║" + @echo "║ • Export statistics (rows/sec, file size, duration) ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Connecting to Cassandra at $(CASSANDRA_CONTACT_POINTS)..." + @echo "💾 Output will be saved to: $(EXAMPLE_OUTPUT_DIR)" + @echo "" + @$(EXAMPLES_ENV) python examples/export_large_table.py + +example-export-parquet: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ PARQUET EXPORT EXAMPLE ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This example exports Cassandra tables to Parquet format with streaming ║" + @echo "║ ║" + @echo "║ What you'll see: ║" + @echo "║ • Creating time-series data with complex types (30,000+ events) ║" + @echo "║ • Three export scenarios: ║" + @echo "║ - Full table export with snappy compression ║" + @echo "║ - Filtered export (purchase events only) with gzip ║" + @echo "║ - Different compression comparison (lz4) ║" + @echo "║ • Automatic schema inference from Cassandra types ║" + @echo "║ • Verification of exported Parquet files ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Connecting to Cassandra at $(CASSANDRA_CONTACT_POINTS)..." + @echo "💾 Output will be saved to: $(EXAMPLE_OUTPUT_DIR)" + @echo "📦 Installing PyArrow if needed..." + @pip install pyarrow >/dev/null 2>&1 || echo "✅ PyArrow ready" + @echo "" + @$(EXAMPLES_ENV) python examples/export_to_parquet.py + +example-realtime: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ REAL-TIME PROCESSING EXAMPLE ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This example demonstrates real-time streaming analytics on sensor data ║" + @echo "║ ║" + @echo "║ What you'll see: ║" + @echo "║ • Simulating IoT sensor network (50 sensors, time-series data) ║" + @echo "║ • Sliding window analytics with time-based queries ║" + @echo "║ • Real-time anomaly detection and alerting ║" + @echo "║ • Continuous monitoring with aggregations ║" + @echo "║ • High-performance streaming of time-series data ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Connecting to Cassandra at $(CASSANDRA_CONTACT_POINTS)..." + @echo "🌡️ Simulating sensor network..." + @echo "" + @$(EXAMPLES_ENV) python examples/realtime_processing.py + +example-metrics: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ METRICS COLLECTION EXAMPLES ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ These examples demonstrate query performance monitoring and metrics ║" + @echo "║ ║" + @echo "║ Part 1 - Simple Metrics: ║" + @echo "║ • Basic query performance tracking ║" + @echo "║ • Connection health monitoring ║" + @echo "║ • Error rate calculation ║" + @echo "║ ║" + @echo "║ Part 2 - Advanced Metrics: ║" + @echo "║ • Multiple metrics collectors ║" + @echo "║ • Prometheus integration patterns ║" + @echo "║ • FastAPI integration examples ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Connecting to Cassandra at $(CASSANDRA_CONTACT_POINTS)..." + @echo "" + @echo "📊 Part 1: Simple Metrics..." + @echo "─────────────────────────────" + @$(EXAMPLES_ENV) python examples/metrics_simple.py + @echo "" + @echo "📈 Part 2: Advanced Metrics..." + @echo "──────────────────────────────" + @$(EXAMPLES_ENV) python examples/metrics_example.py + +example-non-blocking: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ NON-BLOCKING STREAMING DEMO ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This PROVES that streaming doesn't block the asyncio event loop! ║" + @echo "║ ║" + @echo "║ What you'll see: ║" + @echo "║ • 💓 Heartbeat indicators pulsing every 10ms ║" + @echo "║ • Streaming 50,000 rows while heartbeat continues ║" + @echo "║ • Event loop responsiveness analysis ║" + @echo "║ • Concurrent queries executing during streaming ║" + @echo "║ • Multiple streams running in parallel ║" + @echo "║ ║" + @echo "║ 🔍 Watch the heartbeats - they should NEVER stop! ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Connecting to Cassandra at $(CASSANDRA_CONTACT_POINTS)..." + @echo "" + @$(EXAMPLES_ENV) python examples/streaming_non_blocking_demo.py + +example-context: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ CONTEXT MANAGER SAFETY DEMO ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This demonstrates proper resource management with context managers ║" + @echo "║ ║" + @echo "║ What you'll see: ║" + @echo "║ • Query errors DON'T close sessions (resilience) ║" + @echo "║ • Streaming errors DON'T affect other operations ║" + @echo "║ • Context managers provide proper isolation ║" + @echo "║ • Multiple concurrent operations share resources safely ║" + @echo "║ • Automatic cleanup even during exceptions ║" + @echo "║ ║" + @echo "║ 💡 Key lesson: ALWAYS use context managers! ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Connecting to Cassandra at $(CASSANDRA_CONTACT_POINTS)..." + @echo "" + @$(EXAMPLES_ENV) python examples/context_manager_safety_demo.py + +example-fastapi: + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ FASTAPI EXAMPLE APP ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This starts a full REST API with async Cassandra integration ║" + @echo "║ ║" + @echo "║ Features: ║" + @echo "║ • Complete CRUD operations with async patterns ║" + @echo "║ • Streaming endpoints for large datasets ║" + @echo "║ • Performance comparison endpoints (async vs sync) ║" + @echo "║ • Connection lifecycle management ║" + @echo "║ • Docker Compose for easy development ║" + @echo "║ ║" + @echo "║ 📚 See examples/fastapi_app/README.md for API documentation ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "🚀 Starting FastAPI application..." + @echo "" + @cd examples/fastapi_app && $(MAKE) run -publish: build - python -m twine upload dist/* +examples-all: cassandra-wait + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ RUNNING ALL EXAMPLES ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ This will run each example in sequence to demonstrate all features ║" + @echo "║ ║" + @echo "║ Examples to run: ║" + @echo "║ 1. Streaming Basic - Memory-efficient data processing ║" + @echo "║ 2. CSV Export - Large table export with progress tracking ║" + @echo "║ 3. Parquet Export - Complex types and compression options ║" + @echo "║ 4. Real-time Processing - IoT sensor analytics ║" + @echo "║ 5. Metrics Collection - Performance monitoring ║" + @echo "║ 6. Non-blocking Demo - Event loop responsiveness proof ║" + @echo "║ 7. Context Managers - Resource management patterns ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" + @echo "" + @echo "📡 Using Cassandra at $(CASSANDRA_CONTACT_POINTS)" + @echo "" + @$(MAKE) example-streaming + @echo "" + @echo "════════════════════════════════════════════════════════════════════════════════" + @echo "" + @$(MAKE) example-export-csv + @echo "" + @echo "════════════════════════════════════════════════════════════════════════════════" + @echo "" + @$(MAKE) example-export-parquet + @echo "" + @echo "════════════════════════════════════════════════════════════════════════════════" + @echo "" + @$(MAKE) example-realtime + @echo "" + @echo "════════════════════════════════════════════════════════════════════════════════" + @echo "" + @$(MAKE) example-metrics + @echo "" + @echo "════════════════════════════════════════════════════════════════════════════════" + @echo "" + @$(MAKE) example-non-blocking + @echo "" + @echo "════════════════════════════════════════════════════════════════════════════════" + @echo "" + @$(MAKE) example-context + @echo "" + @echo "╔══════════════════════════════════════════════════════════════════════════════╗" + @echo "║ ✅ ALL EXAMPLES COMPLETED SUCCESSFULLY! ║" + @echo "╠══════════════════════════════════════════════════════════════════════════════╣" + @echo "║ Note: FastAPI example not included as it starts a server. ║" + @echo "║ Run 'make example-fastapi' separately to start the FastAPI app. ║" + @echo "╚══════════════════════════════════════════════════════════════════════════════╝" diff --git a/examples/README.md b/libs/async-cassandra/examples/README.md similarity index 100% rename from examples/README.md rename to libs/async-cassandra/examples/README.md diff --git a/examples/bulk_operations/.gitignore b/libs/async-cassandra/examples/bulk_operations/.gitignore similarity index 100% rename from examples/bulk_operations/.gitignore rename to libs/async-cassandra/examples/bulk_operations/.gitignore diff --git a/examples/bulk_operations/Makefile b/libs/async-cassandra/examples/bulk_operations/Makefile similarity index 100% rename from examples/bulk_operations/Makefile rename to libs/async-cassandra/examples/bulk_operations/Makefile diff --git a/examples/bulk_operations/README.md b/libs/async-cassandra/examples/bulk_operations/README.md similarity index 100% rename from examples/bulk_operations/README.md rename to libs/async-cassandra/examples/bulk_operations/README.md diff --git a/examples/bulk_operations/bulk_operations/__init__.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/__init__.py similarity index 100% rename from examples/bulk_operations/bulk_operations/__init__.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/__init__.py diff --git a/examples/bulk_operations/bulk_operations/bulk_operator.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/bulk_operator.py similarity index 100% rename from examples/bulk_operations/bulk_operations/bulk_operator.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/bulk_operator.py diff --git a/examples/bulk_operations/bulk_operations/exporters/__init__.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/__init__.py similarity index 100% rename from examples/bulk_operations/bulk_operations/exporters/__init__.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/__init__.py diff --git a/examples/bulk_operations/bulk_operations/exporters/base.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/base.py similarity index 99% rename from examples/bulk_operations/bulk_operations/exporters/base.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/base.py index 015d629..894ba95 100644 --- a/examples/bulk_operations/bulk_operations/exporters/base.py +++ b/libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/base.py @@ -9,9 +9,8 @@ from pathlib import Path from typing import Any -from cassandra.util import OrderedMap, OrderedMapSerializedKey - from bulk_operations.bulk_operator import TokenAwareBulkOperator +from cassandra.util import OrderedMap, OrderedMapSerializedKey class ExportFormat(Enum): diff --git a/examples/bulk_operations/bulk_operations/exporters/csv_exporter.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/csv_exporter.py similarity index 100% rename from examples/bulk_operations/bulk_operations/exporters/csv_exporter.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/csv_exporter.py diff --git a/examples/bulk_operations/bulk_operations/exporters/json_exporter.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/json_exporter.py similarity index 100% rename from examples/bulk_operations/bulk_operations/exporters/json_exporter.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/json_exporter.py diff --git a/examples/bulk_operations/bulk_operations/exporters/parquet_exporter.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/parquet_exporter.py similarity index 99% rename from examples/bulk_operations/bulk_operations/exporters/parquet_exporter.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/parquet_exporter.py index f9835bc..809863c 100644 --- a/examples/bulk_operations/bulk_operations/exporters/parquet_exporter.py +++ b/libs/async-cassandra/examples/bulk_operations/bulk_operations/exporters/parquet_exporter.py @@ -15,9 +15,8 @@ "PyArrow is required for Parquet export. Install with: pip install pyarrow" ) from None -from cassandra.util import OrderedMap, OrderedMapSerializedKey - from bulk_operations.exporters.base import Exporter, ExportFormat, ExportProgress +from cassandra.util import OrderedMap, OrderedMapSerializedKey class ParquetExporter(Exporter): diff --git a/examples/bulk_operations/bulk_operations/iceberg/__init__.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/__init__.py similarity index 100% rename from examples/bulk_operations/bulk_operations/iceberg/__init__.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/__init__.py diff --git a/examples/bulk_operations/bulk_operations/iceberg/catalog.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/catalog.py similarity index 100% rename from examples/bulk_operations/bulk_operations/iceberg/catalog.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/catalog.py diff --git a/examples/bulk_operations/bulk_operations/iceberg/exporter.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/exporter.py similarity index 99% rename from examples/bulk_operations/bulk_operations/iceberg/exporter.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/exporter.py index cd6cb7a..980699e 100644 --- a/examples/bulk_operations/bulk_operations/iceberg/exporter.py +++ b/libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/exporter.py @@ -9,17 +9,16 @@ import pyarrow as pa import pyarrow.parquet as pq +from bulk_operations.exporters.base import ExportFormat, ExportProgress +from bulk_operations.exporters.parquet_exporter import ParquetExporter +from bulk_operations.iceberg.catalog import get_or_create_catalog +from bulk_operations.iceberg.schema_mapper import CassandraToIcebergSchemaMapper from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table -from bulk_operations.exporters.base import ExportFormat, ExportProgress -from bulk_operations.exporters.parquet_exporter import ParquetExporter -from bulk_operations.iceberg.catalog import get_or_create_catalog -from bulk_operations.iceberg.schema_mapper import CassandraToIcebergSchemaMapper - class IcebergExporter(ParquetExporter): """Export Cassandra data to Apache Iceberg tables. diff --git a/examples/bulk_operations/bulk_operations/iceberg/schema_mapper.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/schema_mapper.py similarity index 100% rename from examples/bulk_operations/bulk_operations/iceberg/schema_mapper.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/iceberg/schema_mapper.py diff --git a/examples/bulk_operations/bulk_operations/parallel_export.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/parallel_export.py similarity index 100% rename from examples/bulk_operations/bulk_operations/parallel_export.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/parallel_export.py diff --git a/examples/bulk_operations/bulk_operations/stats.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/stats.py similarity index 100% rename from examples/bulk_operations/bulk_operations/stats.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/stats.py diff --git a/examples/bulk_operations/bulk_operations/token_utils.py b/libs/async-cassandra/examples/bulk_operations/bulk_operations/token_utils.py similarity index 100% rename from examples/bulk_operations/bulk_operations/token_utils.py rename to libs/async-cassandra/examples/bulk_operations/bulk_operations/token_utils.py diff --git a/examples/bulk_operations/debug_coverage.py b/libs/async-cassandra/examples/bulk_operations/debug_coverage.py similarity index 99% rename from examples/bulk_operations/debug_coverage.py rename to libs/async-cassandra/examples/bulk_operations/debug_coverage.py index ca8c781..fb7d46b 100644 --- a/examples/bulk_operations/debug_coverage.py +++ b/libs/async-cassandra/examples/bulk_operations/debug_coverage.py @@ -3,10 +3,11 @@ import asyncio -from async_cassandra import AsyncCluster from bulk_operations.bulk_operator import TokenAwareBulkOperator from bulk_operations.token_utils import MIN_TOKEN, discover_token_ranges, generate_token_range_query +from async_cassandra import AsyncCluster + async def debug_coverage(): """Debug why we're missing rows.""" diff --git a/examples/context_manager_safety_demo.py b/libs/async-cassandra/examples/context_manager_safety_demo.py similarity index 100% rename from examples/context_manager_safety_demo.py rename to libs/async-cassandra/examples/context_manager_safety_demo.py diff --git a/examples/exampleoutput/.gitignore b/libs/async-cassandra/examples/exampleoutput/.gitignore similarity index 100% rename from examples/exampleoutput/.gitignore rename to libs/async-cassandra/examples/exampleoutput/.gitignore diff --git a/examples/exampleoutput/README.md b/libs/async-cassandra/examples/exampleoutput/README.md similarity index 100% rename from examples/exampleoutput/README.md rename to libs/async-cassandra/examples/exampleoutput/README.md diff --git a/examples/export_large_table.py b/libs/async-cassandra/examples/export_large_table.py similarity index 100% rename from examples/export_large_table.py rename to libs/async-cassandra/examples/export_large_table.py diff --git a/examples/export_to_parquet.py b/libs/async-cassandra/examples/export_to_parquet.py similarity index 100% rename from examples/export_to_parquet.py rename to libs/async-cassandra/examples/export_to_parquet.py diff --git a/examples/metrics_example.py b/libs/async-cassandra/examples/metrics_example.py similarity index 100% rename from examples/metrics_example.py rename to libs/async-cassandra/examples/metrics_example.py diff --git a/examples/metrics_simple.py b/libs/async-cassandra/examples/metrics_simple.py similarity index 100% rename from examples/metrics_simple.py rename to libs/async-cassandra/examples/metrics_simple.py diff --git a/examples/monitoring/alerts.yml b/libs/async-cassandra/examples/monitoring/alerts.yml similarity index 100% rename from examples/monitoring/alerts.yml rename to libs/async-cassandra/examples/monitoring/alerts.yml diff --git a/examples/monitoring/grafana_dashboard.json b/libs/async-cassandra/examples/monitoring/grafana_dashboard.json similarity index 100% rename from examples/monitoring/grafana_dashboard.json rename to libs/async-cassandra/examples/monitoring/grafana_dashboard.json diff --git a/examples/realtime_processing.py b/libs/async-cassandra/examples/realtime_processing.py similarity index 100% rename from examples/realtime_processing.py rename to libs/async-cassandra/examples/realtime_processing.py diff --git a/examples/requirements.txt b/libs/async-cassandra/examples/requirements.txt similarity index 100% rename from examples/requirements.txt rename to libs/async-cassandra/examples/requirements.txt diff --git a/examples/streaming_basic.py b/libs/async-cassandra/examples/streaming_basic.py similarity index 100% rename from examples/streaming_basic.py rename to libs/async-cassandra/examples/streaming_basic.py diff --git a/examples/streaming_non_blocking_demo.py b/libs/async-cassandra/examples/streaming_non_blocking_demo.py similarity index 100% rename from examples/streaming_non_blocking_demo.py rename to libs/async-cassandra/examples/streaming_non_blocking_demo.py diff --git a/libs/async-cassandra/pyproject.toml b/libs/async-cassandra/pyproject.toml index 0b4e643..d513837 100644 --- a/libs/async-cassandra/pyproject.toml +++ b/libs/async-cassandra/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Async Python wrapper for the Cassandra Python driver" readme = "README_PYPI.md" requires-python = ">=3.12" -license = "Apache-2.0" +license = {text = "Apache-2.0"} authors = [ {name = "AxonOps"}, ] diff --git a/libs/async-cassandra/tests/integration/test_context_manager_safety_integration.py b/libs/async-cassandra/tests/integration/test_context_manager_safety_integration.py index 19df52d..8dca597 100644 --- a/libs/async-cassandra/tests/integration/test_context_manager_safety_integration.py +++ b/libs/async-cassandra/tests/integration/test_context_manager_safety_integration.py @@ -97,6 +97,9 @@ async def test_streaming_error_doesnt_close_session(self, cassandra_session): """ ) + # Clean up any existing data + await cassandra_session.execute("TRUNCATE test_stream_data") + # Insert some data insert_prepared = await cassandra_session.prepare( "INSERT INTO test_stream_data (id, value) VALUES (?, ?)" diff --git a/src/async_cassandra/__init__.py b/src/async_cassandra/__init__.py deleted file mode 100644 index 813e19c..0000000 --- a/src/async_cassandra/__init__.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -async-cassandra: Async Python wrapper for the Cassandra Python driver. - -This package provides true async/await support for Cassandra operations, -addressing performance limitations when using the official driver with -async frameworks like FastAPI. -""" - -try: - from importlib.metadata import PackageNotFoundError, version - - try: - __version__ = version("async-cassandra") - except PackageNotFoundError: - # Package is not installed - __version__ = "0.0.0+unknown" -except ImportError: - # Python < 3.8 - __version__ = "0.0.0+unknown" - -__author__ = "AxonOps" -__email__ = "community@axonops.com" - -from .cluster import AsyncCluster -from .exceptions import AsyncCassandraError, ConnectionError, QueryError -from .metrics import ( - ConnectionMetrics, - InMemoryMetricsCollector, - MetricsCollector, - MetricsMiddleware, - PrometheusMetricsCollector, - QueryMetrics, - create_metrics_system, -) -from .monitoring import ( - HOST_STATUS_DOWN, - HOST_STATUS_UNKNOWN, - HOST_STATUS_UP, - ClusterMetrics, - ConnectionMonitor, - HostMetrics, - RateLimitedSession, - create_monitored_session, -) -from .result import AsyncResultSet -from .retry_policy import AsyncRetryPolicy -from .session import AsyncCassandraSession -from .streaming import AsyncStreamingResultSet, StreamConfig, create_streaming_statement - -__all__ = [ - "AsyncCassandraSession", - "AsyncCluster", - "AsyncCassandraError", - "ConnectionError", - "QueryError", - "AsyncResultSet", - "AsyncRetryPolicy", - "ConnectionMonitor", - "RateLimitedSession", - "create_monitored_session", - "HOST_STATUS_UP", - "HOST_STATUS_DOWN", - "HOST_STATUS_UNKNOWN", - "HostMetrics", - "ClusterMetrics", - "AsyncStreamingResultSet", - "StreamConfig", - "create_streaming_statement", - "MetricsMiddleware", - "MetricsCollector", - "InMemoryMetricsCollector", - "PrometheusMetricsCollector", - "QueryMetrics", - "ConnectionMetrics", - "create_metrics_system", -] diff --git a/src/async_cassandra/base.py b/src/async_cassandra/base.py deleted file mode 100644 index 6eac5a4..0000000 --- a/src/async_cassandra/base.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Simplified base classes for async-cassandra. - -This module provides minimal functionality needed for the async wrapper, -avoiding over-engineering and complex locking patterns. -""" - -from typing import Any, TypeVar - -T = TypeVar("T") - - -class AsyncContextManageable: - """ - Simple mixin to add async context manager support. - - Classes using this mixin must implement an async close() method. - """ - - async def __aenter__(self: T) -> T: - """Async context manager entry.""" - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Async context manager exit.""" - await self.close() # type: ignore diff --git a/src/async_cassandra/cluster.py b/src/async_cassandra/cluster.py deleted file mode 100644 index dbdd2cb..0000000 --- a/src/async_cassandra/cluster.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -Simplified async cluster management for Cassandra connections. - -This implementation focuses on being a thin wrapper around the driver cluster, -avoiding complex state management. -""" - -import asyncio -from ssl import SSLContext -from typing import Dict, List, Optional - -from cassandra.auth import AuthProvider, PlainTextAuthProvider -from cassandra.cluster import Cluster, Metadata -from cassandra.policies import ( - DCAwareRoundRobinPolicy, - ExponentialReconnectionPolicy, - LoadBalancingPolicy, - ReconnectionPolicy, - RetryPolicy, - TokenAwarePolicy, -) - -from .base import AsyncContextManageable -from .exceptions import ConnectionError -from .retry_policy import AsyncRetryPolicy -from .session import AsyncCassandraSession - - -class AsyncCluster(AsyncContextManageable): - """ - Simplified async wrapper for Cassandra Cluster. - - This implementation: - - Uses a single lock only for close operations - - Focuses on being a thin wrapper without complex state management - - Accepts reasonable trade-offs for simplicity - """ - - def __init__( - self, - contact_points: Optional[List[str]] = None, - port: int = 9042, - auth_provider: Optional[AuthProvider] = None, - load_balancing_policy: Optional[LoadBalancingPolicy] = None, - reconnection_policy: Optional[ReconnectionPolicy] = None, - retry_policy: Optional[RetryPolicy] = None, - ssl_context: Optional[SSLContext] = None, - protocol_version: Optional[int] = None, - executor_threads: int = 2, - max_schema_agreement_wait: int = 10, - control_connection_timeout: float = 2.0, - idle_heartbeat_interval: float = 30.0, - schema_event_refresh_window: float = 2.0, - topology_event_refresh_window: float = 10.0, - status_event_refresh_window: float = 2.0, - **kwargs: Dict[str, object], - ): - """ - Initialize async cluster wrapper. - - Args: - contact_points: List of contact points to connect to. - port: Port to connect to on contact points. - auth_provider: Authentication provider. - load_balancing_policy: Load balancing policy to use. - reconnection_policy: Reconnection policy to use. - retry_policy: Retry policy to use. - ssl_context: SSL context for secure connections. - protocol_version: CQL protocol version to use. - executor_threads: Number of executor threads. - max_schema_agreement_wait: Max time to wait for schema agreement. - control_connection_timeout: Timeout for control connection. - idle_heartbeat_interval: Interval for idle heartbeats. - schema_event_refresh_window: Window for schema event refresh. - topology_event_refresh_window: Window for topology event refresh. - status_event_refresh_window: Window for status event refresh. - **kwargs: Additional cluster options as key-value pairs. - """ - # Set defaults - if contact_points is None: - contact_points = ["127.0.0.1"] - - if load_balancing_policy is None: - load_balancing_policy = TokenAwarePolicy(DCAwareRoundRobinPolicy()) - - if reconnection_policy is None: - reconnection_policy = ExponentialReconnectionPolicy(base_delay=1.0, max_delay=60.0) - - if retry_policy is None: - retry_policy = AsyncRetryPolicy() - - # Create the underlying cluster with only non-None parameters - cluster_kwargs = { - "contact_points": contact_points, - "port": port, - "load_balancing_policy": load_balancing_policy, - "reconnection_policy": reconnection_policy, - "default_retry_policy": retry_policy, - "executor_threads": executor_threads, - "max_schema_agreement_wait": max_schema_agreement_wait, - "control_connection_timeout": control_connection_timeout, - "idle_heartbeat_interval": idle_heartbeat_interval, - "schema_event_refresh_window": schema_event_refresh_window, - "topology_event_refresh_window": topology_event_refresh_window, - "status_event_refresh_window": status_event_refresh_window, - } - - # Add optional parameters only if they're not None - if auth_provider is not None: - cluster_kwargs["auth_provider"] = auth_provider - if ssl_context is not None: - cluster_kwargs["ssl_context"] = ssl_context - # Handle protocol version - if protocol_version is not None: - # Validate explicitly specified protocol version - if protocol_version < 5: - from .exceptions import ConfigurationError - - raise ConfigurationError( - f"Protocol version {protocol_version} is not supported. " - "async-cassandra requires CQL protocol v5 or higher for optimal async performance. " - "Protocol v5 was introduced in Cassandra 4.0 (released July 2021). " - "Please upgrade your Cassandra cluster to 4.0+ or use a compatible service. " - "If you're using a cloud provider, check their documentation for protocol support." - ) - cluster_kwargs["protocol_version"] = protocol_version - # else: Let driver negotiate to get the highest available version - - # Merge with any additional kwargs - cluster_kwargs.update(kwargs) - - self._cluster = Cluster(**cluster_kwargs) - self._closed = False - self._close_lock = asyncio.Lock() - - @classmethod - def create_with_auth( - cls, contact_points: List[str], username: str, password: str, **kwargs: Dict[str, object] - ) -> "AsyncCluster": - """ - Create cluster with username/password authentication. - - Args: - contact_points: List of contact points to connect to. - username: Username for authentication. - password: Password for authentication. - **kwargs: Additional cluster options as key-value pairs. - - Returns: - New AsyncCluster instance. - """ - auth_provider = PlainTextAuthProvider(username=username, password=password) - - return cls(contact_points=contact_points, auth_provider=auth_provider, **kwargs) # type: ignore[arg-type] - - async def connect( - self, keyspace: Optional[str] = None, timeout: Optional[float] = None - ) -> AsyncCassandraSession: - """ - Connect to the cluster and create a session. - - Args: - keyspace: Optional keyspace to use. - timeout: Connection timeout in seconds. Defaults to DEFAULT_CONNECTION_TIMEOUT. - - Returns: - New AsyncCassandraSession. - - Raises: - ConnectionError: If connection fails or cluster is closed. - asyncio.TimeoutError: If connection times out. - """ - # Simple closed check - no lock needed for read - if self._closed: - raise ConnectionError("Cluster is closed") - - # Import here to avoid circular import - from .constants import DEFAULT_CONNECTION_TIMEOUT, MAX_RETRY_ATTEMPTS - - if timeout is None: - timeout = DEFAULT_CONNECTION_TIMEOUT - - last_error = None - for attempt in range(MAX_RETRY_ATTEMPTS): - try: - session = await asyncio.wait_for( - AsyncCassandraSession.create(self._cluster, keyspace), timeout=timeout - ) - - # Verify we got protocol v5 or higher - negotiated_version = self._cluster.protocol_version - if negotiated_version < 5: - await session.close() - raise ConnectionError( - f"Connected with protocol v{negotiated_version} but v5+ is required. " - f"Your Cassandra server only supports up to protocol v{negotiated_version}. " - "async-cassandra requires CQL protocol v5 or higher (Cassandra 4.0+). " - "Please upgrade your Cassandra cluster to version 4.0 or newer." - ) - - return session - - except asyncio.TimeoutError: - raise - except Exception as e: - last_error = e - - # Check for protocol version mismatch - error_str = str(e) - if "NoHostAvailable" in str(type(e).__name__): - # Check if it's due to protocol version incompatibility - if "ProtocolError" in error_str or "protocol version" in error_str.lower(): - # Don't retry protocol version errors - the server doesn't support v5+ - raise ConnectionError( - "Failed to connect: Your Cassandra server doesn't support protocol v5. " - "async-cassandra requires CQL protocol v5 or higher (Cassandra 4.0+). " - "Please upgrade your Cassandra cluster to version 4.0 or newer." - ) from e - - if attempt < MAX_RETRY_ATTEMPTS - 1: - # Log retry attempt - import logging - - logger = logging.getLogger(__name__) - logger.warning( - f"Connection attempt {attempt + 1} failed: {str(e)}. " - f"Retrying... ({attempt + 2}/{MAX_RETRY_ATTEMPTS})" - ) - # Small delay before retry to allow service to recover - # Use longer delay for NoHostAvailable errors - if "NoHostAvailable" in str(type(e).__name__): - # For connection reset errors, wait longer - if "Connection reset by peer" in str(e): - await asyncio.sleep(5.0 * (attempt + 1)) - else: - await asyncio.sleep(2.0 * (attempt + 1)) - else: - await asyncio.sleep(0.5 * (attempt + 1)) - - raise ConnectionError( - f"Failed to connect to cluster after {MAX_RETRY_ATTEMPTS} attempts: {str(last_error)}" - ) from last_error - - async def close(self) -> None: - """ - Close the cluster and release all resources. - - This method is idempotent and can be called multiple times safely. - Uses a single lock to ensure shutdown is called only once. - """ - async with self._close_lock: - if not self._closed: - self._closed = True - loop = asyncio.get_event_loop() - # Use a reasonable timeout for shutdown operations - await asyncio.wait_for( - loop.run_in_executor(None, self._cluster.shutdown), timeout=30.0 - ) - # Give the driver's internal threads time to finish - # This helps prevent "cannot schedule new futures after shutdown" errors - # The driver has internal scheduler threads that may still be running - await asyncio.sleep(5.0) - - async def shutdown(self) -> None: - """ - Shutdown the cluster and release all resources. - - This method is idempotent and can be called multiple times safely. - Alias for close() to match driver API. - """ - await self.close() - - @property - def is_closed(self) -> bool: - """Check if the cluster is closed.""" - return self._closed - - @property - def metadata(self) -> Metadata: - """Get cluster metadata.""" - return self._cluster.metadata - - def register_user_type(self, keyspace: str, user_type: str, klass: type) -> None: - """ - Register a user-defined type. - - Args: - keyspace: Keyspace containing the type. - user_type: Name of the user-defined type. - klass: Python class to map the type to. - """ - self._cluster.register_user_type(keyspace, user_type, klass) diff --git a/src/async_cassandra/constants.py b/src/async_cassandra/constants.py deleted file mode 100644 index c93f9fc..0000000 --- a/src/async_cassandra/constants.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Constants used throughout the async-cassandra library. -""" - -# Default values -DEFAULT_FETCH_SIZE = 1000 -DEFAULT_EXECUTOR_THREADS = 4 -DEFAULT_CONNECTION_TIMEOUT = 30.0 # Increased for larger heap sizes -DEFAULT_REQUEST_TIMEOUT = 120.0 - -# Limits -MAX_CONCURRENT_QUERIES = 100 -MAX_RETRY_ATTEMPTS = 3 - -# Thread pool settings -MIN_EXECUTOR_THREADS = 1 -MAX_EXECUTOR_THREADS = 128 diff --git a/src/async_cassandra/exceptions.py b/src/async_cassandra/exceptions.py deleted file mode 100644 index 311a254..0000000 --- a/src/async_cassandra/exceptions.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -Exception classes for async-cassandra. -""" - -from typing import Optional - - -class AsyncCassandraError(Exception): - """Base exception for all async-cassandra errors.""" - - def __init__(self, message: str, cause: Optional[Exception] = None): - super().__init__(message) - self.cause = cause - - -class ConnectionError(AsyncCassandraError): - """Raised when connection to Cassandra fails.""" - - pass - - -class QueryError(AsyncCassandraError): - """Raised when a query execution fails.""" - - pass - - -class TimeoutError(AsyncCassandraError): - """Raised when an operation times out.""" - - pass - - -class AuthenticationError(AsyncCassandraError): - """Raised when authentication fails.""" - - pass - - -class ConfigurationError(AsyncCassandraError): - """Raised when configuration is invalid.""" - - pass diff --git a/src/async_cassandra/metrics.py b/src/async_cassandra/metrics.py deleted file mode 100644 index 90f853d..0000000 --- a/src/async_cassandra/metrics.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Metrics and observability system for async-cassandra. - -This module provides comprehensive monitoring capabilities including: -- Query performance metrics -- Connection health tracking -- Error rate monitoring -- Custom metrics collection -""" - -import asyncio -import logging -from collections import defaultdict, deque -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -if TYPE_CHECKING: - from prometheus_client import Counter, Gauge, Histogram - -logger = logging.getLogger(__name__) - - -@dataclass -class QueryMetrics: - """Metrics for individual query execution.""" - - query_hash: str - duration: float - success: bool - error_type: Optional[str] = None - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - parameters_count: int = 0 - result_size: int = 0 - - -@dataclass -class ConnectionMetrics: - """Metrics for connection health.""" - - host: str - is_healthy: bool - last_check: datetime - response_time: float - error_count: int = 0 - total_queries: int = 0 - - -class MetricsCollector: - """Base class for metrics collection backends.""" - - async def record_query(self, metrics: QueryMetrics) -> None: - """Record query execution metrics.""" - raise NotImplementedError - - async def record_connection_health(self, metrics: ConnectionMetrics) -> None: - """Record connection health metrics.""" - raise NotImplementedError - - async def get_stats(self) -> Dict[str, Any]: - """Get aggregated statistics.""" - raise NotImplementedError - - -class InMemoryMetricsCollector(MetricsCollector): - """In-memory metrics collector for development and testing.""" - - def __init__(self, max_entries: int = 10000): - self.max_entries = max_entries - self.query_metrics: deque[QueryMetrics] = deque(maxlen=max_entries) - self.connection_metrics: Dict[str, ConnectionMetrics] = {} - self.error_counts: Dict[str, int] = defaultdict(int) - self.query_counts: Dict[str, int] = defaultdict(int) - self._lock = asyncio.Lock() - - async def record_query(self, metrics: QueryMetrics) -> None: - """Record query execution metrics.""" - async with self._lock: - self.query_metrics.append(metrics) - self.query_counts[metrics.query_hash] += 1 - - if not metrics.success and metrics.error_type: - self.error_counts[metrics.error_type] += 1 - - async def record_connection_health(self, metrics: ConnectionMetrics) -> None: - """Record connection health metrics.""" - async with self._lock: - self.connection_metrics[metrics.host] = metrics - - async def get_stats(self) -> Dict[str, Any]: - """Get aggregated statistics.""" - async with self._lock: - if not self.query_metrics: - return {"message": "No metrics available"} - - # Calculate performance stats - recent_queries = [ - q - for q in self.query_metrics - if q.timestamp > datetime.now(timezone.utc) - timedelta(minutes=5) - ] - - if recent_queries: - durations = [q.duration for q in recent_queries] - success_rate = sum(1 for q in recent_queries if q.success) / len(recent_queries) - - stats = { - "query_performance": { - "total_queries": len(self.query_metrics), - "recent_queries_5min": len(recent_queries), - "avg_duration_ms": sum(durations) / len(durations) * 1000, - "min_duration_ms": min(durations) * 1000, - "max_duration_ms": max(durations) * 1000, - "success_rate": success_rate, - "queries_per_second": len(recent_queries) / 300, # 5 minutes - }, - "error_summary": dict(self.error_counts), - "top_queries": dict( - sorted(self.query_counts.items(), key=lambda x: x[1], reverse=True)[:10] - ), - "connection_health": { - host: { - "healthy": metrics.is_healthy, - "response_time_ms": metrics.response_time * 1000, - "error_count": metrics.error_count, - "total_queries": metrics.total_queries, - } - for host, metrics in self.connection_metrics.items() - }, - } - else: - stats = { - "query_performance": {"message": "No recent queries"}, - "error_summary": dict(self.error_counts), - "top_queries": {}, - "connection_health": {}, - } - - return stats - - -class PrometheusMetricsCollector(MetricsCollector): - """Prometheus metrics collector for production monitoring.""" - - def __init__(self) -> None: - self._available = False - self.query_duration: Optional["Histogram"] = None - self.query_total: Optional["Counter"] = None - self.connection_health: Optional["Gauge"] = None - self.error_total: Optional["Counter"] = None - - try: - from prometheus_client import Counter, Gauge, Histogram - - self.query_duration = Histogram( - "cassandra_query_duration_seconds", - "Time spent executing Cassandra queries", - ["query_type", "success"], - ) - self.query_total = Counter( - "cassandra_queries_total", - "Total number of Cassandra queries", - ["query_type", "success"], - ) - self.connection_health = Gauge( - "cassandra_connection_healthy", "Whether Cassandra connection is healthy", ["host"] - ) - self.error_total = Counter( - "cassandra_errors_total", "Total number of Cassandra errors", ["error_type"] - ) - self._available = True - except ImportError: - logger.warning("prometheus_client not available, metrics disabled") - - async def record_query(self, metrics: QueryMetrics) -> None: - """Record query execution metrics to Prometheus.""" - if not self._available: - return - - query_type = "prepared" if "prepared" in metrics.query_hash else "simple" - success_label = "success" if metrics.success else "failure" - - if self.query_duration is not None: - self.query_duration.labels(query_type=query_type, success=success_label).observe( - metrics.duration - ) - - if self.query_total is not None: - self.query_total.labels(query_type=query_type, success=success_label).inc() - - if not metrics.success and metrics.error_type and self.error_total is not None: - self.error_total.labels(error_type=metrics.error_type).inc() - - async def record_connection_health(self, metrics: ConnectionMetrics) -> None: - """Record connection health to Prometheus.""" - if not self._available: - return - - if self.connection_health is not None: - self.connection_health.labels(host=metrics.host).set(1 if metrics.is_healthy else 0) - - async def get_stats(self) -> Dict[str, Any]: - """Get current Prometheus metrics.""" - if not self._available: - return {"error": "Prometheus client not available"} - - return {"message": "Metrics available via Prometheus endpoint"} - - -class MetricsMiddleware: - """Middleware to automatically collect metrics for async-cassandra operations.""" - - def __init__(self, collectors: List[MetricsCollector]): - self.collectors = collectors - self._enabled = True - - def enable(self) -> None: - """Enable metrics collection.""" - self._enabled = True - - def disable(self) -> None: - """Disable metrics collection.""" - self._enabled = False - - async def record_query_metrics( - self, - query: str, - duration: float, - success: bool, - error_type: Optional[str] = None, - parameters_count: int = 0, - result_size: int = 0, - ) -> None: - """Record metrics for a query execution.""" - if not self._enabled: - return - - # Create a hash of the query for grouping (remove parameter values) - query_hash = self._normalize_query(query) - - metrics = QueryMetrics( - query_hash=query_hash, - duration=duration, - success=success, - error_type=error_type, - parameters_count=parameters_count, - result_size=result_size, - ) - - # Send to all collectors - for collector in self.collectors: - try: - await collector.record_query(metrics) - except Exception as e: - logger.warning(f"Failed to record metrics: {e}") - - async def record_connection_metrics( - self, - host: str, - is_healthy: bool, - response_time: float, - error_count: int = 0, - total_queries: int = 0, - ) -> None: - """Record connection health metrics.""" - if not self._enabled: - return - - metrics = ConnectionMetrics( - host=host, - is_healthy=is_healthy, - last_check=datetime.now(timezone.utc), - response_time=response_time, - error_count=error_count, - total_queries=total_queries, - ) - - for collector in self.collectors: - try: - await collector.record_connection_health(metrics) - except Exception as e: - logger.warning(f"Failed to record connection metrics: {e}") - - def _normalize_query(self, query: str) -> str: - """Normalize query for grouping by removing parameter values.""" - import hashlib - import re - - # Remove extra whitespace and normalize - normalized = re.sub(r"\s+", " ", query.strip().upper()) - - # Replace parameter placeholders with generic markers - normalized = re.sub(r"\?", "?", normalized) - normalized = re.sub(r"'[^']*'", "'?'", normalized) # String literals - normalized = re.sub(r"\b\d+\b", "?", normalized) # Numbers - - # Create a hash for storage efficiency (not for security) - # Using MD5 here is fine as it's just for creating identifiers - return hashlib.md5(normalized.encode(), usedforsecurity=False).hexdigest()[:12] - - -# Factory function for easy setup -def create_metrics_system( - backend: str = "memory", prometheus_enabled: bool = False -) -> MetricsMiddleware: - """Create a metrics system with specified backend.""" - collectors: List[MetricsCollector] = [] - - if backend == "memory": - collectors.append(InMemoryMetricsCollector()) - - if prometheus_enabled: - collectors.append(PrometheusMetricsCollector()) - - return MetricsMiddleware(collectors) diff --git a/src/async_cassandra/monitoring.py b/src/async_cassandra/monitoring.py deleted file mode 100644 index 5034200..0000000 --- a/src/async_cassandra/monitoring.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -Connection monitoring utilities for async-cassandra. - -This module provides tools to monitor connection health and performance metrics -for the async-cassandra wrapper. Since the Python driver maintains only one -connection per host, monitoring these connections is crucial. -""" - -import asyncio -import logging -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -from cassandra.cluster import Host -from cassandra.query import SimpleStatement - -from .session import AsyncCassandraSession - -logger = logging.getLogger(__name__) - - -# Host status constants -HOST_STATUS_UP = "up" -HOST_STATUS_DOWN = "down" -HOST_STATUS_UNKNOWN = "unknown" - - -@dataclass -class HostMetrics: - """Metrics for a single Cassandra host.""" - - address: str - datacenter: Optional[str] - rack: Optional[str] - status: str - release_version: Optional[str] - connection_count: int # Always 1 for protocol v3+ - latency_ms: Optional[float] = None - last_error: Optional[str] = None - last_check: Optional[datetime] = None - - -@dataclass -class ClusterMetrics: - """Metrics for the entire Cassandra cluster.""" - - timestamp: datetime - cluster_name: Optional[str] - protocol_version: int - hosts: List[HostMetrics] - total_connections: int - healthy_hosts: int - unhealthy_hosts: int - app_metrics: Dict[str, Any] = field(default_factory=dict) - - -class ConnectionMonitor: - """ - Monitor async-cassandra connection health and metrics. - - Since the Python driver maintains only one connection per host, - this monitor helps track the health and performance of these - critical connections. - """ - - def __init__(self, session: AsyncCassandraSession): - """ - Initialize the connection monitor. - - Args: - session: The async Cassandra session to monitor - """ - self.session = session - self.metrics: Dict[str, Any] = { - "requests_sent": 0, - "requests_completed": 0, - "requests_failed": 0, - "last_health_check": None, - "monitoring_started": datetime.now(timezone.utc), - } - self._monitoring_task: Optional[asyncio.Task[None]] = None - self._callbacks: List[Callable[[ClusterMetrics], Any]] = [] - - def add_callback(self, callback: Callable[[ClusterMetrics], Any]) -> None: - """ - Add a callback to be called when metrics are collected. - - Args: - callback: Function to call with cluster metrics - """ - self._callbacks.append(callback) - - async def check_host_health(self, host: Host) -> HostMetrics: - """ - Check the health of a specific host. - - Args: - host: The host to check - - Returns: - HostMetrics for the host - """ - metrics = HostMetrics( - address=str(host.address), - datacenter=host.datacenter, - rack=host.rack, - status=HOST_STATUS_UP if host.is_up else HOST_STATUS_DOWN, - release_version=host.release_version, - connection_count=1 if host.is_up else 0, - ) - - if host.is_up: - try: - # Test connection latency with a simple query - start = asyncio.get_event_loop().time() - - # Create a statement that routes to the specific host - statement = SimpleStatement( - "SELECT now() FROM system.local", - # Note: host parameter might not be directly supported, - # but we try to measure general latency - ) - - await self.session.execute(statement) - - metrics.latency_ms = (asyncio.get_event_loop().time() - start) * 1000 - metrics.last_check = datetime.now(timezone.utc) - - except Exception as e: - metrics.status = HOST_STATUS_UNKNOWN - metrics.last_error = str(e) - metrics.connection_count = 0 - logger.warning(f"Health check failed for host {host.address}: {e}") - - return metrics - - async def get_cluster_metrics(self) -> ClusterMetrics: - """ - Get comprehensive metrics for the entire cluster. - - Returns: - ClusterMetrics with current state - """ - cluster = self.session._session.cluster - - # Collect metrics for all hosts - host_metrics = [] - for host in cluster.metadata.all_hosts(): - host_metric = await self.check_host_health(host) - host_metrics.append(host_metric) - - # Calculate summary statistics - healthy_hosts = sum(1 for h in host_metrics if h.status == HOST_STATUS_UP) - unhealthy_hosts = sum(1 for h in host_metrics if h.status != HOST_STATUS_UP) - - return ClusterMetrics( - timestamp=datetime.now(timezone.utc), - cluster_name=cluster.metadata.cluster_name, - protocol_version=cluster.protocol_version, - hosts=host_metrics, - total_connections=sum(h.connection_count for h in host_metrics), - healthy_hosts=healthy_hosts, - unhealthy_hosts=unhealthy_hosts, - app_metrics=self.metrics.copy(), - ) - - async def warmup_connections(self) -> None: - """ - Pre-establish connections to all nodes. - - This is useful to avoid cold start latency on first queries. - """ - logger.info("Warming up connections to all nodes...") - - cluster = self.session._session.cluster - successful = 0 - failed = 0 - - for host in cluster.metadata.all_hosts(): - if host.is_up: - try: - # Execute a lightweight query to establish connection - statement = SimpleStatement("SELECT now() FROM system.local") - await self.session.execute(statement) - successful += 1 - logger.debug(f"Warmed up connection to {host.address}") - except Exception as e: - failed += 1 - logger.warning(f"Failed to warm up connection to {host.address}: {e}") - - logger.info(f"Connection warmup complete: {successful} successful, {failed} failed") - - async def start_monitoring(self, interval: int = 60) -> None: - """ - Start continuous monitoring. - - Args: - interval: Seconds between health checks - """ - if self._monitoring_task and not self._monitoring_task.done(): - logger.warning("Monitoring already running") - return - - self._monitoring_task = asyncio.create_task(self._monitoring_loop(interval)) - logger.info(f"Started connection monitoring with {interval}s interval") - - async def stop_monitoring(self) -> None: - """Stop continuous monitoring.""" - if self._monitoring_task: - self._monitoring_task.cancel() - try: - await self._monitoring_task - except asyncio.CancelledError: - pass - logger.info("Stopped connection monitoring") - - async def _monitoring_loop(self, interval: int) -> None: - """Internal monitoring loop.""" - while True: - try: - metrics = await self.get_cluster_metrics() - self.metrics["last_health_check"] = metrics.timestamp.isoformat() - - # Log summary - logger.info( - f"Cluster health: {metrics.healthy_hosts} healthy, " - f"{metrics.unhealthy_hosts} unhealthy hosts" - ) - - # Alert on issues - if metrics.unhealthy_hosts > 0: - logger.warning(f"ALERT: {metrics.unhealthy_hosts} hosts are unhealthy") - - # Call registered callbacks - for callback in self._callbacks: - try: - result = callback(metrics) - if asyncio.iscoroutine(result): - await result - except Exception as e: - logger.error(f"Callback error: {e}") - - await asyncio.sleep(interval) - - except asyncio.CancelledError: - raise - except Exception as e: - logger.error(f"Monitoring error: {e}") - await asyncio.sleep(interval) - - def get_connection_summary(self) -> Dict[str, Any]: - """ - Get a summary of connection status. - - Returns: - Dictionary with connection summary - """ - cluster = self.session._session.cluster - hosts = list(cluster.metadata.all_hosts()) - - return { - "total_hosts": len(hosts), - "up_hosts": sum(1 for h in hosts if h.is_up), - "down_hosts": sum(1 for h in hosts if not h.is_up), - "protocol_version": cluster.protocol_version, - "max_requests_per_connection": 32768 if cluster.protocol_version >= 3 else 128, - "note": "Python driver maintains 1 connection per host (protocol v3+)", - } - - -class RateLimitedSession: - """ - Rate-limited wrapper for AsyncCassandraSession. - - Since the Python driver is limited to one connection per host, - this wrapper helps prevent overwhelming those connections. - """ - - def __init__(self, session: AsyncCassandraSession, max_concurrent: int = 1000): - """ - Initialize rate-limited session. - - Args: - session: The async session to wrap - max_concurrent: Maximum concurrent requests - """ - self.session = session - self.semaphore = asyncio.Semaphore(max_concurrent) - self.metrics = {"total_requests": 0, "active_requests": 0, "rejected_requests": 0} - - async def execute(self, query: Any, parameters: Any = None, **kwargs: Any) -> Any: - """Execute a query with rate limiting.""" - async with self.semaphore: - self.metrics["total_requests"] += 1 - self.metrics["active_requests"] += 1 - try: - result = await self.session.execute(query, parameters, **kwargs) - return result - finally: - self.metrics["active_requests"] -= 1 - - async def prepare(self, query: str) -> Any: - """Prepare a statement (not rate limited).""" - return await self.session.prepare(query) - - def get_metrics(self) -> Dict[str, int]: - """Get rate limiting metrics.""" - return self.metrics.copy() - - -async def create_monitored_session( - contact_points: List[str], - keyspace: Optional[str] = None, - max_concurrent: Optional[int] = None, - warmup: bool = True, -) -> Tuple[Union[RateLimitedSession, AsyncCassandraSession], ConnectionMonitor]: - """ - Create a monitored and optionally rate-limited session. - - Args: - contact_points: Cassandra contact points - keyspace: Optional keyspace to use - max_concurrent: Optional max concurrent requests - warmup: Whether to warm up connections - - Returns: - Tuple of (rate_limited_session, monitor) - """ - from .cluster import AsyncCluster - - # Create cluster and session - cluster = AsyncCluster(contact_points=contact_points) - session = await cluster.connect(keyspace) - - # Create monitor - monitor = ConnectionMonitor(session) - - # Warm up connections if requested - if warmup: - await monitor.warmup_connections() - - # Create rate-limited wrapper if requested - if max_concurrent: - rate_limited = RateLimitedSession(session, max_concurrent) - return rate_limited, monitor - else: - return session, monitor diff --git a/src/async_cassandra/py.typed b/src/async_cassandra/py.typed deleted file mode 100644 index e69de29..0000000 diff --git a/src/async_cassandra/result.py b/src/async_cassandra/result.py deleted file mode 100644 index a9e6fb0..0000000 --- a/src/async_cassandra/result.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Simplified async result handling for Cassandra queries. - -This implementation focuses on essential functionality without -complex state tracking. -""" - -import asyncio -import threading -from typing import Any, AsyncIterator, List, Optional - -from cassandra.cluster import ResponseFuture - - -class AsyncResultHandler: - """ - Simplified handler for asynchronous results from Cassandra queries. - - This class wraps ResponseFuture callbacks in asyncio Futures, - providing async/await support with minimal complexity. - """ - - def __init__(self, response_future: ResponseFuture): - self.response_future = response_future - self.rows: List[Any] = [] - self._future: Optional[asyncio.Future[AsyncResultSet]] = None - # Thread lock is necessary since callbacks come from driver threads - self._lock = threading.Lock() - # Store early results/errors if callbacks fire before get_result - self._early_result: Optional[AsyncResultSet] = None - self._early_error: Optional[Exception] = None - - # Set up callbacks - self.response_future.add_callbacks(callback=self._handle_page, errback=self._handle_error) - - def _cleanup_callbacks(self) -> None: - """Clean up response future callbacks to prevent memory leaks.""" - try: - # Clear callbacks if the method exists - if hasattr(self.response_future, "clear_callbacks"): - self.response_future.clear_callbacks() - except Exception: - # Ignore errors during cleanup - pass - - def _handle_page(self, rows: List[Any]) -> None: - """Handle successful page retrieval. - - This method is called from driver threads, so we need thread safety. - """ - with self._lock: - if rows is not None: - # Create a defensive copy to avoid cross-thread data issues - self.rows.extend(list(rows)) - - if self.response_future.has_more_pages: - self.response_future.start_fetching_next_page() - else: - # All pages fetched - # Create a copy of rows to avoid reference issues - final_result = AsyncResultSet(list(self.rows), self.response_future) - - if self._future and not self._future.done(): - loop = getattr(self, "_loop", None) - if loop: - loop.call_soon_threadsafe(self._future.set_result, final_result) - else: - # Store for later if future doesn't exist yet - self._early_result = final_result - - # Clean up callbacks after completion - self._cleanup_callbacks() - - def _handle_error(self, exc: Exception) -> None: - """Handle query execution error.""" - with self._lock: - if self._future and not self._future.done(): - loop = getattr(self, "_loop", None) - if loop: - loop.call_soon_threadsafe(self._future.set_exception, exc) - else: - # Store for later if future doesn't exist yet - self._early_error = exc - - # Clean up callbacks to prevent memory leaks - self._cleanup_callbacks() - - async def get_result(self, timeout: Optional[float] = None) -> "AsyncResultSet": - """ - Wait for the query to complete and return the result. - - Args: - timeout: Optional timeout in seconds. - - Returns: - AsyncResultSet containing all rows from the query. - - Raises: - asyncio.TimeoutError: If the query doesn't complete within the timeout. - """ - # Create future in the current event loop - loop = asyncio.get_running_loop() - self._future = loop.create_future() - self._loop = loop # Store loop for callbacks - - # Check if result/error is already available (callback might have fired early) - with self._lock: - if self._early_error: - self._future.set_exception(self._early_error) - elif self._early_result: - self._future.set_result(self._early_result) - # Remove the early check for empty results - let callbacks handle it - - # Use query timeout if no explicit timeout provided - if ( - timeout is None - and hasattr(self.response_future, "timeout") - and self.response_future.timeout is not None - ): - timeout = self.response_future.timeout - - try: - if timeout is not None: - return await asyncio.wait_for(self._future, timeout=timeout) - else: - return await self._future - except asyncio.TimeoutError: - # Clean up on timeout - self._cleanup_callbacks() - raise - except Exception: - # Clean up on any error - self._cleanup_callbacks() - raise - - -class AsyncResultSet: - """ - Async wrapper for Cassandra query results. - - Provides async iteration over result rows and metadata access. - """ - - def __init__(self, rows: List[Any], response_future: Any = None): - self._rows = rows - self._index = 0 - self._response_future = response_future - - def __aiter__(self) -> AsyncIterator[Any]: - """Return async iterator for the result set.""" - self._index = 0 # Reset index for each iteration - return self - - async def __anext__(self) -> Any: - """Get next row from the result set.""" - if self._index >= len(self._rows): - raise StopAsyncIteration - - row = self._rows[self._index] - self._index += 1 - return row - - def __len__(self) -> int: - """Return number of rows in the result set.""" - return len(self._rows) - - def __getitem__(self, index: int) -> Any: - """Get row by index.""" - return self._rows[index] - - @property - def rows(self) -> List[Any]: - """Get all rows as a list.""" - return self._rows - - def one(self) -> Optional[Any]: - """ - Get the first row or None if empty. - - Returns: - First row from the result set or None. - """ - return self._rows[0] if self._rows else None - - def all(self) -> List[Any]: - """ - Get all rows. - - Returns: - List of all rows in the result set. - """ - return self._rows - - def get_query_trace(self) -> Any: - """ - Get the query trace if available. - - Returns: - Query trace object or None if tracing wasn't enabled. - """ - if self._response_future and hasattr(self._response_future, "get_query_trace"): - return self._response_future.get_query_trace() - return None diff --git a/src/async_cassandra/retry_policy.py b/src/async_cassandra/retry_policy.py deleted file mode 100644 index 65c3f7c..0000000 --- a/src/async_cassandra/retry_policy.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Async-aware retry policies for Cassandra operations. -""" - -from typing import Optional, Tuple, Union - -from cassandra.policies import RetryPolicy, WriteType -from cassandra.query import BatchStatement, ConsistencyLevel, PreparedStatement, SimpleStatement - - -class AsyncRetryPolicy(RetryPolicy): - """ - Retry policy for async Cassandra operations. - - This extends the base RetryPolicy with async-aware retry logic - and configurable retry limits. - """ - - def __init__(self, max_retries: int = 3): - """ - Initialize the retry policy. - - Args: - max_retries: Maximum number of retry attempts. - """ - super().__init__() - self.max_retries = max_retries - - def on_read_timeout( - self, - query: Union[SimpleStatement, PreparedStatement, BatchStatement], - consistency: ConsistencyLevel, - required_responses: int, - received_responses: int, - data_retrieved: bool, - retry_num: int, - ) -> Tuple[int, Optional[ConsistencyLevel]]: - """ - Handle read timeout. - - Args: - query: The query statement that timed out. - consistency: The consistency level of the query. - required_responses: Number of responses required by consistency level. - received_responses: Number of responses received before timeout. - data_retrieved: Whether any data was retrieved. - retry_num: Current retry attempt number. - - Returns: - Tuple of (retry decision, consistency level to use). - """ - if retry_num >= self.max_retries: - return self.RETHROW, None - - # If we got some data, retry might succeed - if data_retrieved: - return self.RETRY, consistency - - # If we got enough responses, retry at same consistency - if received_responses >= required_responses: - return self.RETRY, consistency - - # Otherwise, rethrow - return self.RETHROW, None - - def on_write_timeout( - self, - query: Union[SimpleStatement, PreparedStatement, BatchStatement], - consistency: ConsistencyLevel, - write_type: str, - required_responses: int, - received_responses: int, - retry_num: int, - ) -> Tuple[int, Optional[ConsistencyLevel]]: - """ - Handle write timeout. - - Args: - query: The query statement that timed out. - consistency: The consistency level of the query. - write_type: Type of write operation. - required_responses: Number of responses required by consistency level. - received_responses: Number of responses received before timeout. - retry_num: Current retry attempt number. - - Returns: - Tuple of (retry decision, consistency level to use). - """ - if retry_num >= self.max_retries: - return self.RETHROW, None - - # CRITICAL: Only retry write operations if they are explicitly marked as idempotent - # Non-idempotent writes should NEVER be retried as they could cause: - # - Duplicate inserts - # - Multiple increments/decrements - # - Data corruption - - # Check if query has is_idempotent attribute and if it's exactly True - # Only retry if is_idempotent is explicitly True (not truthy values) - if getattr(query, "is_idempotent", None) is not True: - # Query is not idempotent or not explicitly marked as True - do not retry - return self.RETHROW, None - - # Only retry simple and batch writes (including UNLOGGED_BATCH) that are explicitly idempotent - if write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.UNLOGGED_BATCH): - return self.RETRY, consistency - - return self.RETHROW, None - - def on_unavailable( - self, - query: Union[SimpleStatement, PreparedStatement, BatchStatement], - consistency: ConsistencyLevel, - required_replicas: int, - alive_replicas: int, - retry_num: int, - ) -> Tuple[int, Optional[ConsistencyLevel]]: - """ - Handle unavailable exception. - - Args: - query: The query that failed. - consistency: The consistency level of the query. - required_replicas: Number of replicas required by consistency level. - alive_replicas: Number of replicas that are alive. - retry_num: Current retry attempt number. - - Returns: - Tuple of (retry decision, consistency level to use). - """ - if retry_num >= self.max_retries: - return self.RETHROW, None - - # Try next host on first retry - if retry_num == 0: - return self.RETRY_NEXT_HOST, consistency - - # Retry with same consistency - return self.RETRY, consistency - - def on_request_error( - self, - query: Union[SimpleStatement, PreparedStatement, BatchStatement], - consistency: ConsistencyLevel, - error: Exception, - retry_num: int, - ) -> Tuple[int, Optional[ConsistencyLevel]]: - """ - Handle request error. - - Args: - query: The query that failed. - consistency: The consistency level of the query. - error: The error that occurred. - retry_num: Current retry attempt number. - - Returns: - Tuple of (retry decision, consistency level to use). - """ - if retry_num >= self.max_retries: - return self.RETHROW, None - - # Try next host for connection errors - return self.RETRY_NEXT_HOST, consistency diff --git a/src/async_cassandra/session.py b/src/async_cassandra/session.py deleted file mode 100644 index 378b56e..0000000 --- a/src/async_cassandra/session.py +++ /dev/null @@ -1,454 +0,0 @@ -""" -Simplified async session management for Cassandra connections. - -This implementation focuses on being a thin wrapper around the driver, -avoiding complex locking and state management. -""" - -import asyncio -import logging -import time -from typing import Any, Dict, Optional - -from cassandra.cluster import _NOT_SET, EXEC_PROFILE_DEFAULT, Cluster, Session -from cassandra.query import BatchStatement, PreparedStatement, SimpleStatement - -from .base import AsyncContextManageable -from .exceptions import ConnectionError, QueryError -from .metrics import MetricsMiddleware -from .result import AsyncResultHandler, AsyncResultSet -from .streaming import AsyncStreamingResultSet, StreamingResultHandler - -logger = logging.getLogger(__name__) - - -class AsyncCassandraSession(AsyncContextManageable): - """ - Simplified async wrapper for Cassandra Session. - - This implementation: - - Uses a single lock only for close operations - - Accepts that operations might fail if close() is called concurrently - - Focuses on being a thin wrapper without complex state management - """ - - def __init__(self, session: Session, metrics: Optional[MetricsMiddleware] = None): - """ - Initialize async session wrapper. - - Args: - session: The underlying Cassandra session. - metrics: Optional metrics middleware for observability. - """ - self._session = session - self._metrics = metrics - self._closed = False - self._close_lock = asyncio.Lock() - - def _record_metrics_async( - self, - query_str: str, - duration: float, - success: bool, - error_type: Optional[str], - parameters_count: int, - result_size: int, - ) -> None: - """ - Record metrics in a fire-and-forget manner. - - This method creates a background task to record metrics without blocking - the main execution flow or preventing exception propagation. - """ - if not self._metrics: - return - - async def _record() -> None: - try: - assert self._metrics is not None # Type guard for mypy - await self._metrics.record_query_metrics( - query=query_str, - duration=duration, - success=success, - error_type=error_type, - parameters_count=parameters_count, - result_size=result_size, - ) - except Exception as e: - # Log error but don't propagate - metrics should not break queries - logger.warning(f"Failed to record metrics: {e}") - - # Create task without awaiting it - try: - asyncio.create_task(_record()) - except RuntimeError: - # No event loop running, skip metrics - pass - - @classmethod - async def create( - cls, cluster: Cluster, keyspace: Optional[str] = None - ) -> "AsyncCassandraSession": - """ - Create a new async session. - - Args: - cluster: The Cassandra cluster to connect to. - keyspace: Optional keyspace to use. - - Returns: - New AsyncCassandraSession instance. - """ - loop = asyncio.get_event_loop() - - # Connect in executor to avoid blocking - session = await loop.run_in_executor( - None, lambda: cluster.connect(keyspace) if keyspace else cluster.connect() - ) - - return cls(session) - - async def execute( - self, - query: Any, - parameters: Any = None, - trace: bool = False, - custom_payload: Any = None, - timeout: Any = None, - execution_profile: Any = EXEC_PROFILE_DEFAULT, - paging_state: Any = None, - host: Any = None, - execute_as: Any = None, - ) -> AsyncResultSet: - """ - Execute a CQL query asynchronously. - - Args: - query: The query to execute. - parameters: Query parameters. - trace: Whether to enable query tracing. - custom_payload: Custom payload to send with the request. - timeout: Query timeout in seconds or _NOT_SET. - execution_profile: Execution profile name or object to use. - paging_state: Paging state for resuming paged queries. - host: Specific host to execute query on. - execute_as: User to execute the query as. - - Returns: - AsyncResultSet containing query results. - - Raises: - QueryError: If query execution fails. - ConnectionError: If session is closed. - """ - # Simple closed check - no lock needed for read - if self._closed: - raise ConnectionError("Session is closed") - - # Start metrics timing - start_time = time.perf_counter() - success = False - error_type = None - result_size = 0 - - try: - # Fix timeout handling - use _NOT_SET if timeout is None - response_future = self._session.execute_async( - query, - parameters, - trace, - custom_payload, - timeout if timeout is not None else _NOT_SET, - execution_profile, - paging_state, - host, - execute_as, - ) - - handler = AsyncResultHandler(response_future) - # Pass timeout to get_result if specified - query_timeout = timeout if timeout is not None and timeout != _NOT_SET else None - result = await handler.get_result(timeout=query_timeout) - - success = True - result_size = len(result.rows) if hasattr(result, "rows") else 0 - return result - - except Exception as e: - error_type = type(e).__name__ - # Check if this is a Cassandra driver exception by looking at its module - if ( - hasattr(e, "__module__") - and (e.__module__ == "cassandra" or e.__module__.startswith("cassandra.")) - or isinstance(e, asyncio.TimeoutError) - ): - # Pass through all Cassandra driver exceptions and asyncio.TimeoutError - raise - else: - # Only wrap unexpected exceptions - raise QueryError(f"Query execution failed: {str(e)}", cause=e) from e - finally: - # Record metrics in a fire-and-forget manner - duration = time.perf_counter() - start_time - query_str = ( - str(query) if isinstance(query, (SimpleStatement, PreparedStatement)) else query - ) - params_count = len(parameters) if parameters else 0 - - self._record_metrics_async( - query_str=query_str, - duration=duration, - success=success, - error_type=error_type, - parameters_count=params_count, - result_size=result_size, - ) - - async def execute_stream( - self, - query: Any, - parameters: Any = None, - stream_config: Any = None, - trace: bool = False, - custom_payload: Any = None, - timeout: Any = None, - execution_profile: Any = EXEC_PROFILE_DEFAULT, - paging_state: Any = None, - host: Any = None, - execute_as: Any = None, - ) -> AsyncStreamingResultSet: - """ - Execute a CQL query with streaming support for large result sets. - - This method is memory-efficient for queries that return many rows, - as it fetches results page by page instead of loading everything - into memory at once. - - Args: - query: The query to execute. - parameters: Query parameters. - stream_config: Configuration for streaming (fetch size, callbacks, etc.) - trace: Whether to enable query tracing. - custom_payload: Custom payload to send with the request. - timeout: Query timeout in seconds or _NOT_SET. - execution_profile: Execution profile name or object to use. - paging_state: Paging state for resuming paged queries. - host: Specific host to execute query on. - execute_as: User to execute the query as. - - Returns: - AsyncStreamingResultSet for memory-efficient iteration. - - Raises: - QueryError: If query execution fails. - ConnectionError: If session is closed. - """ - # Simple closed check - no lock needed for read - if self._closed: - raise ConnectionError("Session is closed") - - # Start metrics timing for consistency with execute() - start_time = time.perf_counter() - success = False - error_type = None - - try: - # Apply fetch_size from stream_config if provided - query_to_execute = query - if stream_config and hasattr(stream_config, "fetch_size"): - # If query is a string, create a SimpleStatement with fetch_size - if isinstance(query_to_execute, str): - from cassandra.query import SimpleStatement - - query_to_execute = SimpleStatement( - query_to_execute, fetch_size=stream_config.fetch_size - ) - # If it's already a statement, try to set fetch_size - elif hasattr(query_to_execute, "fetch_size"): - query_to_execute.fetch_size = stream_config.fetch_size - - response_future = self._session.execute_async( - query_to_execute, - parameters, - trace, - custom_payload, - timeout if timeout is not None else _NOT_SET, - execution_profile, - paging_state, - host, - execute_as, - ) - - handler = StreamingResultHandler(response_future, stream_config) - result = await handler.get_streaming_result() - success = True - return result - - except Exception as e: - error_type = type(e).__name__ - # Check if this is a Cassandra driver exception by looking at its module - if ( - hasattr(e, "__module__") - and (e.__module__ == "cassandra" or e.__module__.startswith("cassandra.")) - or isinstance(e, asyncio.TimeoutError) - ): - # Pass through all Cassandra driver exceptions and asyncio.TimeoutError - raise - else: - # Only wrap unexpected exceptions - raise QueryError(f"Streaming query execution failed: {str(e)}", cause=e) from e - finally: - # Record metrics in a fire-and-forget manner - duration = time.perf_counter() - start_time - # Import here to avoid circular imports - from cassandra.query import PreparedStatement, SimpleStatement - - query_str = ( - str(query) if isinstance(query, (SimpleStatement, PreparedStatement)) else query - ) - params_count = len(parameters) if parameters else 0 - - self._record_metrics_async( - query_str=query_str, - duration=duration, - success=success, - error_type=error_type, - parameters_count=params_count, - result_size=0, # Streaming doesn't know size upfront - ) - - async def execute_batch( - self, - batch_statement: BatchStatement, - trace: bool = False, - custom_payload: Optional[Dict[str, bytes]] = None, - timeout: Any = None, - execution_profile: Any = EXEC_PROFILE_DEFAULT, - ) -> AsyncResultSet: - """ - Execute a batch statement asynchronously. - - Args: - batch_statement: The batch statement to execute. - trace: Whether to enable query tracing. - custom_payload: Custom payload to send with the request. - timeout: Query timeout in seconds. - execution_profile: Execution profile to use. - - Returns: - AsyncResultSet (usually empty for batch operations). - - Raises: - QueryError: If batch execution fails. - ConnectionError: If session is closed. - """ - return await self.execute( - batch_statement, - trace=trace, - custom_payload=custom_payload, - timeout=timeout if timeout is not None else _NOT_SET, - execution_profile=execution_profile, - ) - - async def prepare( - self, query: str, custom_payload: Any = None, timeout: Optional[float] = None - ) -> PreparedStatement: - """ - Prepare a CQL statement asynchronously. - - Args: - query: The query to prepare. - custom_payload: Custom payload to send with the request. - timeout: Timeout in seconds. Defaults to DEFAULT_REQUEST_TIMEOUT. - - Returns: - PreparedStatement that can be executed multiple times. - - Raises: - QueryError: If statement preparation fails. - asyncio.TimeoutError: If preparation times out. - ConnectionError: If session is closed. - """ - # Simple closed check - no lock needed for read - if self._closed: - raise ConnectionError("Session is closed") - - # Import here to avoid circular import - from .constants import DEFAULT_REQUEST_TIMEOUT - - if timeout is None: - timeout = DEFAULT_REQUEST_TIMEOUT - - try: - loop = asyncio.get_event_loop() - - # Prepare in executor to avoid blocking with timeout - prepared = await asyncio.wait_for( - loop.run_in_executor(None, lambda: self._session.prepare(query, custom_payload)), - timeout=timeout, - ) - - return prepared - except Exception as e: - # Check if this is a Cassandra driver exception by looking at its module - if ( - hasattr(e, "__module__") - and (e.__module__ == "cassandra" or e.__module__.startswith("cassandra.")) - or isinstance(e, asyncio.TimeoutError) - ): - # Pass through all Cassandra driver exceptions and asyncio.TimeoutError - raise - else: - # Only wrap unexpected exceptions - raise QueryError(f"Statement preparation failed: {str(e)}", cause=e) from e - - async def close(self) -> None: - """ - Close the session and release resources. - - This method is idempotent and can be called multiple times safely. - Uses a single lock to ensure shutdown is called only once. - """ - async with self._close_lock: - if not self._closed: - self._closed = True - loop = asyncio.get_event_loop() - # Use a reasonable timeout for shutdown operations - await asyncio.wait_for( - loop.run_in_executor(None, self._session.shutdown), timeout=30.0 - ) - # Give the driver's internal threads time to finish - # This helps prevent "cannot schedule new futures after shutdown" errors - await asyncio.sleep(5.0) - - @property - def is_closed(self) -> bool: - """Check if the session is closed.""" - return self._closed - - @property - def keyspace(self) -> Optional[str]: - """Get current keyspace.""" - keyspace = self._session.keyspace - return keyspace if isinstance(keyspace, str) else None - - async def set_keyspace(self, keyspace: str) -> None: - """ - Set the current keyspace. - - Args: - keyspace: The keyspace to use. - - Raises: - QueryError: If setting keyspace fails. - ValueError: If keyspace name is invalid. - ConnectionError: If session is closed. - """ - # Validate keyspace name to prevent injection attacks - if not keyspace or not all(c.isalnum() or c == "_" for c in keyspace): - raise ValueError( - f"Invalid keyspace name: '{keyspace}'. " - "Keyspace names must contain only alphanumeric characters and underscores." - ) - - await self.execute(f"USE {keyspace}") diff --git a/src/async_cassandra/streaming.py b/src/async_cassandra/streaming.py deleted file mode 100644 index eb28d98..0000000 --- a/src/async_cassandra/streaming.py +++ /dev/null @@ -1,336 +0,0 @@ -""" -Simplified streaming support for large result sets in async-cassandra. - -This implementation focuses on essential streaming functionality -without complex state tracking. -""" - -import asyncio -import logging -import threading -from dataclasses import dataclass -from typing import Any, AsyncIterator, Callable, List, Optional - -from cassandra.cluster import ResponseFuture -from cassandra.query import ConsistencyLevel, SimpleStatement - -logger = logging.getLogger(__name__) - - -@dataclass -class StreamConfig: - """Configuration for streaming results.""" - - fetch_size: int = 1000 # Number of rows per page - max_pages: Optional[int] = None # Limit number of pages (None = no limit) - page_callback: Optional[Callable[[int, int], None]] = None # Progress callback - timeout_seconds: Optional[float] = None # Timeout for the entire streaming operation - - -class AsyncStreamingResultSet: - """ - Simplified streaming result set that fetches pages on demand. - - This class provides memory-efficient iteration over large result sets - by fetching pages as needed rather than loading all results at once. - """ - - def __init__(self, response_future: ResponseFuture, config: Optional[StreamConfig] = None): - """ - Initialize streaming result set. - - Args: - response_future: The Cassandra response future - config: Streaming configuration - """ - self.response_future = response_future - self.config = config or StreamConfig() - - self._current_page: List[Any] = [] - self._current_index = 0 - self._page_number = 0 - self._total_rows = 0 - self._exhausted = False - self._error: Optional[Exception] = None - self._closed = False - - # Thread lock for thread-safe operations (necessary for driver callbacks) - self._lock = threading.Lock() - - # Event to signal when a page is ready - self._page_ready: Optional[asyncio.Event] = None - self._loop: Optional[asyncio.AbstractEventLoop] = None - - # Start fetching the first page - self._setup_callbacks() - - def _cleanup_callbacks(self) -> None: - """Clean up response future callbacks to prevent memory leaks.""" - try: - # Clear callbacks if the method exists - if hasattr(self.response_future, "clear_callbacks"): - self.response_future.clear_callbacks() - except Exception: - # Ignore errors during cleanup - pass - - def __del__(self) -> None: - """Ensure callbacks are cleaned up when object is garbage collected.""" - # Clean up callbacks to break circular references - self._cleanup_callbacks() - - def _setup_callbacks(self) -> None: - """Set up callbacks for the current page.""" - self.response_future.add_callbacks(callback=self._handle_page, errback=self._handle_error) - - # Check if the response_future already has an error - # This can happen with very short timeouts - if ( - hasattr(self.response_future, "_final_exception") - and self.response_future._final_exception - ): - self._handle_error(self.response_future._final_exception) - - def _handle_page(self, rows: Optional[List[Any]]) -> None: - """Handle successful page retrieval. - - This method is called from driver threads, so we need thread safety. - """ - with self._lock: - if rows is not None: - # Replace the current page (don't accumulate) - self._current_page = list(rows) # Defensive copy - self._current_index = 0 - self._page_number += 1 - self._total_rows += len(rows) - - # Check if we've reached the page limit - if self.config.max_pages and self._page_number >= self.config.max_pages: - self._exhausted = True - else: - self._current_page = [] - self._exhausted = True - - # Call progress callback if configured - if self.config.page_callback: - try: - self.config.page_callback(self._page_number, len(rows) if rows else 0) - except Exception as e: - logger.warning(f"Page callback error: {e}") - - # Signal that the page is ready - if self._page_ready and self._loop: - self._loop.call_soon_threadsafe(self._page_ready.set) - - def _handle_error(self, exc: Exception) -> None: - """Handle query execution error.""" - with self._lock: - self._error = exc - self._exhausted = True - # Clear current page to prevent memory leak - self._current_page = [] - self._current_index = 0 - - if self._page_ready and self._loop: - self._loop.call_soon_threadsafe(self._page_ready.set) - - # Clean up callbacks to prevent memory leaks - self._cleanup_callbacks() - - async def _fetch_next_page(self) -> bool: - """ - Fetch the next page of results. - - Returns: - True if a page was fetched, False if no more pages. - """ - if self._exhausted: - return False - - if not self.response_future.has_more_pages: - self._exhausted = True - return False - - # Initialize event if needed - if self._page_ready is None: - self._page_ready = asyncio.Event() - self._loop = asyncio.get_running_loop() - - # Clear the event before fetching - self._page_ready.clear() - - # Start fetching the next page - self.response_future.start_fetching_next_page() - - # Wait for the page to be ready - if self.config.timeout_seconds: - await asyncio.wait_for(self._page_ready.wait(), timeout=self.config.timeout_seconds) - else: - await self._page_ready.wait() - - # Check for errors - if self._error: - raise self._error - - return len(self._current_page) > 0 - - def __aiter__(self) -> AsyncIterator[Any]: - """Return async iterator for streaming results.""" - return self - - async def __anext__(self) -> Any: - """Get next row from the streaming result set.""" - # Initialize event if needed - if self._page_ready is None: - self._page_ready = asyncio.Event() - self._loop = asyncio.get_running_loop() - - # Wait for first page if needed - if self._page_number == 0 and not self._current_page: - # Use timeout from config if available - if self.config.timeout_seconds: - await asyncio.wait_for(self._page_ready.wait(), timeout=self.config.timeout_seconds) - else: - await self._page_ready.wait() - - # Check for errors first - if self._error: - raise self._error - - # If we have rows in the current page, return one - if self._current_index < len(self._current_page): - row = self._current_page[self._current_index] - self._current_index += 1 - return row - - # If current page is exhausted, try to fetch next page - if not self._exhausted and await self._fetch_next_page(): - # Recursively call to get the first row from new page - return await self.__anext__() - - # No more rows - raise StopAsyncIteration - - async def pages(self) -> AsyncIterator[List[Any]]: - """ - Iterate over pages instead of individual rows. - - Yields: - Lists of row objects (pages). - """ - # Initialize event if needed - if self._page_ready is None: - self._page_ready = asyncio.Event() - self._loop = asyncio.get_running_loop() - - # Wait for first page if needed - if self._page_number == 0 and not self._current_page: - if self.config.timeout_seconds: - await asyncio.wait_for(self._page_ready.wait(), timeout=self.config.timeout_seconds) - else: - await self._page_ready.wait() - - # Yield the current page if it has data - if self._current_page: - yield self._current_page - - # Fetch and yield subsequent pages - while await self._fetch_next_page(): - if self._current_page: - yield self._current_page - - @property - def page_number(self) -> int: - """Get the current page number.""" - return self._page_number - - @property - def total_rows_fetched(self) -> int: - """Get the total number of rows fetched so far.""" - return self._total_rows - - async def cancel(self) -> None: - """Cancel the streaming operation.""" - self._exhausted = True - self._cleanup_callbacks() - - async def __aenter__(self) -> "AsyncStreamingResultSet": - """Enter async context manager.""" - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Exit async context manager and clean up resources.""" - await self.close() - - async def close(self) -> None: - """Close the streaming result set and clean up resources.""" - if self._closed: - return - - self._closed = True - self._exhausted = True - - # Clean up callbacks - self._cleanup_callbacks() - - # Clear current page to free memory - with self._lock: - self._current_page = [] - self._current_index = 0 - - # Signal any waiters - if self._page_ready is not None: - self._page_ready.set() - - -class StreamingResultHandler: - """ - Handler for creating streaming result sets. - - This is an alternative to AsyncResultHandler that doesn't - load all results into memory. - """ - - def __init__(self, response_future: ResponseFuture, config: Optional[StreamConfig] = None): - """ - Initialize streaming result handler. - - Args: - response_future: The Cassandra response future - config: Streaming configuration - """ - self.response_future = response_future - self.config = config or StreamConfig() - - async def get_streaming_result(self) -> AsyncStreamingResultSet: - """ - Get the streaming result set. - - Returns: - AsyncStreamingResultSet for efficient iteration. - """ - # Simply create and return the streaming result set - # It will handle its own callbacks - return AsyncStreamingResultSet(self.response_future, self.config) - - -def create_streaming_statement( - query: str, fetch_size: int = 1000, consistency_level: Optional[ConsistencyLevel] = None -) -> SimpleStatement: - """ - Create a statement configured for streaming. - - Args: - query: The CQL query - fetch_size: Number of rows per page - consistency_level: Optional consistency level - - Returns: - SimpleStatement configured for streaming - """ - statement = SimpleStatement(query, fetch_size=fetch_size) - - if consistency_level is not None: - statement.consistency_level = consistency_level - - return statement diff --git a/src/async_cassandra/utils.py b/src/async_cassandra/utils.py deleted file mode 100644 index b0b8512..0000000 --- a/src/async_cassandra/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Utility functions and helpers for async-cassandra. -""" - -import asyncio -import logging -from typing import Any, Optional - -logger = logging.getLogger(__name__) - - -def get_or_create_event_loop() -> asyncio.AbstractEventLoop: - """ - Get the current event loop or create a new one if necessary. - - Returns: - The current or newly created event loop. - """ - try: - return asyncio.get_running_loop() - except RuntimeError: - # No event loop running, create a new one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - - -def safe_call_soon_threadsafe( - loop: Optional[asyncio.AbstractEventLoop], callback: Any, *args: Any -) -> None: - """ - Safely schedule a callback in the event loop from another thread. - - Args: - loop: The event loop to schedule in (may be None). - callback: The callback function to schedule. - *args: Arguments to pass to the callback. - """ - if loop is not None: - try: - loop.call_soon_threadsafe(callback, *args) - except RuntimeError as e: - # Event loop might be closed - logger.warning(f"Failed to schedule callback: {e}") - except Exception: - # Ignore other exceptions - we don't want to crash the caller - pass diff --git a/test-env/bin/Activate.ps1 b/test-env/bin/Activate.ps1 deleted file mode 100644 index 354eb42..0000000 --- a/test-env/bin/Activate.ps1 +++ /dev/null @@ -1,247 +0,0 @@ -<# -.Synopsis -Activate a Python virtual environment for the current PowerShell session. - -.Description -Pushes the python executable for a virtual environment to the front of the -$Env:PATH environment variable and sets the prompt to signify that you are -in a Python virtual environment. Makes use of the command line switches as -well as the `pyvenv.cfg` file values present in the virtual environment. - -.Parameter VenvDir -Path to the directory that contains the virtual environment to activate. The -default value for this is the parent of the directory that the Activate.ps1 -script is located within. - -.Parameter Prompt -The prompt prefix to display when this virtual environment is activated. By -default, this prompt is the name of the virtual environment folder (VenvDir) -surrounded by parentheses and followed by a single space (ie. '(.venv) '). - -.Example -Activate.ps1 -Activates the Python virtual environment that contains the Activate.ps1 script. - -.Example -Activate.ps1 -Verbose -Activates the Python virtual environment that contains the Activate.ps1 script, -and shows extra information about the activation as it executes. - -.Example -Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv -Activates the Python virtual environment located in the specified location. - -.Example -Activate.ps1 -Prompt "MyPython" -Activates the Python virtual environment that contains the Activate.ps1 script, -and prefixes the current prompt with the specified string (surrounded in -parentheses) while the virtual environment is active. - -.Notes -On Windows, it may be required to enable this Activate.ps1 script by setting the -execution policy for the user. You can do this by issuing the following PowerShell -command: - -PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser - -For more information on Execution Policies: -https://go.microsoft.com/fwlink/?LinkID=135170 - -#> -Param( - [Parameter(Mandatory = $false)] - [String] - $VenvDir, - [Parameter(Mandatory = $false)] - [String] - $Prompt -) - -<# Function declarations --------------------------------------------------- #> - -<# -.Synopsis -Remove all shell session elements added by the Activate script, including the -addition of the virtual environment's Python executable from the beginning of -the PATH variable. - -.Parameter NonDestructive -If present, do not remove this function from the global namespace for the -session. - -#> -function global:deactivate ([switch]$NonDestructive) { - # Revert to original values - - # The prior prompt: - if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) { - Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt - Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT - } - - # The prior PYTHONHOME: - if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) { - Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME - Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME - } - - # The prior PATH: - if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) { - Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH - Remove-Item -Path Env:_OLD_VIRTUAL_PATH - } - - # Just remove the VIRTUAL_ENV altogether: - if (Test-Path -Path Env:VIRTUAL_ENV) { - Remove-Item -Path env:VIRTUAL_ENV - } - - # Just remove VIRTUAL_ENV_PROMPT altogether. - if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) { - Remove-Item -Path env:VIRTUAL_ENV_PROMPT - } - - # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether: - if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) { - Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force - } - - # Leave deactivate function in the global namespace if requested: - if (-not $NonDestructive) { - Remove-Item -Path function:deactivate - } -} - -<# -.Description -Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the -given folder, and returns them in a map. - -For each line in the pyvenv.cfg file, if that line can be parsed into exactly -two strings separated by `=` (with any amount of whitespace surrounding the =) -then it is considered a `key = value` line. The left hand string is the key, -the right hand is the value. - -If the value starts with a `'` or a `"` then the first and last character is -stripped from the value before being captured. - -.Parameter ConfigDir -Path to the directory that contains the `pyvenv.cfg` file. -#> -function Get-PyVenvConfig( - [String] - $ConfigDir -) { - Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg" - - # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue). - $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue - - # An empty map will be returned if no config file is found. - $pyvenvConfig = @{ } - - if ($pyvenvConfigPath) { - - Write-Verbose "File exists, parse `key = value` lines" - $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath - - $pyvenvConfigContent | ForEach-Object { - $keyval = $PSItem -split "\s*=\s*", 2 - if ($keyval[0] -and $keyval[1]) { - $val = $keyval[1] - - # Remove extraneous quotations around a string value. - if ("'""".Contains($val.Substring(0, 1))) { - $val = $val.Substring(1, $val.Length - 2) - } - - $pyvenvConfig[$keyval[0]] = $val - Write-Verbose "Adding Key: '$($keyval[0])'='$val'" - } - } - } - return $pyvenvConfig -} - - -<# Begin Activate script --------------------------------------------------- #> - -# Determine the containing directory of this script -$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition -$VenvExecDir = Get-Item -Path $VenvExecPath - -Write-Verbose "Activation script is located in path: '$VenvExecPath'" -Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)" -Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)" - -# Set values required in priority: CmdLine, ConfigFile, Default -# First, get the location of the virtual environment, it might not be -# VenvExecDir if specified on the command line. -if ($VenvDir) { - Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values" -} -else { - Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir." - $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/") - Write-Verbose "VenvDir=$VenvDir" -} - -# Next, read the `pyvenv.cfg` file to determine any required value such -# as `prompt`. -$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir - -# Next, set the prompt from the command line, or the config file, or -# just use the name of the virtual environment folder. -if ($Prompt) { - Write-Verbose "Prompt specified as argument, using '$Prompt'" -} -else { - Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value" - if ($pyvenvCfg -and $pyvenvCfg['prompt']) { - Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'" - $Prompt = $pyvenvCfg['prompt']; - } - else { - Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)" - Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'" - $Prompt = Split-Path -Path $venvDir -Leaf - } -} - -Write-Verbose "Prompt = '$Prompt'" -Write-Verbose "VenvDir='$VenvDir'" - -# Deactivate any currently active virtual environment, but leave the -# deactivate function in place. -deactivate -nondestructive - -# Now set the environment variable VIRTUAL_ENV, used by many tools to determine -# that there is an activated venv. -$env:VIRTUAL_ENV = $VenvDir - -if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) { - - Write-Verbose "Setting prompt to '$Prompt'" - - # Set the prompt to include the env name - # Make sure _OLD_VIRTUAL_PROMPT is global - function global:_OLD_VIRTUAL_PROMPT { "" } - Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT - New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt - - function global:prompt { - Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) " - _OLD_VIRTUAL_PROMPT - } - $env:VIRTUAL_ENV_PROMPT = $Prompt -} - -# Clear PYTHONHOME -if (Test-Path -Path Env:PYTHONHOME) { - Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME - Remove-Item -Path Env:PYTHONHOME -} - -# Add the venv to the PATH -Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH -$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH" diff --git a/test-env/bin/activate b/test-env/bin/activate deleted file mode 100644 index bcf0a37..0000000 --- a/test-env/bin/activate +++ /dev/null @@ -1,71 +0,0 @@ -# This file must be used with "source bin/activate" *from bash* -# You cannot run it directly - -deactivate () { - # reset old environment variables - if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then - PATH="${_OLD_VIRTUAL_PATH:-}" - export PATH - unset _OLD_VIRTUAL_PATH - fi - if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then - PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}" - export PYTHONHOME - unset _OLD_VIRTUAL_PYTHONHOME - fi - - # Call hash to forget past locations. Without forgetting - # past locations the $PATH changes we made may not be respected. - # See "man bash" for more details. hash is usually a builtin of your shell - hash -r 2> /dev/null - - if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then - PS1="${_OLD_VIRTUAL_PS1:-}" - export PS1 - unset _OLD_VIRTUAL_PS1 - fi - - unset VIRTUAL_ENV - unset VIRTUAL_ENV_PROMPT - if [ ! "${1:-}" = "nondestructive" ] ; then - # Self destruct! - unset -f deactivate - fi -} - -# unset irrelevant variables -deactivate nondestructive - -# on Windows, a path can contain colons and backslashes and has to be converted: -if [ "${OSTYPE:-}" = "cygwin" ] || [ "${OSTYPE:-}" = "msys" ] ; then - # transform D:\path\to\venv to /d/path/to/venv on MSYS - # and to /cygdrive/d/path/to/venv on Cygwin - export VIRTUAL_ENV=$(cygpath /Users/johnny/Development/async-python-cassandra-client/test-env) -else - # use the path as-is - export VIRTUAL_ENV=/Users/johnny/Development/async-python-cassandra-client/test-env -fi - -_OLD_VIRTUAL_PATH="$PATH" -PATH="$VIRTUAL_ENV/"bin":$PATH" -export PATH - -# unset PYTHONHOME if set -# this will fail if PYTHONHOME is set to the empty string (which is bad anyway) -# could use `if (set -u; : $PYTHONHOME) ;` in bash -if [ -n "${PYTHONHOME:-}" ] ; then - _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}" - unset PYTHONHOME -fi - -if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then - _OLD_VIRTUAL_PS1="${PS1:-}" - PS1='(test-env) '"${PS1:-}" - export PS1 - VIRTUAL_ENV_PROMPT='(test-env) ' - export VIRTUAL_ENV_PROMPT -fi - -# Call hash to forget past commands. Without forgetting -# past commands the $PATH changes we made may not be respected -hash -r 2> /dev/null diff --git a/test-env/bin/activate.csh b/test-env/bin/activate.csh deleted file mode 100644 index 356139d..0000000 --- a/test-env/bin/activate.csh +++ /dev/null @@ -1,27 +0,0 @@ -# This file must be used with "source bin/activate.csh" *from csh*. -# You cannot run it directly. - -# Created by Davide Di Blasi . -# Ported to Python 3.3 venv by Andrew Svetlov - -alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate' - -# Unset irrelevant variables. -deactivate nondestructive - -setenv VIRTUAL_ENV /Users/johnny/Development/async-python-cassandra-client/test-env - -set _OLD_VIRTUAL_PATH="$PATH" -setenv PATH "$VIRTUAL_ENV/"bin":$PATH" - - -set _OLD_VIRTUAL_PROMPT="$prompt" - -if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then - set prompt = '(test-env) '"$prompt" - setenv VIRTUAL_ENV_PROMPT '(test-env) ' -endif - -alias pydoc python -m pydoc - -rehash diff --git a/test-env/bin/activate.fish b/test-env/bin/activate.fish deleted file mode 100644 index 5db1bc3..0000000 --- a/test-env/bin/activate.fish +++ /dev/null @@ -1,69 +0,0 @@ -# This file must be used with "source /bin/activate.fish" *from fish* -# (https://fishshell.com/). You cannot run it directly. - -function deactivate -d "Exit virtual environment and return to normal shell environment" - # reset old environment variables - if test -n "$_OLD_VIRTUAL_PATH" - set -gx PATH $_OLD_VIRTUAL_PATH - set -e _OLD_VIRTUAL_PATH - end - if test -n "$_OLD_VIRTUAL_PYTHONHOME" - set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME - set -e _OLD_VIRTUAL_PYTHONHOME - end - - if test -n "$_OLD_FISH_PROMPT_OVERRIDE" - set -e _OLD_FISH_PROMPT_OVERRIDE - # prevents error when using nested fish instances (Issue #93858) - if functions -q _old_fish_prompt - functions -e fish_prompt - functions -c _old_fish_prompt fish_prompt - functions -e _old_fish_prompt - end - end - - set -e VIRTUAL_ENV - set -e VIRTUAL_ENV_PROMPT - if test "$argv[1]" != "nondestructive" - # Self-destruct! - functions -e deactivate - end -end - -# Unset irrelevant variables. -deactivate nondestructive - -set -gx VIRTUAL_ENV /Users/johnny/Development/async-python-cassandra-client/test-env - -set -gx _OLD_VIRTUAL_PATH $PATH -set -gx PATH "$VIRTUAL_ENV/"bin $PATH - -# Unset PYTHONHOME if set. -if set -q PYTHONHOME - set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME - set -e PYTHONHOME -end - -if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" - # fish uses a function instead of an env var to generate the prompt. - - # Save the current fish_prompt function as the function _old_fish_prompt. - functions -c fish_prompt _old_fish_prompt - - # With the original prompt function renamed, we can override with our own. - function fish_prompt - # Save the return status of the last command. - set -l old_status $status - - # Output the venv prompt; color taken from the blue of the Python logo. - printf "%s%s%s" (set_color 4B8BBE) '(test-env) ' (set_color normal) - - # Restore the return status of the previous command. - echo "exit $old_status" | . - # Output the original/"old" prompt. - _old_fish_prompt - end - - set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" - set -gx VIRTUAL_ENV_PROMPT '(test-env) ' -end diff --git a/test-env/bin/geomet b/test-env/bin/geomet deleted file mode 100755 index 8345043..0000000 --- a/test-env/bin/geomet +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/johnny/Development/async-python-cassandra-client/test-env/bin/python -# -*- coding: utf-8 -*- -import re -import sys - -from geomet.tool import cli - -if __name__ == "__main__": - sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0]) - sys.exit(cli()) diff --git a/test-env/bin/pip b/test-env/bin/pip deleted file mode 100755 index a3b4401..0000000 --- a/test-env/bin/pip +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/johnny/Development/async-python-cassandra-client/test-env/bin/python -# -*- coding: utf-8 -*- -import re -import sys - -from pip._internal.cli.main import main - -if __name__ == "__main__": - sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0]) - sys.exit(main()) diff --git a/test-env/bin/pip3 b/test-env/bin/pip3 deleted file mode 100755 index a3b4401..0000000 --- a/test-env/bin/pip3 +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/johnny/Development/async-python-cassandra-client/test-env/bin/python -# -*- coding: utf-8 -*- -import re -import sys - -from pip._internal.cli.main import main - -if __name__ == "__main__": - sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0]) - sys.exit(main()) diff --git a/test-env/bin/pip3.12 b/test-env/bin/pip3.12 deleted file mode 100755 index a3b4401..0000000 --- a/test-env/bin/pip3.12 +++ /dev/null @@ -1,10 +0,0 @@ -#!/Users/johnny/Development/async-python-cassandra-client/test-env/bin/python -# -*- coding: utf-8 -*- -import re -import sys - -from pip._internal.cli.main import main - -if __name__ == "__main__": - sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0]) - sys.exit(main()) diff --git a/test-env/bin/python b/test-env/bin/python deleted file mode 120000 index 091d463..0000000 --- a/test-env/bin/python +++ /dev/null @@ -1 +0,0 @@ -/Users/johnny/.pyenv/versions/3.12.8/bin/python \ No newline at end of file diff --git a/test-env/bin/python3 b/test-env/bin/python3 deleted file mode 120000 index d8654aa..0000000 --- a/test-env/bin/python3 +++ /dev/null @@ -1 +0,0 @@ -python \ No newline at end of file diff --git a/test-env/bin/python3.12 b/test-env/bin/python3.12 deleted file mode 120000 index d8654aa..0000000 --- a/test-env/bin/python3.12 +++ /dev/null @@ -1 +0,0 @@ -python \ No newline at end of file diff --git a/test-env/pyvenv.cfg b/test-env/pyvenv.cfg deleted file mode 100644 index ba6019d..0000000 --- a/test-env/pyvenv.cfg +++ /dev/null @@ -1,5 +0,0 @@ -home = /Users/johnny/.pyenv/versions/3.12.8/bin -include-system-site-packages = false -version = 3.12.8 -executable = /Users/johnny/.pyenv/versions/3.12.8/bin/python3.12 -command = /Users/johnny/.pyenv/versions/3.12.8/bin/python -m venv /Users/johnny/Development/async-python-cassandra-client/test-env diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 47ef89c..0000000 --- a/tests/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# Test Organization - -This directory contains all tests for async-python-cassandra-client, organized by test type: - -## Directory Structure - -### `/unit` -Pure unit tests with mocked dependencies. No external services required. -- Fast execution -- Test individual components in isolation -- All Cassandra interactions are mocked - -### `/integration` -Integration tests that require a real Cassandra instance. -- Test actual database operations -- Verify driver behavior with real Cassandra -- Marked with `@pytest.mark.integration` - -### `/bdd` -Cucumber-based Behavior Driven Development tests. -- Feature files in `/bdd/features` -- Step definitions in `/bdd/steps` -- Focus on user scenarios and requirements - -### `/fastapi_integration` -FastAPI-specific integration tests. -- Test the example FastAPI application -- Verify async-cassandra works correctly with FastAPI -- Requires both Cassandra and the FastAPI app running -- No mocking - tests real-world scenarios - -### `/benchmarks` -Performance benchmarks and stress tests. -- Measure performance characteristics -- Identify performance regressions - -### `/utils` -Shared test utilities and helpers. - -### `/_fixtures` -Test fixtures and sample data. - -## Running Tests - -```bash -# Unit tests (fast, no external dependencies) -make test-unit - -# Integration tests (requires Cassandra) -make test-integration - -# FastAPI integration tests (requires Cassandra + FastAPI app) -make test-fastapi - -# BDD tests (requires Cassandra) -make test-bdd - -# All tests -make test-all -``` - -## Test Isolation - -- Each test type is completely isolated -- No shared code between test types -- Each directory has its own conftest.py if needed -- Tests should not import from other test directories diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 0a60055..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Test package for async-cassandra.""" diff --git a/tests/_fixtures/__init__.py b/tests/_fixtures/__init__.py deleted file mode 100644 index 27f3868..0000000 --- a/tests/_fixtures/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Shared test fixtures and utilities. - -This package contains reusable fixtures for Cassandra containers, -FastAPI apps, and monitoring utilities. -""" diff --git a/tests/_fixtures/cassandra.py b/tests/_fixtures/cassandra.py deleted file mode 100644 index cdab804..0000000 --- a/tests/_fixtures/cassandra.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Cassandra test fixtures supporting both Docker and Podman. - -This module provides fixtures for managing Cassandra containers -in tests, with support for both Docker and Podman runtimes. -""" - -import os -import subprocess -import time -from typing import Optional - -import pytest - - -def get_container_runtime() -> str: - """Detect available container runtime (docker or podman).""" - for runtime in ["docker", "podman"]: - try: - subprocess.run([runtime, "--version"], capture_output=True, check=True) - return runtime - except (subprocess.CalledProcessError, FileNotFoundError): - continue - raise RuntimeError("Neither docker nor podman found. Please install one.") - - -class CassandraContainer: - """Manages a Cassandra container for testing.""" - - def __init__(self, runtime: str = None): - self.runtime = runtime or get_container_runtime() - self.container_name = "async-cassandra-test" - self.container_id: Optional[str] = None - - def start(self): - """Start the Cassandra container.""" - # Stop and remove any existing container with our name - print(f"Cleaning up any existing container named {self.container_name}...") - subprocess.run( - [self.runtime, "stop", self.container_name], - capture_output=True, - stderr=subprocess.DEVNULL, - ) - subprocess.run( - [self.runtime, "rm", "-f", self.container_name], - capture_output=True, - stderr=subprocess.DEVNULL, - ) - - # Create new container with proper resources - print(f"Starting fresh Cassandra container: {self.container_name}") - result = subprocess.run( - [ - self.runtime, - "run", - "-d", - "--name", - self.container_name, - "-p", - "9042:9042", - "-e", - "CASSANDRA_CLUSTER_NAME=TestCluster", - "-e", - "CASSANDRA_DC=datacenter1", - "-e", - "CASSANDRA_ENDPOINT_SNITCH=GossipingPropertyFileSnitch", - "-e", - "HEAP_NEWSIZE=512M", - "-e", - "MAX_HEAP_SIZE=3G", - "-e", - "JVM_OPTS=-XX:+UseG1GC -XX:G1RSetUpdatingPauseTimePercent=5 -XX:MaxGCPauseMillis=300", - "--memory=4g", - "--memory-swap=4g", - "cassandra:5", - ], - capture_output=True, - text=True, - check=True, - ) - self.container_id = result.stdout.strip() - - # Wait for Cassandra to be ready - self._wait_for_cassandra() - - def stop(self): - """Stop the Cassandra container.""" - if self.container_id or self.container_name: - container_ref = self.container_id or self.container_name - subprocess.run([self.runtime, "stop", container_ref], capture_output=True) - - def remove(self): - """Remove the Cassandra container.""" - if self.container_id or self.container_name: - container_ref = self.container_id or self.container_name - subprocess.run([self.runtime, "rm", "-f", container_ref], capture_output=True) - - def _wait_for_cassandra(self, timeout: int = 90): - """Wait for Cassandra to be ready to accept connections.""" - start_time = time.time() - while time.time() - start_time < timeout: - # Use container name instead of ID for exec - container_ref = self.container_name if self.container_name else self.container_id - - # First check if native transport is active - health_result = subprocess.run( - [ - self.runtime, - "exec", - container_ref, - "nodetool", - "info", - ], - capture_output=True, - text=True, - ) - - if ( - health_result.returncode == 0 - and "Native Transport active: true" in health_result.stdout - ): - # Now check if CQL is responsive - cql_result = subprocess.run( - [ - self.runtime, - "exec", - container_ref, - "cqlsh", - "-e", - "SELECT release_version FROM system.local", - ], - capture_output=True, - ) - if cql_result.returncode == 0: - return - time.sleep(3) - raise TimeoutError(f"Cassandra did not start within {timeout} seconds") - - def execute_cql(self, cql: str): - """Execute CQL statement in the container.""" - return subprocess.run( - [self.runtime, "exec", self.container_id, "cqlsh", "-e", cql], - capture_output=True, - text=True, - check=True, - ) - - def is_running(self) -> bool: - """Check if container is running.""" - if not self.container_id: - return False - result = subprocess.run( - [self.runtime, "inspect", "-f", "{{.State.Running}}", self.container_id], - capture_output=True, - text=True, - ) - return result.stdout.strip() == "true" - - def check_health(self) -> dict: - """Check Cassandra health using nodetool info.""" - if not self.container_id: - return { - "native_transport": False, - "gossip": False, - "cql_available": False, - } - - container_ref = self.container_name if self.container_name else self.container_id - - # Run nodetool info - result = subprocess.run( - [ - self.runtime, - "exec", - container_ref, - "nodetool", - "info", - ], - capture_output=True, - text=True, - ) - - health_status = { - "native_transport": False, - "gossip": False, - "cql_available": False, - } - - if result.returncode == 0: - info = result.stdout - health_status["native_transport"] = "Native Transport active: true" in info - health_status["gossip"] = ( - "Gossip active" in info and "true" in info.split("Gossip active")[1].split("\n")[0] - ) - - # Check CQL availability - cql_result = subprocess.run( - [ - self.runtime, - "exec", - container_ref, - "cqlsh", - "-e", - "SELECT now() FROM system.local", - ], - capture_output=True, - ) - health_status["cql_available"] = cql_result.returncode == 0 - - return health_status - - -@pytest.fixture(scope="session") -def cassandra_container(): - """Provide a Cassandra container for the test session.""" - # First check if there's already a running container we can use - runtime = get_container_runtime() - port_check = subprocess.run( - [runtime, "ps", "--format", "{{.Names}} {{.Ports}}"], - capture_output=True, - text=True, - ) - - if port_check.stdout.strip(): - # Check for container using port 9042 - for line in port_check.stdout.strip().split("\n"): - if "9042" in line: - existing_container = line.split()[0] - print(f"Using existing Cassandra container: {existing_container}") - - container = CassandraContainer() - container.container_name = existing_container - container.container_id = existing_container - container.runtime = runtime - - # Ensure test keyspace exists - container.execute_cql( - """ - CREATE KEYSPACE IF NOT EXISTS test_keyspace - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - - yield container - # Don't stop/remove containers we didn't create - return - - # No existing container, create new one - container = CassandraContainer() - container.start() - - # Create test keyspace - container.execute_cql( - """ - CREATE KEYSPACE IF NOT EXISTS test_keyspace - WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - - yield container - - # Cleanup based on environment variable - if os.environ.get("KEEP_CONTAINERS") != "1": - container.stop() - container.remove() - - -@pytest.fixture(scope="function") -def cassandra_session(cassandra_container): - """Provide a Cassandra session connected to test keyspace.""" - from cassandra.cluster import Cluster - - cluster = Cluster(["127.0.0.1"]) - session = cluster.connect() - session.set_keyspace("test_keyspace") - - yield session - - # Cleanup tables created during test - rows = session.execute( - """ - SELECT table_name FROM system_schema.tables - WHERE keyspace_name = 'test_keyspace' - """ - ) - for row in rows: - session.execute(f"DROP TABLE IF EXISTS {row.table_name}") - - cluster.shutdown() - - -@pytest.fixture(scope="function") -async def async_cassandra_session(cassandra_container): - """Provide an async Cassandra session.""" - from async_cassandra import AsyncCluster - - cluster = AsyncCluster(["127.0.0.1"]) - session = await cluster.connect() - await session.set_keyspace("test_keyspace") - - yield session - - # Cleanup - await session.close() - await cluster.shutdown() diff --git a/tests/bdd/conftest.py b/tests/bdd/conftest.py deleted file mode 100644 index a571457..0000000 --- a/tests/bdd/conftest.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Pytest configuration for BDD tests.""" - -import asyncio -import sys -from pathlib import Path - -import pytest - -from tests._fixtures.cassandra import cassandra_container # noqa: F401 - -# Add project root to path -project_root = Path(__file__).parent.parent.parent -sys.path.insert(0, str(project_root)) - -# Import test utils for isolation -sys.path.insert(0, str(Path(__file__).parent.parent)) -from test_utils import ( # noqa: E402 - cleanup_keyspace, - create_test_keyspace, - generate_unique_keyspace, - get_test_timeout, -) - - -@pytest.fixture(scope="session") -def event_loop(): - """Create an event loop for the test session.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -@pytest.fixture -def anyio_backend(): - """Use asyncio backend for async tests.""" - return "asyncio" - - -@pytest.fixture -def connection_parameters(): - """Provide connection parameters for BDD tests.""" - return {"contact_points": ["127.0.0.1"], "port": 9042} - - -@pytest.fixture -def driver_configured(): - """Provide driver configuration for BDD tests.""" - return {"contact_points": ["127.0.0.1"], "port": 9042, "thread_pool_max_workers": 32} - - -@pytest.fixture -def cassandra_cluster_running(cassandra_container): # noqa: F811 - """Ensure Cassandra container is running and healthy.""" - assert cassandra_container.is_running() - - # Check health before proceeding - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - pytest.fail(f"Cassandra not healthy: {health}") - - return cassandra_container - - -@pytest.fixture -async def cassandra_cluster(cassandra_container): # noqa: F811 - """Provide an async Cassandra cluster for BDD tests.""" - from async_cassandra import AsyncCluster - - # Ensure Cassandra is healthy before creating cluster - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - pytest.fail(f"Cassandra not healthy: {health}") - - cluster = AsyncCluster(["127.0.0.1"], protocol_version=5) - yield cluster - await cluster.shutdown() - # Give extra time for driver's internal threads to fully stop - # This prevents "cannot schedule new futures after shutdown" errors - await asyncio.sleep(2) - - -@pytest.fixture -async def isolated_session(cassandra_cluster): - """Provide an isolated session with unique keyspace for BDD tests.""" - session = await cassandra_cluster.connect() - - # Create unique keyspace for this test - keyspace = generate_unique_keyspace("test_bdd") - await create_test_keyspace(session, keyspace) - await session.set_keyspace(keyspace) - - yield session - - # Cleanup - await cleanup_keyspace(session, keyspace) - await session.close() - # Give time for session cleanup - await asyncio.sleep(1) - - -@pytest.fixture -def test_context(): - """Shared context for BDD tests with isolation helpers.""" - return { - "keyspaces_created": [], - "tables_created": [], - "get_unique_keyspace": lambda: generate_unique_keyspace("bdd"), - "get_test_timeout": get_test_timeout, - } - - -@pytest.fixture -def bdd_test_timeout(): - """Get appropriate timeout for BDD tests.""" - return get_test_timeout(10.0) - - -# BDD-specific configuration -def pytest_bdd_step_error(request, feature, scenario, step, step_func, step_func_args, exception): - """Enhanced error reporting for BDD steps.""" - print(f"\n{'='*60}") - print(f"STEP FAILED: {step.keyword} {step.name}") - print(f"Feature: {feature.name}") - print(f"Scenario: {scenario.name}") - print(f"Error: {exception}") - print(f"{'='*60}\n") - - -# Markers for BDD tests -def pytest_configure(config): - """Configure custom markers for BDD tests.""" - config.addinivalue_line("markers", "bdd: mark test as BDD test") - config.addinivalue_line("markers", "critical: mark test as critical for production") - config.addinivalue_line("markers", "concurrency: mark test as concurrency test") - config.addinivalue_line("markers", "performance: mark test as performance test") - config.addinivalue_line("markers", "memory: mark test as memory test") - config.addinivalue_line("markers", "fastapi: mark test as FastAPI integration test") - config.addinivalue_line("markers", "startup_shutdown: mark test as startup/shutdown test") - config.addinivalue_line( - "markers", "dependency_injection: mark test as dependency injection test" - ) - config.addinivalue_line("markers", "streaming: mark test as streaming test") - config.addinivalue_line("markers", "pagination: mark test as pagination test") - config.addinivalue_line("markers", "caching: mark test as caching test") - config.addinivalue_line("markers", "prepared_statements: mark test as prepared statements test") - config.addinivalue_line("markers", "monitoring: mark test as monitoring test") - config.addinivalue_line("markers", "connection_reuse: mark test as connection reuse test") - config.addinivalue_line("markers", "background_tasks: mark test as background tasks test") - config.addinivalue_line("markers", "graceful_shutdown: mark test as graceful shutdown test") - config.addinivalue_line("markers", "middleware: mark test as middleware test") - config.addinivalue_line("markers", "connection_failure: mark test as connection failure test") - config.addinivalue_line("markers", "websocket: mark test as websocket test") - config.addinivalue_line("markers", "memory_pressure: mark test as memory pressure test") - config.addinivalue_line("markers", "auth: mark test as authentication test") - config.addinivalue_line("markers", "error_handling: mark test as error handling test") - - -@pytest.fixture(scope="function", autouse=True) -async def ensure_cassandra_healthy_bdd(cassandra_container): # noqa: F811 - """Ensure Cassandra is healthy before each BDD test.""" - # Check health before test - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - # Try to wait a bit and check again - import asyncio - - await asyncio.sleep(2) - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - pytest.fail(f"Cassandra not healthy before test: {health}") - - yield - - # Optional: Check health after test - health = cassandra_container.check_health() - if not health["native_transport"]: - print(f"Warning: Cassandra health degraded after test: {health}") - - -# Automatically mark all BDD tests -def pytest_collection_modifyitems(items): - """Automatically add markers to BDD tests.""" - for item in items: - # Mark all tests in bdd directory - if "bdd" in str(item.fspath): - item.add_marker(pytest.mark.bdd) - - # Add markers based on tags in feature files - if hasattr(item, "scenario"): - for tag in item.scenario.tags: - # Remove @ and convert hyphens to underscores - marker_name = tag.lstrip("@").replace("-", "_") - if hasattr(pytest.mark, marker_name): - marker = getattr(pytest.mark, marker_name) - item.add_marker(marker) diff --git a/tests/bdd/features/concurrent_load.feature b/tests/bdd/features/concurrent_load.feature deleted file mode 100644 index 0d139fc..0000000 --- a/tests/bdd/features/concurrent_load.feature +++ /dev/null @@ -1,26 +0,0 @@ -Feature: Concurrent Load Handling - As a developer using async-cassandra - I need the driver to handle concurrent requests properly - So that my application doesn't deadlock or leak memory under load - - Background: - Given a running Cassandra cluster - And async-cassandra configured with default settings - - @critical @performance - Scenario: Thread pool exhaustion prevention - Given a configured thread pool of 10 threads - When I submit 1000 concurrent queries - Then all queries should eventually complete - And no deadlock should occur - And memory usage should remain stable - And response times should degrade gracefully - - @critical @memory - Scenario: Memory leak prevention under load - Given a baseline memory measurement - When I execute 10,000 queries - Then memory usage should not grow continuously - And garbage collection should work effectively - And no resource warnings should be logged - And performance should remain consistent diff --git a/tests/bdd/features/context_manager_safety.feature b/tests/bdd/features/context_manager_safety.feature deleted file mode 100644 index 056bff8..0000000 --- a/tests/bdd/features/context_manager_safety.feature +++ /dev/null @@ -1,56 +0,0 @@ -Feature: Context Manager Safety - As a developer using async-cassandra - I want context managers to only close their own resources - So that shared resources remain available for other operations - - Background: - Given a running Cassandra cluster - And a test keyspace "test_context_safety" - - Scenario: Query error doesn't close session - Given an open session connected to the test keyspace - When I execute a query that causes an error - Then the session should remain open and usable - And I should be able to execute subsequent queries successfully - - Scenario: Streaming error doesn't close session - Given an open session with test data - When a streaming operation encounters an error - Then the streaming result should be closed - But the session should remain open - And I should be able to start new streaming operations - - Scenario: Session context manager doesn't close cluster - Given an open cluster connection - When I use a session in a context manager that exits with an error - Then the session should be closed - But the cluster should remain open - And I should be able to create new sessions from the cluster - - Scenario: Multiple concurrent streams don't interfere - Given multiple sessions from the same cluster - When I stream data concurrently from each session - Then each stream should complete independently - And closing one stream should not affect others - And all sessions should remain usable - - Scenario: Nested context managers close in correct order - Given a cluster, session, and streaming result in nested context managers - When the innermost context (streaming) exits - Then only the streaming result should be closed - When the middle context (session) exits - Then only the session should be closed - When the outer context (cluster) exits - Then the cluster should be shut down - - Scenario: Thread safety during context exit - Given a session being used by multiple threads - When one thread exits a streaming context manager - Then other threads should still be able to use the session - And no operations should be interrupted - - Scenario: Context manager handles cancellation correctly - Given an active streaming operation in a context manager - When the operation is cancelled - Then the streaming result should be properly cleaned up - But the session should remain open and usable diff --git a/tests/bdd/features/fastapi_integration.feature b/tests/bdd/features/fastapi_integration.feature deleted file mode 100644 index 0c9ba03..0000000 --- a/tests/bdd/features/fastapi_integration.feature +++ /dev/null @@ -1,217 +0,0 @@ -Feature: FastAPI Integration - As a FastAPI developer - I want to use async-cassandra in my web application - So that I can build responsive APIs with Cassandra backend - - Background: - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - - @critical @fastapi - Scenario: Simple REST API endpoint - Given a user endpoint that queries Cassandra - When I send a GET request to "/users/123" - Then I should receive a 200 response - And the response should contain user data - And the request should complete within 100ms - - @critical @fastapi @concurrency - Scenario: Handle concurrent API requests - Given a product search endpoint - When I send 100 concurrent search requests - Then all requests should receive valid responses - And no request should take longer than 500ms - And the Cassandra connection pool should not be exhausted - - @fastapi @error_handling - Scenario: API error handling for database issues - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And a Cassandra query that will fail - When I send a request that triggers the failing query - Then I should receive a 500 error response - And the error should not expose internal details - And the connection should be returned to the pool - - @fastapi @startup_shutdown - Scenario: Application lifecycle management - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - When the FastAPI application starts up - Then the Cassandra cluster connection should be established - And the connection pool should be initialized - When the application shuts down - Then all active queries should complete or timeout - And all connections should be properly closed - And no resource warnings should be logged - - @fastapi @dependency_injection - Scenario: Use async-cassandra with FastAPI dependencies - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And a FastAPI dependency that provides a Cassandra session - When I use this dependency in multiple endpoints - Then each request should get a working session - And sessions should be properly managed per request - And no session leaks should occur between requests - - @fastapi @streaming - Scenario: Stream large datasets through API - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And an endpoint that returns 10,000 records - When I request the data with streaming enabled - Then the response should start immediately - And data should be streamed in chunks - And memory usage should remain constant - And the client should be able to cancel mid-stream - - @fastapi @pagination - Scenario: Implement cursor-based pagination - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And a paginated endpoint for listing items - When I request the first page with limit 20 - Then I should receive 20 items and a next cursor - When I request the next page using the cursor - Then I should receive the next 20 items - And pagination should work correctly under concurrent access - - @fastapi @caching - Scenario: Implement query result caching - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And an endpoint with query result caching enabled - When I make the same request multiple times - Then the first request should query Cassandra - And subsequent requests should use cached data - And cache should expire after the configured TTL - And cache should be invalidated on data updates - - @fastapi @prepared_statements - Scenario: Use prepared statements in API endpoints - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And an endpoint that uses prepared statements - When I make 1000 requests to this endpoint - Then statement preparation should happen only once - And query performance should be optimized - And the prepared statement cache should be shared across requests - - @fastapi @monitoring - Scenario: Monitor API and database performance - Given monitoring is enabled for the FastAPI app - And a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And a user endpoint that queries Cassandra - When I make various API requests - Then metrics should track: - | metric_type | description | - | request_count | Total API requests | - | request_duration | API response times | - | cassandra_query_count | Database queries per endpoint | - | cassandra_query_duration | Database query times | - | connection_pool_size | Active connections | - | error_rate | Failed requests percentage | - And metrics should be accessible via "/metrics" endpoint - - @critical @fastapi @connection_reuse - Scenario: Connection reuse across requests - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And an endpoint that performs multiple queries - When I make 50 sequential requests - Then the same Cassandra session should be reused - And no new connections should be created after warmup - And each request should complete faster than connection setup time - - @fastapi @background_tasks - Scenario: Background tasks with Cassandra operations - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And an endpoint that triggers background Cassandra operations - When I submit 10 tasks that write to Cassandra - Then the API should return immediately with 202 status - And all background writes should complete successfully - And no resources should leak from background tasks - - @critical @fastapi @graceful_shutdown - Scenario: Graceful shutdown under load - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And heavy concurrent load on the API - When the application receives a shutdown signal - Then in-flight requests should complete successfully - And new requests should be rejected with 503 - And all Cassandra operations should finish cleanly - And shutdown should complete within 30 seconds - - @fastapi @middleware - Scenario: Track Cassandra query metrics in middleware - Given a middleware that tracks Cassandra query execution - And a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And endpoints that perform different numbers of queries - When I make requests to endpoints with varying query counts - Then the middleware should accurately count queries per request - And query execution time should be measured - And async operations should not be blocked by tracking - - @critical @fastapi @connection_failure - Scenario: Handle Cassandra connection failures gracefully - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And a healthy API with established connections - When Cassandra becomes temporarily unavailable - Then API should return 503 Service Unavailable - And error messages should be user-friendly - When Cassandra becomes available again - Then API should automatically recover - And no manual intervention should be required - - @fastapi @websocket - Scenario: WebSocket endpoint with Cassandra streaming - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And a WebSocket endpoint that streams Cassandra data - When a client connects and requests real-time updates - Then the WebSocket should stream query results - And updates should be pushed as data changes - And connection cleanup should occur on disconnect - - @critical @fastapi @memory_pressure - Scenario: Handle memory pressure gracefully - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And an endpoint that fetches large datasets - When multiple clients request large amounts of data - Then memory usage should stay within limits - And requests should be throttled if necessary - And the application should not crash from OOM - - @fastapi @auth - Scenario: Authentication and session isolation - Given a FastAPI application with async-cassandra - And a running Cassandra cluster with test data - And the FastAPI test client is initialized - And endpoints with per-user Cassandra keyspaces - When different users make concurrent requests - Then each user should only access their keyspace - And sessions should be isolated between users - And no data should leak between user contexts diff --git a/tests/bdd/test_bdd_concurrent_load.py b/tests/bdd/test_bdd_concurrent_load.py deleted file mode 100644 index 3c8cbd5..0000000 --- a/tests/bdd/test_bdd_concurrent_load.py +++ /dev/null @@ -1,378 +0,0 @@ -"""BDD tests for concurrent load handling with real Cassandra.""" - -import asyncio -import gc -import time - -import psutil -import pytest -from pytest_bdd import given, parsers, scenario, then, when - -from async_cassandra import AsyncCluster - -# Import the cassandra_container fixture -pytest_plugins = ["tests._fixtures.cassandra"] - - -@scenario("features/concurrent_load.feature", "Thread pool exhaustion prevention") -def test_thread_pool_exhaustion(): - """ - Test thread pool exhaustion prevention. - - What this tests: - --------------- - 1. Thread pool limits respected - 2. No deadlock under load - 3. Queries complete eventually - 4. Graceful degradation - - Why this matters: - ---------------- - Thread exhaustion causes: - - Application hangs - - Query timeouts - - Poor user experience - - Must handle high load - without blocking. - """ - pass - - -@scenario("features/concurrent_load.feature", "Memory leak prevention under load") -def test_memory_leak_prevention(): - """ - Test memory leak prevention. - - What this tests: - --------------- - 1. Memory usage stable - 2. GC works effectively - 3. No continuous growth - 4. Resources cleaned up - - Why this matters: - ---------------- - Memory leaks fatal: - - OOM crashes - - Performance degradation - - Service instability - - Long-running apps need - stable memory usage. - """ - pass - - -@pytest.fixture -def load_context(cassandra_container): - """Context for concurrent load tests.""" - return { - "cluster": None, - "session": None, - "container": cassandra_container, - "metrics": { - "queries_sent": 0, - "queries_completed": 0, - "queries_failed": 0, - "memory_baseline": 0, - "memory_current": 0, - "memory_samples": [], - "start_time": None, - "errors": [], - }, - "thread_pool_size": 10, - "query_results": [], - "duration": None, - } - - -def run_async(coro, loop): - """Run async code in sync context.""" - return loop.run_until_complete(coro) - - -# Given steps -@given("a running Cassandra cluster") -def running_cluster(load_context): - """Verify Cassandra cluster is running.""" - assert load_context["container"].is_running() - - -@given("async-cassandra configured with default settings") -def default_settings(load_context, event_loop): - """Configure with default settings.""" - - async def _configure(): - cluster = AsyncCluster( - contact_points=["127.0.0.1"], - protocol_version=5, - executor_threads=load_context.get("thread_pool_size", 10), - ) - session = await cluster.connect() - await session.set_keyspace("test_keyspace") - - # Create test table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS test_data ( - id int PRIMARY KEY, - data text - ) - """ - ) - - load_context["cluster"] = cluster - load_context["session"] = session - - run_async(_configure(), event_loop) - - -@given(parsers.parse("a configured thread pool of {size:d} threads")) -def configure_thread_pool(size, load_context): - """Configure thread pool size.""" - load_context["thread_pool_size"] = size - - -@given("a baseline memory measurement") -def baseline_memory(load_context): - """Take baseline memory measurement.""" - # Force garbage collection for accurate baseline - gc.collect() - process = psutil.Process() - load_context["metrics"]["memory_baseline"] = process.memory_info().rss / 1024 / 1024 # MB - - -# When steps -@when(parsers.parse("I submit {count:d} concurrent queries")) -def submit_concurrent_queries(count, load_context, event_loop): - """Submit many concurrent queries.""" - - async def _submit(): - session = load_context["session"] - - # Insert some test data first - for i in range(100): - await session.execute( - "INSERT INTO test_data (id, data) VALUES (%s, %s)", [i, f"test_data_{i}"] - ) - - # Now submit concurrent queries - async def execute_one(query_id): - try: - load_context["metrics"]["queries_sent"] += 1 - - result = await session.execute( - "SELECT * FROM test_data WHERE id = %s", [query_id % 100] - ) - - load_context["metrics"]["queries_completed"] += 1 - return result - except Exception as e: - load_context["metrics"]["queries_failed"] += 1 - load_context["metrics"]["errors"].append(str(e)) - raise - - start = time.time() - - # Submit queries in batches to avoid overwhelming - batch_size = 100 - all_results = [] - - for batch_start in range(0, count, batch_size): - batch_end = min(batch_start + batch_size, count) - tasks = [execute_one(i) for i in range(batch_start, batch_end)] - batch_results = await asyncio.gather(*tasks, return_exceptions=True) - all_results.extend(batch_results) - - # Small delay between batches - if batch_end < count: - await asyncio.sleep(0.1) - - load_context["query_results"] = all_results - load_context["duration"] = time.time() - start - - run_async(_submit(), event_loop) - - -@when(parsers.re(r"I execute (?P[\d,]+) queries")) -def execute_many_queries(count, load_context, event_loop): - """Execute many queries.""" - # Convert count string to int, removing commas - count_int = int(count.replace(",", "")) - - async def _execute(): - session = load_context["session"] - - # We'll simulate by doing it faster but with memory measurements - batch_size = 1000 - batches = count_int // batch_size - - for batch_num in range(batches): - # Execute batch - tasks = [] - for i in range(batch_size): - query_id = batch_num * batch_size + i - task = session.execute("SELECT * FROM test_data WHERE id = %s", [query_id % 100]) - tasks.append(task) - - await asyncio.gather(*tasks) - load_context["metrics"]["queries_completed"] += batch_size - load_context["metrics"]["queries_sent"] += batch_size - - # Measure memory periodically - if batch_num % 10 == 0: - gc.collect() # Force GC to get accurate reading - process = psutil.Process() - memory_mb = process.memory_info().rss / 1024 / 1024 - load_context["metrics"]["memory_samples"].append(memory_mb) - load_context["metrics"]["memory_current"] = memory_mb - - run_async(_execute(), event_loop) - - -# Then steps -@then("all queries should eventually complete") -def verify_all_complete(load_context): - """Verify all queries complete.""" - total_processed = ( - load_context["metrics"]["queries_completed"] + load_context["metrics"]["queries_failed"] - ) - assert total_processed == load_context["metrics"]["queries_sent"] - - -@then("no deadlock should occur") -def verify_no_deadlock(load_context): - """Verify no deadlock.""" - # If we completed queries, there was no deadlock - assert load_context["metrics"]["queries_completed"] > 0 - - # Also verify that the duration is reasonable for the number of queries - # With a thread pool of 10 and proper concurrency, 1000 queries shouldn't take too long - if load_context.get("duration"): - avg_time_per_query = load_context["duration"] / load_context["metrics"]["queries_sent"] - # Average should be under 100ms per query with concurrency - assert ( - avg_time_per_query < 0.1 - ), f"Queries took too long: {avg_time_per_query:.3f}s per query" - - -@then("memory usage should remain stable") -def verify_memory_stable(load_context): - """Verify memory stability.""" - # Check that memory didn't grow excessively - baseline = load_context["metrics"]["memory_baseline"] - current = load_context["metrics"]["memory_current"] - - # Allow for some growth but not excessive (e.g., 100MB) - growth = current - baseline - assert growth < 100, f"Memory grew by {growth}MB" - - -@then("response times should degrade gracefully") -def verify_graceful_degradation(load_context): - """Verify graceful degradation.""" - # With 1000 queries and thread pool of 10, should still complete reasonably - # Average time per query should be reasonable - avg_time = load_context["duration"] / 1000 - assert avg_time < 1.0 # Less than 1 second per query average - - -@then("memory usage should not grow continuously") -def verify_no_memory_leak(load_context): - """Verify no memory leak.""" - samples = load_context["metrics"]["memory_samples"] - if len(samples) < 2: - return # Not enough samples - - # Check that memory is not monotonically increasing - # Allow for some fluctuation but overall should be stable - baseline = samples[0] - max_growth = max(s - baseline for s in samples) - - # Should not grow more than 50MB over the test - assert max_growth < 50, f"Memory grew by {max_growth}MB" - - -@then("garbage collection should work effectively") -def verify_gc_works(load_context): - """Verify GC effectiveness.""" - # We forced GC during the test, verify it helped - assert len(load_context["metrics"]["memory_samples"]) > 0 - - # Check that memory growth is controlled - samples = load_context["metrics"]["memory_samples"] - if len(samples) >= 2: - # Calculate growth rate - first_sample = samples[0] - last_sample = samples[-1] - total_growth = last_sample - first_sample - - # Growth should be minimal for the workload - # Allow up to 100MB growth for 100k queries - assert total_growth < 100, f"Memory grew too much: {total_growth}MB" - - # Check for stability in later samples (after warmup) - if len(samples) >= 5: - later_samples = samples[-5:] - max_variance = max(later_samples) - min(later_samples) - # Memory should stabilize - variance should be small - assert ( - max_variance < 20 - ), f"Memory not stable in later samples: {max_variance}MB variance" - - -@then("no resource warnings should be logged") -def verify_no_warnings(load_context): - """Verify no resource warnings.""" - # Check for common warnings in errors - warnings = [e for e in load_context["metrics"]["errors"] if "warning" in e.lower()] - assert len(warnings) == 0, f"Found warnings: {warnings}" - - # Also check Python's warning system - import warnings - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - # Force garbage collection to trigger any pending resource warnings - import gc - - gc.collect() - - # Check for resource warnings - resource_warnings = [ - warning for warning in w if issubclass(warning.category, ResourceWarning) - ] - assert len(resource_warnings) == 0, f"Found resource warnings: {resource_warnings}" - - -@then("performance should remain consistent") -def verify_consistent_performance(load_context): - """Verify consistent performance.""" - # Most queries should succeed - if load_context["metrics"]["queries_sent"] > 0: - success_rate = ( - load_context["metrics"]["queries_completed"] / load_context["metrics"]["queries_sent"] - ) - assert success_rate > 0.95 # 95% success rate - else: - # If no queries were sent, check that completed count matches - assert ( - load_context["metrics"]["queries_completed"] >= 100 - ) # At least some queries should have completed - - -# Cleanup -@pytest.fixture(autouse=True) -def cleanup_after_test(load_context, event_loop): - """Cleanup resources after each test.""" - yield - - async def _cleanup(): - if load_context.get("session"): - await load_context["session"].close() - if load_context.get("cluster"): - await load_context["cluster"].shutdown() - - if load_context.get("session") or load_context.get("cluster"): - run_async(_cleanup(), event_loop) diff --git a/tests/bdd/test_bdd_context_manager_safety.py b/tests/bdd/test_bdd_context_manager_safety.py deleted file mode 100644 index 6c3cbca..0000000 --- a/tests/bdd/test_bdd_context_manager_safety.py +++ /dev/null @@ -1,668 +0,0 @@ -""" -BDD tests for context manager safety. - -Tests the behavior described in features/context_manager_safety.feature -""" - -import asyncio -import uuid -from concurrent.futures import ThreadPoolExecutor - -import pytest -from cassandra import InvalidRequest -from pytest_bdd import given, scenarios, then, when - -from async_cassandra import AsyncCluster -from async_cassandra.streaming import StreamConfig - -# Load all scenarios from the feature file -scenarios("features/context_manager_safety.feature") - - -# Fixtures for test state -@pytest.fixture -def test_state(): - """Holds state across BDD steps.""" - return { - "cluster": None, - "session": None, - "error": None, - "streaming_result": None, - "sessions": [], - "results": [], - "thread_results": [], - } - - -@pytest.fixture -def event_loop(): - """Create event loop for tests.""" - loop = asyncio.new_event_loop() - yield loop - loop.close() - - -def run_async(coro, loop): - """Run async coroutine in sync context.""" - return loop.run_until_complete(coro) - - -# Background steps -@given("a running Cassandra cluster") -def cassandra_is_running(cassandra_cluster): - """Cassandra cluster is provided by the fixture.""" - # Just verify we have a cluster object - assert cassandra_cluster is not None - - -@given('a test keyspace "test_context_safety"') -def create_test_keyspace(cassandra_cluster, test_state, event_loop): - """Create test keyspace.""" - - async def _setup(): - cluster = AsyncCluster(["localhost"]) - session = await cluster.connect() - - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_context_safety - WITH REPLICATION = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - - test_state["cluster"] = cluster - test_state["session"] = session - - run_async(_setup(), event_loop) - - -# Scenario: Query error doesn't close session -@given("an open session connected to the test keyspace") -def open_session(test_state, event_loop): - """Ensure session is connected to test keyspace.""" - - async def _impl(): - await test_state["session"].set_keyspace("test_context_safety") - - # Create a test table - await test_state["session"].execute( - """ - CREATE TABLE IF NOT EXISTS test_table ( - id UUID PRIMARY KEY, - value TEXT - ) - """ - ) - - run_async(_impl(), event_loop) - - -@when("I execute a query that causes an error") -def execute_bad_query(test_state, event_loop): - """Execute a query that will fail.""" - - async def _impl(): - try: - await test_state["session"].execute("SELECT * FROM non_existent_table") - except InvalidRequest as e: - test_state["error"] = e - - run_async(_impl(), event_loop) - - -@then("the session should remain open and usable") -def session_is_open(test_state, event_loop): - """Verify session is still open.""" - assert test_state["session"] is not None - assert not test_state["session"].is_closed - - -@then("I should be able to execute subsequent queries successfully") -def can_execute_queries(test_state, event_loop): - """Execute a successful query.""" - - async def _impl(): - test_id = uuid.uuid4() - await test_state["session"].execute( - "INSERT INTO test_table (id, value) VALUES (%s, %s)", [test_id, "test_value"] - ) - - result = await test_state["session"].execute( - "SELECT * FROM test_table WHERE id = %s", [test_id] - ) - assert result.one().value == "test_value" - - run_async(_impl(), event_loop) - - -# Scenario: Streaming error doesn't close session -@given("an open session with test data") -def session_with_data(test_state, event_loop): - """Create session with test data.""" - - async def _impl(): - await test_state["session"].set_keyspace("test_context_safety") - - await test_state["session"].execute( - """ - CREATE TABLE IF NOT EXISTS stream_test ( - id UUID PRIMARY KEY, - value INT - ) - """ - ) - - # Insert test data - for i in range(10): - await test_state["session"].execute( - "INSERT INTO stream_test (id, value) VALUES (%s, %s)", [uuid.uuid4(), i] - ) - - run_async(_impl(), event_loop) - - -@when("a streaming operation encounters an error") -def streaming_error(test_state, event_loop): - """Try to stream from non-existent table.""" - - async def _impl(): - try: - async with await test_state["session"].execute_stream( - "SELECT * FROM non_existent_stream_table" - ) as stream: - async for row in stream: - pass - except Exception as e: - test_state["error"] = e - - run_async(_impl(), event_loop) - - -@then("the streaming result should be closed") -def streaming_closed(test_state, event_loop): - """Streaming result is closed (checked by context manager exit).""" - # Context manager ensures closure - assert test_state["error"] is not None - - -@then("the session should remain open") -def session_still_open(test_state, event_loop): - """Session should not be closed.""" - assert not test_state["session"].is_closed - - -@then("I should be able to start new streaming operations") -def can_stream_again(test_state, event_loop): - """Start a new streaming operation.""" - - async def _impl(): - count = 0 - async with await test_state["session"].execute_stream( - "SELECT * FROM stream_test" - ) as stream: - async for row in stream: - count += 1 - - assert count == 10 # Should get all 10 rows - - run_async(_impl(), event_loop) - - -# Scenario: Session context manager doesn't close cluster -@given("an open cluster connection") -def cluster_is_open(test_state): - """Cluster is already open from background.""" - assert test_state["cluster"] is not None - - -@when("I use a session in a context manager that exits with an error") -def session_context_with_error(test_state, event_loop): - """Use session context manager with error.""" - - async def _impl(): - try: - async with await test_state["cluster"].connect("test_context_safety") as session: - # Do some work - await session.execute("SELECT * FROM system.local") - # Raise an error - raise ValueError("Test error") - except ValueError: - test_state["error"] = "Session context exited" - - run_async(_impl(), event_loop) - - -@then("the session should be closed") -def session_is_closed(test_state): - """Session was closed by context manager.""" - # We know it's closed because context manager handles it - assert test_state["error"] == "Session context exited" - - -@then("the cluster should remain open") -def cluster_still_open(test_state): - """Cluster should not be closed.""" - assert not test_state["cluster"].is_closed - - -@then("I should be able to create new sessions from the cluster") -def can_create_sessions(test_state, event_loop): - """Create a new session from cluster.""" - - async def _impl(): - new_session = await test_state["cluster"].connect() - result = await new_session.execute("SELECT release_version FROM system.local") - assert result.one() is not None - await new_session.close() - - run_async(_impl(), event_loop) - - -# Scenario: Multiple concurrent streams don't interfere -@given("multiple sessions from the same cluster") -def create_multiple_sessions(test_state, event_loop): - """Create multiple sessions.""" - - async def _impl(): - await test_state["session"].set_keyspace("test_context_safety") - - # Create test table - await test_state["session"].execute( - """ - CREATE TABLE IF NOT EXISTS concurrent_test ( - partition_id INT, - id UUID, - value TEXT, - PRIMARY KEY (partition_id, id) - ) - """ - ) - - # Insert data for different partitions - for partition in range(3): - for i in range(20): - await test_state["session"].execute( - "INSERT INTO concurrent_test (partition_id, id, value) VALUES (%s, %s, %s)", - [partition, uuid.uuid4(), f"value_{partition}_{i}"], - ) - - # Create multiple sessions - for _ in range(3): - session = await test_state["cluster"].connect("test_context_safety") - test_state["sessions"].append(session) - - run_async(_impl(), event_loop) - - -@when("I stream data concurrently from each session") -def concurrent_streaming(test_state, event_loop): - """Stream from each session concurrently.""" - - async def _impl(): - async def stream_partition(session, partition_id): - count = 0 - config = StreamConfig(fetch_size=5) - - async with await session.execute_stream( - "SELECT * FROM concurrent_test WHERE partition_id = %s", - [partition_id], - stream_config=config, - ) as stream: - async for row in stream: - count += 1 - - return count - - # Stream concurrently - tasks = [] - for i, session in enumerate(test_state["sessions"]): - task = stream_partition(session, i) - tasks.append(task) - - test_state["results"] = await asyncio.gather(*tasks) - - run_async(_impl(), event_loop) - - -@then("each stream should complete independently") -def streams_completed(test_state): - """All streams should complete.""" - assert len(test_state["results"]) == 3 - assert all(count == 20 for count in test_state["results"]) - - -@then("closing one stream should not affect others") -def close_one_stream(test_state, event_loop): - """Already tested by concurrent execution.""" - # Streams were in context managers, so they closed independently - pass - - -@then("all sessions should remain usable") -def all_sessions_usable(test_state, event_loop): - """Test all sessions still work.""" - - async def _impl(): - for session in test_state["sessions"]: - result = await session.execute("SELECT COUNT(*) FROM concurrent_test") - assert result.one()[0] == 60 # Total rows - - run_async(_impl(), event_loop) - - -# Scenario: Thread safety during context exit -@given("a session being used by multiple threads") -def session_for_threads(test_state, event_loop): - """Set up session for thread testing.""" - - async def _impl(): - await test_state["session"].set_keyspace("test_context_safety") - - await test_state["session"].execute( - """ - CREATE TABLE IF NOT EXISTS thread_test ( - thread_id INT PRIMARY KEY, - status TEXT, - timestamp TIMESTAMP - ) - """ - ) - - # Truncate first to ensure clean state - await test_state["session"].execute("TRUNCATE thread_test") - - run_async(_impl(), event_loop) - - -@when("one thread exits a streaming context manager") -def thread_exits_context(test_state, event_loop): - """Use streaming in main thread while other threads work.""" - - async def _impl(): - def worker_thread(session, thread_id): - """Worker thread function.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - async def do_work(): - # Each thread writes its own record - import datetime - - await session.execute( - "INSERT INTO thread_test (thread_id, status, timestamp) VALUES (%s, %s, %s)", - [thread_id, "completed", datetime.datetime.now()], - ) - - return f"Thread {thread_id} completed" - - result = loop.run_until_complete(do_work()) - loop.close() - return result - - # Start threads - with ThreadPoolExecutor(max_workers=2) as executor: - futures = [] - for i in range(2): - future = executor.submit(worker_thread, test_state["session"], i) - futures.append(future) - - # Use streaming in main thread - async with await test_state["session"].execute_stream( - "SELECT * FROM thread_test" - ) as stream: - async for row in stream: - await asyncio.sleep(0.1) # Give threads time to work - - # Collect thread results - for future in futures: - result = future.result(timeout=5.0) - test_state["thread_results"].append(result) - - run_async(_impl(), event_loop) - - -@then("other threads should still be able to use the session") -def threads_used_session(test_state): - """Verify threads completed their work.""" - assert len(test_state["thread_results"]) == 2 - assert all("completed" in result for result in test_state["thread_results"]) - - -@then("no operations should be interrupted") -def verify_thread_operations(test_state, event_loop): - """Verify all thread operations completed.""" - - async def _impl(): - result = await test_state["session"].execute("SELECT thread_id, status FROM thread_test") - rows = list(result) - # Both threads should have completed - assert len(rows) == 2 - thread_ids = {row.thread_id for row in rows} - assert 0 in thread_ids - assert 1 in thread_ids - # All should have completed status - assert all(row.status == "completed" for row in rows) - - run_async(_impl(), event_loop) - - -# Scenario: Nested context managers close in correct order -@given("a cluster, session, and streaming result in nested context managers") -def nested_contexts(test_state, event_loop): - """Set up nested context managers.""" - - async def _impl(): - # Set up test data - test_state["nested_cluster"] = AsyncCluster(["localhost"]) - test_state["nested_session"] = await test_state["nested_cluster"].connect() - - await test_state["nested_session"].execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_nested - WITH REPLICATION = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - await test_state["nested_session"].set_keyspace("test_nested") - - await test_state["nested_session"].execute( - """ - CREATE TABLE IF NOT EXISTS nested_test ( - id UUID PRIMARY KEY, - value INT - ) - """ - ) - - # Clear existing data first - await test_state["nested_session"].execute("TRUNCATE nested_test") - - # Insert test data - for i in range(5): - await test_state["nested_session"].execute( - "INSERT INTO nested_test (id, value) VALUES (%s, %s)", [uuid.uuid4(), i] - ) - - # Start streaming (but don't iterate yet) - test_state["nested_stream"] = await test_state["nested_session"].execute_stream( - "SELECT * FROM nested_test" - ) - - run_async(_impl(), event_loop) - - -@when("the innermost context (streaming) exits") -def exit_streaming_context(test_state, event_loop): - """Exit streaming context.""" - - async def _impl(): - # Use and close the streaming context - async with test_state["nested_stream"] as stream: - count = 0 - async for row in stream: - count += 1 - test_state["stream_count"] = count - - run_async(_impl(), event_loop) - - -@then("only the streaming result should be closed") -def verify_only_stream_closed(test_state): - """Verify only stream is closed.""" - # Stream was closed by context manager - assert test_state["stream_count"] == 5 # Got all rows - assert not test_state["nested_session"].is_closed - assert not test_state["nested_cluster"].is_closed - - -@when("the middle context (session) exits") -def exit_session_context(test_state, event_loop): - """Exit session context.""" - - async def _impl(): - await test_state["nested_session"].close() - - run_async(_impl(), event_loop) - - -@then("only the session should be closed") -def verify_only_session_closed(test_state): - """Verify only session is closed.""" - assert test_state["nested_session"].is_closed - assert not test_state["nested_cluster"].is_closed - - -@when("the outer context (cluster) exits") -def exit_cluster_context(test_state, event_loop): - """Exit cluster context.""" - - async def _impl(): - await test_state["nested_cluster"].shutdown() - - run_async(_impl(), event_loop) - - -@then("the cluster should be shut down") -def verify_cluster_shutdown(test_state): - """Verify cluster is shut down.""" - assert test_state["nested_cluster"].is_closed - - -# Scenario: Context manager handles cancellation correctly -@given("an active streaming operation in a context manager") -def active_streaming_operation(test_state, event_loop): - """Set up active streaming operation.""" - - async def _impl(): - # Ensure we have session and keyspace - if not test_state.get("session"): - test_state["cluster"] = AsyncCluster(["localhost"]) - test_state["session"] = await test_state["cluster"].connect() - - await test_state["session"].execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_context_safety - WITH REPLICATION = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - await test_state["session"].set_keyspace("test_context_safety") - - # Create table with lots of data - await test_state["session"].execute( - """ - CREATE TABLE IF NOT EXISTS test_context_safety.cancel_test ( - id UUID PRIMARY KEY, - value INT - ) - """ - ) - - # Insert more data for longer streaming - for i in range(100): - await test_state["session"].execute( - "INSERT INTO test_context_safety.cancel_test (id, value) VALUES (%s, %s)", - [uuid.uuid4(), i], - ) - - # Create streaming task that we'll cancel - async def stream_with_delay(): - async with await test_state["session"].execute_stream( - "SELECT * FROM test_context_safety.cancel_test" - ) as stream: - count = 0 - async for row in stream: - count += 1 - # Add delay to make cancellation more likely - await asyncio.sleep(0.01) - return count - - # Start streaming task - test_state["streaming_task"] = asyncio.create_task(stream_with_delay()) - # Give it time to start - await asyncio.sleep(0.1) - - run_async(_impl(), event_loop) - - -@when("the operation is cancelled") -def cancel_operation(test_state, event_loop): - """Cancel the streaming operation.""" - - async def _impl(): - # Cancel the task - test_state["streaming_task"].cancel() - - # Wait for cancellation - try: - await test_state["streaming_task"] - except asyncio.CancelledError: - test_state["cancelled"] = True - - run_async(_impl(), event_loop) - - -@then("the streaming result should be properly cleaned up") -def verify_streaming_cleaned_up(test_state): - """Verify streaming was cleaned up.""" - # Task was cancelled - assert test_state.get("cancelled") is True - assert test_state["streaming_task"].cancelled() - - -# Reuse the existing session_is_open step for cancellation scenario -# The "But" prefix is ignored by pytest-bdd - - -# Cleanup -@pytest.fixture(autouse=True) -def cleanup(test_state, event_loop, request): - """Clean up after each test.""" - yield - - async def _cleanup(): - # Close all sessions - for session in test_state.get("sessions", []): - if session and not session.is_closed: - await session.close() - - # Clean up main session and cluster - if test_state.get("session"): - try: - await test_state["session"].execute("DROP KEYSPACE IF EXISTS test_context_safety") - except Exception: - pass - if not test_state["session"].is_closed: - await test_state["session"].close() - - if test_state.get("cluster") and not test_state["cluster"].is_closed: - await test_state["cluster"].shutdown() - - run_async(_cleanup(), event_loop) diff --git a/tests/bdd/test_bdd_fastapi.py b/tests/bdd/test_bdd_fastapi.py deleted file mode 100644 index 336311d..0000000 --- a/tests/bdd/test_bdd_fastapi.py +++ /dev/null @@ -1,2040 +0,0 @@ -"""BDD tests for FastAPI integration scenarios with real Cassandra.""" - -import asyncio -import concurrent.futures -import time - -import pytest -import pytest_asyncio -from fastapi import Depends, FastAPI, HTTPException -from fastapi.testclient import TestClient -from pytest_bdd import given, parsers, scenario, then, when - -from async_cassandra import AsyncCluster - -# Import the cassandra_container fixture -pytest_plugins = ["tests._fixtures.cassandra"] - - -@pytest_asyncio.fixture(autouse=True) -async def ensure_cassandra_enabled_for_bdd(cassandra_container): - """Ensure Cassandra binary protocol is enabled before and after each test.""" - import asyncio - import subprocess - - # Enable at start - try: - subprocess.run( - [ - cassandra_container.runtime, - "exec", - cassandra_container.container_name, - "nodetool", - "enablebinary", - ], - capture_output=True, - ) - except Exception: - pass # Container might not be ready yet - - await asyncio.sleep(1) - - yield - - # Enable at end (cleanup) - try: - subprocess.run( - [ - cassandra_container.runtime, - "exec", - cassandra_container.container_name, - "nodetool", - "enablebinary", - ], - capture_output=True, - ) - except Exception: - pass # Don't fail cleanup - - await asyncio.sleep(1) - - -@scenario("features/fastapi_integration.feature", "Simple REST API endpoint") -def test_simple_rest_endpoint(): - """Test simple REST API endpoint.""" - pass - - -@scenario("features/fastapi_integration.feature", "Handle concurrent API requests") -def test_concurrent_requests(): - """Test concurrent API requests.""" - pass - - -@scenario("features/fastapi_integration.feature", "Application lifecycle management") -def test_lifecycle_management(): - """Test application lifecycle.""" - pass - - -@scenario("features/fastapi_integration.feature", "API error handling for database issues") -def test_api_error_handling(): - """Test API error handling for database issues.""" - pass - - -@scenario("features/fastapi_integration.feature", "Use async-cassandra with FastAPI dependencies") -def test_dependency_injection(): - """Test FastAPI dependency injection with async-cassandra.""" - pass - - -@scenario("features/fastapi_integration.feature", "Stream large datasets through API") -def test_streaming_endpoint(): - """Test streaming large datasets.""" - pass - - -@scenario("features/fastapi_integration.feature", "Implement cursor-based pagination") -def test_pagination(): - """Test cursor-based pagination.""" - pass - - -@scenario("features/fastapi_integration.feature", "Implement query result caching") -def test_caching(): - """Test query result caching.""" - pass - - -@scenario("features/fastapi_integration.feature", "Use prepared statements in API endpoints") -def test_prepared_statements(): - """Test prepared statements in API.""" - pass - - -@scenario("features/fastapi_integration.feature", "Monitor API and database performance") -def test_monitoring(): - """Test API and database monitoring.""" - pass - - -@scenario("features/fastapi_integration.feature", "Connection reuse across requests") -def test_connection_reuse(): - """Test connection reuse across requests.""" - pass - - -@scenario("features/fastapi_integration.feature", "Background tasks with Cassandra operations") -def test_background_tasks(): - """Test background tasks with Cassandra.""" - pass - - -@scenario("features/fastapi_integration.feature", "Graceful shutdown under load") -def test_graceful_shutdown(): - """Test graceful shutdown under load.""" - pass - - -@scenario("features/fastapi_integration.feature", "Track Cassandra query metrics in middleware") -def test_track_cassandra_query_metrics(): - """Test tracking Cassandra query metrics in middleware.""" - pass - - -@scenario("features/fastapi_integration.feature", "Handle Cassandra connection failures gracefully") -def test_connection_failure_handling(): - """Test connection failure handling.""" - pass - - -@scenario("features/fastapi_integration.feature", "WebSocket endpoint with Cassandra streaming") -def test_websocket_streaming(): - """Test WebSocket streaming.""" - pass - - -@scenario("features/fastapi_integration.feature", "Handle memory pressure gracefully") -def test_memory_pressure(): - """Test memory pressure handling.""" - pass - - -@scenario("features/fastapi_integration.feature", "Authentication and session isolation") -def test_auth_session_isolation(): - """Test authentication and session isolation.""" - pass - - -@pytest.fixture -def fastapi_context(cassandra_container): - """Context for FastAPI tests.""" - return { - "app": None, - "client": None, - "cluster": None, - "session": None, - "container": cassandra_container, - "response": None, - "responses": [], - "start_time": None, - "duration": None, - "error": None, - "metrics": {}, - "startup_complete": False, - "shutdown_complete": False, - } - - -def run_async(coro): - """Run async code in sync context.""" - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(coro) - finally: - loop.close() - - -# Given steps -@given("a FastAPI application with async-cassandra") -def fastapi_app(fastapi_context): - """Create FastAPI app with async-cassandra.""" - # Use the new lifespan context manager approach - from contextlib import asynccontextmanager - from datetime import datetime - - @asynccontextmanager - async def lifespan(app: FastAPI): - # Startup - cluster = AsyncCluster(["127.0.0.1"]) - session = await cluster.connect() - await session.set_keyspace("test_keyspace") - - app.state.cluster = cluster - app.state.session = session - fastapi_context["cluster"] = cluster - fastapi_context["session"] = session - - # If we need to track queries, wrap the execute method now - if fastapi_context.get("needs_query_tracking"): - import time - - original_execute = app.state.session.execute - - async def tracked_execute(query, *args, **kwargs): - """Wrapper to track query execution.""" - start_time = time.time() - app.state.query_metrics["total_queries"] += 1 - - # Track which request this query belongs to - current_request_id = getattr(app.state, "current_request_id", None) - if current_request_id: - if current_request_id not in app.state.query_metrics["queries_per_request"]: - app.state.query_metrics["queries_per_request"][current_request_id] = 0 - app.state.query_metrics["queries_per_request"][current_request_id] += 1 - - try: - result = await original_execute(query, *args, **kwargs) - execution_time = time.time() - start_time - - # Track execution time - if current_request_id: - if current_request_id not in app.state.query_metrics["query_times"]: - app.state.query_metrics["query_times"][current_request_id] = [] - app.state.query_metrics["query_times"][current_request_id].append( - execution_time - ) - - return result - except Exception as e: - execution_time = time.time() - start_time - # Still track failed queries - if ( - current_request_id - and current_request_id in app.state.query_metrics["query_times"] - ): - app.state.query_metrics["query_times"][current_request_id].append( - execution_time - ) - raise e - - # Store original for later restoration - tracked_execute.__wrapped__ = original_execute - app.state.session.execute = tracked_execute - - fastapi_context["startup_complete"] = True - - yield - - # Shutdown - if app.state.session: - await app.state.session.close() - if app.state.cluster: - await app.state.cluster.shutdown() - fastapi_context["shutdown_complete"] = True - - app = FastAPI(lifespan=lifespan) - - # Add query metrics middleware if needed - if fastapi_context.get("middleware_needed") and fastapi_context.get( - "query_metrics_middleware_class" - ): - app.state.query_metrics = { - "requests": [], - "queries_per_request": {}, - "query_times": {}, - "total_queries": 0, - } - app.add_middleware(fastapi_context["query_metrics_middleware_class"]) - - # Mark that we need to track queries after session is created - fastapi_context["needs_query_tracking"] = fastapi_context.get( - "track_query_execution", False - ) - - fastapi_context["middleware_added"] = True - else: - # Initialize empty metrics anyway for the test - app.state.query_metrics = { - "requests": [], - "queries_per_request": {}, - "query_times": {}, - "total_queries": 0, - } - - # Add monitoring middleware if needed - if fastapi_context.get("monitoring_setup_needed"): - # Simple metrics collector - app.state.metrics = { - "request_count": 0, - "request_duration": [], - "cassandra_query_count": 0, - "cassandra_query_duration": [], - "error_count": 0, - "start_time": datetime.now(), - } - - @app.middleware("http") - async def monitor_requests(request, call_next): - start = time.time() - app.state.metrics["request_count"] += 1 - - try: - response = await call_next(request) - duration = time.time() - start - app.state.metrics["request_duration"].append(duration) - return response - except Exception: - app.state.metrics["error_count"] += 1 - raise - - @app.get("/metrics") - async def get_metrics(): - metrics = app.state.metrics - uptime = (datetime.now() - metrics["start_time"]).total_seconds() - - return { - "request_count": metrics["request_count"], - "request_duration": { - "avg": ( - sum(metrics["request_duration"]) / len(metrics["request_duration"]) - if metrics["request_duration"] - else 0 - ), - "count": len(metrics["request_duration"]), - }, - "cassandra_query_count": metrics["cassandra_query_count"], - "cassandra_query_duration": { - "avg": ( - sum(metrics["cassandra_query_duration"]) - / len(metrics["cassandra_query_duration"]) - if metrics["cassandra_query_duration"] - else 0 - ), - "count": len(metrics["cassandra_query_duration"]), - }, - "connection_pool_size": 10, # Mock value - "error_rate": ( - metrics["error_count"] / metrics["request_count"] - if metrics["request_count"] > 0 - else 0 - ), - "uptime_seconds": uptime, - } - - fastapi_context["monitoring_enabled"] = True - - # Store the app in context - fastapi_context["app"] = app - - # If we already have a client, recreate it with the new app - if fastapi_context.get("client"): - fastapi_context["client"] = TestClient(app) - fastapi_context["client_entered"] = True - - # Initialize state - app.state.cluster = None - app.state.session = None - - -@given("a running Cassandra cluster with test data") -def cassandra_with_data(fastapi_context): - """Ensure Cassandra has test data.""" - # The container is already running from the fixture - assert fastapi_context["container"].is_running() - - # Create test tables and data - async def setup_data(): - cluster = AsyncCluster(["127.0.0.1"]) - session = await cluster.connect() - await session.set_keyspace("test_keyspace") - - # Create users table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS users ( - id int PRIMARY KEY, - name text, - email text, - age int, - created_at timestamp, - updated_at timestamp - ) - """ - ) - - # Insert test users - await session.execute( - """ - INSERT INTO users (id, name, email, age, created_at, updated_at) - VALUES (123, 'Alice', 'alice@example.com', 25, toTimestamp(now()), toTimestamp(now())) - """ - ) - - await session.execute( - """ - INSERT INTO users (id, name, email, age, created_at, updated_at) - VALUES (456, 'Bob', 'bob@example.com', 30, toTimestamp(now()), toTimestamp(now())) - """ - ) - - # Create products table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS products ( - id int PRIMARY KEY, - name text, - price decimal - ) - """ - ) - - # Insert test products - for i in range(1, 51): # Create 50 products for pagination tests - await session.execute( - f""" - INSERT INTO products (id, name, price) - VALUES ({i}, 'Product {i}', {10.99 * i}) - """ - ) - - await session.close() - await cluster.shutdown() - - run_async(setup_data()) - - -@given("the FastAPI test client is initialized") -def init_test_client(fastapi_context): - """Initialize test client.""" - app = fastapi_context["app"] - - # Create test client with lifespan management - # We'll manually handle the lifespan - - # Enter the lifespan context - test_client = TestClient(app) - test_client.__enter__() # This triggers startup - - fastapi_context["client"] = test_client - fastapi_context["client_entered"] = True - - -@given("a user endpoint that queries Cassandra") -def user_endpoint(fastapi_context): - """Create user endpoint.""" - app = fastapi_context["app"] - - @app.get("/users/{user_id}") - async def get_user(user_id: int): - """Get user by ID.""" - session = app.state.session - - # Track query count - if not hasattr(app.state, "total_queries"): - app.state.total_queries = 0 - app.state.total_queries += 1 - - result = await session.execute("SELECT * FROM users WHERE id = %s", [user_id]) - - rows = result.rows - if not rows: - raise HTTPException(status_code=404, detail="User not found") - - user = rows[0] - return { - "id": user.id, - "name": user.name, - "email": user.email, - "age": user.age, - "created_at": user.created_at.isoformat() if user.created_at else None, - "updated_at": user.updated_at.isoformat() if user.updated_at else None, - } - - -@given("a product search endpoint") -def product_endpoint(fastapi_context): - """Create product search endpoint.""" - app = fastapi_context["app"] - - @app.get("/products/search") - async def search_products(q: str = ""): - """Search products.""" - session = app.state.session - - # Get all products and filter in memory (for simplicity) - result = await session.execute("SELECT * FROM products") - - products = [] - for row in result.rows: - if not q or q.lower() in row.name.lower(): - products.append( - {"id": row.id, "name": row.name, "price": float(row.price) if row.price else 0} - ) - - return {"results": products} - - -# When steps -@when(parsers.parse('I send a GET request to "{path}"')) -def send_get_request(path, fastapi_context): - """Send GET request.""" - fastapi_context["start_time"] = time.time() - response = fastapi_context["client"].get(path) - fastapi_context["response"] = response - fastapi_context["duration"] = (time.time() - fastapi_context["start_time"]) * 1000 - - -@when(parsers.parse("I send {count:d} concurrent search requests")) -def send_concurrent_requests(count, fastapi_context): - """Send concurrent requests.""" - - def make_request(i): - return fastapi_context["client"].get("/products/search?q=Product") - - start = time.time() - with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: - futures = [executor.submit(make_request, i) for i in range(count)] - responses = [f.result() for f in concurrent.futures.as_completed(futures)] - - fastapi_context["responses"] = responses - fastapi_context["duration"] = (time.time() - start) * 1000 - - -@when("the FastAPI application starts up") -def app_startup(fastapi_context): - """Start the application.""" - # The TestClient triggers startup event when first used - # Make a dummy request to trigger startup - try: - fastapi_context["client"].get("/nonexistent") # This will 404 but triggers startup - except Exception: - pass # Expected 404 - - -@when("the application shuts down") -def app_shutdown(fastapi_context): - """Shutdown application.""" - # Close the test client to trigger shutdown - if fastapi_context.get("client") and not fastapi_context.get("client_closed"): - fastapi_context["client"].__exit__(None, None, None) - fastapi_context["client_closed"] = True - - -# Then steps -@then(parsers.parse("I should receive a {status_code:d} response")) -def verify_status_code(status_code, fastapi_context): - """Verify response status code.""" - assert fastapi_context["response"].status_code == status_code - - -@then("the response should contain user data") -def verify_user_data(fastapi_context): - """Verify user data in response.""" - data = fastapi_context["response"].json() - assert "id" in data - assert "name" in data - assert "email" in data - assert data["id"] == 123 - assert data["name"] == "Alice" - - -@then(parsers.parse("the request should complete within {timeout:d}ms")) -def verify_request_time(timeout, fastapi_context): - """Verify request completion time.""" - assert fastapi_context["duration"] < timeout - - -@then("all requests should receive valid responses") -def verify_all_responses(fastapi_context): - """Verify all responses are valid.""" - assert len(fastapi_context["responses"]) == 100 - for response in fastapi_context["responses"]: - assert response.status_code == 200 - data = response.json() - assert "results" in data - assert len(data["results"]) > 0 - - -@then(parsers.parse("no request should take longer than {timeout:d}ms")) -def verify_no_slow_requests(timeout, fastapi_context): - """Verify no slow requests.""" - # Overall time for 100 concurrent requests should be reasonable - # Not 100x single request time - assert fastapi_context["duration"] < timeout - - -@then("the Cassandra connection pool should not be exhausted") -def verify_pool_not_exhausted(fastapi_context): - """Verify connection pool is OK.""" - # All requests succeeded, so pool wasn't exhausted - assert all(r.status_code == 200 for r in fastapi_context["responses"]) - - -@then("the Cassandra cluster connection should be established") -def verify_cluster_connected(fastapi_context): - """Verify cluster connection.""" - assert fastapi_context["startup_complete"] is True - assert fastapi_context["cluster"] is not None - assert fastapi_context["session"] is not None - - -@then("the connection pool should be initialized") -def verify_pool_initialized(fastapi_context): - """Verify connection pool.""" - # Session exists means pool is initialized - assert fastapi_context["session"] is not None - - -@then("all active queries should complete or timeout") -def verify_queries_complete(fastapi_context): - """Verify queries complete.""" - # Check that FastAPI shutdown was clean - assert fastapi_context["shutdown_complete"] is True - # Verify session and cluster were available until shutdown - assert fastapi_context["session"] is not None - assert fastapi_context["cluster"] is not None - - -@then("all connections should be properly closed") -def verify_connections_closed(fastapi_context): - """Verify connections closed.""" - # After shutdown, connections should be closed - # We need to actually check this after the shutdown event - with fastapi_context["client"]: - pass # This triggers the shutdown - - # Now verify the session and cluster were closed in shutdown - assert fastapi_context["shutdown_complete"] is True - - -@then("no resource warnings should be logged") -def verify_no_warnings(fastapi_context): - """Verify no resource warnings.""" - import warnings - - # Check if any ResourceWarnings were issued - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always", ResourceWarning) - # Force garbage collection to trigger any pending warnings - import gc - - gc.collect() - - # Check for resource warnings - resource_warnings = [ - warning for warning in w if issubclass(warning.category, ResourceWarning) - ] - assert len(resource_warnings) == 0, f"Found resource warnings: {resource_warnings}" - - -# Cleanup -@pytest.fixture(autouse=True) -def cleanup_after_test(fastapi_context): - """Cleanup resources after each test.""" - yield - - # Cleanup test client if it was entered - if fastapi_context.get("client_entered") and fastapi_context.get("client"): - try: - fastapi_context["client"].__exit__(None, None, None) - except Exception: - pass - - -# Additional Given steps for new scenarios -@given("an endpoint that performs multiple queries") -def setup_multiple_queries_endpoint(fastapi_context): - """Setup endpoint that performs multiple queries.""" - app = fastapi_context["app"] - - @app.get("/multi-query") - async def multi_query_endpoint(): - session = app.state.session - - # Perform multiple queries - results = [] - queries = [ - "SELECT * FROM users WHERE id = 1", - "SELECT * FROM users WHERE id = 2", - "SELECT * FROM products WHERE id = 1", - "SELECT COUNT(*) FROM products", - ] - - for query in queries: - result = await session.execute(query) - results.append(result.one()) - - return {"query_count": len(queries), "results": len(results)} - - fastapi_context["multi_query_endpoint_added"] = True - - -@given("an endpoint that triggers background Cassandra operations") -def setup_background_tasks_endpoint(fastapi_context): - """Setup endpoint with background tasks.""" - from fastapi import BackgroundTasks - - app = fastapi_context["app"] - fastapi_context["background_tasks_completed"] = [] - - async def write_to_cassandra(task_id: int, session): - """Background task to write to Cassandra.""" - try: - await session.execute( - "INSERT INTO background_tasks (id, status, created_at) VALUES (%s, %s, toTimestamp(now()))", - [task_id, "completed"], - ) - fastapi_context["background_tasks_completed"].append(task_id) - except Exception as e: - print(f"Background task {task_id} failed: {e}") - - @app.post("/background-write", status_code=202) - async def trigger_background_write(task_id: int, background_tasks: BackgroundTasks): - # Ensure table exists - await app.state.session.execute( - """CREATE TABLE IF NOT EXISTS background_tasks ( - id int PRIMARY KEY, - status text, - created_at timestamp - )""" - ) - - # Add background task - background_tasks.add_task(write_to_cassandra, task_id, app.state.session) - - return {"message": "Task submitted", "task_id": task_id, "status": "accepted"} - - fastapi_context["background_endpoint_added"] = True - - -@given("heavy concurrent load on the API") -def setup_heavy_load(fastapi_context): - """Setup for heavy load testing.""" - # Create endpoints that will be used for load testing - app = fastapi_context["app"] - - @app.get("/load-test") - async def load_test_endpoint(): - session = app.state.session - result = await session.execute("SELECT now() FROM system.local") - return {"timestamp": str(result.one()[0])} - - # Flag to track shutdown behavior - fastapi_context["shutdown_requested"] = False - fastapi_context["load_test_endpoint_added"] = True - - -@given("a middleware that tracks Cassandra query execution") -def setup_query_metrics_middleware(fastapi_context): - """Setup middleware to track Cassandra queries.""" - from starlette.middleware.base import BaseHTTPMiddleware - from starlette.requests import Request - - class QueryMetricsMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - app = request.app - # Generate unique request ID - request_id = len(app.state.query_metrics["requests"]) + 1 - app.state.query_metrics["requests"].append(request_id) - - # Set current request ID for query tracking - app.state.current_request_id = request_id - - try: - response = await call_next(request) - return response - finally: - # Clear current request ID - app.state.current_request_id = None - - # Mark that we need middleware and query tracking - fastapi_context["query_metrics_middleware_class"] = QueryMetricsMiddleware - fastapi_context["middleware_needed"] = True - fastapi_context["track_query_execution"] = True - - -@given("endpoints that perform different numbers of queries") -def setup_endpoints_with_varying_queries(fastapi_context): - """Setup endpoints that perform different numbers of Cassandra queries.""" - app = fastapi_context["app"] - - @app.get("/no-queries") - async def no_queries(): - """Endpoint that doesn't query Cassandra.""" - return {"message": "No queries executed"} - - @app.get("/single-query") - async def single_query(): - """Endpoint that executes one query.""" - session = app.state.session - result = await session.execute("SELECT now() FROM system.local") - return {"timestamp": str(result.one()[0])} - - @app.get("/multiple-queries") - async def multiple_queries(): - """Endpoint that executes multiple queries.""" - session = app.state.session - results = [] - - # Execute 3 different queries - result1 = await session.execute("SELECT now() FROM system.local") - results.append(str(result1.one()[0])) - - result2 = await session.execute("SELECT count(*) FROM products") - results.append(result2.one()[0]) - - result3 = await session.execute("SELECT * FROM products LIMIT 1") - results.append(1 if result3.one() else 0) - - return {"query_count": 3, "results": results} - - @app.get("/batch-queries/{count}") - async def batch_queries(count: int): - """Endpoint that executes a variable number of queries.""" - if count > 10: - count = 10 # Limit to prevent abuse - - session = app.state.session - results = [] - - for i in range(count): - result = await session.execute("SELECT * FROM products WHERE id = %s", [i]) - results.append(result.one() is not None) - - return {"requested_count": count, "executed_count": len(results)} - - fastapi_context["query_endpoints_added"] = True - - -@given("a healthy API with established connections") -def setup_healthy_api(fastapi_context): - """Setup healthy API state.""" - app = fastapi_context["app"] - - @app.get("/health") - async def health_check(): - try: - session = app.state.session - result = await session.execute("SELECT now() FROM system.local") - return {"status": "healthy", "timestamp": str(result.one()[0])} - except Exception as e: - # Return 503 when Cassandra is unavailable - from cassandra import NoHostAvailable, OperationTimedOut, Unavailable - - if isinstance(e, (NoHostAvailable, OperationTimedOut, Unavailable)): - raise HTTPException(status_code=503, detail="Database service unavailable") - # Return 500 for other errors - raise HTTPException(status_code=500, detail="Internal server error") - - fastapi_context["health_endpoint_added"] = True - - -@given("a WebSocket endpoint that streams Cassandra data") -def setup_websocket_endpoint(fastapi_context): - """Setup WebSocket streaming endpoint.""" - import asyncio - - from fastapi import WebSocket - - app = fastapi_context["app"] - - @app.websocket("/ws/stream") - async def websocket_stream(websocket: WebSocket): - await websocket.accept() - - try: - # Continuously stream data from Cassandra - while True: - session = app.state.session - result = await session.execute("SELECT * FROM products LIMIT 5") - - data = [] - for row in result: - data.append({"id": row.id, "name": row.name}) - - await websocket.send_json({"data": data, "timestamp": str(time.time())}) - await asyncio.sleep(1) # Stream every second - - except Exception: - await websocket.close() - - fastapi_context["websocket_endpoint_added"] = True - - -@given("an endpoint that fetches large datasets") -def setup_large_dataset_endpoint(fastapi_context): - """Setup endpoint for large dataset fetching.""" - app = fastapi_context["app"] - - @app.get("/large-dataset") - async def fetch_large_dataset(limit: int = 10000): - session = app.state.session - - # Simulate memory pressure by fetching many rows - # In reality, we'd use paging to avoid OOM - try: - result = await session.execute(f"SELECT * FROM products LIMIT {min(limit, 1000)}") - - # Process in chunks to avoid memory issues - data = [] - for row in result: - data.append({"id": row.id, "name": row.name}) - - # Simulate throttling if too much data - if len(data) >= 100: - break - - return {"data": data, "total": len(data), "throttled": len(data) < limit} - - except Exception as e: - return {"error": "Memory limit reached", "message": str(e)} - - fastapi_context["large_dataset_endpoint_added"] = True - - -@given("endpoints with per-user Cassandra keyspaces") -def setup_user_keyspace_endpoints(fastapi_context): - """Setup per-user keyspace endpoints.""" - from fastapi import Header, HTTPException - - app = fastapi_context["app"] - - async def get_user_session(user_id: str = Header(None)): - """Get session for user's keyspace.""" - if not user_id: - raise HTTPException(status_code=401, detail="User ID required") - - # In a real app, we'd create/switch to user's keyspace - # For testing, we'll use the same session but track access - session = app.state.session - - # Track which user is accessing - if not hasattr(app.state, "user_access"): - app.state.user_access = {} - - if user_id not in app.state.user_access: - app.state.user_access[user_id] = [] - - return session, user_id - - @app.get("/user-data") - async def get_user_data(session_info=Depends(get_user_session)): - session, user_id = session_info - - # Track access - app.state.user_access[user_id].append(time.time()) - - # Simulate user-specific data query - result = await session.execute( - "SELECT * FROM users WHERE id = %s", [int(user_id) if user_id.isdigit() else 1] - ) - - return {"user_id": user_id, "data": result.one()._asdict() if result.one() else None} - - fastapi_context["user_keyspace_endpoints_added"] = True - - -@given("a Cassandra query that will fail") -def setup_failing_query(fastapi_context): - """Setup a query that will fail.""" - # Add endpoint that executes invalid query - app = fastapi_context["app"] - - @app.get("/failing-query") - async def failing_endpoint(): - session = app.state.session - try: - await session.execute("SELECT * FROM non_existent_table") - except Exception as e: - # Log the error for verification - fastapi_context["error"] = e - raise HTTPException(status_code=500, detail="Database error occurred") - - fastapi_context["failing_endpoint_added"] = True - - -@given("a FastAPI dependency that provides a Cassandra session") -def setup_dependency_injection(fastapi_context): - """Setup dependency injection.""" - from fastapi import Depends - - app = fastapi_context["app"] - - async def get_session(): - """Dependency to get Cassandra session.""" - return app.state.session - - @app.get("/with-dependency") - async def endpoint_with_dependency(session=Depends(get_session)): - result = await session.execute("SELECT now() FROM system.local") - return {"timestamp": str(result.one()[0])} - - fastapi_context["dependency_added"] = True - - -@given("an endpoint that returns 10,000 records") -def setup_streaming_endpoint(fastapi_context): - """Setup streaming endpoint.""" - import json - - from fastapi.responses import StreamingResponse - - app = fastapi_context["app"] - - @app.get("/stream-data") - async def stream_large_dataset(): - session = app.state.session - - async def generate(): - # Create test data if not exists - await session.execute( - """ - CREATE TABLE IF NOT EXISTS large_dataset ( - id int PRIMARY KEY, - data text - ) - """ - ) - - # Stream data in chunks - for i in range(10000): - if i % 1000 == 0: - # Insert some test data - for j in range(i, min(i + 1000, 10000)): - await session.execute( - "INSERT INTO large_dataset (id, data) VALUES (%s, %s)", [j, f"data_{j}"] - ) - - # Yield data as JSON lines - yield json.dumps({"id": i, "data": f"data_{i}"}) + "\n" - - return StreamingResponse(generate(), media_type="application/x-ndjson") - - fastapi_context["streaming_endpoint_added"] = True - - -@given("a paginated endpoint for listing items") -def setup_pagination_endpoint(fastapi_context): - """Setup pagination endpoint.""" - import base64 - - app = fastapi_context["app"] - - @app.get("/paginated-items") - async def get_paginated_items(cursor: str = None, limit: int = 20): - session = app.state.session - - # Decode cursor if provided - start_id = 0 - if cursor: - start_id = int(base64.b64decode(cursor).decode()) - - # Query with limit + 1 to check if there's next page - # Use token-based pagination for better performance and to avoid ALLOW FILTERING - if cursor: - # Use token-based pagination for subsequent pages - result = await session.execute( - "SELECT * FROM products WHERE token(id) > token(%s) LIMIT %s", - [start_id, limit + 1], - ) - else: - # First page - no token restriction needed - result = await session.execute( - "SELECT * FROM products LIMIT %s", - [limit + 1], - ) - - items = list(result) - has_next = len(items) > limit - items = items[:limit] # Return only requested limit - - # Create next cursor - next_cursor = None - if has_next and items: - next_cursor = base64.b64encode(str(items[-1].id).encode()).decode() - - return { - "items": [{"id": item.id, "name": item.name} for item in items], - "next_cursor": next_cursor, - } - - fastapi_context["pagination_endpoint_added"] = True - - -@given("an endpoint with query result caching enabled") -def setup_caching_endpoint(fastapi_context): - """Setup caching endpoint.""" - from datetime import datetime, timedelta - - app = fastapi_context["app"] - cache = {} # Simple in-memory cache - - @app.get("/cached-data/{key}") - async def get_cached_data(key: str): - # Check cache - if key in cache: - cached_data, timestamp = cache[key] - if datetime.now() - timestamp < timedelta(seconds=60): # 60s TTL - return {"data": cached_data, "from_cache": True} - - # Query database - session = app.state.session - result = await session.execute( - "SELECT * FROM products WHERE name = %s ALLOW FILTERING", [key] - ) - - data = [{"id": row.id, "name": row.name} for row in result] - cache[key] = (data, datetime.now()) - - return {"data": data, "from_cache": False} - - @app.post("/cached-data/{key}") - async def update_cached_data(key: str): - # Invalidate cache on update - if key in cache: - del cache[key] - return {"status": "cache invalidated"} - - fastapi_context["cache"] = cache - fastapi_context["caching_endpoint_added"] = True - - -@given("an endpoint that uses prepared statements") -def setup_prepared_statements_endpoint(fastapi_context): - """Setup prepared statements endpoint.""" - app = fastapi_context["app"] - - # Store prepared statement reference - app.state.prepared_statements = {} - - @app.get("/prepared/{user_id}") - async def use_prepared_statement(user_id: int): - session = app.state.session - - # Prepare statement if not already prepared - if "get_user" not in app.state.prepared_statements: - app.state.prepared_statements["get_user"] = await session.prepare( - "SELECT * FROM users WHERE id = ?" - ) - - prepared = app.state.prepared_statements["get_user"] - result = await session.execute(prepared, [user_id]) - - return {"user": result.one()._asdict() if result.one() else None} - - fastapi_context["prepared_statements_added"] = True - - -@given("monitoring is enabled for the FastAPI app") -def setup_monitoring(fastapi_context): - """Setup monitoring.""" - # This will set up the monitoring endpoints and prepare metrics - # The actual middleware will be added when creating the app - fastapi_context["monitoring_setup_needed"] = True - - -# Additional When steps -@when(parsers.parse("I make {count:d} sequential requests")) -def make_sequential_requests(count, fastapi_context): - """Make sequential requests.""" - responses = [] - start_time = time.time() - - for i in range(count): - response = fastapi_context["client"].get("/multi-query") - responses.append(response) - - fastapi_context["sequential_responses"] = responses - fastapi_context["sequential_duration"] = time.time() - start_time - - -@when(parsers.parse("I submit {count:d} tasks that write to Cassandra")) -def submit_background_tasks(count, fastapi_context): - """Submit background tasks.""" - responses = [] - - for i in range(count): - response = fastapi_context["client"].post(f"/background-write?task_id={i}") - responses.append(response) - - fastapi_context["background_task_responses"] = responses - # Give background tasks time to complete - time.sleep(2) - - -@when("the application receives a shutdown signal") -def trigger_shutdown_signal(fastapi_context): - """Simulate shutdown signal.""" - fastapi_context["shutdown_requested"] = True - # Note: In real scenario, we'd send SIGTERM to the process - # For testing, we'll simulate by marking shutdown requested - - -@when("I make requests to endpoints with varying query counts") -def make_requests_with_varying_queries(fastapi_context): - """Make requests to endpoints that execute different numbers of queries.""" - client = fastapi_context["client"] - app = fastapi_context["app"] - - # Reset metrics before testing - app.state.query_metrics["total_queries"] = 0 - app.state.query_metrics["requests"].clear() - app.state.query_metrics["queries_per_request"].clear() - app.state.query_metrics["query_times"].clear() - - test_requests = [] - - # Test 1: No queries - response = client.get("/no-queries") - test_requests.append({"endpoint": "/no-queries", "response": response, "expected_queries": 0}) - - # Test 2: Single query - response = client.get("/single-query") - test_requests.append({"endpoint": "/single-query", "response": response, "expected_queries": 1}) - - # Test 3: Multiple queries (3) - response = client.get("/multiple-queries") - test_requests.append( - {"endpoint": "/multiple-queries", "response": response, "expected_queries": 3} - ) - - # Test 4: Batch queries (5) - response = client.get("/batch-queries/5") - test_requests.append( - {"endpoint": "/batch-queries/5", "response": response, "expected_queries": 5} - ) - - # Test 5: Another single query to verify tracking continues - response = client.get("/single-query") - test_requests.append({"endpoint": "/single-query", "response": response, "expected_queries": 1}) - - fastapi_context["test_requests"] = test_requests - fastapi_context["metrics"] = app.state.query_metrics - - -@when("Cassandra becomes temporarily unavailable") -def simulate_cassandra_unavailable(fastapi_context, cassandra_container): # noqa: F811 - """Simulate Cassandra unavailability.""" - import subprocess - - # Use nodetool to disable binary protocol (client connections) - try: - # Use the actual container from the fixture - container_ref = cassandra_container.container_name - runtime = cassandra_container.runtime - - subprocess.run( - [runtime, "exec", container_ref, "nodetool", "disablebinary"], - capture_output=True, - check=True, - ) - fastapi_context["cassandra_disabled"] = True - except subprocess.CalledProcessError as e: - print(f"Failed to disable Cassandra binary protocol: {e}") - fastapi_context["cassandra_disabled"] = False - - # Give it a moment to take effect - time.sleep(1) - - # Try to make a request that should fail - try: - response = fastapi_context["client"].get("/health") - fastapi_context["unavailable_response"] = response - except Exception as e: - fastapi_context["unavailable_error"] = e - - -@when("Cassandra becomes available again") -def simulate_cassandra_available(fastapi_context, cassandra_container): # noqa: F811 - """Simulate Cassandra becoming available.""" - import subprocess - - # Use nodetool to enable binary protocol - if fastapi_context.get("cassandra_disabled"): - try: - # Use the actual container from the fixture - container_ref = cassandra_container.container_name - runtime = cassandra_container.runtime - - subprocess.run( - [runtime, "exec", container_ref, "nodetool", "enablebinary"], - capture_output=True, - check=True, - ) - except subprocess.CalledProcessError as e: - print(f"Failed to enable Cassandra binary protocol: {e}") - - # Give it a moment to reconnect - time.sleep(2) - - # Make a request to verify recovery - response = fastapi_context["client"].get("/health") - fastapi_context["recovery_response"] = response - - -@when("a client connects and requests real-time updates") -def connect_websocket_client(fastapi_context): - """Connect WebSocket client.""" - - client = fastapi_context["client"] - - # Use test client's websocket support - with client.websocket_connect("/ws/stream") as websocket: - # Receive a few messages - messages = [] - for _ in range(3): - data = websocket.receive_json() - messages.append(data) - - fastapi_context["websocket_messages"] = messages - - -@when("multiple clients request large amounts of data") -def request_large_data_concurrently(fastapi_context): - """Request large data from multiple clients.""" - import concurrent.futures - - def fetch_large_data(client_id): - return fastapi_context["client"].get(f"/large-dataset?limit={10000}") - - # Simulate multiple clients - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(fetch_large_data, i) for i in range(5)] - responses = [f.result() for f in concurrent.futures.as_completed(futures)] - - fastapi_context["large_data_responses"] = responses - - -@when("different users make concurrent requests") -def make_user_specific_requests(fastapi_context): - """Make requests as different users.""" - import concurrent.futures - - def make_user_request(user_id): - return fastapi_context["client"].get("/user-data", headers={"user-id": str(user_id)}) - - # Make concurrent requests as different users - with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: - futures = [executor.submit(make_user_request, i) for i in [1, 2, 3]] - responses = [f.result() for f in concurrent.futures.as_completed(futures)] - - fastapi_context["user_responses"] = responses - - -@when("I send a request that triggers the failing query") -def trigger_failing_query(fastapi_context): - """Trigger the failing query.""" - response = fastapi_context["client"].get("/failing-query") - fastapi_context["response"] = response - - -@when("I use this dependency in multiple endpoints") -def use_dependency_endpoints(fastapi_context): - """Use dependency in multiple endpoints.""" - responses = [] - for _ in range(5): - response = fastapi_context["client"].get("/with-dependency") - responses.append(response) - fastapi_context["responses"] = responses - - -@when("I request the data with streaming enabled") -def request_streaming_data(fastapi_context): - """Request streaming data.""" - with fastapi_context["client"].stream("GET", "/stream-data") as response: - fastapi_context["response"] = response - fastapi_context["streamed_lines"] = [] - for line in response.iter_lines(): - if line: - fastapi_context["streamed_lines"].append(line) - - -@when(parsers.parse("I request the first page with limit {limit:d}")) -def request_first_page(limit, fastapi_context): - """Request first page.""" - response = fastapi_context["client"].get(f"/paginated-items?limit={limit}") - fastapi_context["response"] = response - fastapi_context["first_page_data"] = response.json() - - -@when("I request the next page using the cursor") -def request_next_page(fastapi_context): - """Request next page using cursor.""" - cursor = fastapi_context["first_page_data"]["next_cursor"] - response = fastapi_context["client"].get(f"/paginated-items?cursor={cursor}") - fastapi_context["next_page_response"] = response - - -@when("I make the same request multiple times") -def make_repeated_requests(fastapi_context): - """Make the same request multiple times.""" - responses = [] - key = "Product 1" # Use an actual product name - - for i in range(3): - response = fastapi_context["client"].get(f"/cached-data/{key}") - responses.append(response) - time.sleep(0.1) # Small delay between requests - - fastapi_context["cache_responses"] = responses - - -@when(parsers.parse("I make {count:d} requests to this endpoint")) -def make_many_prepared_requests(count, fastapi_context): - """Make many requests to prepared statement endpoint.""" - responses = [] - start = time.time() - - for i in range(count): - response = fastapi_context["client"].get(f"/prepared/{i % 10}") - responses.append(response) - - fastapi_context["prepared_responses"] = responses - fastapi_context["prepared_duration"] = time.time() - start - - -@when("I make various API requests") -def make_various_requests(fastapi_context): - """Make various API requests for monitoring.""" - # Make different types of requests - requests = [ - ("GET", "/users/1"), - ("GET", "/products/search?q=test"), - ("GET", "/users/2"), - ("GET", "/metrics"), # This shouldn't count in metrics - ] - - for method, path in requests: - if method == "GET": - fastapi_context["client"].get(path) - - -# Additional Then steps -@then("the same Cassandra session should be reused") -def verify_session_reuse(fastapi_context): - """Verify session is reused across requests.""" - # All requests should succeed - assert all(r.status_code == 200 for r in fastapi_context["sequential_responses"]) - - # Session should be the same instance throughout - assert fastapi_context["session"] is not None - # In a real test, we'd track session object IDs - - -@then("no new connections should be created after warmup") -def verify_no_new_connections(fastapi_context): - """Verify no new connections after warmup.""" - # After initial warmup, connection pool should be stable - # This is verified by successful completion of all requests - assert len(fastapi_context["sequential_responses"]) == 50 - - -@then("each request should complete faster than connection setup time") -def verify_request_speed(fastapi_context): - """Verify requests are fast.""" - # Average time per request should be much less than connection setup - avg_time = fastapi_context["sequential_duration"] / 50 - # Connection setup typically takes 100-500ms - # Reused connections should be < 20ms per request - assert avg_time < 0.02 # 20ms - - -@then(parsers.parse("the API should return immediately with {status:d} status")) -def verify_immediate_return(status, fastapi_context): - """Verify API returns immediately.""" - responses = fastapi_context["background_task_responses"] - assert all(r.status_code == status for r in responses) - - # Each response should be fast (background task doesn't block) - for response in responses: - assert response.elapsed.total_seconds() < 0.1 # 100ms - - -@then("all background writes should complete successfully") -def verify_background_writes(fastapi_context): - """Verify background writes completed.""" - # Wait a bit more if needed - time.sleep(1) - - # Check that all tasks completed - completed_tasks = set(fastapi_context.get("background_tasks_completed", [])) - - # Most tasks should have completed (allow for some timing issues) - assert len(completed_tasks) >= 8 # At least 80% success - - -@then("no resources should leak from background tasks") -def verify_no_background_leaks(fastapi_context): - """Verify no resource leaks from background tasks.""" - # Make another request to ensure system is still healthy - # Submit another task to verify the system is still working - response = fastapi_context["client"].post("/background-write?task_id=999") - assert response.status_code == 202 - - -@then("in-flight requests should complete successfully") -def verify_inflight_requests(fastapi_context): - """Verify in-flight requests complete.""" - # In a real test, we'd track requests started before shutdown - # For now, verify the system handles shutdown gracefully - assert fastapi_context.get("shutdown_requested", False) - - -@then(parsers.parse("new requests should be rejected with {status:d}")) -def verify_new_requests_rejected(status, fastapi_context): - """Verify new requests are rejected during shutdown.""" - # In a real implementation, new requests would get 503 - # This would require actual process management - pass # Placeholder for real implementation - - -@then("all Cassandra operations should finish cleanly") -def verify_clean_cassandra_finish(fastapi_context): - """Verify Cassandra operations finish cleanly.""" - # Verify no errors were logged during shutdown - assert fastapi_context.get("shutdown_complete", False) or True - - -@then(parsers.parse("shutdown should complete within {timeout:d} seconds")) -def verify_shutdown_timeout(timeout, fastapi_context): - """Verify shutdown completes within timeout.""" - # In a real test, we'd measure actual shutdown time - # For now, just verify the timeout is reasonable - assert timeout >= 30 - - -@then("the middleware should accurately count queries per request") -def verify_query_count_tracking(fastapi_context): - """Verify query count is accurately tracked per request.""" - test_requests = fastapi_context["test_requests"] - metrics = fastapi_context["metrics"] - - # Verify all requests succeeded - for req in test_requests: - assert req["response"].status_code == 200, f"Request to {req['endpoint']} failed" - - # Verify we tracked the right number of requests - assert len(metrics["requests"]) == len(test_requests), "Request count mismatch" - - # Verify query counts per request - for i, req in enumerate(test_requests): - request_id = i + 1 # Request IDs start at 1 - actual_queries = metrics["queries_per_request"].get(request_id, 0) - expected_queries = req["expected_queries"] - - assert actual_queries == expected_queries, ( - f"Request {request_id} to {req['endpoint']}: " - f"expected {expected_queries} queries, got {actual_queries}" - ) - - # Verify total query count - expected_total = sum(req["expected_queries"] for req in test_requests) - assert ( - metrics["total_queries"] == expected_total - ), f"Total queries mismatch: expected {expected_total}, got {metrics['total_queries']}" - - -@then("query execution time should be measured") -def verify_query_timing(fastapi_context): - """Verify query execution time is measured.""" - metrics = fastapi_context["metrics"] - test_requests = fastapi_context["test_requests"] - - # Verify timing data was collected for requests with queries - for i, req in enumerate(test_requests): - request_id = i + 1 - expected_queries = req["expected_queries"] - - if expected_queries > 0: - # Should have timing data for this request - assert ( - request_id in metrics["query_times"] - ), f"No timing data for request {request_id} to {req['endpoint']}" - - times = metrics["query_times"][request_id] - assert ( - len(times) == expected_queries - ), f"Expected {expected_queries} timing entries, got {len(times)}" - - # Verify all times are reasonable (between 0 and 1 second) - for time_val in times: - assert 0 < time_val < 1.0, f"Unreasonable query time: {time_val}s" - else: - # No queries, so no timing data expected - assert ( - request_id not in metrics["query_times"] - or len(metrics["query_times"][request_id]) == 0 - ) - - -@then("async operations should not be blocked by tracking") -def verify_middleware_no_interference(fastapi_context): - """Verify middleware doesn't block async operations.""" - test_requests = fastapi_context["test_requests"] - - # All requests should have completed successfully - assert all(req["response"].status_code == 200 for req in test_requests) - - # Verify concurrent capability by checking response times - # The middleware tracking should add minimal overhead - import time - - client = fastapi_context["client"] - - # Time a request without tracking (remove the monkey patch temporarily) - app = fastapi_context["app"] - tracked_execute = app.state.session.execute - original_execute = getattr(tracked_execute, "__wrapped__", None) - - if original_execute: - # Temporarily restore original - app.state.session.execute = original_execute - start = time.time() - response = client.get("/single-query") - baseline_time = time.time() - start - assert response.status_code == 200 - - # Restore tracking - app.state.session.execute = tracked_execute - - # Time with tracking - start = time.time() - response = client.get("/single-query") - tracked_time = time.time() - start - assert response.status_code == 200 - - # Tracking should add less than 50% overhead - overhead = (tracked_time - baseline_time) / baseline_time - assert overhead < 0.5, f"Tracking overhead too high: {overhead:.2%}" - - -@then("API should return 503 Service Unavailable") -def verify_service_unavailable(fastapi_context): - """Verify 503 response when Cassandra unavailable.""" - response = fastapi_context.get("unavailable_response") - if response: - # In a real scenario with Cassandra down, we'd get 503 or 500 - assert response.status_code in [500, 503] - - -@then("error messages should be user-friendly") -def verify_user_friendly_errors(fastapi_context): - """Verify errors are user-friendly.""" - response = fastapi_context.get("unavailable_response") - if response and response.status_code >= 500: - error_data = response.json() - # Should not expose internal details - assert "cassandra" not in error_data.get("detail", "").lower() - assert "exception" not in error_data.get("detail", "").lower() - - -@then("API should automatically recover") -def verify_automatic_recovery(fastapi_context): - """Verify API recovers automatically.""" - response = fastapi_context.get("recovery_response") - assert response is not None - assert response.status_code == 200 - data = response.json() - assert data["status"] == "healthy" - - -@then("no manual intervention should be required") -def verify_no_manual_intervention(fastapi_context): - """Verify recovery is automatic.""" - # The fact that recovery_response succeeded proves this - assert fastapi_context.get("cassandra_available", True) - - -@then("the WebSocket should stream query results") -def verify_websocket_streaming(fastapi_context): - """Verify WebSocket streams results.""" - messages = fastapi_context.get("websocket_messages", []) - assert len(messages) >= 3 - - # Each message should contain data and timestamp - for msg in messages: - assert "data" in msg - assert "timestamp" in msg - assert len(msg["data"]) > 0 - - -@then("updates should be pushed as data changes") -def verify_websocket_updates(fastapi_context): - """Verify updates are pushed.""" - messages = fastapi_context.get("websocket_messages", []) - - # Timestamps should be different (proving continuous updates) - timestamps = [float(msg["timestamp"]) for msg in messages] - assert len(set(timestamps)) == len(timestamps) # All unique - - -@then("connection cleanup should occur on disconnect") -def verify_websocket_cleanup(fastapi_context): - """Verify WebSocket cleanup.""" - # The context manager ensures cleanup - # Make a regular request to verify system still works - # Try to connect another websocket to verify the endpoint still works - try: - with fastapi_context["client"].websocket_connect("/ws/stream") as ws: - ws.close() - # If we can connect and close, cleanup worked - except Exception: - # WebSocket might not be available in test client - pass - - -@then("memory usage should stay within limits") -def verify_memory_limits(fastapi_context): - """Verify memory usage is controlled.""" - responses = fastapi_context.get("large_data_responses", []) - - # All requests should complete (not OOM) - assert len(responses) == 5 - - for response in responses: - assert response.status_code == 200 - data = response.json() - # Should be throttled to prevent OOM - assert data.get("throttled", False) or data["total"] <= 1000 - - -@then("requests should be throttled if necessary") -def verify_throttling(fastapi_context): - """Verify throttling works.""" - responses = fastapi_context.get("large_data_responses", []) - - # At least some requests should be throttled - throttled_count = sum(1 for r in responses if r.json().get("throttled", False)) - - # With multiple large requests, some should be throttled - assert throttled_count >= 0 # May or may not throttle depending on system - - -@then("the application should not crash from OOM") -def verify_no_oom_crash(fastapi_context): - """Verify no OOM crash.""" - # Application still responsive after large data requests - # Check if health endpoint exists, otherwise just verify app is responsive - response = fastapi_context["client"].get("/large-dataset?limit=1") - assert response.status_code == 200 - - -@then("each user should only access their keyspace") -def verify_user_isolation(fastapi_context): - """Verify users are isolated.""" - responses = fastapi_context.get("user_responses", []) - - # Each user should get their own data - user_data = {} - for response in responses: - assert response.status_code == 200 - data = response.json() - user_id = data["user_id"] - user_data[user_id] = data["data"] - - # Different users got different responses - assert len(user_data) >= 2 - - -@then("sessions should be isolated between users") -def verify_session_isolation(fastapi_context): - """Verify session isolation.""" - app = fastapi_context["app"] - - # Check user access tracking - if hasattr(app.state, "user_access"): - # Each user should have their own access log - assert len(app.state.user_access) >= 2 - - # Access times should be tracked separately - for user_id, accesses in app.state.user_access.items(): - assert len(accesses) > 0 - - -@then("no data should leak between user contexts") -def verify_no_data_leaks(fastapi_context): - """Verify no data leaks between users.""" - responses = fastapi_context.get("user_responses", []) - - # Each response should only contain data for the requesting user - for response in responses: - data = response.json() - user_id = data["user_id"] - - # If user data exists, it should match the user ID - if data["data"] and "id" in data["data"]: - # User ID in response should match requested user - assert str(data["data"]["id"]) == user_id or True # Allow for test data - - -@then("I should receive a 500 error response") -def verify_error_response(fastapi_context): - """Verify 500 error response.""" - assert fastapi_context["response"].status_code == 500 - - -@then("the error should not expose internal details") -def verify_error_safety(fastapi_context): - """Verify error doesn't expose internals.""" - error_data = fastapi_context["response"].json() - assert "detail" in error_data - # Should not contain table names, stack traces, etc. - assert "non_existent_table" not in error_data["detail"] - assert "Traceback" not in str(error_data) - - -@then("the connection should be returned to the pool") -def verify_connection_returned(fastapi_context): - """Verify connection returned to pool.""" - # Make another request to verify pool is not exhausted - # First check if the failing endpoint exists, otherwise make a simple health check - try: - response = fastapi_context["client"].get("/failing-query") - # If we can make another request (even if it fails), the connection was returned - assert response.status_code in [200, 500] - except Exception: - # Connection pool issue would raise an exception - pass - - -@then("each request should get a working session") -def verify_working_sessions(fastapi_context): - """Verify each request gets working session.""" - assert all(r.status_code == 200 for r in fastapi_context["responses"]) - # Verify different timestamps (proving queries executed) - timestamps = [r.json()["timestamp"] for r in fastapi_context["responses"]] - assert len(set(timestamps)) > 1 # At least some different timestamps - - -@then("sessions should be properly managed per request") -def verify_session_management(fastapi_context): - """Verify proper session management.""" - # Sessions should be reused, not created per request - assert fastapi_context["session"] is not None - assert fastapi_context["dependency_added"] is True - - -@then("no session leaks should occur between requests") -def verify_no_session_leaks(fastapi_context): - """Verify no session leaks.""" - # In a real test, we'd monitor session count - # For now, verify responses are successful - assert all(r.status_code == 200 for r in fastapi_context["responses"]) - - -@then("the response should start immediately") -def verify_streaming_start(fastapi_context): - """Verify streaming starts immediately.""" - assert fastapi_context["response"].status_code == 200 - assert fastapi_context["response"].headers["content-type"] == "application/x-ndjson" - - -@then("data should be streamed in chunks") -def verify_streaming_chunks(fastapi_context): - """Verify data is streamed in chunks.""" - assert len(fastapi_context["streamed_lines"]) > 0 - # Verify we got multiple chunks (not all at once) - assert len(fastapi_context["streamed_lines"]) >= 10 - - -@then("memory usage should remain constant") -def verify_streaming_memory(fastapi_context): - """Verify memory usage remains constant during streaming.""" - # In a real test, we'd monitor memory during streaming - # For now, verify we got all expected data - assert len(fastapi_context["streamed_lines"]) == 10000 - - -@then("the client should be able to cancel mid-stream") -def verify_streaming_cancellation(fastapi_context): - """Verify streaming can be cancelled.""" - # Test early termination - with fastapi_context["client"].stream("GET", "/stream-data") as response: - count = 0 - for line in response.iter_lines(): - count += 1 - if count >= 100: - break # Cancel early - assert count == 100 # Verify we could stop early - - -@then(parsers.parse("I should receive {count:d} items and a next cursor")) -def verify_first_page(count, fastapi_context): - """Verify first page results.""" - data = fastapi_context["first_page_data"] - assert len(data["items"]) == count - assert data["next_cursor"] is not None - - -@then(parsers.parse("I should receive the next {count:d} items")) -def verify_next_page(count, fastapi_context): - """Verify next page results.""" - data = fastapi_context["next_page_response"].json() - assert len(data["items"]) <= count - # Verify items are different from first page - first_ids = {item["id"] for item in fastapi_context["first_page_data"]["items"]} - next_ids = {item["id"] for item in data["items"]} - assert first_ids.isdisjoint(next_ids) # No overlap - - -@then("pagination should work correctly under concurrent access") -def verify_concurrent_pagination(fastapi_context): - """Verify pagination works with concurrent access.""" - import concurrent.futures - - def fetch_page(cursor=None): - url = "/paginated-items" - if cursor: - url += f"?cursor={cursor}" - return fastapi_context["client"].get(url).json() - - # Fetch multiple pages concurrently - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(fetch_page) for _ in range(5)] - results = [f.result() for f in futures] - - # All should return valid data - assert all("items" in r for r in results) - - -@then("the first request should query Cassandra") -def verify_first_cache_miss(fastapi_context): - """Verify first request queries Cassandra.""" - first_response = fastapi_context["cache_responses"][0].json() - assert first_response["from_cache"] is False - - -@then("subsequent requests should use cached data") -def verify_cache_hits(fastapi_context): - """Verify subsequent requests use cache.""" - for response in fastapi_context["cache_responses"][1:]: - assert response.json()["from_cache"] is True - - -@then("cache should expire after the configured TTL") -def verify_cache_ttl(fastapi_context): - """Verify cache TTL.""" - # Wait for TTL to expire (we set 60s in the implementation) - # For testing, we'll just verify the cache mechanism exists - assert "cache" in fastapi_context - assert fastapi_context["caching_endpoint_added"] is True - - -@then("cache should be invalidated on data updates") -def verify_cache_invalidation(fastapi_context): - """Verify cache invalidation on updates.""" - key = "Product 2" # Use an actual product name - - # First request (should cache) - response1 = fastapi_context["client"].get(f"/cached-data/{key}") - assert response1.json()["from_cache"] is False - - # Second request (should hit cache) - response2 = fastapi_context["client"].get(f"/cached-data/{key}") - assert response2.json()["from_cache"] is True - - # Update data (should invalidate cache) - fastapi_context["client"].post(f"/cached-data/{key}") - - # Next request should miss cache - response3 = fastapi_context["client"].get(f"/cached-data/{key}") - assert response3.json()["from_cache"] is False - - -@then("statement preparation should happen only once") -def verify_prepared_once(fastapi_context): - """Verify statement prepared only once.""" - # Check that prepared statements are stored - app = fastapi_context["app"] - assert "get_user" in app.state.prepared_statements - assert len(app.state.prepared_statements) == 1 - - -@then("query performance should be optimized") -def verify_prepared_performance(fastapi_context): - """Verify prepared statement performance.""" - # With 1000 requests, prepared statements should be fast - avg_time = fastapi_context["prepared_duration"] / 1000 - assert avg_time < 0.01 # Less than 10ms per query on average - - -@then("the prepared statement cache should be shared across requests") -def verify_prepared_cache_shared(fastapi_context): - """Verify prepared statement cache is shared.""" - # All requests should have succeeded - assert all(r.status_code == 200 for r in fastapi_context["prepared_responses"]) - # The single prepared statement handled all requests - app = fastapi_context["app"] - assert len(app.state.prepared_statements) == 1 - - -@then("metrics should track:") -def verify_metrics_tracking(fastapi_context): - """Verify metrics are tracked.""" - # Table data is provided in the feature file - # We'll verify the metrics endpoint returns expected fields - response = fastapi_context["client"].get("/metrics") - assert response.status_code == 200 - - metrics = response.json() - expected_fields = [ - "request_count", - "request_duration", - "cassandra_query_count", - "cassandra_query_duration", - "connection_pool_size", - "error_rate", - ] - - for field in expected_fields: - assert field in metrics - - -@then('metrics should be accessible via "/metrics" endpoint') -def verify_metrics_endpoint(fastapi_context): - """Verify metrics endpoint exists.""" - response = fastapi_context["client"].get("/metrics") - assert response.status_code == 200 - assert "request_count" in response.json() diff --git a/tests/bdd/test_fastapi_reconnection.py b/tests/bdd/test_fastapi_reconnection.py deleted file mode 100644 index 8dde092..0000000 --- a/tests/bdd/test_fastapi_reconnection.py +++ /dev/null @@ -1,605 +0,0 @@ -""" -BDD tests for FastAPI Cassandra reconnection behavior. - -This test validates the application's ability to handle Cassandra outages -and automatically recover when the database becomes available again. -""" - -import asyncio -import os -import subprocess -import sys -import time -from pathlib import Path - -import httpx -import pytest -import pytest_asyncio -from httpx import ASGITransport - -# Import the cassandra_container fixture -pytest_plugins = ["tests._fixtures.cassandra"] - -# Add FastAPI app to path -fastapi_app_dir = Path(__file__).parent.parent.parent / "examples" / "fastapi_app" -sys.path.insert(0, str(fastapi_app_dir)) - -# Import test utilities -from tests.test_utils import ( # noqa: E402 - cleanup_keyspace, - create_test_keyspace, - generate_unique_keyspace, -) -from tests.utils.cassandra_control import CassandraControl # noqa: E402 - - -def wait_for_cassandra_ready(host="127.0.0.1", timeout=30): - """Wait for Cassandra to be ready by executing a test query with cqlsh.""" - start_time = time.time() - while time.time() - start_time < timeout: - try: - # Use cqlsh to test if Cassandra is ready - result = subprocess.run( - ["cqlsh", host, "-e", "SELECT release_version FROM system.local;"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0: - return True - except (subprocess.TimeoutExpired, Exception): - pass - time.sleep(0.5) - return False - - -def wait_for_cassandra_down(host="127.0.0.1", timeout=10): - """Wait for Cassandra to be down by checking if cqlsh fails.""" - start_time = time.time() - while time.time() - start_time < timeout: - try: - result = subprocess.run( - ["cqlsh", host, "-e", "SELECT 1;"], capture_output=True, text=True, timeout=2 - ) - if result.returncode != 0: - return True - except (subprocess.TimeoutExpired, Exception): - return True - time.sleep(0.5) - return False - - -@pytest_asyncio.fixture(autouse=True) -async def ensure_cassandra_enabled_bdd(cassandra_container): - """Ensure Cassandra binary protocol is enabled before and after each test.""" - # Enable at start - subprocess.run( - [ - cassandra_container.runtime, - "exec", - cassandra_container.container_name, - "nodetool", - "enablebinary", - ], - capture_output=True, - ) - await asyncio.sleep(2) - - yield - - # Enable at end (cleanup) - subprocess.run( - [ - cassandra_container.runtime, - "exec", - cassandra_container.container_name, - "nodetool", - "enablebinary", - ], - capture_output=True, - ) - await asyncio.sleep(2) - - -@pytest_asyncio.fixture -async def unique_test_keyspace(cassandra_container): - """Create a unique keyspace for each test.""" - from async_cassandra import AsyncCluster - - # Check health before proceeding - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - pytest.fail(f"Cassandra not healthy: {health}") - - cluster = AsyncCluster(contact_points=["127.0.0.1"], protocol_version=5) - session = await cluster.connect() - - # Create unique keyspace - keyspace = generate_unique_keyspace("bdd_reconnection") - await create_test_keyspace(session, keyspace) - - yield keyspace - - # Cleanup - await cleanup_keyspace(session, keyspace) - await session.close() - await cluster.shutdown() - # Give extra time for driver's internal threads to fully stop - await asyncio.sleep(2) - - -@pytest_asyncio.fixture -async def app_client(unique_test_keyspace): - """Create test client for the FastAPI app with isolated keyspace.""" - # Set the test keyspace in environment - os.environ["TEST_KEYSPACE"] = unique_test_keyspace - - from main import app, lifespan - - # Manually handle lifespan since httpx doesn't do it properly - async with lifespan(app): - transport = ASGITransport(app=app) - async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - yield client - - # Clean up environment - os.environ.pop("TEST_KEYSPACE", None) - - -def run_async(coro): - """Run async code in sync context.""" - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(coro) - finally: - loop.close() - - -class TestFastAPIReconnectionBDD: - """BDD tests for Cassandra reconnection in FastAPI applications.""" - - def _get_cassandra_control(self, container): - """Get Cassandra control interface.""" - return CassandraControl(container) - - def test_cassandra_outage_and_recovery(self, app_client, cassandra_container): - """ - Given: A FastAPI application connected to Cassandra - When: Cassandra becomes temporarily unavailable and then recovers - Then: The application should handle the outage gracefully and automatically reconnect - """ - - async def test_scenario(): - # Given: A connected FastAPI application with working APIs - print("\nGiven: A FastAPI application with working Cassandra connection") - - # Verify health check shows connected - health_response = await app_client.get("/health") - assert health_response.status_code == 200 - assert health_response.json()["cassandra_connected"] is True - print("✓ Health check confirms Cassandra is connected") - - # Create a test user to verify functionality - user_data = {"name": "Reconnection Test User", "email": "reconnect@test.com", "age": 30} - create_response = await app_client.post("/users", json=user_data) - assert create_response.status_code == 201 - user_id = create_response.json()["id"] - print(f"✓ Created test user with ID: {user_id}") - - # Verify streaming works - stream_response = await app_client.get("/users/stream?limit=5&fetch_size=10") - if stream_response.status_code != 200: - print(f"Stream response status: {stream_response.status_code}") - print(f"Stream response body: {stream_response.text}") - assert stream_response.status_code == 200 - assert stream_response.json()["metadata"]["streaming_enabled"] is True - print("✓ Streaming API is working") - - # When: Cassandra binary protocol is disabled (simulating outage) - print("\nWhen: Cassandra becomes unavailable (disabling binary protocol)") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - control = self._get_cassandra_control(cassandra_container) - success = control.simulate_outage() - assert success, "Failed to simulate Cassandra outage" - print("✓ Binary protocol disabled - simulating Cassandra outage") - print("✓ Confirmed Cassandra is down via cqlsh") - - # Then: APIs should return 503 Service Unavailable errors - print("\nThen: APIs should return 503 Service Unavailable errors") - - # Try to create a user - should fail with 503 - try: - user_data = {"name": "Test User", "email": "test@example.com", "age": 25} - error_response = await app_client.post("/users", json=user_data, timeout=10.0) - if error_response.status_code == 503: - print("✓ Create user returns 503 Service Unavailable") - else: - print( - f"Warning: Create user returned {error_response.status_code} instead of 503" - ) - except (httpx.TimeoutException, httpx.RequestError) as e: - print(f"✓ Create user failed with {type(e).__name__} (expected)") - - # Verify health check shows disconnected - health_response = await app_client.get("/health") - assert health_response.status_code == 200 - assert health_response.json()["cassandra_connected"] is False - print("✓ Health check correctly reports Cassandra as disconnected") - - # When: Cassandra becomes available again - print("\nWhen: Cassandra becomes available again (enabling binary protocol)") - - if os.environ.get("CI") == "true": - print(" (In CI - Cassandra service always running)") - # In CI, Cassandra is always available - else: - success = control.restore_service() - assert success, "Failed to restore Cassandra service" - print("✓ Binary protocol re-enabled") - print("✓ Confirmed Cassandra is ready via cqlsh") - - # Then: The application should automatically reconnect - print("\nThen: The application should automatically reconnect") - - # Now check if the app has reconnected - # The FastAPI app uses a 2-second constant reconnection delay, so we need to wait - # at least that long plus some buffer for the reconnection to complete - reconnected = False - # Wait up to 30 seconds - driver needs time to rediscover the host - for attempt in range(30): # Up to 30 seconds (30 * 1s) - try: - # Check health first to see connection status - health_resp = await app_client.get("/health") - if health_resp.status_code == 200: - health_data = health_resp.json() - if health_data.get("cassandra_connected"): - # Now try actual query - response = await app_client.get("/users?limit=1") - if response.status_code == 200: - reconnected = True - print(f"✓ App reconnected after {attempt + 1} seconds") - break - else: - print( - f" Health says connected but query returned {response.status_code}" - ) - else: - if attempt % 5 == 0: # Print every 5 seconds - print( - f" After {attempt} seconds: Health check says not connected yet" - ) - except (httpx.TimeoutException, httpx.RequestError) as e: - print(f" Attempt {attempt + 1}: Connection error: {type(e).__name__}") - await asyncio.sleep(1.0) # Check every second - - assert reconnected, "Application failed to reconnect after Cassandra came back" - print("✓ Application successfully reconnected to Cassandra") - - # Verify health check shows connected again - health_response = await app_client.get("/health") - assert health_response.status_code == 200 - assert health_response.json()["cassandra_connected"] is True - print("✓ Health check confirms reconnection") - - # Verify we can retrieve the previously created user - get_response = await app_client.get(f"/users/{user_id}") - assert get_response.status_code == 200 - assert get_response.json()["name"] == "Reconnection Test User" - print("✓ Previously created data is still accessible") - - # Create a new user to verify full functionality - new_user_data = {"name": "Post-Recovery User", "email": "recovery@test.com", "age": 35} - create_response = await app_client.post("/users", json=new_user_data) - assert create_response.status_code == 201 - print("✓ Can create new users after recovery") - - # Verify streaming works again - stream_response = await app_client.get("/users/stream?limit=5&fetch_size=10") - assert stream_response.status_code == 200 - assert stream_response.json()["metadata"]["streaming_enabled"] is True - print("✓ Streaming API works after recovery") - - print("\n✅ Cassandra reconnection test completed successfully!") - print(" - Application handled outage gracefully with 503 errors") - print(" - Automatic reconnection occurred without manual intervention") - print(" - All functionality restored after recovery") - - # Run the async test scenario - run_async(test_scenario()) - - def test_multiple_outage_cycles(self, app_client, cassandra_container): - """ - Given: A FastAPI application connected to Cassandra - When: Cassandra experiences multiple outage/recovery cycles - Then: The application should handle each cycle gracefully - """ - - async def test_scenario(): - print("\nGiven: A FastAPI application with Cassandra connection") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Verify initial health - health_response = await app_client.get("/health") - assert health_response.status_code == 200 - assert health_response.json()["cassandra_connected"] is True - - cycles = 1 # Just test one cycle to speed up - for cycle in range(1, cycles + 1): - print(f"\nWhen: Cassandra outage cycle {cycle}/{cycles} begins") - - # Disable binary protocol - control = self._get_cassandra_control(cassandra_container) - - if os.environ.get("CI") == "true": - print(f" Cycle {cycle}: Skipping in CI - cannot control service") - continue - - success = control.simulate_outage() - assert success, f"Cycle {cycle}: Failed to simulate outage" - print(f"✓ Cycle {cycle}: Binary protocol disabled") - print(f"✓ Cycle {cycle}: Confirmed Cassandra is down via cqlsh") - - # Verify unhealthy state - health_response = await app_client.get("/health") - assert health_response.json()["cassandra_connected"] is False - print(f"✓ Cycle {cycle}: Health check reports disconnected") - - # Re-enable binary protocol - success = control.restore_service() - assert success, f"Cycle {cycle}: Failed to restore service" - print(f"✓ Cycle {cycle}: Binary protocol re-enabled") - print(f"✓ Cycle {cycle}: Confirmed Cassandra is ready via cqlsh") - - # Check app reconnection - # The FastAPI app uses a 2-second constant reconnection delay - reconnected = False - for _ in range(8): # Up to 4 seconds to account for 2s reconnection delay - try: - response = await app_client.get("/users?limit=1") - if response.status_code == 200: - reconnected = True - break - except Exception: - pass - await asyncio.sleep(0.5) - - assert reconnected, f"Cycle {cycle}: Failed to reconnect" - print(f"✓ Cycle {cycle}: Successfully reconnected") - - # Verify functionality with a test operation - user_data = { - "name": f"Cycle {cycle} User", - "email": f"cycle{cycle}@test.com", - "age": 20 + cycle, - } - create_response = await app_client.post("/users", json=user_data) - assert create_response.status_code == 201 - print(f"✓ Cycle {cycle}: Created test user successfully") - - print(f"\nThen: All {cycles} outage cycles handled successfully") - print("✅ Multiple reconnection cycles completed without issues!") - - run_async(test_scenario()) - - def test_reconnection_during_active_load(self, app_client, cassandra_container): - """ - Given: A FastAPI application under active load - When: Cassandra becomes unavailable during request processing - Then: The application should handle in-flight requests gracefully and recover - """ - - async def test_scenario(): - print("\nGiven: A FastAPI application handling active requests") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Track request results - request_results = {"successes": 0, "errors": [], "error_types": set()} - - async def continuous_requests(client: httpx.AsyncClient, duration: int): - """Make continuous requests for specified duration.""" - start_time = time.time() - - while time.time() - start_time < duration: - try: - # Alternate between different endpoints - endpoints = [ - ("/health", "GET", None), - ("/users?limit=5", "GET", None), - ( - "/users", - "POST", - {"name": "Load Test", "email": "load@test.com", "age": 25}, - ), - ] - - endpoint, method, data = endpoints[int(time.time()) % len(endpoints)] - - if method == "GET": - response = await client.get(endpoint, timeout=5.0) - else: - response = await client.post(endpoint, json=data, timeout=5.0) - - if response.status_code in [200, 201]: - request_results["successes"] += 1 - elif response.status_code == 503: - request_results["errors"].append("503_service_unavailable") - request_results["error_types"].add("503") - else: - request_results["errors"].append(f"status_{response.status_code}") - request_results["error_types"].add(str(response.status_code)) - - except (httpx.TimeoutException, httpx.RequestError) as e: - request_results["errors"].append(type(e).__name__) - request_results["error_types"].add(type(e).__name__) - - await asyncio.sleep(0.1) - - # Start continuous requests in background - print("Starting continuous load generation...") - request_task = asyncio.create_task(continuous_requests(app_client, 15)) - - # Let requests run for a bit - await asyncio.sleep(3) - print(f"✓ Initial requests successful: {request_results['successes']}") - - # When: Cassandra becomes unavailable during active load - print("\nWhen: Cassandra becomes unavailable during active requests") - control = self._get_cassandra_control(cassandra_container) - - if os.environ.get("CI") == "true": - print(" (In CI - cannot disable service, continuing with available service)") - else: - success = control.simulate_outage() - assert success, "Failed to simulate outage" - print("✓ Binary protocol disabled during active load") - - # Let errors accumulate - await asyncio.sleep(4) - print(f"✓ Errors during outage: {len(request_results['errors'])}") - - # Re-enable Cassandra - print("\nWhen: Cassandra becomes available again") - if not os.environ.get("CI") == "true": - success = control.restore_service() - assert success, "Failed to restore service" - print("✓ Binary protocol re-enabled") - - # Wait for task completion - await request_task - - # Then: Analyze results - print("\nThen: Application should have handled the outage gracefully") - print("Results:") - print(f" - Successful requests: {request_results['successes']}") - print(f" - Failed requests: {len(request_results['errors'])}") - print(f" - Error types seen: {request_results['error_types']}") - - # Verify we had both successes and failures - assert ( - request_results["successes"] > 0 - ), "Should have successful requests before/after outage" - assert len(request_results["errors"]) > 0, "Should have errors during outage" - assert ( - "503" in request_results["error_types"] or len(request_results["error_types"]) > 0 - ), "Should have seen 503 errors or connection errors" - - # Final health check - health_response = await app_client.get("/health") - assert health_response.status_code == 200 - assert health_response.json()["cassandra_connected"] is True - print("✓ Final health check confirms recovery") - - print("\n✅ Active load reconnection test completed successfully!") - print(" - Application continued serving requests where possible") - print(" - Errors were returned appropriately during outage") - print(" - Automatic recovery restored full functionality") - - run_async(test_scenario()) - - def test_rapid_connection_cycling(self, app_client, cassandra_container): - """ - Given: A FastAPI application connected to Cassandra - When: Cassandra connection is rapidly cycled (quick disable/enable) - Then: The application should remain stable and not leak resources - """ - - async def test_scenario(): - print("\nGiven: A FastAPI application with stable Cassandra connection") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Create initial user to establish baseline - initial_user = {"name": "Baseline User", "email": "baseline@test.com", "age": 25} - response = await app_client.post("/users", json=initial_user) - assert response.status_code == 201 - print("✓ Baseline functionality confirmed") - - print("\nWhen: Rapidly cycling Cassandra connection") - - # Perform rapid cycles - for i in range(5): - print(f"\nRapid cycle {i+1}/5:") - - control = self._get_cassandra_control(cassandra_container) - - if os.environ.get("CI") == "true": - print(" - Skipping cycle in CI") - break - - # Quick disable - control.disable_binary_protocol() - print(" - Disabled") - - # Very short wait - await asyncio.sleep(0.5) - - # Quick enable - control.enable_binary_protocol() - print(" - Enabled") - - # Minimal wait before next cycle - await asyncio.sleep(1) - - print("\nThen: Application should remain stable and recover") - - # The FastAPI app has ConstantReconnectionPolicy with 2 second delay - # So it should recover automatically once Cassandra is available - print("Waiting for FastAPI app to automatically recover...") - recovery_start = time.time() - app_recovered = False - - # Wait for the app to recover - checking via health endpoint and actual operations - while time.time() - recovery_start < 15: - try: - # Test with a real operation - test_user = { - "name": "Recovery Test User", - "email": "recovery@test.com", - "age": 30, - } - response = await app_client.post("/users", json=test_user, timeout=3.0) - if response.status_code == 201: - app_recovered = True - recovery_time = time.time() - recovery_start - print(f"✓ App recovered and accepting requests (took {recovery_time:.1f}s)") - break - else: - print(f" - Got status {response.status_code}, waiting for recovery...") - except Exception as e: - print(f" - Still recovering: {type(e).__name__}") - - await asyncio.sleep(1) - - assert ( - app_recovered - ), "FastAPI app should automatically recover when Cassandra is available" - - # Verify health check also shows recovery - health_response = await app_client.get("/health") - assert health_response.status_code == 200 - assert health_response.json()["cassandra_connected"] is True - print("✓ Health check confirms full recovery") - - # Verify streaming works after recovery - stream_response = await app_client.get("/users/stream?limit=5") - assert stream_response.status_code == 200 - print("✓ Streaming functionality recovered") - - print("\n✅ Rapid connection cycling test completed!") - print(" - Application remained stable during rapid cycling") - print(" - Automatic recovery worked as expected") - print(" - All functionality restored after Cassandra recovery") - - run_async(test_scenario()) diff --git a/tests/benchmarks/README.md b/tests/benchmarks/README.md deleted file mode 100644 index 6335338..0000000 --- a/tests/benchmarks/README.md +++ /dev/null @@ -1,149 +0,0 @@ -# Performance Benchmarks - -This directory contains performance benchmarks that ensure async-cassandra maintains its performance characteristics and catches any regressions. - -## Overview - -The benchmarks measure key performance indicators with defined thresholds: -- Query latency (average, P95, P99, max) -- Throughput (queries per second) -- Concurrency handling -- Memory efficiency -- CPU usage -- Streaming performance - -## Benchmark Categories - -### 1. Query Performance (`test_query_performance.py`) -- Single query latency benchmarks -- Concurrent query throughput -- Async vs sync performance comparison -- Query latency under sustained load -- Prepared statement performance benefits - -### 2. Streaming Performance (`test_streaming_performance.py`) -- Memory efficiency vs regular queries -- Streaming throughput for large datasets -- Latency overhead of streaming -- Page-by-page processing performance -- Concurrent streaming operations - -### 3. Concurrency Performance (`test_concurrency_performance.py`) -- High concurrency throughput -- Connection pool efficiency -- Resource usage under load -- Operation isolation -- Graceful degradation under overload - -## Performance Thresholds - -Default performance thresholds are defined in `benchmark_config.py`: - -```python -# Query latency thresholds -single_query_max: 100ms -single_query_p99: 50ms -single_query_p95: 30ms -single_query_avg: 20ms - -# Throughput thresholds -min_throughput_sync: 50 qps -min_throughput_async: 500 qps - -# Concurrency thresholds -max_concurrent_queries: 1000 -concurrency_speedup_factor: 5x - -# Resource thresholds -max_memory_per_connection: 10MB -max_error_rate: 1% -``` - -## Running Benchmarks - -### Basic Usage - -```bash -# Run all benchmarks -pytest tests/benchmarks/ -m benchmark - -# Run specific benchmark category -pytest tests/benchmarks/test_query_performance.py -v - -# Run with custom markers -pytest tests/benchmarks/ -m "benchmark and not slow" -``` - -### Using the Benchmark Runner - -```bash -# Run benchmarks with report generation -python -m tests.benchmarks.benchmark_runner - -# Run with custom output directory -python -m tests.benchmarks.benchmark_runner --output ./results - -# Run specific benchmarks -python -m tests.benchmarks.benchmark_runner --markers "benchmark and query" -``` - -## Interpreting Results - -### Success Criteria -- All benchmarks must pass their defined thresholds -- No performance regressions compared to baseline -- Resource usage remains within acceptable limits - -### Common Failure Reasons -1. **Latency threshold exceeded**: Query taking longer than expected -2. **Throughput below minimum**: Not achieving required operations/second -3. **Memory overhead too high**: Streaming using too much memory -4. **Error rate exceeded**: Too many failures under load - -## Writing New Benchmarks - -When adding benchmarks: - -1. **Define clear thresholds** based on expected performance -2. **Warm up** before measuring to avoid cold start effects -3. **Measure multiple iterations** for statistical significance -4. **Consider resource usage** not just speed -5. **Test edge cases** like overload conditions - -Example structure: -```python -@pytest.mark.benchmark -async def test_new_performance_metric(benchmark_session): - """ - Benchmark description. - - GIVEN initial conditions - WHEN operation is performed - THEN performance should meet thresholds - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - # Warm up - # ... warm up code ... - - # Measure performance - # ... measurement code ... - - # Verify thresholds - assert metric < threshold, f"Metric {metric} exceeds threshold {threshold}" -``` - -## CI/CD Integration - -Benchmarks should be run: -- On every PR to detect regressions -- Nightly for comprehensive testing -- Before releases to ensure performance - -## Performance Monitoring - -Results can be tracked over time to identify: -- Performance trends -- Gradual degradation -- Impact of changes -- Optimization opportunities diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py deleted file mode 100644 index 14d0480..0000000 --- a/tests/benchmarks/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Performance benchmarks for async-cassandra. - -These benchmarks ensure the library maintains its performance -characteristics and identify any regressions. -""" diff --git a/tests/benchmarks/benchmark_config.py b/tests/benchmarks/benchmark_config.py deleted file mode 100644 index 5309ee4..0000000 --- a/tests/benchmarks/benchmark_config.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Configuration and thresholds for performance benchmarks. -""" - -from dataclasses import dataclass -from typing import Dict, Optional - - -@dataclass -class BenchmarkThresholds: - """Performance thresholds for different operations.""" - - # Query latency thresholds (in seconds) - single_query_max: float = 0.1 # 100ms max for single query - single_query_p99: float = 0.05 # 50ms for 99th percentile - single_query_p95: float = 0.03 # 30ms for 95th percentile - single_query_avg: float = 0.02 # 20ms average - - # Throughput thresholds (queries per second) - min_throughput_sync: float = 50 # Minimum 50 qps for sync operations - min_throughput_async: float = 500 # Minimum 500 qps for async operations - - # Concurrency thresholds - max_concurrent_queries: int = 1000 # Support at least 1000 concurrent queries - concurrency_speedup_factor: float = 5.0 # Async should be 5x faster than sync - - # Streaming thresholds - streaming_memory_overhead: float = 1.5 # Max 50% more memory than data size - streaming_latency_overhead: float = 1.2 # Max 20% slower than regular queries - - # Resource usage thresholds - max_memory_per_connection: float = 10.0 # Max 10MB per connection - max_cpu_usage_idle: float = 0.05 # Max 5% CPU when idle - - # Error rate thresholds - max_error_rate: float = 0.01 # Max 1% error rate under load - max_timeout_rate: float = 0.001 # Max 0.1% timeout rate - - -@dataclass -class BenchmarkResult: - """Result of a benchmark run.""" - - name: str - duration: float - operations: int - throughput: float - latency_avg: float - latency_p95: float - latency_p99: float - latency_max: float - errors: int - error_rate: float - memory_used_mb: float - cpu_percent: float - passed: bool - failure_reason: Optional[str] = None - metadata: Optional[Dict] = None - - -class BenchmarkConfig: - """Configuration for benchmark runs.""" - - # Test data configuration - TEST_KEYSPACE = "benchmark_test" - TEST_TABLE = "benchmark_data" - - # Data sizes for different benchmark scenarios - SMALL_DATASET_SIZE = 100 - MEDIUM_DATASET_SIZE = 1000 - LARGE_DATASET_SIZE = 10000 - - # Concurrency levels - LOW_CONCURRENCY = 10 - MEDIUM_CONCURRENCY = 100 - HIGH_CONCURRENCY = 1000 - - # Test durations - QUICK_TEST_DURATION = 5 # seconds - STANDARD_TEST_DURATION = 30 # seconds - STRESS_TEST_DURATION = 300 # seconds (5 minutes) - - # Default thresholds - DEFAULT_THRESHOLDS = BenchmarkThresholds() diff --git a/tests/benchmarks/benchmark_runner.py b/tests/benchmarks/benchmark_runner.py deleted file mode 100644 index 6889197..0000000 --- a/tests/benchmarks/benchmark_runner.py +++ /dev/null @@ -1,233 +0,0 @@ -""" -Benchmark runner with reporting capabilities. - -This module provides utilities to run benchmarks and generate -performance reports with threshold validation. -""" - -import json -from datetime import datetime -from pathlib import Path -from typing import Dict, List, Optional - -import pytest - -from .benchmark_config import BenchmarkResult - - -class BenchmarkRunner: - """Runner for performance benchmarks with reporting.""" - - def __init__(self, output_dir: Optional[Path] = None): - """Initialize benchmark runner.""" - self.output_dir = output_dir or Path("benchmark_results") - self.output_dir.mkdir(exist_ok=True) - self.results: List[BenchmarkResult] = [] - - def run_benchmarks(self, markers: str = "benchmark", verbose: bool = True) -> bool: - """ - Run benchmarks and collect results. - - Args: - markers: Pytest markers to select benchmarks - verbose: Whether to print verbose output - - Returns: - True if all benchmarks passed thresholds - """ - # Run pytest with benchmark markers - timestamp = datetime.now().isoformat() - - if verbose: - print(f"Running benchmarks at {timestamp}") - print("-" * 60) - - # Run benchmarks - pytest_args = [ - "tests/benchmarks", - f"-m={markers}", - "-v" if verbose else "-q", - "--tb=short", - ] - - result = pytest.main(pytest_args) - - all_passed = result == 0 - - if verbose: - print("-" * 60) - print(f"Benchmark run completed. All passed: {all_passed}") - - return all_passed - - def generate_report(self, results: List[BenchmarkResult]) -> Dict: - """Generate benchmark report.""" - report = { - "timestamp": datetime.now().isoformat(), - "summary": { - "total_benchmarks": len(results), - "passed": sum(1 for r in results if r.passed), - "failed": sum(1 for r in results if not r.passed), - }, - "results": [], - } - - for result in results: - result_data = { - "name": result.name, - "passed": result.passed, - "metrics": { - "duration": result.duration, - "throughput": result.throughput, - "latency_avg": result.latency_avg, - "latency_p95": result.latency_p95, - "latency_p99": result.latency_p99, - "latency_max": result.latency_max, - "error_rate": result.error_rate, - "memory_used_mb": result.memory_used_mb, - "cpu_percent": result.cpu_percent, - }, - } - - if not result.passed: - result_data["failure_reason"] = result.failure_reason - - if result.metadata: - result_data["metadata"] = result.metadata - - report["results"].append(result_data) - - return report - - def save_report(self, report: Dict, filename: Optional[str] = None) -> Path: - """Save benchmark report to file.""" - if not filename: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"benchmark_report_{timestamp}.json" - - filepath = self.output_dir / filename - - with open(filepath, "w") as f: - json.dump(report, f, indent=2) - - return filepath - - def compare_results( - self, current: List[BenchmarkResult], baseline: List[BenchmarkResult] - ) -> Dict: - """Compare current results against baseline.""" - comparison = { - "improved": [], - "regressed": [], - "unchanged": [], - } - - # Create baseline lookup - baseline_by_name = {r.name: r for r in baseline} - - for current_result in current: - baseline_result = baseline_by_name.get(current_result.name) - - if not baseline_result: - continue - - # Compare key metrics - throughput_change = ( - (current_result.throughput - baseline_result.throughput) - / baseline_result.throughput - if baseline_result.throughput > 0 - else 0 - ) - - latency_change = ( - (current_result.latency_avg - baseline_result.latency_avg) - / baseline_result.latency_avg - if baseline_result.latency_avg > 0 - else 0 - ) - - comparison_entry = { - "name": current_result.name, - "throughput_change": throughput_change, - "latency_change": latency_change, - "current": { - "throughput": current_result.throughput, - "latency_avg": current_result.latency_avg, - }, - "baseline": { - "throughput": baseline_result.throughput, - "latency_avg": baseline_result.latency_avg, - }, - } - - # Categorize change - if throughput_change > 0.1 or latency_change < -0.1: - comparison["improved"].append(comparison_entry) - elif throughput_change < -0.1 or latency_change > 0.1: - comparison["regressed"].append(comparison_entry) - else: - comparison["unchanged"].append(comparison_entry) - - return comparison - - def print_summary(self, report: Dict) -> None: - """Print benchmark summary to console.""" - print("\nBenchmark Summary") - print("=" * 60) - print(f"Total benchmarks: {report['summary']['total_benchmarks']}") - print(f"Passed: {report['summary']['passed']}") - print(f"Failed: {report['summary']['failed']}") - print() - - if report["summary"]["failed"] > 0: - print("Failed Benchmarks:") - print("-" * 40) - for result in report["results"]: - if not result["passed"]: - print(f" - {result['name']}") - print(f" Reason: {result.get('failure_reason', 'Unknown')}") - print() - - print("Performance Metrics:") - print("-" * 40) - for result in report["results"]: - if result["passed"]: - metrics = result["metrics"] - print(f" {result['name']}:") - print(f" Throughput: {metrics['throughput']:.1f} ops/sec") - print(f" Avg Latency: {metrics['latency_avg']*1000:.1f} ms") - print(f" P99 Latency: {metrics['latency_p99']*1000:.1f} ms") - - -def main(): - """Run benchmarks from command line.""" - import argparse - - parser = argparse.ArgumentParser(description="Run async-cassandra benchmarks") - parser.add_argument( - "--markers", default="benchmark", help="Pytest markers to select benchmarks" - ) - parser.add_argument("--output", type=Path, help="Output directory for reports") - parser.add_argument("--quiet", action="store_true", help="Suppress verbose output") - - args = parser.parse_args() - - runner = BenchmarkRunner(output_dir=args.output) - - # Run benchmarks - all_passed = runner.run_benchmarks(markers=args.markers, verbose=not args.quiet) - - # Generate and save report - if runner.results: - report = runner.generate_report(runner.results) - report_path = runner.save_report(report) - - if not args.quiet: - runner.print_summary(report) - print(f"\nReport saved to: {report_path}") - - return 0 if all_passed else 1 - - -if __name__ == "__main__": - exit(main()) diff --git a/tests/benchmarks/test_concurrency_performance.py b/tests/benchmarks/test_concurrency_performance.py deleted file mode 100644 index 7fa3569..0000000 --- a/tests/benchmarks/test_concurrency_performance.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -Performance benchmarks for concurrency and resource usage. - -These benchmarks validate the library's ability to handle -high concurrency efficiently with reasonable resource usage. -""" - -import asyncio -import gc -import os -import statistics -import time - -import psutil -import pytest -import pytest_asyncio - -from async_cassandra import AsyncCassandraSession, AsyncCluster - -from .benchmark_config import BenchmarkConfig - - -@pytest.mark.benchmark -class TestConcurrencyPerformance: - """Benchmarks for concurrency handling and resource efficiency.""" - - @pytest_asyncio.fixture - async def benchmark_session(self) -> AsyncCassandraSession: - """Create session for concurrency benchmarks.""" - cluster = AsyncCluster( - contact_points=["localhost"], - executor_threads=16, # More threads for concurrency tests - ) - session = await cluster.connect() - - # Create test keyspace and table - await session.execute( - f""" - CREATE KEYSPACE IF NOT EXISTS {BenchmarkConfig.TEST_KEYSPACE} - WITH REPLICATION = {{'class': 'SimpleStrategy', 'replication_factor': 1}} - """ - ) - await session.set_keyspace(BenchmarkConfig.TEST_KEYSPACE) - - await session.execute("DROP TABLE IF EXISTS concurrency_test") - await session.execute( - """ - CREATE TABLE concurrency_test ( - id UUID PRIMARY KEY, - data TEXT, - counter INT, - updated_at TIMESTAMP - ) - """ - ) - - yield session - - await session.close() - await cluster.shutdown() - - @pytest.mark.asyncio - async def test_high_concurrency_throughput(self, benchmark_session): - """ - Benchmark throughput under high concurrency. - - GIVEN many concurrent operations - WHEN executed simultaneously - THEN system should maintain high throughput - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - # Prepare statements - insert_stmt = await benchmark_session.prepare( - "INSERT INTO concurrency_test (id, data, counter, updated_at) VALUES (?, ?, ?, toTimestamp(now()))" - ) - select_stmt = await benchmark_session.prepare("SELECT * FROM concurrency_test WHERE id = ?") - - async def mixed_operations(op_id: int): - """Perform mixed read/write operations.""" - import uuid - - # Insert - record_id = uuid.uuid4() - await benchmark_session.execute(insert_stmt, [record_id, f"data_{op_id}", op_id]) - - # Read back - result = await benchmark_session.execute(select_stmt, [record_id]) - row = result.one() - - return row is not None - - # Benchmark high concurrency - num_operations = 1000 - start_time = time.perf_counter() - - tasks = [mixed_operations(i) for i in range(num_operations)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - duration = time.perf_counter() - start_time - - # Calculate metrics - successful = sum(1 for r in results if r is True) - errors = sum(1 for r in results if isinstance(r, Exception)) - throughput = successful / duration - - # Verify thresholds - assert ( - throughput >= thresholds.min_throughput_async - ), f"Throughput {throughput:.1f} ops/sec below threshold" - assert ( - successful >= num_operations * 0.99 - ), f"Success rate {successful/num_operations:.1%} below 99%" - assert errors == 0, f"Unexpected errors: {errors}" - - @pytest.mark.asyncio - async def test_connection_pool_efficiency(self, benchmark_session): - """ - Benchmark connection pool handling under load. - - GIVEN limited connection pool - WHEN many requests compete for connections - THEN pool should be used efficiently - """ - # Create a cluster with limited connections - limited_cluster = AsyncCluster( - contact_points=["localhost"], - executor_threads=4, # Limited threads - ) - limited_session = await limited_cluster.connect() - await limited_session.set_keyspace(BenchmarkConfig.TEST_KEYSPACE) - - try: - select_stmt = await limited_session.prepare("SELECT * FROM concurrency_test LIMIT 1") - - # Track connection wait times (removed - not needed) - - async def timed_query(query_id: int): - """Execute query and measure wait time.""" - start = time.perf_counter() - - # This might wait for available connection - result = await limited_session.execute(select_stmt) - _ = result.one() - - duration = time.perf_counter() - start - return duration - - # Run many concurrent queries with limited pool - num_queries = 100 - query_times = await asyncio.gather(*[timed_query(i) for i in range(num_queries)]) - - # Calculate metrics - avg_time = statistics.mean(query_times) - p95_time = statistics.quantiles(query_times, n=20)[18] - - # Pool should handle load efficiently - assert avg_time < 0.1, f"Average query time {avg_time:.3f}s indicates pool contention" - assert p95_time < 0.2, f"P95 query time {p95_time:.3f}s indicates severe contention" - - finally: - await limited_session.close() - await limited_cluster.shutdown() - - @pytest.mark.asyncio - async def test_resource_usage_under_load(self, benchmark_session): - """ - Benchmark resource usage (CPU, memory) under sustained load. - - GIVEN sustained concurrent load - WHEN system processes requests - THEN resource usage should remain reasonable - """ - - # Get process for monitoring - process = psutil.Process(os.getpid()) - - # Prepare statement - select_stmt = await benchmark_session.prepare("SELECT * FROM concurrency_test LIMIT 10") - - # Collect baseline metrics - gc.collect() - baseline_memory = process.memory_info().rss / 1024 / 1024 # MB - process.cpu_percent(interval=0.1) - - # Resource tracking - memory_samples = [] - cpu_samples = [] - - async def load_generator(): - """Generate continuous load.""" - while True: - try: - await benchmark_session.execute(select_stmt) - await asyncio.sleep(0.001) # Small delay - except asyncio.CancelledError: - break - except Exception: - pass - - # Start load generators - load_tasks = [ - asyncio.create_task(load_generator()) for _ in range(50) # 50 concurrent workers - ] - - # Monitor resources for 10 seconds - monitor_duration = 10 - sample_interval = 0.5 - samples = int(monitor_duration / sample_interval) - - for _ in range(samples): - await asyncio.sleep(sample_interval) - - memory_mb = process.memory_info().rss / 1024 / 1024 - cpu_percent = process.cpu_percent(interval=None) - - memory_samples.append(memory_mb - baseline_memory) - cpu_samples.append(cpu_percent) - - # Stop load generators - for task in load_tasks: - task.cancel() - await asyncio.gather(*load_tasks, return_exceptions=True) - - # Calculate metrics - avg_memory_increase = statistics.mean(memory_samples) - max_memory_increase = max(memory_samples) - avg_cpu = statistics.mean(cpu_samples) - max(cpu_samples) - - # Verify resource usage - assert ( - avg_memory_increase < 100 - ), f"Average memory increase {avg_memory_increase:.1f}MB exceeds 100MB" - assert ( - max_memory_increase < 200 - ), f"Max memory increase {max_memory_increase:.1f}MB exceeds 200MB" - # CPU thresholds are relaxed as they depend on system - assert avg_cpu < 80, f"Average CPU usage {avg_cpu:.1f}% exceeds 80%" - - @pytest.mark.asyncio - async def test_concurrent_operation_isolation(self, benchmark_session): - """ - Benchmark operation isolation under concurrency. - - GIVEN concurrent operations on same data - WHEN operations execute simultaneously - THEN they should not interfere with each other - """ - import uuid - - # Create test record - test_id = uuid.uuid4() - await benchmark_session.execute( - "INSERT INTO concurrency_test (id, data, counter, updated_at) VALUES (?, ?, ?, toTimestamp(now()))", - [test_id, "initial", 0], - ) - - # Prepare statements - update_stmt = await benchmark_session.prepare( - "UPDATE concurrency_test SET counter = counter + 1 WHERE id = ?" - ) - select_stmt = await benchmark_session.prepare( - "SELECT counter FROM concurrency_test WHERE id = ?" - ) - - # Concurrent increment operations - num_increments = 100 - - async def increment_counter(): - """Increment counter (may have race conditions).""" - await benchmark_session.execute(update_stmt, [test_id]) - return True - - # Execute concurrent increments - start_time = time.perf_counter() - - await asyncio.gather(*[increment_counter() for _ in range(num_increments)]) - - duration = time.perf_counter() - start_time - - # Check final value - final_result = await benchmark_session.execute(select_stmt, [test_id]) - final_counter = final_result.one().counter - - # Calculate metrics - throughput = num_increments / duration - - # Note: Due to race conditions, final counter may be less than num_increments - # This is expected behavior without proper synchronization - assert throughput > 100, f"Increment throughput {throughput:.1f} ops/sec too low" - assert final_counter > 0, "Counter should have been incremented" - - @pytest.mark.asyncio - async def test_graceful_degradation_under_overload(self, benchmark_session): - """ - Benchmark system behavior under overload conditions. - - GIVEN more load than system can handle - WHEN system is overloaded - THEN it should degrade gracefully - """ - - # Prepare a complex query - complex_query = """ - SELECT * FROM concurrency_test - WHERE token(id) > token(?) - LIMIT 100 - ALLOW FILTERING - """ - - errors = [] - latencies = [] - - async def overload_operation(op_id: int): - """Operation that contributes to overload.""" - import uuid - - start = time.perf_counter() - try: - result = await benchmark_session.execute(complex_query, [uuid.uuid4()]) - # Consume results - count = 0 - async for _ in result: - count += 1 - - latency = time.perf_counter() - start - latencies.append(latency) - return True - - except Exception as e: - errors.append(str(e)) - return False - - # Generate overload with many concurrent operations - num_operations = 500 - - start_time = time.perf_counter() - results = await asyncio.gather( - *[overload_operation(i) for i in range(num_operations)], return_exceptions=True - ) - time.perf_counter() - start_time - - # Calculate metrics - successful = sum(1 for r in results if r is True) - error_rate = len(errors) / num_operations - - if latencies: - statistics.mean(latencies) - p99_latency = statistics.quantiles(latencies, n=100)[98] - else: - float("inf") - p99_latency = float("inf") - - # Even under overload, system should maintain some service - assert ( - successful > num_operations * 0.5 - ), f"Success rate {successful/num_operations:.1%} too low under overload" - assert error_rate < 0.5, f"Error rate {error_rate:.1%} too high" - - # Latencies will be high but should be bounded - assert p99_latency < 5.0, f"P99 latency {p99_latency:.1f}s exceeds 5 second timeout" diff --git a/tests/benchmarks/test_query_performance.py b/tests/benchmarks/test_query_performance.py deleted file mode 100644 index b76e0c2..0000000 --- a/tests/benchmarks/test_query_performance.py +++ /dev/null @@ -1,337 +0,0 @@ -""" -Performance benchmarks for query operations. - -These benchmarks measure latency, throughput, and resource usage -for various query patterns. -""" - -import asyncio -import statistics -import time - -import pytest -import pytest_asyncio - -from async_cassandra import AsyncCassandraSession, AsyncCluster - -from .benchmark_config import BenchmarkConfig - - -@pytest.mark.benchmark -class TestQueryPerformance: - """Benchmarks for query performance.""" - - @pytest_asyncio.fixture - async def benchmark_session(self) -> AsyncCassandraSession: - """Create session for benchmarking.""" - cluster = AsyncCluster( - contact_points=["localhost"], - executor_threads=8, # Optimized for benchmarks - ) - session = await cluster.connect() - - # Create benchmark keyspace and table - await session.execute( - f""" - CREATE KEYSPACE IF NOT EXISTS {BenchmarkConfig.TEST_KEYSPACE} - WITH REPLICATION = {{'class': 'SimpleStrategy', 'replication_factor': 1}} - """ - ) - await session.set_keyspace(BenchmarkConfig.TEST_KEYSPACE) - - await session.execute(f"DROP TABLE IF EXISTS {BenchmarkConfig.TEST_TABLE}") - await session.execute( - f""" - CREATE TABLE {BenchmarkConfig.TEST_TABLE} ( - id INT PRIMARY KEY, - data TEXT, - value DOUBLE, - created_at TIMESTAMP - ) - """ - ) - - # Insert test data - insert_stmt = await session.prepare( - f"INSERT INTO {BenchmarkConfig.TEST_TABLE} (id, data, value, created_at) VALUES (?, ?, ?, toTimestamp(now()))" - ) - - for i in range(BenchmarkConfig.LARGE_DATASET_SIZE): - await session.execute(insert_stmt, [i, f"test_data_{i}", i * 1.5]) - - yield session - - await session.close() - await cluster.shutdown() - - @pytest.mark.asyncio - async def test_single_query_latency(self, benchmark_session): - """ - Benchmark single query latency. - - GIVEN a simple query - WHEN executed individually - THEN latency should be within acceptable thresholds - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - # Prepare statement - select_stmt = await benchmark_session.prepare( - f"SELECT * FROM {BenchmarkConfig.TEST_TABLE} WHERE id = ?" - ) - - # Warm up - for i in range(10): - await benchmark_session.execute(select_stmt, [i]) - - # Benchmark - latencies = [] - errors = 0 - - for i in range(100): - start = time.perf_counter() - try: - result = await benchmark_session.execute(select_stmt, [i % 1000]) - _ = result.one() # Force result materialization - latency = time.perf_counter() - start - latencies.append(latency) - except Exception: - errors += 1 - - # Calculate metrics - avg_latency = statistics.mean(latencies) - p95_latency = statistics.quantiles(latencies, n=20)[18] # 95th percentile - p99_latency = statistics.quantiles(latencies, n=100)[98] # 99th percentile - max_latency = max(latencies) - - # Verify thresholds - assert ( - avg_latency < thresholds.single_query_avg - ), f"Average latency {avg_latency:.3f}s exceeds threshold {thresholds.single_query_avg}s" - assert ( - p95_latency < thresholds.single_query_p95 - ), f"P95 latency {p95_latency:.3f}s exceeds threshold {thresholds.single_query_p95}s" - assert ( - p99_latency < thresholds.single_query_p99 - ), f"P99 latency {p99_latency:.3f}s exceeds threshold {thresholds.single_query_p99}s" - assert ( - max_latency < thresholds.single_query_max - ), f"Max latency {max_latency:.3f}s exceeds threshold {thresholds.single_query_max}s" - assert errors == 0, f"Query errors occurred: {errors}" - - @pytest.mark.asyncio - async def test_concurrent_query_throughput(self, benchmark_session): - """ - Benchmark concurrent query throughput. - - GIVEN multiple concurrent queries - WHEN executed with asyncio - THEN throughput should meet minimum requirements - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - # Prepare statement - select_stmt = await benchmark_session.prepare( - f"SELECT * FROM {BenchmarkConfig.TEST_TABLE} WHERE id = ?" - ) - - async def execute_query(query_id: int): - """Execute a single query.""" - try: - result = await benchmark_session.execute(select_stmt, [query_id % 1000]) - _ = result.one() - return True, time.perf_counter() - except Exception: - return False, time.perf_counter() - - # Benchmark concurrent execution - num_queries = 1000 - start_time = time.perf_counter() - - tasks = [execute_query(i) for i in range(num_queries)] - results = await asyncio.gather(*tasks) - - end_time = time.perf_counter() - duration = end_time - start_time - - # Calculate metrics - successful = sum(1 for success, _ in results if success) - throughput = successful / duration - - # Verify thresholds - assert ( - throughput >= thresholds.min_throughput_async - ), f"Throughput {throughput:.1f} qps below threshold {thresholds.min_throughput_async} qps" - assert ( - successful >= num_queries * 0.99 - ), f"Success rate {successful/num_queries:.1%} below 99%" - - @pytest.mark.asyncio - async def test_async_vs_sync_performance(self, benchmark_session): - """ - Benchmark async performance advantage over sync-style execution. - - GIVEN the same workload - WHEN executed async vs sequentially - THEN async should show significant performance improvement - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - # Prepare statement - select_stmt = await benchmark_session.prepare( - f"SELECT * FROM {BenchmarkConfig.TEST_TABLE} WHERE id = ?" - ) - - num_queries = 100 - - # Benchmark sequential execution - sync_start = time.perf_counter() - for i in range(num_queries): - result = await benchmark_session.execute(select_stmt, [i]) - _ = result.one() - sync_duration = time.perf_counter() - sync_start - sync_throughput = num_queries / sync_duration - - # Benchmark concurrent execution - async_start = time.perf_counter() - tasks = [] - for i in range(num_queries): - task = benchmark_session.execute(select_stmt, [i]) - tasks.append(task) - await asyncio.gather(*tasks) - async_duration = time.perf_counter() - async_start - async_throughput = num_queries / async_duration - - # Calculate speedup - speedup = async_throughput / sync_throughput - - # Verify thresholds - assert ( - speedup >= thresholds.concurrency_speedup_factor - ), f"Async speedup {speedup:.1f}x below threshold {thresholds.concurrency_speedup_factor}x" - assert ( - async_throughput >= thresholds.min_throughput_async - ), f"Async throughput {async_throughput:.1f} qps below threshold" - - @pytest.mark.asyncio - async def test_query_latency_under_load(self, benchmark_session): - """ - Benchmark query latency under sustained load. - - GIVEN continuous query load - WHEN system is under stress - THEN latency should remain acceptable - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - # Prepare statement - select_stmt = await benchmark_session.prepare( - f"SELECT * FROM {BenchmarkConfig.TEST_TABLE} WHERE id = ?" - ) - - latencies = [] - errors = 0 - - async def query_worker(worker_id: int, duration: float): - """Worker that continuously executes queries.""" - nonlocal errors - worker_latencies = [] - end_time = time.perf_counter() + duration - - while time.perf_counter() < end_time: - start = time.perf_counter() - try: - query_id = int(time.time() * 1000) % 1000 - result = await benchmark_session.execute(select_stmt, [query_id]) - _ = result.one() - latency = time.perf_counter() - start - worker_latencies.append(latency) - except Exception: - errors += 1 - - # Small delay to prevent overwhelming - await asyncio.sleep(0.001) - - return worker_latencies - - # Run workers concurrently for sustained load - num_workers = 50 - test_duration = 10 # seconds - - worker_tasks = [query_worker(i, test_duration) for i in range(num_workers)] - - worker_results = await asyncio.gather(*worker_tasks) - - # Aggregate all latencies - for worker_latencies in worker_results: - latencies.extend(worker_latencies) - - # Calculate metrics - avg_latency = statistics.mean(latencies) - statistics.quantiles(latencies, n=20)[18] - p99_latency = statistics.quantiles(latencies, n=100)[98] - error_rate = errors / len(latencies) if latencies else 1.0 - - # Verify thresholds under load (relaxed) - assert ( - avg_latency < thresholds.single_query_avg * 2 - ), f"Average latency under load {avg_latency:.3f}s exceeds 2x threshold" - assert ( - p99_latency < thresholds.single_query_p99 * 2 - ), f"P99 latency under load {p99_latency:.3f}s exceeds 2x threshold" - assert ( - error_rate < thresholds.max_error_rate - ), f"Error rate {error_rate:.1%} exceeds threshold {thresholds.max_error_rate:.1%}" - - @pytest.mark.asyncio - async def test_prepared_statement_performance(self, benchmark_session): - """ - Benchmark prepared statement performance advantage. - - GIVEN queries that can be prepared - WHEN using prepared statements vs simple statements - THEN prepared statements should show performance benefit - """ - num_queries = 500 - - # Benchmark simple statements - simple_latencies = [] - simple_start = time.perf_counter() - - for i in range(num_queries): - query_start = time.perf_counter() - result = await benchmark_session.execute( - f"SELECT * FROM {BenchmarkConfig.TEST_TABLE} WHERE id = {i}" - ) - _ = result.one() - simple_latencies.append(time.perf_counter() - query_start) - - simple_duration = time.perf_counter() - simple_start - - # Benchmark prepared statements - prepared_stmt = await benchmark_session.prepare( - f"SELECT * FROM {BenchmarkConfig.TEST_TABLE} WHERE id = ?" - ) - - prepared_latencies = [] - prepared_start = time.perf_counter() - - for i in range(num_queries): - query_start = time.perf_counter() - result = await benchmark_session.execute(prepared_stmt, [i]) - _ = result.one() - prepared_latencies.append(time.perf_counter() - query_start) - - prepared_duration = time.perf_counter() - prepared_start - - # Calculate metrics - simple_avg = statistics.mean(simple_latencies) - prepared_avg = statistics.mean(prepared_latencies) - performance_gain = (simple_avg - prepared_avg) / simple_avg - - # Verify prepared statements are faster - assert prepared_duration < simple_duration, "Prepared statements should be faster overall" - assert prepared_avg < simple_avg, "Prepared statements should have lower average latency" - assert ( - performance_gain > 0.1 - ), f"Prepared statements should show >10% performance gain, got {performance_gain:.1%}" diff --git a/tests/benchmarks/test_streaming_performance.py b/tests/benchmarks/test_streaming_performance.py deleted file mode 100644 index bbd2f03..0000000 --- a/tests/benchmarks/test_streaming_performance.py +++ /dev/null @@ -1,331 +0,0 @@ -""" -Performance benchmarks for streaming operations. - -These benchmarks ensure streaming provides memory-efficient -data processing without significant performance overhead. -""" - -import asyncio -import gc -import os -import statistics -import time - -import psutil -import pytest -import pytest_asyncio - -from async_cassandra import AsyncCassandraSession, AsyncCluster, StreamConfig - -from .benchmark_config import BenchmarkConfig - - -@pytest.mark.benchmark -class TestStreamingPerformance: - """Benchmarks for streaming performance and memory efficiency.""" - - @pytest_asyncio.fixture - async def benchmark_session(self) -> AsyncCassandraSession: - """Create session with large dataset for streaming benchmarks.""" - cluster = AsyncCluster( - contact_points=["localhost"], - executor_threads=8, - ) - session = await cluster.connect() - - # Create benchmark keyspace and table - await session.execute( - f""" - CREATE KEYSPACE IF NOT EXISTS {BenchmarkConfig.TEST_KEYSPACE} - WITH REPLICATION = {{'class': 'SimpleStrategy', 'replication_factor': 1}} - """ - ) - await session.set_keyspace(BenchmarkConfig.TEST_KEYSPACE) - - await session.execute("DROP TABLE IF EXISTS streaming_test") - await session.execute( - """ - CREATE TABLE streaming_test ( - partition_id INT, - row_id INT, - data TEXT, - value DOUBLE, - metadata MAP, - PRIMARY KEY (partition_id, row_id) - ) - """ - ) - - # Insert large dataset across multiple partitions - insert_stmt = await session.prepare( - "INSERT INTO streaming_test (partition_id, row_id, data, value, metadata) VALUES (?, ?, ?, ?, ?)" - ) - - # Create 100 partitions with 1000 rows each = 100k rows - batch_size = 100 - for partition in range(100): - batch = [] - for row in range(1000): - metadata = {f"key_{i}": f"value_{i}" for i in range(5)} - batch.append((partition, row, f"data_{partition}_{row}" * 10, row * 1.5, metadata)) - - # Insert in batches - for i in range(0, len(batch), batch_size): - await asyncio.gather( - *[session.execute(insert_stmt, params) for params in batch[i : i + batch_size]] - ) - - yield session - - await session.close() - await cluster.shutdown() - - @pytest.mark.asyncio - async def test_streaming_memory_efficiency(self, benchmark_session): - """ - Benchmark memory usage of streaming vs regular queries. - - GIVEN a large result set - WHEN using streaming vs loading all data - THEN streaming should use significantly less memory - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - # Get process for memory monitoring - process = psutil.Process(os.getpid()) - - # Force garbage collection - gc.collect() - - # Measure baseline memory - process.memory_info().rss / 1024 / 1024 # MB - - # Test 1: Regular query (loads all into memory) - regular_start_memory = process.memory_info().rss / 1024 / 1024 - - regular_result = await benchmark_session.execute("SELECT * FROM streaming_test LIMIT 10000") - regular_rows = [] - async for row in regular_result: - regular_rows.append(row) - - regular_peak_memory = process.memory_info().rss / 1024 / 1024 - regular_memory_used = regular_peak_memory - regular_start_memory - - # Clear memory - del regular_rows - del regular_result - gc.collect() - await asyncio.sleep(0.1) - - # Test 2: Streaming query - stream_start_memory = process.memory_info().rss / 1024 / 1024 - - stream_config = StreamConfig(fetch_size=100, max_pages=None) - stream_result = await benchmark_session.execute_stream( - "SELECT * FROM streaming_test LIMIT 10000", stream_config=stream_config - ) - - row_count = 0 - max_stream_memory = stream_start_memory - - async for row in stream_result: - row_count += 1 - if row_count % 1000 == 0: - current_memory = process.memory_info().rss / 1024 / 1024 - max_stream_memory = max(max_stream_memory, current_memory) - - stream_memory_used = max_stream_memory - stream_start_memory - - # Calculate memory efficiency - memory_ratio = stream_memory_used / regular_memory_used if regular_memory_used > 0 else 0 - - # Verify thresholds - assert ( - memory_ratio < thresholds.streaming_memory_overhead - ), f"Streaming memory ratio {memory_ratio:.2f} exceeds threshold {thresholds.streaming_memory_overhead}" - assert ( - stream_memory_used < regular_memory_used - ), f"Streaming used more memory ({stream_memory_used:.1f}MB) than regular ({regular_memory_used:.1f}MB)" - - @pytest.mark.asyncio - async def test_streaming_throughput(self, benchmark_session): - """ - Benchmark streaming throughput for large datasets. - - GIVEN a large dataset - WHEN processing with streaming - THEN throughput should be acceptable - """ - - stream_config = StreamConfig(fetch_size=1000) - - # Benchmark streaming throughput - start_time = time.perf_counter() - row_count = 0 - - stream_result = await benchmark_session.execute_stream( - "SELECT * FROM streaming_test LIMIT 50000", stream_config=stream_config - ) - - async for row in stream_result: - row_count += 1 - # Simulate minimal processing - _ = row.partition_id + row.row_id - - duration = time.perf_counter() - start_time - throughput = row_count / duration - - # Verify throughput - assert ( - throughput > 10000 - ), f"Streaming throughput {throughput:.0f} rows/sec below minimum 10k rows/sec" - assert row_count == 50000, f"Expected 50000 rows, got {row_count}" - - @pytest.mark.asyncio - async def test_streaming_latency_overhead(self, benchmark_session): - """ - Benchmark latency overhead of streaming vs regular queries. - - GIVEN queries of various sizes - WHEN comparing streaming vs regular execution - THEN streaming overhead should be minimal - """ - thresholds = BenchmarkConfig.DEFAULT_THRESHOLDS - - test_sizes = [100, 1000, 5000] - - for size in test_sizes: - # Regular query timing - regular_start = time.perf_counter() - regular_result = await benchmark_session.execute( - f"SELECT * FROM streaming_test LIMIT {size}" - ) - regular_rows = [] - async for row in regular_result: - regular_rows.append(row) - regular_duration = time.perf_counter() - regular_start - - # Streaming query timing - stream_config = StreamConfig(fetch_size=min(100, size)) - stream_start = time.perf_counter() - stream_result = await benchmark_session.execute_stream( - f"SELECT * FROM streaming_test LIMIT {size}", stream_config=stream_config - ) - stream_rows = [] - async for row in stream_result: - stream_rows.append(row) - stream_duration = time.perf_counter() - stream_start - - # Calculate overhead - overhead_ratio = ( - stream_duration / regular_duration if regular_duration > 0 else float("inf") - ) - - # Verify overhead is acceptable - assert ( - overhead_ratio < thresholds.streaming_latency_overhead - ), f"Streaming overhead {overhead_ratio:.2f}x for {size} rows exceeds threshold" - assert len(stream_rows) == len( - regular_rows - ), f"Row count mismatch: streaming={len(stream_rows)}, regular={len(regular_rows)}" - - @pytest.mark.asyncio - async def test_streaming_page_processing_performance(self, benchmark_session): - """ - Benchmark page-by-page processing performance. - - GIVEN streaming with page iteration - WHEN processing pages individually - THEN performance should scale linearly with data size - """ - stream_config = StreamConfig(fetch_size=500, max_pages=100) - - page_latencies = [] - total_rows = 0 - - start_time = time.perf_counter() - - stream_result = await benchmark_session.execute_stream( - "SELECT * FROM streaming_test LIMIT 10000", stream_config=stream_config - ) - - async for page in stream_result.pages(): - page_start = time.perf_counter() - - # Process page - page_rows = 0 - for row in page: - page_rows += 1 - # Simulate processing - _ = row.value * 2 - - page_duration = time.perf_counter() - page_start - page_latencies.append(page_duration) - total_rows += page_rows - - total_duration = time.perf_counter() - start_time - - # Calculate metrics - avg_page_latency = statistics.mean(page_latencies) - page_throughput = len(page_latencies) / total_duration - row_throughput = total_rows / total_duration - - # Verify performance - assert ( - avg_page_latency < 0.1 - ), f"Average page processing time {avg_page_latency:.3f}s exceeds 100ms" - assert ( - page_throughput > 10 - ), f"Page throughput {page_throughput:.1f} pages/sec below minimum" - assert row_throughput > 5000, f"Row throughput {row_throughput:.0f} rows/sec below minimum" - - @pytest.mark.asyncio - async def test_concurrent_streaming_operations(self, benchmark_session): - """ - Benchmark concurrent streaming operations. - - GIVEN multiple concurrent streaming queries - WHEN executed simultaneously - THEN system should handle them efficiently - """ - - async def stream_worker(worker_id: int): - """Worker that processes a streaming query.""" - stream_config = StreamConfig(fetch_size=100) - - start = time.perf_counter() - row_count = 0 - - # Each worker queries different partition - stream_result = await benchmark_session.execute_stream( - f"SELECT * FROM streaming_test WHERE partition_id = {worker_id} LIMIT 1000", - stream_config=stream_config, - ) - - async for row in stream_result: - row_count += 1 - - duration = time.perf_counter() - start - return duration, row_count - - # Run concurrent streaming operations - num_workers = 10 - start_time = time.perf_counter() - - results = await asyncio.gather(*[stream_worker(i) for i in range(num_workers)]) - - total_duration = time.perf_counter() - start_time - - # Calculate metrics - worker_durations = [d for d, _ in results] - total_rows = sum(count for _, count in results) - avg_worker_duration = statistics.mean(worker_durations) - - # Verify concurrent performance - assert ( - total_duration < avg_worker_duration * 2 - ), "Concurrent streams should show parallelism benefit" - assert all( - count >= 900 for _, count in results - ), "All workers should process most of their rows" - assert total_rows >= num_workers * 900, f"Total rows {total_rows} below expected minimum" diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 732bf5a..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Pytest configuration and shared fixtures for all tests. -""" - -import asyncio -from unittest.mock import patch - -import pytest - - -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for the test session.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -@pytest.fixture(autouse=True) -def fast_shutdown_for_unit_tests(request): - """Mock the 5-second sleep in cluster shutdown for unit tests only.""" - # Skip for tests that need real timing - skip_tests = [ - "test_simplified_threading", - "test_timeout_implementation", - "test_protocol_version_bdd", - ] - - # Check if this test should be skipped - should_skip = any(skip_test in request.node.nodeid for skip_test in skip_tests) - - # Only apply to unit tests and BDD tests, not integration tests - if not should_skip and ( - "unit" in request.node.nodeid - or "_core" in request.node.nodeid - or "_features" in request.node.nodeid - or "_resilience" in request.node.nodeid - or "bdd" in request.node.nodeid - ): - # Store the original sleep function - original_sleep = asyncio.sleep - - async def mock_sleep(seconds): - # For the 5-second shutdown sleep, make it instant - if seconds == 5.0: - return - # For other sleeps, use a much shorter delay but use the original function - await original_sleep(min(seconds, 0.01)) - - with patch("asyncio.sleep", side_effect=mock_sleep): - yield - else: - # For integration tests or skipped tests, don't mock - yield diff --git a/tests/fastapi_integration/conftest.py b/tests/fastapi_integration/conftest.py deleted file mode 100644 index f59e76c..0000000 --- a/tests/fastapi_integration/conftest.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Pytest configuration for FastAPI example app tests. -""" - -import sys -from pathlib import Path - -import httpx -import pytest -import pytest_asyncio -from httpx import ASGITransport - -# Add parent directories to path -fastapi_app_dir = Path(__file__).parent.parent.parent / "examples" / "fastapi_app" -sys.path.insert(0, str(fastapi_app_dir)) # fastapi_app dir -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # project root - -# Import test utils -from tests.test_utils import ( # noqa: E402 - cleanup_keyspace, - create_test_keyspace, - generate_unique_keyspace, -) - -# Note: We don't import cassandra_container here to avoid conflicts with integration tests - - -@pytest.fixture(scope="session") -def cassandra_container(): - """Provide access to the running Cassandra container.""" - import subprocess - - # Find running container on port 9042 - for runtime in ["podman", "docker"]: - try: - result = subprocess.run( - [runtime, "ps", "--format", "{{.Names}} {{.Ports}}"], - capture_output=True, - text=True, - ) - if result.returncode == 0: - for line in result.stdout.strip().split("\n"): - if "9042" in line: - container_name = line.split()[0] - - # Create a simple container object - class Container: - def __init__(self, name, runtime_cmd): - self.container_name = name - self.runtime = runtime_cmd - - def check_health(self): - # Run nodetool info - result = subprocess.run( - [self.runtime, "exec", self.container_name, "nodetool", "info"], - capture_output=True, - text=True, - ) - - health_status = { - "native_transport": False, - "gossip": False, - "cql_available": False, - } - - if result.returncode == 0: - info = result.stdout - health_status["native_transport"] = ( - "Native Transport active: true" in info - ) - health_status["gossip"] = ( - "Gossip active" in info - and "true" in info.split("Gossip active")[1].split("\n")[0] - ) - - # Check CQL availability - cql_result = subprocess.run( - [ - self.runtime, - "exec", - self.container_name, - "cqlsh", - "-e", - "SELECT now() FROM system.local", - ], - capture_output=True, - ) - health_status["cql_available"] = cql_result.returncode == 0 - - return health_status - - return Container(container_name, runtime) - except Exception: - pass - - pytest.fail("No Cassandra container found running on port 9042") - - -@pytest_asyncio.fixture -async def unique_test_keyspace(cassandra_container): # noqa: F811 - """Create a unique keyspace for each test.""" - from async_cassandra import AsyncCluster - - # Check health before proceeding - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - pytest.fail(f"Cassandra not healthy: {health}") - - cluster = AsyncCluster(contact_points=["localhost"], protocol_version=5) - session = await cluster.connect() - - # Create unique keyspace - keyspace = generate_unique_keyspace("fastapi_test") - await create_test_keyspace(session, keyspace) - - yield keyspace - - # Cleanup - await cleanup_keyspace(session, keyspace) - await session.close() - await cluster.shutdown() - - -@pytest_asyncio.fixture -async def app_client(unique_test_keyspace): - """Create test client for the FastAPI app with isolated keyspace.""" - # First, check that Cassandra is available - from async_cassandra import AsyncCluster - - try: - test_cluster = AsyncCluster(contact_points=["localhost"], protocol_version=5) - test_session = await test_cluster.connect() - await test_session.execute("SELECT now() FROM system.local") - await test_session.close() - await test_cluster.shutdown() - except Exception as e: - pytest.fail(f"Cassandra not available: {e}") - - # Set the test keyspace in environment - import os - - os.environ["TEST_KEYSPACE"] = unique_test_keyspace - - from main import app, lifespan - - # Manually handle lifespan since httpx doesn't do it properly - async with lifespan(app): - transport = ASGITransport(app=app) - async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - yield client - - # Clean up environment - os.environ.pop("TEST_KEYSPACE", None) - - -@pytest.fixture(scope="function", autouse=True) -async def ensure_cassandra_healthy_fastapi(cassandra_container): - """Ensure Cassandra is healthy before each FastAPI test.""" - # Check health before test - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - # Try to wait a bit and check again - import asyncio - - await asyncio.sleep(2) - health = cassandra_container.check_health() - if not health["native_transport"] or not health["cql_available"]: - pytest.fail(f"Cassandra not healthy before test: {health}") - - yield - - # Optional: Check health after test - health = cassandra_container.check_health() - if not health["native_transport"]: - print(f"Warning: Cassandra health degraded after test: {health}") diff --git a/tests/fastapi_integration/test_fastapi_advanced.py b/tests/fastapi_integration/test_fastapi_advanced.py deleted file mode 100644 index 966dafb..0000000 --- a/tests/fastapi_integration/test_fastapi_advanced.py +++ /dev/null @@ -1,550 +0,0 @@ -""" -Advanced integration tests for FastAPI with async-cassandra. - -These tests cover edge cases, error conditions, and advanced scenarios -that the basic tests don't cover, following TDD principles. -""" - -import gc -import os -import platform -import threading -import time -import uuid -from concurrent.futures import ThreadPoolExecutor - -import psutil # Required dependency for advanced testing -import pytest -from fastapi.testclient import TestClient - - -@pytest.mark.integration -class TestFastAPIAdvancedScenarios: - """Advanced test scenarios for FastAPI integration.""" - - @pytest.fixture - def test_client(self): - """Create FastAPI test client.""" - from examples.fastapi_app.main import app - - with TestClient(app) as client: - yield client - - @pytest.fixture - def monitor_resources(self): - """Monitor system resources during tests.""" - process = psutil.Process(os.getpid()) - initial_memory = process.memory_info().rss / 1024 / 1024 # MB - initial_threads = threading.active_count() - initial_fds = len(process.open_files()) if platform.system() != "Windows" else 0 - - yield { - "initial_memory": initial_memory, - "initial_threads": initial_threads, - "initial_fds": initial_fds, - "process": process, - } - - # Cleanup - gc.collect() - - def test_memory_leak_detection_in_streaming(self, test_client, monitor_resources): - """ - GIVEN a streaming endpoint processing large datasets - WHEN multiple streaming operations are performed - THEN memory usage should not continuously increase (no leaks) - """ - process = monitor_resources["process"] - initial_memory = monitor_resources["initial_memory"] - - # Create test data - for i in range(1000): - user_data = {"name": f"leak_test_user_{i}", "email": f"leak{i}@example.com", "age": 25} - test_client.post("/users", json=user_data) - - memory_readings = [] - - # Perform multiple streaming operations - for iteration in range(5): - # Stream data - response = test_client.get("/users/stream/pages?limit=1000&fetch_size=100") - assert response.status_code == 200 - - # Force garbage collection - gc.collect() - time.sleep(0.1) - - # Record memory usage - current_memory = process.memory_info().rss / 1024 / 1024 - memory_readings.append(current_memory) - - # Check for memory leak - # Memory should stabilize, not continuously increase - memory_increase = max(memory_readings) - initial_memory - assert memory_increase < 50, f"Memory increased by {memory_increase}MB, possible leak" - - # Check that memory readings stabilize (not continuously increasing) - last_three = memory_readings[-3:] - variance = max(last_three) - min(last_three) - assert variance < 10, f"Memory not stabilizing, variance: {variance}MB" - - def test_thread_safety_with_concurrent_operations(self, test_client, monitor_resources): - """ - GIVEN multiple threads performing database operations - WHEN operations access shared resources - THEN no race conditions or thread safety issues should occur - """ - initial_threads = monitor_resources["initial_threads"] - results = {"errors": [], "success_count": 0} - - def perform_mixed_operations(thread_id): - try: - # Create user - user_data = { - "name": f"thread_{thread_id}_user", - "email": f"thread{thread_id}@example.com", - "age": 20 + thread_id, - } - create_resp = test_client.post("/users", json=user_data) - if create_resp.status_code != 201: - results["errors"].append(f"Thread {thread_id}: Create failed") - return - - user_id = create_resp.json()["id"] - - # Read user multiple times - for _ in range(5): - read_resp = test_client.get(f"/users/{user_id}") - if read_resp.status_code != 200: - results["errors"].append(f"Thread {thread_id}: Read failed") - - # Update user - update_data = {"age": 30 + thread_id} - update_resp = test_client.patch(f"/users/{user_id}", json=update_data) - if update_resp.status_code != 200: - results["errors"].append(f"Thread {thread_id}: Update failed") - - # Delete user - delete_resp = test_client.delete(f"/users/{user_id}") - if delete_resp.status_code != 204: - results["errors"].append(f"Thread {thread_id}: Delete failed") - - results["success_count"] += 1 - - except Exception as e: - results["errors"].append(f"Thread {thread_id}: {str(e)}") - - # Run operations in multiple threads - with ThreadPoolExecutor(max_workers=20) as executor: - futures = [executor.submit(perform_mixed_operations, i) for i in range(50)] - for future in futures: - future.result() - - # Verify results - assert len(results["errors"]) == 0, f"Thread safety errors: {results['errors']}" - assert results["success_count"] == 50 - - # Check thread count didn't explode - final_threads = threading.active_count() - thread_increase = final_threads - initial_threads - assert thread_increase < 25, f"Too many threads created: {thread_increase}" - - def test_connection_failure_and_recovery(self, test_client): - """ - GIVEN a Cassandra connection that can fail - WHEN connection failures occur - THEN the application should handle them gracefully and recover - """ - # First, verify normal operation - response = test_client.get("/health") - assert response.status_code == 200 - - # Simulate connection failure by attempting operations that might fail - # This would need mock support or actual connection manipulation - # For now, test error handling paths - - # Test handling of various scenarios - # Since this is integration test and we don't want to break the real connection, - # we'll test that the system remains stable after various operations - - # Test with large limit - response = test_client.get("/users?limit=1000") - assert response.status_code == 200 - - # Test invalid UUID handling - response = test_client.get("/users/invalid-uuid") - assert response.status_code == 400 - - # Test non-existent user - response = test_client.get(f"/users/{uuid.uuid4()}") - assert response.status_code == 404 - - # Verify system still healthy after various errors - health_response = test_client.get("/health") - assert health_response.status_code == 200 - - def test_prepared_statement_lifecycle_and_caching(self, test_client): - """ - GIVEN prepared statements used in queries - WHEN statements are prepared and reused - THEN they should be properly cached and managed - """ - # Create users with same structure to test prepared statement reuse - execution_times = [] - - for i in range(20): - start_time = time.time() - - user_data = {"name": f"ps_test_user_{i}", "email": f"ps{i}@example.com", "age": 25} - response = test_client.post("/users", json=user_data) - assert response.status_code == 201 - - execution_time = time.time() - start_time - execution_times.append(execution_time) - - # First execution might be slower (preparing statement) - # Subsequent executions should be faster - avg_first_5 = sum(execution_times[:5]) / 5 - avg_last_5 = sum(execution_times[-5:]) / 5 - - # Later executions should be at least as fast (allowing some variance) - assert avg_last_5 <= avg_first_5 * 1.5 - - def test_query_cancellation_and_timeout_behavior(self, test_client): - """ - GIVEN long-running queries - WHEN queries are cancelled or timeout - THEN resources should be properly cleaned up - """ - # Test with the slow_query endpoint - - # Test timeout behavior with a short timeout header - response = test_client.get("/slow_query", headers={"X-Request-Timeout": "0.5"}) - # Should return timeout error - assert response.status_code == 504 - - # Verify system still healthy after timeout - health_response = test_client.get("/health") - assert health_response.status_code == 200 - - # Test normal query still works - response = test_client.get("/users?limit=10") - assert response.status_code == 200 - - def test_paging_state_handling(self, test_client): - """ - GIVEN paginated query results - WHEN paging through large result sets - THEN paging state should be properly managed - """ - # Create enough data for multiple pages - for i in range(250): - user_data = { - "name": f"paging_user_{i}", - "email": f"page{i}@example.com", - "age": 20 + (i % 60), - } - test_client.post("/users", json=user_data) - - # Test paging through results - page_count = 0 - - # Stream pages and collect results - response = test_client.get("/users/stream/pages?limit=250&fetch_size=50&max_pages=10") - assert response.status_code == 200 - - data = response.json() - assert "pages_info" in data - assert len(data["pages_info"]) >= 5 # Should have at least 5 pages - - # Verify each page has expected structure - for page_info in data["pages_info"]: - assert "page_number" in page_info - assert "rows_in_page" in page_info - assert page_info["rows_in_page"] <= 50 # Respects fetch_size - page_count += 1 - - assert page_count >= 5 - - def test_connection_pool_exhaustion_and_queueing(self, test_client): - """ - GIVEN limited connection pool - WHEN pool is exhausted - THEN requests should queue and eventually succeed - """ - start_time = time.time() - results = [] - - def make_slow_request(i): - # Each request might take some time - resp = test_client.get("/performance/sync?requests=10") - return resp.status_code, time.time() - start_time - - # Flood with requests to exhaust pool - with ThreadPoolExecutor(max_workers=50) as executor: - futures = [executor.submit(make_slow_request, i) for i in range(100)] - results = [f.result() for f in futures] - - # All requests should eventually succeed - statuses = [r[0] for r in results] - assert all(status == 200 for status in statuses) - - # Check timing - verify some spread in completion times - completion_times = [r[1] for r in results] - # There should be some variance in completion times - time_spread = max(completion_times) - min(completion_times) - assert time_spread > 0.05, f"Expected some time variance, got {time_spread}s" - - def test_error_propagation_through_async_layers(self, test_client): - """ - GIVEN various error conditions at different layers - WHEN errors occur in Cassandra operations - THEN they should propagate correctly through async layers - """ - # Test different error scenarios - error_scenarios = [ - # Invalid query parameter (non-numeric limit) - ("/users?limit=invalid", 422), # FastAPI validation - # Non-existent path - ("/users/../../etc/passwd", 404), # Path not found - # Invalid JSON - need to use proper API call format - ("/users", 422, "post", "invalid json"), - ] - - for scenario in error_scenarios: - if len(scenario) == 2: - # GET request - response = test_client.get(scenario[0]) - assert response.status_code == scenario[1] - else: - # POST request with invalid data - response = test_client.post(scenario[0], data=scenario[3]) - assert response.status_code == scenario[1] - - def test_async_context_cleanup_on_exceptions(self, test_client): - """ - GIVEN async context managers in use - WHEN exceptions occur during operations - THEN contexts should be properly cleaned up - """ - # Perform operations that might fail - for i in range(10): - if i % 3 == 0: - # Valid operation - response = test_client.get("/users") - assert response.status_code == 200 - elif i % 3 == 1: - # Operation that causes client error - response = test_client.get("/users/not-a-uuid") - assert response.status_code == 400 - else: - # Operation that might cause server error - response = test_client.post("/users", json={}) - assert response.status_code == 422 - - # System should still be healthy - health = test_client.get("/health") - assert health.status_code == 200 - - def test_streaming_memory_efficiency(self, test_client): - """ - GIVEN large result sets - WHEN streaming vs loading all at once - THEN streaming should use significantly less memory - """ - # Create large dataset - created_count = 0 - for i in range(500): - user_data = { - "name": f"stream_efficiency_user_{i}", - "email": f"efficiency{i}@example.com", - "age": 25, - } - resp = test_client.post("/users", json=user_data) - if resp.status_code == 201: - created_count += 1 - - assert created_count >= 500 - - # Compare memory usage between streaming and non-streaming - process = psutil.Process(os.getpid()) - - # Non-streaming (loads all) - gc.collect() - mem_before_regular = process.memory_info().rss / 1024 / 1024 - regular_response = test_client.get("/users?limit=500") - assert regular_response.status_code == 200 - regular_data = regular_response.json() - mem_after_regular = process.memory_info().rss / 1024 / 1024 - mem_after_regular - mem_before_regular - - # Streaming (should use less memory) - gc.collect() - mem_before_stream = process.memory_info().rss / 1024 / 1024 - stream_response = test_client.get("/users/stream?limit=500&fetch_size=50") - assert stream_response.status_code == 200 - stream_data = stream_response.json() - mem_after_stream = process.memory_info().rss / 1024 / 1024 - mem_after_stream - mem_before_stream - - # Streaming should use less memory (allow some variance) - # This might not always be true for small datasets, but the pattern is important - assert len(regular_data) > 0 - assert len(stream_data["users"]) > 0 - - def test_monitoring_metrics_accuracy(self, test_client): - """ - GIVEN operations being performed - WHEN metrics are collected - THEN metrics should accurately reflect operations - """ - # Reset metrics (would need endpoint) - # Perform known operations - operations = {"creates": 5, "reads": 10, "updates": 3, "deletes": 2} - - created_ids = [] - - # Create - for i in range(operations["creates"]): - resp = test_client.post( - "/users", - json={"name": f"metrics_user_{i}", "email": f"metrics{i}@example.com", "age": 25}, - ) - if resp.status_code == 201: - created_ids.append(resp.json()["id"]) - - # Read - for _ in range(operations["reads"]): - test_client.get("/users") - - # Update - for i in range(min(operations["updates"], len(created_ids))): - test_client.patch(f"/users/{created_ids[i]}", json={"age": 30}) - - # Delete - for i in range(min(operations["deletes"], len(created_ids))): - test_client.delete(f"/users/{created_ids[i]}") - - # Check metrics (would need metrics endpoint) - # For now, just verify operations succeeded - assert len(created_ids) == operations["creates"] - - def test_graceful_degradation_under_load(self, test_client): - """ - GIVEN system under heavy load - WHEN load exceeds capacity - THEN system should degrade gracefully, not crash - """ - successful_requests = 0 - failed_requests = 0 - response_times = [] - - def make_request(i): - try: - start = time.time() - resp = test_client.get("/users") - elapsed = time.time() - start - - if resp.status_code == 200: - return "success", elapsed - else: - return "failed", elapsed - except Exception: - return "error", 0 - - # Generate high load - with ThreadPoolExecutor(max_workers=100) as executor: - futures = [executor.submit(make_request, i) for i in range(500)] - results = [f.result() for f in futures] - - for status, elapsed in results: - if status == "success": - successful_requests += 1 - response_times.append(elapsed) - else: - failed_requests += 1 - - # System should handle most requests - success_rate = successful_requests / (successful_requests + failed_requests) - assert success_rate > 0.8, f"Success rate too low: {success_rate}" - - # Response times should be reasonable - if response_times: - avg_response_time = sum(response_times) / len(response_times) - assert avg_response_time < 5.0, f"Average response time too high: {avg_response_time}s" - - def test_event_loop_integration_patterns(self, test_client): - """ - GIVEN FastAPI's event loop - WHEN integrated with Cassandra driver callbacks - THEN operations should not block the event loop - """ - # Test that multiple concurrent requests work properly - # Start a potentially slow operation - import threading - import time - - slow_response = None - quick_responses = [] - - def slow_request(): - nonlocal slow_response - slow_response = test_client.get("/performance/sync?requests=20") - - def quick_request(i): - response = test_client.get("/health") - quick_responses.append(response) - - # Start slow request in background - slow_thread = threading.Thread(target=slow_request) - slow_thread.start() - - # Give it a moment to start - time.sleep(0.1) - - # Make quick requests - quick_threads = [] - for i in range(5): - t = threading.Thread(target=quick_request, args=(i,)) - quick_threads.append(t) - t.start() - - # Wait for all threads - for t in quick_threads: - t.join(timeout=1.0) - slow_thread.join(timeout=5.0) - - # Verify results - assert len(quick_responses) == 5 - assert all(r.status_code == 200 for r in quick_responses) - assert slow_response is not None and slow_response.status_code == 200 - - @pytest.mark.parametrize( - "failure_point", ["before_prepare", "after_prepare", "during_execute", "during_fetch"] - ) - def test_failure_recovery_at_different_stages(self, test_client, failure_point): - """ - GIVEN failures at different stages of query execution - WHEN failures occur - THEN system should recover appropriately - """ - # This would require more sophisticated mocking or test hooks - # For now, test that system remains stable after various operations - - if failure_point == "before_prepare": - # Test with invalid query that fails during preparation - # Would need custom endpoint - pass - elif failure_point == "after_prepare": - # Test with valid prepare but execution failure - pass - elif failure_point == "during_execute": - # Test timeout during execution - pass - elif failure_point == "during_fetch": - # Test failure while fetching pages - pass - - # Verify system health after failure scenario - response = test_client.get("/health") - assert response.status_code == 200 diff --git a/tests/fastapi_integration/test_fastapi_app.py b/tests/fastapi_integration/test_fastapi_app.py deleted file mode 100644 index d5f59a7..0000000 --- a/tests/fastapi_integration/test_fastapi_app.py +++ /dev/null @@ -1,422 +0,0 @@ -""" -Comprehensive test suite for the FastAPI example application. - -This validates that the example properly demonstrates all the -improvements made to the async-cassandra library. -""" - -import asyncio -import os -import time -import uuid - -import httpx -import pytest -import pytest_asyncio -from httpx import ASGITransport - - -class TestFastAPIExample: - """Test suite for FastAPI example application.""" - - @pytest_asyncio.fixture - async def app_client(self): - """Create test client for the FastAPI app.""" - # First, check that Cassandra is available - from async_cassandra import AsyncCluster - - try: - test_cluster = AsyncCluster(contact_points=["localhost"]) - test_session = await test_cluster.connect() - await test_session.execute("SELECT now() FROM system.local") - await test_session.close() - await test_cluster.shutdown() - except Exception as e: - pytest.fail(f"Cassandra not available: {e}") - - from main import app, lifespan - - # Manually handle lifespan since httpx doesn't do it properly - async with lifespan(app): - transport = ASGITransport(app=app) - async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: - yield client - - @pytest.mark.asyncio - async def test_health_and_basic_operations(self, app_client): - """Test health check and basic CRUD operations.""" - print("\n=== Testing Health and Basic Operations ===") - - # Health check - health_resp = await app_client.get("/health") - assert health_resp.status_code == 200 - assert health_resp.json()["status"] == "healthy" - print("✓ Health check passed") - - # Create user - user_data = {"name": "Test User", "email": "test@example.com", "age": 30} - create_resp = await app_client.post("/users", json=user_data) - assert create_resp.status_code == 201 - user = create_resp.json() - print(f"✓ Created user: {user['id']}") - - # Get user - get_resp = await app_client.get(f"/users/{user['id']}") - assert get_resp.status_code == 200 - assert get_resp.json()["name"] == user_data["name"] - print("✓ Retrieved user successfully") - - # Update user - update_data = {"age": 31} - update_resp = await app_client.put(f"/users/{user['id']}", json=update_data) - assert update_resp.status_code == 200 - assert update_resp.json()["age"] == 31 - print("✓ Updated user successfully") - - # Delete user - delete_resp = await app_client.delete(f"/users/{user['id']}") - assert delete_resp.status_code == 204 - print("✓ Deleted user successfully") - - @pytest.mark.asyncio - async def test_thread_safety_under_concurrency(self, app_client): - """Test thread safety improvements with concurrent operations.""" - print("\n=== Testing Thread Safety Under Concurrency ===") - - async def create_and_read_user(user_id: int): - """Create a user and immediately read it back.""" - # Create - user_data = { - "name": f"Concurrent User {user_id}", - "email": f"concurrent{user_id}@test.com", - "age": 25 + (user_id % 10), - } - create_resp = await app_client.post("/users", json=user_data) - if create_resp.status_code != 201: - return None - - created_user = create_resp.json() - - # Immediately read back - get_resp = await app_client.get(f"/users/{created_user['id']}") - if get_resp.status_code != 200: - return None - - return get_resp.json() - - # Run many concurrent operations - num_concurrent = 50 - start_time = time.time() - - results = await asyncio.gather( - *[create_and_read_user(i) for i in range(num_concurrent)], return_exceptions=True - ) - - duration = time.time() - start_time - - # Check results - successful = [r for r in results if isinstance(r, dict)] - errors = [r for r in results if isinstance(r, Exception)] - - print(f"✓ Completed {num_concurrent} concurrent operations in {duration:.2f}s") - print(f" - Successful: {len(successful)}") - print(f" - Errors: {len(errors)}") - - # Thread safety should ensure high success rate - assert len(successful) >= num_concurrent * 0.95 # 95% success rate - - # Verify data consistency - for user in successful: - assert "id" in user - assert "name" in user - assert user["created_at"] is not None - - @pytest.mark.asyncio - async def test_streaming_memory_efficiency(self, app_client): - """Test streaming functionality for memory efficiency.""" - print("\n=== Testing Streaming Memory Efficiency ===") - - # Create a batch of users for streaming - batch_size = 100 - batch_data = { - "users": [ - {"name": f"Stream Test {i}", "email": f"stream{i}@test.com", "age": 20 + (i % 50)} - for i in range(batch_size) - ] - } - - batch_resp = await app_client.post("/users/batch", json=batch_data) - assert batch_resp.status_code == 201 - print(f"✓ Created {batch_size} users for streaming test") - - # Test regular streaming - stream_resp = await app_client.get(f"/users/stream?limit={batch_size}&fetch_size=10") - assert stream_resp.status_code == 200 - stream_data = stream_resp.json() - - assert stream_data["metadata"]["streaming_enabled"] is True - assert stream_data["metadata"]["pages_fetched"] > 1 - assert len(stream_data["users"]) >= batch_size - print( - f"✓ Streamed {len(stream_data['users'])} users in {stream_data['metadata']['pages_fetched']} pages" - ) - - # Test page-by-page streaming - pages_resp = await app_client.get( - f"/users/stream/pages?limit={batch_size}&fetch_size=10&max_pages=5" - ) - assert pages_resp.status_code == 200 - pages_data = pages_resp.json() - - assert pages_data["metadata"]["streaming_mode"] == "page_by_page" - assert len(pages_data["pages_info"]) <= 5 - print( - f"✓ Page-by-page streaming: {pages_data['total_rows_processed']} rows in {len(pages_data['pages_info'])} pages" - ) - - @pytest.mark.asyncio - async def test_error_handling_consistency(self, app_client): - """Test error handling improvements.""" - print("\n=== Testing Error Handling Consistency ===") - - # Test invalid UUID handling - invalid_uuid_resp = await app_client.get("/users/not-a-uuid") - assert invalid_uuid_resp.status_code == 400 - assert "Invalid UUID" in invalid_uuid_resp.json()["detail"] - print("✓ Invalid UUID error handled correctly") - - # Test non-existent resource - fake_uuid = str(uuid.uuid4()) - not_found_resp = await app_client.get(f"/users/{fake_uuid}") - assert not_found_resp.status_code == 404 - assert "User not found" in not_found_resp.json()["detail"] - print("✓ Resource not found error handled correctly") - - # Test validation errors - missing required field - invalid_user_resp = await app_client.post( - "/users", json={"name": "Test"} # Missing email and age - ) - assert invalid_user_resp.status_code == 422 - print("✓ Validation error handled correctly") - - # Test streaming with invalid parameters - invalid_stream_resp = await app_client.get("/users/stream?fetch_size=0") - assert invalid_stream_resp.status_code == 422 - print("✓ Streaming parameter validation working") - - @pytest.mark.asyncio - async def test_performance_comparison(self, app_client): - """Test performance endpoints to validate async benefits.""" - print("\n=== Testing Performance Comparison ===") - - # Compare async vs sync performance - num_requests = 50 - - # Test async performance - async_resp = await app_client.get(f"/performance/async?requests={num_requests}") - assert async_resp.status_code == 200 - async_data = async_resp.json() - - # Test sync performance - sync_resp = await app_client.get(f"/performance/sync?requests={num_requests}") - assert sync_resp.status_code == 200 - sync_data = sync_resp.json() - - print(f"✓ Async performance: {async_data['requests_per_second']:.1f} req/s") - print(f"✓ Sync performance: {sync_data['requests_per_second']:.1f} req/s") - print( - f"✓ Speedup factor: {async_data['requests_per_second'] / sync_data['requests_per_second']:.1f}x" - ) - - # Skip performance comparison in CI environments - if os.getenv("CI") != "true": - # Async should be significantly faster - assert async_data["requests_per_second"] > sync_data["requests_per_second"] - else: - # In CI, just verify both completed successfully - assert async_data["requests"] == num_requests - assert sync_data["requests"] == num_requests - assert async_data["requests_per_second"] > 0 - assert sync_data["requests_per_second"] > 0 - - @pytest.mark.asyncio - async def test_monitoring_endpoints(self, app_client): - """Test monitoring and metrics endpoints.""" - print("\n=== Testing Monitoring Endpoints ===") - - # Test metrics endpoint - metrics_resp = await app_client.get("/metrics") - assert metrics_resp.status_code == 200 - metrics = metrics_resp.json() - - assert "query_performance" in metrics - assert "cassandra_connections" in metrics - print("✓ Metrics endpoint working") - - # Test shutdown endpoint - shutdown_resp = await app_client.post("/shutdown") - assert shutdown_resp.status_code == 200 - assert "Shutdown initiated" in shutdown_resp.json()["message"] - print("✓ Shutdown endpoint working") - - @pytest.mark.asyncio - async def test_timeout_handling(self, app_client): - """Test timeout handling capabilities.""" - print("\n=== Testing Timeout Handling ===") - - # Test with short timeout (should timeout) - timeout_resp = await app_client.get("/slow_query", headers={"X-Request-Timeout": "0.1"}) - assert timeout_resp.status_code == 504 - print("✓ Short timeout handled correctly") - - # Test with adequate timeout - success_resp = await app_client.get("/slow_query", headers={"X-Request-Timeout": "10"}) - assert success_resp.status_code == 200 - print("✓ Adequate timeout allows completion") - - @pytest.mark.asyncio - async def test_context_manager_safety(self, app_client): - """Test comprehensive context manager safety in FastAPI.""" - print("\n=== Testing Context Manager Safety ===") - - # Get initial status - status = await app_client.get("/context_manager_safety/status") - assert status.status_code == 200 - initial_state = status.json() - print( - f"✓ Initial state: Session={initial_state['session_open']}, Cluster={initial_state['cluster_open']}" - ) - - # Test 1: Query errors don't close session - print("\nTest 1: Query Error Safety") - query_error_resp = await app_client.post("/context_manager_safety/query_error") - assert query_error_resp.status_code == 200 - query_result = query_error_resp.json() - assert query_result["session_unchanged"] is True - assert query_result["session_open"] is True - assert query_result["session_still_works"] is True - assert "non_existent_table_xyz" in query_result["error_caught"] - print("✓ Query errors don't close session") - print(f" - Error caught: {query_result['error_caught'][:50]}...") - print(f" - Session still works: {query_result['session_still_works']}") - - # Test 2: Streaming errors don't close session - print("\nTest 2: Streaming Error Safety") - stream_error_resp = await app_client.post("/context_manager_safety/streaming_error") - assert stream_error_resp.status_code == 200 - stream_result = stream_error_resp.json() - assert stream_result["session_unchanged"] is True - assert stream_result["session_open"] is True - assert stream_result["streaming_error_caught"] is True - # The session_still_streams might be False if no users exist, but session should work - if not stream_result["session_still_streams"]: - print(f" - Note: No users found ({stream_result['rows_after_error']} rows)") - # Create a user for subsequent tests - user_resp = await app_client.post( - "/users", json={"name": "Test User", "email": "test@example.com", "age": 30} - ) - assert user_resp.status_code == 201 - print("✓ Streaming errors don't close session") - print(f" - Error caught: {stream_result['error_message'][:50]}...") - print(f" - Session remains open: {stream_result['session_open']}") - - # Test 3: Concurrent streams don't interfere - print("\nTest 3: Concurrent Streams Safety") - concurrent_resp = await app_client.post("/context_manager_safety/concurrent_streams") - assert concurrent_resp.status_code == 200 - concurrent_result = concurrent_resp.json() - print(f" - Debug: Results = {concurrent_result['results']}") - assert concurrent_result["streams_completed"] == 3 - # Check if streams worked independently (each should have 10 users) - if not concurrent_result["all_streams_independent"]: - print( - f" - Warning: Stream counts varied: {[r['count'] for r in concurrent_result['results']]}" - ) - assert concurrent_result["session_still_open"] is True - print("✓ Concurrent streams completed") - for result in concurrent_result["results"]: - print(f" - Age {result['age']}: {result['count']} users") - - # Test 4: Nested context managers - print("\nTest 4: Nested Context Managers") - nested_resp = await app_client.post("/context_manager_safety/nested_contexts") - assert nested_resp.status_code == 200 - nested_result = nested_resp.json() - assert nested_result["correct_order"] is True - assert nested_result["main_session_unaffected"] is True - assert nested_result["row_count"] == 5 - print("✓ Nested contexts close in correct order") - print(f" - Events: {' → '.join(nested_result['events'][:5])}...") - print(f" - Main session unaffected: {nested_result['main_session_unaffected']}") - - # Test 5: Streaming cancellation - print("\nTest 5: Streaming Cancellation Safety") - cancel_resp = await app_client.post("/context_manager_safety/cancellation") - assert cancel_resp.status_code == 200 - cancel_result = cancel_resp.json() - assert cancel_result["was_cancelled"] is True - assert cancel_result["session_still_works"] is True - assert cancel_result["new_stream_worked"] is True - assert cancel_result["session_open"] is True - print("✓ Cancelled streams clean up properly") - print(f" - Rows before cancel: {cancel_result['rows_processed_before_cancel']}") - print(f" - Session works after cancel: {cancel_result['session_still_works']}") - print(f" - New stream successful: {cancel_result['new_stream_worked']}") - - # Verify final state matches initial state - final_status = await app_client.get("/context_manager_safety/status") - assert final_status.status_code == 200 - final_state = final_status.json() - assert final_state["session_id"] == initial_state["session_id"] - assert final_state["cluster_id"] == initial_state["cluster_id"] - assert final_state["session_open"] is True - assert final_state["cluster_open"] is True - print("\n✓ All context manager safety tests passed!") - print(" - Session remained stable throughout all tests") - print(" - No resource leaks detected") - - -async def run_all_tests(): - """Run all tests and print summary.""" - print("=" * 60) - print("FastAPI Example Application Test Suite") - print("=" * 60) - - test_suite = TestFastAPIExample() - - # Create client - from main import app - - async with httpx.AsyncClient(app=app, base_url="http://test") as client: - # Run tests - try: - await test_suite.test_health_and_basic_operations(client) - await test_suite.test_thread_safety_under_concurrency(client) - await test_suite.test_streaming_memory_efficiency(client) - await test_suite.test_error_handling_consistency(client) - await test_suite.test_performance_comparison(client) - await test_suite.test_monitoring_endpoints(client) - await test_suite.test_timeout_handling(client) - await test_suite.test_context_manager_safety(client) - - print("\n" + "=" * 60) - print("✅ All tests passed! The FastAPI example properly demonstrates:") - print(" - Thread safety improvements") - print(" - Memory-efficient streaming") - print(" - Consistent error handling") - print(" - Performance benefits of async") - print(" - Monitoring capabilities") - print(" - Timeout handling") - print("=" * 60) - - except AssertionError as e: - print(f"\n❌ Test failed: {e}") - raise - except Exception as e: - print(f"\n❌ Unexpected error: {e}") - raise - - -if __name__ == "__main__": - # Run the test suite - asyncio.run(run_all_tests()) diff --git a/tests/fastapi_integration/test_fastapi_comprehensive.py b/tests/fastapi_integration/test_fastapi_comprehensive.py deleted file mode 100644 index 6a049de..0000000 --- a/tests/fastapi_integration/test_fastapi_comprehensive.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -Comprehensive integration tests for FastAPI application. - -Following TDD principles, these tests are written FIRST to define -the expected behavior of the async-cassandra framework when used -with FastAPI - its primary use case. -""" - -import time -import uuid -from concurrent.futures import ThreadPoolExecutor - -import pytest -from fastapi.testclient import TestClient - - -@pytest.mark.integration -class TestFastAPIComprehensive: - """Comprehensive tests for FastAPI integration following TDD principles.""" - - @pytest.fixture - def test_client(self): - """Create FastAPI test client.""" - # Import here to ensure app is created fresh - from examples.fastapi_app.main import app - - # TestClient properly handles lifespan in newer FastAPI versions - with TestClient(app) as client: - yield client - - def test_health_check_endpoint(self, test_client): - """ - GIVEN a FastAPI application with async-cassandra - WHEN the health endpoint is called - THEN it should return healthy status without blocking - """ - response = test_client.get("/health") - assert response.status_code == 200 - data = response.json() - assert data["status"] == "healthy" - assert data["cassandra_connected"] is True - assert "timestamp" in data - - def test_concurrent_request_handling(self, test_client): - """ - GIVEN a FastAPI application handling multiple concurrent requests - WHEN many requests are sent simultaneously - THEN all requests should be handled without blocking or data corruption - """ - - # Create multiple users concurrently - def create_user(i): - user_data = { - "name": f"concurrent_user_{i}", # Changed from username to name - "email": f"user{i}@example.com", - "age": 25 + (i % 50), # Add required age field - } - return test_client.post("/users", json=user_data) - - # Send 50 concurrent requests - with ThreadPoolExecutor(max_workers=10) as executor: - futures = [executor.submit(create_user, i) for i in range(50)] - responses = [f.result() for f in futures] - - # All should succeed - assert all(r.status_code == 201 for r in responses) - - # Verify no data corruption - all users should be unique - user_ids = [r.json()["id"] for r in responses] - assert len(set(user_ids)) == 50 # All IDs should be unique - - def test_streaming_large_datasets(self, test_client): - """ - GIVEN a large dataset in Cassandra - WHEN streaming data through FastAPI - THEN memory usage should remain constant and not accumulate - """ - # First create some users to stream - for i in range(100): - user_data = { - "name": f"stream_user_{i}", - "email": f"stream{i}@example.com", - "age": 20 + (i % 60), - } - test_client.post("/users", json=user_data) - - # Test streaming endpoint - currently fails due to route ordering bug in FastAPI app - # where /users/{user_id} matches before /users/stream - response = test_client.get("/users/stream?limit=100&fetch_size=10") - - # This test expects the streaming functionality to work - # Currently it fails with 400 due to route ordering issue - assert response.status_code == 200 - data = response.json() - assert "users" in data - assert "metadata" in data - assert data["metadata"]["streaming_enabled"] is True - assert len(data["users"]) >= 100 # Should have at least the users we created - - def test_error_handling_and_recovery(self, test_client): - """ - GIVEN various error conditions - WHEN errors occur during request processing - THEN the application should handle them gracefully and recover - """ - # Test 1: Invalid UUID - response = test_client.get("/users/invalid-uuid") - assert response.status_code == 400 - assert "Invalid UUID" in response.json()["detail"] - - # Test 2: Non-existent resource - non_existent_id = str(uuid.uuid4()) - response = test_client.get(f"/users/{non_existent_id}") - assert response.status_code == 404 - assert "User not found" in response.json()["detail"] - - # Test 3: Invalid data - response = test_client.post("/users", json={"invalid": "data"}) - assert response.status_code == 422 # FastAPI validation error - - # Test 4: Verify app still works after errors - health_response = test_client.get("/health") - assert health_response.status_code == 200 - - def test_connection_pool_behavior(self, test_client): - """ - GIVEN limited connection pool resources - WHEN many requests exceed pool capacity - THEN requests should queue appropriately without failing - """ - # Create a burst of requests that exceed typical pool size - start_time = time.time() - - def make_request(i): - return test_client.get("/users") - - # Send 100 requests with limited concurrency - with ThreadPoolExecutor(max_workers=20) as executor: - futures = [executor.submit(make_request, i) for i in range(100)] - responses = [f.result() for f in futures] - - duration = time.time() - start_time - - # All should eventually succeed - assert all(r.status_code == 200 for r in responses) - - # Should complete in reasonable time (not hung) - assert duration < 30 # 30 seconds for 100 requests is reasonable - - def test_prepared_statement_caching(self, test_client): - """ - GIVEN repeated identical queries - WHEN executed multiple times - THEN prepared statements should be cached and reused - """ - # Create a user first - user_data = {"name": "test_user", "email": "test@example.com", "age": 25} - create_response = test_client.post("/users", json=user_data) - user_id = create_response.json()["id"] - - # Get the same user multiple times - responses = [] - for _ in range(10): - response = test_client.get(f"/users/{user_id}") - responses.append(response) - - # All should succeed and return same data - assert all(r.status_code == 200 for r in responses) - assert all(r.json()["id"] == user_id for r in responses) - - # Performance should improve after first query (prepared statement cached) - # This is more of a performance characteristic than functional test - - def test_batch_operations(self, test_client): - """ - GIVEN multiple operations to perform - WHEN executed as a batch - THEN all operations should succeed atomically - """ - # Create multiple users in a batch - batch_data = { - "users": [ - {"name": f"batch_user_{i}", "email": f"batch{i}@example.com", "age": 25 + i} - for i in range(5) - ] - } - - response = test_client.post("/users/batch", json=batch_data) - assert response.status_code == 201 - - created_users = response.json()["created"] - assert len(created_users) == 5 - - # Verify all were created - for user in created_users: - get_response = test_client.get(f"/users/{user['id']}") - assert get_response.status_code == 200 - - def test_async_context_manager_usage(self, test_client): - """ - GIVEN async context manager pattern - WHEN used in request handlers - THEN resources should be properly managed - """ - # This tests that sessions are properly closed even with errors - # Make multiple requests that might fail - for i in range(10): - if i % 2 == 0: - # Valid request - test_client.get("/users") - else: - # Invalid request - test_client.get("/users/invalid-uuid") - - # Verify system still healthy - health = test_client.get("/health") - assert health.status_code == 200 - - def test_monitoring_and_metrics(self, test_client): - """ - GIVEN monitoring endpoints - WHEN metrics are requested - THEN accurate metrics should be returned - """ - # Make some requests to generate metrics - for _ in range(5): - test_client.get("/users") - - # Get metrics - response = test_client.get("/metrics") - assert response.status_code == 200 - - metrics = response.json() - assert "total_requests" in metrics - assert metrics["total_requests"] >= 5 - assert "query_performance" in metrics - - @pytest.mark.parametrize("consistency_level", ["ONE", "QUORUM", "ALL"]) - def test_consistency_levels(self, test_client, consistency_level): - """ - GIVEN different consistency level requirements - WHEN operations are performed - THEN the appropriate consistency should be used - """ - # Create user with specific consistency level - user_data = { - "name": f"consistency_test_{consistency_level}", - "email": f"test_{consistency_level}@example.com", - "age": 25, - } - - response = test_client.post( - "/users", json=user_data, headers={"X-Consistency-Level": consistency_level} - ) - - assert response.status_code == 201 - - # Verify it was created - user_id = response.json()["id"] - get_response = test_client.get( - f"/users/{user_id}", headers={"X-Consistency-Level": consistency_level} - ) - assert get_response.status_code == 200 - - def test_timeout_handling(self, test_client): - """ - GIVEN timeout constraints - WHEN operations exceed timeout - THEN appropriate timeout errors should be returned - """ - # Create a slow query endpoint (would need to be added to FastAPI app) - response = test_client.get( - "/slow_query", headers={"X-Request-Timeout": "0.1"} # 100ms timeout - ) - - # Should timeout - assert response.status_code == 504 # Gateway timeout - - def test_no_blocking_of_event_loop(self, test_client): - """ - GIVEN async operations running - WHEN Cassandra operations are performed - THEN the event loop should not be blocked - """ - # Start a long-running query - import threading - - long_query_done = threading.Event() - - def long_query(): - test_client.get("/long_running_query") - long_query_done.set() - - # Start long query in background - thread = threading.Thread(target=long_query) - thread.start() - - # Meanwhile, other quick queries should still work - start_time = time.time() - for _ in range(5): - response = test_client.get("/health") - assert response.status_code == 200 - - quick_queries_time = time.time() - start_time - - # Quick queries should complete fast even with long query running - assert quick_queries_time < 1.0 # Should take less than 1 second - - # Wait for long query to complete - thread.join(timeout=5) - - def test_graceful_shutdown(self, test_client): - """ - GIVEN an active FastAPI application - WHEN shutdown is initiated - THEN all connections should be properly closed - """ - # Make some requests - for _ in range(3): - test_client.get("/users") - - # Trigger shutdown (this would need shutdown endpoint) - response = test_client.post("/shutdown") - assert response.status_code == 200 - - # Verify connections were closed properly - # (Would need to check connection metrics) diff --git a/tests/fastapi_integration/test_fastapi_enhanced.py b/tests/fastapi_integration/test_fastapi_enhanced.py deleted file mode 100644 index d005996..0000000 --- a/tests/fastapi_integration/test_fastapi_enhanced.py +++ /dev/null @@ -1,335 +0,0 @@ -""" -Enhanced integration tests for FastAPI with all async-cassandra features. -""" - -import asyncio -import uuid - -import pytest -import pytest_asyncio -from examples.fastapi_app.main_enhanced import app -from httpx import ASGITransport, AsyncClient - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestEnhancedFastAPIFeatures: - """Test all enhanced features in the FastAPI example.""" - - @pytest_asyncio.fixture - async def client(self): - """Create async HTTP client with proper app initialization.""" - # The app needs to be properly initialized with lifespan - - # Create a test app that runs the lifespan - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - # Trigger lifespan startup - async with app.router.lifespan_context(app): - yield client - - async def test_root_endpoint(self, client): - """Test root endpoint lists all features.""" - response = await client.get("/") - assert response.status_code == 200 - data = response.json() - assert "features" in data - assert "Timeout handling" in data["features"] - assert "Memory-efficient streaming" in data["features"] - assert "Connection monitoring" in data["features"] - - async def test_enhanced_health_check(self, client): - """Test enhanced health check with monitoring data.""" - response = await client.get("/health") - assert response.status_code == 200 - data = response.json() - - # Check all required fields - assert "status" in data - assert "healthy_hosts" in data - assert "unhealthy_hosts" in data - assert "total_connections" in data - assert "timestamp" in data - - # Verify at least one healthy host - assert data["healthy_hosts"] >= 1 - - async def test_host_monitoring(self, client): - """Test detailed host monitoring endpoint.""" - response = await client.get("/monitoring/hosts") - assert response.status_code == 200 - data = response.json() - - assert "cluster_name" in data - assert "protocol_version" in data - assert "hosts" in data - assert isinstance(data["hosts"], list) - - # Check host details - if data["hosts"]: - host = data["hosts"][0] - assert "address" in host - assert "status" in host - assert "latency_ms" in host - - async def test_connection_summary(self, client): - """Test connection summary endpoint.""" - response = await client.get("/monitoring/summary") - assert response.status_code == 200 - data = response.json() - - assert "total_hosts" in data - assert "up_hosts" in data - assert "down_hosts" in data - assert "protocol_version" in data - assert "max_requests_per_connection" in data - - async def test_create_user_with_timeout(self, client): - """Test user creation with timeout handling.""" - user_data = {"name": "Timeout Test User", "email": "timeout@test.com", "age": 30} - - response = await client.post("/users", json=user_data) - assert response.status_code == 201 - created_user = response.json() - - assert created_user["name"] == user_data["name"] - assert created_user["email"] == user_data["email"] - assert "id" in created_user - - async def test_list_users_with_custom_timeout(self, client): - """Test listing users with custom timeout.""" - # First create some users - for i in range(5): - await client.post( - "/users", - json={"name": f"Test User {i}", "email": f"user{i}@test.com", "age": 25 + i}, - ) - - # List with custom timeout - response = await client.get("/users?limit=5&timeout=10.0") - assert response.status_code == 200 - users = response.json() - assert isinstance(users, list) - assert len(users) <= 5 - - async def test_advanced_streaming(self, client): - """Test advanced streaming with all options.""" - # Create test data - for i in range(20): - await client.post( - "/users", - json={"name": f"Stream User {i}", "email": f"stream{i}@test.com", "age": 20 + i}, - ) - - # Test streaming with various configurations - response = await client.get( - "/users/stream/advanced?" - "limit=20&" - "fetch_size=10&" # Minimum is 10 - "max_pages=3&" - "timeout_seconds=30.0" - ) - if response.status_code != 200: - print(f"Response status: {response.status_code}") - print(f"Response body: {response.text}") - assert response.status_code == 200 - data = response.json() - - assert "users" in data - assert "metadata" in data - - metadata = data["metadata"] - assert metadata["pages_fetched"] <= 3 # Respects max_pages - assert metadata["rows_processed"] <= 20 # Respects limit - assert "duration_seconds" in metadata - assert "rows_per_second" in metadata - - async def test_streaming_with_memory_limit(self, client): - """Test streaming with memory limit.""" - response = await client.get( - "/users/stream/advanced?" - "limit=1000&" - "fetch_size=100&" - "max_memory_mb=1" # Very low memory limit - ) - assert response.status_code == 200 - data = response.json() - - # Should stop before reaching limit due to memory constraint - assert len(data["users"]) < 1000 - - async def test_error_handling_invalid_uuid(self, client): - """Test proper error handling for invalid UUID.""" - response = await client.get("/users/invalid-uuid") - assert response.status_code == 400 - assert "Invalid UUID format" in response.json()["detail"] - - async def test_error_handling_user_not_found(self, client): - """Test proper error handling for non-existent user.""" - random_uuid = str(uuid.uuid4()) - response = await client.get(f"/users/{random_uuid}") - assert response.status_code == 404 - assert "User not found" in response.json()["detail"] - - async def test_query_metrics(self, client): - """Test query metrics collection.""" - # Execute some queries first - for i in range(10): - await client.get("/users?limit=1") - - response = await client.get("/metrics/queries") - assert response.status_code == 200 - data = response.json() - - if "query_performance" in data: - perf = data["query_performance"] - assert "total_queries" in perf - assert perf["total_queries"] >= 10 - - async def test_rate_limit_status(self, client): - """Test rate limiting status endpoint.""" - response = await client.get("/rate_limit/status") - assert response.status_code == 200 - data = response.json() - - assert "rate_limiting_enabled" in data - if data["rate_limiting_enabled"]: - assert "metrics" in data - assert "max_concurrent" in data - - async def test_timeout_operations(self, client): - """Test timeout handling for different operations.""" - # Test very short timeout - response = await client.post("/test/timeout?operation=execute&timeout=0.1") - assert response.status_code == 200 - data = response.json() - - # Should either complete or timeout - assert data.get("error") in ["timeout", None] - - async def test_concurrent_load_read(self, client): - """Test system under concurrent read load.""" - # Create test data - await client.post( - "/users", json={"name": "Load Test User", "email": "load@test.com", "age": 25} - ) - - # Test concurrent reads - response = await client.post("/test/concurrent_load?concurrent_requests=20&query_type=read") - assert response.status_code == 200 - data = response.json() - - summary = data["test_summary"] - assert summary["successful"] > 0 - assert summary["requests_per_second"] > 0 - - # Check rate limit metrics if available - if data.get("rate_limit_metrics"): - metrics = data["rate_limit_metrics"] - assert metrics["total_requests"] >= 20 - - async def test_concurrent_load_write(self, client): - """Test system under concurrent write load.""" - response = await client.post( - "/test/concurrent_load?concurrent_requests=10&query_type=write" - ) - if response.status_code != 200: - print(f"Response status: {response.status_code}") - print(f"Response body: {response.text}") - assert response.status_code == 200 - data = response.json() - - summary = data["test_summary"] - assert summary["successful"] > 0 - - # Clean up test data - cleanup_response = await client.delete("/users/cleanup") - if cleanup_response.status_code != 200: - print(f"Cleanup error: {cleanup_response.text}") - assert cleanup_response.status_code == 200 - - async def test_streaming_timeout(self, client): - """Test streaming with timeout.""" - # Test with very short timeout - response = await client.get( - "/users/stream/advanced?" - "limit=1000&" - "fetch_size=100&" # Add required fetch_size - "timeout_seconds=0.1" # Very short timeout - ) - - # Should either complete quickly or timeout - if response.status_code == 504: - assert "timeout" in response.json()["detail"].lower() - elif response.status_code == 422: - # Validation error is also acceptable - might fail before timeout - assert "detail" in response.json() - else: - assert response.status_code == 200 - - async def test_connection_monitoring_callbacks(self, client): - """Test that monitoring is active and collecting data.""" - # Wait a bit for monitoring to collect data - await asyncio.sleep(2) - - # Check host status - response = await client.get("/monitoring/hosts") - assert response.status_code == 200 - data = response.json() - - # Should have collected latency data - hosts_with_latency = [h for h in data["hosts"] if h.get("latency_ms") is not None] - assert len(hosts_with_latency) > 0 - - async def test_graceful_error_recovery(self, client): - """Test that system recovers gracefully from errors.""" - # Create a user (should work) - user1 = await client.post( - "/users", json={"name": "Recovery Test 1", "email": "recovery1@test.com", "age": 30} - ) - assert user1.status_code == 201 - - # Try invalid operation - invalid = await client.get("/users/not-a-uuid") - assert invalid.status_code == 400 - - # System should still work - user2 = await client.post( - "/users", json={"name": "Recovery Test 2", "email": "recovery2@test.com", "age": 31} - ) - assert user2.status_code == 201 - - async def test_memory_efficient_streaming(self, client): - """Test that streaming is memory efficient.""" - # Create substantial test data - batch_size = 50 - for batch in range(3): - batch_data = { - "users": [ - { - "name": f"Batch User {batch * batch_size + i}", - "email": f"batch{batch}_{i}@test.com", - "age": 20 + i, - } - for i in range(batch_size) - ] - } - # Use the main app's batch endpoint - response = await client.post("/users/batch", json=batch_data) - assert response.status_code == 200 - - # Stream through all data with smaller fetch size to ensure multiple pages - response = await client.get( - "/users/stream/advanced?" - "limit=200&" # Increase limit to ensure we get all users - "fetch_size=10&" # Small fetch size to ensure multiple pages - "max_pages=20" - ) - assert response.status_code == 200 - data = response.json() - - # With 150 users and fetch_size=10, we should get multiple pages - # Check that we got users (may not be exactly 150 due to other tests) - assert data["metadata"]["pages_fetched"] >= 1 - assert len(data["users"]) >= 150 # Should get at least 150 users - assert len(data["users"]) <= 200 # But no more than limit diff --git a/tests/fastapi_integration/test_fastapi_example.py b/tests/fastapi_integration/test_fastapi_example.py deleted file mode 100644 index ea3fefa..0000000 --- a/tests/fastapi_integration/test_fastapi_example.py +++ /dev/null @@ -1,331 +0,0 @@ -""" -Integration tests for FastAPI example application. -""" - -import asyncio -import sys -import uuid -from pathlib import Path -from typing import AsyncGenerator - -import pytest -import pytest_asyncio -from httpx import AsyncClient - -# Add the FastAPI app directory to the path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "examples" / "fastapi_app")) -from main import app - - -@pytest.fixture(scope="session") -def cassandra_service(): - """Use existing Cassandra service for tests.""" - # Cassandra should already be running on localhost:9042 - # Check if it's available - import socket - import time - - max_attempts = 10 - for i in range(max_attempts): - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - result = sock.connect_ex(("localhost", 9042)) - sock.close() - if result == 0: - yield True - return - except Exception: - pass - time.sleep(1) - - raise RuntimeError("Cassandra is not available on localhost:9042") - - -@pytest_asyncio.fixture -async def client() -> AsyncGenerator[AsyncClient, None]: - """Create async HTTP client for tests.""" - from httpx import ASGITransport, AsyncClient - - # Initialize the app lifespan context - async with app.router.lifespan_context(app): - # Use ASGI transport to test the app directly - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - yield ac - - -@pytest.mark.integration -class TestHealthEndpoint: - """Test health check endpoint.""" - - @pytest.mark.asyncio - async def test_health_check(self, client: AsyncClient, cassandra_service): - """Test health check returns healthy status.""" - response = await client.get("/health") - - assert response.status_code == 200 - data = response.json() - - assert data["status"] == "healthy" - assert data["cassandra_connected"] is True - assert "timestamp" in data - - -@pytest.mark.integration -class TestUserCRUD: - """Test user CRUD operations.""" - - @pytest.mark.asyncio - async def test_create_user(self, client: AsyncClient, cassandra_service): - """Test creating a new user.""" - user_data = {"name": "John Doe", "email": "john@example.com", "age": 30} - - response = await client.post("/users", json=user_data) - - assert response.status_code == 201 - data = response.json() - - assert "id" in data - assert data["name"] == user_data["name"] - assert data["email"] == user_data["email"] - assert data["age"] == user_data["age"] - assert "created_at" in data - assert "updated_at" in data - - @pytest.mark.asyncio - async def test_get_user(self, client: AsyncClient, cassandra_service): - """Test getting user by ID.""" - # First create a user - user_data = {"name": "Jane Doe", "email": "jane@example.com", "age": 25} - - create_response = await client.post("/users", json=user_data) - created_user = create_response.json() - user_id = created_user["id"] - - # Get the user - response = await client.get(f"/users/{user_id}") - - assert response.status_code == 200 - data = response.json() - - assert data["id"] == user_id - assert data["name"] == user_data["name"] - assert data["email"] == user_data["email"] - assert data["age"] == user_data["age"] - - @pytest.mark.asyncio - async def test_get_nonexistent_user(self, client: AsyncClient, cassandra_service): - """Test getting non-existent user returns 404.""" - fake_id = str(uuid.uuid4()) - - response = await client.get(f"/users/{fake_id}") - - assert response.status_code == 404 - assert "User not found" in response.json()["detail"] - - @pytest.mark.asyncio - async def test_invalid_user_id_format(self, client: AsyncClient, cassandra_service): - """Test invalid user ID format returns 400.""" - response = await client.get("/users/invalid-uuid") - - assert response.status_code == 400 - assert "Invalid UUID" in response.json()["detail"] - - @pytest.mark.asyncio - async def test_list_users(self, client: AsyncClient, cassandra_service): - """Test listing users.""" - # Create multiple users - users = [] - for i in range(5): - user_data = {"name": f"User {i}", "email": f"user{i}@example.com", "age": 20 + i} - response = await client.post("/users", json=user_data) - users.append(response.json()) - - # List users - response = await client.get("/users?limit=10") - - assert response.status_code == 200 - data = response.json() - - assert isinstance(data, list) - assert len(data) >= 5 # At least the users we created - - @pytest.mark.asyncio - async def test_update_user(self, client: AsyncClient, cassandra_service): - """Test updating user.""" - # Create a user - user_data = {"name": "Update Test", "email": "update@example.com", "age": 30} - - create_response = await client.post("/users", json=user_data) - user_id = create_response.json()["id"] - - # Update the user - update_data = {"name": "Updated Name", "age": 31} - - response = await client.put(f"/users/{user_id}", json=update_data) - - assert response.status_code == 200 - data = response.json() - - assert data["id"] == user_id - assert data["name"] == update_data["name"] - assert data["email"] == user_data["email"] # Unchanged - assert data["age"] == update_data["age"] - assert data["updated_at"] > data["created_at"] - - @pytest.mark.asyncio - async def test_partial_update(self, client: AsyncClient, cassandra_service): - """Test partial update of user.""" - # Create a user - user_data = {"name": "Partial Update", "email": "partial@example.com", "age": 25} - - create_response = await client.post("/users", json=user_data) - user_id = create_response.json()["id"] - - # Update only email - update_data = {"email": "newemail@example.com"} - - response = await client.put(f"/users/{user_id}", json=update_data) - - assert response.status_code == 200 - data = response.json() - - assert data["email"] == update_data["email"] - assert data["name"] == user_data["name"] # Unchanged - assert data["age"] == user_data["age"] # Unchanged - - @pytest.mark.asyncio - async def test_delete_user(self, client: AsyncClient, cassandra_service): - """Test deleting user.""" - # Create a user - user_data = {"name": "Delete Test", "email": "delete@example.com", "age": 35} - - create_response = await client.post("/users", json=user_data) - user_id = create_response.json()["id"] - - # Delete the user - response = await client.delete(f"/users/{user_id}") - - assert response.status_code == 204 - - # Verify user is deleted - get_response = await client.get(f"/users/{user_id}") - assert get_response.status_code == 404 - - -@pytest.mark.integration -class TestPerformance: - """Test performance endpoints.""" - - @pytest.mark.asyncio - async def test_async_performance(self, client: AsyncClient, cassandra_service): - """Test async performance endpoint.""" - response = await client.get("/performance/async?requests=10") - - assert response.status_code == 200 - data = response.json() - - assert data["requests"] == 10 - assert data["total_time"] > 0 - assert data["avg_time_per_request"] > 0 - assert data["requests_per_second"] > 0 - - @pytest.mark.asyncio - async def test_sync_performance(self, client: AsyncClient, cassandra_service): - """Test sync performance endpoint.""" - response = await client.get("/performance/sync?requests=10") - - assert response.status_code == 200 - data = response.json() - - assert data["requests"] == 10 - assert data["total_time"] > 0 - assert data["avg_time_per_request"] > 0 - assert data["requests_per_second"] > 0 - - @pytest.mark.asyncio - async def test_performance_comparison(self, client: AsyncClient, cassandra_service): - """Test that async is faster than sync for concurrent operations.""" - # Run async test - async_response = await client.get("/performance/async?requests=50") - assert async_response.status_code == 200 - async_data = async_response.json() - assert async_data["requests"] == 50 - assert async_data["total_time"] > 0 - assert async_data["requests_per_second"] > 0 - - # Run sync test - sync_response = await client.get("/performance/sync?requests=50") - assert sync_response.status_code == 200 - sync_data = sync_response.json() - assert sync_data["requests"] == 50 - assert sync_data["total_time"] > 0 - assert sync_data["requests_per_second"] > 0 - - # Async should be significantly faster for concurrent operations - # Note: In CI or under light load, the difference might be small - # so we just verify both work correctly - print(f"Async RPS: {async_data['requests_per_second']:.2f}") - print(f"Sync RPS: {sync_data['requests_per_second']:.2f}") - - # For concurrent operations, async should generally be faster - # but we'll be lenient in case of CI variability - assert async_data["requests_per_second"] > sync_data["requests_per_second"] * 0.8 - - -@pytest.mark.integration -class TestConcurrency: - """Test concurrent operations.""" - - @pytest.mark.asyncio - async def test_concurrent_user_creation(self, client: AsyncClient, cassandra_service): - """Test creating multiple users concurrently.""" - - async def create_user(i: int): - user_data = { - "name": f"Concurrent User {i}", - "email": f"concurrent{i}@example.com", - "age": 20 + i, - } - response = await client.post("/users", json=user_data) - return response.json() - - # Create 20 users concurrently - users = await asyncio.gather(*[create_user(i) for i in range(20)]) - - assert len(users) == 20 - - # Verify all users have unique IDs - user_ids = [user["id"] for user in users] - assert len(set(user_ids)) == 20 - - @pytest.mark.asyncio - async def test_concurrent_read_write(self, client: AsyncClient, cassandra_service): - """Test concurrent read and write operations.""" - # Create initial user - user_data = {"name": "Concurrent Test", "email": "concurrent@example.com", "age": 30} - - create_response = await client.post("/users", json=user_data) - user_id = create_response.json()["id"] - - async def read_user(): - response = await client.get(f"/users/{user_id}") - return response.json() - - async def update_user(age: int): - response = await client.put(f"/users/{user_id}", json={"age": age}) - return response.json() - - # Run mixed read/write operations concurrently - operations = [] - for i in range(10): - if i % 2 == 0: - operations.append(read_user()) - else: - operations.append(update_user(30 + i)) - - results = await asyncio.gather(*operations, return_exceptions=True) - - # Verify no errors occurred - for result in results: - assert not isinstance(result, Exception) diff --git a/tests/fastapi_integration/test_reconnection.py b/tests/fastapi_integration/test_reconnection.py deleted file mode 100644 index 7560b97..0000000 --- a/tests/fastapi_integration/test_reconnection.py +++ /dev/null @@ -1,319 +0,0 @@ -""" -Test FastAPI app reconnection behavior when Cassandra is stopped and restarted. - -This test demonstrates that the cassandra-driver's ExponentialReconnectionPolicy -handles reconnection automatically, which is critical for rolling restarts and DC outages. -""" - -import asyncio -import os -import time - -import httpx -import pytest -import pytest_asyncio - -from tests.utils.cassandra_control import CassandraControl - - -@pytest_asyncio.fixture(autouse=True) -async def ensure_cassandra_enabled(cassandra_container): - """Ensure Cassandra binary protocol is enabled before and after each test.""" - control = CassandraControl(cassandra_container) - - # Enable at start - control.enable_binary_protocol() - await asyncio.sleep(2) - - yield - - # Enable at end (cleanup) - control.enable_binary_protocol() - await asyncio.sleep(2) - - -class TestFastAPIReconnection: - """Test suite for FastAPI reconnection behavior.""" - - async def _wait_for_api_health( - self, client: httpx.AsyncClient, healthy: bool, timeout: int = 30 - ): - """Wait for API health check to reach desired state.""" - start_time = time.time() - while time.time() - start_time < timeout: - try: - response = await client.get("/health") - if response.status_code == 200: - data = response.json() - if data["cassandra_connected"] == healthy: - return True - except httpx.RequestError: - # Connection errors during reconnection - if not healthy: - return True - await asyncio.sleep(0.5) - return False - - async def _verify_apis_working(self, client: httpx.AsyncClient): - """Verify all APIs are working correctly.""" - # 1. Health check - health_resp = await client.get("/health") - assert health_resp.status_code == 200 - assert health_resp.json()["status"] == "healthy" - assert health_resp.json()["cassandra_connected"] is True - - # 2. Create user - user_data = {"name": "Reconnection Test User", "email": "reconnect@test.com", "age": 25} - create_resp = await client.post("/users", json=user_data) - assert create_resp.status_code == 201 - user_id = create_resp.json()["id"] - - # 3. Read user back - get_resp = await client.get(f"/users/{user_id}") - assert get_resp.status_code == 200 - assert get_resp.json()["name"] == user_data["name"] - - # 4. Test streaming - stream_resp = await client.get("/users/stream?limit=10&fetch_size=10") - assert stream_resp.status_code == 200 - stream_data = stream_resp.json() - assert stream_data["metadata"]["streaming_enabled"] is True - - return user_id - - async def _verify_apis_return_errors(self, client: httpx.AsyncClient): - """Verify APIs return appropriate errors when Cassandra is down.""" - # Wait a bit for existing connections to fail - await asyncio.sleep(3) - - # Try to create a user - should fail - user_data = {"name": "Should Fail", "email": "fail@test.com", "age": 30} - error_occurred = False - try: - create_resp = await client.post("/users", json=user_data, timeout=10.0) - print(f"Create user response during outage: {create_resp.status_code}") - if create_resp.status_code >= 500: - error_detail = create_resp.json().get("detail", "") - print(f"Got expected error: {error_detail}") - error_occurred = True - else: - # Might succeed if connection is still cached - print( - f"Warning: Create succeeded with status {create_resp.status_code} - connection might be cached" - ) - except (httpx.TimeoutException, httpx.RequestError) as e: - print(f"Create user failed with {type(e).__name__} - this is expected") - error_occurred = True - - # At least one operation should fail to confirm outage is detected - if not error_occurred: - # Try another operation that should fail - try: - # Force a new query that requires active connection - list_resp = await client.get("/users?limit=100", timeout=10.0) - if list_resp.status_code >= 500: - print(f"List users failed with {list_resp.status_code}") - error_occurred = True - except (httpx.TimeoutException, httpx.RequestError) as e: - print(f"List users failed with {type(e).__name__}") - error_occurred = True - - assert error_occurred, "Expected at least one operation to fail during Cassandra outage" - - def _get_cassandra_control(self, container): - """Get Cassandra control interface.""" - return CassandraControl(container) - - @pytest.mark.asyncio - async def test_cassandra_reconnection_behavior(self, app_client, cassandra_container): - """Test reconnection when Cassandra is stopped and restarted.""" - print("\n=== Testing Cassandra Reconnection Behavior ===") - - # Step 1: Verify everything works initially - print("\n1. Verifying all APIs work initially...") - user_id = await self._verify_apis_working(app_client) - print("✓ All APIs working correctly") - - # Step 2: Disable binary protocol (simulate Cassandra outage) - print("\n2. Disabling Cassandra binary protocol to simulate outage...") - control = self._get_cassandra_control(cassandra_container) - - if os.environ.get("CI") == "true": - print(" (In CI - cannot control service, skipping outage simulation)") - print("\n✓ Test completed (CI environment)") - return - - success, msg = control.disable_binary_protocol() - if not success: - pytest.fail(msg) - print("✓ Binary protocol disabled") - - # Give it a moment for binary protocol to be disabled - await asyncio.sleep(3) - - # Step 3: Verify APIs return appropriate errors - print("\n3. Verifying APIs return appropriate errors during outage...") - await self._verify_apis_return_errors(app_client) - print("✓ APIs returning appropriate error responses") - - # Step 4: Re-enable binary protocol - print("\n4. Re-enabling Cassandra binary protocol...") - success, msg = control.enable_binary_protocol() - if not success: - pytest.fail(msg) - print("✓ Binary protocol re-enabled") - - # Step 5: Wait for reconnection - reconnect_timeout = 30 # seconds - give enough time for exponential backoff - print(f"\n5. Waiting up to {reconnect_timeout} seconds for reconnection...") - - # Instead of checking health, try actual operations - reconnected = False - start_time = time.time() - while time.time() - start_time < reconnect_timeout: - try: - # Try a simple query - test_resp = await app_client.get("/users?limit=1", timeout=5.0) - if test_resp.status_code == 200: - print("✓ Reconnection successful!") - reconnected = True - break - except (httpx.TimeoutException, httpx.RequestError): - pass - await asyncio.sleep(2) - - if not reconnected: - pytest.fail(f"Failed to reconnect within {reconnect_timeout} seconds") - - # Step 6: Verify all APIs work again - print("\n6. Verifying all APIs work after recovery...") - # Verify the user we created earlier still exists - get_resp = await app_client.get(f"/users/{user_id}") - assert get_resp.status_code == 200 - assert get_resp.json()["name"] == "Reconnection Test User" - print("✓ Previously created user still accessible") - - # Create a new user to verify full functionality - await self._verify_apis_working(app_client) - print("✓ All APIs fully functional after recovery") - - print("\n✅ Reconnection test completed successfully!") - print(" - APIs handled outage gracefully with appropriate errors") - print(" - Automatic reconnection occurred after service restoration") - print(" - No manual intervention required") - - @pytest.mark.asyncio - async def test_multiple_reconnection_cycles(self, app_client, cassandra_container): - """Test multiple disconnect/reconnect cycles to ensure stability.""" - print("\n=== Testing Multiple Reconnection Cycles ===") - - cycles = 3 - for cycle in range(1, cycles + 1): - print(f"\n--- Cycle {cycle}/{cycles} ---") - - control = self._get_cassandra_control(cassandra_container) - - if os.environ.get("CI") == "true": - print(f"Cycle {cycle}: Skipping in CI environment") - continue - - # Disable - print("Disabling binary protocol...") - success, msg = control.disable_binary_protocol() - if not success: - pytest.fail(f"Cycle {cycle}: {msg}") - - await asyncio.sleep(2) - - # Verify unhealthy - health_resp = await app_client.get("/health") - assert health_resp.json()["cassandra_connected"] is False - print("✓ Cassandra reported as disconnected") - - # Re-enable - print("Re-enabling binary protocol...") - success, msg = control.enable_binary_protocol() - if not success: - pytest.fail(f"Cycle {cycle}: {msg}") - - # Wait for reconnection - if not await self._wait_for_api_health(app_client, healthy=True, timeout=10): - pytest.fail(f"Cycle {cycle}: Failed to reconnect") - print("✓ Reconnected successfully") - - # Verify functionality - user_data = { - "name": f"Cycle {cycle} User", - "email": f"cycle{cycle}@test.com", - "age": 20 + cycle, - } - create_resp = await app_client.post("/users", json=user_data) - assert create_resp.status_code == 201 - print(f"✓ Created user for cycle {cycle}") - - print(f"\n✅ Successfully completed {cycles} reconnection cycles!") - - @pytest.mark.asyncio - async def test_reconnection_during_active_requests(self, app_client, cassandra_container): - """Test reconnection behavior when requests are active during outage.""" - print("\n=== Testing Reconnection During Active Requests ===") - - async def continuous_requests(client: httpx.AsyncClient, duration: int): - """Make continuous requests for specified duration.""" - errors = [] - successes = 0 - start_time = time.time() - - while time.time() - start_time < duration: - try: - resp = await client.get("/health") - if resp.status_code == 200 and resp.json()["cassandra_connected"]: - successes += 1 - else: - errors.append("unhealthy") - except Exception as e: - errors.append(str(type(e).__name__)) - await asyncio.sleep(0.1) - - return successes, errors - - # Start continuous requests in background - request_task = asyncio.create_task(continuous_requests(app_client, 15)) - - # Wait a bit for requests to start - await asyncio.sleep(2) - - control = self._get_cassandra_control(cassandra_container) - - if os.environ.get("CI") == "true": - print("Skipping outage simulation in CI environment") - # Just let the requests run without outage - else: - # Disable binary protocol - print("Disabling binary protocol during active requests...") - control.disable_binary_protocol() - - # Wait for errors to accumulate - await asyncio.sleep(3) - - # Re-enable binary protocol - print("Re-enabling binary protocol...") - control.enable_binary_protocol() - - # Wait for task to complete - successes, errors = await request_task - - print("\nResults:") - print(f" - Successful requests: {successes}") - print(f" - Failed requests: {len(errors)}") - print(f" - Error types: {set(errors)}") - - # Should have both successes and failures - assert successes > 0, "Should have successful requests before and after outage" - assert len(errors) > 0, "Should have errors during outage" - - # Final health check should be healthy - health_resp = await app_client.get("/health") - assert health_resp.json()["cassandra_connected"] is True - - print("\n✅ Active requests handled reconnection gracefully!") diff --git a/tests/integration/.gitkeep b/tests/integration/.gitkeep deleted file mode 100644 index e229a66..0000000 --- a/tests/integration/.gitkeep +++ /dev/null @@ -1,2 +0,0 @@ -# This directory contains integration tests -# FastAPI tests have been moved to tests/fastapi/ diff --git a/tests/integration/README.md b/tests/integration/README.md deleted file mode 100644 index f6740b9..0000000 --- a/tests/integration/README.md +++ /dev/null @@ -1,112 +0,0 @@ -# Integration Tests - -This directory contains integration tests for the async-python-cassandra-client library. The tests run against a real Cassandra instance. - -## Prerequisites - -You need a running Cassandra instance on your machine. The tests expect Cassandra to be available on `localhost:9042` by default. - -## Running Tests - -### Quick Start - -```bash -# Start Cassandra (if not already running) -make cassandra-start - -# Run integration tests -make test-integration - -# Stop Cassandra when done -make cassandra-stop -``` - -### Using Existing Cassandra - -If you already have Cassandra running elsewhere: - -```bash -# Set the contact points -export CASSANDRA_CONTACT_POINTS=10.0.0.1,10.0.0.2 -export CASSANDRA_PORT=9042 # optional, defaults to 9042 - -# Run tests -make test-integration -``` - -## Makefile Targets - -- `make cassandra-start` - Start a Cassandra container using Docker or Podman -- `make cassandra-stop` - Stop and remove the Cassandra container -- `make cassandra-status` - Check if Cassandra is running and ready -- `make cassandra-wait` - Wait for Cassandra to be ready (starts it if needed) -- `make test-integration` - Run integration tests (waits for Cassandra automatically) -- `make test-integration-keep` - Run tests but keep containers running - -## Environment Variables - -- `CASSANDRA_CONTACT_POINTS` - Comma-separated list of Cassandra contact points (default: localhost) -- `CASSANDRA_PORT` - Cassandra port (default: 9042) -- `CONTAINER_RUNTIME` - Container runtime to use (auto-detected, can be docker or podman) -- `CASSANDRA_IMAGE` - Cassandra Docker image (default: cassandra:5) -- `CASSANDRA_CONTAINER_NAME` - Container name (default: async-cassandra-test) -- `SKIP_INTEGRATION_TESTS=1` - Skip integration tests entirely -- `KEEP_CONTAINERS=1` - Keep containers running after tests complete - -## Container Configuration - -When using `make cassandra-start`, the container is configured with: -- Image: `cassandra:5` (latest Cassandra 5.x) -- Port: `9042` (default Cassandra port) -- Cluster name: `TestCluster` -- Datacenter: `datacenter1` -- Snitch: `SimpleSnitch` - -## Writing Integration Tests - -Integration tests should: -1. Use the `cassandra_session` fixture for a ready-to-use session -2. Clean up any test data they create -3. Be marked with `@pytest.mark.integration` -4. Handle transient network errors gracefully - -Example: -```python -@pytest.mark.integration -@pytest.mark.asyncio -async def test_example(cassandra_session): - result = await cassandra_session.execute("SELECT * FROM system.local") - assert result.one() is not None -``` - -## Troubleshooting - -### Cassandra Not Available - -If tests fail with "Cassandra is not available": - -1. Check if Cassandra is running: `make cassandra-status` -2. Start Cassandra: `make cassandra-start` -3. Wait for it to be ready: `make cassandra-wait` - -### Port Conflicts - -If port 9042 is already in use by another service: -1. Stop the conflicting service, or -2. Use a different Cassandra instance and set `CASSANDRA_CONTACT_POINTS` - -### Container Issues - -If using containers and having issues: -1. Check container logs: `docker logs async-cassandra-test` or `podman logs async-cassandra-test` -2. Ensure you have enough available memory (at least 1GB free) -3. Try removing and recreating: `make cassandra-stop && make cassandra-start` - -### Docker vs Podman - -The Makefile automatically detects whether you have Docker or Podman installed. If you have both and want to force one: - -```bash -export CONTAINER_RUNTIME=podman # or docker -make cassandra-start -``` diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py deleted file mode 100644 index 5cc31ba..0000000 --- a/tests/integration/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Integration tests for async-cassandra.""" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py deleted file mode 100644 index 3bfe2c4..0000000 --- a/tests/integration/conftest.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Pytest configuration for integration tests. -""" - -import os -import socket -import sys -from pathlib import Path - -import pytest -import pytest_asyncio - -from async_cassandra import AsyncCluster - -# Add parent directory to path for test_utils import -sys.path.insert(0, str(Path(__file__).parent.parent)) -from test_utils import ( # noqa: E402 - TestTableManager, - generate_unique_keyspace, - generate_unique_table, -) - - -def pytest_configure(config): - """Configure pytest for integration tests.""" - # Skip if explicitly disabled - if os.environ.get("SKIP_INTEGRATION_TESTS", "").lower() in ("1", "true", "yes"): - pytest.exit("Skipping integration tests (SKIP_INTEGRATION_TESTS is set)", 0) - - # Store shared keyspace name - config.shared_test_keyspace = "integration_test" - - # Get contact points from environment - # Force IPv4 by replacing localhost with 127.0.0.1 - contact_points = os.environ.get("CASSANDRA_CONTACT_POINTS", "127.0.0.1").split(",") - config.cassandra_contact_points = [ - "127.0.0.1" if cp.strip() == "localhost" else cp.strip() for cp in contact_points - ] - - # Check if Cassandra is available - cassandra_port = int(os.environ.get("CASSANDRA_PORT", "9042")) - available = False - for contact_point in config.cassandra_contact_points: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(2) - result = sock.connect_ex((contact_point, cassandra_port)) - sock.close() - if result == 0: - available = True - print(f"Found Cassandra on {contact_point}:{cassandra_port}") - break - except Exception: - pass - - if not available: - pytest.exit( - f"Cassandra is not available on {config.cassandra_contact_points}:{cassandra_port}\n" - f"Please start Cassandra using: make cassandra-start\n" - f"Or set CASSANDRA_CONTACT_POINTS environment variable to point to your Cassandra instance", - 1, - ) - - -@pytest_asyncio.fixture(scope="session") -async def shared_cluster(pytestconfig): - """Create a shared cluster for all integration tests.""" - cluster = AsyncCluster( - contact_points=pytestconfig.cassandra_contact_points, - protocol_version=5, - connect_timeout=10.0, - ) - yield cluster - await cluster.shutdown() - - -@pytest_asyncio.fixture(scope="session") -async def shared_keyspace_setup(shared_cluster, pytestconfig): - """Create shared keyspace for all integration tests.""" - session = await shared_cluster.connect() - - try: - # Create the shared keyspace - keyspace_name = pytestconfig.shared_test_keyspace - await session.execute( - f""" - CREATE KEYSPACE IF NOT EXISTS {keyspace_name} - WITH REPLICATION = {{'class': 'SimpleStrategy', 'replication_factor': 1}} - """ - ) - print(f"Created shared keyspace: {keyspace_name}") - - yield keyspace_name - - finally: - # Clean up the keyspace after all tests - try: - await session.execute(f"DROP KEYSPACE IF EXISTS {pytestconfig.shared_test_keyspace}") - print(f"Dropped shared keyspace: {pytestconfig.shared_test_keyspace}") - except Exception as e: - print(f"Warning: Failed to drop shared keyspace: {e}") - - await session.close() - - -@pytest_asyncio.fixture(scope="function") -async def cassandra_cluster(shared_cluster): - """Use the shared cluster for testing.""" - # Just pass through the shared cluster - don't create a new one - yield shared_cluster - - -@pytest_asyncio.fixture(scope="function") -async def cassandra_session(cassandra_cluster, shared_keyspace_setup, pytestconfig): - """Create an async Cassandra session using shared keyspace with isolated tables.""" - session = await cassandra_cluster.connect() - - # Use the shared keyspace - keyspace = pytestconfig.shared_test_keyspace - await session.set_keyspace(keyspace) - - # Track tables created for this test - created_tables = [] - - # Create a unique users table for tests that expect it - users_table = generate_unique_table("users") - await session.execute( - f""" - CREATE TABLE IF NOT EXISTS {users_table} ( - id UUID PRIMARY KEY, - name TEXT, - email TEXT, - age INT - ) - """ - ) - created_tables.append(users_table) - - # Store the table name in session for tests to use - session._test_users_table = users_table - session._created_tables = created_tables - - yield session - - # Cleanup tables after test - try: - for table in created_tables: - await session.execute(f"DROP TABLE IF EXISTS {table}") - except Exception: - pass - - # Don't close the session - it's from the shared cluster - # try: - # await session.close() - # except Exception: - # pass - - -@pytest_asyncio.fixture(scope="function") -async def test_table_manager(cassandra_cluster, shared_keyspace_setup, pytestconfig): - """Provide a test table manager for isolated table creation.""" - session = await cassandra_cluster.connect() - - # Use the shared keyspace - keyspace = pytestconfig.shared_test_keyspace - await session.set_keyspace(keyspace) - - async with TestTableManager(session, keyspace=keyspace, use_shared_keyspace=True) as manager: - yield manager - - # Don't close the session - it's from the shared cluster - # await session.close() - - -@pytest.fixture -def unique_keyspace(): - """Generate a unique keyspace name for test isolation.""" - return generate_unique_keyspace() - - -@pytest_asyncio.fixture(scope="function") -async def session_with_keyspace(cassandra_cluster, shared_keyspace_setup, pytestconfig): - """Create a session with shared keyspace already set.""" - session = await cassandra_cluster.connect() - keyspace = pytestconfig.shared_test_keyspace - - await session.set_keyspace(keyspace) - - # Track tables created for cleanup - session._created_tables = [] - - yield session, keyspace - - # Cleanup tables - try: - for table in getattr(session, "_created_tables", []): - await session.execute(f"DROP TABLE IF EXISTS {table}") - except Exception: - pass - - # Don't close the session - it's from the shared cluster - # try: - # await session.close() - # except Exception: - # pass diff --git a/tests/integration/test_basic_operations.py b/tests/integration/test_basic_operations.py deleted file mode 100644 index 2f9b3c3..0000000 --- a/tests/integration/test_basic_operations.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Integration tests for basic Cassandra operations. - -This file focuses on connection management, error handling, async patterns, -and concurrent operations. Basic CRUD operations have been moved to -test_crud_operations.py. -""" - -import uuid - -import pytest -from cassandra import InvalidRequest -from test_utils import generate_unique_table - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestBasicOperations: - """Test connection, error handling, and async patterns with real Cassandra.""" - - async def test_connection_and_keyspace( - self, cassandra_cluster, shared_keyspace_setup, pytestconfig - ): - """ - Test connecting to Cassandra and using shared keyspace. - - What this tests: - --------------- - 1. Cluster connection works - 2. Keyspace can be set - 3. Tables can be created - 4. Cleanup is performed - - Why this matters: - ---------------- - Connection management is fundamental: - - Must handle network issues - - Keyspace isolation important - - Resource cleanup critical - - Basic connectivity is the - foundation of all operations. - """ - session = await cassandra_cluster.connect() - - try: - # Use the shared keyspace - keyspace = pytestconfig.shared_test_keyspace - await session.set_keyspace(keyspace) - assert session.keyspace == keyspace - - # Create a test table in the shared keyspace - table_name = generate_unique_table("test_conn") - try: - await session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - data TEXT - ) - """ - ) - - # Verify table exists - await session.execute(f"SELECT * FROM {table_name} LIMIT 1") - - except Exception as e: - pytest.fail(f"Failed to create or query table: {e}") - finally: - # Cleanup table - await session.execute(f"DROP TABLE IF EXISTS {table_name}") - finally: - await session.close() - - async def test_async_iteration(self, cassandra_session): - """ - Test async iteration over results with proper patterns. - - What this tests: - --------------- - 1. Async for loop works - 2. Multiple rows handled - 3. Row attributes accessible - 4. No blocking in iteration - - Why this matters: - ---------------- - Async iteration enables: - - Non-blocking data processing - - Memory-efficient streaming - - Responsive applications - - Critical for handling large - result sets efficiently. - """ - # Use the unique users table created for this test - users_table = cassandra_session._test_users_table - - try: - # Insert test data - insert_stmt = await cassandra_session.prepare( - f""" - INSERT INTO {users_table} (id, name, email, age) - VALUES (?, ?, ?, ?) - """ - ) - - # Insert users with error handling - for i in range(10): - try: - await cassandra_session.execute( - insert_stmt, [uuid.uuid4(), f"User{i}", f"user{i}@example.com", 20 + i] - ) - except Exception as e: - pytest.fail(f"Failed to insert User{i}: {e}") - - # Select all users - select_all_stmt = await cassandra_session.prepare(f"SELECT * FROM {users_table}") - - try: - result = await cassandra_session.execute(select_all_stmt) - - # Iterate asynchronously with error handling - count = 0 - async for row in result: - assert hasattr(row, "name") - assert row.name.startswith("User") - count += 1 - - # We should have at least 10 users (may have more from other tests) - assert count >= 10 - except Exception as e: - pytest.fail(f"Failed to iterate over results: {e}") - - except Exception as e: - pytest.fail(f"Test setup failed: {e}") - - async def test_error_handling(self, cassandra_session): - """ - Test error handling for invalid queries. - - What this tests: - --------------- - 1. Invalid table errors caught - 2. Invalid keyspace errors caught - 3. Syntax errors propagated - 4. Error messages preserved - - Why this matters: - ---------------- - Proper error handling enables: - - Debugging query issues - - Graceful failure modes - - Clear error messages - - Applications need clear errors - to handle failures properly. - """ - # Test invalid table query - with pytest.raises(InvalidRequest) as exc_info: - await cassandra_session.execute("SELECT * FROM non_existent_table") - assert "does not exist" in str(exc_info.value) or "unconfigured table" in str( - exc_info.value - ) - - # Test invalid keyspace - should fail - with pytest.raises(InvalidRequest) as exc_info: - await cassandra_session.set_keyspace("non_existent_keyspace") - assert "Keyspace" in str(exc_info.value) or "does not exist" in str(exc_info.value) - - # Test syntax error - with pytest.raises(Exception) as exc_info: - await cassandra_session.execute("INVALID SQL QUERY") - # Could be SyntaxException or InvalidRequest depending on driver version - assert "Syntax" in str(exc_info.value) or "Invalid" in str(exc_info.value) diff --git a/tests/integration/test_batch_and_lwt_operations.py b/tests/integration/test_batch_and_lwt_operations.py deleted file mode 100644 index 1a10d87..0000000 --- a/tests/integration/test_batch_and_lwt_operations.py +++ /dev/null @@ -1,1115 +0,0 @@ -""" -Consolidated integration tests for batch and LWT (Lightweight Transaction) operations. - -This module combines atomic operation tests from multiple files, focusing on -batch operations and lightweight transactions (conditional statements). - -Tests consolidated from: -- test_batch_operations.py - All batch operation types -- test_lwt_operations.py - All lightweight transaction operations - -Test Organization: -================== -1. Batch Operations - LOGGED, UNLOGGED, and COUNTER batches -2. Lightweight Transactions - IF EXISTS, IF NOT EXISTS, conditional updates -3. Atomic Operation Patterns - Combined usage patterns -4. Error Scenarios - Invalid combinations and error handling -""" - -import asyncio -import time -import uuid -from datetime import datetime, timezone - -import pytest -from cassandra import InvalidRequest -from cassandra.query import BatchStatement, BatchType, ConsistencyLevel, SimpleStatement -from test_utils import generate_unique_table - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestBatchOperations: - """Test batch operations with real Cassandra.""" - - # ======================================== - # Basic Batch Operations - # ======================================== - - async def test_logged_batch(self, cassandra_session, shared_keyspace_setup): - """ - Test LOGGED batch operations for atomicity. - - What this tests: - --------------- - 1. LOGGED batch guarantees atomicity - 2. All statements succeed or fail together - 3. Batch with prepared statements - 4. Performance implications - - Why this matters: - ---------------- - LOGGED batches provide ACID guarantees at the cost of - performance. Used for related mutations that must succeed together. - """ - # Create test table - table_name = generate_unique_table("test_logged_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - partition_key TEXT, - clustering_key INT, - value TEXT, - PRIMARY KEY (partition_key, clustering_key) - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (partition_key, clustering_key, value) VALUES (?, ?, ?)" - ) - - # Create LOGGED batch (default) - batch = BatchStatement(batch_type=BatchType.LOGGED) - partition = "batch_test" - - # Add multiple statements - for i in range(5): - batch.add(insert_stmt, (partition, i, f"value_{i}")) - - # Execute batch - await cassandra_session.execute(batch) - - # Verify all inserts succeeded atomically - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE partition_key = %s", (partition,) - ) - rows = list(result) - assert len(rows) == 5 - - # Verify order and values - rows.sort(key=lambda r: r.clustering_key) - for i, row in enumerate(rows): - assert row.clustering_key == i - assert row.value == f"value_{i}" - - async def test_unlogged_batch(self, cassandra_session, shared_keyspace_setup): - """ - Test UNLOGGED batch operations for performance. - - What this tests: - --------------- - 1. UNLOGGED batch for performance - 2. No atomicity guarantees - 3. Multiple partitions in batch - 4. Large batch handling - - Why this matters: - ---------------- - UNLOGGED batches offer better performance but no atomicity. - Best for mutations to different partitions. - """ - # Create test table - table_name = generate_unique_table("test_unlogged_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - category TEXT, - value INT, - created_at TIMESTAMP - ) - """ - ) - - # Prepare statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, category, value, created_at) VALUES (?, ?, ?, ?)" - ) - - # Create UNLOGGED batch - batch = BatchStatement(batch_type=BatchType.UNLOGGED) - ids = [] - - # Add many statements (different partitions) - for i in range(50): - id = uuid.uuid4() - ids.append(id) - batch.add(insert_stmt, (id, f"cat_{i % 5}", i, datetime.now(timezone.utc))) - - # Execute batch - start = time.time() - await cassandra_session.execute(batch) - duration = time.time() - start - - # Verify inserts (may not all succeed in failure scenarios) - success_count = 0 - for id in ids: - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (id,) - ) - if result.one(): - success_count += 1 - - # In normal conditions, all should succeed - assert success_count == 50 - print(f"UNLOGGED batch of 50 inserts took {duration:.3f}s") - - async def test_counter_batch(self, cassandra_session, shared_keyspace_setup): - """ - Test COUNTER batch operations. - - What this tests: - --------------- - 1. Counter-only batches - 2. Multiple counter updates - 3. Counter batch atomicity - 4. Concurrent counter updates - - Why this matters: - ---------------- - Counter batches have special semantics and restrictions. - They can only contain counter operations. - """ - # Create counter table - table_name = generate_unique_table("test_counter_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - count1 COUNTER, - count2 COUNTER, - count3 COUNTER - ) - """ - ) - - # Prepare counter update statements - update1 = await cassandra_session.prepare( - f"UPDATE {table_name} SET count1 = count1 + ? WHERE id = ?" - ) - update2 = await cassandra_session.prepare( - f"UPDATE {table_name} SET count2 = count2 + ? WHERE id = ?" - ) - update3 = await cassandra_session.prepare( - f"UPDATE {table_name} SET count3 = count3 + ? WHERE id = ?" - ) - - # Create COUNTER batch - batch = BatchStatement(batch_type=BatchType.COUNTER) - counter_id = "test_counter" - - # Add counter updates - batch.add(update1, (10, counter_id)) - batch.add(update2, (20, counter_id)) - batch.add(update3, (30, counter_id)) - - # Execute batch - await cassandra_session.execute(batch) - - # Verify counter values - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (counter_id,) - ) - row = result.one() - assert row.count1 == 10 - assert row.count2 == 20 - assert row.count3 == 30 - - # Test concurrent counter batches - async def increment_counters(increment): - batch = BatchStatement(batch_type=BatchType.COUNTER) - batch.add(update1, (increment, counter_id)) - batch.add(update2, (increment * 2, counter_id)) - batch.add(update3, (increment * 3, counter_id)) - await cassandra_session.execute(batch) - - # Run concurrent increments - await asyncio.gather(*[increment_counters(1) for _ in range(10)]) - - # Verify final values - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (counter_id,) - ) - row = result.one() - assert row.count1 == 20 # 10 + 10*1 - assert row.count2 == 40 # 20 + 10*2 - assert row.count3 == 60 # 30 + 10*3 - - # ======================================== - # Advanced Batch Features - # ======================================== - - async def test_batch_with_consistency_levels(self, cassandra_session, shared_keyspace_setup): - """ - Test batch operations with different consistency levels. - - What this tests: - --------------- - 1. Batch consistency level configuration - 2. Impact on atomicity guarantees - 3. Performance vs consistency trade-offs - - Why this matters: - ---------------- - Consistency levels affect batch behavior and guarantees. - """ - # Create test table - table_name = generate_unique_table("test_batch_consistency") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Test different consistency levels - consistency_levels = [ - ConsistencyLevel.ONE, - ConsistencyLevel.QUORUM, - ConsistencyLevel.ALL, - ] - - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, data) VALUES (?, ?)" - ) - - for cl in consistency_levels: - batch = BatchStatement(consistency_level=cl) - batch_id = uuid.uuid4() - - # Add statement to batch - cl_name = ( - ConsistencyLevel.name_of(cl) if hasattr(ConsistencyLevel, "name_of") else str(cl) - ) - batch.add(insert_stmt, (batch_id, f"consistency_{cl_name}")) - - # Execute with specific consistency - await cassandra_session.execute(batch) - - # Verify insert - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (batch_id,) - ) - assert result.one().data == f"consistency_{cl_name}" - - async def test_batch_with_custom_timestamp(self, cassandra_session, shared_keyspace_setup): - """ - Test batch operations with custom timestamps. - - What this tests: - --------------- - 1. Custom timestamp in batches - 2. Timestamp consistency across batch - 3. Time-based conflict resolution - - Why this matters: - ---------------- - Custom timestamps allow for precise control over - write ordering and conflict resolution. - """ - # Create test table - table_name = generate_unique_table("test_batch_timestamp") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - value INT, - updated_at TIMESTAMP - ) - """ - ) - - row_id = "timestamp_test" - - # First write with current timestamp - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, value, updated_at) VALUES (%s, %s, toTimestamp(now()))", - (row_id, 100), - ) - - # Custom timestamp in microseconds (older than current) - custom_timestamp = int((time.time() - 3600) * 1000000) # 1 hour ago - - insert_stmt = SimpleStatement( - f"INSERT INTO {table_name} (id, value, updated_at) VALUES (%s, %s, %s) USING TIMESTAMP {custom_timestamp}", - ) - - # This write should be ignored due to older timestamp - await cassandra_session.execute(insert_stmt, (row_id, 50, datetime.now(timezone.utc))) - - # Verify the newer value wins - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (row_id,) - ) - assert result.one().value == 100 # Original value retained - - # Now use newer timestamp - newer_timestamp = int((time.time() + 3600) * 1000000) # 1 hour future - newer_stmt = SimpleStatement( - f"INSERT INTO {table_name} (id, value) VALUES (%s, %s) USING TIMESTAMP {newer_timestamp}", - ) - - await cassandra_session.execute(newer_stmt, (row_id, 200)) - - # Verify newer timestamp wins - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (row_id,) - ) - assert result.one().value == 200 - - async def test_large_batch_warning(self, cassandra_session, shared_keyspace_setup): - """ - Test large batch size warnings and limits. - - What this tests: - --------------- - 1. Batch size thresholds - 2. Warning generation - 3. Performance impact of large batches - - Why this matters: - ---------------- - Large batches can cause performance issues and - coordinator node stress. - """ - # Create test table - table_name = generate_unique_table("test_large_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Create a large batch - batch = BatchStatement(batch_type=BatchType.UNLOGGED) - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, data) VALUES (?, ?)" - ) - - # Add many statements with large data - # Reduce size to avoid batch too large error - large_data = "x" * 100 # 100 bytes per row - for i in range(50): # 5KB total - batch.add(insert_stmt, (uuid.uuid4(), large_data)) - - # Execute large batch (may generate warnings) - await cassandra_session.execute(batch) - - # Note: In production, monitor for batch size warnings in logs - - # ======================================== - # Batch Error Scenarios - # ======================================== - - async def test_mixed_batch_types_error(self, cassandra_session, shared_keyspace_setup): - """ - Test error handling for invalid batch combinations. - - What this tests: - --------------- - 1. Mixing counter and regular operations - 2. Error propagation - 3. Batch validation - - Why this matters: - ---------------- - Cassandra enforces strict rules about batch content. - Counter and regular operations cannot be mixed. - """ - # Create regular and counter tables - regular_table = generate_unique_table("test_regular") - counter_table = generate_unique_table("test_counter") - - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {regular_table} ( - id TEXT PRIMARY KEY, - value INT - ) - """ - ) - - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {counter_table} ( - id TEXT PRIMARY KEY, - count COUNTER - ) - """ - ) - - # Try to mix regular and counter operations - batch = BatchStatement() - - # This should fail - cannot mix regular and counter operations - regular_stmt = await cassandra_session.prepare( - f"INSERT INTO {regular_table} (id, value) VALUES (?, ?)" - ) - counter_stmt = await cassandra_session.prepare( - f"UPDATE {counter_table} SET count = count + ? WHERE id = ?" - ) - - batch.add(regular_stmt, ("test1", 100)) - batch.add(counter_stmt, (1, "test1")) - - # Should raise InvalidRequest - with pytest.raises(InvalidRequest) as exc_info: - await cassandra_session.execute(batch) - - assert "counter" in str(exc_info.value).lower() - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestLWTOperations: - """Test Lightweight Transaction (LWT) operations with real Cassandra.""" - - # ======================================== - # Basic LWT Operations - # ======================================== - - async def test_insert_if_not_exists(self, cassandra_session, shared_keyspace_setup): - """ - Test INSERT IF NOT EXISTS operations. - - What this tests: - --------------- - 1. Successful conditional insert - 2. Failed conditional insert (already exists) - 3. Result parsing ([applied] column) - 4. Race condition handling - - Why this matters: - ---------------- - IF NOT EXISTS prevents duplicate inserts and provides - atomic check-and-set semantics. - """ - # Create test table - table_name = generate_unique_table("test_lwt_insert") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - username TEXT, - email TEXT, - created_at TIMESTAMP - ) - """ - ) - - # Prepare conditional insert - insert_stmt = await cassandra_session.prepare( - f""" - INSERT INTO {table_name} (id, username, email, created_at) - VALUES (?, ?, ?, ?) - IF NOT EXISTS - """ - ) - - user_id = uuid.uuid4() - username = "testuser" - email = "test@example.com" - created = datetime.now(timezone.utc) - - # First insert should succeed - result = await cassandra_session.execute(insert_stmt, (user_id, username, email, created)) - row = result.one() - assert row.applied is True - - # Second insert with same ID should fail - result2 = await cassandra_session.execute( - insert_stmt, (user_id, "different", "different@example.com", created) - ) - row2 = result2.one() - assert row2.applied is False - - # Failed insert returns existing values - assert row2.username == username - assert row2.email == email - - # Verify data integrity - result3 = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (user_id,) - ) - final_row = result3.one() - assert final_row.username == username # Original value preserved - assert final_row.email == email - - async def test_update_if_condition(self, cassandra_session, shared_keyspace_setup): - """ - Test UPDATE IF condition operations. - - What this tests: - --------------- - 1. Successful conditional update - 2. Failed conditional update - 3. Multi-column conditions - 4. NULL value conditions - - Why this matters: - ---------------- - Conditional updates enable optimistic locking and - safe state transitions. - """ - # Create test table - table_name = generate_unique_table("test_lwt_update") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - status TEXT, - version INT, - updated_by TEXT, - updated_at TIMESTAMP - ) - """ - ) - - # Insert initial data - doc_id = uuid.uuid4() - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, status, version, updated_by) VALUES (?, ?, ?, ?)" - ) - await cassandra_session.execute(insert_stmt, (doc_id, "draft", 1, "user1")) - - # Conditional update - should succeed - update_stmt = await cassandra_session.prepare( - f""" - UPDATE {table_name} - SET status = ?, version = ?, updated_by = ?, updated_at = ? - WHERE id = ? - IF status = ? AND version = ? - """ - ) - - result = await cassandra_session.execute( - update_stmt, ("published", 2, "user2", datetime.now(timezone.utc), doc_id, "draft", 1) - ) - row = result.one() - - # Debug: print the actual row to understand structure - # print(f"First update result: {row}") - - # Check if update was applied - if hasattr(row, "applied"): - applied = row.applied - elif isinstance(row[0], bool): - applied = row[0] - else: - # Try to find the [applied] column by name - applied = getattr(row, "[applied]", None) - if applied is None and hasattr(row, "_asdict"): - row_dict = row._asdict() - applied = row_dict.get("[applied]", row_dict.get("applied", False)) - - if not applied: - # First update failed, let's check why - verify_result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (doc_id,) - ) - current = verify_result.one() - pytest.skip( - f"First LWT update failed. Current state: status={current.status}, version={current.version}" - ) - - # Verify the update worked - verify_result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (doc_id,) - ) - current_state = verify_result.one() - assert current_state.status == "published" - assert current_state.version == 2 - - # Try to update with wrong version - should fail - result2 = await cassandra_session.execute( - update_stmt, - ("archived", 3, "user3", datetime.now(timezone.utc), doc_id, "published", 1), - ) - row2 = result2.one() - # This should fail and return current values - assert row2[0] is False or getattr(row2, "applied", True) is False - - # Update with correct version - should succeed - result3 = await cassandra_session.execute( - update_stmt, - ("archived", 3, "user3", datetime.now(timezone.utc), doc_id, "published", 2), - ) - result3.one() # Check that it succeeded - - # Verify final state - final_result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (doc_id,) - ) - final_state = final_result.one() - assert final_state.status == "archived" - assert final_state.version == 3 - - async def test_delete_if_exists(self, cassandra_session, shared_keyspace_setup): - """ - Test DELETE IF EXISTS operations. - - What this tests: - --------------- - 1. Successful conditional delete - 2. Failed conditional delete (doesn't exist) - 3. DELETE IF with column conditions - - Why this matters: - ---------------- - Conditional deletes prevent removing non-existent data - and enable safe cleanup operations. - """ - # Create test table - table_name = generate_unique_table("test_lwt_delete") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - type TEXT, - active BOOLEAN - ) - """ - ) - - # Insert test data - record_id = uuid.uuid4() - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, type, active) VALUES (%s, %s, %s)", - (record_id, "temporary", True), - ) - - # Conditional delete - only if inactive - delete_stmt = await cassandra_session.prepare( - f"DELETE FROM {table_name} WHERE id = ? IF active = ?" - ) - - # Should fail - record is active - result = await cassandra_session.execute(delete_stmt, (record_id, False)) - assert result.one().applied is False - - # Update to inactive - await cassandra_session.execute( - f"UPDATE {table_name} SET active = false WHERE id = %s", (record_id,) - ) - - # Now delete should succeed - result2 = await cassandra_session.execute(delete_stmt, (record_id, False)) - assert result2.one()[0] is True # [applied] column - - # Verify deletion - result3 = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (record_id,) - ) - row = result3.one() - # In Cassandra, deleted rows may still appear with NULL/false values - # The behavior depends on Cassandra version and tombstone handling - if row is not None: - # Either all columns are NULL or active is False (due to deletion) - assert (row.type is None and row.active is None) or row.active is False - - # ======================================== - # Advanced LWT Patterns - # ======================================== - - async def test_concurrent_lwt_operations(self, cassandra_session, shared_keyspace_setup): - """ - Test concurrent LWT operations and race conditions. - - What this tests: - --------------- - 1. Multiple concurrent IF NOT EXISTS - 2. Race condition resolution - 3. Consistency guarantees - 4. Performance impact - - Why this matters: - ---------------- - LWTs provide linearizable consistency but at a - performance cost. Understanding race behavior is critical. - """ - # Create test table - table_name = generate_unique_table("test_concurrent_lwt") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - resource_id TEXT PRIMARY KEY, - owner TEXT, - acquired_at TIMESTAMP - ) - """ - ) - - # Prepare acquire statement - acquire_stmt = await cassandra_session.prepare( - f""" - INSERT INTO {table_name} (resource_id, owner, acquired_at) - VALUES (?, ?, ?) - IF NOT EXISTS - """ - ) - - resource = "shared_resource" - - # Simulate concurrent acquisition attempts - async def try_acquire(worker_id): - result = await cassandra_session.execute( - acquire_stmt, (resource, f"worker_{worker_id}", datetime.now(timezone.utc)) - ) - return worker_id, result.one().applied - - # Run many concurrent attempts - results = await asyncio.gather(*[try_acquire(i) for i in range(20)], return_exceptions=True) - - # Analyze results - successful = [] - failed = [] - for result in results: - if isinstance(result, Exception): - continue # Skip exceptions - if isinstance(result, tuple) and len(result) == 2: - w, r = result - if r: - successful.append((w, r)) - else: - failed.append((w, r)) - - # Exactly one should succeed - assert len(successful) == 1 - assert len(failed) == 19 - - # Verify final state - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE resource_id = %s", (resource,) - ) - row = result.one() - winner_id = successful[0][0] - assert row.owner == f"worker_{winner_id}" - - async def test_optimistic_locking_pattern(self, cassandra_session, shared_keyspace_setup): - """ - Test optimistic locking pattern with LWT. - - What this tests: - --------------- - 1. Read-modify-write with version checking - 2. Retry logic for conflicts - 3. ABA problem prevention - 4. Performance considerations - - Why this matters: - ---------------- - Optimistic locking is a common pattern for handling - concurrent modifications without distributed locks. - """ - # Create versioned document table - table_name = generate_unique_table("test_optimistic_lock") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - content TEXT, - version BIGINT, - last_modified TIMESTAMP - ) - """ - ) - - # Insert document - doc_id = uuid.uuid4() - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, content, version, last_modified) VALUES (%s, %s, %s, %s)", - (doc_id, "Initial content", 1, datetime.now(timezone.utc)), - ) - - # Prepare optimistic update - update_stmt = await cassandra_session.prepare( - f""" - UPDATE {table_name} - SET content = ?, version = ?, last_modified = ? - WHERE id = ? - IF version = ? - """ - ) - - # Simulate concurrent modifications - async def modify_document(modification): - max_retries = 3 - for attempt in range(max_retries): - # Read current state - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (doc_id,) - ) - current = result.one() - - # Modify content - new_content = f"{current.content} + {modification}" - new_version = current.version + 1 - - # Try to update - update_result = await cassandra_session.execute( - update_stmt, - (new_content, new_version, datetime.now(timezone.utc), doc_id, current.version), - ) - - update_row = update_result.one() - # Check if update was applied - if hasattr(update_row, "applied"): - applied = update_row.applied - else: - applied = update_row[0] - - if applied: - return True - - # Retry with exponential backoff - await asyncio.sleep(0.1 * (2**attempt)) - - return False - - # Run concurrent modifications - results = await asyncio.gather(*[modify_document(f"Mod{i}") for i in range(5)]) - - # Count successful updates - successful_updates = sum(1 for r in results if r is True) - - # Verify final state - final = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (doc_id,) - ) - final_row = final.one() - - # Version should have increased by the number of successful updates - assert final_row.version == 1 + successful_updates - - # If no updates succeeded, skip the test - if successful_updates == 0: - pytest.skip("No concurrent updates succeeded - may be timing/load issue") - - # Content should contain modifications if any succeeded - if successful_updates > 0: - assert "Mod" in final_row.content - - # ======================================== - # LWT Error Scenarios - # ======================================== - - async def test_lwt_timeout_handling(self, cassandra_session, shared_keyspace_setup): - """ - Test LWT timeout scenarios and handling. - - What this tests: - --------------- - 1. LWT with short timeout - 2. Timeout error propagation - 3. State consistency after timeout - - Why this matters: - ---------------- - LWTs involve multiple round trips and can timeout. - Understanding timeout behavior is crucial. - """ - # Create test table - table_name = generate_unique_table("test_lwt_timeout") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - value TEXT - ) - """ - ) - - # Prepare LWT statement with very short timeout - insert_stmt = SimpleStatement( - f"INSERT INTO {table_name} (id, value) VALUES (%s, %s) IF NOT EXISTS", - consistency_level=ConsistencyLevel.QUORUM, - ) - - test_id = uuid.uuid4() - - # Normal LWT should work - result = await cassandra_session.execute(insert_stmt, (test_id, "test_value")) - assert result.one()[0] is True # [applied] column - - # Note: Actually triggering timeout requires network latency simulation - # This test documents the expected behavior - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestAtomicPatterns: - """Test combined atomic operation patterns.""" - - async def test_lwt_not_supported_in_batch(self, cassandra_session, shared_keyspace_setup): - """ - Test that LWT operations are not supported in batches. - - What this tests: - --------------- - 1. LWT in batch raises error - 2. Error message clarity - 3. Alternative patterns - - Why this matters: - ---------------- - This is a common mistake. LWTs cannot be used in batches - due to their special consistency requirements. - """ - # Create test table - table_name = generate_unique_table("test_lwt_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - value TEXT - ) - """ - ) - - # Try to use LWT in batch - batch = BatchStatement() - - # This should fail - use raw query to ensure it's recognized as LWT - test_id = uuid.uuid4() - lwt_query = f"INSERT INTO {table_name} (id, value) VALUES ({test_id}, 'test') IF NOT EXISTS" - - batch.add(SimpleStatement(lwt_query)) - - # Some Cassandra versions might not error immediately, so check result - try: - await cassandra_session.execute(batch) - # If it succeeded, it shouldn't have applied the LWT semantics - # This is actually unexpected, but let's handle it - pytest.skip("This Cassandra version seems to allow LWT in batch") - except InvalidRequest as e: - # This is what we expect - assert ( - "conditional" in str(e).lower() - or "lwt" in str(e).lower() - or "batch" in str(e).lower() - ) - - async def test_read_before_write_pattern(self, cassandra_session, shared_keyspace_setup): - """ - Test read-before-write pattern for complex updates. - - What this tests: - --------------- - 1. Read current state - 2. Apply business logic - 3. Conditional update based on read - 4. Retry on conflict - - Why this matters: - ---------------- - Complex business logic often requires reading current - state before deciding on updates. - """ - # Create account table - table_name = generate_unique_table("test_account") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - account_id UUID PRIMARY KEY, - balance DECIMAL, - status TEXT, - version BIGINT - ) - """ - ) - - # Create account - account_id = uuid.uuid4() - initial_balance = 1000.0 - await cassandra_session.execute( - f"INSERT INTO {table_name} (account_id, balance, status, version) VALUES (%s, %s, %s, %s)", - (account_id, initial_balance, "active", 1), - ) - - # Prepare conditional update - update_stmt = await cassandra_session.prepare( - f""" - UPDATE {table_name} - SET balance = ?, version = ? - WHERE account_id = ? - IF status = ? AND version = ? - """ - ) - - # Withdraw function with business logic - async def withdraw(amount): - max_retries = 3 - for attempt in range(max_retries): - # Read current state - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE account_id = %s", (account_id,) - ) - account = result.one() - - # Business logic checks - if account.status != "active": - raise Exception("Account not active") - - if account.balance < amount: - raise Exception("Insufficient funds") - - # Calculate new balance - new_balance = float(account.balance) - amount - new_version = account.version + 1 - - # Try conditional update - update_result = await cassandra_session.execute( - update_stmt, (new_balance, new_version, account_id, "active", account.version) - ) - - if update_result.one()[0]: # [applied] column - return new_balance - - # Retry on conflict - await asyncio.sleep(0.1) - - raise Exception("Max retries exceeded") - - # Test concurrent withdrawals - async def safe_withdraw(amount): - try: - return await withdraw(amount) - except Exception as e: - return str(e) - - # Multiple concurrent withdrawals - results = await asyncio.gather( - safe_withdraw(100), - safe_withdraw(200), - safe_withdraw(300), - safe_withdraw(600), # This might fail due to insufficient funds - ) - - # Check final balance - final_result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE account_id = %s", (account_id,) - ) - final_account = final_result.one() - - # Some withdrawals may have failed - successful_withdrawals = [r for r in results if isinstance(r, float)] - failed_withdrawals = [r for r in results if isinstance(r, str)] - - # If all withdrawals failed, skip test - if len(successful_withdrawals) == 0: - pytest.skip(f"All withdrawals failed: {failed_withdrawals}") - - total_withdrawn = initial_balance - float(final_account.balance) - - # Balance should be consistent - assert total_withdrawn >= 0 - assert float(final_account.balance) >= 0 - # Version should increase only if withdrawals succeeded - assert final_account.version >= 1 diff --git a/tests/integration/test_concurrent_and_stress_operations.py b/tests/integration/test_concurrent_and_stress_operations.py deleted file mode 100644 index ebb9c8a..0000000 --- a/tests/integration/test_concurrent_and_stress_operations.py +++ /dev/null @@ -1,1137 +0,0 @@ -""" -Consolidated integration tests for concurrent operations and stress testing. - -This module combines all concurrent operation tests from multiple files, -providing comprehensive coverage of high-concurrency scenarios. - -Tests consolidated from: -- test_concurrent_operations.py - Basic concurrent operations -- test_stress.py - High-volume stress testing -- Various concurrent tests from other files - -Test Organization: -================== -1. Basic Concurrent Operations - Read/write/mixed operations -2. High-Volume Stress Tests - Extreme concurrency scenarios -3. Sustained Load Testing - Long-running concurrent operations -4. Connection Pool Testing - Behavior at connection limits -5. Wide Row Performance - Concurrent operations on large data -""" - -import asyncio -import random -import statistics -import time -import uuid -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timezone - -import pytest -import pytest_asyncio -from cassandra.cluster import Cluster as SyncCluster -from cassandra.query import BatchStatement, BatchType - -from async_cassandra import AsyncCassandraSession, AsyncCluster, StreamConfig - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestConcurrentOperations: - """Test basic concurrent operations with real Cassandra.""" - - # ======================================== - # Basic Concurrent Operations - # ======================================== - - async def test_concurrent_reads(self, cassandra_session: AsyncCassandraSession): - """ - Test high-concurrency read operations. - - What this tests: - --------------- - 1. 1000 concurrent read operations - 2. Connection pool handling - 3. Read performance under load - 4. No interference between reads - - Why this matters: - ---------------- - Read-heavy workloads are common in production. - The driver must handle many concurrent reads efficiently. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Insert test data first - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - test_ids = [] - for i in range(100): - test_id = uuid.uuid4() - test_ids.append(test_id) - await cassandra_session.execute( - insert_stmt, [test_id, f"User {i}", f"user{i}@test.com", 20 + (i % 50)] - ) - - # Perform 1000 concurrent reads - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {users_table} WHERE id = ?") - - async def read_record(record_id): - start = time.time() - result = await cassandra_session.execute(select_stmt, [record_id]) - duration = time.time() - start - rows = [] - async for row in result: - rows.append(row) - return rows[0] if rows else None, duration - - # Create 1000 read tasks (reading the same 100 records multiple times) - tasks = [] - for i in range(1000): - record_id = test_ids[i % len(test_ids)] - tasks.append(read_record(record_id)) - - start_time = time.time() - results = await asyncio.gather(*tasks) - total_time = time.time() - start_time - - # Verify results - successful_reads = [r for r, _ in results if r is not None] - assert len(successful_reads) == 1000 - - # Check performance - durations = [d for _, d in results] - avg_duration = sum(durations) / len(durations) - - print("\nConcurrent read test results:") - print(f" Total time: {total_time:.2f}s") - print(f" Average read latency: {avg_duration*1000:.2f}ms") - print(f" Reads per second: {1000/total_time:.0f}") - - # Performance assertions (relaxed for CI environments) - assert total_time < 15.0 # Should complete within 15 seconds - assert avg_duration < 0.5 # Average latency under 500ms - - async def test_concurrent_writes(self, cassandra_session: AsyncCassandraSession): - """ - Test high-concurrency write operations. - - What this tests: - --------------- - 1. 500 concurrent write operations - 2. Write performance under load - 3. No data loss or corruption - 4. Error handling under load - - Why this matters: - ---------------- - Write-heavy workloads test the driver's ability - to handle many concurrent mutations efficiently. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - async def write_record(i): - start = time.time() - try: - await cassandra_session.execute( - insert_stmt, - [uuid.uuid4(), f"Concurrent User {i}", f"concurrent{i}@test.com", 25], - ) - return True, time.time() - start - except Exception: - return False, time.time() - start - - # Create 500 concurrent write tasks - tasks = [write_record(i) for i in range(500)] - - start_time = time.time() - results = await asyncio.gather(*tasks, return_exceptions=True) - total_time = time.time() - start_time - - # Count successes - successful_writes = sum(1 for r in results if isinstance(r, tuple) and r[0]) - failed_writes = 500 - successful_writes - - print("\nConcurrent write test results:") - print(f" Total time: {total_time:.2f}s") - print(f" Successful writes: {successful_writes}") - print(f" Failed writes: {failed_writes}") - print(f" Writes per second: {successful_writes/total_time:.0f}") - - # Should have very high success rate - assert successful_writes >= 495 # Allow up to 1% failure - assert total_time < 10.0 # Should complete within 10 seconds - - async def test_mixed_concurrent_operations(self, cassandra_session: AsyncCassandraSession): - """ - Test mixed read/write/update operations under high concurrency. - - What this tests: - --------------- - 1. 600 mixed operations (200 inserts, 300 reads, 100 updates) - 2. Different operation types running concurrently - 3. No interference between operation types - 4. Consistent performance across operation types - - Why this matters: - ---------------- - Real workloads mix different operation types. - The driver must handle them all efficiently. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {users_table} WHERE id = ?") - update_stmt = await cassandra_session.prepare( - f"UPDATE {users_table} SET age = ? WHERE id = ?" - ) - - # Pre-populate some data - existing_ids = [] - for i in range(50): - user_id = uuid.uuid4() - existing_ids.append(user_id) - await cassandra_session.execute( - insert_stmt, [user_id, f"Existing User {i}", f"existing{i}@test.com", 30] - ) - - # Define operation types - async def insert_operation(i): - return await cassandra_session.execute( - insert_stmt, - [uuid.uuid4(), f"New User {i}", f"new{i}@test.com", 25], - ) - - async def select_operation(user_id): - result = await cassandra_session.execute(select_stmt, [user_id]) - rows = [] - async for row in result: - rows.append(row) - return rows - - async def update_operation(user_id): - new_age = random.randint(20, 60) - return await cassandra_session.execute(update_stmt, [new_age, user_id]) - - # Create mixed operations - operations = [] - - # 200 inserts - for i in range(200): - operations.append(insert_operation(i)) - - # 300 selects - for _ in range(300): - user_id = random.choice(existing_ids) - operations.append(select_operation(user_id)) - - # 100 updates - for _ in range(100): - user_id = random.choice(existing_ids) - operations.append(update_operation(user_id)) - - # Shuffle to mix operation types - random.shuffle(operations) - - # Execute all operations concurrently - start_time = time.time() - results = await asyncio.gather(*operations, return_exceptions=True) - total_time = time.time() - start_time - - # Count results - successful = sum(1 for r in results if not isinstance(r, Exception)) - failed = sum(1 for r in results if isinstance(r, Exception)) - - print("\nMixed operations test results:") - print(f" Total operations: {len(operations)}") - print(f" Successful: {successful}") - print(f" Failed: {failed}") - print(f" Total time: {total_time:.2f}s") - print(f" Operations per second: {successful/total_time:.0f}") - - # Should have very high success rate - assert successful >= 590 # Allow up to ~2% failure - assert total_time < 15.0 # Should complete within 15 seconds - - async def test_concurrent_counter_updates(self, cassandra_session, shared_keyspace_setup): - """ - Test concurrent counter updates. - - What this tests: - --------------- - 1. 100 concurrent counter increments - 2. Counter consistency under concurrent updates - 3. No lost updates - 4. Correct final counter value - - Why this matters: - ---------------- - Counters have special semantics in Cassandra. - Concurrent updates must not lose increments. - """ - # Create counter table - table_name = f"concurrent_counters_{uuid.uuid4().hex[:8]}" - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - count COUNTER - ) - """ - ) - - # Prepare update statement - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET count = count + ? WHERE id = ?" - ) - - counter_id = "test_counter" - increment_value = 1 - - # Perform concurrent increments - async def increment_counter(i): - try: - await cassandra_session.execute(update_stmt, (increment_value, counter_id)) - return True - except Exception: - return False - - # Run 100 concurrent increments - tasks = [increment_counter(i) for i in range(100)] - results = await asyncio.gather(*tasks) - - successful_updates = sum(1 for r in results if r is True) - - # Verify final counter value - result = await cassandra_session.execute( - f"SELECT count FROM {table_name} WHERE id = %s", (counter_id,) - ) - row = result.one() - final_count = row.count if row else 0 - - print("\nCounter concurrent update results:") - print(f" Successful updates: {successful_updates}/100") - print(f" Final counter value: {final_count}") - - # All updates should succeed and be reflected - assert successful_updates == 100 - assert final_count == 100 - - -@pytest.mark.integration -@pytest.mark.stress -class TestStressScenarios: - """Stress test scenarios for async-cassandra.""" - - @pytest_asyncio.fixture - async def stress_session(self) -> AsyncCassandraSession: - """Create session optimized for stress testing.""" - cluster = AsyncCluster( - contact_points=["localhost"], - # Optimize for high concurrency - use maximum threads - executor_threads=128, # Maximum allowed - ) - session = await cluster.connect() - - # Create stress test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS stress_test - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("stress_test") - - # Create tables for different scenarios - await session.execute("DROP TABLE IF EXISTS high_volume") - await session.execute( - """ - CREATE TABLE high_volume ( - partition_key UUID, - clustering_key TIMESTAMP, - data TEXT, - metrics MAP, - tags SET, - PRIMARY KEY (partition_key, clustering_key) - ) WITH CLUSTERING ORDER BY (clustering_key DESC) - """ - ) - - await session.execute("DROP TABLE IF EXISTS wide_rows") - await session.execute( - """ - CREATE TABLE wide_rows ( - partition_key UUID, - column_id INT, - data BLOB, - PRIMARY KEY (partition_key, column_id) - ) - """ - ) - - yield session - - await session.close() - await cluster.shutdown() - - @pytest.mark.asyncio - @pytest.mark.timeout(60) # 1 minute timeout - async def test_extreme_concurrent_writes(self, stress_session: AsyncCassandraSession): - """ - Test handling 10,000 concurrent write operations. - - What this tests: - --------------- - 1. Extreme write concurrency (10,000 operations) - 2. Thread pool handling under extreme load - 3. Memory usage under high concurrency - 4. Error rates at scale - 5. Latency distribution (P95, P99) - - Why this matters: - ---------------- - Production systems may experience traffic spikes. - The driver must handle extreme load gracefully. - """ - insert_stmt = await stress_session.prepare( - """ - INSERT INTO high_volume (partition_key, clustering_key, data, metrics, tags) - VALUES (?, ?, ?, ?, ?) - """ - ) - - async def write_record(i: int): - """Write a single record with timing.""" - start = time.perf_counter() - try: - await stress_session.execute( - insert_stmt, - [ - uuid.uuid4(), - datetime.now(timezone.utc), - f"stress_test_data_{i}_" + "x" * random.randint(100, 1000), - { - "latency": random.random() * 100, - "throughput": random.random() * 1000, - "cpu": random.random() * 100, - }, - {f"tag{j}" for j in range(random.randint(1, 10))}, - ], - ) - return time.perf_counter() - start, None - except Exception as exc: - return time.perf_counter() - start, str(exc) - - # Launch 10,000 concurrent writes - print("\nLaunching 10,000 concurrent writes...") - start_time = time.time() - - tasks = [write_record(i) for i in range(10000)] - results = await asyncio.gather(*tasks) - - total_time = time.time() - start_time - - # Analyze results - durations = [r[0] for r in results] - errors = [r[1] for r in results if r[1] is not None] - - successful_writes = len(results) - len(errors) - avg_duration = statistics.mean(durations) - p95_duration = statistics.quantiles(durations, n=20)[18] # 95th percentile - p99_duration = statistics.quantiles(durations, n=100)[98] # 99th percentile - - print("\nResults for 10,000 concurrent writes:") - print(f" Total time: {total_time:.2f}s") - print(f" Successful writes: {successful_writes}") - print(f" Failed writes: {len(errors)}") - print(f" Throughput: {successful_writes/total_time:.0f} writes/sec") - print(f" Average latency: {avg_duration*1000:.2f}ms") - print(f" P95 latency: {p95_duration*1000:.2f}ms") - print(f" P99 latency: {p99_duration*1000:.2f}ms") - - # If there are errors, show a sample - if errors: - print("\nSample errors (first 5):") - for i, err in enumerate(errors[:5]): - print(f" {i+1}. {err}") - - # Assertions - assert successful_writes == 10000 # ALL writes MUST succeed - assert len(errors) == 0, f"Write failures detected: {errors[:10]}" - assert total_time < 60 # Should complete within 60 seconds - assert avg_duration < 3.0 # Average latency under 3 seconds - - @pytest.mark.asyncio - @pytest.mark.timeout(60) - async def test_sustained_load(self, stress_session: AsyncCassandraSession): - """ - Test sustained high load over time (30 seconds). - - What this tests: - --------------- - 1. Sustained concurrent operations over 30 seconds - 2. Performance consistency over time - 3. Resource stability (no leaks) - 4. Error rates under sustained load - 5. Read/write balance under load - - Why this matters: - ---------------- - Production systems run continuously. - The driver must maintain performance over time. - """ - insert_stmt = await stress_session.prepare( - """ - INSERT INTO high_volume (partition_key, clustering_key, data, metrics, tags) - VALUES (?, ?, ?, ?, ?) - """ - ) - - select_stmt = await stress_session.prepare( - """ - SELECT * FROM high_volume WHERE partition_key = ? - ORDER BY clustering_key DESC LIMIT 10 - """ - ) - - # Track metrics over time - metrics_by_second = defaultdict( - lambda: { - "writes": 0, - "reads": 0, - "errors": 0, - "write_latencies": [], - "read_latencies": [], - } - ) - - # Shared state for operations - written_partitions = [] - write_lock = asyncio.Lock() - - async def continuous_writes(): - """Continuously write data.""" - while time.time() - start_time < 30: - try: - partition_key = uuid.uuid4() - start = time.perf_counter() - - await stress_session.execute( - insert_stmt, - [ - partition_key, - datetime.now(timezone.utc), - "sustained_load_test_" + "x" * 500, - {"metric": random.random()}, - {f"tag{i}" for i in range(5)}, - ], - ) - - duration = time.perf_counter() - start - second = int(time.time() - start_time) - metrics_by_second[second]["writes"] += 1 - metrics_by_second[second]["write_latencies"].append(duration) - - async with write_lock: - written_partitions.append(partition_key) - - except Exception: - second = int(time.time() - start_time) - metrics_by_second[second]["errors"] += 1 - - await asyncio.sleep(0.001) # Small delay to prevent overwhelming - - async def continuous_reads(): - """Continuously read data.""" - await asyncio.sleep(1) # Let some writes happen first - - while time.time() - start_time < 30: - if written_partitions: - try: - async with write_lock: - partition_key = random.choice(written_partitions[-100:]) - - start = time.perf_counter() - await stress_session.execute(select_stmt, [partition_key]) - - duration = time.perf_counter() - start - second = int(time.time() - start_time) - metrics_by_second[second]["reads"] += 1 - metrics_by_second[second]["read_latencies"].append(duration) - - except Exception: - second = int(time.time() - start_time) - metrics_by_second[second]["errors"] += 1 - - await asyncio.sleep(0.002) # Slightly slower than writes - - # Run sustained load test - print("\nRunning 30-second sustained load test...") - start_time = time.time() - - # Create multiple workers for each operation type - write_tasks = [continuous_writes() for _ in range(50)] - read_tasks = [continuous_reads() for _ in range(30)] - - await asyncio.gather(*write_tasks, *read_tasks) - - # Analyze results - print("\nSustained load test results by second:") - print("Second | Writes/s | Reads/s | Errors | Avg Write ms | Avg Read ms") - print("-" * 70) - - total_writes = 0 - total_reads = 0 - total_errors = 0 - - for second in sorted(metrics_by_second.keys()): - metrics = metrics_by_second[second] - avg_write_ms = ( - statistics.mean(metrics["write_latencies"]) * 1000 - if metrics["write_latencies"] - else 0 - ) - avg_read_ms = ( - statistics.mean(metrics["read_latencies"]) * 1000 - if metrics["read_latencies"] - else 0 - ) - - print( - f"{second:6d} | {metrics['writes']:8d} | {metrics['reads']:7d} | " - f"{metrics['errors']:6d} | {avg_write_ms:12.2f} | {avg_read_ms:11.2f}" - ) - - total_writes += metrics["writes"] - total_reads += metrics["reads"] - total_errors += metrics["errors"] - - print(f"\nTotal operations: {total_writes + total_reads}") - print(f"Total errors: {total_errors}") - print(f"Error rate: {total_errors/(total_writes + total_reads)*100:.2f}%") - - # Assertions - assert total_writes > 10000 # Should achieve high write throughput - assert total_reads > 5000 # Should achieve good read throughput - assert total_errors < (total_writes + total_reads) * 0.01 # Less than 1% error rate - - @pytest.mark.asyncio - @pytest.mark.timeout(45) - async def test_wide_row_performance(self, stress_session: AsyncCassandraSession): - """ - Test performance with wide rows (many columns per partition). - - What this tests: - --------------- - 1. Creating wide rows with 10,000 columns - 2. Reading entire wide rows - 3. Reading column ranges - 4. Streaming through wide rows - 5. Performance with large result sets - - Why this matters: - ---------------- - Wide rows are common in time-series and IoT data. - The driver must handle them efficiently. - """ - insert_stmt = await stress_session.prepare( - """ - INSERT INTO wide_rows (partition_key, column_id, data) - VALUES (?, ?, ?) - """ - ) - - # Create a few partitions with many columns each - partition_keys = [uuid.uuid4() for _ in range(10)] - columns_per_partition = 10000 - - print(f"\nCreating wide rows with {columns_per_partition} columns per partition...") - - async def create_wide_row(partition_key: uuid.UUID): - """Create a single wide row.""" - # Use batch inserts for efficiency - batch_size = 100 - - for batch_start in range(0, columns_per_partition, batch_size): - batch = BatchStatement(batch_type=BatchType.UNLOGGED) - - for i in range(batch_start, min(batch_start + batch_size, columns_per_partition)): - batch.add( - insert_stmt, - [ - partition_key, - i, - random.randbytes(random.randint(100, 1000)), # Variable size data - ], - ) - - await stress_session.execute(batch) - - # Create wide rows concurrently - start_time = time.time() - await asyncio.gather(*[create_wide_row(pk) for pk in partition_keys]) - create_time = time.time() - start_time - - print(f"Created {len(partition_keys)} wide rows in {create_time:.2f}s") - - # Test reading wide rows - select_all_stmt = await stress_session.prepare( - """ - SELECT * FROM wide_rows WHERE partition_key = ? - """ - ) - - select_range_stmt = await stress_session.prepare( - """ - SELECT * FROM wide_rows WHERE partition_key = ? - AND column_id >= ? AND column_id < ? - """ - ) - - # Read entire wide rows - print("\nReading entire wide rows...") - read_times = [] - - for pk in partition_keys: - start = time.perf_counter() - result = await stress_session.execute(select_all_stmt, [pk]) - rows = [] - async for row in result: - rows.append(row) - read_times.append(time.perf_counter() - start) - assert len(rows) == columns_per_partition - - print( - f"Average time to read {columns_per_partition} columns: {statistics.mean(read_times)*1000:.2f}ms" - ) - - # Read ranges from wide rows - print("\nReading column ranges...") - range_times = [] - - for _ in range(100): - pk = random.choice(partition_keys) - start_col = random.randint(0, columns_per_partition - 1000) - end_col = start_col + 1000 - - start = time.perf_counter() - result = await stress_session.execute(select_range_stmt, [pk, start_col, end_col]) - rows = [] - async for row in result: - rows.append(row) - range_times.append(time.perf_counter() - start) - assert 900 <= len(rows) <= 1000 # Approximately 1000 columns - - print(f"Average time to read 1000-column range: {statistics.mean(range_times)*1000:.2f}ms") - - # Stream through wide rows - print("\nStreaming through wide rows...") - stream_config = StreamConfig(fetch_size=1000) - - stream_start = time.time() - total_streamed = 0 - - for pk in partition_keys[:3]: # Stream through 3 partitions - result = await stress_session.execute_stream( - "SELECT * FROM wide_rows WHERE partition_key = %s", - [pk], - stream_config=stream_config, - ) - - async for row in result: - total_streamed += 1 - - stream_time = time.time() - stream_start - print( - f"Streamed {total_streamed} rows in {stream_time:.2f}s " - f"({total_streamed/stream_time:.0f} rows/sec)" - ) - - # Assertions - assert statistics.mean(read_times) < 5.0 # Reading wide row under 5 seconds - assert statistics.mean(range_times) < 0.5 # Range query under 500ms - assert total_streamed == columns_per_partition * 3 # All rows streamed - - @pytest.mark.asyncio - @pytest.mark.timeout(30) - async def test_connection_pool_limits(self, stress_session: AsyncCassandraSession): - """ - Test behavior at connection pool limits. - - What this tests: - --------------- - 1. 1000 concurrent queries exceeding connection pool - 2. Query queueing behavior - 3. No deadlocks or stalls - 4. Graceful handling of pool exhaustion - 5. Performance under pool pressure - - Why this matters: - ---------------- - Connection pools have limits. The driver must - handle more concurrent requests than connections. - """ - # Create a query that takes some time - select_stmt = await stress_session.prepare( - """ - SELECT * FROM high_volume LIMIT 1000 - """ - ) - - # First, insert some data - insert_stmt = await stress_session.prepare( - """ - INSERT INTO high_volume (partition_key, clustering_key, data, metrics, tags) - VALUES (?, ?, ?, ?, ?) - """ - ) - - for i in range(100): - await stress_session.execute( - insert_stmt, - [ - uuid.uuid4(), - datetime.now(timezone.utc), - f"test_data_{i}", - {"metric": float(i)}, - {f"tag{i}"}, - ], - ) - - print("\nTesting connection pool under extreme load...") - - # Launch many more concurrent queries than available connections - num_queries = 1000 - - async def timed_query(query_id: int): - """Execute query with timing.""" - start = time.perf_counter() - try: - await stress_session.execute(select_stmt) - return query_id, time.perf_counter() - start, None - except Exception as exc: - return query_id, time.perf_counter() - start, str(exc) - - # Execute all queries concurrently - start_time = time.time() - results = await asyncio.gather(*[timed_query(i) for i in range(num_queries)]) - total_time = time.time() - start_time - - # Analyze queueing behavior - successful = [r for r in results if r[2] is None] - failed = [r for r in results if r[2] is not None] - latencies = [r[1] for r in successful] - - print("\nConnection pool stress test results:") - print(f" Total queries: {num_queries}") - print(f" Successful: {len(successful)}") - print(f" Failed: {len(failed)}") - print(f" Total time: {total_time:.2f}s") - print(f" Throughput: {len(successful)/total_time:.0f} queries/sec") - print(f" Min latency: {min(latencies)*1000:.2f}ms") - print(f" Avg latency: {statistics.mean(latencies)*1000:.2f}ms") - print(f" Max latency: {max(latencies)*1000:.2f}ms") - print(f" P95 latency: {statistics.quantiles(latencies, n=20)[18]*1000:.2f}ms") - - # Despite connection limits, should handle high concurrency well - assert len(successful) >= num_queries * 0.95 # 95% success rate - assert statistics.mean(latencies) < 2.0 # Average under 2 seconds - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestConcurrentPatterns: - """Test specific concurrent patterns and edge cases.""" - - async def test_concurrent_streaming_sessions(self, cassandra_session, shared_keyspace_setup): - """ - Test multiple sessions streaming concurrently. - - What this tests: - --------------- - 1. Multiple streaming operations in parallel - 2. Resource isolation between streams - 3. Memory management with concurrent streams - 4. No interference between streams - - Why this matters: - ---------------- - Streaming is resource-intensive. Multiple concurrent - streams must not interfere with each other. - """ - # Create test table with data - table_name = f"streaming_test_{uuid.uuid4().hex[:8]}" - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - partition_key INT, - clustering_key INT, - data TEXT, - PRIMARY KEY (partition_key, clustering_key) - ) - """ - ) - - # Insert data for streaming - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (partition_key, clustering_key, data) VALUES (?, ?, ?)" - ) - - for partition in range(5): - for cluster in range(1000): - await cassandra_session.execute( - insert_stmt, (partition, cluster, f"data_{partition}_{cluster}") - ) - - # Define streaming function - async def stream_partition(partition_id): - """Stream all data from a partition.""" - count = 0 - stream_config = StreamConfig(fetch_size=100) - - async with await cassandra_session.execute_stream( - f"SELECT * FROM {table_name} WHERE partition_key = %s", - [partition_id], - stream_config=stream_config, - ) as stream: - async for row in stream: - count += 1 - assert row.partition_key == partition_id - - return partition_id, count - - # Run multiple streams concurrently - print("\nRunning 5 concurrent streaming operations...") - start_time = time.time() - - results = await asyncio.gather(*[stream_partition(i) for i in range(5)]) - - total_time = time.time() - start_time - - # Verify results - for partition_id, count in results: - assert count == 1000, f"Partition {partition_id} had {count} rows, expected 1000" - - print(f"Streamed 5000 total rows across 5 streams in {total_time:.2f}s") - assert total_time < 10.0 # Should complete reasonably fast - - async def test_concurrent_empty_results(self, cassandra_session, shared_keyspace_setup): - """ - Test concurrent queries returning empty results. - - What this tests: - --------------- - 1. 20 concurrent queries with no results - 2. Proper handling of empty result sets - 3. No resource leaks with empty results - 4. Consistent behavior - - Why this matters: - ---------------- - Empty results are common in production. - They must be handled efficiently. - """ - # Create test table - table_name = f"empty_results_{uuid.uuid4().hex[:8]}" - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Don't insert any data - all queries will return empty - - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - - async def query_empty(i): - """Query for non-existent data.""" - result = await cassandra_session.execute(select_stmt, (uuid.uuid4(),)) - rows = list(result) - return len(rows) - - # Run concurrent empty queries - tasks = [query_empty(i) for i in range(20)] - results = await asyncio.gather(*tasks) - - # All should return 0 rows - assert all(count == 0 for count in results) - print("\nAll 20 concurrent empty queries completed successfully") - - async def test_concurrent_failures_recovery(self, cassandra_session, shared_keyspace_setup): - """ - Test concurrent queries with simulated failures and recovery. - - What this tests: - --------------- - 1. Concurrent operations with random failures - 2. Retry mechanism under concurrent load - 3. Recovery from transient errors - 4. No cascading failures - - Why this matters: - ---------------- - Network issues and transient failures happen. - The driver must handle them gracefully. - """ - # Create test table - table_name = f"failure_test_{uuid.uuid4().hex[:8]}" - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - attempt INT, - data TEXT - ) - """ - ) - - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, attempt, data) VALUES (?, ?, ?)" - ) - - # Track attempts per operation - attempt_counts = {} - - async def operation_with_retry(op_id): - """Perform operation with retry on failure.""" - max_retries = 3 - for attempt in range(max_retries): - try: - # Simulate random failures (20% chance) - if random.random() < 0.2 and attempt < max_retries - 1: - raise Exception("Simulated transient failure") - - # Perform the operation - await cassandra_session.execute( - insert_stmt, (uuid.uuid4(), attempt + 1, f"operation_{op_id}") - ) - - attempt_counts[op_id] = attempt + 1 - return True - - except Exception: - if attempt == max_retries - 1: - # Final attempt failed - attempt_counts[op_id] = max_retries - return False - # Retry after brief delay - await asyncio.sleep(0.1 * (attempt + 1)) - - # Run operations concurrently - print("\nRunning 50 concurrent operations with simulated failures...") - tasks = [operation_with_retry(i) for i in range(50)] - results = await asyncio.gather(*tasks) - - successful = sum(1 for r in results if r is True) - failed = sum(1 for r in results if r is False) - - # Analyze retry patterns - retry_histogram = {} - for attempts in attempt_counts.values(): - retry_histogram[attempts] = retry_histogram.get(attempts, 0) + 1 - - print("\nResults:") - print(f" Successful: {successful}/50") - print(f" Failed: {failed}/50") - print(f" Retry distribution: {retry_histogram}") - - # Most operations should succeed (possibly with retries) - assert successful >= 45 # At least 90% success rate - - async def test_async_vs_sync_performance(self, cassandra_session, shared_keyspace_setup): - """ - Test async wrapper performance vs sync driver for concurrent operations. - - What this tests: - --------------- - 1. Performance comparison between async and sync drivers - 2. 50 concurrent operations for both approaches - 3. Thread pool vs event loop efficiency - 4. Overhead of async wrapper - - Why this matters: - ---------------- - Users need to know the async wrapper provides - performance benefits for concurrent operations. - """ - # Create sync cluster and session for comparison - sync_cluster = SyncCluster(["localhost"]) - sync_session = sync_cluster.connect() - sync_session.execute( - f"USE {cassandra_session.keyspace}" - ) # Use same keyspace as async session - - # Create test table - table_name = f"perf_comparison_{uuid.uuid4().hex[:8]}" - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - value TEXT - ) - """ - ) - - # Number of concurrent operations - num_ops = 50 - - # Prepare statements - sync_insert = sync_session.prepare(f"INSERT INTO {table_name} (id, value) VALUES (?, ?)") - async_insert = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, value) VALUES (?, ?)" - ) - - # Sync approach with thread pool - print("\nTesting sync driver with thread pool...") - start_sync = time.time() - with ThreadPoolExecutor(max_workers=10) as executor: - futures = [] - for i in range(num_ops): - future = executor.submit(sync_session.execute, sync_insert, (i, f"sync_{i}")) - futures.append(future) - - # Wait for all - for future in futures: - future.result() - sync_time = time.time() - start_sync - - # Async approach - print("Testing async wrapper...") - start_async = time.time() - tasks = [] - for i in range(num_ops): - task = cassandra_session.execute(async_insert, (i + 1000, f"async_{i}")) - tasks.append(task) - - await asyncio.gather(*tasks) - async_time = time.time() - start_async - - # Results - print(f"\nPerformance comparison for {num_ops} concurrent operations:") - print(f" Sync with thread pool: {sync_time:.3f}s") - print(f" Async wrapper: {async_time:.3f}s") - print(f" Speedup: {sync_time/async_time:.2f}x") - - # Verify all data was inserted - result = await cassandra_session.execute(f"SELECT COUNT(*) FROM {table_name}") - total_count = result.one()[0] - assert total_count == num_ops * 2 # Both sync and async inserts - - # Cleanup - sync_session.shutdown() - sync_cluster.shutdown() diff --git a/tests/integration/test_consistency_and_prepared_statements.py b/tests/integration/test_consistency_and_prepared_statements.py deleted file mode 100644 index 97e4b46..0000000 --- a/tests/integration/test_consistency_and_prepared_statements.py +++ /dev/null @@ -1,927 +0,0 @@ -""" -Consolidated integration tests for consistency levels and prepared statements. - -This module combines all consistency level and prepared statement tests, -providing comprehensive coverage of statement preparation and execution patterns. - -Tests consolidated from: -- test_driver_compatibility.py - Consistency and prepared statement compatibility -- test_simple_statements.py - SimpleStatement consistency levels -- test_select_operations.py - SELECT with different consistency levels -- test_concurrent_operations.py - Concurrent operations with consistency -- Various prepared statement usage from other test files - -Test Organization: -================== -1. Prepared Statement Basics - Creation, binding, execution -2. Consistency Level Configuration - Per-statement and per-query -3. Combined Patterns - Prepared statements with consistency levels -4. Concurrent Usage - Thread safety and performance -5. Error Handling - Invalid statements, binding errors -""" - -import asyncio -import time -import uuid -from datetime import datetime, timezone -from decimal import Decimal - -import pytest -from cassandra import ConsistencyLevel -from cassandra.query import BatchStatement, BatchType, SimpleStatement -from test_utils import generate_unique_table - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestPreparedStatements: - """Test prepared statement functionality with real Cassandra.""" - - # ======================================== - # Basic Prepared Statement Operations - # ======================================== - - async def test_prepared_statement_basics(self, cassandra_session, shared_keyspace_setup): - """ - Test basic prepared statement operations. - - What this tests: - --------------- - 1. Statement preparation with ? placeholders - 2. Binding parameters - 3. Reusing prepared statements - 4. Type safety with prepared statements - - Why this matters: - ---------------- - Prepared statements provide better performance through - query plan caching and protection against injection. - """ - # Create test table - table_name = generate_unique_table("test_prepared_basics") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - name TEXT, - age INT, - created_at TIMESTAMP - ) - """ - ) - - # Prepare INSERT statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, name, age, created_at) VALUES (?, ?, ?, ?)" - ) - - # Prepare SELECT statements - select_by_id = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - - select_all = await cassandra_session.prepare(f"SELECT * FROM {table_name}") - - # Execute prepared statements multiple times - users = [] - for i in range(5): - user_id = uuid.uuid4() - users.append(user_id) - await cassandra_session.execute( - insert_stmt, (user_id, f"User {i}", 20 + i, datetime.now(timezone.utc)) - ) - - # Verify inserts using prepared select - for i, user_id in enumerate(users): - result = await cassandra_session.execute(select_by_id, (user_id,)) - row = result.one() - assert row.name == f"User {i}" - assert row.age == 20 + i - - # Select all and verify count - result = await cassandra_session.execute(select_all) - rows = list(result) - assert len(rows) == 5 - - async def test_prepared_statement_with_different_types( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test prepared statements with various data types. - - What this tests: - --------------- - 1. Type conversion and validation - 2. NULL handling - 3. Collection types in prepared statements - 4. Special types (UUID, decimal, etc.) - - Why this matters: - ---------------- - Prepared statements must correctly handle all - Cassandra data types with proper serialization. - """ - # Create table with various types - table_name = generate_unique_table("test_prepared_types") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - text_val TEXT, - int_val INT, - decimal_val DECIMAL, - list_val LIST, - map_val MAP, - set_val SET, - bool_val BOOLEAN - ) - """ - ) - - # Prepare statement with all types - insert_stmt = await cassandra_session.prepare( - f""" - INSERT INTO {table_name} - (id, text_val, int_val, decimal_val, list_val, map_val, set_val, bool_val) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """ - ) - - # Test with various values including NULL - test_cases = [ - # All values present - ( - uuid.uuid4(), - "test text", - 42, - Decimal("123.456"), - ["a", "b", "c"], - {"key1": 1, "key2": 2}, - {1, 2, 3}, - True, - ), - # Some NULL values - ( - uuid.uuid4(), - None, # NULL text - 100, - None, # NULL decimal - [], # Empty list - {}, # Empty map - set(), # Empty set - False, - ), - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Verify data - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - - for i, test_case in enumerate(test_cases): - result = await cassandra_session.execute(select_stmt, (test_case[0],)) - row = result.one() - - if i == 0: # First test case with all values - assert row.text_val == test_case[1] - assert row.int_val == test_case[2] - assert row.decimal_val == test_case[3] - assert row.list_val == test_case[4] - assert row.map_val == test_case[5] - assert row.set_val == test_case[6] - assert row.bool_val == test_case[7] - else: # Second test case with NULLs - assert row.text_val is None - assert row.int_val == 100 - assert row.decimal_val is None - # Empty collections may be stored as NULL in Cassandra - assert row.list_val is None or row.list_val == [] - assert row.map_val is None or row.map_val == {} - assert row.set_val is None or row.set_val == set() - - async def test_prepared_statement_reuse_performance( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test performance benefits of prepared statement reuse. - - What this tests: - --------------- - 1. Performance improvement with reuse - 2. Statement cache behavior - 3. Concurrent reuse safety - - Why this matters: - ---------------- - Prepared statements should be prepared once and - reused many times for optimal performance. - """ - # Create test table - table_name = generate_unique_table("test_prepared_perf") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Measure time with prepared statement reuse - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, data) VALUES (?, ?)" - ) - - start_prepared = time.time() - for i in range(100): - await cassandra_session.execute(insert_stmt, (uuid.uuid4(), f"prepared_data_{i}")) - prepared_duration = time.time() - start_prepared - - # Measure time with SimpleStatement (no preparation) - start_simple = time.time() - for i in range(100): - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, data) VALUES (%s, %s)", - (uuid.uuid4(), f"simple_data_{i}"), - ) - simple_duration = time.time() - start_simple - - # Prepared statements should generally be faster or similar - # (The difference might be small for simple queries) - print(f"Prepared: {prepared_duration:.3f}s, Simple: {simple_duration:.3f}s") - - # Verify both methods inserted data - result = await cassandra_session.execute(f"SELECT COUNT(*) FROM {table_name}") - count = result.one()[0] - assert count == 200 - - # ======================================== - # Consistency Level Tests - # ======================================== - - async def test_consistency_levels_with_prepared_statements( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test different consistency levels with prepared statements. - - What this tests: - --------------- - 1. Setting consistency on prepared statements - 2. Different consistency levels (ONE, QUORUM, ALL) - 3. Read/write consistency combinations - 4. Consistency level errors - - Why this matters: - ---------------- - Consistency levels control the trade-off between - consistency, availability, and performance. - """ - # Create test table - table_name = generate_unique_table("test_consistency") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - data TEXT, - version INT - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, data, version) VALUES (?, ?, ?)" - ) - - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - - test_id = uuid.uuid4() - - # Test different write consistency levels - consistency_levels = [ - ConsistencyLevel.ONE, - ConsistencyLevel.QUORUM, - ConsistencyLevel.ALL, - ] - - for i, cl in enumerate(consistency_levels): - # Set consistency level on the statement - insert_stmt.consistency_level = cl - - try: - await cassandra_session.execute(insert_stmt, (test_id, f"consistency_{cl}", i)) - print(f"Write with {cl} succeeded") - except Exception as e: - # ALL might fail in single-node setup - if cl == ConsistencyLevel.ALL: - print(f"Write with ALL failed as expected: {e}") - else: - raise - - # Test different read consistency levels - for cl in [ConsistencyLevel.ONE, ConsistencyLevel.QUORUM]: - select_stmt.consistency_level = cl - - result = await cassandra_session.execute(select_stmt, (test_id,)) - row = result.one() - if row: - print(f"Read with {cl} returned version {row.version}") - - async def test_consistency_levels_with_simple_statements( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test consistency levels with SimpleStatement. - - What this tests: - --------------- - 1. SimpleStatement with consistency configuration - 2. Per-query consistency settings - 3. Comparison with prepared statements - - Why this matters: - ---------------- - SimpleStatements allow per-query consistency - configuration without statement preparation. - """ - # Create test table - table_name = generate_unique_table("test_simple_consistency") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - value INT - ) - """ - ) - - # Test with different consistency levels - test_data = [ - ("one_consistency", ConsistencyLevel.ONE), - ("local_one", ConsistencyLevel.LOCAL_ONE), - ("local_quorum", ConsistencyLevel.LOCAL_QUORUM), - ] - - for key, consistency in test_data: - # Create SimpleStatement with specific consistency - insert = SimpleStatement( - f"INSERT INTO {table_name} (id, value) VALUES (%s, %s)", - consistency_level=consistency, - ) - - await cassandra_session.execute(insert, (key, 100)) - - # Read back with same consistency - select = SimpleStatement( - f"SELECT * FROM {table_name} WHERE id = %s", consistency_level=consistency - ) - - result = await cassandra_session.execute(select, (key,)) - row = result.one() - assert row.value == 100 - - # ======================================== - # Combined Patterns - # ======================================== - - async def test_prepared_statements_in_batch_with_consistency( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test prepared statements in batches with consistency levels. - - What this tests: - --------------- - 1. Prepared statements in batch operations - 2. Batch consistency levels - 3. Mixed statement types in batch - 4. Batch atomicity with consistency - - Why this matters: - ---------------- - Batches often combine multiple prepared statements - and need specific consistency guarantees. - """ - # Create test table - table_name = generate_unique_table("test_batch_prepared") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - partition_key TEXT, - clustering_key INT, - data TEXT, - PRIMARY KEY (partition_key, clustering_key) - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (partition_key, clustering_key, data) VALUES (?, ?, ?)" - ) - - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET data = ? WHERE partition_key = ? AND clustering_key = ?" - ) - - # Create batch with specific consistency - batch = BatchStatement( - batch_type=BatchType.LOGGED, consistency_level=ConsistencyLevel.QUORUM - ) - - partition = "batch_test" - - # Add multiple prepared statements to batch - for i in range(5): - batch.add(insert_stmt, (partition, i, f"initial_{i}")) - - # Add updates - for i in range(3): - batch.add(update_stmt, (f"updated_{i}", partition, i)) - - # Execute batch - await cassandra_session.execute(batch) - - # Verify with read at QUORUM - select_stmt = await cassandra_session.prepare( - f"SELECT * FROM {table_name} WHERE partition_key = ?" - ) - select_stmt.consistency_level = ConsistencyLevel.QUORUM - - result = await cassandra_session.execute(select_stmt, (partition,)) - rows = list(result) - assert len(rows) == 5 - - # Check updates were applied - for row in rows: - if row.clustering_key < 3: - assert row.data == f"updated_{row.clustering_key}" - else: - assert row.data == f"initial_{row.clustering_key}" - - # ======================================== - # Concurrent Usage Patterns - # ======================================== - - async def test_concurrent_prepared_statement_usage( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test concurrent usage of prepared statements. - - What this tests: - --------------- - 1. Thread safety of prepared statements - 2. Concurrent execution performance - 3. No interference between concurrent executions - 4. Connection pool behavior - - Why this matters: - ---------------- - Prepared statements must be safe for concurrent - use from multiple async tasks. - """ - # Create test table - table_name = generate_unique_table("test_concurrent_prepared") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - thread_id INT, - value TEXT, - created_at TIMESTAMP - ) - """ - ) - - # Prepare statements once - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, thread_id, value, created_at) VALUES (?, ?, ?, ?)" - ) - - select_stmt = await cassandra_session.prepare( - f"SELECT COUNT(*) FROM {table_name} WHERE thread_id = ? ALLOW FILTERING" - ) - - # Concurrent insert function - async def insert_records(thread_id, count): - for i in range(count): - await cassandra_session.execute( - insert_stmt, - ( - uuid.uuid4(), - thread_id, - f"thread_{thread_id}_record_{i}", - datetime.now(timezone.utc), - ), - ) - return thread_id - - # Run many concurrent tasks - tasks = [] - num_threads = 10 - records_per_thread = 20 - - for i in range(num_threads): - task = asyncio.create_task(insert_records(i, records_per_thread)) - tasks.append(task) - - # Wait for all to complete - results = await asyncio.gather(*tasks) - assert len(results) == num_threads - - # Verify each thread inserted correct number - for thread_id in range(num_threads): - result = await cassandra_session.execute(select_stmt, (thread_id,)) - count = result.one()[0] - assert count == records_per_thread - - # Verify total - total_result = await cassandra_session.execute(f"SELECT COUNT(*) FROM {table_name}") - total = total_result.one()[0] - assert total == num_threads * records_per_thread - - async def test_prepared_statement_with_consistency_race_conditions( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test race conditions with different consistency levels. - - What this tests: - --------------- - 1. Write with ONE, read with ALL pattern - 2. Consistency level impact on visibility - 3. Eventual consistency behavior - 4. Race condition handling - - Why this matters: - ---------------- - Understanding consistency level interactions is - crucial for distributed system correctness. - """ - # Create test table - table_name = generate_unique_table("test_consistency_race") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - counter INT, - last_updated TIMESTAMP - ) - """ - ) - - # Prepare statements with different consistency - insert_one = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, counter, last_updated) VALUES (?, ?, ?)" - ) - insert_one.consistency_level = ConsistencyLevel.ONE - - select_all = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - # Don't set ALL here as it might fail in single-node - select_all.consistency_level = ConsistencyLevel.QUORUM - - update_quorum = await cassandra_session.prepare( - f"UPDATE {table_name} SET counter = ?, last_updated = ? WHERE id = ?" - ) - update_quorum.consistency_level = ConsistencyLevel.QUORUM - - # Test concurrent updates with different consistency - test_id = "consistency_test" - - # Initial insert with ONE - await cassandra_session.execute(insert_one, (test_id, 0, datetime.now(timezone.utc))) - - # Concurrent updates - async def update_counter(increment): - # Read current value - result = await cassandra_session.execute(select_all, (test_id,)) - current = result.one() - - if current: - new_value = current.counter + increment - # Update with QUORUM - await cassandra_session.execute( - update_quorum, (new_value, datetime.now(timezone.utc), test_id) - ) - return new_value - return None - - # Run concurrent updates - tasks = [update_counter(1) for _ in range(5)] - await asyncio.gather(*tasks, return_exceptions=True) - - # Final read - final_result = await cassandra_session.execute(select_all, (test_id,)) - final_row = final_result.one() - - # Due to race conditions, final counter might not be 5 - # but should be between 1 and 5 - assert 1 <= final_row.counter <= 5 - print(f"Final counter value: {final_row.counter} (race conditions expected)") - - # ======================================== - # Error Handling - # ======================================== - - async def test_prepared_statement_error_handling( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test error handling with prepared statements. - - What this tests: - --------------- - 1. Invalid query preparation - 2. Wrong parameter count - 3. Type mismatch errors - 4. Non-existent table/column errors - - Why this matters: - ---------------- - Proper error handling ensures robust applications - and clear debugging information. - """ - # Test preparing invalid query - from cassandra.protocol import SyntaxException - - with pytest.raises(SyntaxException): - await cassandra_session.prepare("INVALID SQL QUERY") - - # Create test table - table_name = generate_unique_table("test_prepared_errors") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - value INT - ) - """ - ) - - # Prepare valid statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, value) VALUES (?, ?)" - ) - - # Test wrong parameter count - Cassandra driver behavior varies - # Some versions auto-fill missing parameters with None - try: - await cassandra_session.execute(insert_stmt, (uuid.uuid4(),)) # Missing value - # If no exception, verify it inserted NULL for missing value - print("Note: Driver accepted missing parameter (filled with NULL)") - except Exception as e: - print(f"Driver raised exception for missing parameter: {type(e).__name__}") - - # Test too many parameters - this should always fail - with pytest.raises(Exception): - await cassandra_session.execute( - insert_stmt, (uuid.uuid4(), 100, "extra", "more") # Way too many parameters - ) - - # Test type mismatch - string for UUID should fail - try: - await cassandra_session.execute( - insert_stmt, ("not-a-uuid", 100) # String instead of UUID - ) - pytest.fail("Expected exception for invalid UUID string") - except Exception: - pass # Expected - - # Test non-existent column - from cassandra import InvalidRequest - - with pytest.raises(InvalidRequest): - await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, nonexistent) VALUES (?, ?)" - ) - - async def test_statement_id_and_metadata(self, cassandra_session, shared_keyspace_setup): - """ - Test prepared statement metadata and IDs. - - What this tests: - --------------- - 1. Statement preparation returns metadata - 2. Prepared statement IDs are stable - 3. Re-preparing returns same statement - 4. Metadata contains column information - - Why this matters: - ---------------- - Understanding statement metadata helps with - debugging and advanced driver usage. - """ - # Create test table - table_name = generate_unique_table("test_stmt_metadata") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - name TEXT, - age INT, - active BOOLEAN - ) - """ - ) - - # Prepare statement - query = f"INSERT INTO {table_name} (id, name, age, active) VALUES (?, ?, ?, ?)" - stmt1 = await cassandra_session.prepare(query) - - # Re-prepare same query - await cassandra_session.prepare(query) - - # Both should be the same prepared statement - # (Cassandra caches prepared statements) - - # Test statement has required attributes - assert hasattr(stmt1, "bind") - assert hasattr(stmt1, "consistency_level") - - # Can bind values - bound = stmt1.bind((uuid.uuid4(), "Test", 25, True)) - await cassandra_session.execute(bound) - - # Verify insert worked - result = await cassandra_session.execute(f"SELECT COUNT(*) FROM {table_name}") - assert result.one()[0] == 1 - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestConsistencyPatterns: - """Test advanced consistency patterns and scenarios.""" - - async def test_read_your_writes_pattern(self, cassandra_session, shared_keyspace_setup): - """ - Test read-your-writes consistency pattern. - - What this tests: - --------------- - 1. Write at QUORUM, read at QUORUM - 2. Immediate read visibility - 3. Consistency across nodes - 4. No stale reads - - Why this matters: - ---------------- - Read-your-writes is a common consistency requirement - where users expect to see their own changes immediately. - """ - # Create test table - table_name = generate_unique_table("test_read_your_writes") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - user_id UUID PRIMARY KEY, - username TEXT, - email TEXT, - updated_at TIMESTAMP - ) - """ - ) - - # Prepare statements with QUORUM consistency - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (user_id, username, email, updated_at) VALUES (?, ?, ?, ?)" - ) - insert_stmt.consistency_level = ConsistencyLevel.QUORUM - - select_stmt = await cassandra_session.prepare( - f"SELECT * FROM {table_name} WHERE user_id = ?" - ) - select_stmt.consistency_level = ConsistencyLevel.QUORUM - - # Test immediate read after write - user_id = uuid.uuid4() - username = "testuser" - email = "test@example.com" - - # Write - await cassandra_session.execute( - insert_stmt, (user_id, username, email, datetime.now(timezone.utc)) - ) - - # Immediate read should see the write - result = await cassandra_session.execute(select_stmt, (user_id,)) - row = result.one() - assert row is not None - assert row.username == username - assert row.email == email - - async def test_eventual_consistency_demonstration( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test and demonstrate eventual consistency behavior. - - What this tests: - --------------- - 1. Write at ONE, read at ONE behavior - 2. Potential inconsistency windows - 3. Eventually consistent reads - 4. Consistency level trade-offs - - Why this matters: - ---------------- - Understanding eventual consistency helps design - systems that handle temporary inconsistencies. - """ - # Create test table - table_name = generate_unique_table("test_eventual") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - value INT, - timestamp TIMESTAMP - ) - """ - ) - - # Prepare statements with ONE consistency (fastest, least consistent) - write_one = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, value, timestamp) VALUES (?, ?, ?)" - ) - write_one.consistency_level = ConsistencyLevel.ONE - - read_one = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - read_one.consistency_level = ConsistencyLevel.ONE - - read_all = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - # Use QUORUM instead of ALL for single-node compatibility - read_all.consistency_level = ConsistencyLevel.QUORUM - - test_id = "eventual_test" - - # Rapid writes with ONE - for i in range(10): - await cassandra_session.execute(write_one, (test_id, i, datetime.now(timezone.utc))) - - # Read with different consistency levels - result_one = await cassandra_session.execute(read_one, (test_id,)) - result_all = await cassandra_session.execute(read_all, (test_id,)) - - # Both should eventually see the same value - # In a single-node setup, they'll be consistent - row_one = result_one.one() - row_all = result_all.one() - - assert row_one.value == row_all.value == 9 - print(f"ONE read: {row_one.value}, QUORUM read: {row_all.value}") - - async def test_multi_datacenter_consistency_levels( - self, cassandra_session, shared_keyspace_setup - ): - """ - Test LOCAL consistency levels for multi-DC scenarios. - - What this tests: - --------------- - 1. LOCAL_ONE vs ONE - 2. LOCAL_QUORUM vs QUORUM - 3. Multi-DC consistency patterns - 4. DC-aware consistency - - Why this matters: - ---------------- - Multi-datacenter deployments require careful - consistency level selection for performance. - """ - # Create test table - table_name = generate_unique_table("test_local_consistency") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - dc_name TEXT, - data TEXT - ) - """ - ) - - # Test LOCAL consistency levels (work in single-DC too) - local_consistency_levels = [ - (ConsistencyLevel.LOCAL_ONE, "LOCAL_ONE"), - (ConsistencyLevel.LOCAL_QUORUM, "LOCAL_QUORUM"), - ] - - for cl, cl_name in local_consistency_levels: - stmt = SimpleStatement( - f"INSERT INTO {table_name} (id, dc_name, data) VALUES (%s, %s, %s)", - consistency_level=cl, - ) - - try: - await cassandra_session.execute( - stmt, (uuid.uuid4(), cl_name, f"Written with {cl_name}") - ) - print(f"Write with {cl_name} succeeded") - except Exception as e: - print(f"Write with {cl_name} failed: {e}") - - # Verify writes - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = list(result) - print(f"Successfully wrote {len(rows)} rows with LOCAL consistency levels") diff --git a/tests/integration/test_context_manager_safety_integration.py b/tests/integration/test_context_manager_safety_integration.py deleted file mode 100644 index 19df52d..0000000 --- a/tests/integration/test_context_manager_safety_integration.py +++ /dev/null @@ -1,423 +0,0 @@ -""" -Integration tests for context manager safety with real Cassandra. - -These tests ensure that context managers behave correctly with actual -Cassandra connections and don't close shared resources inappropriately. -""" - -import asyncio -import uuid - -import pytest -from cassandra import InvalidRequest - -from async_cassandra import AsyncCluster -from async_cassandra.streaming import StreamConfig - - -@pytest.mark.integration -class TestContextManagerSafetyIntegration: - """Test context manager safety with real Cassandra connections.""" - - @pytest.mark.asyncio - async def test_session_remains_open_after_query_error(self, cassandra_session): - """ - Test that session remains usable after a query error occurs. - - What this tests: - --------------- - 1. Query errors don't close session - 2. Session still usable - 3. New queries work - 4. Insert/select functional - - Why this matters: - ---------------- - Error recovery critical: - - Apps have query errors - - Must continue operating - - No resource leaks - - Sessions must survive - individual query failures. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Try a bad query - with pytest.raises(InvalidRequest): - await cassandra_session.execute( - "SELECT * FROM table_that_definitely_does_not_exist_xyz123" - ) - - # Session should still be usable - user_id = uuid.uuid4() - insert_prepared = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name) VALUES (?, ?)" - ) - await cassandra_session.execute(insert_prepared, [user_id, "Test User"]) - - # Verify insert worked - select_prepared = await cassandra_session.prepare( - f"SELECT * FROM {users_table} WHERE id = ?" - ) - result = await cassandra_session.execute(select_prepared, [user_id]) - row = result.one() - assert row.name == "Test User" - - @pytest.mark.asyncio - async def test_streaming_error_doesnt_close_session(self, cassandra_session): - """ - Test that an error during streaming doesn't close the session. - - What this tests: - --------------- - 1. Stream errors handled - 2. Session stays open - 3. New streams work - 4. Regular queries work - - Why this matters: - ---------------- - Streaming failures common: - - Large result sets - - Network interruptions - - Query timeouts - - Session must survive - streaming failures. - """ - # Create test table - await cassandra_session.execute( - """ - CREATE TABLE IF NOT EXISTS test_stream_data ( - id UUID PRIMARY KEY, - value INT - ) - """ - ) - - # Insert some data - insert_prepared = await cassandra_session.prepare( - "INSERT INTO test_stream_data (id, value) VALUES (?, ?)" - ) - for i in range(10): - await cassandra_session.execute(insert_prepared, [uuid.uuid4(), i]) - - # Stream with an error (simulate by using bad query) - try: - async with await cassandra_session.execute_stream( - "SELECT * FROM non_existent_table" - ) as stream: - async for row in stream: - pass - except Exception: - pass # Expected - - # Session should still work - result = await cassandra_session.execute("SELECT COUNT(*) FROM test_stream_data") - assert result.one()[0] == 10 - - # Try another streaming query - should work - count = 0 - async with await cassandra_session.execute_stream( - "SELECT * FROM test_stream_data" - ) as stream: - async for row in stream: - count += 1 - assert count == 10 - - @pytest.mark.asyncio - async def test_concurrent_streaming_sessions(self, cassandra_session, cassandra_cluster): - """ - Test that multiple sessions can stream concurrently without interference. - - What this tests: - --------------- - 1. Multiple sessions work - 2. Concurrent streaming OK - 3. No interference - 4. Independent results - - Why this matters: - ---------------- - Multi-session patterns: - - Worker processes - - Parallel processing - - Load distribution - - Sessions must be truly - independent. - """ - # Create test table - await cassandra_session.execute( - """ - CREATE TABLE IF NOT EXISTS test_concurrent_data ( - partition INT, - id UUID, - value TEXT, - PRIMARY KEY (partition, id) - ) - """ - ) - - # Insert data in different partitions - insert_prepared = await cassandra_session.prepare( - "INSERT INTO test_concurrent_data (partition, id, value) VALUES (?, ?, ?)" - ) - for partition in range(3): - for i in range(100): - await cassandra_session.execute( - insert_prepared, - [partition, uuid.uuid4(), f"value_{partition}_{i}"], - ) - - # Stream from multiple sessions concurrently - async def stream_partition(partition_id): - # Create new session and connect to the shared keyspace - session = await cassandra_cluster.connect() - await session.set_keyspace("integration_test") - try: - count = 0 - config = StreamConfig(fetch_size=10) - - query_prepared = await session.prepare( - "SELECT * FROM test_concurrent_data WHERE partition = ?" - ) - async with await session.execute_stream( - query_prepared, [partition_id], stream_config=config - ) as stream: - async for row in stream: - assert row.value.startswith(f"value_{partition_id}_") - count += 1 - - return count - finally: - await session.close() - - # Run streams concurrently - results = await asyncio.gather( - stream_partition(0), stream_partition(1), stream_partition(2) - ) - - # Each partition should have 100 rows - assert all(count == 100 for count in results) - - @pytest.mark.asyncio - async def test_session_context_manager_with_streaming(self, cassandra_cluster): - """ - Test using session context manager with streaming operations. - - What this tests: - --------------- - 1. Session context managers - 2. Streaming within context - 3. Error cleanup works - 4. Resources freed - - Why this matters: - ---------------- - Context managers ensure: - - Proper cleanup - - Exception safety - - Resource management - - Critical for production - reliability. - """ - try: - # Use session in context manager - async with await cassandra_cluster.connect() as session: - await session.set_keyspace("integration_test") - await session.execute( - """ - CREATE TABLE IF NOT EXISTS test_session_ctx_data ( - id UUID PRIMARY KEY, - value TEXT - ) - """ - ) - - # Insert data - insert_prepared = await session.prepare( - "INSERT INTO test_session_ctx_data (id, value) VALUES (?, ?)" - ) - for i in range(50): - await session.execute( - insert_prepared, - [uuid.uuid4(), f"value_{i}"], - ) - - # Stream data - count = 0 - async with await session.execute_stream( - "SELECT * FROM test_session_ctx_data" - ) as stream: - async for row in stream: - count += 1 - - assert count == 50 - - # Raise an error to test cleanup - if True: # Always true, but makes intent clear - raise ValueError("Test error") - - except ValueError: - # Expected error - pass - - # Cluster should still be usable - verify_session = await cassandra_cluster.connect() - await verify_session.set_keyspace("integration_test") - result = await verify_session.execute("SELECT COUNT(*) FROM test_session_ctx_data") - assert result.one()[0] == 50 - - # Cleanup - await verify_session.close() - - @pytest.mark.asyncio - async def test_cluster_context_manager_multiple_sessions(self, cassandra_cluster): - """ - Test cluster context manager with multiple sessions. - - What this tests: - --------------- - 1. Multiple sessions per cluster - 2. Independent session lifecycle - 3. Cluster cleanup on exit - 4. Session isolation - - Why this matters: - ---------------- - Multi-session patterns: - - Connection pooling - - Worker threads - - Service isolation - - Cluster must manage all - sessions properly. - """ - # Use cluster in context manager - async with AsyncCluster(["localhost"]) as cluster: - # Create multiple sessions - sessions = [] - for i in range(3): - session = await cluster.connect() - sessions.append(session) - - # Use all sessions - for i, session in enumerate(sessions): - result = await session.execute("SELECT release_version FROM system.local") - assert result.one() is not None - - # Close only one session - await sessions[0].close() - - # Other sessions should still work - for session in sessions[1:]: - result = await session.execute("SELECT release_version FROM system.local") - assert result.one() is not None - - # Close remaining sessions - for session in sessions[1:]: - await session.close() - - # After cluster context exits, cluster is shut down - # Trying to use it should fail - with pytest.raises(Exception): - await cluster.connect() - - @pytest.mark.asyncio - async def test_nested_streaming_contexts(self, cassandra_session): - """ - Test nested streaming context managers. - - What this tests: - --------------- - 1. Nested streams work - 2. Inner/outer independence - 3. Proper cleanup order - 4. No resource conflicts - - Why this matters: - ---------------- - Nested patterns common: - - Parent-child queries - - Hierarchical data - - Complex workflows - - Must handle nested contexts - without deadlocks. - """ - # Create test tables - await cassandra_session.execute( - """ - CREATE TABLE IF NOT EXISTS test_nested_categories ( - id UUID PRIMARY KEY, - name TEXT - ) - """ - ) - - await cassandra_session.execute( - """ - CREATE TABLE IF NOT EXISTS test_nested_items ( - category_id UUID, - id UUID, - name TEXT, - PRIMARY KEY (category_id, id) - ) - """ - ) - - # Insert test data - categories = [] - category_prepared = await cassandra_session.prepare( - "INSERT INTO test_nested_categories (id, name) VALUES (?, ?)" - ) - item_prepared = await cassandra_session.prepare( - "INSERT INTO test_nested_items (category_id, id, name) VALUES (?, ?, ?)" - ) - - for i in range(3): - cat_id = uuid.uuid4() - categories.append(cat_id) - await cassandra_session.execute( - category_prepared, - [cat_id, f"Category {i}"], - ) - - # Insert items for this category - for j in range(5): - await cassandra_session.execute( - item_prepared, - [cat_id, uuid.uuid4(), f"Item {i}-{j}"], - ) - - # Nested streaming - category_count = 0 - item_count = 0 - - # Stream categories - async with await cassandra_session.execute_stream( - "SELECT * FROM test_nested_categories" - ) as cat_stream: - async for category in cat_stream: - category_count += 1 - - # For each category, stream its items - query_prepared = await cassandra_session.prepare( - "SELECT * FROM test_nested_items WHERE category_id = ?" - ) - async with await cassandra_session.execute_stream( - query_prepared, [category.id] - ) as item_stream: - async for item in item_stream: - item_count += 1 - - assert category_count == 3 - assert item_count == 15 # 3 categories * 5 items each - - # Session should still be usable - result = await cassandra_session.execute("SELECT COUNT(*) FROM test_nested_categories") - assert result.one()[0] == 3 diff --git a/tests/integration/test_crud_operations.py b/tests/integration/test_crud_operations.py deleted file mode 100644 index d756e30..0000000 --- a/tests/integration/test_crud_operations.py +++ /dev/null @@ -1,617 +0,0 @@ -""" -Consolidated integration tests for CRUD operations. - -This module combines basic CRUD operation tests from multiple files, -focusing on core insert, select, update, and delete functionality. - -Tests consolidated from: -- test_basic_operations.py -- test_select_operations.py - -Test Organization: -================== -1. Basic CRUD Operations - Single record operations -2. Prepared Statement CRUD - Prepared statement usage -3. Batch Operations - Batch inserts and updates -4. Edge Cases - Non-existent data, NULL values, etc. -""" - -import uuid -from decimal import Decimal - -import pytest -from cassandra.query import BatchStatement, BatchType -from test_utils import generate_unique_table - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestCRUDOperations: - """Test basic CRUD operations with real Cassandra.""" - - # ======================================== - # Basic CRUD Operations - # ======================================== - - async def test_insert_and_select(self, cassandra_session, shared_keyspace_setup): - """ - Test basic insert and select operations. - - What this tests: - --------------- - 1. INSERT with prepared statements - 2. SELECT with prepared statements - 3. Data integrity after insert - 4. Multiple row retrieval - - Why this matters: - ---------------- - These are the most fundamental database operations that - every application needs to perform reliably. - """ - # Create a test table - table_name = generate_unique_table("test_crud") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - name TEXT, - age INT, - created_at TIMESTAMP - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, name, age, created_at) VALUES (?, ?, ?, toTimestamp(now()))" - ) - select_stmt = await cassandra_session.prepare( - f"SELECT id, name, age, created_at FROM {table_name} WHERE id = ?" - ) - select_all_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name}") - - # Insert test data - test_id = uuid.uuid4() - test_name = "John Doe" - test_age = 30 - - await cassandra_session.execute(insert_stmt, (test_id, test_name, test_age)) - - # Select and verify single row - result = await cassandra_session.execute(select_stmt, (test_id,)) - rows = list(result) - assert len(rows) == 1 - row = rows[0] - assert row.id == test_id - assert row.name == test_name - assert row.age == test_age - assert row.created_at is not None - - # Insert more data - more_ids = [] - for i in range(5): - new_id = uuid.uuid4() - more_ids.append(new_id) - await cassandra_session.execute(insert_stmt, (new_id, f"Person {i}", 20 + i)) - - # Select all and verify - result = await cassandra_session.execute(select_all_stmt) - all_rows = list(result) - assert len(all_rows) == 6 # Original + 5 more - - # Verify all IDs are present - all_ids = {row.id for row in all_rows} - assert test_id in all_ids - for more_id in more_ids: - assert more_id in all_ids - - async def test_update_and_delete(self, cassandra_session, shared_keyspace_setup): - """ - Test update and delete operations. - - What this tests: - --------------- - 1. UPDATE with prepared statements - 2. Conditional updates (IF EXISTS) - 3. DELETE operations - 4. Verification of changes - - Why this matters: - ---------------- - Update and delete operations are critical for maintaining - data accuracy and lifecycle management. - """ - # Create test table - table_name = generate_unique_table("test_update_delete") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - name TEXT, - email TEXT, - active BOOLEAN, - score DECIMAL - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, name, email, active, score) VALUES (?, ?, ?, ?, ?)" - ) - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET email = ?, active = ? WHERE id = ?" - ) - update_if_exists_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET score = ? WHERE id = ? IF EXISTS" - ) - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - delete_stmt = await cassandra_session.prepare(f"DELETE FROM {table_name} WHERE id = ?") - - # Insert test data - test_id = uuid.uuid4() - await cassandra_session.execute( - insert_stmt, (test_id, "Alice Smith", "alice@example.com", True, Decimal("85.5")) - ) - - # Update the record - new_email = "alice.smith@example.com" - await cassandra_session.execute(update_stmt, (new_email, False, test_id)) - - # Verify update - result = await cassandra_session.execute(select_stmt, (test_id,)) - row = result.one() - assert row.email == new_email - assert row.active is False - assert row.name == "Alice Smith" # Unchanged - assert row.score == Decimal("85.5") # Unchanged - - # Test conditional update - result = await cassandra_session.execute(update_if_exists_stmt, (Decimal("92.0"), test_id)) - assert result.one().applied is True - - # Verify conditional update worked - result = await cassandra_session.execute(select_stmt, (test_id,)) - assert result.one().score == Decimal("92.0") - - # Test conditional update on non-existent record - fake_id = uuid.uuid4() - result = await cassandra_session.execute(update_if_exists_stmt, (Decimal("100.0"), fake_id)) - assert result.one().applied is False - - # Delete the record - await cassandra_session.execute(delete_stmt, (test_id,)) - - # Verify deletion - in Cassandra, a deleted row may still appear with null values - # if only some columns were deleted. The row truly disappears only after compaction. - result = await cassandra_session.execute(select_stmt, (test_id,)) - row = result.one() - if row is not None: - # If row still exists, all non-primary key columns should be None - assert row.name is None - assert row.email is None - assert row.active is None - # Note: score might remain due to tombstone timing - - async def test_select_non_existent_data(self, cassandra_session, shared_keyspace_setup): - """ - Test selecting non-existent data. - - What this tests: - --------------- - 1. SELECT returns empty result for non-existent primary key - 2. No exceptions thrown for missing data - 3. Result iteration handles empty results - - Why this matters: - ---------------- - Applications must gracefully handle queries that return no data. - """ - # Create test table - table_name = generate_unique_table("test_non_existent") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Prepare select statement - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - - # Query for non-existent ID - fake_id = uuid.uuid4() - result = await cassandra_session.execute(select_stmt, (fake_id,)) - - # Should return empty result, not error - assert result.one() is None - assert list(result) == [] - - # ======================================== - # Prepared Statement CRUD - # ======================================== - - async def test_prepared_statement_lifecycle(self, cassandra_session, shared_keyspace_setup): - """ - Test prepared statement lifecycle and reuse. - - What this tests: - --------------- - 1. Prepare once, execute many times - 2. Prepared statements with different parameter counts - 3. Performance benefit of prepared statements - 4. Statement reuse across operations - - Why this matters: - ---------------- - Prepared statements are the recommended way to execute queries - for performance, security, and consistency. - """ - # Create test table - table_name = generate_unique_table("test_prepared") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - partition_key INT, - clustering_key INT, - value TEXT, - metadata MAP, - PRIMARY KEY (partition_key, clustering_key) - ) - """ - ) - - # Prepare various statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (partition_key, clustering_key, value) VALUES (?, ?, ?)" - ) - - insert_with_meta_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (partition_key, clustering_key, value, metadata) VALUES (?, ?, ?, ?)" - ) - - select_partition_stmt = await cassandra_session.prepare( - f"SELECT * FROM {table_name} WHERE partition_key = ?" - ) - - select_row_stmt = await cassandra_session.prepare( - f"SELECT * FROM {table_name} WHERE partition_key = ? AND clustering_key = ?" - ) - - update_value_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET value = ? WHERE partition_key = ? AND clustering_key = ?" - ) - - delete_row_stmt = await cassandra_session.prepare( - f"DELETE FROM {table_name} WHERE partition_key = ? AND clustering_key = ?" - ) - - # Execute many times with same prepared statements - partition = 1 - - # Insert multiple rows - for i in range(10): - await cassandra_session.execute(insert_stmt, (partition, i, f"value_{i}")) - - # Insert with metadata - await cassandra_session.execute( - insert_with_meta_stmt, - (partition, 100, "special", {"type": "special", "priority": "high"}), - ) - - # Select entire partition - result = await cassandra_session.execute(select_partition_stmt, (partition,)) - rows = list(result) - assert len(rows) == 11 - - # Update specific rows - for i in range(0, 10, 2): # Update even rows - await cassandra_session.execute(update_value_stmt, (f"updated_{i}", partition, i)) - - # Verify updates - for i in range(10): - result = await cassandra_session.execute(select_row_stmt, (partition, i)) - row = result.one() - if i % 2 == 0: - assert row.value == f"updated_{i}" - else: - assert row.value == f"value_{i}" - - # Delete some rows - for i in range(5, 10): - await cassandra_session.execute(delete_row_stmt, (partition, i)) - - # Verify final state - result = await cassandra_session.execute(select_partition_stmt, (partition,)) - remaining_rows = list(result) - assert len(remaining_rows) == 6 # 0-4 plus row 100 - - # ======================================== - # Batch Operations - # ======================================== - - async def test_batch_insert_operations(self, cassandra_session, shared_keyspace_setup): - """ - Test batch insert operations. - - What this tests: - --------------- - 1. LOGGED batch inserts - 2. UNLOGGED batch inserts - 3. Batch size limits - 4. Mixed statement batches - - Why this matters: - ---------------- - Batch operations can improve performance for related writes - and ensure atomicity for LOGGED batches. - """ - # Create test table - table_name = generate_unique_table("test_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - type TEXT, - value INT, - timestamp TIMESTAMP - ) - """ - ) - - # Prepare insert statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, type, value, timestamp) VALUES (?, ?, ?, toTimestamp(now()))" - ) - - # Test LOGGED batch (atomic) - logged_batch = BatchStatement(batch_type=BatchType.LOGGED) - logged_ids = [] - - for i in range(10): - batch_id = uuid.uuid4() - logged_ids.append(batch_id) - logged_batch.add(insert_stmt, (batch_id, "logged", i)) - - await cassandra_session.execute(logged_batch) - - # Verify all logged batch inserts - for batch_id in logged_ids: - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (batch_id,) - ) - assert result.one() is not None - - # Test UNLOGGED batch (better performance, no atomicity) - unlogged_batch = BatchStatement(batch_type=BatchType.UNLOGGED) - unlogged_ids = [] - - for i in range(20): - batch_id = uuid.uuid4() - unlogged_ids.append(batch_id) - unlogged_batch.add(insert_stmt, (batch_id, "unlogged", i)) - - await cassandra_session.execute(unlogged_batch) - - # Verify unlogged batch inserts - count = 0 - for batch_id in unlogged_ids: - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (batch_id,) - ) - if result.one() is not None: - count += 1 - - # All should succeed in normal conditions - assert count == 20 - - # Test mixed batch with different operations - mixed_table = generate_unique_table("test_mixed_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {mixed_table} ( - pk INT, - ck INT, - value TEXT, - PRIMARY KEY (pk, ck) - ) - """ - ) - - insert_mixed = await cassandra_session.prepare( - f"INSERT INTO {mixed_table} (pk, ck, value) VALUES (?, ?, ?)" - ) - update_mixed = await cassandra_session.prepare( - f"UPDATE {mixed_table} SET value = ? WHERE pk = ? AND ck = ?" - ) - - # Insert initial data - await cassandra_session.execute(insert_mixed, (1, 1, "initial")) - - # Mixed batch - mixed_batch = BatchStatement() - mixed_batch.add(insert_mixed, (1, 2, "new_insert")) - mixed_batch.add(update_mixed, ("updated", 1, 1)) - mixed_batch.add(insert_mixed, (1, 3, "another_insert")) - - await cassandra_session.execute(mixed_batch) - - # Verify mixed batch results - result = await cassandra_session.execute(f"SELECT * FROM {mixed_table} WHERE pk = 1") - rows = {row.ck: row.value for row in result} - - assert rows[1] == "updated" - assert rows[2] == "new_insert" - assert rows[3] == "another_insert" - - # ======================================== - # Edge Cases - # ======================================== - - async def test_null_value_handling(self, cassandra_session, shared_keyspace_setup): - """ - Test NULL value handling in CRUD operations. - - What this tests: - --------------- - 1. INSERT with NULL values - 2. UPDATE to NULL (deletion of value) - 3. SELECT with NULL values - 4. Distinction between NULL and empty string - - Why this matters: - ---------------- - NULL handling is a common source of bugs. Applications must - correctly handle NULL vs empty vs missing values. - """ - # Create test table - table_name = generate_unique_table("test_null") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - required_field TEXT, - optional_field TEXT, - numeric_field INT, - collection_field LIST - ) - """ - ) - - # Test inserting with NULL values - test_id = uuid.uuid4() - insert_stmt = await cassandra_session.prepare( - f"""INSERT INTO {table_name} - (id, required_field, optional_field, numeric_field, collection_field) - VALUES (?, ?, ?, ?, ?)""" - ) - - # Insert with some NULL values - await cassandra_session.execute(insert_stmt, (test_id, "required", None, None, None)) - - # Select and verify NULLs - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (test_id,) - ) - row = result.one() - - assert row.required_field == "required" - assert row.optional_field is None - assert row.numeric_field is None - assert row.collection_field is None - - # Test updating to NULL (removes the value) - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET required_field = ? WHERE id = ?" - ) - await cassandra_session.execute(update_stmt, (None, test_id)) - - # In Cassandra, setting to NULL deletes the column - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (test_id,) - ) - row = result.one() - assert row.required_field is None - - # Test empty string vs NULL - test_id2 = uuid.uuid4() - await cassandra_session.execute( - insert_stmt, (test_id2, "", "", 0, []) # Empty values, not NULL - ) - - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE id = %s", (test_id2,) - ) - row = result.one() - - # Empty string is different from NULL - assert row.required_field == "" - assert row.optional_field == "" - assert row.numeric_field == 0 - # In Cassandra, empty collections are stored as NULL - assert row.collection_field is None # Empty list becomes NULL - - async def test_large_text_operations(self, cassandra_session, shared_keyspace_setup): - """ - Test CRUD operations with large text data. - - What this tests: - --------------- - 1. INSERT large text blobs - 2. SELECT large text data - 3. UPDATE with large text - 4. Performance with large values - - Why this matters: - ---------------- - Many applications store large text data (JSON, XML, logs). - The driver must handle these efficiently. - """ - # Create test table - table_name = generate_unique_table("test_large_text") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id UUID PRIMARY KEY, - small_text TEXT, - large_text TEXT, - metadata MAP - ) - """ - ) - - # Generate large text data - large_text = "x" * 100000 # 100KB of text - small_text = "This is a small text field" - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"""INSERT INTO {table_name} - (id, small_text, large_text, metadata) - VALUES (?, ?, ?, ?)""" - ) - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - - # Insert large text - test_id = uuid.uuid4() - metadata = {f"key_{i}": f"value_{i}" * 100 for i in range(10)} - - await cassandra_session.execute(insert_stmt, (test_id, small_text, large_text, metadata)) - - # Select and verify - result = await cassandra_session.execute(select_stmt, (test_id,)) - row = result.one() - - assert row.small_text == small_text - assert row.large_text == large_text - assert len(row.large_text) == 100000 - assert len(row.metadata) == 10 - - # Update with even larger text - larger_text = "y" * 200000 # 200KB - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET large_text = ? WHERE id = ?" - ) - - await cassandra_session.execute(update_stmt, (larger_text, test_id)) - - # Verify update - result = await cassandra_session.execute(select_stmt, (test_id,)) - row = result.one() - assert row.large_text == larger_text - assert len(row.large_text) == 200000 - - # Test multiple large text operations - bulk_ids = [] - for i in range(5): - bulk_id = uuid.uuid4() - bulk_ids.append(bulk_id) - await cassandra_session.execute(insert_stmt, (bulk_id, f"bulk_{i}", large_text, None)) - - # Verify all bulk inserts - for bulk_id in bulk_ids: - result = await cassandra_session.execute(select_stmt, (bulk_id,)) - assert result.one() is not None diff --git a/tests/integration/test_data_types_and_counters.py b/tests/integration/test_data_types_and_counters.py deleted file mode 100644 index a954c27..0000000 --- a/tests/integration/test_data_types_and_counters.py +++ /dev/null @@ -1,1350 +0,0 @@ -""" -Consolidated integration tests for Cassandra data types and counter operations. - -This module combines all data type and counter tests from multiple files, -providing comprehensive coverage of Cassandra's type system. - -Tests consolidated from: -- test_cassandra_data_types.py - All supported Cassandra data types -- test_counters.py - Counter-specific operations and edge cases -- Various type usage from other test files - -Test Organization: -================== -1. Basic Data Types - Numeric, text, temporal, boolean, UUID, binary -2. Collection Types - List, set, map, tuple, frozen collections -3. Special Types - Inet, counter -4. Counter Operations - Increment, decrement, concurrent updates -5. Type Conversions and Edge Cases - NULL handling, boundaries, errors -""" - -import asyncio -import datetime -import decimal -import uuid -from datetime import date -from datetime import time as datetime_time -from datetime import timezone - -import pytest -from cassandra import ConsistencyLevel, InvalidRequest -from cassandra.util import Date, Time, uuid_from_time -from test_utils import generate_unique_table - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestDataTypes: - """Test various Cassandra data types with real Cassandra.""" - - # ======================================== - # Numeric Data Types - # ======================================== - - async def test_numeric_types(self, cassandra_session, shared_keyspace_setup): - """ - Test all numeric data types in Cassandra. - - What this tests: - --------------- - 1. TINYINT, SMALLINT, INT, BIGINT - 2. FLOAT, DOUBLE - 3. DECIMAL, VARINT - 4. Boundary values - 5. Precision handling - - Why this matters: - ---------------- - Numeric types have different ranges and precision characteristics. - Choosing the right type affects storage and performance. - """ - # Create test table with all numeric types - table_name = generate_unique_table("test_numeric_types") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - tiny_val TINYINT, - small_val SMALLINT, - int_val INT, - big_val BIGINT, - float_val FLOAT, - double_val DOUBLE, - decimal_val DECIMAL, - varint_val VARINT - ) - """ - ) - - # Prepare insert statement - insert_stmt = await cassandra_session.prepare( - f""" - INSERT INTO {table_name} - (id, tiny_val, small_val, int_val, big_val, - float_val, double_val, decimal_val, varint_val) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """ - ) - - # Test various numeric values - test_cases = [ - # Normal values - ( - 1, - 127, - 32767, - 2147483647, - 9223372036854775807, - 3.14, - 3.141592653589793, - decimal.Decimal("123.456"), - 123456789, - ), - # Negative values - ( - 2, - -128, - -32768, - -2147483648, - -9223372036854775808, - -3.14, - -3.141592653589793, - decimal.Decimal("-123.456"), - -123456789, - ), - # Zero values - (3, 0, 0, 0, 0, 0.0, 0.0, decimal.Decimal("0"), 0), - # High precision decimal - (4, 1, 1, 1, 1, 1.1, 1.1, decimal.Decimal("123456789.123456789"), 123456789123456789), - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Verify all values - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - - for i, expected in enumerate(test_cases, 1): - result = await cassandra_session.execute(select_stmt, (i,)) - row = result.one() - - # Verify each numeric type - assert row.id == expected[0] - assert row.tiny_val == expected[1] - assert row.small_val == expected[2] - assert row.int_val == expected[3] - assert row.big_val == expected[4] - assert abs(row.float_val - expected[5]) < 0.0001 # Float comparison - assert abs(row.double_val - expected[6]) < 0.0000001 # Double comparison - assert row.decimal_val == expected[7] - assert row.varint_val == expected[8] - - async def test_text_types(self, cassandra_session, shared_keyspace_setup): - """ - Test text-based data types. - - What this tests: - --------------- - 1. TEXT and VARCHAR (synonymous in Cassandra) - 2. ASCII type - 3. Unicode handling - 4. Empty strings vs NULL - 5. Maximum string lengths - - Why this matters: - ---------------- - Text types are the most common data types. Understanding - encoding and storage implications is crucial. - """ - # Create test table - table_name = generate_unique_table("test_text_types") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - text_val TEXT, - varchar_val VARCHAR, - ascii_val ASCII - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, text_val, varchar_val, ascii_val) VALUES (?, ?, ?, ?)" - ) - - # Test various text values - test_cases = [ - (1, "Simple text", "Simple varchar", "Simple ASCII"), - (2, "Unicode: 你好世界 🌍", "Unicode: émojis 😀", "ASCII only"), - (3, "", "", ""), # Empty strings - (4, " " * 100, " " * 100, " " * 100), # Spaces - (5, "Line\nBreaks\r\nAllowed", "Special\tChars\t", "No_Special"), - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Test NULL values - await cassandra_session.execute(insert_stmt, (6, None, None, None)) - - # Verify values - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = list(result) - assert len(rows) == 6 - - # Verify specific cases - for row in rows: - if row.id == 2: - assert "你好世界" in row.text_val - assert "émojis" in row.varchar_val - elif row.id == 3: - assert row.text_val == "" - assert row.varchar_val == "" - assert row.ascii_val == "" - elif row.id == 6: - assert row.text_val is None - assert row.varchar_val is None - assert row.ascii_val is None - - async def test_temporal_types(self, cassandra_session, shared_keyspace_setup): - """ - Test date and time related data types. - - What this tests: - --------------- - 1. TIMESTAMP type - 2. DATE type - 3. TIME type - 4. Timezone handling - 5. Precision and range - - Why this matters: - ---------------- - Temporal data is common in applications. Understanding - precision and timezone behavior is critical. - """ - # Create test table - table_name = generate_unique_table("test_temporal_types") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - ts_val TIMESTAMP, - date_val DATE, - time_val TIME - ) - """ - ) - - # Prepare insert - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, ts_val, date_val, time_val) VALUES (?, ?, ?, ?)" - ) - - # Test values - now = datetime.datetime.now(timezone.utc) - today = Date(date.today()) - current_time = Time(datetime_time(14, 30, 45, 123000)) # 14:30:45.123 - - test_cases = [ - (1, now, today, current_time), - ( - 2, - datetime.datetime(2000, 1, 1, 0, 0, 0, 0, timezone.utc), - Date(date(2000, 1, 1)), - Time(datetime_time(0, 0, 0)), - ), - ( - 3, - datetime.datetime(2038, 1, 19, 3, 14, 7, 0, timezone.utc), - Date(date(2038, 1, 19)), - Time(datetime_time(23, 59, 59, 999999)), - ), - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Verify temporal values - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = list(result) - assert len(rows) == 3 - - # Check timestamp precision (millisecond precision in Cassandra) - row1 = next(r for r in rows if r.id == 1) - # Handle both timezone-aware and naive datetimes - if row1.ts_val.tzinfo is None: - # Convert to UTC aware for comparison - row_ts = row1.ts_val.replace(tzinfo=timezone.utc) - else: - row_ts = row1.ts_val - assert abs((row_ts - now).total_seconds()) < 1 - - async def test_uuid_types(self, cassandra_session, shared_keyspace_setup): - """ - Test UUID and TIMEUUID data types. - - What this tests: - --------------- - 1. UUID type (type 4 random UUID) - 2. TIMEUUID type (type 1 time-based UUID) - 3. UUID generation functions - 4. Time extraction from TIMEUUID - - Why this matters: - ---------------- - UUIDs are commonly used for distributed unique identifiers. - TIMEUUIDs provide time-ordering capabilities. - """ - # Create test table - table_name = generate_unique_table("test_uuid_types") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - uuid_val UUID, - timeuuid_val TIMEUUID, - created_at TIMESTAMP - ) - """ - ) - - # Test UUIDs - regular_uuid = uuid.uuid4() - time_uuid = uuid_from_time(datetime.datetime.now()) - - # Insert with prepared statement - insert_stmt = await cassandra_session.prepare( - f""" - INSERT INTO {table_name} (id, uuid_val, timeuuid_val, created_at) - VALUES (?, ?, ?, ?) - """ - ) - - await cassandra_session.execute( - insert_stmt, (1, regular_uuid, time_uuid, datetime.datetime.now(timezone.utc)) - ) - - # Test UUID functions - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, uuid_val, timeuuid_val) VALUES (2, uuid(), now())" - ) - - # Verify UUIDs - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = list(result) - assert len(rows) == 2 - - # Verify UUID types - for row in rows: - assert isinstance(row.uuid_val, uuid.UUID) - assert isinstance(row.timeuuid_val, uuid.UUID) - # TIMEUUID should be version 1 - if row.id == 1: - assert row.timeuuid_val.version == 1 - - async def test_binary_and_boolean_types(self, cassandra_session, shared_keyspace_setup): - """ - Test BLOB and BOOLEAN data types. - - What this tests: - --------------- - 1. BLOB type for binary data - 2. BOOLEAN type - 3. Binary data encoding/decoding - 4. NULL vs empty blob - - Why this matters: - ---------------- - Binary data storage and boolean flags are common requirements. - """ - # Create test table - table_name = generate_unique_table("test_binary_boolean") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - binary_data BLOB, - is_active BOOLEAN, - is_verified BOOLEAN - ) - """ - ) - - # Prepare statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, binary_data, is_active, is_verified) VALUES (?, ?, ?, ?)" - ) - - # Test data - test_cases = [ - (1, b"Hello World", True, False), - (2, b"\x00\x01\x02\x03\xff", False, True), - (3, b"", True, True), # Empty blob - (4, None, None, None), # NULL values - (5, b"Unicode bytes: \xf0\x9f\x98\x80", False, False), - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Verify data - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = {row.id: row for row in result} - - assert rows[1].binary_data == b"Hello World" - assert rows[1].is_active is True - assert rows[1].is_verified is False - - assert rows[2].binary_data == b"\x00\x01\x02\x03\xff" - assert rows[3].binary_data == b"" # Empty blob - assert rows[4].binary_data is None - assert rows[4].is_active is None - - async def test_inet_types(self, cassandra_session, shared_keyspace_setup): - """ - Test INET data type for IP addresses. - - What this tests: - --------------- - 1. IPv4 addresses - 2. IPv6 addresses - 3. Address validation - 4. String conversion - - Why this matters: - ---------------- - Storing IP addresses efficiently is common in network applications. - """ - # Create test table - table_name = generate_unique_table("test_inet_types") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - client_ip INET, - server_ip INET, - description TEXT - ) - """ - ) - - # Prepare statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, client_ip, server_ip, description) VALUES (?, ?, ?, ?)" - ) - - # Test IP addresses - test_cases = [ - (1, "192.168.1.1", "10.0.0.1", "Private IPv4"), - (2, "8.8.8.8", "8.8.4.4", "Public IPv4"), - (3, "::1", "fe80::1", "IPv6 loopback and link-local"), - (4, "2001:db8::1", "2001:db8:0:0:1:0:0:1", "IPv6 public"), - (5, "127.0.0.1", "::ffff:127.0.0.1", "IPv4 and IPv4-mapped IPv6"), - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Verify IP addresses - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = list(result) - assert len(rows) == 5 - - # Verify specific addresses - for row in rows: - assert row.client_ip is not None - assert row.server_ip is not None - # IPs are returned as strings - if row.id == 1: - assert row.client_ip == "192.168.1.1" - elif row.id == 3: - assert row.client_ip == "::1" - - # ======================================== - # Collection Data Types - # ======================================== - - async def test_list_type(self, cassandra_session, shared_keyspace_setup): - """ - Test LIST collection type. - - What this tests: - --------------- - 1. List creation and manipulation - 2. Ordering preservation - 3. Duplicate values - 4. NULL vs empty list - 5. List updates and appends - - Why this matters: - ---------------- - Lists maintain order and allow duplicates, useful for - ordered collections like tags or history. - """ - # Create test table - table_name = generate_unique_table("test_list_type") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - tags LIST, - scores LIST, - timestamps LIST - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, tags, scores, timestamps) VALUES (?, ?, ?, ?)" - ) - - # Test list operations - now = datetime.datetime.now(timezone.utc) - test_cases = [ - (1, ["tag1", "tag2", "tag3"], [100, 200, 300], [now]), - (2, ["duplicate", "duplicate"], [1, 1, 2, 3, 5], None), # Duplicates allowed - (3, [], [], []), # Empty lists - (4, None, None, None), # NULL lists - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Test list append - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET tags = tags + ? WHERE id = ?" - ) - await cassandra_session.execute(update_stmt, (["tag4", "tag5"], 1)) - - # Test list prepend - update_prepend = await cassandra_session.prepare( - f"UPDATE {table_name} SET tags = ? + tags WHERE id = ?" - ) - await cassandra_session.execute(update_prepend, (["tag0"], 1)) - - # Verify lists - result = await cassandra_session.execute(f"SELECT * FROM {table_name} WHERE id = 1") - row = result.one() - assert row.tags == ["tag0", "tag1", "tag2", "tag3", "tag4", "tag5"] - - # Test removing from list - update_remove = await cassandra_session.prepare( - f"UPDATE {table_name} SET scores = scores - ? WHERE id = ?" - ) - await cassandra_session.execute(update_remove, ([1], 2)) - - result = await cassandra_session.execute(f"SELECT * FROM {table_name} WHERE id = 2") - row = result.one() - # Note: removes all occurrences - assert 1 not in row.scores - - async def test_set_type(self, cassandra_session, shared_keyspace_setup): - """ - Test SET collection type. - - What this tests: - --------------- - 1. Set creation and manipulation - 2. Uniqueness enforcement - 3. Unordered nature - 4. Set operations (add, remove) - 5. NULL vs empty set - - Why this matters: - ---------------- - Sets enforce uniqueness and are useful for tags, - categories, or any unique collection. - """ - # Create test table - table_name = generate_unique_table("test_set_type") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - categories SET, - user_ids SET, - ip_addresses SET - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, categories, user_ids, ip_addresses) VALUES (?, ?, ?, ?)" - ) - - # Test data - user_id1 = uuid.uuid4() - user_id2 = uuid.uuid4() - - test_cases = [ - (1, {"tech", "news", "sports"}, {user_id1, user_id2}, {"192.168.1.1", "10.0.0.1"}), - (2, {"tech", "tech", "tech"}, {user_id1}, None), # Duplicates become unique - (3, set(), set(), set()), # Empty sets - Note: these become NULL in Cassandra - (4, None, None, None), # NULL sets - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Test set addition - update_add = await cassandra_session.prepare( - f"UPDATE {table_name} SET categories = categories + ? WHERE id = ?" - ) - await cassandra_session.execute(update_add, ({"politics", "tech"}, 1)) - - # Test set removal - update_remove = await cassandra_session.prepare( - f"UPDATE {table_name} SET categories = categories - ? WHERE id = ?" - ) - await cassandra_session.execute(update_remove, ({"sports"}, 1)) - - # Verify sets - result = await cassandra_session.execute(f"SELECT * FROM {table_name} WHERE id = 1") - row = result.one() - # Sets are unordered - assert row.categories == {"tech", "news", "politics"} - - # Check empty set behavior - result3 = await cassandra_session.execute(f"SELECT * FROM {table_name} WHERE id = 3") - row3 = result3.one() - # Empty sets become NULL in Cassandra - assert row3.categories is None - - async def test_map_type(self, cassandra_session, shared_keyspace_setup): - """ - Test MAP collection type. - - What this tests: - --------------- - 1. Map creation and manipulation - 2. Key-value pairs - 3. Key uniqueness - 4. Map updates - 5. NULL vs empty map - - Why this matters: - ---------------- - Maps provide key-value storage within a column, - useful for metadata or configuration. - """ - # Create test table - table_name = generate_unique_table("test_map_type") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - metadata MAP, - scores MAP, - timestamps MAP - ) - """ - ) - - # Prepare statements - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, metadata, scores, timestamps) VALUES (?, ?, ?, ?)" - ) - - # Test data - now = datetime.datetime.now(timezone.utc) - test_cases = [ - (1, {"name": "John", "city": "NYC"}, {"math": 95, "english": 88}, {"created": now}), - (2, {"key": "value"}, None, None), - (3, {}, {}, {}), # Empty maps - become NULL - (4, None, None, None), # NULL maps - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Test map update - add/update entries - update_map = await cassandra_session.prepare( - f"UPDATE {table_name} SET metadata = metadata + ? WHERE id = ?" - ) - await cassandra_session.execute(update_map, ({"country": "USA", "city": "Boston"}, 1)) - - # Test map entry update - update_entry = await cassandra_session.prepare( - f"UPDATE {table_name} SET metadata[?] = ? WHERE id = ?" - ) - await cassandra_session.execute(update_entry, ("status", "active", 1)) - - # Test map entry deletion - delete_entry = await cassandra_session.prepare( - f"DELETE metadata[?] FROM {table_name} WHERE id = ?" - ) - await cassandra_session.execute(delete_entry, ("name", 1)) - - # Verify map - result = await cassandra_session.execute(f"SELECT * FROM {table_name} WHERE id = 1") - row = result.one() - assert row.metadata == {"city": "Boston", "country": "USA", "status": "active"} - assert "name" not in row.metadata # Deleted - - async def test_tuple_type(self, cassandra_session, shared_keyspace_setup): - """ - Test TUPLE type. - - What this tests: - --------------- - 1. Fixed-size ordered collections - 2. Heterogeneous types - 3. Tuple comparison - 4. NULL elements in tuples - - Why this matters: - ---------------- - Tuples provide fixed-structure data storage, - useful for coordinates, versions, etc. - """ - # Create test table - table_name = generate_unique_table("test_tuple_type") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - coordinates TUPLE, - version TUPLE, - user_info TUPLE - ) - """ - ) - - # Prepare statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, coordinates, version, user_info) VALUES (?, ?, ?, ?)" - ) - - # Test tuples - test_cases = [ - (1, (37.7749, -122.4194), (1, 2, 3), ("Alice", 25, True)), - (2, (0.0, 0.0), (0, 0, 1), ("Bob", None, False)), # NULL element - (3, None, None, None), # NULL tuples - ] - - for values in test_cases: - await cassandra_session.execute(insert_stmt, values) - - # Verify tuples - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = {row.id: row for row in result} - - assert rows[1].coordinates == (37.7749, -122.4194) - assert rows[1].version == (1, 2, 3) - assert rows[1].user_info == ("Alice", 25, True) - - # Check NULL element in tuple - assert rows[2].user_info == ("Bob", None, False) - - async def test_frozen_collections(self, cassandra_session, shared_keyspace_setup): - """ - Test FROZEN collections. - - What this tests: - --------------- - 1. Frozen lists, sets, maps - 2. Nested frozen collections - 3. Immutability of frozen collections - 4. Use as primary key components - - Why this matters: - ---------------- - Frozen collections can be used in primary keys and - are stored more efficiently but cannot be updated partially. - """ - # Create test table with frozen collections - table_name = generate_unique_table("test_frozen_collections") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT, - frozen_tags FROZEN>, - config FROZEN>, - nested FROZEN>>>, - PRIMARY KEY (id, frozen_tags) - ) - """ - ) - - # Prepare statement - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, frozen_tags, config, nested) VALUES (?, ?, ?, ?)" - ) - - # Test frozen collections - test_cases = [ - (1, {"tag1", "tag2"}, {"key1": "val1"}, {"nums": [1, 2, 3]}), - (1, {"tag3", "tag4"}, {"key2": "val2"}, {"nums": [4, 5, 6]}), - (2, set(), {}, {}), # Empty frozen collections - ] - - for values in test_cases: - # Convert the list to tuple for frozen list - id_val, tags, config, nested_dict = values - # Convert nested list to tuple for frozen representation - nested_frozen = {k: v for k, v in nested_dict.items()} - await cassandra_session.execute(insert_stmt, (id_val, tags, config, nested_frozen)) - - # Verify frozen collections - result = await cassandra_session.execute(f"SELECT * FROM {table_name} WHERE id = 1") - rows = list(result) - assert len(rows) == 2 # Two rows with same id but different frozen_tags - - # Try to update frozen collection (should replace entire value) - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET config = ? WHERE id = ? AND frozen_tags = ?" - ) - await cassandra_session.execute(update_stmt, ({"new": "config"}, 1, {"tag1", "tag2"})) - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestCounterOperations: - """Test counter data type operations with real Cassandra.""" - - async def test_basic_counter_operations(self, cassandra_session, shared_keyspace_setup): - """ - Test basic counter increment and decrement. - - What this tests: - --------------- - 1. Counter table creation - 2. INCREMENT operations - 3. DECREMENT operations - 4. Counter initialization - 5. Reading counter values - - Why this matters: - ---------------- - Counters provide atomic increment/decrement operations - essential for metrics and statistics. - """ - # Create counter table - table_name = generate_unique_table("test_basic_counters") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - page_views COUNTER, - likes COUNTER, - shares COUNTER - ) - """ - ) - - # Prepare counter update statements - increment_views = await cassandra_session.prepare( - f"UPDATE {table_name} SET page_views = page_views + ? WHERE id = ?" - ) - increment_likes = await cassandra_session.prepare( - f"UPDATE {table_name} SET likes = likes + ? WHERE id = ?" - ) - decrement_shares = await cassandra_session.prepare( - f"UPDATE {table_name} SET shares = shares - ? WHERE id = ?" - ) - - # Test counter operations - post_id = "post_001" - - # Increment counters - await cassandra_session.execute(increment_views, (100, post_id)) - await cassandra_session.execute(increment_likes, (10, post_id)) - await cassandra_session.execute(increment_views, (50, post_id)) # Another increment - - # Decrement counter - await cassandra_session.execute(decrement_shares, (5, post_id)) - - # Read counter values - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - result = await cassandra_session.execute(select_stmt, (post_id,)) - row = result.one() - - assert row.page_views == 150 # 100 + 50 - assert row.likes == 10 - assert row.shares == -5 # Started at 0, decremented by 5 - - # Test multiple increments in sequence - for i in range(10): - await cassandra_session.execute(increment_likes, (1, post_id)) - - result = await cassandra_session.execute(select_stmt, (post_id,)) - row = result.one() - assert row.likes == 20 # 10 + 10*1 - - async def test_concurrent_counter_updates(self, cassandra_session, shared_keyspace_setup): - """ - Test concurrent counter updates. - - What this tests: - --------------- - 1. Thread-safe counter operations - 2. No lost updates - 3. Atomic increments - 4. Performance under concurrency - - Why this matters: - ---------------- - Counters must handle concurrent updates correctly - in distributed systems. - """ - # Create counter table - table_name = generate_unique_table("test_concurrent_counters") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - total_requests COUNTER, - error_count COUNTER - ) - """ - ) - - # Prepare statements - increment_requests = await cassandra_session.prepare( - f"UPDATE {table_name} SET total_requests = total_requests + ? WHERE id = ?" - ) - increment_errors = await cassandra_session.prepare( - f"UPDATE {table_name} SET error_count = error_count + ? WHERE id = ?" - ) - - service_id = "api_service" - - # Simulate concurrent updates - async def increment_counter(counter_type, count): - if counter_type == "requests": - await cassandra_session.execute(increment_requests, (count, service_id)) - else: - await cassandra_session.execute(increment_errors, (count, service_id)) - - # Run 100 concurrent increments - tasks = [] - for i in range(100): - tasks.append(increment_counter("requests", 1)) - if i % 10 == 0: # 10% error rate - tasks.append(increment_counter("errors", 1)) - - await asyncio.gather(*tasks) - - # Verify final counts - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {table_name} WHERE id = ?") - result = await cassandra_session.execute(select_stmt, (service_id,)) - row = result.one() - - assert row.total_requests == 100 - assert row.error_count == 10 - - async def test_counter_consistency_levels(self, cassandra_session, shared_keyspace_setup): - """ - Test counters with different consistency levels. - - What this tests: - --------------- - 1. Counter updates with QUORUM - 2. Counter reads with different consistency - 3. Consistency vs performance trade-offs - - Why this matters: - ---------------- - Counter consistency affects accuracy and performance - in distributed deployments. - """ - # Create counter table - table_name = generate_unique_table("test_counter_consistency") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - metric_value COUNTER - ) - """ - ) - - # Prepare statements with different consistency levels - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET metric_value = metric_value + ? WHERE id = ?" - ) - update_stmt.consistency_level = ConsistencyLevel.QUORUM - - select_stmt = await cassandra_session.prepare( - f"SELECT metric_value FROM {table_name} WHERE id = ?" - ) - select_stmt.consistency_level = ConsistencyLevel.ONE - - metric_id = "cpu_usage" - - # Update with QUORUM consistency - await cassandra_session.execute(update_stmt, (75, metric_id)) - - # Read with ONE consistency (faster but potentially stale) - result = await cassandra_session.execute(select_stmt, (metric_id,)) - row = result.one() - assert row.metric_value == 75 - - async def test_counter_special_cases(self, cassandra_session, shared_keyspace_setup): - """ - Test counter special cases and limitations. - - What this tests: - --------------- - 1. Counters cannot be set to specific values - 2. Counters cannot have TTL - 3. Counter deletion behavior - 4. NULL counter behavior - - Why this matters: - ---------------- - Understanding counter limitations prevents - design mistakes and runtime errors. - """ - # Create counter table - table_name = generate_unique_table("test_counter_special") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id TEXT PRIMARY KEY, - counter_val COUNTER - ) - """ - ) - - # Test that we cannot INSERT counters (only UPDATE) - with pytest.raises(InvalidRequest): - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, counter_val) VALUES ('test', 100)" - ) - - # Test that counters cannot have TTL - with pytest.raises(InvalidRequest): - await cassandra_session.execute( - f"UPDATE {table_name} USING TTL 3600 SET counter_val = counter_val + 1 WHERE id = 'test'" - ) - - # Test counter deletion - update_stmt = await cassandra_session.prepare( - f"UPDATE {table_name} SET counter_val = counter_val + ? WHERE id = ?" - ) - await cassandra_session.execute(update_stmt, (100, "delete_test")) - - # Delete the counter - await cassandra_session.execute( - f"DELETE counter_val FROM {table_name} WHERE id = 'delete_test'" - ) - - # After deletion, counter reads as NULL - result = await cassandra_session.execute( - f"SELECT counter_val FROM {table_name} WHERE id = 'delete_test'" - ) - row = result.one() - if row: # Row might not exist at all - assert row.counter_val is None - - # Can increment again after deletion - await cassandra_session.execute(update_stmt, (50, "delete_test")) - result = await cassandra_session.execute( - f"SELECT counter_val FROM {table_name} WHERE id = 'delete_test'" - ) - row = result.one() - # After deleting a counter column, the row might not exist - # or the counter might be reset depending on Cassandra version - if row is not None: - assert row.counter_val == 50 # Starts from 0 again - - async def test_counter_batch_operations(self, cassandra_session, shared_keyspace_setup): - """ - Test counter operations in batches. - - What this tests: - --------------- - 1. Counter-only batches - 2. Multiple counter updates in batch - 3. Batch atomicity for counters - - Why this matters: - ---------------- - Batching counter updates can improve performance - for related counter modifications. - """ - # Create counter table - table_name = generate_unique_table("test_counter_batch") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - category TEXT, - item TEXT, - views COUNTER, - clicks COUNTER, - PRIMARY KEY (category, item) - ) - """ - ) - - # This test demonstrates counter batch operations - # which are already covered in test_batch_and_lwt_operations.py - # Here we'll test a specific counter batch pattern - - # Prepare counter updates - update_views = await cassandra_session.prepare( - f"UPDATE {table_name} SET views = views + ? WHERE category = ? AND item = ?" - ) - update_clicks = await cassandra_session.prepare( - f"UPDATE {table_name} SET clicks = clicks + ? WHERE category = ? AND item = ?" - ) - - # Update multiple counters for same partition - category = "electronics" - items = ["laptop", "phone", "tablet"] - - # Simulate page views and clicks - for item in items: - await cassandra_session.execute(update_views, (100, category, item)) - await cassandra_session.execute(update_clicks, (10, category, item)) - - # Verify counters - result = await cassandra_session.execute( - f"SELECT * FROM {table_name} WHERE category = '{category}'" - ) - rows = list(result) - assert len(rows) == 3 - - for row in rows: - assert row.views == 100 - assert row.clicks == 10 - - -@pytest.mark.asyncio -@pytest.mark.integration -class TestDataTypeEdgeCases: - """Test edge cases and special scenarios for data types.""" - - async def test_null_value_handling(self, cassandra_session, shared_keyspace_setup): - """ - Test NULL value handling across different data types. - - What this tests: - --------------- - 1. NULL vs missing columns - 2. NULL in collections - 3. NULL in primary keys (not allowed) - 4. Distinguishing NULL from empty - - Why this matters: - ---------------- - NULL handling affects storage, queries, and application logic. - """ - # Create test table - table_name = generate_unique_table("test_null_handling") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - text_col TEXT, - int_col INT, - list_col LIST, - map_col MAP - ) - """ - ) - - # Insert with explicit NULLs - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, text_col, int_col, list_col, map_col) VALUES (?, ?, ?, ?, ?)" - ) - await cassandra_session.execute(insert_stmt, (1, None, None, None, None)) - - # Insert with missing columns (implicitly NULL) - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, text_col) VALUES (2, 'has text')" - ) - - # Insert with empty collections - await cassandra_session.execute(insert_stmt, (3, "text", 0, [], {})) - - # Verify NULL handling - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = {row.id: row for row in result} - - # Explicit NULLs - assert rows[1].text_col is None - assert rows[1].int_col is None - assert rows[1].list_col is None - assert rows[1].map_col is None - - # Missing columns are NULL - assert rows[2].int_col is None - assert rows[2].list_col is None - - # Empty collections become NULL in Cassandra - assert rows[3].list_col is None - assert rows[3].map_col is None - - async def test_numeric_boundaries(self, cassandra_session, shared_keyspace_setup): - """ - Test numeric type boundaries and overflow behavior. - - What this tests: - --------------- - 1. Maximum and minimum values - 2. Overflow behavior - 3. Precision limits - 4. Special float values (NaN, Infinity) - - Why this matters: - ---------------- - Understanding type limits prevents data corruption - and application errors. - """ - # Create test table - table_name = generate_unique_table("test_numeric_boundaries") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - tiny_val TINYINT, - small_val SMALLINT, - float_val FLOAT, - double_val DOUBLE - ) - """ - ) - - # Test boundary values - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, tiny_val, small_val, float_val, double_val) VALUES (?, ?, ?, ?, ?)" - ) - - # Maximum values - await cassandra_session.execute(insert_stmt, (1, 127, 32767, float("inf"), float("inf"))) - - # Minimum values - await cassandra_session.execute( - insert_stmt, (2, -128, -32768, float("-inf"), float("-inf")) - ) - - # Special float values - await cassandra_session.execute(insert_stmt, (3, 0, 0, float("nan"), float("nan"))) - - # Verify special values - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = {row.id: row for row in result} - - # Check infinity - assert rows[1].float_val == float("inf") - assert rows[2].double_val == float("-inf") - - # Check NaN (NaN != NaN in Python) - import math - - assert math.isnan(rows[3].float_val) - assert math.isnan(rows[3].double_val) - - async def test_collection_size_limits(self, cassandra_session, shared_keyspace_setup): - """ - Test collection size limits and performance. - - What this tests: - --------------- - 1. Large collections - 2. Maximum collection sizes - 3. Performance with large collections - 4. Nested collection limits - - Why this matters: - ---------------- - Collections have size limits that affect design decisions. - """ - # Create test table - table_name = generate_unique_table("test_collection_limits") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - large_list LIST, - large_set SET, - large_map MAP - ) - """ - ) - - # Create large collections (but not too large to avoid timeouts) - large_list = [f"item_{i}" for i in range(1000)] - large_set = set(range(1000)) - large_map = {i: f"value_{i}" for i in range(1000)} - - # Insert large collections - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, large_list, large_set, large_map) VALUES (?, ?, ?, ?)" - ) - await cassandra_session.execute(insert_stmt, (1, large_list, large_set, large_map)) - - # Verify large collections - result = await cassandra_session.execute(f"SELECT * FROM {table_name} WHERE id = 1") - row = result.one() - - assert len(row.large_list) == 1000 - assert len(row.large_set) == 1000 - assert len(row.large_map) == 1000 - - # Note: Cassandra has a practical limit of ~64KB for a collection - # and a hard limit of 2GB for any single column value - - async def test_type_compatibility(self, cassandra_session, shared_keyspace_setup): - """ - Test type compatibility and implicit conversions. - - What this tests: - --------------- - 1. Compatible type assignments - 2. String to numeric conversions - 3. Timestamp formats - 4. Type validation - - Why this matters: - ---------------- - Understanding type compatibility helps prevent - runtime errors and data corruption. - """ - # Create test table - table_name = generate_unique_table("test_type_compatibility") - await cassandra_session.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - id INT PRIMARY KEY, - int_val INT, - bigint_val BIGINT, - text_val TEXT, - timestamp_val TIMESTAMP - ) - """ - ) - - # Test compatible assignments - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {table_name} (id, int_val, bigint_val, text_val, timestamp_val) VALUES (?, ?, ?, ?, ?)" - ) - - # INT can be assigned to BIGINT - await cassandra_session.execute( - insert_stmt, (1, 12345, 12345, "12345", datetime.datetime.now(timezone.utc)) - ) - - # Test string representations - await cassandra_session.execute( - f"INSERT INTO {table_name} (id, text_val) VALUES (2, '你好世界')" - ) - - # Verify assignments - result = await cassandra_session.execute(f"SELECT * FROM {table_name}") - rows = list(result) - assert len(rows) == 2 - - # Test type errors - # Cannot insert string into numeric column via prepared statement - with pytest.raises(Exception): # Will be TypeError or similar - await cassandra_session.execute( - insert_stmt, (3, "not a number", 123, "text", datetime.datetime.now(timezone.utc)) - ) diff --git a/tests/integration/test_driver_compatibility.py b/tests/integration/test_driver_compatibility.py deleted file mode 100644 index fc76f80..0000000 --- a/tests/integration/test_driver_compatibility.py +++ /dev/null @@ -1,573 +0,0 @@ -""" -Integration tests comparing async wrapper behavior with raw driver. - -This ensures our wrapper maintains compatibility and doesn't break any functionality. -""" - -import os -import uuid -import warnings - -import pytest -from cassandra.cluster import Cluster as SyncCluster -from cassandra.policies import DCAwareRoundRobinPolicy -from cassandra.query import BatchStatement, BatchType, dict_factory - - -@pytest.mark.integration -@pytest.mark.sync_driver # Allow filtering these tests: pytest -m "not sync_driver" -class TestDriverCompatibility: - """Test async wrapper compatibility with raw driver features.""" - - @pytest.fixture - def sync_cluster(self): - """Create a synchronous cluster for comparison with stability improvements.""" - is_ci = os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true" - - # Strategy 1: Increase connection timeout for CI environments - connect_timeout = 30.0 if is_ci else 10.0 - - # Strategy 2: Explicit configuration to reduce startup delays - cluster = SyncCluster( - contact_points=["127.0.0.1"], - port=9042, - connect_timeout=connect_timeout, - # Always use default connection class - load_balancing_policy=DCAwareRoundRobinPolicy(local_dc="datacenter1"), - protocol_version=5, # We support protocol version 5 - idle_heartbeat_interval=30, # Keep connections alive in CI - schema_event_refresh_window=10, # Reduce schema refresh overhead - ) - - # Strategy 3: Adjust settings for CI stability - if is_ci: - # Reduce executor threads to minimize resource usage - cluster.executor_threads = 1 - # Increase control connection timeout - cluster.control_connection_timeout = 30.0 - # Suppress known warnings - warnings.filterwarnings("ignore", category=DeprecationWarning) - - try: - yield cluster - finally: - cluster.shutdown() - - @pytest.fixture - def sync_session(self, sync_cluster, unique_keyspace): - """Create a synchronous session with retry logic for CI stability.""" - is_ci = os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true" - - # Add retry logic for connection in CI - max_retries = 3 if is_ci else 1 - retry_delay = 2.0 - - session = None - last_error = None - - for attempt in range(max_retries): - try: - session = sync_cluster.connect() - # Verify connection is working - session.execute("SELECT release_version FROM system.local") - break - except Exception as e: - last_error = e - if attempt < max_retries - 1: - import time - - if is_ci: - print(f"Connection attempt {attempt + 1} failed: {e}, retrying...") - time.sleep(retry_delay) - continue - raise e - - if session is None: - raise last_error or Exception("Failed to connect") - - # Create keyspace with retry for schema agreement - for attempt in range(max_retries): - try: - session.execute( - f""" - CREATE KEYSPACE IF NOT EXISTS {unique_keyspace} - WITH REPLICATION = {{'class': 'SimpleStrategy', 'replication_factor': 1}} - """ - ) - session.set_keyspace(unique_keyspace) - break - except Exception as e: - if attempt < max_retries - 1 and is_ci: - import time - - time.sleep(1) - continue - raise e - - try: - yield session - finally: - session.shutdown() - - @pytest.mark.asyncio - async def test_basic_query_compatibility(self, sync_session, session_with_keyspace): - """ - Test basic query execution matches between sync and async. - - What this tests: - --------------- - 1. Same query syntax works - 2. Prepared statements compatible - 3. Results format matches - 4. Independent keyspaces - - Why this matters: - ---------------- - API compatibility ensures: - - Easy migration - - Same patterns work - - No relearning needed - - Drop-in replacement for - sync driver. - """ - async_session, keyspace = session_with_keyspace - - # Create table in both sessions' keyspace - table_name = f"compat_basic_{uuid.uuid4().hex[:8]}" - create_table = f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - name text, - value double - ) - """ - - # Create in sync session's keyspace - sync_session.execute(create_table) - - # Create in async session's keyspace - await async_session.execute(create_table) - - # Prepare statements - both use ? for prepared statements - sync_prepared = sync_session.prepare( - f"INSERT INTO {table_name} (id, name, value) VALUES (?, ?, ?)" - ) - async_prepared = await async_session.prepare( - f"INSERT INTO {table_name} (id, name, value) VALUES (?, ?, ?)" - ) - - # Sync insert - sync_session.execute(sync_prepared, (1, "sync", 1.23)) - - # Async insert - await async_session.execute(async_prepared, (2, "async", 4.56)) - - # Both should see their own rows (different keyspaces) - sync_result = list(sync_session.execute(f"SELECT * FROM {table_name}")) - async_result = list(await async_session.execute(f"SELECT * FROM {table_name}")) - - assert len(sync_result) == 1 # Only sync's insert - assert len(async_result) == 1 # Only async's insert - assert sync_result[0].name == "sync" - assert async_result[0].name == "async" - - @pytest.mark.asyncio - async def test_batch_compatibility(self, sync_session, session_with_keyspace): - """ - Test batch operations compatibility. - - What this tests: - --------------- - 1. Batch types work same - 2. Counter batches OK - 3. Statement binding - 4. Execution results - - Why this matters: - ---------------- - Batch operations critical: - - Atomic operations - - Performance optimization - - Complex workflows - - Must work identically - to sync driver. - """ - async_session, keyspace = session_with_keyspace - - # Create tables in both keyspaces - table_name = f"compat_batch_{uuid.uuid4().hex[:8]}" - counter_table = f"compat_counter_{uuid.uuid4().hex[:8]}" - - # Create in sync keyspace - sync_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - value text - ) - """ - ) - sync_session.execute( - f""" - CREATE TABLE {counter_table} ( - id text PRIMARY KEY, - count counter - ) - """ - ) - - # Create in async keyspace - await async_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - value text - ) - """ - ) - await async_session.execute( - f""" - CREATE TABLE {counter_table} ( - id text PRIMARY KEY, - count counter - ) - """ - ) - - # Prepare statements - sync_stmt = sync_session.prepare(f"INSERT INTO {table_name} (id, value) VALUES (?, ?)") - async_stmt = await async_session.prepare( - f"INSERT INTO {table_name} (id, value) VALUES (?, ?)" - ) - - # Test logged batch - sync_batch = BatchStatement() - async_batch = BatchStatement() - - for i in range(5): - sync_batch.add(sync_stmt, (i, f"sync_{i}")) - async_batch.add(async_stmt, (i + 10, f"async_{i}")) - - sync_session.execute(sync_batch) - await async_session.execute(async_batch) - - # Test counter batch - sync_counter_stmt = sync_session.prepare( - f"UPDATE {counter_table} SET count = count + ? WHERE id = ?" - ) - async_counter_stmt = await async_session.prepare( - f"UPDATE {counter_table} SET count = count + ? WHERE id = ?" - ) - - sync_counter_batch = BatchStatement(batch_type=BatchType.COUNTER) - async_counter_batch = BatchStatement(batch_type=BatchType.COUNTER) - - sync_counter_batch.add(sync_counter_stmt, (5, "sync_counter")) - async_counter_batch.add(async_counter_stmt, (10, "async_counter")) - - sync_session.execute(sync_counter_batch) - await async_session.execute(async_counter_batch) - - # Verify - sync_batch_result = list(sync_session.execute(f"SELECT * FROM {table_name}")) - async_batch_result = list(await async_session.execute(f"SELECT * FROM {table_name}")) - - assert len(sync_batch_result) == 5 # sync batch - assert len(async_batch_result) == 5 # async batch - - sync_counter_result = list(sync_session.execute(f"SELECT * FROM {counter_table}")) - async_counter_result = list(await async_session.execute(f"SELECT * FROM {counter_table}")) - - assert len(sync_counter_result) == 1 - assert len(async_counter_result) == 1 - assert sync_counter_result[0].count == 5 - assert async_counter_result[0].count == 10 - - @pytest.mark.asyncio - async def test_row_factory_compatibility(self, sync_session, session_with_keyspace): - """ - Test row factories work the same. - - What this tests: - --------------- - 1. dict_factory works - 2. Same result format - 3. Key/value access - 4. Custom factories - - Why this matters: - ---------------- - Row factories enable: - - Custom result types - - ORM integration - - Flexible data access - - Must preserve driver's - flexibility. - """ - async_session, keyspace = session_with_keyspace - - table_name = f"compat_factory_{uuid.uuid4().hex[:8]}" - - # Create table in both keyspaces - sync_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - name text, - age int - ) - """ - ) - await async_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - name text, - age int - ) - """ - ) - - # Insert test data using prepared statements - sync_insert = sync_session.prepare( - f"INSERT INTO {table_name} (id, name, age) VALUES (?, ?, ?)" - ) - async_insert = await async_session.prepare( - f"INSERT INTO {table_name} (id, name, age) VALUES (?, ?, ?)" - ) - - sync_session.execute(sync_insert, (1, "Alice", 30)) - await async_session.execute(async_insert, (1, "Alice", 30)) - - # Set row factory to dict - sync_session.row_factory = dict_factory - async_session._session.row_factory = dict_factory - - # Query and compare - sync_result = sync_session.execute(f"SELECT * FROM {table_name}").one() - async_result = (await async_session.execute(f"SELECT * FROM {table_name}")).one() - - assert isinstance(sync_result, dict) - assert isinstance(async_result, dict) - assert sync_result == async_result - assert sync_result["name"] == "Alice" - assert async_result["age"] == 30 - - @pytest.mark.asyncio - async def test_timeout_compatibility(self, sync_session, session_with_keyspace): - """ - Test timeout behavior is similar. - - What this tests: - --------------- - 1. Timeouts respected - 2. Same timeout API - 3. No crashes - 4. Error handling - - Why this matters: - ---------------- - Timeout control critical: - - Prevent hanging - - Resource management - - User experience - - Must match sync driver - timeout behavior. - """ - async_session, keyspace = session_with_keyspace - - table_name = f"compat_timeout_{uuid.uuid4().hex[:8]}" - - # Create table in both keyspaces - sync_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - data text - ) - """ - ) - await async_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - data text - ) - """ - ) - - # Both should respect timeout - short_timeout = 0.001 # 1ms - should timeout - - # These might timeout or not depending on system load - # We're just checking they don't crash - try: - sync_session.execute(f"SELECT * FROM {table_name}", timeout=short_timeout) - except Exception: - pass # Timeout is expected - - try: - await async_session.execute(f"SELECT * FROM {table_name}", timeout=short_timeout) - except Exception: - pass # Timeout is expected - - @pytest.mark.asyncio - async def test_trace_compatibility(self, sync_session, session_with_keyspace): - """ - Test query tracing works the same. - - What this tests: - --------------- - 1. Tracing enabled - 2. Trace data available - 3. Same trace API - 4. Debug capability - - Why this matters: - ---------------- - Tracing essential for: - - Performance debugging - - Query optimization - - Issue diagnosis - - Must preserve debugging - capabilities. - """ - async_session, keyspace = session_with_keyspace - - table_name = f"compat_trace_{uuid.uuid4().hex[:8]}" - - # Create table in both keyspaces - sync_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - value text - ) - """ - ) - await async_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - value text - ) - """ - ) - - # Prepare statements - both use ? for prepared statements - sync_insert = sync_session.prepare(f"INSERT INTO {table_name} (id, value) VALUES (?, ?)") - async_insert = await async_session.prepare( - f"INSERT INTO {table_name} (id, value) VALUES (?, ?)" - ) - - # Execute with tracing - sync_result = sync_session.execute(sync_insert, (1, "sync_trace"), trace=True) - - async_result = await async_session.execute(async_insert, (2, "async_trace"), trace=True) - - # Both should have trace available - assert sync_result.get_query_trace() is not None - assert async_result.get_query_trace() is not None - - # Verify data - sync_count = sync_session.execute(f"SELECT COUNT(*) FROM {table_name}") - async_count = await async_session.execute(f"SELECT COUNT(*) FROM {table_name}") - assert sync_count.one()[0] == 1 - assert async_count.one()[0] == 1 - - @pytest.mark.asyncio - async def test_lwt_compatibility(self, sync_session, session_with_keyspace): - """ - Test lightweight transactions work the same. - - What this tests: - --------------- - 1. IF NOT EXISTS works - 2. Conditional updates - 3. Applied flag correct - 4. Failure handling - - Why this matters: - ---------------- - LWT critical for: - - ACID operations - - Conflict resolution - - Data consistency - - Must work identically - for correctness. - """ - async_session, keyspace = session_with_keyspace - - table_name = f"compat_lwt_{uuid.uuid4().hex[:8]}" - - # Create table in both keyspaces - sync_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - value text, - version int - ) - """ - ) - await async_session.execute( - f""" - CREATE TABLE {table_name} ( - id int PRIMARY KEY, - value text, - version int - ) - """ - ) - - # Prepare LWT statements - both use ? for prepared statements - sync_insert_if_not_exists = sync_session.prepare( - f"INSERT INTO {table_name} (id, value, version) VALUES (?, ?, ?) IF NOT EXISTS" - ) - async_insert_if_not_exists = await async_session.prepare( - f"INSERT INTO {table_name} (id, value, version) VALUES (?, ?, ?) IF NOT EXISTS" - ) - - # Test IF NOT EXISTS - sync_result = sync_session.execute(sync_insert_if_not_exists, (1, "sync", 1)) - async_result = await async_session.execute(async_insert_if_not_exists, (2, "async", 1)) - - # Both should succeed - assert sync_result.one().applied - assert async_result.one().applied - - # Prepare conditional update statements - both use ? for prepared statements - sync_update_if = sync_session.prepare( - f"UPDATE {table_name} SET value = ?, version = ? WHERE id = ? IF version = ?" - ) - async_update_if = await async_session.prepare( - f"UPDATE {table_name} SET value = ?, version = ? WHERE id = ? IF version = ?" - ) - - # Test conditional update - sync_update = sync_session.execute(sync_update_if, ("sync_updated", 2, 1, 1)) - async_update = await async_session.execute(async_update_if, ("async_updated", 2, 2, 1)) - - assert sync_update.one().applied - assert async_update.one().applied - - # Prepare failed condition statements - both use ? for prepared statements - sync_update_fail = sync_session.prepare( - f"UPDATE {table_name} SET version = ? WHERE id = ? IF version = ?" - ) - async_update_fail = await async_session.prepare( - f"UPDATE {table_name} SET version = ? WHERE id = ? IF version = ?" - ) - - # Failed condition - sync_fail = sync_session.execute(sync_update_fail, (3, 1, 1)) - async_fail = await async_session.execute(async_update_fail, (3, 2, 1)) - - assert not sync_fail.one().applied - assert not async_fail.one().applied diff --git a/tests/integration/test_empty_resultsets.py b/tests/integration/test_empty_resultsets.py deleted file mode 100644 index 52ce4f7..0000000 --- a/tests/integration/test_empty_resultsets.py +++ /dev/null @@ -1,542 +0,0 @@ -""" -Integration tests for empty resultset handling. - -These tests verify that the fix for empty resultsets works correctly -with a real Cassandra instance. Empty resultsets are common for: -- Batch INSERT/UPDATE/DELETE statements -- DDL statements (CREATE, ALTER, DROP) -- Queries that match no rows -""" - -import asyncio -import uuid - -import pytest -from cassandra.query import BatchStatement, BatchType - - -@pytest.mark.integration -class TestEmptyResultsets: - """Test empty resultset handling with real Cassandra.""" - - async def _ensure_table_exists(self, session): - """Ensure test table exists.""" - await session.execute( - """ - CREATE TABLE IF NOT EXISTS test_empty_results_table ( - id UUID PRIMARY KEY, - name TEXT, - value INT - ) - """ - ) - - @pytest.mark.asyncio - async def test_batch_insert_returns_empty_result(self, cassandra_session): - """ - Test that batch INSERT statements return empty results without hanging. - - What this tests: - --------------- - 1. Batch INSERT returns empty - 2. No hanging on empty result - 3. Valid result object - 4. Empty rows collection - - Why this matters: - ---------------- - Empty results common for: - - INSERT operations - - UPDATE operations - - DELETE operations - - Must handle without blocking - the event loop. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare the statement first - prepared = await cassandra_session.prepare( - "INSERT INTO test_empty_results_table (id, name, value) VALUES (?, ?, ?)" - ) - - batch = BatchStatement(batch_type=BatchType.LOGGED) - - # Add multiple prepared statements to batch - for i in range(10): - bound = prepared.bind((uuid.uuid4(), f"test_{i}", i)) - batch.add(bound) - - # Execute batch - should return empty result without hanging - result = await cassandra_session.execute(batch) - - # Verify result is empty but valid - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_single_insert_returns_empty_result(self, cassandra_session): - """ - Test that single INSERT statements return empty results. - - What this tests: - --------------- - 1. Single INSERT empty result - 2. Result object valid - 3. Rows collection empty - 4. No exceptions thrown - - Why this matters: - ---------------- - INSERT operations: - - Don't return data - - Still need result object - - Must complete cleanly - - Foundation for all - write operations. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare and execute single INSERT - prepared = await cassandra_session.prepare( - "INSERT INTO test_empty_results_table (id, name, value) VALUES (?, ?, ?)" - ) - result = await cassandra_session.execute(prepared, (uuid.uuid4(), "single_insert", 42)) - - # Verify empty result - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_update_no_match_returns_empty_result(self, cassandra_session): - """ - Test that UPDATE with no matching rows returns empty result. - - What this tests: - --------------- - 1. UPDATE non-existent row - 2. Empty result returned - 3. No error thrown - 4. Clean completion - - Why this matters: - ---------------- - UPDATE operations: - - May match no rows - - Still succeed - - Return empty result - - Common in conditional - update patterns. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare and update non-existent row - prepared = await cassandra_session.prepare( - "UPDATE test_empty_results_table SET value = ? WHERE id = ?" - ) - result = await cassandra_session.execute( - prepared, (100, uuid.uuid4()) # Random UUID won't match any row - ) - - # Verify empty result - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_delete_no_match_returns_empty_result(self, cassandra_session): - """ - Test that DELETE with no matching rows returns empty result. - - What this tests: - --------------- - 1. DELETE non-existent row - 2. Empty result returned - 3. No error thrown - 4. Operation completes - - Why this matters: - ---------------- - DELETE operations: - - Idempotent by design - - No error if not found - - Empty result normal - - Enables safe cleanup - operations. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare and delete non-existent row - prepared = await cassandra_session.prepare( - "DELETE FROM test_empty_results_table WHERE id = ?" - ) - result = await cassandra_session.execute( - prepared, (uuid.uuid4(),) - ) # Random UUID won't match any row - - # Verify empty result - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_select_no_match_returns_empty_result(self, cassandra_session): - """ - Test that SELECT with no matching rows returns empty result. - - What this tests: - --------------- - 1. SELECT finds no rows - 2. Empty result valid - 3. Can iterate empty - 4. No exceptions - - Why this matters: - ---------------- - Empty SELECT results: - - Very common case - - Must handle gracefully - - No special casing - - Simplifies application - error handling. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare and select non-existent row - prepared = await cassandra_session.prepare( - "SELECT * FROM test_empty_results_table WHERE id = ?" - ) - result = await cassandra_session.execute( - prepared, (uuid.uuid4(),) - ) # Random UUID won't match any row - - # Verify empty result - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_ddl_statements_return_empty_results(self, cassandra_session): - """ - Test that DDL statements return empty results. - - What this tests: - --------------- - 1. CREATE TABLE empty result - 2. ALTER TABLE empty result - 3. DROP TABLE empty result - 4. All DDL operations - - Why this matters: - ---------------- - DDL operations: - - Schema changes only - - No data returned - - Must complete cleanly - - Essential for schema - management code. - """ - # Create table - result = await cassandra_session.execute( - """ - CREATE TABLE IF NOT EXISTS ddl_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - # Alter table - result = await cassandra_session.execute("ALTER TABLE ddl_test ADD new_column INT") - - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - # Drop table - result = await cassandra_session.execute("DROP TABLE IF EXISTS ddl_test") - - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_concurrent_empty_results(self, cassandra_session): - """ - Test handling multiple concurrent queries returning empty results. - - What this tests: - --------------- - 1. Concurrent empty results - 2. No blocking or hanging - 3. All queries complete - 4. Mixed operation types - - Why this matters: - ---------------- - High concurrency scenarios: - - Many empty results - - Must not deadlock - - Event loop health - - Verifies async handling - under load. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare statements for concurrent execution - insert_prepared = await cassandra_session.prepare( - "INSERT INTO test_empty_results_table (id, name, value) VALUES (?, ?, ?)" - ) - update_prepared = await cassandra_session.prepare( - "UPDATE test_empty_results_table SET value = ? WHERE id = ?" - ) - delete_prepared = await cassandra_session.prepare( - "DELETE FROM test_empty_results_table WHERE id = ?" - ) - select_prepared = await cassandra_session.prepare( - "SELECT * FROM test_empty_results_table WHERE id = ?" - ) - - # Create multiple concurrent queries that return empty results - tasks = [] - - # Mix of different empty-result queries - for i in range(20): - if i % 4 == 0: - # INSERT - task = cassandra_session.execute( - insert_prepared, (uuid.uuid4(), f"concurrent_{i}", i) - ) - elif i % 4 == 1: - # UPDATE non-existent - task = cassandra_session.execute(update_prepared, (i, uuid.uuid4())) - elif i % 4 == 2: - # DELETE non-existent - task = cassandra_session.execute(delete_prepared, (uuid.uuid4(),)) - else: - # SELECT non-existent - task = cassandra_session.execute(select_prepared, (uuid.uuid4(),)) - - tasks.append(task) - - # Execute all concurrently - results = await asyncio.gather(*tasks) - - # All should complete without hanging - assert len(results) == 20 - - # All should be valid empty results - for result in results: - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_prepared_statement_empty_results(self, cassandra_session): - """ - Test that prepared statements handle empty results correctly. - - What this tests: - --------------- - 1. Prepared INSERT empty - 2. Prepared SELECT empty - 3. Same as simple statements - 4. No special handling - - Why this matters: - ---------------- - Prepared statements: - - Most common pattern - - Must handle empty - - Consistent behavior - - Core functionality for - production apps. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare statements - insert_prepared = await cassandra_session.prepare( - "INSERT INTO test_empty_results_table (id, name, value) VALUES (?, ?, ?)" - ) - - select_prepared = await cassandra_session.prepare( - "SELECT * FROM test_empty_results_table WHERE id = ?" - ) - - # Execute prepared INSERT - result = await cassandra_session.execute(insert_prepared, (uuid.uuid4(), "prepared", 123)) - assert result is not None - assert len(result.rows) == 0 - - # Execute prepared SELECT with no match - result = await cassandra_session.execute(select_prepared, (uuid.uuid4(),)) - assert result is not None - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_batch_mixed_statements_empty_result(self, cassandra_session): - """ - Test batch with mixed statement types returns empty result. - - What this tests: - --------------- - 1. Mixed batch operations - 2. INSERT/UPDATE/DELETE mix - 3. All return empty - 4. Batch completes clean - - Why this matters: - ---------------- - Complex batches: - - Multiple operations - - All write operations - - Single empty result - - Common pattern for - transactional writes. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare statements for batch - insert_prepared = await cassandra_session.prepare( - "INSERT INTO test_empty_results_table (id, name, value) VALUES (?, ?, ?)" - ) - update_prepared = await cassandra_session.prepare( - "UPDATE test_empty_results_table SET value = ? WHERE id = ?" - ) - delete_prepared = await cassandra_session.prepare( - "DELETE FROM test_empty_results_table WHERE id = ?" - ) - - batch = BatchStatement(batch_type=BatchType.UNLOGGED) - - # Mix different types of prepared statements - batch.add(insert_prepared.bind((uuid.uuid4(), "batch_insert", 1))) - batch.add(update_prepared.bind((2, uuid.uuid4()))) # Won't match - batch.add(delete_prepared.bind((uuid.uuid4(),))) # Won't match - - # Execute batch - result = await cassandra_session.execute(batch) - - # Should return empty result - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - @pytest.mark.asyncio - async def test_streaming_empty_results(self, cassandra_session): - """ - Test that streaming queries handle empty results correctly. - - What this tests: - --------------- - 1. Streaming with no data - 2. Iterator completes - 3. No hanging - 4. Context manager works - - Why this matters: - ---------------- - Streaming edge case: - - Must handle empty - - Clean iterator exit - - Resource cleanup - - Prevents infinite loops - and resource leaks. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Configure streaming - from async_cassandra.streaming import StreamConfig - - config = StreamConfig(fetch_size=10, max_pages=5) - - # Prepare statement for streaming - select_prepared = await cassandra_session.prepare( - "SELECT * FROM test_empty_results_table WHERE id = ?" - ) - - # Stream query with no results - async with await cassandra_session.execute_stream( - select_prepared, - (uuid.uuid4(),), # Won't match any row - stream_config=config, - ) as streaming_result: - # Collect all results - all_rows = [] - async for row in streaming_result: - all_rows.append(row) - - # Should complete without hanging and return no rows - assert len(all_rows) == 0 - - @pytest.mark.asyncio - async def test_truncate_returns_empty_result(self, cassandra_session): - """ - Test that TRUNCATE returns empty result. - - What this tests: - --------------- - 1. TRUNCATE operation - 2. DDL empty result - 3. Table cleared - 4. No data returned - - Why this matters: - ---------------- - TRUNCATE operations: - - Clear all data - - DDL operation - - Empty result expected - - Common maintenance - operation pattern. - """ - # Ensure table exists - await self._ensure_table_exists(cassandra_session) - - # Prepare insert statement - insert_prepared = await cassandra_session.prepare( - "INSERT INTO test_empty_results_table (id, name, value) VALUES (?, ?, ?)" - ) - - # Insert some data first - for i in range(5): - await cassandra_session.execute( - insert_prepared, (uuid.uuid4(), f"truncate_test_{i}", i) - ) - - # Truncate table (DDL operation - no parameters) - result = await cassandra_session.execute("TRUNCATE test_empty_results_table") - - # Should return empty result - assert result is not None - assert hasattr(result, "rows") - assert len(result.rows) == 0 - - # The main purpose of this test is to verify TRUNCATE returns empty result - # The SELECT COUNT verification is having issues in the test environment - # but the critical part (TRUNCATE returning empty result) is verified above diff --git a/tests/integration/test_error_propagation.py b/tests/integration/test_error_propagation.py deleted file mode 100644 index 3298d94..0000000 --- a/tests/integration/test_error_propagation.py +++ /dev/null @@ -1,943 +0,0 @@ -""" -Integration tests for error propagation from the Cassandra driver. - -Tests various error conditions that can occur during normal operations -to ensure the async wrapper properly propagates all error types from -the underlying driver to the application layer. -""" - -import asyncio -import uuid - -import pytest -from cassandra import AlreadyExists, ConfigurationException, InvalidRequest -from cassandra.protocol import SyntaxException -from cassandra.query import SimpleStatement - -from async_cassandra.exceptions import QueryError - - -class TestErrorPropagation: - """Test that various Cassandra errors are properly propagated through the async wrapper.""" - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_invalid_query_syntax_error(self, cassandra_cluster): - """ - Test that invalid query syntax errors are propagated. - - What this tests: - --------------- - 1. Syntax errors caught - 2. InvalidRequest raised - 3. Error message preserved - 4. Stack trace intact - - Why this matters: - ---------------- - Development debugging needs: - - Clear error messages - - Exact error types - - Full stack traces - - Bad queries must fail - with helpful errors. - """ - session = await cassandra_cluster.connect() - - # Various syntax errors - invalid_queries = [ - "SELECT * FROM", # Incomplete query - "SELCT * FROM system.local", # Typo in SELECT - "SELECT * FROM system.local WHERE", # Incomplete WHERE - "INSERT INTO test_table", # Incomplete INSERT - "CREATE TABLE", # Incomplete CREATE - ] - - for query in invalid_queries: - # The driver raises SyntaxException for syntax errors, not InvalidRequest - # We might get either SyntaxException directly or QueryError wrapping it - with pytest.raises((SyntaxException, QueryError)) as exc_info: - await session.execute(query) - - # Verify error details are preserved - assert str(exc_info.value) # Has error message - - # If it's wrapped in QueryError, check the cause - if isinstance(exc_info.value, QueryError): - assert isinstance(exc_info.value.__cause__, SyntaxException) - - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_table_not_found_error(self, cassandra_cluster): - """ - Test that table not found errors are propagated. - - What this tests: - --------------- - 1. Missing table error - 2. InvalidRequest raised - 3. Table name in error - 4. Keyspace context - - Why this matters: - ---------------- - Common development error: - - Typos in table names - - Wrong keyspace - - Missing migrations - - Clear errors speed up - debugging significantly. - """ - session = await cassandra_cluster.connect() - - # Create a test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_errors - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_errors") - - # Try to query non-existent table - # This should raise InvalidRequest or be wrapped in QueryError - with pytest.raises((InvalidRequest, QueryError)) as exc_info: - await session.execute("SELECT * FROM non_existent_table") - - # Error should mention the table - error_msg = str(exc_info.value).lower() - assert "non_existent_table" in error_msg or "table" in error_msg - - # If wrapped, check the cause - if isinstance(exc_info.value, QueryError): - assert exc_info.value.__cause__ is not None - - # Cleanup - await session.execute("DROP KEYSPACE IF EXISTS test_errors") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_prepared_statement_invalidation_error(self, cassandra_cluster): - """ - Test errors when prepared statements become invalid. - - What this tests: - --------------- - 1. Table drop invalidates - 2. Prepare after drop - 3. Schema changes handled - 4. Error recovery - - Why this matters: - ---------------- - Schema evolution common: - - Table modifications - - Column changes - - Migration scripts - - Apps must handle schema - changes gracefully. - """ - session = await cassandra_cluster.connect() - - # Create test keyspace and table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_prepare_errors - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_prepare_errors") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS prepare_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Prepare a statement - prepared = await session.prepare("SELECT * FROM prepare_test WHERE id = ?") - - # Insert some data and verify prepared statement works - test_id = uuid.uuid4() - await session.execute( - "INSERT INTO prepare_test (id, data) VALUES (%s, %s)", [test_id, "test data"] - ) - result = await session.execute(prepared, [test_id]) - assert result.one() is not None - - # Drop and recreate table with different schema - await session.execute("DROP TABLE prepare_test") - await session.execute( - """ - CREATE TABLE prepare_test ( - id UUID PRIMARY KEY, - data TEXT, - new_column INT -- Schema changed - ) - """ - ) - - # The prepared statement should still work (driver handles re-preparation) - # but let's also test preparing a statement for a dropped table - await session.execute("DROP TABLE prepare_test") - - # Trying to prepare for non-existent table should fail - # This might raise InvalidRequest or be wrapped in QueryError - with pytest.raises((InvalidRequest, QueryError)) as exc_info: - await session.prepare("SELECT * FROM prepare_test WHERE id = ?") - - error_msg = str(exc_info.value).lower() - assert "prepare_test" in error_msg or "table" in error_msg - - # If wrapped, check the cause - if isinstance(exc_info.value, QueryError): - assert exc_info.value.__cause__ is not None - - # Cleanup - await session.execute("DROP KEYSPACE IF EXISTS test_prepare_errors") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_prepared_statement_column_drop_error(self, cassandra_cluster): - """ - Test what happens when a column referenced by a prepared statement is dropped. - - What this tests: - --------------- - 1. Prepare with column reference - 2. Drop the column - 3. Reuse prepared statement - 4. Error propagation - - Why this matters: - ---------------- - Column drops happen during: - - Schema refactoring - - Deprecating features - - Data model changes - - Prepared statements must - handle column removal. - """ - session = await cassandra_cluster.connect() - - # Create test keyspace and table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_column_drop - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_column_drop") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS column_test ( - id UUID PRIMARY KEY, - name TEXT, - email TEXT, - age INT - ) - """ - ) - - # Prepare statements that reference specific columns - select_with_email = await session.prepare( - "SELECT id, name, email FROM column_test WHERE id = ?" - ) - insert_with_email = await session.prepare( - "INSERT INTO column_test (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - update_email = await session.prepare("UPDATE column_test SET email = ? WHERE id = ?") - - # Insert test data and verify statements work - test_id = uuid.uuid4() - await session.execute(insert_with_email, [test_id, "Test User", "test@example.com", 25]) - - result = await session.execute(select_with_email, [test_id]) - row = result.one() - assert row.email == "test@example.com" - - # Now drop the email column - await session.execute("ALTER TABLE column_test DROP email") - - # Try to use the prepared statements that reference the dropped column - - # SELECT with dropped column should fail - with pytest.raises(InvalidRequest) as exc_info: - await session.execute(select_with_email, [test_id]) - error_msg = str(exc_info.value).lower() - assert "email" in error_msg or "column" in error_msg or "undefined" in error_msg - - # INSERT with dropped column should fail - with pytest.raises(InvalidRequest) as exc_info: - await session.execute( - insert_with_email, [uuid.uuid4(), "Another User", "another@example.com", 30] - ) - error_msg = str(exc_info.value).lower() - assert "email" in error_msg or "column" in error_msg or "undefined" in error_msg - - # UPDATE of dropped column should fail - with pytest.raises(InvalidRequest) as exc_info: - await session.execute(update_email, ["new@example.com", test_id]) - error_msg = str(exc_info.value).lower() - assert "email" in error_msg or "column" in error_msg or "undefined" in error_msg - - # Verify that statements without the dropped column still work - select_without_email = await session.prepare( - "SELECT id, name, age FROM column_test WHERE id = ?" - ) - result = await session.execute(select_without_email, [test_id]) - row = result.one() - assert row.name == "Test User" - assert row.age == 25 - - # Cleanup - await session.execute("DROP TABLE IF EXISTS column_test") - await session.execute("DROP KEYSPACE IF EXISTS test_column_drop") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_keyspace_not_found_error(self, cassandra_cluster): - """ - Test that keyspace not found errors are propagated. - - What this tests: - --------------- - 1. Missing keyspace error - 2. Clear error message - 3. Keyspace name shown - 4. Connection still valid - - Why this matters: - ---------------- - Keyspace errors indicate: - - Wrong environment - - Missing setup - - Config issues - - Must fail clearly to - prevent data loss. - """ - session = await cassandra_cluster.connect() - - # Try to use non-existent keyspace - with pytest.raises(InvalidRequest) as exc_info: - await session.execute("USE non_existent_keyspace") - - error_msg = str(exc_info.value) - assert "non_existent_keyspace" in error_msg or "keyspace" in error_msg.lower() - - # Session should still be usable - result = await session.execute("SELECT now() FROM system.local") - assert result.one() is not None - - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_type_mismatch_errors(self, cassandra_cluster): - """ - Test that type mismatch errors are propagated. - - What this tests: - --------------- - 1. Type validation works - 2. InvalidRequest raised - 3. Column info in error - 4. Type details shown - - Why this matters: - ---------------- - Type safety critical: - - Data integrity - - Bug prevention - - Clear debugging - - Type errors must be - caught and reported. - """ - session = await cassandra_cluster.connect() - - # Create test table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_type_errors - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_type_errors") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS type_test ( - id UUID PRIMARY KEY, - count INT, - active BOOLEAN, - created TIMESTAMP - ) - """ - ) - - # Prepare insert statement - insert_stmt = await session.prepare( - "INSERT INTO type_test (id, count, active, created) VALUES (?, ?, ?, ?)" - ) - - # Try various type mismatches - test_cases = [ - # (values, expected_error_contains) - ([uuid.uuid4(), "not_a_number", True, "2023-01-01"], ["count", "int"]), - ([uuid.uuid4(), 42, "not_a_boolean", "2023-01-01"], ["active", "boolean"]), - (["not_a_uuid", 42, True, "2023-01-01"], ["id", "uuid"]), - ] - - for values, error_keywords in test_cases: - with pytest.raises(Exception) as exc_info: # Could be InvalidRequest or TypeError - await session.execute(insert_stmt, values) - - error_msg = str(exc_info.value).lower() - # Check that at least one expected keyword is in the error - assert any( - keyword.lower() in error_msg for keyword in error_keywords - ), f"Expected keywords {error_keywords} not found in error: {error_msg}" - - # Cleanup - await session.execute("DROP TABLE IF EXISTS type_test") - await session.execute("DROP KEYSPACE IF EXISTS test_type_errors") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_timeout_errors(self, cassandra_cluster): - """ - Test that timeout errors are properly propagated. - - What this tests: - --------------- - 1. Query timeouts work - 2. Timeout value respected - 3. Error type correct - 4. Session recovers - - Why this matters: - ---------------- - Timeout handling critical: - - Prevent hanging - - Resource cleanup - - User experience - - Timeouts must fail fast - and recover cleanly. - """ - session = await cassandra_cluster.connect() - - # Create a test table with data - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_timeout_errors - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_timeout_errors") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS timeout_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Insert some data - for i in range(100): - await session.execute( - "INSERT INTO timeout_test (id, data) VALUES (%s, %s)", - [uuid.uuid4(), f"data_{i}" * 100], # Make data reasonably large - ) - - # Create a simple query - stmt = SimpleStatement("SELECT * FROM timeout_test") - - # Execute with very short timeout - # Note: This might not always timeout in fast local environments - try: - result = await session.execute(stmt, timeout=0.001) # 1ms timeout - very aggressive - # If it succeeds, that's fine - timeout is environment dependent - rows = list(result) - assert len(rows) > 0 - except Exception as e: - # If it times out, verify we get a timeout-related error - # TimeoutError might have empty string representation, check type name too - error_msg = str(e).lower() - error_type = type(e).__name__.lower() - assert ( - "timeout" in error_msg - or "timeout" in error_type - or isinstance(e, asyncio.TimeoutError) - ) - - # Session should still be usable after timeout - result = await session.execute("SELECT count(*) FROM timeout_test") - assert result.one().count >= 0 - - # Cleanup - await session.execute("DROP TABLE IF EXISTS timeout_test") - await session.execute("DROP KEYSPACE IF EXISTS test_timeout_errors") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_batch_size_limit_error(self, cassandra_cluster): - """ - Test that batch size limit errors are propagated. - - What this tests: - --------------- - 1. Batch size limits - 2. Error on too large - 3. Clear error message - 4. Batch still usable - - Why this matters: - ---------------- - Batch limits prevent: - - Memory issues - - Performance problems - - Cluster instability - - Apps must respect - batch size limits. - """ - from cassandra.query import BatchStatement - - session = await cassandra_cluster.connect() - - # Create test table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_batch_errors - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_batch_errors") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS batch_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Prepare insert statement - insert_stmt = await session.prepare("INSERT INTO batch_test (id, data) VALUES (?, ?)") - - # Try to create a very large batch - # Default batch size warning is at 5KB, error at 50KB - batch = BatchStatement() - large_data = "x" * 1000 # 1KB per row - - # Add many statements to exceed size limit - for i in range(100): # This should exceed typical batch size limits - batch.add(insert_stmt, [uuid.uuid4(), large_data]) - - # This might warn or error depending on server config - try: - await session.execute(batch) - # If it succeeds, server has high limits - that's OK - except Exception as e: - # If it fails, should mention batch size - error_msg = str(e).lower() - assert "batch" in error_msg or "size" in error_msg or "limit" in error_msg - - # Smaller batch should work fine - small_batch = BatchStatement() - for i in range(5): - small_batch.add(insert_stmt, [uuid.uuid4(), "small data"]) - - await session.execute(small_batch) # Should succeed - - # Cleanup - await session.execute("DROP TABLE IF EXISTS batch_test") - await session.execute("DROP KEYSPACE IF EXISTS test_batch_errors") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_concurrent_schema_modification_errors(self, cassandra_cluster): - """ - Test errors from concurrent schema modifications. - - What this tests: - --------------- - 1. Schema conflicts - 2. AlreadyExists errors - 3. Concurrent DDL - 4. Error recovery - - Why this matters: - ---------------- - Multiple apps/devs may: - - Run migrations - - Modify schema - - Create tables - - Must handle conflicts - gracefully. - """ - session = await cassandra_cluster.connect() - - # Create test keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_schema_errors - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_schema_errors") - - # Create a table - await session.execute( - """ - CREATE TABLE IF NOT EXISTS schema_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Try to create the same table again (without IF NOT EXISTS) - # This might raise AlreadyExists or be wrapped in QueryError - with pytest.raises((AlreadyExists, QueryError)) as exc_info: - await session.execute( - """ - CREATE TABLE schema_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - error_msg = str(exc_info.value).lower() - assert "schema_test" in error_msg or "already exists" in error_msg - - # If wrapped, check the cause - if isinstance(exc_info.value, QueryError): - assert exc_info.value.__cause__ is not None - - # Try to create duplicate index - await session.execute("CREATE INDEX IF NOT EXISTS idx_data ON schema_test (data)") - - # This might raise InvalidRequest or be wrapped in QueryError - with pytest.raises((InvalidRequest, QueryError)) as exc_info: - await session.execute("CREATE INDEX idx_data ON schema_test (data)") - - error_msg = str(exc_info.value).lower() - assert "index" in error_msg or "already exists" in error_msg - - # If wrapped, check the cause - if isinstance(exc_info.value, QueryError): - assert exc_info.value.__cause__ is not None - - # Simulate concurrent modifications by trying operations that might conflict - async def create_column(col_name): - try: - await session.execute(f"ALTER TABLE schema_test ADD {col_name} TEXT") - return True - except (InvalidRequest, ConfigurationException): - return False - - # Try to add same column concurrently (one should fail) - results = await asyncio.gather( - create_column("new_col"), create_column("new_col"), return_exceptions=True - ) - - # At least one should succeed, at least one should fail - successes = sum(1 for r in results if r is True) - failures = sum(1 for r in results if r is False or isinstance(r, Exception)) - assert successes >= 1 # At least one succeeded - assert failures >= 0 # Some might fail due to concurrent modification - - # Cleanup - await session.execute("DROP TABLE IF EXISTS schema_test") - await session.execute("DROP KEYSPACE IF EXISTS test_schema_errors") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_consistency_level_errors(self, cassandra_cluster): - """ - Test that consistency level errors are propagated. - - What this tests: - --------------- - 1. Consistency failures - 2. Unavailable errors - 3. Error details preserved - 4. Session recovery - - Why this matters: - ---------------- - Consistency errors show: - - Cluster health issues - - Replication problems - - Config mismatches - - Critical for distributed - system debugging. - """ - from cassandra import ConsistencyLevel - from cassandra.query import SimpleStatement - - session = await cassandra_cluster.connect() - - # Create test keyspace with RF=1 - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_consistency_errors - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_consistency_errors") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS consistency_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # Insert some data - test_id = uuid.uuid4() - await session.execute( - "INSERT INTO consistency_test (id, data) VALUES (%s, %s)", [test_id, "test data"] - ) - - # In a single-node setup, we can't truly test consistency failures - # but we can verify that consistency levels are accepted - - # These should work with single node - for cl in [ConsistencyLevel.ONE, ConsistencyLevel.LOCAL_ONE]: - stmt = SimpleStatement( - "SELECT * FROM consistency_test WHERE id = %s", consistency_level=cl - ) - result = await session.execute(stmt, [test_id]) - assert result.one() is not None - - # Note: In production, requesting ALL or QUORUM with RF=1 on multi-node - # cluster could fail. Here we just verify the statement executes. - stmt = SimpleStatement( - "SELECT * FROM consistency_test", consistency_level=ConsistencyLevel.ALL - ) - result = await session.execute(stmt) - # Should work on single node even with CL=ALL - - # Cleanup - await session.execute("DROP TABLE IF EXISTS consistency_test") - await session.execute("DROP KEYSPACE IF EXISTS test_consistency_errors") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_function_and_aggregate_errors(self, cassandra_cluster): - """ - Test errors related to functions and aggregates. - - What this tests: - --------------- - 1. Invalid function calls - 2. Missing functions - 3. Wrong arguments - 4. Clear error messages - - Why this matters: - ---------------- - Function errors common: - - Wrong function names - - Incorrect arguments - - Type mismatches - - Need clear error messages - for debugging. - """ - session = await cassandra_cluster.connect() - - # Test invalid function calls - with pytest.raises(InvalidRequest) as exc_info: - await session.execute("SELECT non_existent_function(now()) FROM system.local") - - error_msg = str(exc_info.value).lower() - assert "function" in error_msg or "unknown" in error_msg - - # Test wrong number of arguments to built-in function - with pytest.raises(InvalidRequest) as exc_info: - await session.execute("SELECT toTimestamp() FROM system.local") # Missing argument - - # Test invalid aggregate usage - with pytest.raises(InvalidRequest) as exc_info: - await session.execute("SELECT sum(release_version) FROM system.local") # Can't sum text - - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_large_query_handling(self, cassandra_cluster): - """ - Test handling of large queries and data. - - What this tests: - --------------- - 1. Large INSERT data - 2. Large SELECT results - 3. Protocol limits - 4. Memory handling - - Why this matters: - ---------------- - Large data scenarios: - - Bulk imports - - Document storage - - Media metadata - - Must handle large payloads - without protocol errors. - """ - session = await cassandra_cluster.connect() - - # Create test keyspace and table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_large_data - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("test_large_data") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS large_data_test ( - id UUID PRIMARY KEY, - small_text TEXT, - large_text TEXT, - binary_data BLOB - ) - """ - ) - - # Test 1: Large text data (just under common limits) - test_id = uuid.uuid4() - # Create 1MB of text data (well within Cassandra's default frame size) - large_text = "x" * (1024 * 1024) # 1MB - - # This should succeed - insert_stmt = await session.prepare( - "INSERT INTO large_data_test (id, small_text, large_text) VALUES (?, ?, ?)" - ) - await session.execute(insert_stmt, [test_id, "small", large_text]) - - # Verify we can read it back - select_stmt = await session.prepare("SELECT * FROM large_data_test WHERE id = ?") - result = await session.execute(select_stmt, [test_id]) - row = result.one() - assert row is not None - assert len(row.large_text) == len(large_text) - assert row.large_text == large_text - - # Test 2: Binary data - import os - - test_id2 = uuid.uuid4() - # Create 512KB of random binary data - binary_data = os.urandom(512 * 1024) # 512KB - - insert_binary_stmt = await session.prepare( - "INSERT INTO large_data_test (id, small_text, binary_data) VALUES (?, ?, ?)" - ) - await session.execute(insert_binary_stmt, [test_id2, "binary test", binary_data]) - - # Read it back - result = await session.execute(select_stmt, [test_id2]) - row = result.one() - assert row is not None - assert len(row.binary_data) == len(binary_data) - assert row.binary_data == binary_data - - # Test 3: Multiple large rows in one query - # Insert several rows with moderately large data - insert_many_stmt = await session.prepare( - "INSERT INTO large_data_test (id, small_text, large_text) VALUES (?, ?, ?)" - ) - - row_ids = [] - medium_text = "y" * (100 * 1024) # 100KB per row - for i in range(10): - row_id = uuid.uuid4() - row_ids.append(row_id) - await session.execute(insert_many_stmt, [row_id, f"row_{i}", medium_text]) - - # Select all of them at once - # For simple statements, use %s placeholders - placeholders = ",".join(["%s"] * len(row_ids)) - select_many = f"SELECT * FROM large_data_test WHERE id IN ({placeholders})" - result = await session.execute(select_many, row_ids) - rows = list(result) - assert len(rows) == 10 - for row in rows: - assert len(row.large_text) == len(medium_text) - - # Test 4: Very large data that might exceed limits - # Default native protocol frame size is often 256MB, but message size limits are lower - # Try something that's large but should still work - test_id3 = uuid.uuid4() - very_large_text = "z" * (10 * 1024 * 1024) # 10MB - - try: - await session.execute(insert_stmt, [test_id3, "very large", very_large_text]) - # If it succeeds, verify we can read it - result = await session.execute(select_stmt, [test_id3]) - row = result.one() - assert row is not None - assert len(row.large_text) == len(very_large_text) - except Exception as e: - # If it fails due to size limits, that's expected - error_msg = str(e).lower() - assert any(word in error_msg for word in ["size", "large", "limit", "frame", "big"]) - - # Test 5: Large batch with multiple large values - from cassandra.query import BatchStatement - - batch = BatchStatement() - batch_text = "b" * (50 * 1024) # 50KB per row - - # Add 20 statements to the batch (total ~1MB) - for i in range(20): - batch.add(insert_stmt, [uuid.uuid4(), f"batch_{i}", batch_text]) - - try: - await session.execute(batch) - # Success means the batch was within limits - except Exception as e: - # Large batches might be rejected - error_msg = str(e).lower() - assert any(word in error_msg for word in ["batch", "size", "large", "limit"]) - - # Cleanup - await session.execute("DROP TABLE IF EXISTS large_data_test") - await session.execute("DROP KEYSPACE IF EXISTS test_large_data") - await session.close() diff --git a/tests/integration/test_example_scripts.py b/tests/integration/test_example_scripts.py deleted file mode 100644 index 7ed2629..0000000 --- a/tests/integration/test_example_scripts.py +++ /dev/null @@ -1,783 +0,0 @@ -""" -Integration tests for example scripts. - -This module tests that all example scripts in the examples/ directory -work correctly and follow the proper API usage patterns. - -What this tests: ---------------- -1. All example scripts execute without errors -2. Examples use context managers properly -3. Examples use prepared statements where appropriate -4. Examples clean up resources correctly -5. Examples demonstrate best practices - -Why this matters: ----------------- -- Examples are often the first code users see -- Broken examples damage library credibility -- Examples should showcase best practices -- Users copy example code into production - -Additional context: ---------------------------------- -- Tests run each example in isolation -- Cassandra container is shared between tests -- Each example creates and drops its own keyspace -- Tests verify output and side effects -""" - -import asyncio -import os -import shutil -import subprocess -import sys -from pathlib import Path - -import pytest - -from async_cassandra import AsyncCluster - -# Path to examples directory -EXAMPLES_DIR = Path(__file__).parent.parent.parent / "examples" - - -class TestExampleScripts: - """Test all example scripts work correctly.""" - - @pytest.fixture(autouse=True) - async def setup_cassandra(self, cassandra_cluster): - """Ensure Cassandra is available for examples.""" - # Cassandra is guaranteed to be available via cassandra_cluster fixture - pass - - @pytest.mark.timeout(180) # Override default timeout for this test - async def test_streaming_basic_example(self, cassandra_cluster): - """ - Test the basic streaming example. - - What this tests: - --------------- - 1. Script executes without errors - 2. Creates and populates test data - 3. Demonstrates streaming with context manager - 4. Shows filtered streaming with prepared statements - 5. Cleans up keyspace after completion - - Why this matters: - ---------------- - - Streaming is critical for large datasets - - Context managers prevent memory leaks - - Users need clear streaming examples - - Common use case for analytics - """ - script_path = EXAMPLES_DIR / "streaming_basic.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Run the example script - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=120, # Allow time for 100k events generation - ) - - # Check execution succeeded - if result.returncode != 0: - print(f"STDOUT:\n{result.stdout}") - print(f"STDERR:\n{result.stderr}") - assert result.returncode == 0, f"Script failed with return code {result.returncode}" - - # Verify expected output patterns - # The examples use logging which outputs to stderr - output = result.stderr if result.stderr else result.stdout - assert "Basic Streaming Example" in output - assert "Inserted 100000 test events" in output or "Inserted 100,000 test events" in output - assert "Streaming completed:" in output - assert "Total events: 100,000" in output or "Total events: 100000" in output - assert "Filtered Streaming Example" in output - assert "Page-Based Streaming Example (True Async Paging)" in output - assert "Pages are fetched asynchronously" in output - - # Verify keyspace was cleaned up - async with AsyncCluster(["localhost"]) as cluster: - async with await cluster.connect() as session: - result = await session.execute( - "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = 'streaming_example'" - ) - assert result.one() is None, "Keyspace was not cleaned up" - - async def test_export_large_table_example(self, cassandra_cluster, tmp_path): - """ - Test the table export example. - - What this tests: - --------------- - 1. Creates sample data correctly - 2. Exports data to CSV format - 3. Handles different data types properly - 4. Shows progress during export - 5. Cleans up resources - 6. Validates output file content - - Why this matters: - ---------------- - - Data export is common requirement - - CSV format widely used - - Memory efficiency critical for large tables - - Progress tracking improves UX - """ - script_path = EXAMPLES_DIR / "export_large_table.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Use temp directory for output - export_dir = tmp_path / "example_output" - export_dir.mkdir(exist_ok=True) - - try: - # Run the example script with custom output directory - env = os.environ.copy() - env["EXAMPLE_OUTPUT_DIR"] = str(export_dir) - - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=60, - env=env, - ) - - # Check execution succeeded - assert result.returncode == 0, f"Script failed with: {result.stderr}" - - # Verify expected output (might be in stdout or stderr due to logging) - output = result.stdout + result.stderr - assert "Created 5000 sample products" in output - assert "Export completed:" in output - assert "Rows exported: 5,000" in output - assert f"Output directory: {export_dir}" in output - - # Verify CSV file was created - csv_files = list(export_dir.glob("*.csv")) - assert len(csv_files) > 0, "No CSV files were created" - - # Verify CSV content - csv_file = csv_files[0] - assert csv_file.stat().st_size > 0, "CSV file is empty" - - # Read and validate CSV content - with open(csv_file, "r") as f: - header = f.readline().strip() - # Verify header contains expected columns - assert "product_id" in header - assert "category" in header - assert "price" in header - assert "in_stock" in header - assert "tags" in header - assert "attributes" in header - assert "created_at" in header - - # Read a few data rows to verify content - row_count = 0 - for line in f: - row_count += 1 - if row_count > 10: # Check first 10 rows - break - # Basic validation that row has content - assert len(line.strip()) > 0 - assert "," in line # CSV format - - # Verify we have the expected number of rows (5000 + header) - f.seek(0) - total_lines = sum(1 for _ in f) - assert ( - total_lines == 5001 - ), f"Expected 5001 lines (header + 5000 rows), got {total_lines}" - - finally: - # Cleanup - always clean up even if test fails - # pytest's tmp_path fixture also cleans up automatically - if export_dir.exists(): - shutil.rmtree(export_dir) - - async def test_context_manager_safety_demo(self, cassandra_cluster): - """ - Test the context manager safety demonstration. - - What this tests: - --------------- - 1. Query errors don't close sessions - 2. Streaming errors don't close sessions - 3. Context managers isolate resources - 4. Concurrent operations work safely - 5. Proper error handling patterns - - Why this matters: - ---------------- - - Users need to understand resource lifecycle - - Error handling is often done wrong - - Context managers are mandatory - - Demonstrates resilience patterns - """ - script_path = EXAMPLES_DIR / "context_manager_safety_demo.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Run the example script with longer timeout - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=60, # Increase timeout as this example runs multiple demonstrations - ) - - # Check execution succeeded - assert result.returncode == 0, f"Script failed with: {result.stderr}" - - # Verify all demonstrations ran (might be in stdout or stderr due to logging) - output = result.stdout + result.stderr - assert "Demonstrating Query Error Safety" in output - assert "Query failed as expected" in output - assert "Session still works after error" in output - - assert "Demonstrating Streaming Error Safety" in output - assert "Streaming failed as expected" in output - assert "Successfully streamed" in output - - assert "Demonstrating Context Manager Isolation" in output - assert "Demonstrating Concurrent Safety" in output - - # Verify key takeaways are shown - assert "Query errors don't close sessions" in output - assert "Context managers only close their own resources" in output - - async def test_metrics_simple_example(self, cassandra_cluster): - """ - Test the simple metrics example. - - What this tests: - --------------- - 1. Metrics collection works correctly - 2. Query performance is tracked - 3. Connection health is monitored - 4. Statistics are calculated properly - 5. Error tracking functions - - Why this matters: - ---------------- - - Observability is critical in production - - Users need metrics examples - - Performance monitoring essential - - Shows integration patterns - """ - script_path = EXAMPLES_DIR / "metrics_simple.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Run the example script - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=30, - ) - - # Check execution succeeded - assert result.returncode == 0, f"Script failed with: {result.stderr}" - - # Verify metrics output (might be in stdout or stderr due to logging) - output = result.stdout + result.stderr - assert "Query Metrics Example" in output or "async-cassandra Metrics Example" in output - assert "Connection Health Monitoring" in output - assert "Error Tracking Example" in output or "Expected error recorded" in output - assert "Performance Summary" in output - - # Verify statistics are shown - assert "Total queries:" in output or "Query Metrics:" in output - assert "Success rate:" in output or "Success Rate:" in output - assert "Average latency:" in output or "Average Duration:" in output - - @pytest.mark.timeout(240) # Override default timeout for this test (lots of data) - async def test_realtime_processing_example(self, cassandra_cluster): - """ - Test the real-time processing example. - - What this tests: - --------------- - 1. Time-series data handling - 2. Sliding window analytics - 3. Real-time aggregations - 4. Alert triggering logic - 5. Continuous processing patterns - - Why this matters: - ---------------- - - IoT/sensor data is common use case - - Real-time analytics increasingly important - - Shows advanced streaming patterns - - Demonstrates time-based queries - """ - script_path = EXAMPLES_DIR / "realtime_processing.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Run the example script with a longer timeout since it processes lots of data - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=180, # Allow more time for 108k readings (50 sensors × 2160 time points) - ) - - # Check execution succeeded - assert result.returncode == 0, f"Script failed with: {result.stderr}" - - # Verify expected output (check both stdout and stderr) - output = result.stdout + result.stderr - - # Check that setup completed - assert "Setting up sensor data" in output - assert "Sample data inserted" in output - - # Check that processing occurred - assert "Processing Historical Data" in output or "Processing historical data" in output - assert "Processing completed" in output or "readings processed" in output - - # Check that real-time simulation ran - assert "Simulating Real-Time Processing" in output or "Processing cycle" in output - - # Verify cleanup - assert "Cleaning up" in output - - async def test_metrics_advanced_example(self, cassandra_cluster): - """ - Test the advanced metrics example. - - What this tests: - --------------- - 1. Multiple metrics collectors - 2. Prometheus integration setup - 3. FastAPI integration patterns - 4. Comprehensive monitoring - 5. Production-ready patterns - - Why this matters: - ---------------- - - Production systems need Prometheus - - FastAPI integration common - - Shows complete monitoring setup - - Enterprise-ready patterns - """ - script_path = EXAMPLES_DIR / "metrics_example.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Run the example script - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=30, - ) - - # Check execution succeeded - assert result.returncode == 0, f"Script failed with: {result.stderr}" - - # Verify advanced features demonstrated (might be in stdout or stderr due to logging) - output = result.stdout + result.stderr - assert "Metrics" in output or "metrics" in output - assert "queries" in output.lower() or "Queries" in output - - @pytest.mark.timeout(240) # Override default timeout for this test - async def test_export_to_parquet_example(self, cassandra_cluster, tmp_path): - """ - Test the Parquet export example. - - What this tests: - --------------- - 1. Creates test data with various types - 2. Exports data to Parquet format - 3. Handles different compression formats - 4. Shows progress during export - 5. Verifies exported files - 6. Validates Parquet file content - 7. Cleans up resources automatically - - Why this matters: - ---------------- - - Parquet is popular for analytics - - Memory-efficient export critical for large datasets - - Type handling must be correct - - Shows advanced streaming patterns - """ - script_path = EXAMPLES_DIR / "export_to_parquet.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Use temp directory for output - export_dir = tmp_path / "parquet_output" - export_dir.mkdir(exist_ok=True) - - try: - # Run the example script with custom output directory - env = os.environ.copy() - env["EXAMPLE_OUTPUT_DIR"] = str(export_dir) - - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=180, # Allow time for data generation and export - env=env, - ) - - # Check execution succeeded - if result.returncode != 0: - print(f"STDOUT:\n{result.stdout}") - print(f"STDERR:\n{result.stderr}") - assert result.returncode == 0, f"Script failed with return code {result.returncode}" - - # Verify expected output - output = result.stderr if result.stderr else result.stdout - assert "Setting up test data" in output - assert "Test data setup complete" in output - assert "Example 1: Export Entire Table" in output - assert "Example 2: Export Filtered Data" in output - assert "Example 3: Export with Different Compression" in output - assert "Export completed successfully!" in output - assert "Verifying Exported Files" in output - assert f"Output directory: {export_dir}" in output - - # Verify Parquet files were created (look recursively in subdirectories) - parquet_files = list(export_dir.rglob("*.parquet")) - assert ( - len(parquet_files) >= 3 - ), f"Expected at least 3 Parquet files, found {len(parquet_files)}" - - # Verify files have content - for parquet_file in parquet_files: - assert parquet_file.stat().st_size > 0, f"Parquet file {parquet_file} is empty" - - # Verify we can read and validate the Parquet files - try: - import pyarrow as pa - import pyarrow.parquet as pq - - # Track total rows across all files - total_rows = 0 - - for parquet_file in parquet_files: - table = pq.read_table(parquet_file) - assert table.num_rows > 0, f"Parquet file {parquet_file} has no rows" - total_rows += table.num_rows - - # Verify expected columns exist - column_names = [field.name for field in table.schema] - assert "user_id" in column_names - assert "event_time" in column_names - assert "event_type" in column_names - assert "device_type" in column_names - assert "country_code" in column_names - assert "city" in column_names - assert "revenue" in column_names - assert "duration_seconds" in column_names - assert "is_premium" in column_names - assert "metadata" in column_names - assert "tags" in column_names - - # Verify data types are preserved - schema = table.schema - assert schema.field("is_premium").type == pa.bool_() - assert ( - schema.field("duration_seconds").type == pa.int64() - ) # We use int64 in our schema - - # Read first few rows to validate content - df = table.to_pandas() - assert len(df) > 0 - - # Validate some data characteristics - assert ( - df["event_type"] - .isin(["view", "click", "purchase", "signup", "logout"]) - .all() - ) - assert df["device_type"].isin(["mobile", "desktop", "tablet", "tv"]).all() - assert df["duration_seconds"].between(10, 3600).all() - - # Verify we generated substantial test data (should be > 10k rows) - assert total_rows > 10000, f"Expected > 10000 total rows, got {total_rows}" - - except ImportError: - # PyArrow not available in test environment - pytest.skip("PyArrow not available for full validation") - - finally: - # Cleanup - always clean up even if test fails - # pytest's tmp_path fixture also cleans up automatically - if export_dir.exists(): - shutil.rmtree(export_dir) - - async def test_streaming_non_blocking_demo(self, cassandra_cluster): - """ - Test the non-blocking streaming demonstration. - - What this tests: - --------------- - 1. Creates test data for streaming - 2. Demonstrates event loop responsiveness - 3. Shows concurrent operations during streaming - 4. Provides visual feedback of non-blocking behavior - 5. Cleans up resources - - Why this matters: - ---------------- - - Proves async wrapper doesn't block - - Critical for understanding async benefits - - Shows real concurrent execution - - Validates our architecture claims - """ - script_path = EXAMPLES_DIR / "streaming_non_blocking_demo.py" - assert script_path.exists(), f"Example script not found: {script_path}" - - # Run the example script - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=120, # Allow time for demonstrations - ) - - # Check execution succeeded - if result.returncode != 0: - print(f"STDOUT:\n{result.stdout}") - print(f"STDERR:\n{result.stderr}") - assert result.returncode == 0, f"Script failed with return code {result.returncode}" - - # Verify expected output - output = result.stdout + result.stderr - assert "Starting non-blocking streaming demonstration" in output - assert "Heartbeat still running!" in output - assert "Event Loop Analysis:" in output - assert "Event loop remained responsive!" in output - assert "Demonstrating concurrent operations" in output - assert "Demonstration complete!" in output - - # Verify keyspace was cleaned up - async with AsyncCluster(["localhost"]) as cluster: - async with await cluster.connect() as session: - result = await session.execute( - "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = 'streaming_demo'" - ) - assert result.one() is None, "Keyspace was not cleaned up" - - @pytest.mark.parametrize( - "script_name", - [ - "streaming_basic.py", - "export_large_table.py", - "context_manager_safety_demo.py", - "metrics_simple.py", - "export_to_parquet.py", - "streaming_non_blocking_demo.py", - ], - ) - async def test_example_uses_context_managers(self, script_name): - """ - Verify all examples use context managers properly. - - What this tests: - --------------- - 1. AsyncCluster used with context manager - 2. Sessions used with context manager - 3. Streaming uses context manager - 4. No resource leaks - - Why this matters: - ---------------- - - Context managers are mandatory - - Prevents resource leaks - - Examples must show best practices - - Users copy example patterns - """ - script_path = EXAMPLES_DIR / script_name - assert script_path.exists(), f"Example script not found: {script_path}" - - # Read script content - content = script_path.read_text() - - # Check for context manager usage - assert ( - "async with AsyncCluster" in content - ), f"{script_name} doesn't use AsyncCluster context manager" - - # If script has streaming, verify context manager usage - if "execute_stream" in content: - assert ( - "async with await session.execute_stream" in content - or "async with session.execute_stream" in content - ), f"{script_name} doesn't use streaming context manager" - - @pytest.mark.parametrize( - "script_name", - [ - "streaming_basic.py", - "export_large_table.py", - "context_manager_safety_demo.py", - "metrics_simple.py", - "export_to_parquet.py", - "streaming_non_blocking_demo.py", - ], - ) - async def test_example_uses_prepared_statements(self, script_name): - """ - Verify examples use prepared statements for parameterized queries. - - What this tests: - --------------- - 1. Prepared statements for inserts - 2. Prepared statements for selects with parameters - 3. No string interpolation in queries - 4. Proper parameter binding - - Why this matters: - ---------------- - - Prepared statements are mandatory - - Prevents SQL injection - - Better performance - - Examples must show best practices - """ - script_path = EXAMPLES_DIR / script_name - assert script_path.exists(), f"Example script not found: {script_path}" - - # Read script content - content = script_path.read_text() - - # If script has parameterized queries, check for prepared statements - if "VALUES (?" in content or "WHERE" in content and "= ?" in content: - assert ( - "prepare(" in content - ), f"{script_name} has parameterized queries but doesn't use prepare()" - - -class TestExampleDocumentation: - """Test that example documentation is accurate and complete.""" - - async def test_readme_lists_all_examples(self): - """ - Verify README documents all example scripts. - - What this tests: - --------------- - 1. All .py files are documented - 2. Descriptions match actual functionality - 3. Run instructions are provided - 4. Prerequisites are listed - - Why this matters: - ---------------- - - Users rely on README for navigation - - Missing examples confuse users - - Documentation must stay in sync - - First impression matters - """ - readme_path = EXAMPLES_DIR / "README.md" - assert readme_path.exists(), "Examples README.md not found" - - readme_content = readme_path.read_text() - - # Get all Python example files (excluding FastAPI app) - example_files = [ - f.name for f in EXAMPLES_DIR.glob("*.py") if f.is_file() and not f.name.startswith("_") - ] - - # Verify each example is documented - for example_file in example_files: - assert example_file in readme_content, f"{example_file} not documented in README" - - # Verify required sections exist - assert "Prerequisites" in readme_content - assert "Best Practices Demonstrated" in readme_content - assert "Running Multiple Examples" in readme_content - assert "Troubleshooting" in readme_content - - async def test_examples_have_docstrings(self): - """ - Verify all examples have proper module docstrings. - - What this tests: - --------------- - 1. Module-level docstrings exist - 2. Docstrings describe what's demonstrated - 3. Key features are listed - 4. Usage context is clear - - Why this matters: - ---------------- - - Docstrings provide immediate context - - Help users understand purpose - - Good documentation practice - - Self-documenting code - """ - example_files = list(EXAMPLES_DIR.glob("*.py")) - - for example_file in example_files: - content = example_file.read_text() - lines = content.split("\n") - - # Check for module docstring - docstring_found = False - for i, line in enumerate(lines[:20]): # Check first 20 lines - if line.strip().startswith('"""') or line.strip().startswith("'''"): - docstring_found = True - break - - assert docstring_found, f"{example_file.name} missing module docstring" - - # Verify docstring mentions what's demonstrated - if docstring_found: - # Extract docstring content - docstring_lines = [] - for j in range(i, min(i + 20, len(lines))): - docstring_lines.append(lines[j]) - if j > i and ( - lines[j].strip().endswith('"""') or lines[j].strip().endswith("'''") - ): - break - - docstring_content = "\n".join(docstring_lines).lower() - assert ( - "demonstrates" in docstring_content or "example" in docstring_content - ), f"{example_file.name} docstring doesn't describe what it demonstrates" - - -# Run integration test for a specific example (useful for development) -async def run_single_example(example_name: str): - """Run a single example script for testing.""" - script_path = EXAMPLES_DIR / example_name - if not script_path.exists(): - print(f"Example not found: {script_path}") - return - - print(f"Running {example_name}...") - result = subprocess.run( - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=60, - ) - - if result.returncode == 0: - print("Success! Output:") - print(result.stdout) - else: - print("Failed! Error:") - print(result.stderr) - - -if __name__ == "__main__": - # For development testing - import sys - - if len(sys.argv) > 1: - asyncio.run(run_single_example(sys.argv[1])) - else: - print("Usage: python test_example_scripts.py ") - print("Available examples:") - for f in sorted(EXAMPLES_DIR.glob("*.py")): - print(f" - {f.name}") diff --git a/tests/integration/test_fastapi_reconnection_isolation.py b/tests/integration/test_fastapi_reconnection_isolation.py deleted file mode 100644 index 8b83b53..0000000 --- a/tests/integration/test_fastapi_reconnection_isolation.py +++ /dev/null @@ -1,251 +0,0 @@ -""" -Test to isolate why FastAPI app doesn't reconnect after Cassandra comes back. -""" - -import asyncio -import os -import time - -import pytest -from cassandra.policies import ConstantReconnectionPolicy - -from async_cassandra import AsyncCluster -from tests.utils.cassandra_control import CassandraControl - - -class TestFastAPIReconnectionIsolation: - """Isolate FastAPI reconnection issue.""" - - def _get_cassandra_control(self, container=None): - """Get Cassandra control interface.""" - return CassandraControl(container) - - @pytest.mark.integration - @pytest.mark.asyncio - @pytest.mark.skip(reason="Requires container control not available in CI") - async def test_session_health_check_pattern(self): - """ - Test the FastAPI health check pattern that might prevent reconnection. - - What this tests: - --------------- - 1. Health check pattern - 2. Failure detection - 3. Recovery behavior - 4. Session reuse - - Why this matters: - ---------------- - FastAPI patterns: - - Health endpoints common - - Global session reuse - - Must handle outages - - Verifies reconnection works - with app patterns. - """ - pytest.skip("This test requires container control capabilities") - print("\n=== Testing FastAPI Health Check Pattern ===") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Simulate FastAPI startup - cluster = None - session = None - - try: - # Initial connection (like FastAPI startup) - cluster = AsyncCluster( - contact_points=["127.0.0.1"], - protocol_version=5, - reconnection_policy=ConstantReconnectionPolicy(delay=2.0), - connect_timeout=10.0, - ) - session = await cluster.connect() - print("✓ Initial connection established") - - # Create keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS fastapi_test - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("fastapi_test") - - # Simulate health check function - async def health_check(): - """Simulate FastAPI health check.""" - try: - if session is None: - return False - await session.execute("SELECT now() FROM system.local") - return True - except Exception: - return False - - # Initial health check should pass - assert await health_check(), "Initial health check failed" - print("✓ Initial health check passed") - - # Disable Cassandra - print("\nDisabling Cassandra...") - control = self._get_cassandra_control() - - if os.environ.get("CI") == "true": - # Still test that health check works with available service - print("✓ Skipping outage simulation in CI") - else: - success = control.simulate_outage() - assert success, "Failed to simulate outage" - print("✓ Cassandra is down") - - # Health check behavior depends on environment - if os.environ.get("CI") == "true": - # In CI, Cassandra is always up - assert await health_check(), "Health check should pass in CI" - print("✓ Health check passes (CI environment)") - else: - # In local env, should fail when down - assert not await health_check(), "Health check should fail when Cassandra is down" - print("✓ Health check correctly reports failure") - - # Re-enable Cassandra - print("\nRe-enabling Cassandra...") - if not os.environ.get("CI") == "true": - success = control.restore_service() - assert success, "Failed to restore service" - print("✓ Cassandra is ready") - - # Test health check recovery - print("\nTesting health check recovery...") - recovered = False - start_time = time.time() - - for attempt in range(30): - if await health_check(): - recovered = True - elapsed = time.time() - start_time - print(f"✓ Health check recovered after {elapsed:.1f} seconds") - break - await asyncio.sleep(1) - if attempt % 5 == 0: - print(f" After {attempt} seconds: Health check still failing") - - if not recovered: - # Try a direct query to see if session works - print("\nTesting direct query...") - try: - await session.execute("SELECT now() FROM system.local") - print("✓ Direct query works! Health check pattern may be caching errors") - except Exception as e: - print(f"✗ Direct query also fails: {type(e).__name__}: {e}") - - assert recovered, "Health check never recovered" - - finally: - if session: - await session.close() - if cluster: - await cluster.shutdown() - - @pytest.mark.integration - @pytest.mark.asyncio - @pytest.mark.skip(reason="Requires container control not available in CI") - async def test_global_session_reconnection(self): - """ - Test reconnection with global session variable like FastAPI. - - What this tests: - --------------- - 1. Global session pattern - 2. Reconnection works - 3. No session replacement - 4. Automatic recovery - - Why this matters: - ---------------- - Global state common: - - FastAPI apps - - Flask apps - - Service patterns - - Must reconnect without - manual intervention. - """ - pytest.skip("This test requires container control capabilities") - print("\n=== Testing Global Session Reconnection ===") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Global variables like in FastAPI - global session, cluster - session = None - cluster = None - - try: - # Startup - cluster = AsyncCluster( - contact_points=["127.0.0.1"], - protocol_version=5, - reconnection_policy=ConstantReconnectionPolicy(delay=2.0), - connect_timeout=10.0, - ) - session = await cluster.connect() - print("✓ Global session created") - - # Create keyspace - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS global_test - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("global_test") - - # Test query - await session.execute("SELECT now() FROM system.local") - print("✓ Initial query works") - - # Get control interface - control = self._get_cassandra_control() - - if os.environ.get("CI") == "true": - print("\nSkipping outage simulation in CI") - # In CI, just test that the session works - await session.execute("SELECT now() FROM system.local") - print("✓ Session works in CI environment") - else: - # Disable Cassandra - print("\nDisabling Cassandra...") - control.simulate_outage() - - # Re-enable Cassandra - print("Re-enabling Cassandra...") - control.restore_service() - - # Test recovery with global session - print("\nTesting global session recovery...") - recovered = False - for attempt in range(30): - try: - await session.execute("SELECT now() FROM system.local") - recovered = True - print(f"✓ Global session recovered after {attempt + 1} seconds") - break - except Exception as e: - if attempt % 5 == 0: - print(f" After {attempt} seconds: {type(e).__name__}") - await asyncio.sleep(1) - - assert recovered, "Global session never recovered" - - finally: - if session: - await session.close() - if cluster: - await cluster.shutdown() diff --git a/tests/integration/test_long_lived_connections.py b/tests/integration/test_long_lived_connections.py deleted file mode 100644 index 6568d52..0000000 --- a/tests/integration/test_long_lived_connections.py +++ /dev/null @@ -1,370 +0,0 @@ -""" -Integration tests to ensure clusters and sessions are long-lived and reusable. - -This is critical for production applications where connections should be -established once and reused across many requests. -""" - -import asyncio -import time -import uuid - -import pytest - -from async_cassandra import AsyncCluster - - -class TestLongLivedConnections: - """Test that clusters and sessions can be long-lived and reused.""" - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_session_reuse_across_many_operations(self, cassandra_cluster): - """ - Test that a session can be reused for many operations. - - What this tests: - --------------- - 1. Session reuse works - 2. Many operations OK - 3. No degradation - 4. Long-lived sessions - - Why this matters: - ---------------- - Production pattern: - - One session per app - - Thousands of queries - - No reconnection cost - - Must support connection - pooling correctly. - """ - # Create session once - session = await cassandra_cluster.connect() - - # Use session for many operations - operations_count = 100 - results = [] - - for i in range(operations_count): - result = await session.execute("SELECT release_version FROM system.local") - results.append(result.one()) - - # Small delay to simulate time between requests - await asyncio.sleep(0.01) - - # Verify all operations succeeded - assert len(results) == operations_count - assert all(r is not None for r in results) - - # Session should still be usable - final_result = await session.execute("SELECT now() FROM system.local") - assert final_result.one() is not None - - # Explicitly close when done (not after each operation) - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_cluster_creates_multiple_sessions(self, cassandra_cluster): - """ - Test that a cluster can create multiple sessions. - - What this tests: - --------------- - 1. Multiple sessions work - 2. Sessions independent - 3. Concurrent usage OK - 4. Resource isolation - - Why this matters: - ---------------- - Multi-session needs: - - Microservices - - Different keyspaces - - Isolation requirements - - Cluster manages many - sessions properly. - """ - # Create multiple sessions from same cluster - sessions = [] - session_count = 5 - - for i in range(session_count): - session = await cassandra_cluster.connect() - sessions.append(session) - - # Use all sessions concurrently - async def use_session(session, session_id): - results = [] - for i in range(10): - result = await session.execute("SELECT release_version FROM system.local") - results.append(result.one()) - return session_id, results - - tasks = [use_session(session, i) for i, session in enumerate(sessions)] - results = await asyncio.gather(*tasks) - - # Verify all sessions worked - assert len(results) == session_count - for session_id, session_results in results: - assert len(session_results) == 10 - assert all(r is not None for r in session_results) - - # Close all sessions - for session in sessions: - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_session_survives_errors(self, cassandra_cluster): - """ - Test that session remains usable after query errors. - - What this tests: - --------------- - 1. Errors don't kill session - 2. Recovery automatic - 3. Multiple error types - 4. Continued operation - - Why this matters: - ---------------- - Real apps have errors: - - Bad queries - - Missing tables - - Syntax issues - - Session must survive all - non-fatal errors. - """ - session = await cassandra_cluster.connect() - await session.execute( - "CREATE KEYSPACE IF NOT EXISTS test_long_lived " - "WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}" - ) - await session.set_keyspace("test_long_lived") - - # Create test table - await session.execute( - "CREATE TABLE IF NOT EXISTS test_errors (id UUID PRIMARY KEY, data TEXT)" - ) - - # Successful operation - test_id = uuid.uuid4() - insert_stmt = await session.prepare("INSERT INTO test_errors (id, data) VALUES (?, ?)") - await session.execute(insert_stmt, [test_id, "test data"]) - - # Cause an error (invalid query) - with pytest.raises(Exception): # Will be InvalidRequest or similar - await session.execute("INVALID QUERY SYNTAX") - - # Session should still be usable after error - select_stmt = await session.prepare("SELECT * FROM test_errors WHERE id = ?") - result = await session.execute(select_stmt, [test_id]) - assert result.one() is not None - assert result.one().data == "test data" - - # Another error (table doesn't exist) - with pytest.raises(Exception): - await session.execute("SELECT * FROM non_existent_table") - - # Still usable - result = await session.execute("SELECT now() FROM system.local") - assert result.one() is not None - - # Cleanup - await session.execute("DROP TABLE IF EXISTS test_errors") - await session.execute("DROP KEYSPACE IF EXISTS test_long_lived") - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_prepared_statements_are_cached(self, cassandra_cluster): - """ - Test that prepared statements can be reused efficiently. - - What this tests: - --------------- - 1. Statement caching works - 2. Reuse is efficient - 3. Multiple statements OK - 4. No re-preparation - - Why this matters: - ---------------- - Performance critical: - - Prepare once - - Execute many times - - Reduced latency - - Core optimization for - production apps. - """ - session = await cassandra_cluster.connect() - - # Prepare statement once - prepared = await session.prepare("SELECT release_version FROM system.local WHERE key = ?") - - # Reuse prepared statement many times - for i in range(50): - result = await session.execute(prepared, ["local"]) - assert result.one() is not None - - # Prepare another statement - prepared2 = await session.prepare("SELECT cluster_name FROM system.local WHERE key = ?") - - # Both prepared statements should be reusable - result1 = await session.execute(prepared, ["local"]) - result2 = await session.execute(prepared2, ["local"]) - - assert result1.one() is not None - assert result2.one() is not None - - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_session_lifetime_measurement(self, cassandra_cluster): - """ - Test that sessions can live for extended periods. - - What this tests: - --------------- - 1. Extended lifetime OK - 2. No timeout issues - 3. Sustained throughput - 4. Stable performance - - Why this matters: - ---------------- - Production sessions: - - Days to weeks alive - - Millions of queries - - No restarts needed - - Proves long-term - stability. - """ - session = await cassandra_cluster.connect() - start_time = time.time() - - # Use session over a period of time - test_duration = 5 # seconds - operations = 0 - - while time.time() - start_time < test_duration: - result = await session.execute("SELECT now() FROM system.local") - assert result.one() is not None - operations += 1 - await asyncio.sleep(0.1) # 10 operations per second - - end_time = time.time() - actual_duration = end_time - start_time - - # Session should have been alive for the full duration - assert actual_duration >= test_duration - assert operations >= test_duration * 9 # At least 9 ops/second - - # Still usable after the test period - final_result = await session.execute("SELECT now() FROM system.local") - assert final_result.one() is not None - - await session.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_context_manager_closes_session(self): - """ - Test that context manager does close session (for scripts/tests). - - What this tests: - --------------- - 1. Context manager works - 2. Session closed on exit - 3. Cluster still usable - 4. Clean resource handling - - Why this matters: - ---------------- - Script patterns: - - Short-lived sessions - - Automatic cleanup - - No leaks - - Different from production - but still supported. - """ - # Create cluster manually to test context manager - cluster = AsyncCluster(["localhost"]) - - # Use context manager - async with await cluster.connect() as session: - # Session should be usable - result = await session.execute("SELECT now() FROM system.local") - assert result.one() is not None - assert not session.is_closed - - # Session should be closed after context exit - assert session.is_closed - - # Cluster should still be usable - new_session = await cluster.connect() - result = await new_session.execute("SELECT now() FROM system.local") - assert result.one() is not None - - await new_session.close() - await cluster.shutdown() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_production_pattern(self): - """ - Test the recommended production pattern. - - What this tests: - --------------- - 1. Production lifecycle - 2. Startup/shutdown once - 3. Many requests handled - 4. Concurrent load OK - - Why this matters: - ---------------- - Best practice pattern: - - Initialize once - - Reuse everywhere - - Clean shutdown - - Template for real - applications. - """ - # This simulates a production application lifecycle - - # Application startup - cluster = AsyncCluster(["localhost"]) - session = await cluster.connect() - - # Simulate many requests over time - async def handle_request(request_id): - """Simulate handling a web request.""" - result = await session.execute("SELECT cluster_name FROM system.local") - return f"Request {request_id}: {result.one().cluster_name}" - - # Handle many concurrent requests - for batch in range(5): # 5 batches - tasks = [ - handle_request(f"{batch}-{i}") - for i in range(20) # 20 concurrent requests per batch - ] - results = await asyncio.gather(*tasks) - assert len(results) == 20 - - # Small delay between batches - await asyncio.sleep(0.1) - - # Application shutdown (only happens once) - await session.close() - await cluster.shutdown() diff --git a/tests/integration/test_network_failures.py b/tests/integration/test_network_failures.py deleted file mode 100644 index 245d70c..0000000 --- a/tests/integration/test_network_failures.py +++ /dev/null @@ -1,411 +0,0 @@ -""" -Integration tests for network failure scenarios against real Cassandra. - -Note: These tests require the ability to manipulate network conditions. -They will be skipped if running in environments without proper permissions. -""" - -import asyncio -import time -import uuid - -import pytest -from cassandra import OperationTimedOut, ReadTimeout, Unavailable -from cassandra.cluster import NoHostAvailable - -from async_cassandra import AsyncCassandraSession, AsyncCluster -from async_cassandra.exceptions import ConnectionError - - -@pytest.mark.integration -class TestNetworkFailures: - """Test behavior under various network failure conditions.""" - - @pytest.mark.asyncio - async def test_unavailable_handling(self, cassandra_session): - """ - Test handling of Unavailable exceptions. - - What this tests: - --------------- - 1. Unavailable errors caught - 2. Replica count reported - 3. Consistency level impact - 4. Error message clarity - - Why this matters: - ---------------- - Unavailable errors indicate: - - Not enough replicas - - Cluster health issues - - Consistency impossible - - Apps must handle cluster - degradation gracefully. - """ - # Create a table with high replication factor in a new keyspace - # This test needs its own keyspace to test replication - await cassandra_session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_unavailable - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 3} - """ - ) - - # Use the new keyspace temporarily - original_keyspace = cassandra_session.keyspace - await cassandra_session.set_keyspace("test_unavailable") - - try: - await cassandra_session.execute("DROP TABLE IF EXISTS unavailable_test") - await cassandra_session.execute( - """ - CREATE TABLE unavailable_test ( - id UUID PRIMARY KEY, - data TEXT - ) - """ - ) - - # With replication factor 3 on a single node, QUORUM/ALL will fail - from cassandra import ConsistencyLevel - from cassandra.query import SimpleStatement - - # This should fail with Unavailable - insert_stmt = SimpleStatement( - "INSERT INTO unavailable_test (id, data) VALUES (%s, %s)", - consistency_level=ConsistencyLevel.ALL, - ) - - try: - await cassandra_session.execute(insert_stmt, [uuid.uuid4(), "test data"]) - pytest.fail("Should have raised Unavailable exception") - except (Unavailable, Exception) as e: - # Expected - we don't have 3 replicas - # The exception might be wrapped or not depending on the driver version - if isinstance(e, Unavailable): - assert e.alive_replicas < e.required_replicas - else: - # Check if it's wrapped - assert "Unavailable" in str(e) or "Cannot achieve consistency level ALL" in str( - e - ) - - finally: - # Clean up and restore original keyspace - await cassandra_session.execute("DROP KEYSPACE IF EXISTS test_unavailable") - await cassandra_session.set_keyspace(original_keyspace) - - @pytest.mark.asyncio - async def test_connection_pool_exhaustion(self, cassandra_session: AsyncCassandraSession): - """ - Test behavior when connection pool is exhausted. - - What this tests: - --------------- - 1. Many concurrent queries - 2. Pool limits respected - 3. Most queries succeed - 4. Graceful degradation - - Why this matters: - ---------------- - Pool exhaustion happens: - - Traffic spikes - - Slow queries - - Resource limits - - System must degrade - gracefully, not crash. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Create many concurrent long-running queries - async def long_query(i): - try: - # This query will scan the entire table - result = await cassandra_session.execute( - f"SELECT * FROM {users_table} ALLOW FILTERING" - ) - count = 0 - async for _ in result: - count += 1 - return i, count, None - except Exception as e: - return i, 0, str(e) - - # Insert some data first - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - for i in range(100): - await cassandra_session.execute( - insert_stmt, - [uuid.uuid4(), f"User {i}", f"user{i}@test.com", 25], - ) - - # Launch many concurrent queries - tasks = [long_query(i) for i in range(50)] - results = await asyncio.gather(*tasks) - - # Check results - successful = sum(1 for _, count, error in results if error is None) - failed = sum(1 for _, count, error in results if error is not None) - - print("\nConnection pool test results:") - print(f" Successful queries: {successful}") - print(f" Failed queries: {failed}") - - # Most queries should succeed - assert successful >= 45 # Allow a few failures - - @pytest.mark.asyncio - async def test_read_timeout_behavior(self, cassandra_session: AsyncCassandraSession): - """ - Test read timeout behavior with different scenarios. - - What this tests: - --------------- - 1. Short timeouts fail fast - 2. Reasonable timeouts work - 3. Timeout errors caught - 4. Query-level timeouts - - Why this matters: - ---------------- - Timeout control prevents: - - Hanging operations - - Resource exhaustion - - Poor user experience - - Critical for responsive - applications. - """ - # Create test data - await cassandra_session.execute("DROP TABLE IF EXISTS read_timeout_test") - await cassandra_session.execute( - """ - CREATE TABLE read_timeout_test ( - partition_key INT, - clustering_key INT, - data TEXT, - PRIMARY KEY (partition_key, clustering_key) - ) - """ - ) - - # Insert data across multiple partitions - # Prepare statement first - insert_stmt = await cassandra_session.prepare( - "INSERT INTO read_timeout_test (partition_key, clustering_key, data) " - "VALUES (?, ?, ?)" - ) - - insert_tasks = [] - for p in range(10): - for c in range(100): - task = cassandra_session.execute( - insert_stmt, - [p, c, f"data_{p}_{c}"], - ) - insert_tasks.append(task) - - # Execute in batches - for i in range(0, len(insert_tasks), 50): - await asyncio.gather(*insert_tasks[i : i + 50]) - - # Test 1: Query that might timeout on slow systems - start_time = time.time() - try: - result = await cassandra_session.execute( - "SELECT * FROM read_timeout_test", timeout=0.05 # 50ms timeout - ) - # Try to consume results - count = 0 - async for _ in result: - count += 1 - except (ReadTimeout, OperationTimedOut): - # Expected on most systems - duration = time.time() - start_time - assert duration < 1.0 # Should fail quickly - - # Test 2: Query with reasonable timeout should succeed - result = await cassandra_session.execute( - "SELECT * FROM read_timeout_test WHERE partition_key = 1", timeout=5.0 - ) - - rows = [] - async for row in result: - rows.append(row) - - assert len(rows) == 100 # Should get all rows from partition 1 - - @pytest.mark.asyncio - async def test_concurrent_failures_recovery(self, cassandra_session: AsyncCassandraSession): - """ - Test that the system recovers properly from concurrent failures. - - What this tests: - --------------- - 1. Retry logic works - 2. Exponential backoff - 3. High success rate - 4. Concurrent recovery - - Why this matters: - ---------------- - Transient failures common: - - Network hiccups - - Temporary overload - - Node restarts - - Smart retries maintain - reliability. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Prepare test data - test_ids = [uuid.uuid4() for _ in range(100)] - - # Insert test data - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - for test_id in test_ids: - await cassandra_session.execute( - insert_stmt, - [test_id, "Test User", "test@test.com", 30], - ) - - # Prepare select statement for reuse - select_stmt = await cassandra_session.prepare(f"SELECT * FROM {users_table} WHERE id = ?") - - # Function that sometimes fails - async def unreliable_query(user_id, fail_rate=0.2): - import random - - # Simulate random failures - if random.random() < fail_rate: - raise Exception("Simulated failure") - - result = await cassandra_session.execute(select_stmt, [user_id]) - rows = [] - async for row in result: - rows.append(row) - return rows[0] if rows else None - - # Run many concurrent queries with retries - async def query_with_retry(user_id, max_retries=3): - for attempt in range(max_retries): - try: - return await unreliable_query(user_id) - except Exception: - if attempt == max_retries - 1: - raise - await asyncio.sleep(0.1 * (attempt + 1)) # Exponential backoff - - # Execute concurrent queries - tasks = [query_with_retry(uid) for uid in test_ids] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Check results - successful = sum(1 for r in results if not isinstance(r, Exception)) - failed = sum(1 for r in results if isinstance(r, Exception)) - - print("\nRecovery test results:") - print(f" Successful queries: {successful}") - print(f" Failed queries: {failed}") - - # With retries, most should succeed - assert successful >= 95 # At least 95% success rate - - @pytest.mark.asyncio - async def test_connection_timeout_handling(self): - """ - Test connection timeout with unreachable hosts. - - What this tests: - --------------- - 1. Unreachable hosts timeout - 2. Timeout respected - 3. Fast failure - 4. Clear error - - Why this matters: - ---------------- - Connection timeouts prevent: - - Hanging startup - - Infinite waits - - Resource tie-up - - Fast failure enables - quick recovery. - """ - # Try to connect to non-existent host - async with AsyncCluster( - contact_points=["192.168.255.255"], # Non-routable IP - control_connection_timeout=1.0, - ) as cluster: - start_time = time.time() - - with pytest.raises((ConnectionError, NoHostAvailable, asyncio.TimeoutError)): - # Should timeout quickly - await cluster.connect(timeout=2.0) - - duration = time.time() - start_time - assert duration < 5.0 # Should fail within timeout period - - @pytest.mark.asyncio - async def test_batch_operations_with_failures(self, cassandra_session: AsyncCassandraSession): - """ - Test batch operation behavior during failures. - - What this tests: - --------------- - 1. Batch execution works - 2. Unlogged batches - 3. Multiple statements - 4. Data verification - - Why this matters: - ---------------- - Batch operations must: - - Handle partial failures - - Complete successfully - - Insert all data - - Critical for bulk - data operations. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - from cassandra.query import BatchStatement, BatchType - - # Create a batch - batch = BatchStatement(batch_type=BatchType.UNLOGGED) - - # Prepare statement for batch - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - # Add multiple statements to the batch - for i in range(20): - batch.add( - insert_stmt, - [uuid.uuid4(), f"Batch User {i}", f"batch{i}@test.com", 25], - ) - - # Execute batch - should succeed - await cassandra_session.execute_batch(batch) - - # Verify data was inserted - count_stmt = await cassandra_session.prepare( - f"SELECT COUNT(*) FROM {users_table} WHERE age = ? ALLOW FILTERING" - ) - result = await cassandra_session.execute(count_stmt, [25]) - count = result.one()[0] - assert count >= 20 # At least our batch inserts diff --git a/tests/integration/test_protocol_version.py b/tests/integration/test_protocol_version.py deleted file mode 100644 index c72ea49..0000000 --- a/tests/integration/test_protocol_version.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Integration tests for protocol version connection. - -Only tests actual connection with protocol v5 - validation logic is tested in unit tests. -""" - -import pytest - -from async_cassandra import AsyncCluster - - -class TestProtocolVersionIntegration: - """Integration tests for protocol version connection.""" - - @pytest.mark.asyncio - async def test_protocol_v5_connection(self): - """ - Test successful connection with protocol v5. - - What this tests: - --------------- - 1. Protocol v5 connects - 2. Queries execute OK - 3. Results returned - 4. Clean shutdown - - Why this matters: - ---------------- - Protocol v5 required: - - Async features - - Better performance - - New data types - - Verifies minimum protocol - version works. - """ - cluster = AsyncCluster(contact_points=["localhost"], protocol_version=5) - - try: - session = await cluster.connect() - - # Verify we can execute queries - result = await session.execute("SELECT release_version FROM system.local") - row = result.one() - assert row is not None - - await session.close() - finally: - await cluster.shutdown() - - @pytest.mark.asyncio - async def test_no_protocol_version_uses_negotiation(self): - """ - Test that omitting protocol version allows negotiation. - - What this tests: - --------------- - 1. Auto-negotiation works - 2. Driver picks version - 3. Connection succeeds - 4. Queries work - - Why this matters: - ---------------- - Flexible configuration: - - Works with any server - - Future compatibility - - Easier deployment - - Default behavior should - just work. - """ - cluster = AsyncCluster( - contact_points=["localhost"] - # No protocol_version specified - driver will negotiate - ) - - try: - session = await cluster.connect() - - # Should connect successfully - result = await session.execute("SELECT release_version FROM system.local") - assert result.one() is not None - - await session.close() - finally: - await cluster.shutdown() diff --git a/tests/integration/test_reconnection_behavior.py b/tests/integration/test_reconnection_behavior.py deleted file mode 100644 index 882d6b2..0000000 --- a/tests/integration/test_reconnection_behavior.py +++ /dev/null @@ -1,394 +0,0 @@ -""" -Integration tests comparing reconnection behavior between raw driver and async wrapper. - -This test verifies that our wrapper doesn't interfere with the driver's reconnection logic. -""" - -import asyncio -import os -import subprocess -import time - -import pytest -from cassandra.cluster import Cluster -from cassandra.policies import ConstantReconnectionPolicy - -from async_cassandra import AsyncCluster -from tests.utils.cassandra_control import CassandraControl - - -class TestReconnectionBehavior: - """Test reconnection behavior of raw driver vs async wrapper.""" - - def _get_cassandra_control(self, container=None): - """Get Cassandra control interface for the test environment.""" - # For integration tests, create a mock container object with just the fields we need - if container is None and os.environ.get("CI") != "true": - container = type( - "MockContainer", - (), - { - "container_name": "async-cassandra-test", - "runtime": ( - "podman" - if subprocess.run(["which", "podman"], capture_output=True).returncode == 0 - else "docker" - ), - }, - )() - return CassandraControl(container) - - @pytest.mark.integration - def test_raw_driver_reconnection(self): - """ - Test reconnection with raw Cassandra driver (synchronous). - - What this tests: - --------------- - 1. Raw driver reconnects - 2. After service outage - 3. Reconnection policy works - 4. Full functionality restored - - Why this matters: - ---------------- - Baseline behavior shows: - - Expected reconnection time - - Driver capabilities - - Recovery patterns - - Wrapper must match this - baseline behavior. - """ - print("\n=== Testing Raw Driver Reconnection ===") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Create cluster with constant reconnection policy - cluster = Cluster( - contact_points=["127.0.0.1"], - protocol_version=5, - reconnection_policy=ConstantReconnectionPolicy(delay=2.0), - connect_timeout=10.0, - ) - - session = cluster.connect() - - # Create test keyspace and table - session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS reconnect_test_sync - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - session.set_keyspace("reconnect_test_sync") - session.execute("DROP TABLE IF EXISTS test_table") - session.execute( - """ - CREATE TABLE test_table ( - id INT PRIMARY KEY, - value TEXT - ) - """ - ) - - # Insert initial data - session.execute("INSERT INTO test_table (id, value) VALUES (1, 'before_outage')") - result = session.execute("SELECT * FROM test_table WHERE id = 1") - assert result.one().value == "before_outage" - print("✓ Initial connection working") - - # Get control interface - control = self._get_cassandra_control() - - # Disable Cassandra - print("Disabling Cassandra binary protocol...") - success = control.simulate_outage() - assert success, "Failed to simulate Cassandra outage" - print("✓ Cassandra is down") - - # Try query - should fail - try: - session.execute("SELECT * FROM test_table", timeout=2.0) - assert False, "Query should have failed" - except Exception as e: - print(f"✓ Query failed as expected: {type(e).__name__}") - - # Re-enable Cassandra - print("Re-enabling Cassandra binary protocol...") - success = control.restore_service() - assert success, "Failed to restore Cassandra service" - print("✓ Cassandra is ready") - - # Test reconnection - try for up to 30 seconds - reconnected = False - start_time = time.time() - while time.time() - start_time < 30: - try: - result = session.execute("SELECT * FROM test_table WHERE id = 1") - if result.one().value == "before_outage": - reconnected = True - elapsed = time.time() - start_time - print(f"✓ Raw driver reconnected after {elapsed:.1f} seconds") - break - except Exception: - pass - time.sleep(1) - - assert reconnected, "Raw driver failed to reconnect within 30 seconds" - - # Insert new data to verify full functionality - session.execute("INSERT INTO test_table (id, value) VALUES (2, 'after_reconnect')") - result = session.execute("SELECT * FROM test_table WHERE id = 2") - assert result.one().value == "after_reconnect" - print("✓ Can insert and query after reconnection") - - cluster.shutdown() - - @pytest.mark.integration - @pytest.mark.asyncio - async def test_async_wrapper_reconnection(self): - """ - Test reconnection with async wrapper. - - What this tests: - --------------- - 1. Wrapper reconnects properly - 2. Async operations resume - 3. No blocking during outage - 4. Same behavior as raw driver - - Why this matters: - ---------------- - Wrapper must not break: - - Driver reconnection logic - - Automatic recovery - - Connection pooling - - Critical for production - reliability. - """ - print("\n=== Testing Async Wrapper Reconnection ===") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Create cluster with constant reconnection policy - cluster = AsyncCluster( - contact_points=["127.0.0.1"], - protocol_version=5, - reconnection_policy=ConstantReconnectionPolicy(delay=2.0), - connect_timeout=10.0, - ) - - session = await cluster.connect() - - # Create test keyspace and table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS reconnect_test_async - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - await session.set_keyspace("reconnect_test_async") - await session.execute("DROP TABLE IF EXISTS test_table") - await session.execute( - """ - CREATE TABLE test_table ( - id INT PRIMARY KEY, - value TEXT - ) - """ - ) - - # Insert initial data - await session.execute("INSERT INTO test_table (id, value) VALUES (1, 'before_outage')") - result = await session.execute("SELECT * FROM test_table WHERE id = 1") - assert result.one().value == "before_outage" - print("✓ Initial connection working") - - # Get control interface - control = self._get_cassandra_control() - - # Disable Cassandra - print("Disabling Cassandra binary protocol...") - success = control.simulate_outage() - assert success, "Failed to simulate Cassandra outage" - print("✓ Cassandra is down") - - # Try query - should fail - try: - await session.execute("SELECT * FROM test_table", timeout=2.0) - assert False, "Query should have failed" - except Exception as e: - print(f"✓ Query failed as expected: {type(e).__name__}") - - # Re-enable Cassandra - print("Re-enabling Cassandra binary protocol...") - success = control.restore_service() - assert success, "Failed to restore Cassandra service" - print("✓ Cassandra is ready") - - # Test reconnection - try for up to 30 seconds - reconnected = False - start_time = time.time() - while time.time() - start_time < 30: - try: - result = await session.execute("SELECT * FROM test_table WHERE id = 1") - if result.one().value == "before_outage": - reconnected = True - elapsed = time.time() - start_time - print(f"✓ Async wrapper reconnected after {elapsed:.1f} seconds") - break - except Exception: - pass - await asyncio.sleep(1) - - assert reconnected, "Async wrapper failed to reconnect within 30 seconds" - - # Insert new data to verify full functionality - await session.execute("INSERT INTO test_table (id, value) VALUES (2, 'after_reconnect')") - result = await session.execute("SELECT * FROM test_table WHERE id = 2") - assert result.one().value == "after_reconnect" - print("✓ Can insert and query after reconnection") - - await session.close() - await cluster.shutdown() - - @pytest.mark.integration - @pytest.mark.asyncio - async def test_reconnection_timing_comparison(self): - """ - Compare reconnection timing between raw driver and async wrapper. - - What this tests: - --------------- - 1. Both reconnect similarly - 2. Timing within 5 seconds - 3. No wrapper overhead - 4. Parallel comparison - - Why this matters: - ---------------- - Performance validation: - - Wrapper adds minimal delay - - Recovery time predictable - - Production SLAs met - - Ensures wrapper doesn't - degrade reconnection. - """ - print("\n=== Comparing Reconnection Timing ===") - - # Skip this test in CI since we can't control Cassandra service - if os.environ.get("CI") == "true": - pytest.skip("Cannot control Cassandra service in CI environment") - - # Test both in parallel to ensure fair comparison - raw_reconnect_time = None - async_reconnect_time = None - - def test_raw_driver(): - nonlocal raw_reconnect_time - cluster = Cluster( - contact_points=["127.0.0.1"], - protocol_version=5, - reconnection_policy=ConstantReconnectionPolicy(delay=2.0), - connect_timeout=10.0, - ) - session = cluster.connect() - session.execute("SELECT now() FROM system.local") - - # Wait for Cassandra to be down - time.sleep(2) # Give time for Cassandra to be disabled - - # Measure reconnection time - start_time = time.time() - while time.time() - start_time < 30: - try: - session.execute("SELECT now() FROM system.local") - raw_reconnect_time = time.time() - start_time - break - except Exception: - time.sleep(0.5) - - cluster.shutdown() - - async def test_async_wrapper(): - nonlocal async_reconnect_time - cluster = AsyncCluster( - contact_points=["127.0.0.1"], - protocol_version=5, - reconnection_policy=ConstantReconnectionPolicy(delay=2.0), - connect_timeout=10.0, - ) - session = await cluster.connect() - await session.execute("SELECT now() FROM system.local") - - # Wait for Cassandra to be down - await asyncio.sleep(2) # Give time for Cassandra to be disabled - - # Measure reconnection time - start_time = time.time() - while time.time() - start_time < 30: - try: - await session.execute("SELECT now() FROM system.local") - async_reconnect_time = time.time() - start_time - break - except Exception: - await asyncio.sleep(0.5) - - await session.close() - await cluster.shutdown() - - # Get control interface - control = self._get_cassandra_control() - - # Ensure Cassandra is up - assert control.wait_for_cassandra_ready(), "Cassandra not ready at start" - - # Start both tests - import threading - - raw_thread = threading.Thread(target=test_raw_driver) - raw_thread.start() - async_task = asyncio.create_task(test_async_wrapper()) - - # Disable Cassandra after connections are established - await asyncio.sleep(1) - print("Disabling Cassandra...") - control.simulate_outage() - - # Re-enable after a few seconds - await asyncio.sleep(3) - print("Re-enabling Cassandra...") - control.restore_service() - - # Wait for both tests to complete - raw_thread.join(timeout=35) - await asyncio.wait_for(async_task, timeout=35) - - # Compare results - print("\nReconnection times:") - print( - f" Raw driver: {raw_reconnect_time:.1f}s" - if raw_reconnect_time - else " Raw driver: Failed to reconnect" - ) - print( - f" Async wrapper: {async_reconnect_time:.1f}s" - if async_reconnect_time - else " Async wrapper: Failed to reconnect" - ) - - # Both should reconnect - assert raw_reconnect_time is not None, "Raw driver failed to reconnect" - assert async_reconnect_time is not None, "Async wrapper failed to reconnect" - - # Times should be similar (within 5 seconds) - time_diff = abs(raw_reconnect_time - async_reconnect_time) - assert time_diff < 5.0, f"Reconnection time difference too large: {time_diff:.1f}s" - print(f"✓ Reconnection times are similar (difference: {time_diff:.1f}s)") diff --git a/tests/integration/test_select_operations.py b/tests/integration/test_select_operations.py deleted file mode 100644 index 3344ff9..0000000 --- a/tests/integration/test_select_operations.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Integration tests for SELECT query operations. - -This file focuses on advanced SELECT scenarios: consistency levels, large result sets, -concurrent operations, and special query features. Basic SELECT operations have been -moved to test_crud_operations.py. -""" - -import asyncio -import uuid - -import pytest -from cassandra.query import SimpleStatement - - -@pytest.mark.integration -class TestSelectOperations: - """Test advanced SELECT query operations with real Cassandra.""" - - @pytest.mark.asyncio - async def test_select_with_large_result_set(self, cassandra_session): - """ - Test SELECT with large result sets to verify paging and retries work. - - What this tests: - --------------- - 1. Large result sets (1000+ rows) - 2. Automatic paging with fetch_size - 3. Memory-efficient iteration - 4. ALLOW FILTERING queries - - Why this matters: - ---------------- - Large result sets require: - - Paging to avoid OOM - - Streaming for efficiency - - Proper retry handling - - Critical for analytics and - bulk data processing. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Insert many rows - # Prepare statement once - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - insert_tasks = [] - for i in range(1000): - task = cassandra_session.execute( - insert_stmt, - [uuid.uuid4(), f"User {i}", f"user{i}@example.com", 20 + (i % 50)], - ) - insert_tasks.append(task) - - # Execute in batches to avoid overwhelming - for i in range(0, len(insert_tasks), 100): - await asyncio.gather(*insert_tasks[i : i + 100]) - - # Query with small fetch size to test paging - statement = SimpleStatement( - f"SELECT * FROM {users_table} WHERE age >= 20 AND age <= 30 ALLOW FILTERING", - fetch_size=50, - ) - result = await cassandra_session.execute(statement) - - count = 0 - async for row in result: - assert 20 <= row.age <= 30 - count += 1 - - # Should have retrieved multiple pages - assert count > 50 - - @pytest.mark.asyncio - async def test_select_with_limit_and_ordering(self, cassandra_session): - """ - Test SELECT with LIMIT and ordering to ensure retries preserve results. - - What this tests: - --------------- - 1. LIMIT clause respected - 2. Clustering order preserved - 3. Time series queries - 4. Result consistency - - Why this matters: - ---------------- - Ordered queries critical for: - - Time series data - - Top-N queries - - Pagination - - Order must be consistent - across retries. - """ - # Create a table with clustering columns for ordering - await cassandra_session.execute("DROP TABLE IF EXISTS time_series") - await cassandra_session.execute( - """ - CREATE TABLE time_series ( - partition_key UUID, - timestamp TIMESTAMP, - value DOUBLE, - PRIMARY KEY (partition_key, timestamp) - ) WITH CLUSTERING ORDER BY (timestamp DESC) - """ - ) - - # Insert time series data - partition_key = uuid.uuid4() - base_time = 1700000000000 # milliseconds - - # Prepare insert statement - insert_stmt = await cassandra_session.prepare( - "INSERT INTO time_series (partition_key, timestamp, value) VALUES (?, ?, ?)" - ) - - for i in range(100): - await cassandra_session.execute( - insert_stmt, - [partition_key, base_time + i * 1000, float(i)], - ) - - # Query with limit - select_stmt = await cassandra_session.prepare( - "SELECT * FROM time_series WHERE partition_key = ? LIMIT 10" - ) - result = await cassandra_session.execute(select_stmt, [partition_key]) - - rows = [] - async for row in result: - rows.append(row) - - # Should get exactly 10 rows in descending order - assert len(rows) == 10 - # Verify descending order (latest timestamps first) - for i in range(1, len(rows)): - assert rows[i - 1].timestamp > rows[i].timestamp diff --git a/tests/integration/test_simple_statements.py b/tests/integration/test_simple_statements.py deleted file mode 100644 index e33f50b..0000000 --- a/tests/integration/test_simple_statements.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Integration tests for SimpleStatement functionality. - -This test module specifically tests SimpleStatement usage, which is generally -discouraged in favor of prepared statements but may be needed for: -- Setting consistency levels -- Legacy code compatibility -- Dynamic queries that can't be prepared -""" - -import uuid - -import pytest -from cassandra.query import SimpleStatement - - -@pytest.mark.integration -class TestSimpleStatements: - """Test SimpleStatement functionality with real Cassandra.""" - - @pytest.mark.asyncio - async def test_simple_statement_basic_usage(self, cassandra_session): - """ - Test basic SimpleStatement usage with parameters. - - What this tests: - --------------- - 1. SimpleStatement creation - 2. Parameter binding with %s - 3. Query execution - 4. Result retrieval - - Why this matters: - ---------------- - SimpleStatement needed for: - - Legacy code compatibility - - Dynamic queries - - One-off statements - - Must work but prepared - statements preferred. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Create a SimpleStatement with parameters - user_id = uuid.uuid4() - insert_stmt = SimpleStatement( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (%s, %s, %s, %s)" - ) - - # Execute with parameters - await cassandra_session.execute(insert_stmt, [user_id, "John Doe", "john@example.com", 30]) - - # Verify with SELECT - select_stmt = SimpleStatement(f"SELECT * FROM {users_table} WHERE id = %s") - result = await cassandra_session.execute(select_stmt, [user_id]) - - row = result.one() - assert row is not None - assert row.name == "John Doe" - assert row.email == "john@example.com" - assert row.age == 30 - - @pytest.mark.asyncio - async def test_simple_statement_without_parameters(self, cassandra_session): - """ - Test SimpleStatement without parameters for queries. - - What this tests: - --------------- - 1. Parameterless queries - 2. Fetch size configuration - 3. Result pagination - 4. Multiple row handling - - Why this matters: - ---------------- - Some queries need no params: - - Table scans - - Aggregations - - DDL operations - - SimpleStatement supports - all query options. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Insert some test data using prepared statement - insert_prepared = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - for i in range(5): - await cassandra_session.execute( - insert_prepared, [uuid.uuid4(), f"User {i}", f"user{i}@example.com", 20 + i] - ) - - # Use SimpleStatement for a parameter-less query - select_all = SimpleStatement( - f"SELECT * FROM {users_table}", fetch_size=2 # Test pagination - ) - - result = await cassandra_session.execute(select_all) - rows = list(result) - - # Should have at least 5 rows - assert len(rows) >= 5 - - @pytest.mark.asyncio - async def test_simple_statement_vs_prepared_performance(self, cassandra_session): - """ - Compare SimpleStatement vs PreparedStatement (prepared should be faster). - - What this tests: - --------------- - 1. Performance comparison - 2. Both statement types work - 3. Timing measurements - 4. Prepared advantages - - Why this matters: - ---------------- - Shows why prepared better: - - Query plan caching - - Type validation - - Network efficiency - - Educates on best - practices. - """ - import time - - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Time SimpleStatement execution - simple_stmt = SimpleStatement( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (%s, %s, %s, %s)" - ) - - simple_start = time.perf_counter() - for i in range(10): - await cassandra_session.execute( - simple_stmt, [uuid.uuid4(), f"Simple {i}", f"simple{i}@example.com", i] - ) - simple_time = time.perf_counter() - simple_start - - # Time PreparedStatement execution - prepared_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - prepared_start = time.perf_counter() - for i in range(10): - await cassandra_session.execute( - prepared_stmt, [uuid.uuid4(), f"Prepared {i}", f"prepared{i}@example.com", i] - ) - prepared_time = time.perf_counter() - prepared_start - - # Log the times for debugging - print(f"SimpleStatement time: {simple_time:.3f}s") - print(f"PreparedStatement time: {prepared_time:.3f}s") - - # PreparedStatement should generally be faster, but we won't assert - # this as it can vary based on network conditions - - @pytest.mark.asyncio - async def test_simple_statement_with_custom_payload(self, cassandra_session): - """ - Test SimpleStatement with custom payload. - - What this tests: - --------------- - 1. Custom payload support - 2. Bytes payload format - 3. Payload passed through - 4. Query still works - - Why this matters: - ---------------- - Custom payloads enable: - - Request tracing - - Application metadata - - Cross-system correlation - - Advanced feature for - observability. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - # Create SimpleStatement with custom payload - user_id = uuid.uuid4() - stmt = SimpleStatement( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (%s, %s, %s, %s)" - ) - - # Execute with custom payload (payload is passed through to Cassandra) - # Custom payload values must be bytes - custom_payload = {b"application": b"test_suite", b"version": b"1.0"} - await cassandra_session.execute( - stmt, - [user_id, "Payload User", "payload@example.com", 40], - custom_payload=custom_payload, - ) - - # Verify insert worked - result = await cassandra_session.execute( - f"SELECT * FROM {users_table} WHERE id = %s", [user_id] - ) - assert result.one() is not None - - @pytest.mark.asyncio - async def test_simple_statement_batch_not_recommended(self, cassandra_session): - """ - Test that SimpleStatements work in batches but prepared is preferred. - - What this tests: - --------------- - 1. SimpleStatement in batches - 2. Batch execution works - 3. Not recommended pattern - 4. Compatibility maintained - - Why this matters: - ---------------- - Shows anti-pattern: - - Poor performance - - No query plan reuse - - Network inefficient - - Works but educates on - better approaches. - """ - from cassandra.query import BatchStatement, BatchType - - # Get the unique table name - users_table = cassandra_session._test_users_table - - batch = BatchStatement(batch_type=BatchType.LOGGED) - - # Add SimpleStatements to batch (not recommended but should work) - for i in range(3): - stmt = SimpleStatement( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (%s, %s, %s, %s)" - ) - batch.add(stmt, [uuid.uuid4(), f"Batch {i}", f"batch{i}@example.com", i]) - - # Execute batch - await cassandra_session.execute(batch) - - # Verify inserts - result = await cassandra_session.execute(f"SELECT COUNT(*) FROM {users_table}") - assert result.one()[0] >= 3 diff --git a/tests/integration/test_streaming_non_blocking.py b/tests/integration/test_streaming_non_blocking.py deleted file mode 100644 index 4ca51b4..0000000 --- a/tests/integration/test_streaming_non_blocking.py +++ /dev/null @@ -1,341 +0,0 @@ -""" -Integration tests demonstrating that streaming doesn't block the event loop. - -This test proves that while the driver fetches pages in its thread pool, -the asyncio event loop remains free to handle other tasks. -""" - -import asyncio -import time -from typing import List - -import pytest - -from async_cassandra import AsyncCluster, StreamConfig - - -class TestStreamingNonBlocking: - """Test that streaming operations don't block the event loop.""" - - @pytest.fixture(autouse=True) - async def setup_test_data(self, cassandra_cluster): - """Create test data for streaming tests.""" - async with AsyncCluster(["localhost"]) as cluster: - async with await cluster.connect() as session: - # Create keyspace and table - await session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS test_streaming - WITH REPLICATION = { - 'class': 'SimpleStrategy', - 'replication_factor': 1 - } - """ - ) - await session.set_keyspace("test_streaming") - - await session.execute( - """ - CREATE TABLE IF NOT EXISTS large_table ( - partition_key INT, - clustering_key INT, - data TEXT, - PRIMARY KEY (partition_key, clustering_key) - ) - """ - ) - - # Insert enough data to ensure multiple pages - # With fetch_size=1000 and 10k rows, we'll have 10 pages - insert_stmt = await session.prepare( - "INSERT INTO large_table (partition_key, clustering_key, data) VALUES (?, ?, ?)" - ) - - tasks = [] - for partition in range(10): - for cluster in range(1000): - # Create some data that takes time to process - data = f"Data for partition {partition}, cluster {cluster}" * 10 - tasks.append(session.execute(insert_stmt, [partition, cluster, data])) - - # Execute in batches - if len(tasks) >= 100: - await asyncio.gather(*tasks) - tasks = [] - - # Execute remaining - if tasks: - await asyncio.gather(*tasks) - - yield - - # Cleanup - await session.execute("DROP KEYSPACE test_streaming") - - async def test_event_loop_not_blocked_during_paging(self, cassandra_cluster): - """ - Test that the event loop remains responsive while pages are being fetched. - - This test runs a streaming query that fetches multiple pages while - simultaneously running a "heartbeat" task that increments a counter - every 10ms. If the event loop was blocked during page fetches, - we would see gaps in the heartbeat counter. - """ - heartbeat_count = 0 - heartbeat_times: List[float] = [] - streaming_events: List[tuple[float, str]] = [] - stop_heartbeat = False - - async def heartbeat_task(): - """Increment counter every 10ms to detect event loop blocking.""" - nonlocal heartbeat_count - start_time = time.perf_counter() - - while not stop_heartbeat: - heartbeat_count += 1 - current_time = time.perf_counter() - heartbeat_times.append(current_time - start_time) - await asyncio.sleep(0.01) # 10ms - - async def streaming_task(): - """Stream data and record when pages are fetched.""" - nonlocal streaming_events - - async with AsyncCluster(["localhost"]) as cluster: - async with await cluster.connect() as session: - await session.set_keyspace("test_streaming") - - rows_seen = 0 - pages_fetched = 0 - - def page_callback(page_num: int, rows_in_page: int): - nonlocal pages_fetched - pages_fetched = page_num - current_time = time.perf_counter() - start_time - streaming_events.append((current_time, f"Page {page_num} fetched")) - - # Use small fetch_size to ensure multiple pages - config = StreamConfig(fetch_size=1000, page_callback=page_callback) - - start_time = time.perf_counter() - - async with await session.execute_stream( - "SELECT * FROM large_table", stream_config=config - ) as result: - async for row in result: - rows_seen += 1 - - # Simulate some processing time - await asyncio.sleep(0.001) # 1ms per row - - # Record progress at key points - if rows_seen % 1000 == 0: - current_time = time.perf_counter() - start_time - streaming_events.append( - (current_time, f"Processed {rows_seen} rows") - ) - - return rows_seen, pages_fetched - - # Run both tasks concurrently - heartbeat = asyncio.create_task(heartbeat_task()) - - # Run streaming and measure time - stream_start = time.perf_counter() - rows_processed, pages = await streaming_task() - stream_duration = time.perf_counter() - stream_start - - # Stop heartbeat - stop_heartbeat = True - await heartbeat - - # Analyze results - print("\n=== Event Loop Blocking Test Results ===") - print(f"Total rows processed: {rows_processed:,}") - print(f"Total pages fetched: {pages}") - print(f"Streaming duration: {stream_duration:.2f}s") - print(f"Heartbeat count: {heartbeat_count}") - print(f"Expected heartbeats: ~{int(stream_duration / 0.01)}") - - # Check heartbeat consistency - if len(heartbeat_times) > 1: - # Calculate gaps between heartbeats - heartbeat_gaps = [] - for i in range(1, len(heartbeat_times)): - gap = heartbeat_times[i] - heartbeat_times[i - 1] - heartbeat_gaps.append(gap) - - avg_gap = sum(heartbeat_gaps) / len(heartbeat_gaps) - max_gap = max(heartbeat_gaps) - gaps_over_50ms = sum(1 for gap in heartbeat_gaps if gap > 0.05) - - print("\nHeartbeat Analysis:") - print(f"Average gap: {avg_gap*1000:.1f}ms (target: 10ms)") - print(f"Max gap: {max_gap*1000:.1f}ms") - print(f"Gaps > 50ms: {gaps_over_50ms}") - - # Print streaming events timeline - print("\nStreaming Events Timeline:") - for event_time, event in streaming_events: - print(f" {event_time:.3f}s: {event}") - - # Assertions - assert heartbeat_count > 0, "Heartbeat task didn't run" - - # The average gap should be close to 10ms - # Allow some tolerance for scheduling - assert avg_gap < 0.02, f"Average heartbeat gap too large: {avg_gap*1000:.1f}ms" - - # Max gap shows worst-case blocking - # Even with page fetches, should not block for long - assert max_gap < 0.1, f"Max heartbeat gap too large: {max_gap*1000:.1f}ms" - - # Should have very few large gaps - assert gaps_over_50ms < 5, f"Too many large gaps: {gaps_over_50ms}" - - # Verify streaming completed successfully - assert rows_processed == 10000, f"Expected 10000 rows, got {rows_processed}" - assert pages >= 10, f"Expected at least 10 pages, got {pages}" - - async def test_concurrent_queries_during_streaming(self, cassandra_cluster): - """ - Test that other queries can execute while streaming is in progress. - - This proves that the thread pool isn't completely blocked by streaming. - """ - async with AsyncCluster(["localhost"]) as cluster: - async with await cluster.connect() as session: - await session.set_keyspace("test_streaming") - - # Prepare a simple query - count_stmt = await session.prepare( - "SELECT COUNT(*) FROM large_table WHERE partition_key = ?" - ) - - query_times: List[float] = [] - queries_completed = 0 - - async def run_concurrent_queries(): - """Run queries every 100ms during streaming.""" - nonlocal queries_completed - - for i in range(20): # 20 queries over 2 seconds - start = time.perf_counter() - await session.execute(count_stmt, [i % 10]) - duration = time.perf_counter() - start - query_times.append(duration) - queries_completed += 1 - - # Log slow queries - if duration > 0.1: - print(f"Slow query {i}: {duration:.3f}s") - - await asyncio.sleep(0.1) # 100ms between queries - - async def stream_large_dataset(): - """Stream the entire table.""" - config = StreamConfig(fetch_size=1000) - rows = 0 - - async with await session.execute_stream( - "SELECT * FROM large_table", stream_config=config - ) as result: - async for row in result: - rows += 1 - # Minimal processing - if rows % 2000 == 0: - await asyncio.sleep(0.001) - - return rows - - # Run both concurrently - streaming_task = asyncio.create_task(stream_large_dataset()) - queries_task = asyncio.create_task(run_concurrent_queries()) - - # Wait for both to complete - rows_streamed, _ = await asyncio.gather(streaming_task, queries_task) - - # Analyze results - print("\n=== Concurrent Queries Test Results ===") - print(f"Rows streamed: {rows_streamed:,}") - print(f"Concurrent queries completed: {queries_completed}") - - if query_times: - avg_query_time = sum(query_times) / len(query_times) - max_query_time = max(query_times) - - print(f"Average query time: {avg_query_time*1000:.1f}ms") - print(f"Max query time: {max_query_time*1000:.1f}ms") - - # Assertions - assert queries_completed >= 15, "Not enough queries completed" - assert avg_query_time < 0.1, f"Queries too slow: {avg_query_time:.3f}s" - - # Even the slowest query shouldn't be terribly slow - assert max_query_time < 0.5, f"Max query time too high: {max_query_time:.3f}s" - - async def test_multiple_streams_concurrent(self, cassandra_cluster): - """ - Test that multiple streaming operations can run concurrently. - - This demonstrates that streaming doesn't monopolize the thread pool. - """ - async with AsyncCluster(["localhost"]) as cluster: - async with await cluster.connect() as session: - await session.set_keyspace("test_streaming") - - async def stream_partition(partition: int) -> tuple[int, float]: - """Stream a specific partition.""" - config = StreamConfig(fetch_size=500) - rows = 0 - start = time.perf_counter() - - stmt = await session.prepare( - "SELECT * FROM large_table WHERE partition_key = ?" - ) - - async with await session.execute_stream( - stmt, [partition], stream_config=config - ) as result: - async for row in result: - rows += 1 - - duration = time.perf_counter() - start - return rows, duration - - # Start multiple streams concurrently - print("\n=== Multiple Concurrent Streams Test ===") - start_time = time.perf_counter() - - # Stream 5 partitions concurrently - tasks = [stream_partition(i) for i in range(5)] - - results = await asyncio.gather(*tasks) - - total_duration = time.perf_counter() - start_time - - # Analyze results - total_rows = sum(rows for rows, _ in results) - individual_durations = [duration for _, duration in results] - - print(f"Total rows streamed: {total_rows:,}") - print(f"Total duration: {total_duration:.2f}s") - print(f"Individual stream durations: {[f'{d:.2f}s' for d in individual_durations]}") - - # If streams were serialized, total duration would be sum of individual - sum_durations = sum(individual_durations) - concurrency_factor = sum_durations / total_duration - - print(f"Sum of individual durations: {sum_durations:.2f}s") - print(f"Concurrency factor: {concurrency_factor:.1f}x") - - # Assertions - assert total_rows == 5000, f"Expected 5000 rows total, got {total_rows}" - - # Should show significant concurrency (at least 2x) - assert ( - concurrency_factor > 2.0 - ), f"Insufficient concurrency: {concurrency_factor:.1f}x" - - # Total time should be much less than sum of individual times - assert total_duration < sum_durations * 0.7, "Streams appear to be serialized" diff --git a/tests/integration/test_streaming_operations.py b/tests/integration/test_streaming_operations.py deleted file mode 100644 index 530bed4..0000000 --- a/tests/integration/test_streaming_operations.py +++ /dev/null @@ -1,533 +0,0 @@ -""" -Integration tests for streaming functionality. - -Demonstrates CRITICAL context manager usage for streaming operations -to prevent memory leaks. -""" - -import asyncio -import uuid - -import pytest - -from async_cassandra import StreamConfig, create_streaming_statement - - -@pytest.mark.integration -@pytest.mark.asyncio -class TestStreamingIntegration: - """Test streaming operations with real Cassandra using proper context managers.""" - - async def test_basic_streaming(self, cassandra_session): - """ - Test basic streaming functionality with context managers. - - What this tests: - --------------- - 1. Basic streaming works - 2. Context manager usage - 3. Row iteration - 4. Total rows tracked - - Why this matters: - ---------------- - Context managers ensure: - - Resources cleaned up - - No memory leaks - - Proper error handling - - CRITICAL for production - streaming usage. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - # Insert test data - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - # Insert 100 test records - tasks = [] - for i in range(100): - task = cassandra_session.execute( - insert_stmt, [uuid.uuid4(), f"User {i}", f"user{i}@test.com", 20 + (i % 50)] - ) - tasks.append(task) - - await asyncio.gather(*tasks) - - # Stream through all users WITH CONTEXT MANAGER - stream_config = StreamConfig(fetch_size=20) - - # CRITICAL: Use context manager to prevent memory leaks - async with await cassandra_session.execute_stream( - f"SELECT * FROM {users_table}", stream_config=stream_config - ) as result: - # Count rows - row_count = 0 - async for row in result: - assert hasattr(row, "id") - assert hasattr(row, "name") - assert hasattr(row, "email") - assert hasattr(row, "age") - row_count += 1 - - assert row_count >= 100 # At least the records we inserted - assert result.total_rows_fetched >= 100 - - except Exception as e: - pytest.fail(f"Streaming test failed: {e}") - - async def test_page_based_streaming(self, cassandra_session): - """ - Test streaming by pages with proper context managers. - - What this tests: - --------------- - 1. Page-by-page iteration - 2. Fetch size respected - 3. Multiple pages handled - 4. Filter conditions work - - Why this matters: - ---------------- - Page iteration enables: - - Batch processing - - Progress tracking - - Memory control - - Essential for ETL and - bulk operations. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - # Insert test data - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - # Insert 50 test records - for i in range(50): - await cassandra_session.execute( - insert_stmt, [uuid.uuid4(), f"PageUser {i}", f"pageuser{i}@test.com", 25] - ) - - # Stream by pages WITH CONTEXT MANAGER - stream_config = StreamConfig(fetch_size=10) - - async with await cassandra_session.execute_stream( - f"SELECT * FROM {users_table} WHERE age = 25 ALLOW FILTERING", - stream_config=stream_config, - ) as result: - page_count = 0 - total_rows = 0 - - async for page in result.pages(): - page_count += 1 - total_rows += len(page) - assert len(page) <= 10 # Should not exceed fetch_size - - # Verify all rows in page have age = 25 - for row in page: - assert row.age == 25 - - assert page_count >= 5 # Should have multiple pages - assert total_rows >= 50 - - except Exception as e: - pytest.fail(f"Page-based streaming test failed: {e}") - - async def test_streaming_with_progress_callback(self, cassandra_session): - """ - Test streaming with progress callback using context managers. - - What this tests: - --------------- - 1. Progress callbacks fire - 2. Page numbers accurate - 3. Row counts correct - 4. Callback integration - - Why this matters: - ---------------- - Progress tracking enables: - - User feedback - - Long operation monitoring - - Cancellation decisions - - Critical for interactive - applications. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - progress_calls = [] - - def progress_callback(page_num, row_count): - progress_calls.append((page_num, row_count)) - - stream_config = StreamConfig(fetch_size=15, page_callback=progress_callback) - - # Use context manager for streaming - async with await cassandra_session.execute_stream( - f"SELECT * FROM {users_table} LIMIT 50", stream_config=stream_config - ) as result: - # Consume the stream - row_count = 0 - async for row in result: - row_count += 1 - - # Should have received progress callbacks - assert len(progress_calls) > 0 - assert all(isinstance(call[0], int) for call in progress_calls) # page numbers - assert all(isinstance(call[1], int) for call in progress_calls) # row counts - - except Exception as e: - pytest.fail(f"Progress callback test failed: {e}") - - async def test_streaming_statement_helper(self, cassandra_session): - """ - Test using the streaming statement helper with context managers. - - What this tests: - --------------- - 1. Helper function works - 2. Statement configuration - 3. LIMIT respected - 4. Page tracking - - Why this matters: - ---------------- - Helper functions simplify: - - Statement creation - - Config management - - Common patterns - - Improves developer - experience. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - statement = create_streaming_statement( - f"SELECT * FROM {users_table} LIMIT 30", fetch_size=10 - ) - - # Use context manager - async with await cassandra_session.execute_stream(statement) as result: - rows = [] - async for row in result: - rows.append(row) - - assert len(rows) <= 30 # Respects LIMIT - assert result.page_number >= 1 - - except Exception as e: - pytest.fail(f"Streaming statement helper test failed: {e}") - - async def test_streaming_with_parameters(self, cassandra_session): - """ - Test streaming with parameterized queries using context managers. - - What this tests: - --------------- - 1. Prepared statements work - 2. Parameters bound correctly - 3. Filtering accurate - 4. Type safety maintained - - Why this matters: - ---------------- - Parameterized queries: - - Prevent injection - - Improve performance - - Type checking - - Security and performance - critical. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - # Insert some specific test data - user_id = uuid.uuid4() - # Prepare statement first - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - await cassandra_session.execute( - insert_stmt, [user_id, "StreamTest", "streamtest@test.com", 99] - ) - - # Stream with parameters - prepare statement first - stream_stmt = await cassandra_session.prepare( - f"SELECT * FROM {users_table} WHERE age = ? ALLOW FILTERING" - ) - - # Use context manager - async with await cassandra_session.execute_stream( - stream_stmt, - parameters=[99], - stream_config=StreamConfig(fetch_size=5), - ) as result: - found_user = False - async for row in result: - if str(row.id) == str(user_id): - found_user = True - assert row.name == "StreamTest" - assert row.age == 99 - - assert found_user - - except Exception as e: - pytest.fail(f"Parameterized streaming test failed: {e}") - - async def test_streaming_empty_result(self, cassandra_session): - """ - Test streaming with empty result set using context managers. - - What this tests: - --------------- - 1. Empty results handled - 2. No errors on empty - 3. Counts are zero - 4. Context still works - - Why this matters: - ---------------- - Empty results common: - - No matching data - - Filtered queries - - Edge conditions - - Must handle gracefully - without errors. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - # Use context manager even for empty results - async with await cassandra_session.execute_stream( - f"SELECT * FROM {users_table} WHERE age = 999 ALLOW FILTERING" - ) as result: - rows = [] - async for row in result: - rows.append(row) - - assert len(rows) == 0 - assert result.total_rows_fetched == 0 - - except Exception as e: - pytest.fail(f"Empty result streaming test failed: {e}") - - async def test_streaming_vs_regular_results(self, cassandra_session): - """ - Test that streaming and regular execute return same data. - - What this tests: - --------------- - 1. Results identical - 2. No data loss - 3. Same row count - 4. ID consistency - - Why this matters: - ---------------- - Streaming must be: - - Accurate alternative - - No data corruption - - Reliable results - - Ensures streaming is - trustworthy. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - query = f"SELECT * FROM {users_table} LIMIT 20" - - # Get results with regular execute - regular_result = await cassandra_session.execute(query) - regular_rows = [] - async for row in regular_result: - regular_rows.append(row) - - # Get results with streaming USING CONTEXT MANAGER - async with await cassandra_session.execute_stream(query) as stream_result: - stream_rows = [] - async for row in stream_result: - stream_rows.append(row) - - # Should have same number of rows - assert len(regular_rows) == len(stream_rows) - - # Convert to sets of IDs for comparison (order might differ) - regular_ids = {str(row.id) for row in regular_rows} - stream_ids = {str(row.id) for row in stream_rows} - - assert regular_ids == stream_ids - - except Exception as e: - pytest.fail(f"Streaming vs regular comparison failed: {e}") - - async def test_streaming_max_pages_limit(self, cassandra_session): - """ - Test streaming with maximum pages limit using context managers. - - What this tests: - --------------- - 1. Max pages enforced - 2. Stops at limit - 3. Row count limited - 4. Page count accurate - - Why this matters: - ---------------- - Page limits enable: - - Resource control - - Preview functionality - - Sampling data - - Prevents runaway - queries. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - stream_config = StreamConfig(fetch_size=5, max_pages=2) # Limit to 2 pages only - - # Use context manager - async with await cassandra_session.execute_stream( - f"SELECT * FROM {users_table}", stream_config=stream_config - ) as result: - rows = [] - async for row in result: - rows.append(row) - - # Should stop after 2 pages max - assert len(rows) <= 10 # 2 pages * 5 rows per page - assert result.page_number <= 2 - - except Exception as e: - pytest.fail(f"Max pages limit test failed: {e}") - - async def test_streaming_early_exit(self, cassandra_session): - """ - Test early exit from streaming with proper cleanup. - - What this tests: - --------------- - 1. Break works correctly - 2. Cleanup still happens - 3. Partial results OK - 4. No resource leaks - - Why this matters: - ---------------- - Early exit common for: - - Finding first match - - User cancellation - - Error conditions - - Must clean up properly - in all cases. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - try: - # Insert enough data to have multiple pages - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - for i in range(50): - await cassandra_session.execute( - insert_stmt, [uuid.uuid4(), f"EarlyExit {i}", f"early{i}@test.com", 30] - ) - - stream_config = StreamConfig(fetch_size=10) - - # Context manager ensures cleanup even with early exit - async with await cassandra_session.execute_stream( - f"SELECT * FROM {users_table} WHERE age = 30 ALLOW FILTERING", - stream_config=stream_config, - ) as result: - count = 0 - async for row in result: - count += 1 - if count >= 15: # Exit early - break - - assert count == 15 - # Context manager ensures cleanup happens here - - except Exception as e: - pytest.fail(f"Early exit test failed: {e}") - - async def test_streaming_exception_handling(self, cassandra_session): - """ - Test exception handling during streaming with context managers. - - What this tests: - --------------- - 1. Exceptions propagate - 2. Cleanup on error - 3. Context manager robust - 4. No hanging resources - - Why this matters: - ---------------- - Error handling critical: - - Processing errors - - Network failures - - Application bugs - - Resources must be freed - even on exceptions. - """ - # Get the unique table name - users_table = cassandra_session._test_users_table - - class TestError(Exception): - pass - - try: - # Insert test data - insert_stmt = await cassandra_session.prepare( - f"INSERT INTO {users_table} (id, name, email, age) VALUES (?, ?, ?, ?)" - ) - - for i in range(20): - await cassandra_session.execute( - insert_stmt, [uuid.uuid4(), f"ExceptionTest {i}", f"exc{i}@test.com", 40] - ) - - # Test that context manager cleans up even on exception - with pytest.raises(TestError): - async with await cassandra_session.execute_stream( - f"SELECT * FROM {users_table} WHERE age = 40 ALLOW FILTERING" - ) as result: - count = 0 - async for row in result: - count += 1 - if count >= 10: - raise TestError("Simulated error during streaming") - - # Context manager should have cleaned up despite exception - - except TestError: - # This is expected - re-raise it for pytest - raise - except Exception as e: - pytest.fail(f"Exception handling test failed: {e}") diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index ec673f9..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Test utilities for isolating tests and managing test resources.""" - -import asyncio -import uuid -from typing import Optional, Set - -# Track created keyspaces for cleanup -_created_keyspaces: Set[str] = set() - - -def generate_unique_keyspace(prefix: str = "test") -> str: - """Generate a unique keyspace name for test isolation.""" - unique_id = str(uuid.uuid4()).replace("-", "")[:8] - keyspace = f"{prefix}_{unique_id}" - _created_keyspaces.add(keyspace) - return keyspace - - -def generate_unique_table(prefix: str = "table") -> str: - """Generate a unique table name for test isolation.""" - unique_id = str(uuid.uuid4()).replace("-", "")[:8] - return f"{prefix}_{unique_id}" - - -async def create_test_table( - session, table_name: Optional[str] = None, schema: str = "(id int PRIMARY KEY, data text)" -) -> str: - """Create a test table with the given schema and register it for cleanup.""" - if table_name is None: - table_name = generate_unique_table() - - await session.execute(f"CREATE TABLE IF NOT EXISTS {table_name} {schema}") - - # Register table for cleanup if session tracks created tables - if hasattr(session, "_created_tables"): - session._created_tables.append(table_name) - - return table_name - - -async def create_test_keyspace(session, keyspace: Optional[str] = None) -> str: - """Create a test keyspace with proper replication.""" - if keyspace is None: - keyspace = generate_unique_keyspace() - - await session.execute( - f""" - CREATE KEYSPACE IF NOT EXISTS {keyspace} - WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}} - """ - ) - return keyspace - - -async def cleanup_keyspace(session, keyspace: str) -> None: - """Clean up a test keyspace.""" - try: - await session.execute(f"DROP KEYSPACE IF EXISTS {keyspace}") - _created_keyspaces.discard(keyspace) - except Exception: - # Ignore cleanup errors - pass - - -async def cleanup_all_test_keyspaces(session) -> None: - """Clean up all tracked test keyspaces.""" - for keyspace in list(_created_keyspaces): - await cleanup_keyspace(session, keyspace) - - -def get_test_timeout(base_timeout: float = 5.0) -> float: - """Get appropriate timeout for tests based on environment.""" - # Increase timeout in CI environments or when running under coverage - import os - - if os.environ.get("CI") or os.environ.get("COVERAGE_RUN"): - return base_timeout * 3 - return base_timeout - - -async def wait_for_schema_agreement(session, timeout: float = 10.0) -> None: - """Wait for schema agreement across the cluster.""" - start_time = asyncio.get_event_loop().time() - while asyncio.get_event_loop().time() - start_time < timeout: - try: - result = await session.execute("SELECT schema_version FROM system.local") - if result: - return - except Exception: - pass - await asyncio.sleep(0.1) - - -async def ensure_keyspace_exists(session, keyspace: str) -> None: - """Ensure a keyspace exists before using it.""" - await session.execute( - f""" - CREATE KEYSPACE IF NOT EXISTS {keyspace} - WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}} - """ - ) - await wait_for_schema_agreement(session) - - -async def ensure_table_exists(session, keyspace: str, table: str, schema: str) -> None: - """Ensure a table exists with the given schema.""" - await ensure_keyspace_exists(session, keyspace) - await session.execute(f"USE {keyspace}") - await session.execute(f"CREATE TABLE IF NOT EXISTS {table} {schema}") - await wait_for_schema_agreement(session) - - -def get_container_timeout() -> int: - """Get timeout for container operations.""" - import os - - # Longer timeout in CI environments - if os.environ.get("CI"): - return 120 - return 60 - - -async def run_with_timeout(coro, timeout: float): - """Run a coroutine with a timeout.""" - try: - return await asyncio.wait_for(coro, timeout=timeout) - except asyncio.TimeoutError: - raise TimeoutError(f"Operation timed out after {timeout} seconds") - - -class TestTableManager: - """Context manager for creating and cleaning up test tables.""" - - def __init__(self, session, keyspace: Optional[str] = None, use_shared_keyspace: bool = False): - self.session = session - self.keyspace = keyspace or generate_unique_keyspace() - self.tables = [] - self.use_shared_keyspace = use_shared_keyspace - - async def __aenter__(self): - if not self.use_shared_keyspace: - await create_test_keyspace(self.session, self.keyspace) - await self.session.execute(f"USE {self.keyspace}") - # If using shared keyspace, assume it's already set on the session - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - # Clean up tables - for table in self.tables: - try: - await self.session.execute(f"DROP TABLE IF EXISTS {table}") - except Exception: - pass - - # Only clean up keyspace if we created it - if not self.use_shared_keyspace: - try: - await cleanup_keyspace(self.session, self.keyspace) - except Exception: - pass - - async def create_table( - self, table_name: Optional[str] = None, schema: str = "(id int PRIMARY KEY, data text)" - ) -> str: - """Create a test table with the given schema.""" - if table_name is None: - table_name = generate_unique_table() - - await self.session.execute(f"CREATE TABLE IF NOT EXISTS {table_name} {schema}") - self.tables.append(table_name) - return table_name diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py deleted file mode 100644 index cfaf7e1..0000000 --- a/tests/unit/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Unit tests for async-cassandra.""" diff --git a/tests/unit/test_async_wrapper.py b/tests/unit/test_async_wrapper.py deleted file mode 100644 index e04a68b..0000000 --- a/tests/unit/test_async_wrapper.py +++ /dev/null @@ -1,552 +0,0 @@ -"""Core async wrapper functionality tests. - -This module consolidates tests for the fundamental async wrapper components -including AsyncCluster, AsyncSession, and base functionality. - -Test Organization: -================== -1. TestAsyncContextManageable - Tests the base async context manager mixin -2. TestAsyncCluster - Tests cluster initialization, connection, and lifecycle -3. TestAsyncSession - Tests session operations (queries, prepare, keyspace) - -Key Testing Patterns: -==================== -- Uses mocks extensively to isolate async wrapper behavior from driver -- Tests both success and error paths -- Verifies context manager cleanup happens correctly -- Ensures proper parameter passing to underlying driver -""" - -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -import pytest -from cassandra.auth import PlainTextAuthProvider -from cassandra.cluster import ResponseFuture - -from async_cassandra import AsyncCassandraSession as AsyncSession -from async_cassandra import AsyncCluster -from async_cassandra.base import AsyncContextManageable -from async_cassandra.result import AsyncResultSet - - -class TestAsyncContextManageable: - """Test the async context manager mixin functionality.""" - - @pytest.mark.core - @pytest.mark.quick - async def test_async_context_manager(self): - """ - Test basic async context manager functionality. - - What this tests: - --------------- - 1. AsyncContextManageable provides proper async context manager protocol - 2. __aenter__ is called when entering the context - 3. __aexit__ is called when exiting the context - 4. The object is properly returned from __aenter__ - - Why this matters: - ---------------- - Many of our classes (AsyncCluster, AsyncSession) inherit from this base - class to provide 'async with' functionality. This ensures resource cleanup - happens automatically when leaving the context. - """ - - # Create a test implementation that tracks enter/exit calls - class TestClass(AsyncContextManageable): - entered = False - exited = False - - async def __aenter__(self): - self.entered = True - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - self.exited = True - - # Test the context manager flow - async with TestClass() as obj: - # Inside context: should be entered but not exited - assert obj.entered - assert not obj.exited - - # Outside context: should be exited - assert obj.exited - - @pytest.mark.core - async def test_context_manager_with_exception(self): - """ - Test context manager handles exceptions properly. - - What this tests: - --------------- - 1. __aexit__ receives exception information when exception occurs - 2. Exception type, value, and traceback are passed correctly - 3. Returning False from __aexit__ propagates the exception - 4. The exception is not suppressed unless explicitly handled - - Why this matters: - ---------------- - Ensures that errors in async operations (like connection failures) - are properly propagated and that cleanup still happens even when - exceptions occur. This prevents resource leaks in error scenarios. - """ - - class TestClass(AsyncContextManageable): - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - # Verify exception info is passed correctly - assert exc_type is ValueError - assert str(exc_val) == "test error" - return False # Don't suppress exception - let it propagate - - # Verify the exception is still raised after __aexit__ - with pytest.raises(ValueError, match="test error"): - async with TestClass(): - raise ValueError("test error") - - -class TestAsyncCluster: - """ - Test AsyncCluster core functionality. - - AsyncCluster is the entry point for establishing Cassandra connections. - It wraps the driver's Cluster object to provide async operations. - """ - - @pytest.mark.core - @pytest.mark.quick - def test_init_defaults(self): - """ - Test AsyncCluster initialization with default values. - - What this tests: - --------------- - 1. AsyncCluster can be created without any parameters - 2. Default values are properly applied - 3. Internal state is initialized correctly (_cluster, _close_lock) - - Why this matters: - ---------------- - Users often create clusters with minimal configuration. This ensures - the defaults work correctly and the cluster is usable out of the box. - """ - cluster = AsyncCluster() - # Verify internal driver cluster was created - assert cluster._cluster is not None - # Verify lock for thread-safe close operations exists - assert cluster._close_lock is not None - - @pytest.mark.core - def test_init_custom_values(self): - """ - Test AsyncCluster initialization with custom values. - - What this tests: - --------------- - 1. Custom contact points are accepted - 2. Non-default port can be specified - 3. Authentication providers work correctly - 4. Executor thread pool size can be customized - 5. All parameters are properly passed to underlying driver - - Why this matters: - ---------------- - Production deployments often require custom configuration: - - Different Cassandra nodes (contact_points) - - Non-standard ports for security - - Authentication for secure clusters - - Thread pool tuning for performance - """ - # Create auth provider for secure clusters - auth_provider = PlainTextAuthProvider(username="user", password="pass") - - # Initialize with custom configuration - cluster = AsyncCluster( - contact_points=["192.168.1.1", "192.168.1.2"], - port=9043, # Non-default port - auth_provider=auth_provider, - executor_threads=16, # Larger thread pool for high concurrency - ) - - # Verify cluster was created with our settings - assert cluster._cluster is not None - # Verify thread pool size was applied - assert cluster._cluster.executor._max_workers == 16 - - @pytest.mark.core - @patch("async_cassandra.cluster.Cluster", new_callable=MagicMock) - async def test_connect(self, mock_cluster_class): - """ - Test cluster connection. - - What this tests: - --------------- - 1. connect() returns an AsyncSession instance - 2. The underlying driver's connect() is called - 3. The returned session wraps the driver's session - 4. Connection can be established without specifying keyspace - - Why this matters: - ---------------- - This is the primary way users establish database connections. - The test ensures our async wrapper properly delegates to the - synchronous driver and wraps the result for async operations. - - Implementation note: - ------------------- - We mock the driver's Cluster to isolate our wrapper's behavior - from actual network operations. - """ - # Set up mocks - mock_cluster = mock_cluster_class.return_value - mock_cluster.protocol_version = 5 # Mock protocol version - mock_session = Mock() - mock_cluster.connect.return_value = mock_session - - # Test connection - cluster = AsyncCluster() - session = await cluster.connect() - - # Verify we get an async wrapper - assert isinstance(session, AsyncSession) - # Verify it wraps the driver's session - assert session._session == mock_session - # Verify driver's connect was called - mock_cluster.connect.assert_called_once() - - @pytest.mark.core - @patch("async_cassandra.cluster.Cluster", new_callable=MagicMock) - async def test_shutdown(self, mock_cluster_class): - """ - Test cluster shutdown. - - What this tests: - --------------- - 1. shutdown() can be called explicitly - 2. The underlying driver's shutdown() is called - 3. Resources are properly cleaned up - - Why this matters: - ---------------- - Proper shutdown is critical to: - - Release network connections - - Stop background threads - - Prevent resource leaks - - Allow clean application termination - """ - mock_cluster = mock_cluster_class.return_value - - cluster = AsyncCluster() - await cluster.shutdown() - - # Verify driver's shutdown was called - mock_cluster.shutdown.assert_called_once() - - @pytest.mark.core - @pytest.mark.critical - async def test_context_manager(self): - """ - Test AsyncCluster as context manager. - - What this tests: - --------------- - 1. AsyncCluster can be used with 'async with' statement - 2. Cluster is accessible within the context - 3. shutdown() is automatically called on exit - 4. Cleanup happens even if not explicitly called - - Why this matters: - ---------------- - Context managers are the recommended pattern for resource management. - They ensure cleanup happens automatically, preventing resource leaks - even if the user forgets to call shutdown() or if exceptions occur. - - Example usage: - ------------- - async with AsyncCluster() as cluster: - session = await cluster.connect() - # ... use session ... - # cluster.shutdown() called automatically here - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = mock_cluster_class.return_value - - # Use cluster as context manager - async with AsyncCluster() as cluster: - # Verify cluster is accessible inside context - assert cluster._cluster == mock_cluster - - # Verify shutdown was called when exiting context - mock_cluster.shutdown.assert_called_once() - - -class TestAsyncSession: - """ - Test AsyncSession core functionality. - - AsyncSession is the main interface for executing queries. It wraps - the driver's Session object to provide async query execution. - """ - - @pytest.mark.core - @pytest.mark.quick - def test_init(self): - """ - Test AsyncSession initialization. - - What this tests: - --------------- - 1. AsyncSession properly stores the wrapped session - 2. No additional initialization is required - 3. The wrapper is lightweight (thin wrapper pattern) - - Why this matters: - ---------------- - The session wrapper should be minimal overhead. This test - ensures we're not doing unnecessary work during initialization - and that the wrapper maintains a reference to the driver session. - """ - mock_session = Mock() - async_session = AsyncSession(mock_session) - # Verify the wrapper stores the driver session - assert async_session._session == mock_session - - @pytest.mark.core - @pytest.mark.critical - async def test_execute_simple_query(self): - """ - Test executing a simple query. - - What this tests: - --------------- - 1. Basic query execution works - 2. execute() converts sync driver operations to async - 3. Results are wrapped in AsyncResultSet - 4. The AsyncResultHandler is used to manage callbacks - - Why this matters: - ---------------- - This is the most fundamental operation - executing a SELECT query. - The test verifies our async/await wrapper correctly: - - Calls driver's execute_async (not execute) - - Handles the ResponseFuture with callbacks - - Returns results in an async-friendly format - - Implementation details: - ---------------------- - - We mock AsyncResultHandler to avoid callback complexity - - The real implementation registers callbacks on ResponseFuture - - Results are delivered asynchronously via the event loop - """ - # Set up driver mocks - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_future.has_more_pages = False - mock_session.execute_async.return_value = mock_future - - async_session = AsyncSession(mock_session) - - # Mock the result handler to simulate query completion - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_result = AsyncResultSet([{"id": 1, "name": "test"}]) - mock_handler.get_result = AsyncMock(return_value=mock_result) - mock_handler_class.return_value = mock_handler - - # Execute query - result = await async_session.execute("SELECT * FROM users") - - # Verify result type and that async execution was used - assert isinstance(result, AsyncResultSet) - mock_session.execute_async.assert_called_once() - - @pytest.mark.core - async def test_execute_with_parameters(self): - """ - Test executing query with parameters. - - What this tests: - --------------- - 1. Parameterized queries work correctly - 2. Parameters are passed through to the driver - 3. Both query string and parameters reach execute_async - - Why this matters: - ---------------- - Parameterized queries are essential for: - - Preventing SQL injection attacks - - Better performance (query plan caching) - - Cleaner code (no string concatenation) - - The test ensures parameters aren't lost in the async wrapper. - - Note: - ----- - Parameters can be passed as list [123] or tuple (123,) - This test uses a list, but both should work. - """ - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_session.execute_async.return_value = mock_future - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_result = AsyncResultSet([]) - mock_handler.get_result = AsyncMock(return_value=mock_result) - mock_handler_class.return_value = mock_handler - - # Execute parameterized query - await async_session.execute("SELECT * FROM users WHERE id = ?", [123]) - - # Verify both query and parameters were passed correctly - call_args = mock_session.execute_async.call_args - assert call_args[0][0] == "SELECT * FROM users WHERE id = ?" - assert call_args[0][1] == [123] - - @pytest.mark.core - async def test_prepare(self): - """ - Test preparing statements. - - What this tests: - --------------- - 1. prepare() returns a PreparedStatement - 2. The query string is passed to driver's prepare() - 3. The prepared statement can be used for execution - - Why this matters: - ---------------- - Prepared statements are crucial for production use: - - Better performance (cached query plans) - - Type safety and validation - - Protection against injection - - Required by our coding standards - - The wrapper must properly handle statement preparation - to maintain these benefits. - - Note: - ----- - The second parameter (None) is for custom prepare options, - which we pass through unchanged. - """ - mock_session = Mock() - mock_prepared = Mock() - mock_session.prepare.return_value = mock_prepared - - async_session = AsyncSession(mock_session) - - # Prepare a parameterized statement - prepared = await async_session.prepare("SELECT * FROM users WHERE id = ?") - - # Verify we get the prepared statement back - assert prepared == mock_prepared - # Verify driver's prepare was called with correct arguments - mock_session.prepare.assert_called_once_with("SELECT * FROM users WHERE id = ?", None) - - @pytest.mark.core - async def test_close(self): - """ - Test closing session. - - What this tests: - --------------- - 1. close() can be called explicitly - 2. The underlying session's shutdown() is called - 3. Resources are cleaned up properly - - Why this matters: - ---------------- - Sessions hold resources like: - - Connection pools - - Prepared statement cache - - Background threads - - Proper cleanup prevents resource leaks and ensures - graceful application shutdown. - """ - mock_session = Mock() - async_session = AsyncSession(mock_session) - - await async_session.close() - - # Verify driver's shutdown was called - mock_session.shutdown.assert_called_once() - - @pytest.mark.core - @pytest.mark.critical - async def test_context_manager(self): - """ - Test AsyncSession as context manager. - - What this tests: - --------------- - 1. AsyncSession supports 'async with' statement - 2. Session is accessible within the context - 3. shutdown() is called automatically on exit - - Why this matters: - ---------------- - Context managers ensure cleanup even with exceptions. - This is the recommended pattern for session usage: - - async with cluster.connect() as session: - await session.execute(...) - # session.close() called automatically - - This prevents resource leaks from forgotten close() calls. - """ - mock_session = Mock() - - async with AsyncSession(mock_session) as session: - # Verify session is accessible in context - assert session._session == mock_session - - # Verify cleanup happened on exit - mock_session.shutdown.assert_called_once() - - @pytest.mark.core - async def test_set_keyspace(self): - """ - Test setting keyspace. - - What this tests: - --------------- - 1. set_keyspace() executes a USE statement - 2. The keyspace name is properly formatted - 3. The operation completes successfully - - Why this matters: - ---------------- - Keyspaces organize data in Cassandra (like databases in SQL). - Users need to switch keyspaces for different data domains. - The wrapper must handle this transparently. - - Implementation note: - ------------------- - set_keyspace() is implemented as execute("USE keyspace") - This test verifies that translation works correctly. - """ - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_session.execute_async.return_value = mock_future - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_result = AsyncResultSet([]) - mock_handler.get_result = AsyncMock(return_value=mock_result) - mock_handler_class.return_value = mock_handler - - # Set the keyspace - await async_session.set_keyspace("test_keyspace") - - # Verify USE statement was executed - call_args = mock_session.execute_async.call_args - assert call_args[0][0] == "USE test_keyspace" diff --git a/tests/unit/test_auth_failures.py b/tests/unit/test_auth_failures.py deleted file mode 100644 index 0aa2fd1..0000000 --- a/tests/unit/test_auth_failures.py +++ /dev/null @@ -1,590 +0,0 @@ -""" -Unit tests for authentication and authorization failures. - -Tests how the async wrapper handles: -- Authentication failures during connection -- Authorization failures during operations -- Credential rotation scenarios -- Session invalidation due to auth changes - -Test Organization: -================== -1. Initial Authentication - Connection-time auth failures -2. Operation Authorization - Query-time permission failures -3. Credential Rotation - Handling credential changes -4. Session Invalidation - Auth state changes during session -5. Custom Auth Providers - Advanced authentication scenarios - -Key Testing Principles: -====================== -- Auth failures wrapped appropriately -- Original error details preserved -- Concurrent auth failures handled -- Custom auth providers supported -""" - -import asyncio -from unittest.mock import Mock, patch - -import pytest -from cassandra import AuthenticationFailed, Unauthorized -from cassandra.auth import PlainTextAuthProvider -from cassandra.cluster import NoHostAvailable - -from async_cassandra import AsyncCluster -from async_cassandra.exceptions import ConnectionError - - -class TestAuthenticationFailures: - """Test authentication failure scenarios.""" - - def create_error_future(self, exception): - """ - Create a mock future that raises the given exception. - - Helper method to simulate driver futures that fail with - specific exceptions during callback execution. - """ - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.mark.asyncio - async def test_initial_auth_failure(self): - """ - Test handling of authentication failure during initial connection. - - What this tests: - --------------- - 1. Auth failure during cluster.connect() - 2. NoHostAvailable with AuthenticationFailed - 3. Wrapped in ConnectionError - 4. Error message preservation - - Why this matters: - ---------------- - Initial connection auth failures indicate: - - Invalid credentials - - User doesn't exist - - Password expired - - Applications need clear error messages to: - - Distinguish auth from network issues - - Prompt for new credentials - - Alert on configuration problems - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster instance - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - - # Configure cluster to fail authentication - mock_cluster.connect.side_effect = NoHostAvailable( - "Unable to connect to any servers", - {"127.0.0.1": AuthenticationFailed("Bad credentials")}, - ) - - async_cluster = AsyncCluster( - contact_points=["127.0.0.1"], - auth_provider=PlainTextAuthProvider("bad_user", "bad_pass"), - ) - - # Should raise connection error wrapping the auth failure - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Verify the error message contains auth failure - assert "Failed to connect to cluster" in str(exc_info.value) - - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_auth_failure_during_operation(self): - """ - Test handling of authentication failure during query execution. - - What this tests: - --------------- - 1. Unauthorized error during query - 2. Permission failures on tables - 3. Passed through directly - 4. Native exception handling - - Why this matters: - ---------------- - Authorization failures during operations indicate: - - Missing table/keyspace permissions - - Role changes after connection - - Fine-grained access control - - Applications need direct access to: - - Handle permission errors gracefully - - Potentially retry with different user - - Log security violations - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster and session - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 - - mock_session = Mock() - mock_cluster.connect.return_value = mock_session - - # Create async cluster and connect - async_cluster = AsyncCluster() - session = await async_cluster.connect() - - # Configure query to fail with auth error - mock_session.execute_async.return_value = self.create_error_future( - Unauthorized("User has no SELECT permission on ") - ) - - # Unauthorized is passed through directly (not wrapped) - with pytest.raises(Unauthorized) as exc_info: - await session.execute("SELECT * FROM test.users") - - assert "User has no SELECT permission" in str(exc_info.value) - - await session.close() - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_credential_rotation_reconnect(self): - """ - Test handling credential rotation requiring reconnection. - - What this tests: - --------------- - 1. Auth provider can be updated - 2. Old credentials cause auth failures - 3. AuthenticationFailed during queries - 4. Wrapped appropriately - - Why this matters: - ---------------- - Production systems rotate credentials: - - Security best practice - - Compliance requirements - - Automated rotation systems - - Applications must handle: - - Credential updates - - Re-authentication needs - - Graceful credential transitions - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster and session - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 - - mock_session = Mock() - mock_cluster.connect.return_value = mock_session - - # Set initial auth provider - old_auth = PlainTextAuthProvider("user1", "pass1") - - async_cluster = AsyncCluster(auth_provider=old_auth) - session = await async_cluster.connect() - - # Simulate credential rotation - new_auth = PlainTextAuthProvider("user1", "pass2") - - # Update auth provider on the underlying cluster - async_cluster._cluster.auth_provider = new_auth - - # Next operation fails with auth error - mock_session.execute_async.return_value = self.create_error_future( - AuthenticationFailed("Password verification failed") - ) - - # AuthenticationFailed is passed through directly - with pytest.raises(AuthenticationFailed) as exc_info: - await session.execute("SELECT * FROM test") - - assert "Password verification failed" in str(exc_info.value) - - await session.close() - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_authorization_failure_different_operations(self): - """ - Test different authorization failures for various operations. - - What this tests: - --------------- - 1. Different permission types (SELECT, MODIFY, CREATE, etc.) - 2. Each permission failure handled correctly - 3. Error messages indicate specific permission - 4. Exceptions passed through directly - - Why this matters: - ---------------- - Cassandra has fine-grained permissions: - - SELECT: read data - - MODIFY: insert/update/delete - - CREATE/DROP/ALTER: schema changes - - Applications need to: - - Understand which permission failed - - Request appropriate access - - Implement least-privilege principle - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Setup mock cluster and session - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 - - mock_session = Mock() - mock_cluster.connect.return_value = mock_session - - async_cluster = AsyncCluster() - session = await async_cluster.connect() - - # Test different permission failures - permissions = [ - ("SELECT * FROM users", "User has no SELECT permission"), - ("INSERT INTO users VALUES (1)", "User has no MODIFY permission"), - ("CREATE TABLE test (id int)", "User has no CREATE permission"), - ("DROP TABLE users", "User has no DROP permission"), - ("ALTER TABLE users ADD col text", "User has no ALTER permission"), - ] - - for query, error_msg in permissions: - mock_session.execute_async.return_value = self.create_error_future( - Unauthorized(error_msg) - ) - - # Unauthorized is passed through directly - with pytest.raises(Unauthorized) as exc_info: - await session.execute(query) - - assert error_msg in str(exc_info.value) - - await session.close() - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_session_invalidation_on_auth_change(self): - """ - Test session invalidation when authentication changes. - - What this tests: - --------------- - 1. Session can become auth-invalid - 2. Subsequent operations fail - 3. Session expired errors handled - 4. Clear error messaging - - Why this matters: - ---------------- - Sessions can be invalidated by: - - Token expiration - - Admin revoking access - - Password changes - - Applications must: - - Detect invalid sessions - - Re-authenticate if possible - - Handle session lifecycle - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Setup mock cluster and session - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 - - mock_session = Mock() - mock_cluster.connect.return_value = mock_session - - async_cluster = AsyncCluster() - session = await async_cluster.connect() - - # Mark session as needing re-authentication - mock_session._auth_invalid = True - - # Operations should detect invalid auth state - mock_session.execute_async.return_value = self.create_error_future( - AuthenticationFailed("Session expired") - ) - - # AuthenticationFailed is passed through directly - with pytest.raises(AuthenticationFailed) as exc_info: - await session.execute("SELECT * FROM test") - - assert "Session expired" in str(exc_info.value) - - await session.close() - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_concurrent_auth_failures(self): - """ - Test handling of concurrent authentication failures. - - What this tests: - --------------- - 1. Multiple queries with auth failures - 2. All failures handled independently - 3. No error cascading or corruption - 4. Consistent error types - - Why this matters: - ---------------- - Applications often run parallel queries: - - Batch operations - - Dashboard data fetching - - Concurrent API requests - - Auth failures in one query shouldn't: - - Affect other queries - - Cause cascading failures - - Corrupt session state - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Setup mock cluster and session - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 - - mock_session = Mock() - mock_cluster.connect.return_value = mock_session - - async_cluster = AsyncCluster() - session = await async_cluster.connect() - - # All queries fail with auth error - mock_session.execute_async.return_value = self.create_error_future( - Unauthorized("No permission") - ) - - # Execute multiple concurrent queries - tasks = [session.execute(f"SELECT * FROM table{i}") for i in range(5)] - - # All should fail with Unauthorized directly - results = await asyncio.gather(*tasks, return_exceptions=True) - assert all(isinstance(r, Unauthorized) for r in results) - - await session.close() - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_auth_error_in_prepared_statement(self): - """ - Test authorization failure with prepared statements. - - What this tests: - --------------- - 1. Prepare succeeds (metadata access) - 2. Execute fails (data access) - 3. Different permission requirements - 4. Error handling consistency - - Why this matters: - ---------------- - Prepared statements have two phases: - - Prepare: needs schema access - - Execute: needs data access - - Users might have permission to see schema - but not to access data, leading to: - - Prepare success - - Execute failure - - This split permission model must be handled. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Setup mock cluster and session - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 - - mock_session = Mock() - mock_cluster.connect.return_value = mock_session - - async_cluster = AsyncCluster() - session = await async_cluster.connect() - - # Prepare succeeds - prepared = Mock() - prepared.query = "INSERT INTO users (id, name) VALUES (?, ?)" - prepare_future = Mock() - prepare_future.result = Mock(return_value=prepared) - prepare_future.add_callbacks = Mock() - prepare_future.has_more_pages = False - prepare_future.timeout = None - prepare_future.clear_callbacks = Mock() - mock_session.prepare_async.return_value = prepare_future - - stmt = await session.prepare("INSERT INTO users (id, name) VALUES (?, ?)") - - # But execution fails with auth error - mock_session.execute_async.return_value = self.create_error_future( - Unauthorized("User has no MODIFY permission on
") - ) - - # Unauthorized is passed through directly - with pytest.raises(Unauthorized) as exc_info: - await session.execute(stmt, [1, "test"]) - - assert "no MODIFY permission" in str(exc_info.value) - - await session.close() - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_keyspace_auth_failure(self): - """ - Test authorization failure when switching keyspaces. - - What this tests: - --------------- - 1. Keyspace-level permissions - 2. Connection fails with no keyspace access - 3. NoHostAvailable with Unauthorized - 4. Wrapped in ConnectionError - - Why this matters: - ---------------- - Keyspace permissions control: - - Which keyspaces users can access - - Data isolation between tenants - - Security boundaries - - Connection failures due to keyspace access - need clear error messages for debugging. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - - # Try to connect to specific keyspace with no access - mock_cluster.connect.side_effect = NoHostAvailable( - "Unable to connect to any servers", - { - "127.0.0.1": Unauthorized( - "User has no ACCESS permission on " - ) - }, - ) - - async_cluster = AsyncCluster() - - # Should fail with connection error - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect("restricted_ks") - - assert "Failed to connect" in str(exc_info.value) - - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_auth_provider_callback_handling(self): - """ - Test custom auth provider with async callbacks. - - What this tests: - --------------- - 1. Custom auth providers accepted - 2. Async credential fetching supported - 3. Provider integration works - 4. No interference with driver auth - - Why this matters: - ---------------- - Advanced auth scenarios require: - - Dynamic credential fetching - - Token-based authentication - - External auth services - - The async wrapper must support custom - auth providers for enterprise use cases. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 - - # Create custom auth provider - class AsyncAuthProvider: - def __init__(self): - self.call_count = 0 - - async def get_credentials(self): - self.call_count += 1 - # Simulate async credential fetching - await asyncio.sleep(0.01) - return {"username": "user", "password": "pass"} - - auth_provider = AsyncAuthProvider() - - # AsyncCluster constructor accepts auth_provider - async_cluster = AsyncCluster(auth_provider=auth_provider) - - # The driver handles auth internally, we just pass the provider - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_auth_provider_refresh(self): - """ - Test auth provider that refreshes credentials. - - What this tests: - --------------- - 1. Refreshable auth providers work - 2. Credential rotation capability - 3. Provider state management - 4. Integration with async wrapper - - Why this matters: - ---------------- - Production auth often requires: - - Periodic credential refresh - - Token renewal before expiry - - Seamless rotation without downtime - - Supporting refreshable providers enables - enterprise authentication patterns. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - - class RefreshableAuthProvider: - def __init__(self): - self.refresh_count = 0 - self.credentials = {"username": "user", "password": "initial"} - - async def refresh_credentials(self): - self.refresh_count += 1 - self.credentials["password"] = f"refreshed_{self.refresh_count}" - return self.credentials - - auth_provider = RefreshableAuthProvider() - - async_cluster = AsyncCluster(auth_provider=auth_provider) - - # Note: The actual credential refresh would be handled by the driver - # We're just testing that our wrapper can accept such providers - - await async_cluster.shutdown() diff --git a/tests/unit/test_backpressure_handling.py b/tests/unit/test_backpressure_handling.py deleted file mode 100644 index 7d760bc..0000000 --- a/tests/unit/test_backpressure_handling.py +++ /dev/null @@ -1,574 +0,0 @@ -""" -Unit tests for backpressure and queue management. - -Tests how the async wrapper handles: -- Client-side request queue overflow -- Server overload responses -- Backpressure propagation -- Queue management strategies - -Test Organization: -================== -1. Queue Overflow - Client request queue limits -2. Server Overload - Coordinator overload responses -3. Backpressure Propagation - Flow control -4. Adaptive Control - Dynamic concurrency adjustment -5. Circuit Breaker - Fail-fast under overload -6. Load Shedding - Dropping low priority work - -Key Testing Principles: -====================== -- Simulate realistic overload scenarios -- Test backpressure mechanisms -- Verify graceful degradation -- Ensure system stability -""" - -import asyncio -from unittest.mock import Mock - -import pytest -from cassandra import OperationTimedOut, WriteTimeout - -from async_cassandra import AsyncCassandraSession - - -class TestBackpressureHandling: - """Test backpressure and queue management scenarios.""" - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock() - session.execute_async = Mock() - session.cluster = Mock() - - # Mock request queue settings - session.cluster.protocol_version = 5 - session.cluster.connection_class = Mock() - session.cluster.connection_class.max_in_flight = 128 - - return session - - def create_error_future(self, exception): - """Create a mock future that raises the given exception.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def create_success_future(self, result): - """Create a mock future that returns a result.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - # For success, the callback expects an iterable of rows - # Create a mock that can be iterated over - mock_rows = [result] if result else [] - callback(mock_rows) - if errback: - errbacks.append(errback) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.mark.asyncio - async def test_client_queue_overflow(self, mock_session): - """ - Test handling when client request queue overflows. - - What this tests: - --------------- - 1. Client has finite request queue - 2. Queue overflow causes timeouts - 3. Clear error message provided - 4. Some requests fail when overloaded - - Why this matters: - ---------------- - Request queues prevent memory exhaustion: - - Each pending request uses memory - - Unbounded queues cause OOM - - Better to fail fast than crash - - Applications must handle queue overflow - with backoff or rate limiting. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track requests - request_count = 0 - max_requests = 10 - - def execute_async_side_effect(*args, **kwargs): - nonlocal request_count - request_count += 1 - - if request_count > max_requests: - # Queue is full - return self.create_error_future( - OperationTimedOut("Client request queue is full (max_in_flight=10)") - ) - - # Success response - return self.create_success_future({"id": request_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Try to overflow the queue - tasks = [] - for i in range(15): # More than max_requests - tasks.append(async_session.execute(f"SELECT * FROM test WHERE id = {i}")) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Some should fail with overload - overloaded = [r for r in results if isinstance(r, OperationTimedOut)] - assert len(overloaded) > 0 - assert "queue is full" in str(overloaded[0]) - - @pytest.mark.asyncio - async def test_server_overload_response(self, mock_session): - """ - Test handling server overload responses. - - What this tests: - --------------- - 1. Server signals overload via WriteTimeout - 2. Coordinator can't handle load - 3. Multiple attempts may fail - 4. Eventually recovers - - Why this matters: - ---------------- - Server overload indicates: - - Too many concurrent requests - - Slow queries consuming resources - - Need for client-side throttling - - Proper handling prevents cascading - failures and allows recovery. - """ - async_session = AsyncCassandraSession(mock_session) - - # Simulate server overload responses - overload_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal overload_count - overload_count += 1 - - if overload_count <= 3: - # First 3 requests get overloaded response - from cassandra import WriteType - - error = WriteTimeout("Coordinator overloaded", write_type=WriteType.SIMPLE) - error.consistency_level = 1 - error.required_responses = 1 - error.received_responses = 0 - return self.create_error_future(error) - - # Subsequent requests succeed - # Create a proper row object - row = {"success": True} - return self.create_success_future(row) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # First attempts should fail - for i in range(3): - with pytest.raises(WriteTimeout) as exc_info: - await async_session.execute("INSERT INTO test VALUES (1)") - assert "Coordinator overloaded" in str(exc_info.value) - - # Next attempt should succeed (after backoff) - result = await async_session.execute("INSERT INTO test VALUES (1)") - assert len(result.rows) == 1 - assert result.rows[0]["success"] is True - - @pytest.mark.asyncio - async def test_backpressure_propagation(self, mock_session): - """ - Test that backpressure is properly propagated to callers. - - What this tests: - --------------- - 1. Backpressure signals propagate up - 2. Callers receive clear errors - 3. Can distinguish from other failures - 4. Enables flow control - - Why this matters: - ---------------- - Backpressure enables flow control: - - Prevents overwhelming the system - - Allows graceful slowdown - - Better than dropping requests - - Applications can respond by: - - Reducing request rate - - Buffering at higher level - - Applying backoff - """ - async_session = AsyncCassandraSession(mock_session) - - # Track requests - request_count = 0 - threshold = 5 - - def execute_async_side_effect(*args, **kwargs): - nonlocal request_count - request_count += 1 - - if request_count > threshold: - # Simulate backpressure - return self.create_error_future( - OperationTimedOut("Backpressure active - please slow down") - ) - - # Success response - return self.create_success_future({"id": request_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Send burst of requests - tasks = [] - for i in range(10): - tasks.append(async_session.execute(f"SELECT {i}")) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Should have some backpressure errors - backpressure_errors = [r for r in results if isinstance(r, OperationTimedOut)] - assert len(backpressure_errors) > 0 - assert "Backpressure active" in str(backpressure_errors[0]) - - @pytest.mark.asyncio - async def test_adaptive_concurrency_control(self, mock_session): - """ - Test adaptive concurrency control based on response times. - - What this tests: - --------------- - 1. Concurrency limit adjusts dynamically - 2. Reduces limit under stress - 3. Rejects excess requests - 4. Prevents overload - - Why this matters: - ---------------- - Static limits don't work well: - - Load varies over time - - Query complexity changes - - Node performance fluctuates - - Adaptive control maintains optimal - throughput without overload. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track concurrency - request_count = 0 - initial_limit = 10 - current_limit = initial_limit - rejected_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal request_count, current_limit, rejected_count - request_count += 1 - - # Simulate adaptive behavior - reduce limit after 5 requests - if request_count == 5: - current_limit = 5 - - # Reject if over limit - if request_count % 10 > current_limit: - rejected_count += 1 - return self.create_error_future( - OperationTimedOut(f"Concurrency limit reached ({current_limit})") - ) - - # Success response with simulated latency - return self.create_success_future({"latency": 50 + request_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Execute requests - success_count = 0 - for i in range(20): - try: - await async_session.execute(f"SELECT {i}") - success_count += 1 - except OperationTimedOut: - pass - - # Should have some rejections due to adaptive limits - assert rejected_count > 0 - assert current_limit != initial_limit - - @pytest.mark.asyncio - async def test_queue_timeout_handling(self, mock_session): - """ - Test handling of requests that timeout while queued. - - What this tests: - --------------- - 1. Queued requests can timeout - 2. Don't wait forever in queue - 3. Clear timeout indication - 4. Resources cleaned up - - Why this matters: - ---------------- - Queue timeouts prevent: - - Indefinite waiting - - Resource accumulation - - Poor user experience - - Failed fast is better than - hanging indefinitely. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track requests - request_count = 0 - queue_size_limit = 5 - - def execute_async_side_effect(*args, **kwargs): - nonlocal request_count - request_count += 1 - - # Simulate queue timeout for requests beyond limit - if request_count > queue_size_limit: - return self.create_error_future( - OperationTimedOut("Request timed out in queue after 1.0s") - ) - - # Success response - return self.create_success_future({"processed": True}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Send requests that will queue up - tasks = [] - for i in range(10): - tasks.append(async_session.execute(f"SELECT {i}")) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Should have some timeouts - timeouts = [r for r in results if isinstance(r, OperationTimedOut)] - assert len(timeouts) > 0 - assert "timed out in queue" in str(timeouts[0]) - - @pytest.mark.asyncio - async def test_priority_queue_management(self, mock_session): - """ - Test priority-based queue management during overload. - - What this tests: - --------------- - 1. High priority queries processed first - 2. System/critical queries prioritized - 3. Normal queries may wait - 4. Priority ordering maintained - - Why this matters: - ---------------- - Not all queries are equal: - - Health checks must work - - Critical paths prioritized - - Analytics can wait - - Priority queues ensure critical - operations continue under load. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track processed queries - processed_queries = [] - - def execute_async_side_effect(*args, **kwargs): - query = str(args[0] if args else kwargs.get("query", "")) - - # Determine priority - is_high_priority = "SYSTEM" in query or "CRITICAL" in query - - # Track order - if is_high_priority: - # Insert high priority at front - processed_queries.insert(0, query) - else: - # Append normal priority - processed_queries.append(query) - - # Always succeed - return self.create_success_future({"query": query}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Mix of priority queries - queries = [ - "SELECT * FROM users", # Normal - "CRITICAL: SELECT * FROM system.local", # High - "SELECT * FROM data", # Normal - "SYSTEM CHECK", # High - "SELECT * FROM logs", # Normal - ] - - for query in queries: - result = await async_session.execute(query) - assert result.rows[0]["query"] == query - - # High priority queries should be at front of processed list - assert "CRITICAL" in processed_queries[0] or "SYSTEM" in processed_queries[0] - assert "CRITICAL" in processed_queries[1] or "SYSTEM" in processed_queries[1] - - @pytest.mark.asyncio - async def test_circuit_breaker_on_overload(self, mock_session): - """ - Test circuit breaker pattern for overload protection. - - What this tests: - --------------- - 1. Repeated failures open circuit - 2. Open circuit fails fast - 3. Prevents overwhelming failed system - 4. Can reset after recovery - - Why this matters: - ---------------- - Circuit breakers prevent: - - Cascading failures - - Resource exhaustion - - Thundering herd on recovery - - Failing fast gives system time - to recover without additional load. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track circuit breaker state - failure_count = 0 - circuit_open = False - - def execute_async_side_effect(*args, **kwargs): - nonlocal failure_count, circuit_open - - if circuit_open: - return self.create_error_future(OperationTimedOut("Circuit breaker is OPEN")) - - # First 3 requests fail - if failure_count < 3: - failure_count += 1 - if failure_count == 3: - circuit_open = True - return self.create_error_future(OperationTimedOut("Server overloaded")) - - # After circuit reset, succeed - return self.create_success_future({"success": True}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Trigger circuit breaker with 3 failures - for i in range(3): - with pytest.raises(OperationTimedOut) as exc_info: - await async_session.execute("SELECT 1") - assert "Server overloaded" in str(exc_info.value) - - # Circuit should be open - with pytest.raises(OperationTimedOut) as exc_info: - await async_session.execute("SELECT 2") - assert "Circuit breaker is OPEN" in str(exc_info.value) - - # Reset circuit for test - circuit_open = False - - # Should allow attempt after reset - result = await async_session.execute("SELECT 3") - assert result.rows[0]["success"] is True - - @pytest.mark.asyncio - async def test_load_shedding_strategy(self, mock_session): - """ - Test load shedding to prevent system overload. - - What this tests: - --------------- - 1. Optional queries shed under load - 2. Critical queries still processed - 3. Clear load shedding errors - 4. System remains stable - - Why this matters: - ---------------- - Load shedding maintains stability: - - Drops non-essential work - - Preserves critical functions - - Prevents total failure - - Better to serve some requests - well than fail all requests. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track queries - shed_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal shed_count - query = str(args[0] if args else kwargs.get("query", "")) - - # Shed optional/low priority queries - if "OPTIONAL" in query or "LOW_PRIORITY" in query: - shed_count += 1 - return self.create_error_future(OperationTimedOut("Load shedding active (load=85)")) - - # Normal queries succeed - return self.create_success_future({"executed": query}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Send mix of queries - queries = [ - "SELECT * FROM users", - "OPTIONAL: SELECT * FROM logs", - "INSERT INTO data VALUES (1)", - "LOW_PRIORITY: SELECT count(*) FROM events", - "SELECT * FROM critical_data", - ] - - results = [] - for query in queries: - try: - result = await async_session.execute(query) - results.append(result.rows[0]["executed"]) - except OperationTimedOut: - results.append(f"SHED: {query}") - - # Should have shed optional/low priority queries - shed_queries = [r for r in results if r.startswith("SHED:")] - assert len(shed_queries) == 2 # OPTIONAL and LOW_PRIORITY - assert any("OPTIONAL" in q for q in shed_queries) - assert any("LOW_PRIORITY" in q for q in shed_queries) - assert shed_count == 2 diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py deleted file mode 100644 index 6d4ab83..0000000 --- a/tests/unit/test_base.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -Unit tests for base module decorators and utilities. - -This module tests the foundational AsyncContextManageable mixin that provides -async context manager functionality to AsyncCluster, AsyncSession, and other -resources that need automatic cleanup. - -Test Organization: -================== -- TestAsyncContextManageable: Tests the async context manager mixin -- TestAsyncStreamingResultSet: Tests streaming result wrapper (if present) - -Key Testing Focus: -================== -1. Resource cleanup happens automatically -2. Exceptions don't prevent cleanup -3. Multiple cleanup calls are safe -4. Proper async/await protocol implementation -""" - -import pytest - -from async_cassandra.base import AsyncContextManageable - - -class TestAsyncContextManageable: - """ - Test AsyncContextManageable mixin. - - This mixin is inherited by AsyncCluster, AsyncSession, and other - resources to provide 'async with' functionality. It ensures proper - cleanup even when exceptions occur. - """ - - @pytest.mark.asyncio - async def test_context_manager(self): - """ - Test basic async context manager functionality. - - What this tests: - --------------- - 1. Resources implementing AsyncContextManageable can use 'async with' - 2. The resource is returned from __aenter__ for use in the context - 3. close() is automatically called when exiting the context - 4. Resource state properly reflects being closed - - Why this matters: - ---------------- - Context managers are the primary way to ensure resource cleanup in Python. - This pattern prevents resource leaks by guaranteeing cleanup happens even - if the user forgets to call close() explicitly. - - Example usage pattern: - -------------------- - async with AsyncCluster() as cluster: - async with cluster.connect() as session: - await session.execute(...) - # Both session and cluster are automatically closed here - """ - - class TestResource(AsyncContextManageable): - close_count = 0 - is_closed = False - - async def close(self): - self.close_count += 1 - self.is_closed = True - - # Use as context manager - async with TestResource() as resource: - # Inside context: resource should be open - assert not resource.is_closed - assert resource.close_count == 0 - - # After context: should be closed exactly once - assert resource.is_closed - assert resource.close_count == 1 - - @pytest.mark.asyncio - async def test_context_manager_with_exception(self): - """ - Test context manager closes resource even when exception occurs. - - What this tests: - --------------- - 1. Exceptions inside the context don't prevent cleanup - 2. close() is called even when exception is raised - 3. The original exception is propagated (not suppressed) - 4. Resource state is consistent after exception - - Why this matters: - ---------------- - Many errors can occur during database operations: - - Network failures - - Query errors - - Timeout exceptions - - Application logic errors - - The context manager MUST clean up resources even when these - errors occur, otherwise we leak connections, memory, and threads. - - Real-world scenario: - ------------------- - async with cluster.connect() as session: - await session.execute("INVALID QUERY") # Raises QueryError - # session.close() must still be called despite the error - """ - - class TestResource(AsyncContextManageable): - close_count = 0 - is_closed = False - - async def close(self): - self.close_count += 1 - self.is_closed = True - - resource = None - try: - async with TestResource() as res: - resource = res - raise ValueError("Test error") - except ValueError: - pass - - # Should still close resource on exception - assert resource is not None - assert resource.is_closed - assert resource.close_count == 1 - - @pytest.mark.asyncio - async def test_context_manager_multiple_use(self): - """ - Test context manager can be used multiple times. - - What this tests: - --------------- - 1. Same resource can enter/exit context multiple times - 2. close() is called each time the context exits - 3. No state corruption between uses - 4. Resource remains functional for multiple contexts - - Why this matters: - ---------------- - While not common, some use cases might reuse resources: - - Connection pooling implementations - - Cached sessions with periodic cleanup - - Test fixtures that reset between tests - - The mixin should handle multiple uses gracefully without - assuming single-use semantics. - - Note: - ----- - In practice, most resources (cluster, session) are used - once and discarded, but the base mixin doesn't enforce this. - """ - - class TestResource(AsyncContextManageable): - close_count = 0 - - async def close(self): - self.close_count += 1 - - resource = TestResource() - - # First use - async with resource: - pass - assert resource.close_count == 1 - - # Second use - should work and increment close count - async with resource: - pass - assert resource.close_count == 2 diff --git a/tests/unit/test_basic_queries.py b/tests/unit/test_basic_queries.py deleted file mode 100644 index a5eb17c..0000000 --- a/tests/unit/test_basic_queries.py +++ /dev/null @@ -1,513 +0,0 @@ -"""Core basic query execution tests. - -This module tests fundamental query operations that must work -for the async wrapper to be functional. These are the most basic -operations that users will perform, so they must be rock solid. - -Test Organization: -================== -- TestBasicQueryExecution: All fundamental query types (SELECT, INSERT, UPDATE, DELETE) -- Tests both simple string queries and parameterized queries -- Covers various query options (consistency, timeout, custom payload) - -Key Testing Focus: -================== -1. All CRUD operations work correctly -2. Parameters are properly passed to the driver -3. Results are wrapped in AsyncResultSet -4. Query options (timeout, consistency) are preserved -5. Empty results are handled gracefully -""" - -from unittest.mock import AsyncMock, Mock, patch - -import pytest -from cassandra import ConsistencyLevel -from cassandra.cluster import ResponseFuture -from cassandra.query import SimpleStatement - -from async_cassandra import AsyncCassandraSession as AsyncSession -from async_cassandra.result import AsyncResultSet - - -class TestBasicQueryExecution: - """ - Test basic query execution patterns. - - These tests ensure that the async wrapper correctly handles all - fundamental query types that users will execute against Cassandra. - Each test mocks the underlying driver to focus on the wrapper's behavior. - """ - - def _setup_mock_execute(self, mock_session, result_data=None): - """ - Helper to setup mock execute_async with proper response. - - Creates a mock ResponseFuture that simulates the driver's - async execution mechanism. This allows us to test the wrapper - without actual network calls. - """ - mock_future = Mock(spec=ResponseFuture) - mock_future.has_more_pages = False - mock_session.execute_async.return_value = mock_future - - if result_data is None: - result_data = [] - - return AsyncResultSet(result_data) - - @pytest.mark.core - @pytest.mark.quick - @pytest.mark.critical - async def test_simple_select(self): - """ - Test basic SELECT query execution. - - What this tests: - --------------- - 1. Simple string SELECT queries work - 2. Results are returned as AsyncResultSet - 3. The driver's execute_async is called (not execute) - 4. No parameters case works correctly - - Why this matters: - ---------------- - SELECT queries are the most common operation. This test ensures - the basic read path works: - - Query string is passed correctly - - Async execution is used - - Results are properly wrapped - - This is the simplest possible query - if this doesn't work, - nothing else will. - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session, [{"id": 1, "name": "test"}]) - - async_session = AsyncSession(mock_session) - - # Patch AsyncResultHandler to simulate immediate result - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute("SELECT * FROM users WHERE id = 1") - - assert isinstance(result, AsyncResultSet) - mock_session.execute_async.assert_called_once() - - @pytest.mark.core - @pytest.mark.critical - async def test_parameterized_query(self): - """ - Test query with bound parameters. - - What this tests: - --------------- - 1. Parameterized queries work with ? placeholders - 2. Parameters are passed as a list - 3. Multiple parameters are handled correctly - 4. Parameter values are preserved exactly - - Why this matters: - ---------------- - Parameterized queries are essential for: - - SQL injection prevention - - Better performance (query plan caching) - - Type safety - - Clean code (no string concatenation) - - This test ensures parameters flow correctly through the - async wrapper to the driver. Parameter handling bugs could - cause security vulnerabilities or data corruption. - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session, [{"id": 123, "status": "active"}]) - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute( - "SELECT * FROM users WHERE id = ? AND status = ?", [123, "active"] - ) - - assert isinstance(result, AsyncResultSet) - # Verify query and parameters were passed - call_args = mock_session.execute_async.call_args - assert call_args[0][0] == "SELECT * FROM users WHERE id = ? AND status = ?" - assert call_args[0][1] == [123, "active"] - - @pytest.mark.core - async def test_query_with_consistency_level(self): - """ - Test query with custom consistency level. - - What this tests: - --------------- - 1. SimpleStatement with consistency level works - 2. Consistency level is preserved through execution - 3. Statement objects are passed correctly - 4. QUORUM consistency can be specified - - Why this matters: - ---------------- - Consistency levels control the CAP theorem trade-offs: - - ONE: Fast but may read stale data - - QUORUM: Balanced consistency and availability - - ALL: Strong consistency but less available - - Applications need fine-grained control over consistency - per query. This test ensures that control is preserved - through our async wrapper. - - Example use case: - ---------------- - - User profile reads: ONE (fast, eventual consistency OK) - - Financial transactions: QUORUM (must be consistent) - - Critical configuration: ALL (absolute consistency) - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session, [{"id": 1}]) - - async_session = AsyncSession(mock_session) - - statement = SimpleStatement( - "SELECT * FROM users", consistency_level=ConsistencyLevel.QUORUM - ) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute(statement) - - assert isinstance(result, AsyncResultSet) - # Verify statement was passed - call_args = mock_session.execute_async.call_args - assert isinstance(call_args[0][0], SimpleStatement) - assert call_args[0][0].consistency_level == ConsistencyLevel.QUORUM - - @pytest.mark.core - @pytest.mark.critical - async def test_insert_query(self): - """ - Test INSERT query execution. - - What this tests: - --------------- - 1. INSERT queries with parameters work - 2. Multiple values can be inserted - 3. Parameter order is preserved - 4. Returns AsyncResultSet (even though usually empty) - - Why this matters: - ---------------- - INSERT is a fundamental write operation. This test ensures: - - Data can be written to Cassandra - - Parameter binding works for writes - - The async pattern works for non-SELECT queries - - Common pattern: - -------------- - await session.execute( - "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", - [user_id, name, email] - ) - - The result is typically empty but may contain info for - special cases (LWT with IF NOT EXISTS). - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session) - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute( - "INSERT INTO users (id, name, email) VALUES (?, ?, ?)", - [1, "John Doe", "john@example.com"], - ) - - assert isinstance(result, AsyncResultSet) - # Verify query was executed - call_args = mock_session.execute_async.call_args - assert "INSERT INTO users" in call_args[0][0] - assert call_args[0][1] == [1, "John Doe", "john@example.com"] - - @pytest.mark.core - async def test_update_query(self): - """ - Test UPDATE query execution. - - What this tests: - --------------- - 1. UPDATE queries work with WHERE clause - 2. SET values can be parameterized - 3. WHERE conditions can be parameterized - 4. Parameter order matters (SET params, then WHERE params) - - Why this matters: - ---------------- - UPDATE operations modify existing data. Critical aspects: - - Must target specific rows (WHERE clause) - - Must preserve parameter order - - Often used for state changes - - Common mistakes this prevents: - - Forgetting WHERE clause (would update all rows!) - - Mixing up parameter order - - SQL injection via string concatenation - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session) - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute( - "UPDATE users SET name = ? WHERE id = ?", ["Jane Doe", 1] - ) - - assert isinstance(result, AsyncResultSet) - - @pytest.mark.core - async def test_delete_query(self): - """ - Test DELETE query execution. - - What this tests: - --------------- - 1. DELETE queries work with WHERE clause - 2. WHERE parameters are handled correctly - 3. Returns AsyncResultSet (typically empty) - - Why this matters: - ---------------- - DELETE operations remove data permanently. Critical because: - - Data loss is irreversible - - Must target specific rows - - Often part of cleanup or state transitions - - Safety considerations: - - Always use WHERE clause - - Consider soft deletes for audit trails - - May create tombstones (performance impact) - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session) - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute("DELETE FROM users WHERE id = ?", [1]) - - assert isinstance(result, AsyncResultSet) - - @pytest.mark.core - @pytest.mark.critical - async def test_batch_query(self): - """ - Test batch query execution. - - What this tests: - --------------- - 1. CQL batch syntax is supported - 2. Multiple statements in one batch work - 3. Batch is executed as a single operation - 4. Returns AsyncResultSet - - Why this matters: - ---------------- - Batches are used for: - - Atomic operations (all succeed or all fail) - - Reducing round trips - - Maintaining consistency across rows - - Important notes: - - This tests CQL string batches - - For programmatic batches, use BatchStatement - - Batches can impact performance if misused - - Not the same as SQL transactions! - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session) - - async_session = AsyncSession(mock_session) - - batch_query = """ - BEGIN BATCH - INSERT INTO users (id, name) VALUES (1, 'User 1'); - INSERT INTO users (id, name) VALUES (2, 'User 2'); - APPLY BATCH - """ - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute(batch_query) - - assert isinstance(result, AsyncResultSet) - - @pytest.mark.core - async def test_query_with_timeout(self): - """ - Test query with timeout parameter. - - What this tests: - --------------- - 1. Timeout parameter is accepted - 2. Timeout value is passed to execute_async - 3. Timeout is in the correct position (5th argument) - 4. Float timeout values work - - Why this matters: - ---------------- - Timeouts prevent: - - Queries hanging forever - - Resource exhaustion - - Cascading failures - - Critical for production: - - Set reasonable timeouts - - Handle timeout errors gracefully - - Different timeouts for different query types - - Note: This tests request timeout, not connection timeout. - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session) - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute("SELECT * FROM users", timeout=10.0) - - assert isinstance(result, AsyncResultSet) - # Check timeout was passed - call_args = mock_session.execute_async.call_args - # Timeout is the 5th positional argument (after query, params, trace, custom_payload) - assert call_args[0][4] == 10.0 - - @pytest.mark.core - async def test_query_with_custom_payload(self): - """ - Test query with custom payload. - - What this tests: - --------------- - 1. Custom payload parameter is accepted - 2. Payload dict is passed to execute_async - 3. Payload is in correct position (4th argument) - 4. Payload structure is preserved - - Why this matters: - ---------------- - Custom payloads enable: - - Request tracing/debugging - - Multi-tenancy information - - Feature flags per query - - Custom routing hints - - Advanced feature used by: - - Monitoring systems - - Multi-tenant applications - - Custom Cassandra extensions - - The payload is opaque to the driver but may be - used by custom QueryHandler implementations. - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session) - - async_session = AsyncSession(mock_session) - custom_payload = {"key": "value"} - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute( - "SELECT * FROM users", custom_payload=custom_payload - ) - - assert isinstance(result, AsyncResultSet) - # Check custom_payload was passed - call_args = mock_session.execute_async.call_args - # Custom payload is the 4th positional argument - assert call_args[0][3] == custom_payload - - @pytest.mark.core - @pytest.mark.critical - async def test_empty_result_handling(self): - """ - Test handling of empty results. - - What this tests: - --------------- - 1. Empty result sets are handled gracefully - 2. AsyncResultSet works with no rows - 3. Iteration over empty results completes immediately - 4. No errors when converting empty results to list - - Why this matters: - ---------------- - Empty results are common: - - No matching rows for WHERE clause - - Table is empty - - Row was already deleted - - Applications must handle empty results without: - - Raising exceptions - - Hanging on iteration - - Returning None instead of empty set - - Common pattern: - -------------- - result = await session.execute("SELECT * FROM users WHERE id = ?", [999]) - users = [row async for row in result] # Should be [] - if not users: - print("User not found") - """ - mock_session = Mock() - expected_result = self._setup_mock_execute(mock_session, []) - - async_session = AsyncSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=expected_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute("SELECT * FROM users WHERE id = 999") - - assert isinstance(result, AsyncResultSet) - # Convert to list to check emptiness - rows = [] - async for row in result: - rows.append(row) - assert rows == [] diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py deleted file mode 100644 index 4f49e6f..0000000 --- a/tests/unit/test_cluster.py +++ /dev/null @@ -1,877 +0,0 @@ -""" -Unit tests for async cluster management. - -This module tests AsyncCluster in detail, covering: -- Initialization with various configurations -- Connection establishment and error handling -- Protocol version validation (v5+ requirement) -- SSL/TLS support -- Resource cleanup and context managers -- Metadata access and user type registration - -Key Testing Focus: -================== -1. Protocol Version Enforcement - We require v5+ for async operations -2. Connection Error Handling - Clear error messages for common issues -3. Thread Safety - Proper locking for shutdown operations -4. Resource Management - No leaks even with errors -""" - -from ssl import PROTOCOL_TLS_CLIENT, SSLContext -from unittest.mock import Mock, patch - -import pytest -from cassandra.auth import PlainTextAuthProvider -from cassandra.cluster import Cluster -from cassandra.policies import ExponentialReconnectionPolicy, TokenAwarePolicy - -from async_cassandra.cluster import AsyncCluster -from async_cassandra.exceptions import ConfigurationError, ConnectionError -from async_cassandra.retry_policy import AsyncRetryPolicy -from async_cassandra.session import AsyncCassandraSession - - -class TestAsyncCluster: - """ - Test cases for AsyncCluster. - - AsyncCluster is responsible for: - - Managing connection to Cassandra nodes - - Enforcing protocol version requirements - - Providing session creation - - Handling authentication and SSL - """ - - @pytest.fixture - def mock_cluster(self): - """ - Create a mock Cassandra cluster. - - This fixture patches the driver's Cluster class to avoid - actual network connections during unit tests. The mock - provides the minimal interface needed for our tests. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_instance = Mock(spec=Cluster) - mock_instance.shutdown = Mock() - mock_instance.metadata = {"test": "metadata"} - mock_cluster_class.return_value = mock_instance - yield mock_instance - - def test_init_with_defaults(self, mock_cluster): - """ - Test initialization with default values. - - What this tests: - --------------- - 1. AsyncCluster can be created without parameters - 2. Default contact point is localhost (127.0.0.1) - 3. Default port is 9042 (Cassandra standard) - 4. Default policies are applied: - - TokenAwarePolicy for load balancing (data locality) - - ExponentialReconnectionPolicy (gradual backoff) - - AsyncRetryPolicy (our custom retry logic) - - Why this matters: - ---------------- - Defaults should work for local development and common setups. - The default policies provide good production behavior: - - Token awareness reduces latency - - Exponential backoff prevents connection storms - - Async retry policy handles transient failures - """ - async_cluster = AsyncCluster() - - # Verify cluster starts in open state - assert not async_cluster.is_closed - - # Verify driver cluster was created with expected defaults - from async_cassandra.cluster import Cluster as ClusterImport - - ClusterImport.assert_called_once() - call_args = ClusterImport.call_args - - # Check connection defaults - assert call_args.kwargs["contact_points"] == ["127.0.0.1"] - assert call_args.kwargs["port"] == 9042 - - # Check policy defaults - assert isinstance(call_args.kwargs["load_balancing_policy"], TokenAwarePolicy) - assert isinstance(call_args.kwargs["reconnection_policy"], ExponentialReconnectionPolicy) - assert isinstance(call_args.kwargs["default_retry_policy"], AsyncRetryPolicy) - - def test_init_with_custom_values(self, mock_cluster): - """ - Test initialization with custom values. - - What this tests: - --------------- - 1. All custom parameters are passed to the driver - 2. Multiple contact points can be specified - 3. Authentication is configurable - 4. Thread pool size can be tuned - 5. Protocol version can be explicitly set - - Why this matters: - ---------------- - Production deployments need: - - Multiple nodes for high availability - - Custom ports for security/routing - - Authentication for access control - - Thread tuning for workload optimization - - Protocol version control for compatibility - """ - contact_points = ["192.168.1.1", "192.168.1.2"] - port = 9043 - auth_provider = PlainTextAuthProvider("user", "pass") - - AsyncCluster( - contact_points=contact_points, - port=port, - auth_provider=auth_provider, - executor_threads=4, # Smaller pool for testing - protocol_version=5, # Explicit v5 - ) - - from async_cassandra.cluster import Cluster as ClusterImport - - call_args = ClusterImport.call_args - - # Verify all custom values were passed through - assert call_args.kwargs["contact_points"] == contact_points - assert call_args.kwargs["port"] == port - assert call_args.kwargs["auth_provider"] == auth_provider - assert call_args.kwargs["executor_threads"] == 4 - assert call_args.kwargs["protocol_version"] == 5 - - def test_create_with_auth(self, mock_cluster): - """ - Test creating cluster with authentication. - - What this tests: - --------------- - 1. create_with_auth() helper method works - 2. PlainTextAuthProvider is created automatically - 3. Username/password are properly configured - - Why this matters: - ---------------- - This is a convenience method for the common case of - username/password authentication. It saves users from: - - Importing PlainTextAuthProvider - - Creating the auth provider manually - - Reduces boilerplate for simple auth setups - - Example usage: - ------------- - cluster = AsyncCluster.create_with_auth( - contact_points=['cassandra.example.com'], - username='myuser', - password='mypass' - ) - """ - contact_points = ["localhost"] - username = "testuser" - password = "testpass" - - AsyncCluster.create_with_auth( - contact_points=contact_points, username=username, password=password - ) - - from async_cassandra.cluster import Cluster as ClusterImport - - call_args = ClusterImport.call_args - - assert call_args.kwargs["contact_points"] == contact_points - # Verify PlainTextAuthProvider was created - auth_provider = call_args.kwargs["auth_provider"] - assert isinstance(auth_provider, PlainTextAuthProvider) - - @pytest.mark.asyncio - async def test_connect_without_keyspace(self, mock_cluster): - """ - Test connecting without keyspace. - - What this tests: - --------------- - 1. connect() can be called without specifying keyspace - 2. AsyncCassandraSession is created properly - 3. Protocol version is validated (must be v5+) - 4. None is passed as keyspace to session creation - - Why this matters: - ---------------- - Users often connect first, then select keyspace later. - This pattern is common for: - - Creating keyspaces dynamically - - Working with multiple keyspaces - - Administrative operations - - Protocol validation ensures async features work correctly. - """ - async_cluster = AsyncCluster() - - # Mock protocol version as v5 so it passes validation - mock_cluster.protocol_version = 5 - - with patch("async_cassandra.cluster.AsyncCassandraSession.create") as mock_create: - mock_session = Mock(spec=AsyncCassandraSession) - mock_create.return_value = mock_session - - session = await async_cluster.connect() - - assert session == mock_session - # Verify keyspace=None was passed - mock_create.assert_called_once_with(mock_cluster, None) - - @pytest.mark.asyncio - async def test_connect_with_keyspace(self, mock_cluster): - """ - Test connecting with keyspace. - - What this tests: - --------------- - 1. connect() accepts keyspace parameter - 2. Keyspace is passed to session creation - 3. Session is pre-configured with the keyspace - - Why this matters: - ---------------- - Specifying keyspace at connection time: - - Saves an extra round trip (no USE statement) - - Ensures all queries use the correct keyspace - - Prevents accidental cross-keyspace queries - - Common pattern for single-keyspace applications - """ - async_cluster = AsyncCluster() - keyspace = "test_keyspace" - - # Mock protocol version as v5 so it passes validation - mock_cluster.protocol_version = 5 - - with patch("async_cassandra.cluster.AsyncCassandraSession.create") as mock_create: - mock_session = Mock(spec=AsyncCassandraSession) - mock_create.return_value = mock_session - - session = await async_cluster.connect(keyspace) - - assert session == mock_session - # Verify keyspace was passed through - mock_create.assert_called_once_with(mock_cluster, keyspace) - - @pytest.mark.asyncio - async def test_connect_error(self, mock_cluster): - """ - Test handling connection error. - - What this tests: - --------------- - 1. Generic exceptions are wrapped in ConnectionError - 2. Original exception is preserved as __cause__ - 3. Error message provides context - - Why this matters: - ---------------- - Connection failures need clear error messages: - - Users need to know it's a connection issue - - Original error details must be preserved - - Stack traces should show the full context - - Common causes: - - Network issues - - Wrong contact points - - Cassandra not running - - Authentication failures - """ - async_cluster = AsyncCluster() - - with patch("async_cassandra.cluster.AsyncCassandraSession.create") as mock_create: - # Simulate connection failure - mock_create.side_effect = Exception("Connection failed") - - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Verify error wrapping - assert "Failed to connect to cluster" in str(exc_info.value) - # Verify original exception is preserved for debugging - assert exc_info.value.__cause__ is not None - - @pytest.mark.asyncio - async def test_connect_on_closed_cluster(self, mock_cluster): - """ - Test connecting on closed cluster. - - What this tests: - --------------- - 1. Cannot connect after shutdown() - 2. Clear error message is provided - 3. No resource leaks or hangs - - Why this matters: - ---------------- - Prevents common programming errors: - - Using cluster after cleanup - - Race conditions in shutdown - - Resource leaks from partial operations - - This ensures fail-fast behavior rather than - mysterious hangs or corrupted state. - """ - async_cluster = AsyncCluster() - # Close the cluster first - await async_cluster.shutdown() - - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Verify clear error message - assert "Cluster is closed" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_shutdown(self, mock_cluster): - """ - Test shutting down the cluster. - - What this tests: - --------------- - 1. shutdown() marks cluster as closed - 2. Driver's shutdown() is called - 3. is_closed property reflects state - - Why this matters: - ---------------- - Proper shutdown is critical for: - - Closing network connections - - Stopping background threads - - Releasing memory - - Clean process termination - """ - async_cluster = AsyncCluster() - - await async_cluster.shutdown() - - # Verify state change - assert async_cluster.is_closed - # Verify driver cleanup - mock_cluster.shutdown.assert_called_once() - - @pytest.mark.asyncio - async def test_shutdown_idempotent(self, mock_cluster): - """ - Test that shutdown is idempotent. - - What this tests: - --------------- - 1. Multiple shutdown() calls are safe - 2. Driver shutdown only happens once - 3. No errors on repeated calls - - Why this matters: - ---------------- - Idempotent shutdown prevents: - - Double-free errors - - Race conditions in cleanup - - Errors in finally blocks - - Users might call shutdown() multiple times: - - In error handlers - - In finally blocks - - From different cleanup paths - """ - async_cluster = AsyncCluster() - - # Call shutdown twice - await async_cluster.shutdown() - await async_cluster.shutdown() - - # Driver shutdown should only be called once - mock_cluster.shutdown.assert_called_once() - - @pytest.mark.asyncio - async def test_context_manager(self, mock_cluster): - """ - Test using cluster as async context manager. - - What this tests: - --------------- - 1. Cluster supports 'async with' syntax - 2. Cluster is open inside the context - 3. Automatic shutdown on context exit - - Why this matters: - ---------------- - Context managers ensure cleanup: - ```python - async with AsyncCluster() as cluster: - session = await cluster.connect() - # ... use session ... - # cluster.shutdown() called automatically - ``` - - Benefits: - - No forgotten shutdowns - - Exception safety - - Cleaner code - - Resource leak prevention - """ - async with AsyncCluster() as cluster: - # Inside context: cluster should be usable - assert isinstance(cluster, AsyncCluster) - assert not cluster.is_closed - - # After context: should be shut down - mock_cluster.shutdown.assert_called_once() - - def test_is_closed_property(self, mock_cluster): - """ - Test is_closed property. - - What this tests: - --------------- - 1. is_closed starts as False - 2. Reflects internal _closed state - 3. Read-only property (no setter) - - Why this matters: - ---------------- - Users need to check cluster state before operations. - This property enables defensive programming: - ```python - if not cluster.is_closed: - session = await cluster.connect() - ``` - """ - async_cluster = AsyncCluster() - - # Initially open - assert not async_cluster.is_closed - # Simulate closed state - async_cluster._closed = True - assert async_cluster.is_closed - - def test_metadata_property(self, mock_cluster): - """ - Test metadata property. - - What this tests: - --------------- - 1. Metadata is accessible from async wrapper - 2. Returns driver's cluster metadata - - Why this matters: - ---------------- - Metadata provides: - - Keyspace definitions - - Table schemas - - Node topology - - Token ranges - - Essential for advanced features like: - - Schema discovery - - Token-aware routing - - Dynamic query building - """ - async_cluster = AsyncCluster() - - assert async_cluster.metadata == {"test": "metadata"} - - def test_register_user_type(self, mock_cluster): - """ - Test registering user-defined type. - - What this tests: - --------------- - 1. User types can be registered - 2. Registration is delegated to driver - 3. Parameters are passed correctly - - Why this matters: - ---------------- - Cassandra supports complex user-defined types (UDTs). - Python classes must be registered to handle them: - - ```python - class Address: - def __init__(self, street, city, zip_code): - self.street = street - self.city = city - self.zip_code = zip_code - - cluster.register_user_type('my_keyspace', 'address', Address) - ``` - - This enables seamless UDT handling in queries. - """ - async_cluster = AsyncCluster() - - keyspace = "test_keyspace" - user_type = "address" - klass = type("Address", (), {}) # Dynamic class for testing - - async_cluster.register_user_type(keyspace, user_type, klass) - - # Verify delegation to driver - mock_cluster.register_user_type.assert_called_once_with(keyspace, user_type, klass) - - def test_ssl_context(self, mock_cluster): - """ - Test initialization with SSL context. - - What this tests: - --------------- - 1. SSL/TLS can be configured - 2. SSL context is passed to driver - - Why this matters: - ---------------- - Production Cassandra often requires encryption: - - Client-to-node encryption - - Compliance requirements - - Network security - - Example usage: - ------------- - ```python - import ssl - - ssl_context = ssl.create_default_context() - ssl_context.load_cert_chain('client.crt', 'client.key') - ssl_context.load_verify_locations('ca.crt') - - cluster = AsyncCluster(ssl_context=ssl_context) - ``` - """ - ssl_context = SSLContext(PROTOCOL_TLS_CLIENT) - - AsyncCluster(ssl_context=ssl_context) - - from async_cassandra.cluster import Cluster as ClusterImport - - call_args = ClusterImport.call_args - - # Verify SSL context passed through - assert call_args.kwargs["ssl_context"] == ssl_context - - def test_protocol_version_validation_v1(self, mock_cluster): - """ - Test that protocol version 1 is rejected. - - What this tests: - --------------- - 1. Protocol v1 raises ConfigurationError - 2. Error message explains the requirement - 3. Suggests Cassandra upgrade path - - Why we require v5+: - ------------------ - Protocol v5 (Cassandra 4.0+) provides: - - Improved async operations - - Better error handling - - Enhanced performance features - - Required for some async patterns - - Protocol v1-v4 limitations: - - Missing features we depend on - - Less efficient for async operations - - Older Cassandra versions (pre-4.0) - - This ensures users have a compatible setup - before they encounter runtime issues. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(protocol_version=1) - - # Verify helpful error message - assert "Protocol version 1 is not supported" in str(exc_info.value) - assert "requires CQL protocol v5 or higher" in str(exc_info.value) - assert "Cassandra 4.0" in str(exc_info.value) - - def test_protocol_version_validation_v2(self, mock_cluster): - """ - Test that protocol version 2 is rejected. - - What this tests: - --------------- - 1. Protocol version 2 validation and rejection - 2. Clear error message for unsupported version - 3. Guidance on minimum required version - 4. Early validation before cluster creation - - Why this matters: - ---------------- - - Protocol v2 lacks async-friendly features - - Prevents runtime failures from missing capabilities - - Helps users upgrade to supported Cassandra versions - - Clear error messages reduce debugging time - - Additional context: - --------------------------------- - - Protocol v2 was used in Cassandra 2.0 - - Lacks continuous paging and other v5+ features - - Common when migrating from old clusters - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(protocol_version=2) - - assert "Protocol version 2 is not supported" in str(exc_info.value) - assert "requires CQL protocol v5 or higher" in str(exc_info.value) - - def test_protocol_version_validation_v3(self, mock_cluster): - """ - Test that protocol version 3 is rejected. - - What this tests: - --------------- - 1. Protocol version 3 validation and rejection - 2. Proper error handling for intermediate versions - 3. Consistent error messaging across versions - 4. Configuration validation at initialization - - Why this matters: - ---------------- - - Protocol v3 still lacks critical async features - - Common version in legacy deployments - - Users need clear upgrade path guidance - - Prevents subtle bugs from missing features - - Additional context: - --------------------------------- - - Protocol v3 was used in Cassandra 2.1-2.2 - - Added some features but not enough for async - - Many production clusters still use this - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(protocol_version=3) - - assert "Protocol version 3 is not supported" in str(exc_info.value) - assert "requires CQL protocol v5 or higher" in str(exc_info.value) - - def test_protocol_version_validation_v4(self, mock_cluster): - """ - Test that protocol version 4 is rejected. - - What this tests: - --------------- - 1. Protocol version 4 validation and rejection - 2. Handling of most common incompatible version - 3. Clear upgrade guidance in error message - 4. Protection against near-miss configurations - - Why this matters: - ---------------- - - Protocol v4 is extremely common (Cassandra 3.x) - - Users often assume v4 is "good enough" - - Missing v5 features cause subtle async issues - - Most frequent configuration error - - Additional context: - --------------------------------- - - Protocol v4 was standard in Cassandra 3.x - - Very close to v5 but missing key improvements - - Requires Cassandra 4.0+ upgrade for v5 - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(protocol_version=4) - - assert "Protocol version 4 is not supported" in str(exc_info.value) - assert "requires CQL protocol v5 or higher" in str(exc_info.value) - - def test_protocol_version_validation_v5(self, mock_cluster): - """ - Test that protocol version 5 is accepted. - - What this tests: - --------------- - 1. Protocol version 5 is accepted without error - 2. Minimum supported version works correctly - 3. Version is properly passed to underlying driver - 4. No warnings for supported versions - - Why this matters: - ---------------- - - Protocol v5 is our minimum requirement - - First version with all async-friendly features - - Baseline for production deployments - - Must work flawlessly as the default - - Additional context: - --------------------------------- - - Protocol v5 introduced in Cassandra 4.0 - - Adds continuous paging and duration type - - Required for optimal async performance - """ - # Should not raise - AsyncCluster(protocol_version=5) - - from async_cassandra.cluster import Cluster as ClusterImport - - call_args = ClusterImport.call_args - assert call_args.kwargs["protocol_version"] == 5 - - def test_protocol_version_validation_v6(self, mock_cluster): - """ - Test that protocol version 6 is accepted. - - What this tests: - --------------- - 1. Protocol version 6 is accepted without error - 2. Future protocol versions are supported - 3. Version is correctly propagated to driver - 4. Forward compatibility is maintained - - Why this matters: - ---------------- - - Users on latest Cassandra need v6 support - - Future-proofing for new deployments - - Enables access to latest features - - Prevents forced downgrades - - Additional context: - --------------------------------- - - Protocol v6 introduced in Cassandra 4.1 - - Adds vector types and other improvements - - Backward compatible with v5 features - """ - # Should not raise - AsyncCluster(protocol_version=6) - - from async_cassandra.cluster import Cluster as ClusterImport - - call_args = ClusterImport.call_args - assert call_args.kwargs["protocol_version"] == 6 - - def test_protocol_version_none(self, mock_cluster): - """ - Test that no protocol version allows driver negotiation. - - What this tests: - --------------- - 1. Protocol version is optional - 2. Driver can negotiate version - 3. We validate after connection - - Why this matters: - ---------------- - Allows flexibility: - - Driver picks best version - - Works with various Cassandra versions - - Fails clearly if negotiated version < 5 - """ - # Should not raise and should not set protocol_version - AsyncCluster() - - from async_cassandra.cluster import Cluster as ClusterImport - - call_args = ClusterImport.call_args - # No protocol_version means driver negotiates - assert "protocol_version" not in call_args.kwargs - - @pytest.mark.asyncio - async def test_protocol_version_mismatch_error(self, mock_cluster): - """ - Test that protocol version mismatch errors are handled properly. - - What this tests: - --------------- - 1. NoHostAvailable with protocol errors get special handling - 2. Clear error message about version mismatch - 3. Actionable advice (upgrade Cassandra) - - Why this matters: - ---------------- - Common scenario: - - User tries to connect to Cassandra 3.x - - Driver requests protocol v5 - - Server only supports v4 - - Without special handling: - - Generic "NoHostAvailable" error - - User doesn't know why connection failed - - With our handling: - - Clear message about protocol version - - Tells user to upgrade to Cassandra 4.0+ - """ - async_cluster = AsyncCluster() - - # Mock NoHostAvailable with protocol error - from cassandra.cluster import NoHostAvailable - - protocol_error = Exception("ProtocolError: Server does not support protocol version 5") - no_host_error = NoHostAvailable("Unable to connect", {"host1": protocol_error}) - - with patch("async_cassandra.cluster.AsyncCassandraSession.create") as mock_create: - mock_create.side_effect = no_host_error - - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Verify helpful error message - error_msg = str(exc_info.value) - assert "Your Cassandra server doesn't support protocol v5" in error_msg - assert "Cassandra 4.0+" in error_msg - assert "Please upgrade your Cassandra cluster" in error_msg - - @pytest.mark.asyncio - async def test_negotiated_protocol_version_too_low(self, mock_cluster): - """ - Test that negotiated protocol version < 5 is rejected after connection. - - What this tests: - --------------- - 1. Protocol validation happens after connection - 2. Session is properly closed on failure - 3. Clear error about negotiated version - - Why this matters: - ---------------- - Scenario: - - User doesn't specify protocol version - - Driver negotiates with server - - Server offers v4 (Cassandra 3.x) - - We detect this and fail cleanly - - This catches the case where: - - Connection succeeds (server is running) - - But protocol is incompatible - - Must clean up the session - - Without this check: - - Async operations might fail mysteriously - - Users get confusing errors later - """ - async_cluster = AsyncCluster() - - # Mock the cluster to return protocol_version 4 after connection - mock_cluster.protocol_version = 4 - - mock_session = Mock(spec=AsyncCassandraSession) - - # Track if close was called - close_called = False - - async def async_close(): - nonlocal close_called - close_called = True - - mock_session.close = async_close - - with patch("async_cassandra.cluster.AsyncCassandraSession.create") as mock_create: - # Make create return a coroutine that returns the session - async def create_session(cluster, keyspace): - return mock_session - - mock_create.side_effect = create_session - - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Verify specific error about negotiated version - error_msg = str(exc_info.value) - assert "Connected with protocol v4 but v5+ is required" in error_msg - assert "Your Cassandra server only supports up to protocol v4" in error_msg - assert "Cassandra 4.0+" in error_msg - - # Verify cleanup happened - assert close_called, "Session close() should have been called" diff --git a/tests/unit/test_cluster_edge_cases.py b/tests/unit/test_cluster_edge_cases.py deleted file mode 100644 index fbc9b29..0000000 --- a/tests/unit/test_cluster_edge_cases.py +++ /dev/null @@ -1,546 +0,0 @@ -""" -Unit tests for cluster edge cases and failure scenarios. - -Tests how the async wrapper handles various cluster-level failures and edge cases -within its existing functionality. -""" - -import asyncio -import time -from unittest.mock import Mock, patch - -import pytest -from cassandra.cluster import NoHostAvailable - -from async_cassandra import AsyncCluster -from async_cassandra.exceptions import ConnectionError - - -class TestClusterEdgeCases: - """Test cluster edge cases and failure scenarios.""" - - def _create_mock_cluster(self): - """Create a properly configured mock cluster.""" - mock_cluster = Mock() - mock_cluster.protocol_version = 5 - mock_cluster.shutdown = Mock() - return mock_cluster - - @pytest.mark.asyncio - async def test_protocol_version_validation(self): - """ - Test that protocol versions below v5 are rejected. - - What this tests: - --------------- - 1. Protocol v4 and below rejected - 2. ConfigurationError at creation - 3. v5+ versions accepted - 4. Clear error messages - - Why this matters: - ---------------- - async-cassandra requires v5+ for: - - Required async features - - Better performance - - Modern functionality - - Failing early prevents confusing - runtime errors. - """ - from async_cassandra.exceptions import ConfigurationError - - # Should reject v4 and below - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(protocol_version=4) - - assert "Protocol version 4 is not supported" in str(exc_info.value) - assert "requires CQL protocol v5 or higher" in str(exc_info.value) - - # Should accept v5 and above - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - - # v5 should work - cluster5 = AsyncCluster(protocol_version=5) - assert cluster5._cluster == mock_cluster - - # v6 should work - cluster6 = AsyncCluster(protocol_version=6) - assert cluster6._cluster == mock_cluster - - @pytest.mark.asyncio - async def test_connection_retry_with_protocol_error(self): - """ - Test that protocol version errors are not retried. - - What this tests: - --------------- - 1. Protocol errors fail fast - 2. No retry for version mismatch - 3. Clear error message - 4. Single attempt only - - Why this matters: - ---------------- - Protocol errors aren't transient: - - Server won't change version - - Retrying wastes time - - User needs to upgrade - - Fast failure enables quick - diagnosis and resolution. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - - # Count connection attempts - connect_count = 0 - - def connect_side_effect(*args, **kwargs): - nonlocal connect_count - connect_count += 1 - # Create NoHostAvailable with protocol error details - error = NoHostAvailable( - "Unable to connect to any servers", - {"127.0.0.1": Exception("ProtocolError: Cannot negotiate protocol version")}, - ) - raise error - - # Mock sync connect to fail with protocol error - mock_cluster.connect.side_effect = connect_side_effect - - async_cluster = AsyncCluster() - - # Should fail immediately without retrying - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Should only try once (no retries for protocol errors) - assert connect_count == 1 - assert "doesn't support protocol v5" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_connection_retry_with_reset_errors(self): - """ - Test connection retry with connection reset errors. - - What this tests: - --------------- - 1. Connection resets trigger retry - 2. Exponential backoff applied - 3. Eventually succeeds - 4. Retry timing increases - - Why this matters: - ---------------- - Connection resets are transient: - - Network hiccups - - Server restarts - - Load balancer changes - - Automatic retry with backoff - handles temporary issues gracefully. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster.protocol_version = 5 # Set a valid protocol version - mock_cluster_class.return_value = mock_cluster - - # Track timing of retries - call_times = [] - - def connect_side_effect(*args, **kwargs): - call_times.append(time.time()) - - # Fail first 2 attempts with connection reset - if len(call_times) <= 2: - error = NoHostAvailable( - "Unable to connect to any servers", - {"127.0.0.1": Exception("Connection reset by peer")}, - ) - raise error - else: - # Third attempt succeeds - mock_session = Mock() - return mock_session - - mock_cluster.connect.side_effect = connect_side_effect - - async_cluster = AsyncCluster() - - # Should eventually succeed after retries - session = await async_cluster.connect() - assert session is not None - - # Should have retried 3 times total - assert len(call_times) == 3 - - # Check retry delays increased (connection reset uses longer delays) - if len(call_times) > 2: - delay1 = call_times[1] - call_times[0] - delay2 = call_times[2] - call_times[1] - # Second delay should be longer than first - assert delay2 > delay1 - - @pytest.mark.asyncio - async def test_concurrent_connect_attempts(self): - """ - Test handling of concurrent connection attempts. - - What this tests: - --------------- - 1. Concurrent connects allowed - 2. Each gets separate session - 3. No connection reuse - 4. Thread-safe operation - - Why this matters: - ---------------- - Real apps may connect concurrently: - - Multiple workers starting - - Parallel initialization - - No singleton pattern - - Must handle concurrent connects - without deadlock or corruption. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - - # Make connect slow to ensure concurrency - connect_count = 0 - sessions_created = [] - - def slow_connect(*args, **kwargs): - nonlocal connect_count - connect_count += 1 - # This is called from an executor, so we can use time.sleep - time.sleep(0.1) - session = Mock() - session.id = connect_count - sessions_created.append(session) - return session - - mock_cluster.connect = Mock(side_effect=slow_connect) - - async_cluster = AsyncCluster() - - # Try to connect concurrently - tasks = [async_cluster.connect(), async_cluster.connect(), async_cluster.connect()] - - results = await asyncio.gather(*tasks) - - # All should return sessions - assert all(r is not None for r in results) - - # Should have called connect multiple times - # (no connection caching in current implementation) - assert mock_cluster.connect.call_count == 3 - - @pytest.mark.asyncio - async def test_cluster_shutdown_timeout(self): - """ - Test cluster shutdown with timeout. - - What this tests: - --------------- - 1. Shutdown can timeout - 2. TimeoutError raised - 3. Hanging shutdown detected - 4. Async timeout works - - Why this matters: - ---------------- - Shutdown can hang due to: - - Network issues - - Deadlocked threads - - Resource cleanup bugs - - Timeout prevents app hanging - during shutdown. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - - # Make shutdown hang - import threading - - def hanging_shutdown(): - # Use threading.Event to wait without consuming CPU - event = threading.Event() - event.wait(2) # Short wait, will be interrupted by the test timeout - - mock_cluster.shutdown.side_effect = hanging_shutdown - - async_cluster = AsyncCluster() - - # Should timeout during shutdown - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(async_cluster.shutdown(), timeout=1.0) - - @pytest.mark.asyncio - async def test_cluster_double_shutdown(self): - """ - Test that cluster shutdown is idempotent. - - What this tests: - --------------- - 1. Multiple shutdowns safe - 2. Only shuts down once - 3. is_closed flag works - 4. close() also idempotent - - Why this matters: - ---------------- - Idempotent shutdown critical for: - - Error handling paths - - Cleanup in finally blocks - - Multiple shutdown sources - - Prevents errors during cleanup - and resource leaks. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - mock_cluster.shutdown = Mock() - - async_cluster = AsyncCluster() - - # First shutdown - await async_cluster.shutdown() - assert mock_cluster.shutdown.call_count == 1 - assert async_cluster.is_closed - - # Second shutdown should be safe - await async_cluster.shutdown() - # Should still only be called once - assert mock_cluster.shutdown.call_count == 1 - assert async_cluster.is_closed - - # Third shutdown via close() - await async_cluster.close() - assert mock_cluster.shutdown.call_count == 1 - - @pytest.mark.asyncio - async def test_cluster_metadata_access(self): - """ - Test accessing cluster metadata. - - What this tests: - --------------- - 1. Metadata accessible - 2. Keyspace info available - 3. Direct passthrough - 4. No async wrapper needed - - Why this matters: - ---------------- - Metadata access enables: - - Schema discovery - - Dynamic queries - - ORM functionality - - Must work seamlessly through - async wrapper. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_metadata = Mock() - mock_metadata.keyspaces = {"system": Mock()} - mock_cluster.metadata = mock_metadata - mock_cluster_class.return_value = mock_cluster - - async_cluster = AsyncCluster() - - # Should provide access to metadata - metadata = async_cluster.metadata - assert metadata == mock_metadata - assert "system" in metadata.keyspaces - - @pytest.mark.asyncio - async def test_register_user_type(self): - """ - Test user type registration. - - What this tests: - --------------- - 1. UDT registration works - 2. Delegates to driver - 3. Parameters passed through - 4. Type mapping enabled - - Why this matters: - ---------------- - User-defined types (UDTs): - - Complex data modeling - - Type-safe operations - - ORM integration - - Registration must work for - proper UDT handling. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster.register_user_type = Mock() - mock_cluster_class.return_value = mock_cluster - - async_cluster = AsyncCluster() - - # Register a user type - class UserAddress: - pass - - async_cluster.register_user_type("my_keyspace", "address", UserAddress) - - # Should delegate to underlying cluster - mock_cluster.register_user_type.assert_called_once_with( - "my_keyspace", "address", UserAddress - ) - - @pytest.mark.asyncio - async def test_connection_with_auth_failure(self): - """ - Test connection with authentication failure. - - What this tests: - --------------- - 1. Auth failures retried - 2. Multiple attempts made - 3. Eventually fails - 4. Clear error message - - Why this matters: - ---------------- - Auth failures might be transient: - - Token expiration timing - - Auth service hiccup - - Race conditions - - Limited retry gives auth - issues chance to resolve. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - - from cassandra import AuthenticationFailed - - # Mock auth failure - auth_error = NoHostAvailable( - "Unable to connect to any servers", - {"127.0.0.1": AuthenticationFailed("Bad credentials")}, - ) - mock_cluster.connect.side_effect = auth_error - - async_cluster = AsyncCluster() - - # Should fail after retries - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Should have retried (auth errors are retried in case of transient issues) - assert mock_cluster.connect.call_count == 3 - assert "Failed to connect to cluster after 3 attempts" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_connection_with_mixed_errors(self): - """ - Test connection with different errors on different attempts. - - What this tests: - --------------- - 1. Different errors per attempt - 2. All attempts exhausted - 3. Last error reported - 4. Varied error handling - - Why this matters: - ---------------- - Real failures are messy: - - Different nodes fail differently - - Errors change over time - - Mixed failure modes - - Must handle varied errors - during connection attempts. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - - # Different error each attempt - errors = [ - NoHostAvailable( - "Unable to connect", {"127.0.0.1": Exception("Connection refused")} - ), - NoHostAvailable( - "Unable to connect", {"127.0.0.1": Exception("Connection reset by peer")} - ), - Exception("Unexpected error"), - ] - - attempt = 0 - - def connect_side_effect(*args, **kwargs): - nonlocal attempt - error = errors[attempt] - attempt += 1 - raise error - - mock_cluster.connect.side_effect = connect_side_effect - - async_cluster = AsyncCluster() - - # Should fail after all retries - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - # Should have tried all attempts - assert mock_cluster.connect.call_count == 3 - assert "Unexpected error" in str(exc_info.value) # Last error - - @pytest.mark.asyncio - async def test_create_with_auth_convenience_method(self): - """ - Test create_with_auth convenience method. - - What this tests: - --------------- - 1. Auth provider created - 2. Credentials passed correctly - 3. Other params preserved - 4. Convenience method works - - Why this matters: - ---------------- - Simple auth setup critical: - - Common use case - - Easy to get wrong - - Security sensitive - - Convenience method reduces - auth configuration errors. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster = self._create_mock_cluster() - mock_cluster_class.return_value = mock_cluster - - # Create with auth - AsyncCluster.create_with_auth( - contact_points=["10.0.0.1"], username="cassandra", password="cassandra", port=9043 - ) - - # Verify auth provider was created - call_kwargs = mock_cluster_class.call_args[1] - assert "auth_provider" in call_kwargs - auth_provider = call_kwargs["auth_provider"] - assert auth_provider is not None - # Verify other params - assert call_kwargs["contact_points"] == ["10.0.0.1"] - assert call_kwargs["port"] == 9043 diff --git a/tests/unit/test_cluster_retry.py b/tests/unit/test_cluster_retry.py deleted file mode 100644 index 76de897..0000000 --- a/tests/unit/test_cluster_retry.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Unit tests for cluster connection retry logic. -""" - -import asyncio -from unittest.mock import Mock, patch - -import pytest -from cassandra.cluster import NoHostAvailable - -from async_cassandra.cluster import AsyncCluster -from async_cassandra.exceptions import ConnectionError - - -@pytest.mark.asyncio -class TestClusterConnectionRetry: - """Test cluster connection retry behavior.""" - - async def test_connection_retries_on_failure(self): - """ - Test that connection attempts are retried on failure. - - What this tests: - --------------- - 1. Failed connections retry - 2. Third attempt succeeds - 3. Total of 3 attempts - 4. Eventually returns session - - Why this matters: - ---------------- - Connection failures are common: - - Network hiccups - - Node startup delays - - Temporary unavailability - - Automatic retry improves - reliability significantly. - """ - mock_cluster = Mock() - # Mock protocol version to pass validation - mock_cluster.protocol_version = 5 - - # Create a mock that fails twice then succeeds - connect_attempts = 0 - mock_session = Mock() - - async def create_side_effect(cluster, keyspace): - nonlocal connect_attempts - connect_attempts += 1 - if connect_attempts < 3: - raise NoHostAvailable("Unable to connect to any servers", {}) - return mock_session # Return a mock session on third attempt - - with patch("async_cassandra.cluster.Cluster", return_value=mock_cluster): - with patch( - "async_cassandra.cluster.AsyncCassandraSession.create", - side_effect=create_side_effect, - ): - cluster = AsyncCluster(["localhost"]) - - # Should succeed after retries - session = await cluster.connect() - assert session is not None - assert connect_attempts == 3 - - async def test_connection_fails_after_max_retries(self): - """ - Test that connection fails after maximum retry attempts. - - What this tests: - --------------- - 1. Max retry limit enforced - 2. Exactly 3 attempts made - 3. ConnectionError raised - 4. Clear failure message - - Why this matters: - ---------------- - Must give up eventually: - - Prevent infinite loops - - Fail with clear error - - Allow app to handle - - Bounded retries prevent - hanging applications. - """ - mock_cluster = Mock() - # Mock protocol version to pass validation - mock_cluster.protocol_version = 5 - - create_call_count = 0 - - async def create_side_effect(cluster, keyspace): - nonlocal create_call_count - create_call_count += 1 - raise NoHostAvailable("Unable to connect to any servers", {}) - - with patch("async_cassandra.cluster.Cluster", return_value=mock_cluster): - with patch( - "async_cassandra.cluster.AsyncCassandraSession.create", - side_effect=create_side_effect, - ): - cluster = AsyncCluster(["localhost"]) - - # Should fail after max retries (3) - with pytest.raises(ConnectionError) as exc_info: - await cluster.connect() - - assert "Failed to connect to cluster after 3 attempts" in str(exc_info.value) - assert create_call_count == 3 - - async def test_connection_retry_with_increasing_delay(self): - """ - Test that retry delays increase with each attempt. - - What this tests: - --------------- - 1. Delays between retries - 2. Exponential backoff - 3. NoHostAvailable gets longer delays - 4. Prevents thundering herd - - Why this matters: - ---------------- - Exponential backoff: - - Reduces server load - - Allows recovery time - - Prevents retry storms - - Smart retry timing improves - overall system stability. - """ - mock_cluster = Mock() - # Mock protocol version to pass validation - mock_cluster.protocol_version = 5 - - # Fail all attempts - async def create_side_effect(cluster, keyspace): - raise NoHostAvailable("Unable to connect to any servers", {}) - - sleep_delays = [] - - async def mock_sleep(delay): - sleep_delays.append(delay) - - with patch("async_cassandra.cluster.Cluster", return_value=mock_cluster): - with patch( - "async_cassandra.cluster.AsyncCassandraSession.create", - side_effect=create_side_effect, - ): - with patch("asyncio.sleep", side_effect=mock_sleep): - cluster = AsyncCluster(["localhost"]) - - with pytest.raises(ConnectionError): - await cluster.connect() - - # Should have 2 sleep calls (between 3 attempts) - assert len(sleep_delays) == 2 - # First delay should be 2.0 seconds (NoHostAvailable gets longer delay) - assert sleep_delays[0] == 2.0 - # Second delay should be 4.0 seconds - assert sleep_delays[1] == 4.0 - - async def test_timeout_error_not_retried(self): - """ - Test that asyncio.TimeoutError is not retried. - - What this tests: - --------------- - 1. Timeouts fail immediately - 2. No retry for timeouts - 3. TimeoutError propagated - 4. Fast failure mode - - Why this matters: - ---------------- - Timeouts indicate: - - User-specified limit hit - - Operation too slow - - Should fail fast - - Retrying timeouts would - violate user expectations. - """ - mock_cluster = Mock() - - # Create session that takes too long - async def slow_connect(keyspace=None): - await asyncio.sleep(20) # Longer than timeout - return Mock() - - mock_cluster.connect = Mock(side_effect=lambda k=None: Mock()) - - with patch("async_cassandra.cluster.Cluster", return_value=mock_cluster): - with patch( - "async_cassandra.session.AsyncCassandraSession.create", - side_effect=asyncio.TimeoutError(), - ): - cluster = AsyncCluster(["localhost"]) - - # Should raise TimeoutError without retrying - with pytest.raises(asyncio.TimeoutError): - await cluster.connect(timeout=0.1) - - # Should not have retried (create was called only once) - - async def test_other_exceptions_use_shorter_delay(self): - """ - Test that non-NoHostAvailable exceptions use shorter retry delay. - - What this tests: - --------------- - 1. Different delays by error type - 2. Generic errors get short delay - 3. NoHostAvailable gets long delay - 4. Smart backoff strategy - - Why this matters: - ---------------- - Error-specific delays: - - Network errors need more time - - Generic errors retry quickly - - Optimizes recovery time - - Adaptive retry delays improve - connection success rates. - """ - mock_cluster = Mock() - # Mock protocol version to pass validation - mock_cluster.protocol_version = 5 - - # Fail with generic exception - async def create_side_effect(cluster, keyspace): - raise Exception("Generic error") - - sleep_delays = [] - - async def mock_sleep(delay): - sleep_delays.append(delay) - - with patch("async_cassandra.cluster.Cluster", return_value=mock_cluster): - with patch( - "async_cassandra.cluster.AsyncCassandraSession.create", - side_effect=create_side_effect, - ): - with patch("asyncio.sleep", side_effect=mock_sleep): - cluster = AsyncCluster(["localhost"]) - - with pytest.raises(ConnectionError): - await cluster.connect() - - # Should have 2 sleep calls - assert len(sleep_delays) == 2 - # First delay should be 0.5 seconds (generic exception) - assert sleep_delays[0] == 0.5 - # Second delay should be 1.0 seconds - assert sleep_delays[1] == 1.0 diff --git a/tests/unit/test_connection_pool_exhaustion.py b/tests/unit/test_connection_pool_exhaustion.py deleted file mode 100644 index b9b4b6a..0000000 --- a/tests/unit/test_connection_pool_exhaustion.py +++ /dev/null @@ -1,622 +0,0 @@ -""" -Unit tests for connection pool exhaustion scenarios. - -Tests how the async wrapper handles: -- Pool exhaustion under high load -- Connection borrowing timeouts -- Pool recovery after exhaustion -- Connection health checks - -Test Organization: -================== -1. Pool Exhaustion - Running out of connections -2. Borrowing Timeouts - Waiting for available connections -3. Recovery - Pool recovering after exhaustion -4. Health Checks - Connection health monitoring -5. Metrics - Tracking pool usage and exhaustion -6. Graceful Degradation - Prioritizing critical queries - -Key Testing Principles: -====================== -- Simulate realistic pool limits -- Test concurrent access patterns -- Verify recovery mechanisms -- Track exhaustion metrics -""" - -import asyncio -from unittest.mock import Mock - -import pytest -from cassandra import OperationTimedOut -from cassandra.cluster import Session -from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable - -from async_cassandra import AsyncCassandraSession - - -class TestConnectionPoolExhaustion: - """Test connection pool exhaustion scenarios.""" - - @pytest.fixture - def mock_session(self): - """Create a mock session with connection pool.""" - session = Mock(spec=Session) - session.execute_async = Mock() - session.cluster = Mock() - - # Mock pool manager - session.cluster._core_connections_per_host = 2 - session.cluster._max_connections_per_host = 8 - - return session - - @pytest.fixture - def mock_connection_pool(self): - """Create a mock connection pool.""" - pool = Mock(spec=HostConnectionPool) - pool.host = Mock(spec=Host, address="127.0.0.1") - pool.is_shutdown = False - pool.open_count = 0 - pool.in_flight = 0 - return pool - - def create_error_future(self, exception): - """Create a mock future that raises the given exception.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def create_success_future(self, result): - """Create a mock future that returns a result.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - # For success, the callback expects an iterable of rows - mock_rows = [result] if result else [] - callback(mock_rows) - if errback: - errbacks.append(errback) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.mark.asyncio - async def test_pool_exhaustion_under_load(self, mock_session): - """ - Test behavior when connection pool is exhausted. - - What this tests: - --------------- - 1. Pool has finite connection limit - 2. Excess queries fail with NoConnectionsAvailable - 3. Exceptions passed through directly - 4. Success/failure count matches pool size - - Why this matters: - ---------------- - Connection pools prevent resource exhaustion: - - Each connection uses memory/CPU - - Database has connection limits - - Pool size must be tuned - - Applications need direct access to - handle pool exhaustion with retries. - """ - async_session = AsyncCassandraSession(mock_session) - - # Configure mock to simulate pool exhaustion after N requests - pool_size = 5 - request_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal request_count - request_count += 1 - - if request_count > pool_size: - # Pool exhausted - return self.create_error_future(NoConnectionsAvailable("Connection pool exhausted")) - - # Success response - return self.create_success_future({"id": request_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Try to execute more queries than pool size - tasks = [] - for i in range(pool_size + 3): # 3 more than pool size - tasks.append(async_session.execute(f"SELECT * FROM test WHERE id = {i}")) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # First pool_size queries should succeed - successful = [r for r in results if not isinstance(r, Exception)] - # NoConnectionsAvailable is now passed through directly - failed = [r for r in results if isinstance(r, NoConnectionsAvailable)] - - assert len(successful) == pool_size - assert len(failed) == 3 - - @pytest.mark.asyncio - async def test_connection_borrowing_timeout(self, mock_session): - """ - Test timeout when waiting for available connection. - - What this tests: - --------------- - 1. Waiting for connections can timeout - 2. OperationTimedOut raised - 3. Clear error message - 4. Not wrapped (driver exception) - - Why this matters: - ---------------- - When pool is exhausted, queries wait. - If wait is too long: - - Client timeout exceeded - - Better to fail fast - - Allow retry with backoff - - Timeouts prevent indefinite blocking. - """ - async_session = AsyncCassandraSession(mock_session) - - # Simulate all connections busy - mock_session.execute_async.return_value = self.create_error_future( - OperationTimedOut("Timed out waiting for connection from pool") - ) - - # Should timeout waiting for connection - with pytest.raises(OperationTimedOut) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "waiting for connection" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_pool_recovery_after_exhaustion(self, mock_session): - """ - Test that pool recovers after temporary exhaustion. - - What this tests: - --------------- - 1. Pool exhaustion is temporary - 2. Connections return to pool - 3. New queries succeed after recovery - 4. No permanent failure - - Why this matters: - ---------------- - Pool exhaustion often transient: - - Burst of traffic - - Slow queries holding connections - - Temporary spike - - Applications should retry after - brief delay for pool recovery. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track pool state - query_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal query_count - query_count += 1 - - if query_count <= 3: - # First 3 queries fail - return self.create_error_future(NoConnectionsAvailable("Pool exhausted")) - - # Subsequent queries succeed - return self.create_success_future({"id": query_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # First attempts fail - for i in range(3): - with pytest.raises(NoConnectionsAvailable): - await async_session.execute("SELECT * FROM test") - - # Wait a bit (simulating pool recovery) - await asyncio.sleep(0.1) - - # Next attempt should succeed - result = await async_session.execute("SELECT * FROM test") - assert result.rows[0]["id"] == 4 - - @pytest.mark.asyncio - async def test_connection_health_checks(self, mock_session, mock_connection_pool): - """ - Test connection health checking during pool management. - - What this tests: - --------------- - 1. Unhealthy connections detected - 2. Bad connections removed from pool - 3. Health checks periodic - 4. Pool maintains health - - Why this matters: - ---------------- - Connections can become unhealthy: - - Network issues - - Server restarts - - Idle timeouts - - Health checks ensure pool only - contains usable connections. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock pool with health check capability - mock_session._pools = {Mock(address="127.0.0.1"): mock_connection_pool} - - # Since AsyncCassandraSession doesn't have these methods, - # we'll test by simulating health checks through queries - health_check_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal health_check_count - health_check_count += 1 - # Every 3rd query simulates unhealthy connection - if health_check_count % 3 == 0: - return self.create_error_future(NoConnectionsAvailable("Connection unhealthy")) - return self.create_success_future({"healthy": True}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Execute queries to simulate health checks - results = [] - for i in range(5): - try: - result = await async_session.execute(f"SELECT {i}") - results.append(result) - except NoConnectionsAvailable: # NoConnectionsAvailable is now passed through directly - results.append(None) - - # Should have 1 failure (3rd query) - assert sum(1 for r in results if r is None) == 1 - assert sum(1 for r in results if r is not None) == 4 - assert health_check_count == 5 - - @pytest.mark.asyncio - async def test_concurrent_pool_exhaustion(self, mock_session): - """ - Test multiple threads hitting pool exhaustion simultaneously. - - What this tests: - --------------- - 1. Concurrent queries compete for connections - 2. Pool limits enforced under concurrency - 3. Some queries fail, some succeed - 4. No race conditions or corruption - - Why this matters: - ---------------- - Real applications have concurrent load: - - Multiple API requests - - Background jobs - - Batch processing - - Pool must handle concurrent access - safely without deadlocks. - """ - async_session = AsyncCassandraSession(mock_session) - - # Simulate limited pool - available_connections = 2 - lock = asyncio.Lock() - - async def acquire_connection(): - async with lock: - nonlocal available_connections - if available_connections > 0: - available_connections -= 1 - return True - return False - - async def release_connection(): - async with lock: - nonlocal available_connections - available_connections += 1 - - async def execute_with_pool_limit(*args, **kwargs): - if await acquire_connection(): - try: - await asyncio.sleep(0.1) # Hold connection - return Mock(one=Mock(return_value={"success": True})) - finally: - await release_connection() - else: - raise NoConnectionsAvailable("No connections available") - - # Mock limited pool behavior - concurrent_count = 0 - max_concurrent = 2 - - def execute_async_side_effect(*args, **kwargs): - nonlocal concurrent_count - - if concurrent_count >= max_concurrent: - return self.create_error_future(NoConnectionsAvailable("No connections available")) - - concurrent_count += 1 - # Simulate delayed response - return self.create_success_future({"success": True}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Try to execute many concurrent queries - tasks = [async_session.execute(f"SELECT {i}") for i in range(10)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Should have mix of successes and failures - successes = sum(1 for r in results if not isinstance(r, Exception)) - failures = sum(1 for r in results if isinstance(r, NoConnectionsAvailable)) - - assert successes >= max_concurrent - assert failures > 0 - - @pytest.mark.asyncio - async def test_pool_metrics_tracking(self, mock_session, mock_connection_pool): - """ - Test tracking of pool metrics during exhaustion. - - What this tests: - --------------- - 1. Borrow attempts counted - 2. Timeouts tracked - 3. Exhaustion events recorded - 4. Metrics help diagnose issues - - Why this matters: - ---------------- - Pool metrics are critical for: - - Capacity planning - - Performance tuning - - Alerting on exhaustion - - Debugging production issues - - Without metrics, pool problems - are invisible until failure. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track pool metrics - metrics = { - "borrow_attempts": 0, - "borrow_timeouts": 0, - "pool_exhausted_events": 0, - "max_waiters": 0, - } - - def track_borrow_attempt(): - metrics["borrow_attempts"] += 1 - - def track_borrow_timeout(): - metrics["borrow_timeouts"] += 1 - - def track_pool_exhausted(): - metrics["pool_exhausted_events"] += 1 - - # Simulate pool exhaustion scenario - attempt = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal attempt - attempt += 1 - track_borrow_attempt() - - if attempt <= 3: - track_pool_exhausted() - raise NoConnectionsAvailable("Pool exhausted") - elif attempt == 4: - track_borrow_timeout() - raise OperationTimedOut("Timeout waiting for connection") - else: - return self.create_success_future({"metrics": "ok"}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Execute queries to trigger various pool states - for i in range(6): - try: - await async_session.execute(f"SELECT {i}") - except Exception: - pass - - # Verify metrics were tracked - assert metrics["borrow_attempts"] == 6 - assert metrics["pool_exhausted_events"] == 3 - assert metrics["borrow_timeouts"] == 1 - - @pytest.mark.asyncio - async def test_pool_size_limits(self, mock_session): - """ - Test respecting min/max connection limits. - - What this tests: - --------------- - 1. Pool respects maximum size - 2. Minimum connections maintained - 3. Cannot exceed limits - 4. Queries work within limits - - Why this matters: - ---------------- - Pool limits prevent: - - Resource exhaustion (max) - - Cold start delays (min) - - Database overload - - Proper limits balance resource - usage with performance. - """ - async_session = AsyncCassandraSession(mock_session) - - # Configure pool limits - min_connections = 2 - max_connections = 10 - current_connections = min_connections - - async def adjust_pool_size(target_size): - nonlocal current_connections - if target_size > max_connections: - raise ValueError(f"Cannot exceed max connections: {max_connections}") - elif target_size < min_connections: - raise ValueError(f"Cannot go below min connections: {min_connections}") - current_connections = target_size - return current_connections - - # AsyncCassandraSession doesn't have _adjust_pool_size method - # Test pool limits through query behavior instead - query_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal query_count - query_count += 1 - - # Normal queries succeed - return self.create_success_future({"size": query_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Test that we can execute queries up to max_connections - results = [] - for i in range(max_connections): - result = await async_session.execute(f"SELECT {i}") - results.append(result) - - # Verify all queries succeeded - assert len(results) == max_connections - assert results[0].rows[0]["size"] == 1 - assert results[-1].rows[0]["size"] == max_connections - - @pytest.mark.asyncio - async def test_connection_leak_detection(self, mock_session): - """ - Test detection of connection leaks during pool exhaustion. - - What this tests: - --------------- - 1. Connections not returned detected - 2. Leak threshold triggers detection - 3. Borrowed connections tracked - 4. Leaks identified for debugging - - Why this matters: - ---------------- - Connection leaks cause: - - Pool exhaustion - - Performance degradation - - Resource waste - - Early leak detection prevents - production outages. - """ - async_session = AsyncCassandraSession(mock_session) # noqa: F841 - - # Track borrowed connections - borrowed_connections = set() - leak_detected = False - - async def borrow_connection(query_id): - nonlocal leak_detected - borrowed_connections.add(query_id) - if len(borrowed_connections) > 5: # Threshold for leak detection - leak_detected = True - return Mock(id=query_id) - - async def return_connection(query_id): - borrowed_connections.discard(query_id) - - # Simulate queries that don't properly return connections - for i in range(10): - await borrow_connection(f"query_{i}") - # Simulate some queries not returning connections (leak) - # Only return every 3rd connection (i=0,3,6,9) - if i % 3 == 0: # Return only some connections - await return_connection(f"query_{i}") - - # Should detect potential leak - # We borrow 10 but only return 4 (0,3,6,9), leaving 6 in borrowed_connections - assert len(borrowed_connections) == 6 # 1,2,4,5,7,8 are still borrowed - assert leak_detected # Should be True since we have > 5 borrowed - - @pytest.mark.asyncio - async def test_graceful_degradation(self, mock_session): - """ - Test graceful degradation when pool is under pressure. - - What this tests: - --------------- - 1. Critical queries prioritized - 2. Non-critical queries rejected - 3. System remains stable - 4. Important work continues - - Why this matters: - ---------------- - Under extreme load: - - Not all queries equal priority - - Critical paths must work - - Better partial service than none - - Graceful degradation maintains - core functionality during stress. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track query attempts and degradation - degradation_active = False - - def execute_async_side_effect(*args, **kwargs): - nonlocal degradation_active - - # Check if it's a critical query - query = args[0] if args else kwargs.get("query", "") - is_critical = "CRITICAL" in str(query) - - if degradation_active and not is_critical: - # Reject non-critical queries during degradation - raise NoConnectionsAvailable("Pool exhausted - non-critical queries rejected") - - return self.create_success_future({"result": "ok"}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Normal operation - result = await async_session.execute("SELECT * FROM test") - assert result.rows[0]["result"] == "ok" - - # Activate degradation - degradation_active = True - - # Non-critical query should fail - with pytest.raises(NoConnectionsAvailable): - await async_session.execute("SELECT * FROM test") - - # Critical query should still work - result = await async_session.execute("CRITICAL: SELECT * FROM system.local") - assert result.rows[0]["result"] == "ok" diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py deleted file mode 100644 index bc6b9a2..0000000 --- a/tests/unit/test_constants.py +++ /dev/null @@ -1,343 +0,0 @@ -""" -Unit tests for constants module. -""" - -import pytest - -from async_cassandra.constants import ( - DEFAULT_CONNECTION_TIMEOUT, - DEFAULT_EXECUTOR_THREADS, - DEFAULT_FETCH_SIZE, - DEFAULT_REQUEST_TIMEOUT, - MAX_CONCURRENT_QUERIES, - MAX_EXECUTOR_THREADS, - MAX_RETRY_ATTEMPTS, - MIN_EXECUTOR_THREADS, -) - - -class TestConstants: - """Test all constants are properly defined and have reasonable values.""" - - def test_default_values(self): - """ - Test default values are reasonable. - - What this tests: - --------------- - 1. Fetch size is 1000 - 2. Default threads is 4 - 3. Connection timeout 30s - 4. Request timeout 120s - - Why this matters: - ---------------- - Default values affect: - - Performance out-of-box - - Resource consumption - - Timeout behavior - - Good defaults mean most - apps work without tuning. - """ - assert DEFAULT_FETCH_SIZE == 1000 - assert DEFAULT_EXECUTOR_THREADS == 4 - assert DEFAULT_CONNECTION_TIMEOUT == 30.0 # Increased for larger heap sizes - assert DEFAULT_REQUEST_TIMEOUT == 120.0 - - def test_limits(self): - """ - Test limit values are reasonable. - - What this tests: - --------------- - 1. Max queries is 100 - 2. Max retries is 3 - 3. Values not too high - 4. Values not too low - - Why this matters: - ---------------- - Limits prevent: - - Resource exhaustion - - Infinite retries - - System overload - - Reasonable limits protect - production systems. - """ - assert MAX_CONCURRENT_QUERIES == 100 - assert MAX_RETRY_ATTEMPTS == 3 - - def test_thread_pool_settings(self): - """ - Test thread pool settings are reasonable. - - What this tests: - --------------- - 1. Min threads >= 1 - 2. Max threads <= 128 - 3. Min < Max relationship - 4. Default within bounds - - Why this matters: - ---------------- - Thread pool sizing affects: - - Concurrent operations - - Memory usage - - CPU utilization - - Proper bounds prevent thread - explosion and starvation. - """ - assert MIN_EXECUTOR_THREADS == 1 - assert MAX_EXECUTOR_THREADS == 128 - assert MIN_EXECUTOR_THREADS < MAX_EXECUTOR_THREADS - assert MIN_EXECUTOR_THREADS <= DEFAULT_EXECUTOR_THREADS <= MAX_EXECUTOR_THREADS - - def test_timeout_relationships(self): - """ - Test timeout values have reasonable relationships. - - What this tests: - --------------- - 1. Connection < Request timeout - 2. Both timeouts positive - 3. Logical ordering - 4. No zero timeouts - - Why this matters: - ---------------- - Timeout ordering ensures: - - Connect fails before request - - Clear failure modes - - No hanging operations - - Prevents confusing timeout - cascades in production. - """ - # Connection timeout should be less than request timeout - assert DEFAULT_CONNECTION_TIMEOUT < DEFAULT_REQUEST_TIMEOUT - # Both should be positive - assert DEFAULT_CONNECTION_TIMEOUT > 0 - assert DEFAULT_REQUEST_TIMEOUT > 0 - - def test_fetch_size_reasonable(self): - """ - Test fetch size is within reasonable bounds. - - What this tests: - --------------- - 1. Fetch size positive - 2. Not too large (<=10k) - 3. Efficient batching - 4. Memory reasonable - - Why this matters: - ---------------- - Fetch size affects: - - Memory per query - - Network efficiency - - Latency vs throughput - - Balance prevents OOM while - maintaining performance. - """ - assert DEFAULT_FETCH_SIZE > 0 - assert DEFAULT_FETCH_SIZE <= 10000 # Not too large - - def test_concurrent_queries_reasonable(self): - """ - Test concurrent queries limit is reasonable. - - What this tests: - --------------- - 1. Positive limit - 2. Not too high (<=1000) - 3. Allows parallelism - 4. Prevents overload - - Why this matters: - ---------------- - Query limits prevent: - - Connection exhaustion - - Memory explosion - - Cassandra overload - - Protects both client and - server from abuse. - """ - assert MAX_CONCURRENT_QUERIES > 0 - assert MAX_CONCURRENT_QUERIES <= 1000 # Not too large - - def test_retry_attempts_reasonable(self): - """ - Test retry attempts is reasonable. - - What this tests: - --------------- - 1. At least 1 retry - 2. Max 10 retries - 3. Not infinite - 4. Allows recovery - - Why this matters: - ---------------- - Retry limits balance: - - Transient error recovery - - Avoiding retry storms - - Fail-fast behavior - - Too many retries hurt - more than help. - """ - assert MAX_RETRY_ATTEMPTS > 0 - assert MAX_RETRY_ATTEMPTS <= 10 # Not too many - - def test_constant_types(self): - """ - Test constants have correct types. - - What this tests: - --------------- - 1. Integers are int - 2. Timeouts are float - 3. No string types - 4. Type consistency - - Why this matters: - ---------------- - Type safety ensures: - - No runtime conversions - - Clear API contracts - - Predictable behavior - - Wrong types cause subtle - bugs in production. - """ - assert isinstance(DEFAULT_FETCH_SIZE, int) - assert isinstance(DEFAULT_EXECUTOR_THREADS, int) - assert isinstance(DEFAULT_CONNECTION_TIMEOUT, float) - assert isinstance(DEFAULT_REQUEST_TIMEOUT, float) - assert isinstance(MAX_CONCURRENT_QUERIES, int) - assert isinstance(MAX_RETRY_ATTEMPTS, int) - assert isinstance(MIN_EXECUTOR_THREADS, int) - assert isinstance(MAX_EXECUTOR_THREADS, int) - - def test_constants_immutable(self): - """ - Test that constants cannot be modified (basic check). - - What this tests: - --------------- - 1. All constants uppercase - 2. Follow Python convention - 3. Clear naming pattern - 4. Module organization - - Why this matters: - ---------------- - Naming conventions: - - Signal immutability - - Improve readability - - Prevent accidents - - UPPERCASE warns developers - not to modify values. - """ - # This is more of a convention test - Python doesn't have true constants - # But we can verify the module defines them properly - import async_cassandra.constants as constants_module - - # Verify all constants are uppercase (Python convention) - for attr_name in dir(constants_module): - if not attr_name.startswith("_"): - attr_value = getattr(constants_module, attr_name) - if isinstance(attr_value, (int, float, str)): - assert attr_name.isupper(), f"Constant {attr_name} should be uppercase" - - @pytest.mark.parametrize( - "constant_name,min_value,max_value", - [ - ("DEFAULT_FETCH_SIZE", 1, 50000), - ("DEFAULT_EXECUTOR_THREADS", 1, 32), - ("DEFAULT_CONNECTION_TIMEOUT", 1.0, 60.0), - ("DEFAULT_REQUEST_TIMEOUT", 10.0, 600.0), - ("MAX_CONCURRENT_QUERIES", 10, 10000), - ("MAX_RETRY_ATTEMPTS", 1, 20), - ("MIN_EXECUTOR_THREADS", 1, 4), - ("MAX_EXECUTOR_THREADS", 32, 256), - ], - ) - def test_constant_ranges(self, constant_name, min_value, max_value): - """ - Test that constants are within expected ranges. - - What this tests: - --------------- - 1. Each constant in range - 2. Not too small - 3. Not too large - 4. Sensible values - - Why this matters: - ---------------- - Range validation prevents: - - Extreme configurations - - Performance problems - - Resource issues - - Catches config errors - before deployment. - """ - import async_cassandra.constants as constants_module - - value = getattr(constants_module, constant_name) - assert ( - min_value <= value <= max_value - ), f"{constant_name} value {value} is outside expected range [{min_value}, {max_value}]" - - def test_no_missing_constants(self): - """ - Test that all expected constants are defined. - - What this tests: - --------------- - 1. All constants present - 2. No missing values - 3. No extra constants - 4. API completeness - - Why this matters: - ---------------- - Complete constants ensure: - - No hardcoded values - - Consistent configuration - - Clear tuning points - - Missing constants force - magic numbers in code. - """ - expected_constants = { - "DEFAULT_FETCH_SIZE", - "DEFAULT_EXECUTOR_THREADS", - "DEFAULT_CONNECTION_TIMEOUT", - "DEFAULT_REQUEST_TIMEOUT", - "MAX_CONCURRENT_QUERIES", - "MAX_RETRY_ATTEMPTS", - "MIN_EXECUTOR_THREADS", - "MAX_EXECUTOR_THREADS", - } - - import async_cassandra.constants as constants_module - - module_constants = { - name for name in dir(constants_module) if not name.startswith("_") and name.isupper() - } - - missing = expected_constants - module_constants - assert not missing, f"Missing constants: {missing}" - - # Also check no unexpected constants - unexpected = module_constants - expected_constants - assert not unexpected, f"Unexpected constants: {unexpected}" diff --git a/tests/unit/test_context_manager_safety.py b/tests/unit/test_context_manager_safety.py deleted file mode 100644 index 42c20f6..0000000 --- a/tests/unit/test_context_manager_safety.py +++ /dev/null @@ -1,854 +0,0 @@ -""" -Unit tests for context manager safety. - -These tests ensure that context managers only close what they should, -and don't accidentally close shared resources like clusters and sessions -when errors occur. -""" - -import asyncio -import threading -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from async_cassandra import AsyncCassandraSession, AsyncCluster -from async_cassandra.exceptions import QueryError -from async_cassandra.streaming import AsyncStreamingResultSet - - -class TestContextManagerSafety: - """Test that context managers don't close shared resources inappropriately.""" - - @pytest.mark.asyncio - async def test_cluster_context_manager_closes_only_cluster(self): - """ - Test that cluster context manager only closes the cluster, - not any sessions created from it. - - What this tests: - --------------- - 1. Cluster context manager closes cluster - 2. Sessions remain open after cluster exit - 3. Resources properly scoped - 4. No premature cleanup - - Why this matters: - ---------------- - Context managers must respect ownership: - - Cluster owns its lifecycle - - Sessions own their lifecycle - - No cross-contamination - - Prevents accidental resource cleanup - that breaks active operations. - """ - mock_cluster = MagicMock() - mock_cluster.shutdown = MagicMock() # Not AsyncMock because it's called via run_in_executor - mock_cluster.connect = AsyncMock() - mock_cluster.protocol_version = 5 # Mock protocol version - - # Create a mock session that should NOT be closed by cluster context manager - mock_session = MagicMock() - mock_session.close = AsyncMock() - mock_cluster.connect.return_value = mock_session - - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster_class.return_value = mock_cluster - - # Mock AsyncCassandraSession.create - mock_async_session = MagicMock() - mock_async_session._session = mock_session - mock_async_session.close = AsyncMock() - - with patch( - "async_cassandra.session.AsyncCassandraSession.create", new_callable=AsyncMock - ) as mock_create: - mock_create.return_value = mock_async_session - - # Use cluster in context manager - async with AsyncCluster(["localhost"]) as cluster: - # Create a session - session = await cluster.connect() - - # Session should be the mock we created - assert session._session == mock_session - - # Cluster should be shut down - mock_cluster.shutdown.assert_called_once() - - # But session should NOT be closed - mock_session.close.assert_not_called() - - @pytest.mark.asyncio - async def test_session_context_manager_closes_only_session(self): - """ - Test that session context manager only closes the session, - not the cluster it came from. - - What this tests: - --------------- - 1. Session context closes session - 2. Cluster remains open - 3. Independent lifecycles - 4. Clean resource separation - - Why this matters: - ---------------- - Sessions don't own clusters: - - Multiple sessions per cluster - - Cluster outlives sessions - - Sessions are lightweight - - Critical for connection pooling - and resource efficiency. - """ - mock_cluster = MagicMock() - mock_cluster.shutdown = MagicMock() # Not AsyncMock because it's called via run_in_executor - mock_session = MagicMock() - mock_session.shutdown = MagicMock() # AsyncCassandraSession calls shutdown, not close - - # Create AsyncCassandraSession with mocks - async_session = AsyncCassandraSession(mock_session) - - # Use session in context manager - async with async_session: - # Do some work - pass - - # Session should be shut down - mock_session.shutdown.assert_called_once() - - # But cluster should NOT be shut down - mock_cluster.shutdown.assert_not_called() - - @pytest.mark.asyncio - async def test_streaming_context_manager_closes_only_stream(self): - """ - Test that streaming result context manager only closes the stream, - not the session or cluster. - - What this tests: - --------------- - 1. Stream context closes stream - 2. Session remains open - 3. Callbacks cleaned up - 4. No session interference - - Why this matters: - ---------------- - Streams are ephemeral resources: - - One query = one stream - - Session handles many queries - - Stream cleanup is isolated - - Ensures streaming doesn't break - session for other queries. - """ - # Create mock response future - mock_future = MagicMock() - mock_future.has_more_pages = False - mock_future._final_exception = None - mock_future.add_callbacks = MagicMock() - mock_future.clear_callbacks = MagicMock() - - # Create mock session (should NOT be closed) - mock_session = MagicMock() - mock_session.close = AsyncMock() - - # Create streaming result - stream_result = AsyncStreamingResultSet(mock_future) - stream_result._handle_page(["row1", "row2", "row3"]) - - # Use streaming result in context manager - async with stream_result as stream: - # Process some data - rows = [] - async for row in stream: - rows.append(row) - - # Stream callbacks should be cleaned up - mock_future.clear_callbacks.assert_called() - - # But session should NOT be closed - mock_session.close.assert_not_called() - - @pytest.mark.asyncio - async def test_query_error_doesnt_close_session(self): - """ - Test that a query error doesn't close the session. - - What this tests: - --------------- - 1. Query errors don't close session - 2. Session remains usable - 3. Error handling isolated - 4. No cascade failures - - Why this matters: - ---------------- - Query errors are normal: - - Bad syntax happens - - Tables may not exist - - Timeouts occur - - Session must survive individual - query failures. - """ - mock_session = MagicMock() - mock_session.close = AsyncMock() - - # Create a session that will raise an error - async_session = AsyncCassandraSession(mock_session) - - # Mock execute to raise an error - with patch.object(async_session, "execute", side_effect=QueryError("Bad query")): - try: - await async_session.execute("SELECT * FROM bad_table") - except QueryError: - pass # Expected - - # Session should NOT be closed due to query error - mock_session.close.assert_not_called() - - @pytest.mark.asyncio - async def test_streaming_error_doesnt_close_session(self): - """ - Test that an error during streaming doesn't close the session. - - This test verifies that when a streaming operation fails, - it doesn't accidentally close the session that might be - used by other concurrent operations. - - What this tests: - --------------- - 1. Streaming errors isolated - 2. Session unaffected by stream errors - 3. Concurrent operations continue - 4. Error containment works - - Why this matters: - ---------------- - Streaming failures common: - - Network interruptions - - Large result timeouts - - Memory pressure - - Other queries must continue - despite streaming failures. - """ - mock_session = MagicMock() - mock_session.close = AsyncMock() - - # For this test, we just need to verify that streaming errors - # are isolated and don't affect the session. - # The actual streaming error handling is tested elsewhere. - - # Create a simple async function that raises an error - async def failing_operation(): - raise Exception("Streaming error") - - # Run the failing operation - with pytest.raises(Exception, match="Streaming error"): - await failing_operation() - - # Session should NOT be closed - mock_session.close.assert_not_called() - - @pytest.mark.asyncio - async def test_concurrent_session_usage_during_error(self): - """ - Test that other coroutines can still use the session when - one coroutine has an error. - - What this tests: - --------------- - 1. Concurrent queries independent - 2. One failure doesn't affect others - 3. Session thread-safe for errors - 4. Proper error isolation - - Why this matters: - ---------------- - Real apps have concurrent queries: - - API handling multiple requests - - Background jobs running - - Batch processing - - One bad query shouldn't break - all other operations. - """ - mock_session = MagicMock() - mock_session.close = AsyncMock() - - # Track execute calls - execute_count = 0 - execute_results = [] - - async def mock_execute(query, *args, **kwargs): - nonlocal execute_count - execute_count += 1 - - # First call fails, others succeed - if execute_count == 1: - raise QueryError("First query fails") - - # Return a mock result - result = MagicMock() - result.one = MagicMock(return_value={"id": execute_count}) - execute_results.append(result) - return result - - # Create session - async_session = AsyncCassandraSession(mock_session) - async_session.execute = mock_execute - - # Run concurrent queries - async def query_with_error(): - try: - await async_session.execute("SELECT * FROM table1") - except QueryError: - pass # Expected - - async def query_success(): - return await async_session.execute("SELECT * FROM table2") - - # Run queries concurrently - results = await asyncio.gather( - query_with_error(), query_success(), query_success(), return_exceptions=True - ) - - # First should be None (handled error), others should succeed - assert results[0] is None - assert results[1] is not None - assert results[2] is not None - - # Session should NOT be closed - mock_session.close.assert_not_called() - - # Should have made 3 execute calls - assert execute_count == 3 - - @pytest.mark.asyncio - async def test_session_usable_after_streaming_context_exit(self): - """ - Test that session remains usable after streaming context manager exits. - - What this tests: - --------------- - 1. Session works after streaming - 2. Stream cleanup doesn't break session - 3. Can execute new queries - 4. Resource isolation verified - - Why this matters: - ---------------- - Common pattern: - - Stream large results - - Process data - - Execute follow-up queries - - Session must remain fully - functional after streaming. - """ - mock_session = MagicMock() - mock_session.close = AsyncMock() - - # Create session - async_session = AsyncCassandraSession(mock_session) - - # Mock execute_stream - mock_future = MagicMock() - mock_future.has_more_pages = False - mock_future._final_exception = None - mock_future.add_callbacks = MagicMock() - mock_future.clear_callbacks = MagicMock() - - stream_result = AsyncStreamingResultSet(mock_future) - stream_result._handle_page(["row1", "row2"]) - - async def mock_execute_stream(*args, **kwargs): - return stream_result - - async_session.execute_stream = mock_execute_stream - - # Use streaming in context manager - async with await async_session.execute_stream("SELECT * FROM table") as stream: - rows = [] - async for row in stream: - rows.append(row) - - # Now try to use session again - should work - mock_result = MagicMock() - mock_result.one = MagicMock(return_value={"id": 1}) - - async def mock_execute(*args, **kwargs): - return mock_result - - async_session.execute = mock_execute - - # This should work fine - result = await async_session.execute("SELECT * FROM another_table") - assert result.one() == {"id": 1} - - # Session should still be open - mock_session.close.assert_not_called() - - @pytest.mark.asyncio - async def test_cluster_remains_open_after_session_context_exit(self): - """ - Test that cluster remains open after session context manager exits. - - What this tests: - --------------- - 1. Cluster survives session closure - 2. Can create new sessions - 3. Cluster lifecycle independent - 4. Multiple session support - - Why this matters: - ---------------- - Cluster is expensive resource: - - Connection pool - - Metadata management - - Load balancing state - - Must support many short-lived - sessions efficiently. - """ - mock_cluster = MagicMock() - mock_cluster.shutdown = MagicMock() # Not AsyncMock because it's called via run_in_executor - mock_cluster.connect = AsyncMock() - mock_cluster.protocol_version = 5 # Mock protocol version - - mock_session1 = MagicMock() - mock_session1.close = AsyncMock() - - mock_session2 = MagicMock() - mock_session2.close = AsyncMock() - - # First connect returns session1, second returns session2 - mock_cluster.connect.side_effect = [mock_session1, mock_session2] - - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster_class.return_value = mock_cluster - - # Mock AsyncCassandraSession.create - mock_async_session1 = MagicMock() - mock_async_session1._session = mock_session1 - mock_async_session1.close = AsyncMock() - mock_async_session1.__aenter__ = AsyncMock(return_value=mock_async_session1) - - async def async_exit1(*args): - await mock_async_session1.close() - - mock_async_session1.__aexit__ = AsyncMock(side_effect=async_exit1) - - mock_async_session2 = MagicMock() - mock_async_session2._session = mock_session2 - mock_async_session2.close = AsyncMock() - - with patch( - "async_cassandra.session.AsyncCassandraSession.create", new_callable=AsyncMock - ) as mock_create: - mock_create.side_effect = [mock_async_session1, mock_async_session2] - - cluster = AsyncCluster(["localhost"]) - - # Use first session in context manager - async with await cluster.connect(): - pass # Do some work - - # First session should be closed - mock_async_session1.close.assert_called_once() - - # But cluster should NOT be shut down - mock_cluster.shutdown.assert_not_called() - - # Should be able to create another session - session2 = await cluster.connect() - assert session2._session == mock_session2 - - # Clean up - await cluster.shutdown() - - @pytest.mark.asyncio - async def test_thread_safety_of_session_during_context_exit(self): - """ - Test that session can be used by other threads even when - one thread is exiting a context manager. - - What this tests: - --------------- - 1. Thread-safe context exit - 2. Concurrent usage allowed - 3. No race conditions - 4. Proper synchronization - - Why this matters: - ---------------- - Multi-threaded usage common: - - Web frameworks spawn threads - - Background workers - - Parallel processing - - Context managers must be - thread-safe during cleanup. - """ - mock_session = MagicMock() - mock_session.shutdown = MagicMock() # AsyncCassandraSession calls shutdown - - # Create thread-safe mock for execute - execute_lock = threading.Lock() - execute_calls = [] - - def mock_execute_sync(query): - with execute_lock: - execute_calls.append(query) - result = MagicMock() - result.one = MagicMock(return_value={"id": len(execute_calls)}) - return result - - mock_session.execute = mock_execute_sync - - # Create async session - async_session = AsyncCassandraSession(mock_session) - - # Track if session is being used - session_in_use = threading.Event() - other_thread_done = threading.Event() - - # Function for other thread - def other_thread_work(): - session_in_use.wait() # Wait for signal - - # Try to use session from another thread - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - async def do_query(): - # Wrap sync call in executor - result = await asyncio.get_event_loop().run_in_executor( - None, mock_session.execute, "SELECT FROM other_thread" - ) - return result - - loop.run_until_complete(do_query()) - loop.close() - - other_thread_done.set() - - # Start other thread - thread = threading.Thread(target=other_thread_work) - thread.start() - - # Use session in context manager - async with async_session: - # Signal other thread that session is in use - session_in_use.set() - - # Do some work - await asyncio.get_event_loop().run_in_executor( - None, mock_session.execute, "SELECT FROM main_thread" - ) - - # Wait a bit for other thread to also use session - await asyncio.sleep(0.1) - - # Wait for other thread - other_thread_done.wait(timeout=2.0) - thread.join() - - # Both threads should have executed queries - assert len(execute_calls) == 2 - assert "SELECT FROM main_thread" in execute_calls - assert "SELECT FROM other_thread" in execute_calls - - # Session should be shut down only once - mock_session.shutdown.assert_called_once() - - @pytest.mark.asyncio - async def test_streaming_context_manager_implementation(self): - """ - Test that streaming result properly implements context manager protocol. - - What this tests: - --------------- - 1. __aenter__ returns self - 2. __aexit__ calls close - 3. Cleanup always happens - 4. Protocol correctly implemented - - Why this matters: - ---------------- - Context manager protocol ensures: - - Resources always cleaned - - Even with exceptions - - Pythonic usage pattern - - Users expect async with to - work correctly. - """ - # Mock response future - mock_future = MagicMock() - mock_future.has_more_pages = False - mock_future._final_exception = None - mock_future.add_callbacks = MagicMock() - mock_future.clear_callbacks = MagicMock() - - # Create streaming result - stream_result = AsyncStreamingResultSet(mock_future) - stream_result._handle_page(["row1", "row2"]) - - # Test __aenter__ returns self - entered = await stream_result.__aenter__() - assert entered is stream_result - - # Test __aexit__ calls close - close_called = False - original_close = stream_result.close - - async def mock_close(): - nonlocal close_called - close_called = True - await original_close() - - stream_result.close = mock_close - - # Call __aexit__ with no exception - result = await stream_result.__aexit__(None, None, None) - assert result is None # Should not suppress exceptions - assert close_called - - # Verify cleanup happened - mock_future.clear_callbacks.assert_called() - - @pytest.mark.asyncio - async def test_context_manager_with_exception_propagation(self): - """ - Test that exceptions are properly propagated through context managers. - - What this tests: - --------------- - 1. Exceptions propagate correctly - 2. Cleanup still happens - 3. __aexit__ doesn't suppress - 4. Error handling correct - - Why this matters: - ---------------- - Exception handling critical: - - Errors must bubble up - - Resources still cleaned - - No silent failures - - Context managers must not - hide exceptions. - """ - mock_future = MagicMock() - mock_future.has_more_pages = False - mock_future._final_exception = None - mock_future.add_callbacks = MagicMock() - mock_future.clear_callbacks = MagicMock() - - stream_result = AsyncStreamingResultSet(mock_future) - stream_result._handle_page(["row1"]) - - # Test that exceptions are propagated - exception_caught = None - close_called = False - - async def track_close(): - nonlocal close_called - close_called = True - - stream_result.close = track_close - - try: - async with stream_result: - raise ValueError("Test exception") - except ValueError as e: - exception_caught = e - - # Exception should be propagated - assert exception_caught is not None - assert str(exception_caught) == "Test exception" - - # But close should still have been called - assert close_called - - @pytest.mark.asyncio - async def test_nested_context_managers_close_correctly(self): - """ - Test that nested context managers only close their own resources. - - What this tests: - --------------- - 1. Nested contexts independent - 2. Inner closes before outer - 3. Each manages own resources - 4. Proper cleanup order - - Why this matters: - ---------------- - Common nesting pattern: - - Cluster context - - Session context inside - - Stream context inside that - - Each level must clean up - only its own resources. - """ - mock_cluster = MagicMock() - mock_cluster.shutdown = MagicMock() # Not AsyncMock because it's called via run_in_executor - mock_cluster.connect = AsyncMock() - mock_cluster.protocol_version = 5 # Mock protocol version - - mock_session = MagicMock() - mock_session.close = AsyncMock() - mock_cluster.connect.return_value = mock_session - - # Mock for streaming - mock_future = MagicMock() - mock_future.has_more_pages = False - mock_future._final_exception = None - mock_future.add_callbacks = MagicMock() - mock_future.clear_callbacks = MagicMock() - - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster_class.return_value = mock_cluster - - # Mock AsyncCassandraSession.create - mock_async_session = MagicMock() - mock_async_session._session = mock_session - mock_async_session.close = AsyncMock() - mock_async_session.shutdown = AsyncMock() # For when __aexit__ calls close() - mock_async_session.__aenter__ = AsyncMock(return_value=mock_async_session) - - async def async_exit_shutdown(*args): - await mock_async_session.shutdown() - - mock_async_session.__aexit__ = AsyncMock(side_effect=async_exit_shutdown) - - with patch( - "async_cassandra.session.AsyncCassandraSession.create", new_callable=AsyncMock - ) as mock_create: - mock_create.return_value = mock_async_session - - # Nested context managers - async with AsyncCluster(["localhost"]) as cluster: - async with await cluster.connect(): - # Create streaming result - stream_result = AsyncStreamingResultSet(mock_future) - stream_result._handle_page(["row1"]) - - async with stream_result as stream: - async for row in stream: - pass - - # After stream context, only stream should be cleaned - mock_future.clear_callbacks.assert_called() - mock_async_session.shutdown.assert_not_called() - mock_cluster.shutdown.assert_not_called() - - # After session context, session should be closed - mock_async_session.shutdown.assert_called_once() - mock_cluster.shutdown.assert_not_called() - - # After cluster context, cluster should be shut down - mock_cluster.shutdown.assert_called_once() - - @pytest.mark.asyncio - async def test_cluster_and_session_context_managers_are_independent(self): - """ - Test that cluster and session context managers don't interfere. - - What this tests: - --------------- - 1. Context managers fully independent - 2. Can use in any order - 3. No hidden dependencies - 4. Flexible usage patterns - - Why this matters: - ---------------- - Users need flexibility: - - Long-lived clusters - - Short-lived sessions - - Various usage patterns - - Context managers must support - all reasonable usage patterns. - """ - mock_cluster = MagicMock() - mock_cluster.shutdown = MagicMock() # Not AsyncMock because it's called via run_in_executor - mock_cluster.connect = AsyncMock() - mock_cluster.is_closed = False - mock_cluster.protocol_version = 5 # Mock protocol version - - mock_session = MagicMock() - mock_session.close = AsyncMock() - mock_session.is_closed = False - mock_cluster.connect.return_value = mock_session - - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - mock_cluster_class.return_value = mock_cluster - - # Mock AsyncCassandraSession.create - mock_async_session1 = MagicMock() - mock_async_session1._session = mock_session - mock_async_session1.close = AsyncMock() - mock_async_session1.__aenter__ = AsyncMock(return_value=mock_async_session1) - - async def async_exit1(*args): - await mock_async_session1.close() - - mock_async_session1.__aexit__ = AsyncMock(side_effect=async_exit1) - - mock_async_session2 = MagicMock() - mock_async_session2._session = mock_session - mock_async_session2.close = AsyncMock() - - mock_async_session3 = MagicMock() - mock_async_session3._session = mock_session - mock_async_session3.close = AsyncMock() - mock_async_session3.__aenter__ = AsyncMock(return_value=mock_async_session3) - - async def async_exit3(*args): - await mock_async_session3.close() - - mock_async_session3.__aexit__ = AsyncMock(side_effect=async_exit3) - - with patch( - "async_cassandra.session.AsyncCassandraSession.create", new_callable=AsyncMock - ) as mock_create: - mock_create.side_effect = [ - mock_async_session1, - mock_async_session2, - mock_async_session3, - ] - - # Create cluster (not in context manager) - cluster = AsyncCluster(["localhost"]) - - # Use session in context manager - async with await cluster.connect(): - # Do work - pass - - # Session closed, but cluster still open - mock_async_session1.close.assert_called_once() - mock_cluster.shutdown.assert_not_called() - - # Can create another session - session2 = await cluster.connect() - assert session2 is not None - - # Now use cluster in context manager - async with cluster: - # Create and use another session - async with await cluster.connect(): - pass - - # Now cluster should be shut down - mock_cluster.shutdown.assert_called_once() diff --git a/tests/unit/test_coverage_summary.py b/tests/unit/test_coverage_summary.py deleted file mode 100644 index 86c4528..0000000 --- a/tests/unit/test_coverage_summary.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -Test Coverage Summary and Guide - -This module documents the comprehensive unit test coverage added to address gaps -in testing failure scenarios and edge cases for the async-cassandra wrapper. - -NEW TEST COVERAGE AREAS: -======================= - -1. TOPOLOGY CHANGES (test_topology_changes.py) - - Host up/down events without blocking event loop - - Add/remove host callbacks - - Rapid topology changes - - Concurrent topology events - - Host state changes during queries - - Listener registration/unregistration - -2. PREPARED STATEMENT INVALIDATION (test_prepared_statement_invalidation.py) - - Automatic re-preparation after schema changes - - Concurrent invalidation handling - - Batch execution with invalidated statements - - Re-preparation failures - - Cache invalidation - - Statement ID tracking - -3. AUTHENTICATION/AUTHORIZATION (test_auth_failures.py) - - Initial connection auth failures - - Auth failures during operations - - Credential rotation scenarios - - Different permission failures (SELECT, INSERT, CREATE, etc.) - - Session invalidation on auth changes - - Keyspace-level authorization - -4. CONNECTION POOL EXHAUSTION (test_connection_pool_exhaustion.py) - - Pool exhaustion under load - - Connection borrowing timeouts - - Pool recovery after exhaustion - - Connection health checks - - Pool size limits (min/max) - - Connection leak detection - - Graceful degradation - -5. BACKPRESSURE HANDLING (test_backpressure_handling.py) - - Client request queue overflow - - Server overload responses - - Backpressure propagation - - Adaptive concurrency control - - Queue timeout handling - - Priority queue management - - Circuit breaker pattern - - Load shedding strategies - -6. SCHEMA CHANGES (test_schema_changes.py) - - Schema change event listeners - - Metadata refresh on changes - - Concurrent schema changes - - Schema agreement waiting - - Schema disagreement handling - - Keyspace/table metadata tracking - - DDL operation coordination - -7. NETWORK FAILURES (test_network_failures.py) - - Partial network failures - - Connection timeouts vs request timeouts - - Slow network simulation - - Coordinator failures mid-query - - Asymmetric network partitions - - Network flapping - - Connection pool recovery - - Host distance changes - - Exponential backoff - -8. PROTOCOL EDGE CASES (test_protocol_edge_cases.py) - - Protocol version negotiation failures - - Compression issues - - Custom payload handling - - Frame size limits - - Unsupported message types - - Protocol error recovery - - Beta features handling - - Protocol flags (tracing, warnings) - - Stream ID exhaustion - -TESTING PHILOSOPHY: -================== - -These tests focus on the WRAPPER'S behavior, not the driver's: -- How events/callbacks are handled without blocking the event loop -- How errors are propagated through the async layer -- How resources are cleaned up in async context -- How the wrapper maintains compatibility while adding async support - -FUTURE TESTING CONSIDERATIONS: -============================= - -1. Integration Tests Still Needed For: - - Multi-node cluster scenarios - - Real network partitions - - Actual schema changes with running queries - - True coordinator failures - - Cross-datacenter scenarios - -2. Performance Tests Could Cover: - - Overhead of async wrapper - - Thread pool efficiency - - Memory usage under load - - Latency impact - -3. Stress Tests Could Verify: - - Behavior under extreme load - - Resource cleanup under pressure - - Memory leak prevention - - Thread safety guarantees - -USAGE: -====== - -Run all new gap coverage tests: - pytest tests/unit/test_topology_changes.py \ - tests/unit/test_prepared_statement_invalidation.py \ - tests/unit/test_auth_failures.py \ - tests/unit/test_connection_pool_exhaustion.py \ - tests/unit/test_backpressure_handling.py \ - tests/unit/test_schema_changes.py \ - tests/unit/test_network_failures.py \ - tests/unit/test_protocol_edge_cases.py -v - -Run specific scenario: - pytest tests/unit/test_topology_changes.py::TestTopologyChanges::test_host_up_event_nonblocking -v - -MAINTENANCE: -============ - -When adding new features to the wrapper, consider: -1. Does it handle driver callbacks? → Add to topology/schema tests -2. Does it deal with errors? → Add to appropriate failure test file -3. Does it manage resources? → Add to pool/backpressure tests -4. Does it interact with protocol? → Add to protocol edge cases - -""" - - -class TestCoverageSummary: - """ - This test class serves as documentation and verification that all - gap coverage test files exist and are importable. - """ - - def test_all_gap_coverage_modules_exist(self): - """ - Verify all gap coverage test modules can be imported. - - What this tests: - --------------- - 1. All test modules listed - 2. Naming convention followed - 3. Module paths correct - 4. Coverage areas complete - - Why this matters: - ---------------- - Documentation accuracy: - - Tests match documentation - - No missing test files - - Clear test organization - - Helps developers find - the right test file. - """ - test_modules = [ - "tests.unit.test_topology_changes", - "tests.unit.test_prepared_statement_invalidation", - "tests.unit.test_auth_failures", - "tests.unit.test_connection_pool_exhaustion", - "tests.unit.test_backpressure_handling", - "tests.unit.test_schema_changes", - "tests.unit.test_network_failures", - "tests.unit.test_protocol_edge_cases", - ] - - # Just verify we can reference the module names - # Actual imports would happen when running the tests - for module in test_modules: - assert isinstance(module, str) - assert module.startswith("tests.unit.test_") - - def test_coverage_areas_documented(self): - """ - Verify this summary documents all coverage areas. - - What this tests: - --------------- - 1. All areas in docstring - 2. Documentation complete - 3. No missing sections - 4. Self-documenting test - - Why this matters: - ---------------- - Complete documentation: - - Guides new developers - - Shows test coverage - - Prevents blind spots - - Living documentation stays - accurate with codebase. - """ - coverage_areas = [ - "TOPOLOGY CHANGES", - "PREPARED STATEMENT INVALIDATION", - "AUTHENTICATION/AUTHORIZATION", - "CONNECTION POOL EXHAUSTION", - "BACKPRESSURE HANDLING", - "SCHEMA CHANGES", - "NETWORK FAILURES", - "PROTOCOL EDGE CASES", - ] - - # Read this file's docstring - module_doc = __doc__ - - for area in coverage_areas: - assert area in module_doc, f"Coverage area '{area}' not documented" - - def test_no_regression_in_existing_tests(self): - """ - Reminder: These new tests supplement, not replace existing tests. - - Existing test coverage that should remain: - - Basic async operations (test_session.py) - - Retry policies (test_retry_policies.py) - - Error handling (test_error_handling.py) - - Streaming (test_streaming.py) - - Connection management (test_connection.py) - - Cluster operations (test_cluster.py) - - What this tests: - --------------- - 1. Documentation reminder - 2. Test suite completeness - 3. No test deletion - 4. Coverage preservation - - Why this matters: - ---------------- - Test regression prevention: - - Keep existing coverage - - Build on foundation - - No coverage gaps - - New tests augment, not - replace existing tests. - """ - # This is a documentation test - no actual assertions - # Just ensures we remember to keep existing tests - pass diff --git a/tests/unit/test_critical_issues.py b/tests/unit/test_critical_issues.py deleted file mode 100644 index 36ab9a5..0000000 --- a/tests/unit/test_critical_issues.py +++ /dev/null @@ -1,600 +0,0 @@ -""" -Unit tests for critical issues identified in the technical review. - -These tests use mocking to isolate and test specific problematic code paths. - -Test Organization: -================== -1. Thread Safety Issues - Race conditions in AsyncResultHandler -2. Memory Leaks - Reference cycles and page accumulation in streaming -3. Error Consistency - Inconsistent error handling between methods - -Key Testing Principles: -====================== -- Expose race conditions through concurrent access -- Track object lifecycle with weakrefs -- Verify error handling consistency -- Test edge cases that trigger bugs - -Note: Some of these tests may fail, demonstrating the issues they test. -""" - -import asyncio -import gc -import threading -import weakref -from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock - -import pytest - -from async_cassandra.result import AsyncResultHandler -from async_cassandra.streaming import AsyncStreamingResultSet, StreamConfig - - -class TestAsyncResultHandlerThreadSafety: - """Unit tests for thread safety issues in AsyncResultHandler.""" - - def test_race_condition_in_handle_page(self): - """ - Test race condition in _handle_page method. - - What this tests: - --------------- - 1. Concurrent _handle_page calls from driver threads - 2. Data corruption from unsynchronized row appending - 3. Missing or duplicated rows - 4. Thread safety of shared state - - Why this matters: - ---------------- - The Cassandra driver calls callbacks from multiple threads. - Without proper synchronization, concurrent callbacks can: - - Corrupt the rows list - - Lose data - - Cause index errors - - This test may fail, demonstrating the critical issue - that needs fixing with proper locking. - """ - # Create handler with mock future - mock_future = Mock() - mock_future.has_more_pages = True - handler = AsyncResultHandler(mock_future) - - # Track all rows added - all_rows = [] - errors = [] - - def concurrent_callback(thread_id, page_num): - try: - # Simulate driver callback with unique data - rows = [f"thread_{thread_id}_page_{page_num}_row_{i}" for i in range(10)] - handler._handle_page(rows) - all_rows.extend(rows) - except Exception as e: - errors.append(f"Thread {thread_id}: {e}") - - # Simulate concurrent callbacks from driver threads - with ThreadPoolExecutor(max_workers=10) as executor: - futures = [] - for thread_id in range(10): - for page_num in range(5): - future = executor.submit(concurrent_callback, thread_id, page_num) - futures.append(future) - - # Wait for all callbacks - for future in futures: - future.result() - - # Check for data corruption - assert len(errors) == 0, f"Thread safety errors: {errors}" - - # All rows should be present - expected_count = 10 * 5 * 10 # threads * pages * rows_per_page - assert len(all_rows) == expected_count - - # Check handler.rows for corruption - # Current implementation may have race conditions here - # This test may fail, demonstrating the issue - - def test_event_loop_thread_safety(self): - """ - Test event loop thread safety in callbacks. - - What this tests: - --------------- - 1. Callbacks run in driver threads (not event loop) - 2. Future results set from wrong thread - 3. call_soon_threadsafe usage - 4. Cross-thread future completion - - Why this matters: - ---------------- - asyncio futures must be completed from the event loop - thread. Driver callbacks run in executor threads, so: - - Direct future.set_result() is unsafe - - Must use call_soon_threadsafe() - - Otherwise: "Future attached to different loop" errors - - This ensures the async wrapper properly bridges - thread boundaries for asyncio safety. - """ - - async def run_test(): - loop = asyncio.get_running_loop() - - # Track which thread sets the future result - result_thread = None - - # Patch to monitor thread safety - original_call_soon_threadsafe = loop.call_soon_threadsafe - call_soon_threadsafe_used = False - - def monitored_call_soon_threadsafe(callback, *args): - nonlocal call_soon_threadsafe_used - call_soon_threadsafe_used = True - return original_call_soon_threadsafe(callback, *args) - - loop.call_soon_threadsafe = monitored_call_soon_threadsafe - - try: - mock_future = Mock() - mock_future.has_more_pages = True # Start with more pages expected - mock_future.add_callbacks = Mock() - mock_future.timeout = None - mock_future.start_fetching_next_page = Mock() - - handler = AsyncResultHandler(mock_future) - - # Start get_result to create the future - result_task = asyncio.create_task(handler.get_result()) - await asyncio.sleep(0.1) # Make sure it's fully initialized - - # Simulate callback from driver thread - def driver_callback(): - nonlocal result_thread - result_thread = threading.current_thread() - # First callback with more pages - handler._handle_page([1, 2, 3]) - # Now final callback - set has_more_pages to False before calling - mock_future.has_more_pages = False - handler._handle_page([4, 5, 6]) - - driver_thread = threading.Thread(target=driver_callback) - driver_thread.start() - driver_thread.join() - - # Give time for async operations - await asyncio.sleep(0.1) - - # Verify thread safety was maintained - assert result_thread != threading.current_thread() - # Now call_soon_threadsafe SHOULD be used since we store the loop - assert call_soon_threadsafe_used - - # The result task should be completed - assert result_task.done() - result = await result_task - assert len(result.rows) == 6 # We added [1,2,3] then [4,5,6] - - finally: - loop.call_soon_threadsafe = original_call_soon_threadsafe - - asyncio.run(run_test()) - - def test_state_synchronization_issues(self): - """ - Test state synchronization between threads. - - What this tests: - --------------- - 1. Unsynchronized access to handler.rows - 2. Non-atomic operations on shared state - 3. Lost updates from concurrent modifications - 4. Data consistency under concurrent access - - Why this matters: - ---------------- - Multiple driver threads might modify handler state: - - rows.append() is not thread-safe - - len() followed by append() is not atomic - - Can lose rows or corrupt list structure - - This demonstrates why locks are needed around - all shared state modifications. - """ - mock_future = Mock() - mock_future.has_more_pages = True - handler = AsyncResultHandler(mock_future) - - # Simulate rapid state changes from multiple threads - state_changes = [] - - def modify_state(thread_id): - for i in range(100): - # These operations are not atomic without proper locking - current_rows = len(handler.rows) - state_changes.append((thread_id, i, current_rows)) - handler.rows.append(f"thread_{thread_id}_item_{i}") - - threads = [] - for thread_id in range(5): - thread = threading.Thread(target=modify_state, args=(thread_id,)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - # Check for consistency - expected_total = 5 * 100 # threads * iterations - actual_total = len(handler.rows) - - # This might fail due to race conditions - assert ( - actual_total == expected_total - ), f"Race condition detected: expected {expected_total}, got {actual_total}" - - -class TestStreamingMemoryLeaks: - """Unit tests for memory leaks in streaming functionality.""" - - def test_page_reference_cleanup(self): - """ - Test page reference cleanup in streaming. - - What this tests: - --------------- - 1. Pages are not accumulated in memory - 2. Only current page is retained - 3. Old pages become garbage collectible - 4. Memory usage is bounded - - Why this matters: - ---------------- - Streaming is designed for large result sets. - If pages accumulate: - - Memory usage grows unbounded - - Defeats purpose of streaming - - Can cause OOM with large results - - This verifies the streaming implementation - properly releases old pages. - """ - # Track pages created - pages_created = [] - - mock_future = Mock() - mock_future.has_more_pages = True - mock_future._final_exception = None # Important: must be None - - page_count = 0 - handler = None # Define handler first - callbacks = {} - - def add_callbacks(callback=None, errback=None): - callbacks["callback"] = callback - callbacks["errback"] = errback - # Simulate initial page callback from a thread - if callback: - import threading - - def thread_callback(): - first_page = [f"row_0_{i}" for i in range(100)] - pages_created.append(first_page) - callback(first_page) - - thread = threading.Thread(target=thread_callback) - thread.start() - - def mock_fetch_next(): - nonlocal page_count - page_count += 1 - - if page_count <= 5: - # Create a page - page = [f"row_{page_count}_{i}" for i in range(100)] - pages_created.append(page) - - # Simulate callback from thread - if callbacks.get("callback"): - import threading - - def thread_callback(): - callbacks["callback"](page) - - thread = threading.Thread(target=thread_callback) - thread.start() - mock_future.has_more_pages = page_count < 5 - else: - if callbacks.get("callback"): - import threading - - def thread_callback(): - callbacks["callback"]([]) - - thread = threading.Thread(target=thread_callback) - thread.start() - mock_future.has_more_pages = False - - mock_future.start_fetching_next_page = mock_fetch_next - mock_future.add_callbacks = add_callbacks - - handler = AsyncStreamingResultSet(mock_future) - - async def consume_all(): - consumed = 0 - async for row in handler: - consumed += 1 - return consumed - - # Consume all rows - total_consumed = asyncio.run(consume_all()) - assert total_consumed == 600 # 6 pages * 100 rows (including first page) - - # Check that handler only holds one page at a time - assert len(handler._current_page) <= 100, "Handler should only hold one page" - - # Verify pages were replaced, not accumulated - assert len(pages_created) == 6 # 1 initial page + 5 pages from mock_fetch_next - - def test_callback_reference_cycles(self): - """ - Test for callback reference cycles. - - What this tests: - --------------- - 1. Callbacks don't create reference cycles - 2. Handler -> Future -> Callback -> Handler cycles - 3. Objects are garbage collected after use - 4. No memory leaks from circular references - - Why this matters: - ---------------- - Callbacks often reference the handler: - - Handler registers callbacks on future - - Future stores reference to callbacks - - Callbacks reference handler methods - - Creates circular reference - - Without breaking cycles, these objects - leak memory even after streaming completes. - """ - # Track object lifecycle - handler_refs = [] - future_refs = [] - - class TrackedFuture: - def __init__(self): - future_refs.append(weakref.ref(self)) - self.callbacks = [] - self.has_more_pages = False - - def add_callbacks(self, callback, errback): - # This creates a reference from future to handler - self.callbacks.append((callback, errback)) - - def start_fetching_next_page(self): - pass - - class TrackedHandler(AsyncStreamingResultSet): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - handler_refs.append(weakref.ref(self)) - - # Create objects with potential cycle - future = TrackedFuture() - handler = TrackedHandler(future) - - # Use the handler - async def use_handler(h): - h._handle_page([1, 2, 3]) - h._exhausted = True - - try: - async for _ in h: - pass - except StopAsyncIteration: - pass - - asyncio.run(use_handler(handler)) - - # Clear explicit references - del future - del handler - - # Force garbage collection - gc.collect() - - # Check for leaks - alive_handlers = sum(1 for ref in handler_refs if ref() is not None) - alive_futures = sum(1 for ref in future_refs if ref() is not None) - - assert alive_handlers == 0, f"Handler leak: {alive_handlers} still alive" - assert alive_futures == 0, f"Future leak: {alive_futures} still alive" - - def test_streaming_config_lifecycle(self): - """ - Test streaming config and callback cleanup. - - What this tests: - --------------- - 1. StreamConfig doesn't leak memory - 2. Page callbacks are properly released - 3. Callback data is garbage collected - 4. No references retained after completion - - Why this matters: - ---------------- - Page callbacks might reference large objects: - - Progress tracking data structures - - Metric collectors - - UI update handlers - - These must be released when streaming ends - to avoid memory leaks in long-running apps. - """ - callback_refs = [] - - class CallbackData: - """Object that can be weakly referenced""" - - def __init__(self, page_num, row_count): - self.page = page_num - self.rows = row_count - - def progress_callback(page_num, row_count): - # Simulate some object that could be leaked - data = CallbackData(page_num, row_count) - callback_refs.append(weakref.ref(data)) - - config = StreamConfig(fetch_size=10, max_pages=5, page_callback=progress_callback) - - # Create a simpler test that doesn't require async iteration - mock_future = Mock() - mock_future.has_more_pages = False - mock_future.add_callbacks = Mock() - - handler = AsyncStreamingResultSet(mock_future, config) - - # Simulate page callbacks directly - handler._handle_page([f"row_{i}" for i in range(10)]) - handler._handle_page([f"row_{i}" for i in range(10, 20)]) - handler._handle_page([f"row_{i}" for i in range(20, 30)]) - - # Verify callbacks were called - assert len(callback_refs) == 3 # 3 pages - - # Clear references - del handler - del config - del progress_callback - gc.collect() - - # Check for leaked callback data - alive_callbacks = sum(1 for ref in callback_refs if ref() is not None) - assert alive_callbacks == 0, f"Callback data leak: {alive_callbacks} still alive" - - -class TestErrorHandlingConsistency: - """Unit tests for error handling consistency.""" - - @pytest.mark.asyncio - async def test_execute_vs_execute_stream_error_wrapping(self): - """ - Test error handling consistency between methods. - - What this tests: - --------------- - 1. execute() and execute_stream() handle errors the same - 2. No extra wrapping in QueryError - 3. Original error types preserved - 4. Error messages unchanged - - Why this matters: - ---------------- - Applications need consistent error handling: - - Same error type for same problem - - Can use same except clauses - - Error handling code is reusable - - Inconsistent wrapping makes error handling - complex and error-prone. - """ - from cassandra import InvalidRequest - - # Test InvalidRequest handling - base_error = InvalidRequest("Test error") - - # Test execute() error handling with AsyncResultHandler - execute_error = None - mock_future = Mock() - mock_future.add_callbacks = Mock() - mock_future.has_more_pages = False - mock_future.timeout = None # Add timeout attribute - - handler = AsyncResultHandler(mock_future) - # Simulate error callback being called after init - handler._handle_error(base_error) - try: - await handler.get_result() - except Exception as e: - execute_error = e - - # Test execute_stream() error handling with AsyncStreamingResultSet - # We need to test error handling without async iteration to avoid complexity - stream_mock_future = Mock() - stream_mock_future.add_callbacks = Mock() - stream_mock_future.has_more_pages = False - - # Get the error that would be raised - stream_handler = AsyncStreamingResultSet(stream_mock_future) - stream_handler._handle_error(base_error) - stream_error = stream_handler._error - - # Both should have the same error type - assert execute_error is not None - assert stream_error is not None - assert type(execute_error) is type( - stream_error - ), f"Different error types: {type(execute_error)} vs {type(stream_error)}" - assert isinstance(execute_error, InvalidRequest) - assert isinstance(stream_error, InvalidRequest) - - def test_timeout_error_consistency(self): - """ - Test timeout error handling consistency. - - What this tests: - --------------- - 1. Timeout errors preserved across contexts - 2. OperationTimedOut not wrapped - 3. Error details maintained - 4. Same handling in all code paths - - Why this matters: - ---------------- - Timeouts need special handling: - - May indicate overload - - Might need backoff/retry - - Critical for monitoring - - Consistent timeout errors enable proper - timeout handling strategies. - """ - from cassandra import OperationTimedOut - - timeout_error = OperationTimedOut("Test timeout") - - # Test in AsyncResultHandler - result_error = None - - async def get_result_error(): - nonlocal result_error - mock_future = Mock() - mock_future.add_callbacks = Mock() - mock_future.has_more_pages = False - mock_future.timeout = None # Add timeout attribute - result_handler = AsyncResultHandler(mock_future) - # Simulate error callback being called after init - result_handler._handle_error(timeout_error) - try: - await result_handler.get_result() - except Exception as e: - result_error = e - - asyncio.run(get_result_error()) - - # Test in AsyncStreamingResultSet - stream_mock_future = Mock() - stream_mock_future.add_callbacks = Mock() - stream_mock_future.has_more_pages = False - stream_handler = AsyncStreamingResultSet(stream_mock_future) - stream_handler._handle_error(timeout_error) - stream_error = stream_handler._error - - # Both should preserve the timeout error - assert isinstance(result_error, OperationTimedOut) - assert isinstance(stream_error, OperationTimedOut) - assert str(result_error) == str(stream_error) diff --git a/tests/unit/test_error_recovery.py b/tests/unit/test_error_recovery.py deleted file mode 100644 index b559b48..0000000 --- a/tests/unit/test_error_recovery.py +++ /dev/null @@ -1,534 +0,0 @@ -"""Error recovery and handling tests. - -This module tests various error scenarios including NoHostAvailable, -connection errors, and proper error propagation through the async layer. - -Test Organization: -================== -1. Connection Errors - NoHostAvailable, pool exhaustion -2. Query Errors - InvalidRequest, Unavailable -3. Callback Errors - Errors in async callbacks -4. Shutdown Scenarios - Graceful shutdown with pending queries -5. Error Isolation - Concurrent query error isolation - -Key Testing Principles: -====================== -- Errors must propagate with full context -- Stack traces must be preserved -- Concurrent errors must be isolated -- Graceful degradation under failure -- Recovery after transient failures -""" - -import asyncio -from unittest.mock import Mock - -import pytest -from cassandra import ConsistencyLevel, InvalidRequest, Unavailable -from cassandra.cluster import NoHostAvailable - -from async_cassandra import AsyncCassandraSession as AsyncSession -from async_cassandra import AsyncCluster - - -def create_mock_response_future(rows=None, has_more_pages=False): - """ - Helper to create a properly configured mock ResponseFuture. - - This helper ensures mock ResponseFutures behave like real ones, - with proper callback handling and attribute setup. - """ - mock_future = Mock() - mock_future.has_more_pages = has_more_pages - mock_future.timeout = None # Avoid comparison issues - mock_future.add_callbacks = Mock() - - def handle_callbacks(callback=None, errback=None): - if callback: - callback(rows if rows is not None else []) - - mock_future.add_callbacks.side_effect = handle_callbacks - return mock_future - - -class TestErrorRecovery: - """Test error recovery and handling scenarios.""" - - @pytest.mark.resilience - @pytest.mark.quick - @pytest.mark.critical - async def test_no_host_available_error(self): - """ - Test handling of NoHostAvailable errors. - - What this tests: - --------------- - 1. NoHostAvailable errors propagate correctly - 2. Error details include all failed hosts - 3. Connection errors for each host preserved - 4. Error message is informative - - Why this matters: - ---------------- - NoHostAvailable is a critical error indicating: - - All nodes are down or unreachable - - Network partition or configuration issues - - Need for manual intervention - - Applications need full error details to diagnose - and alert on infrastructure problems. - """ - errors = { - "127.0.0.1": ConnectionRefusedError("Connection refused"), - "127.0.0.2": TimeoutError("Connection timeout"), - } - - # Create a real async session with mocked underlying session - mock_session = Mock() - mock_session.execute_async.side_effect = NoHostAvailable( - "Unable to connect to any servers", errors - ) - - async_session = AsyncSession(mock_session) - - with pytest.raises(NoHostAvailable) as exc_info: - await async_session.execute("SELECT * FROM users") - - assert "Unable to connect to any servers" in str(exc_info.value) - assert "127.0.0.1" in exc_info.value.errors - assert "127.0.0.2" in exc_info.value.errors - - @pytest.mark.resilience - async def test_invalid_request_error(self): - """ - Test handling of invalid request errors. - - What this tests: - --------------- - 1. InvalidRequest errors propagate cleanly - 2. Error message preserved exactly - 3. No wrapping or modification - 4. Useful for debugging CQL issues - - Why this matters: - ---------------- - InvalidRequest indicates: - - Syntax errors in CQL - - Schema mismatches - - Invalid parameters - - Developers need the exact error message from - Cassandra to fix their queries. - """ - mock_session = Mock() - mock_session.execute_async.side_effect = InvalidRequest("Invalid CQL syntax") - - async_session = AsyncSession(mock_session) - - with pytest.raises(InvalidRequest, match="Invalid CQL syntax"): - await async_session.execute("INVALID QUERY SYNTAX") - - @pytest.mark.resilience - async def test_unavailable_error(self): - """ - Test handling of unavailable errors. - - What this tests: - --------------- - 1. Unavailable errors include consistency details - 2. Required vs available replicas reported - 3. Consistency level preserved - 4. All error attributes accessible - - Why this matters: - ---------------- - Unavailable errors help diagnose: - - Insufficient replicas for consistency - - Node failures affecting availability - - Need to adjust consistency levels - - Applications can use this info to: - - Retry with lower consistency - - Alert on degraded availability - - Make informed consistency trade-offs - """ - mock_session = Mock() - mock_session.execute_async.side_effect = Unavailable( - "Cannot achieve consistency", - consistency=ConsistencyLevel.QUORUM, - required_replicas=2, - alive_replicas=1, - ) - - async_session = AsyncSession(mock_session) - - with pytest.raises(Unavailable) as exc_info: - await async_session.execute("SELECT * FROM users") - - assert exc_info.value.consistency == ConsistencyLevel.QUORUM - assert exc_info.value.required_replicas == 2 - assert exc_info.value.alive_replicas == 1 - - @pytest.mark.resilience - @pytest.mark.critical - async def test_error_in_async_callback(self): - """ - Test error handling in async callbacks. - - What this tests: - --------------- - 1. Errors in callbacks are captured - 2. AsyncResultHandler propagates callback errors - 3. Original error type and message preserved - 4. Async layer doesn't swallow errors - - Why this matters: - ---------------- - The async wrapper uses callbacks to bridge - sync driver to async/await. Errors in this - bridge must not be lost or corrupted. - - This ensures reliability of error reporting - through the entire async pipeline. - """ - from async_cassandra.result import AsyncResultHandler - - # Create a mock ResponseFuture - mock_future = Mock() - mock_future.has_more_pages = False - mock_future.add_callbacks = Mock() - mock_future.timeout = None # Set timeout to None to avoid comparison issues - - handler = AsyncResultHandler(mock_future) - test_error = RuntimeError("Callback error") - - # Manually call the error handler to simulate callback error - handler._handle_error(test_error) - - with pytest.raises(RuntimeError, match="Callback error"): - await handler.get_result() - - @pytest.mark.resilience - async def test_connection_pool_exhaustion_recovery(self): - """ - Test recovery from connection pool exhaustion. - - What this tests: - --------------- - 1. Pool exhaustion errors are transient - 2. Retry after exhaustion can succeed - 3. No permanent failure from temporary exhaustion - 4. Application can recover automatically - - Why this matters: - ---------------- - Connection pools can be temporarily exhausted during: - - Traffic spikes - - Slow queries holding connections - - Network delays - - Applications should be able to recover when - connections become available again, without - manual intervention or restart. - """ - mock_session = Mock() - - # Create a mock ResponseFuture for successful response - mock_future = create_mock_response_future([{"id": 1}]) - - # Simulate pool exhaustion then recovery - responses = [ - NoHostAvailable("Pool exhausted", {}), - NoHostAvailable("Pool exhausted", {}), - mock_future, # Recovery returns ResponseFuture - ] - mock_session.execute_async.side_effect = responses - - async_session = AsyncSession(mock_session) - - # First two attempts fail - for i in range(2): - with pytest.raises(NoHostAvailable): - await async_session.execute("SELECT * FROM users") - - # Third attempt succeeds - result = await async_session.execute("SELECT * FROM users") - assert result._rows == [{"id": 1}] - - @pytest.mark.resilience - async def test_partial_write_error_handling(self): - """ - Test handling of partial write errors. - - What this tests: - --------------- - 1. Coordinator timeout errors propagate - 2. Write might have partially succeeded - 3. Error message indicates uncertainty - 4. Application can handle ambiguity - - Why this matters: - ---------------- - Partial writes are dangerous because: - - Some replicas might have the data - - Some might not (inconsistent state) - - Retry might cause duplicates - - Applications need to know when writes - are ambiguous to handle appropriately. - """ - mock_session = Mock() - - # Simulate partial write success - mock_session.execute_async.side_effect = Exception( - "Coordinator node timed out during write" - ) - - async_session = AsyncSession(mock_session) - - with pytest.raises(Exception, match="Coordinator node timed out"): - await async_session.execute("INSERT INTO users (id, name) VALUES (?, ?)", [1, "test"]) - - @pytest.mark.resilience - async def test_error_during_prepared_statement(self): - """ - Test error handling during prepared statement execution. - - What this tests: - --------------- - 1. Prepare succeeds but execute can fail - 2. Parameter validation errors propagate - 3. Prepared statements don't mask errors - 4. Error occurs at execution, not preparation - - Why this matters: - ---------------- - Prepared statements can fail at execution due to: - - Invalid parameter types - - Null values where not allowed - - Value size exceeding limits - - The async layer must propagate these execution - errors clearly for debugging. - """ - mock_session = Mock() - mock_prepared = Mock() - - # Prepare succeeds - mock_session.prepare.return_value = mock_prepared - - # But execution fails - mock_session.execute_async.side_effect = InvalidRequest("Invalid parameter") - - async_session = AsyncSession(mock_session) - - # Prepare statement - prepared = await async_session.prepare("SELECT * FROM users WHERE id = ?") - assert prepared == mock_prepared - - # Execute should fail - with pytest.raises(InvalidRequest, match="Invalid parameter"): - await async_session.execute(prepared, [None]) - - @pytest.mark.resilience - @pytest.mark.critical - @pytest.mark.timeout(40) # Increase timeout to account for 5s shutdown delay - async def test_graceful_shutdown_with_pending_queries(self): - """ - Test graceful shutdown when queries are pending. - - What this tests: - --------------- - 1. Shutdown waits for driver to finish - 2. Pending queries can complete during shutdown - 3. 5-second grace period for completion - 4. Clean shutdown without hanging - - Why this matters: - ---------------- - Applications need graceful shutdown to: - - Complete in-flight requests - - Avoid data loss or corruption - - Clean up resources properly - - The 5-second delay gives driver threads - time to complete ongoing operations before - forcing termination. - """ - mock_session = Mock() - mock_cluster = Mock() - - # Track shutdown completion - shutdown_complete = asyncio.Event() - - # Mock the cluster shutdown to complete quickly - def mock_shutdown(): - shutdown_complete.set() - - mock_cluster.shutdown = mock_shutdown - - # Create queries that will complete after a delay - query_complete = asyncio.Event() - - # Create mock ResponseFutures - def create_mock_future(*args): - mock_future = Mock() - mock_future.has_more_pages = False - mock_future.timeout = None - mock_future.add_callbacks = Mock() - - def handle_callbacks(callback=None, errback=None): - # Schedule the callback to be called after a short delay - # This simulates a query that completes during shutdown - def delayed_callback(): - if callback: - callback([]) # Call with empty rows - query_complete.set() - - # Use asyncio to schedule the callback - asyncio.get_event_loop().call_later(0.1, delayed_callback) - - mock_future.add_callbacks.side_effect = handle_callbacks - return mock_future - - mock_session.execute_async.side_effect = create_mock_future - - cluster = AsyncCluster() - cluster._cluster = mock_cluster - cluster._cluster.protocol_version = 5 # Mock protocol version - cluster._cluster.connect.return_value = mock_session - - session = await cluster.connect() - - # Start a query - query_task = asyncio.create_task(session.execute("SELECT * FROM table")) - - # Give query time to start - await asyncio.sleep(0.05) - - # Start shutdown in background (it will wait 5 seconds after driver shutdown) - shutdown_task = asyncio.create_task(cluster.shutdown()) - - # Wait for driver shutdown to complete - await shutdown_complete.wait() - - # Query should complete during the 5 second wait - await query_complete.wait() - - # Wait for the query task to actually complete - # Use wait_for with a timeout to avoid hanging if something goes wrong - try: - await asyncio.wait_for(query_task, timeout=1.0) - except asyncio.TimeoutError: - pytest.fail("Query task did not complete within timeout") - - # Wait for full shutdown including the 5 second delay - await shutdown_task - - # Verify everything completed properly - assert query_task.done() - assert not query_task.cancelled() # Query completed normally - assert cluster.is_closed - - @pytest.mark.resilience - async def test_error_stack_trace_preservation(self): - """ - Test that error stack traces are preserved through async layer. - - What this tests: - --------------- - 1. Original exception traceback preserved - 2. Error message unchanged - 3. Exception type maintained - 4. Debugging information intact - - Why this matters: - ---------------- - Stack traces are critical for debugging: - - Show where error originated - - Include call chain context - - Help identify root cause - - The async wrapper must not lose or corrupt - this debugging information while propagating - errors across thread boundaries. - """ - mock_session = Mock() - - # Create an error with traceback info - try: - raise InvalidRequest("Original error") - except InvalidRequest as e: - original_error = e - - mock_session.execute_async.side_effect = original_error - - async_session = AsyncSession(mock_session) - - try: - await async_session.execute("SELECT * FROM users") - except InvalidRequest as e: - # Stack trace should be preserved - assert str(e) == "Original error" - assert e.__traceback__ is not None - - @pytest.mark.resilience - async def test_concurrent_error_isolation(self): - """ - Test that errors in concurrent queries don't affect each other. - - What this tests: - --------------- - 1. Each query gets its own error/result - 2. Failures don't cascade to other queries - 3. Mixed success/failure scenarios work - 4. Error types are preserved per query - - Why this matters: - ---------------- - Applications often run many queries concurrently: - - Dashboard fetching multiple metrics - - Batch processing different tables - - Parallel data aggregation - - One query's failure should not affect others. - Each query should succeed or fail independently - based on its own merits. - """ - mock_session = Mock() - - # Different errors for different queries - def execute_side_effect(query, *args, **kwargs): - if "table1" in query: - raise InvalidRequest("Error in table1") - elif "table2" in query: - # Create a mock ResponseFuture for success - return create_mock_response_future([{"id": 2}]) - elif "table3" in query: - raise NoHostAvailable("No hosts for table3", {}) - else: - # Create a mock ResponseFuture for empty result - return create_mock_response_future([]) - - mock_session.execute_async.side_effect = execute_side_effect - - async_session = AsyncSession(mock_session) - - # Execute queries concurrently - tasks = [ - async_session.execute("SELECT * FROM table1"), - async_session.execute("SELECT * FROM table2"), - async_session.execute("SELECT * FROM table3"), - ] - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Verify each query got its expected result/error - assert isinstance(results[0], InvalidRequest) - assert "Error in table1" in str(results[0]) - - assert not isinstance(results[1], Exception) - assert results[1]._rows == [{"id": 2}] - - assert isinstance(results[2], NoHostAvailable) - assert "No hosts for table3" in str(results[2]) diff --git a/tests/unit/test_event_loop_handling.py b/tests/unit/test_event_loop_handling.py deleted file mode 100644 index a9278d4..0000000 --- a/tests/unit/test_event_loop_handling.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Unit tests for event loop reference handling. -""" - -import asyncio -from unittest.mock import Mock - -import pytest - -from async_cassandra.result import AsyncResultHandler -from async_cassandra.streaming import AsyncStreamingResultSet - - -@pytest.mark.asyncio -class TestEventLoopHandling: - """Test that event loop references are not stored.""" - - async def test_result_handler_no_stored_loop_reference(self): - """ - Test that AsyncResultHandler doesn't store event loop reference initially. - - What this tests: - --------------- - 1. No loop reference at creation - 2. Future not created eagerly - 3. Early result tracking exists - 4. Lazy initialization pattern - - Why this matters: - ---------------- - Event loop references problematic: - - Can't share across threads - - Prevents object reuse - - Causes "attached to different loop" errors - - Lazy creation allows flexible - usage across different contexts. - """ - # Create handler - response_future = Mock() - response_future.has_more_pages = False - response_future.add_callbacks = Mock() - response_future.timeout = None - - handler = AsyncResultHandler(response_future) - - # Verify no _loop attribute initially - assert not hasattr(handler, "_loop") - # Future should be None initially - assert handler._future is None - # Should have early result/error tracking - assert hasattr(handler, "_early_result") - assert hasattr(handler, "_early_error") - - async def test_streaming_no_stored_loop_reference(self): - """ - Test that AsyncStreamingResultSet doesn't store event loop reference initially. - - What this tests: - --------------- - 1. Loop starts as None - 2. No eager event creation - 3. Clean initial state - 4. Ready for any loop - - Why this matters: - ---------------- - Streaming objects created in threads: - - Driver callbacks from thread pool - - No event loop in creation context - - Must defer loop capture - - Enables thread-safe object creation - before async iteration. - """ - # Create streaming result set - response_future = Mock() - response_future.has_more_pages = False - response_future.add_callbacks = Mock() - - result_set = AsyncStreamingResultSet(response_future) - - # _loop is initialized to None - assert result_set._loop is None - - async def test_future_created_on_first_get_result(self): - """ - Test that future is created on first call to get_result. - - What this tests: - --------------- - 1. Future created on demand - 2. Loop captured at usage time - 3. Callbacks work correctly - 4. Results properly aggregated - - Why this matters: - ---------------- - Just-in-time future creation: - - Captures correct event loop - - Avoids cross-loop issues - - Works with any async context - - Critical for framework integration - where object creation context differs - from usage context. - """ - # Create handler with has_more_pages=True to prevent immediate completion - response_future = Mock() - response_future.has_more_pages = True # Start with more pages - response_future.add_callbacks = Mock() - response_future.start_fetching_next_page = Mock() - response_future.timeout = None - - handler = AsyncResultHandler(response_future) - - # Future should not be created yet - assert handler._future is None - - # Get the callback that was registered - call_args = response_future.add_callbacks.call_args - callback = call_args.kwargs.get("callback") if call_args else None - - # Start get_result task - result_task = asyncio.create_task(handler.get_result()) - await asyncio.sleep(0.01) - - # Future should now be created - assert handler._future is not None - assert hasattr(handler, "_loop") - - # Trigger callbacks to complete the future - if callback: - # First page - callback(["row1"]) - # Now indicate no more pages - response_future.has_more_pages = False - # Second page (final) - callback(["row2"]) - - # Get result - result = await result_task - assert len(result.rows) == 2 - - async def test_streaming_page_ready_lazy_creation(self): - """ - Test that page_ready event is created lazily. - - What this tests: - --------------- - 1. Event created on iteration start - 2. Thread callbacks work correctly - 3. Loop captured at right time - 4. Cross-thread coordination works - - Why this matters: - ---------------- - Streaming uses thread callbacks: - - Driver calls from thread pool - - Event needed for coordination - - Must work across thread boundaries - - Lazy event creation ensures - correct loop association for - thread-to-async communication. - """ - # Create streaming result set - response_future = Mock() - response_future.has_more_pages = False - response_future._final_exception = None # Important: must be None - response_future.add_callbacks = Mock() - - result_set = AsyncStreamingResultSet(response_future) - - # Page ready event should not exist yet - assert result_set._page_ready is None - - # Trigger callback from a thread (like the real driver) - args = response_future.add_callbacks.call_args - callback = args[1]["callback"] - - import threading - - def thread_callback(): - callback(["row1", "row2"]) - - thread = threading.Thread(target=thread_callback) - thread.start() - - # Start iteration - this should create the event - rows = [] - async for row in result_set: - rows.append(row) - - # Now page_ready should be created - assert result_set._page_ready is not None - assert isinstance(result_set._page_ready, asyncio.Event) - assert len(rows) == 2 - - # Loop should also be stored now - assert result_set._loop is not None diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py deleted file mode 100644 index 298816c..0000000 --- a/tests/unit/test_helpers.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Test helpers for advanced features tests. - -This module provides utility functions for creating mock objects that simulate -Cassandra driver behavior in unit tests. These helpers ensure consistent test -behavior and reduce boilerplate across test files. -""" - -import asyncio -from unittest.mock import Mock - - -def create_mock_response_future(rows=None, has_more_pages=False): - """ - Helper to create a properly configured mock ResponseFuture. - - What this does: - -------------- - 1. Creates mock ResponseFuture - 2. Configures callback behavior - 3. Simulates async execution - 4. Handles event loop scheduling - - Why this matters: - ---------------- - Consistent mock behavior: - - Accurate driver simulation - - Reliable test results - - Less test flakiness - - Proper async simulation prevents - race conditions in tests. - - Parameters: - ----------- - rows : list, optional - The rows to return when callback is executed - has_more_pages : bool, default False - Whether to indicate more pages are available - - Returns: - -------- - Mock - A configured mock ResponseFuture object - """ - mock_future = Mock() - mock_future.has_more_pages = has_more_pages - mock_future.timeout = None - mock_future.add_callbacks = Mock() - - def handle_callbacks(callback=None, errback=None): - if callback: - # Schedule callback on the event loop to simulate async behavior - loop = asyncio.get_event_loop() - loop.call_soon(callback, rows if rows is not None else []) - - mock_future.add_callbacks.side_effect = handle_callbacks - return mock_future diff --git a/tests/unit/test_lwt_operations.py b/tests/unit/test_lwt_operations.py deleted file mode 100644 index cea6591..0000000 --- a/tests/unit/test_lwt_operations.py +++ /dev/null @@ -1,595 +0,0 @@ -""" -Unit tests for Lightweight Transaction (LWT) operations. - -Tests how the async wrapper handles: -- IF NOT EXISTS conditions -- IF EXISTS conditions -- Conditional updates -- LWT result parsing -- Race conditions -""" - -import asyncio -from unittest.mock import Mock - -import pytest -from cassandra import InvalidRequest, WriteTimeout -from cassandra.cluster import Session - -from async_cassandra import AsyncCassandraSession - - -class TestLWTOperations: - """Test Lightweight Transaction operations.""" - - def create_lwt_success_future(self, applied=True, existing_data=None): - """Create a mock future for successful LWT operations.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - # LWT results include the [applied] column - if applied: - # Successful LWT - mock_rows = [{"[applied]": True}] - else: - # Failed LWT with existing data - result = {"[applied]": False} - if existing_data: - result.update(existing_data) - mock_rows = [result] - callback(mock_rows) - if errback: - errbacks.append(errback) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def create_error_future(self, exception): - """Create a mock future that raises the given exception.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock(spec=Session) - session.execute_async = Mock() - session.prepare = Mock() - return session - - @pytest.mark.asyncio - async def test_insert_if_not_exists_success(self, mock_session): - """ - Test successful INSERT IF NOT EXISTS. - - What this tests: - --------------- - 1. LWT INSERT succeeds when no conflict - 2. [applied] column is True - 3. Result properly parsed - 4. Async execution works - - Why this matters: - ---------------- - INSERT IF NOT EXISTS enables: - - Distributed unique constraints - - Race-condition-free inserts - - Idempotent operations - - Critical for distributed systems - without locks or coordination. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock successful LWT - mock_session.execute_async.return_value = self.create_lwt_success_future(applied=True) - - # Execute INSERT IF NOT EXISTS - result = await async_session.execute( - "INSERT INTO users (id, name) VALUES (?, ?) IF NOT EXISTS", (1, "Alice") - ) - - # Verify result - assert result is not None - assert len(result.rows) == 1 - assert result.rows[0]["[applied]"] is True - - @pytest.mark.asyncio - async def test_insert_if_not_exists_conflict(self, mock_session): - """ - Test INSERT IF NOT EXISTS when row already exists. - - What this tests: - --------------- - 1. LWT INSERT fails on conflict - 2. [applied] is False - 3. Existing data returned - 4. Can see what blocked insert - - Why this matters: - ---------------- - Failed LWTs return existing data: - - Shows why operation failed - - Enables conflict resolution - - Helps with debugging - - Applications must check [applied] - and handle conflicts appropriately. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock failed LWT with existing data - existing_data = {"id": 1, "name": "Bob"} # Different name - mock_session.execute_async.return_value = self.create_lwt_success_future( - applied=False, existing_data=existing_data - ) - - # Execute INSERT IF NOT EXISTS - result = await async_session.execute( - "INSERT INTO users (id, name) VALUES (?, ?) IF NOT EXISTS", (1, "Alice") - ) - - # Verify result shows conflict - assert result is not None - assert len(result.rows) == 1 - assert result.rows[0]["[applied]"] is False - assert result.rows[0]["id"] == 1 - assert result.rows[0]["name"] == "Bob" - - @pytest.mark.asyncio - async def test_update_if_condition_success(self, mock_session): - """ - Test successful conditional UPDATE. - - What this tests: - --------------- - 1. Conditional UPDATE when condition matches - 2. [applied] is True on success - 3. Update actually applied - 4. Condition properly evaluated - - Why this matters: - ---------------- - Conditional updates enable: - - Optimistic concurrency control - - Check-then-act atomically - - Prevent lost updates - - Essential for maintaining data - consistency without locks. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock successful conditional update - mock_session.execute_async.return_value = self.create_lwt_success_future(applied=True) - - # Execute conditional UPDATE - result = await async_session.execute( - "UPDATE users SET email = ? WHERE id = ? IF name = ?", ("alice@example.com", 1, "Alice") - ) - - # Verify result - assert result is not None - assert len(result.rows) == 1 - assert result.rows[0]["[applied]"] is True - - @pytest.mark.asyncio - async def test_update_if_condition_failure(self, mock_session): - """ - Test conditional UPDATE when condition doesn't match. - - What this tests: - --------------- - 1. UPDATE fails when condition false - 2. [applied] is False - 3. Current values returned - 4. Update not applied - - Why this matters: - ---------------- - Failed conditions show current state: - - Understand why update failed - - Retry with correct values - - Implement compare-and-swap - - Prevents blind overwrites and - maintains data integrity. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock failed conditional update - existing_data = {"name": "Bob"} # Actual name is different - mock_session.execute_async.return_value = self.create_lwt_success_future( - applied=False, existing_data=existing_data - ) - - # Execute conditional UPDATE - result = await async_session.execute( - "UPDATE users SET email = ? WHERE id = ? IF name = ?", ("alice@example.com", 1, "Alice") - ) - - # Verify result shows condition failure - assert result is not None - assert len(result.rows) == 1 - assert result.rows[0]["[applied]"] is False - assert result.rows[0]["name"] == "Bob" - - @pytest.mark.asyncio - async def test_delete_if_exists_success(self, mock_session): - """ - Test successful DELETE IF EXISTS. - - What this tests: - --------------- - 1. DELETE succeeds when row exists - 2. [applied] is True - 3. Row actually deleted - 4. No error on existing row - - Why this matters: - ---------------- - DELETE IF EXISTS provides: - - Idempotent deletes - - No error if already gone - - Useful for cleanup - - Simplifies error handling in - distributed delete operations. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock successful DELETE IF EXISTS - mock_session.execute_async.return_value = self.create_lwt_success_future(applied=True) - - # Execute DELETE IF EXISTS - result = await async_session.execute("DELETE FROM users WHERE id = ? IF EXISTS", (1,)) - - # Verify result - assert result is not None - assert len(result.rows) == 1 - assert result.rows[0]["[applied]"] is True - - @pytest.mark.asyncio - async def test_delete_if_exists_not_found(self, mock_session): - """ - Test DELETE IF EXISTS when row doesn't exist. - - What this tests: - --------------- - 1. DELETE IF EXISTS on missing row - 2. [applied] is False - 3. No error raised - 4. Operation completes normally - - Why this matters: - ---------------- - Missing row handling: - - No exception thrown - - Can detect if deleted - - Idempotent behavior - - Allows safe cleanup without - checking existence first. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock failed DELETE IF EXISTS - mock_session.execute_async.return_value = self.create_lwt_success_future( - applied=False, existing_data={} - ) - - # Execute DELETE IF EXISTS - result = await async_session.execute( - "DELETE FROM users WHERE id = ? IF EXISTS", (999,) # Non-existent ID - ) - - # Verify result - assert result is not None - assert len(result.rows) == 1 - assert result.rows[0]["[applied]"] is False - - @pytest.mark.asyncio - async def test_lwt_with_multiple_conditions(self, mock_session): - """ - Test LWT with multiple IF conditions. - - What this tests: - --------------- - 1. Multiple conditions work together - 2. All must be true to apply - 3. Complex conditions supported - 4. AND logic properly evaluated - - Why this matters: - ---------------- - Multiple conditions enable: - - Complex business rules - - Multi-field validation - - Stronger consistency checks - - Real-world updates often need - multiple preconditions. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock successful multi-condition update - mock_session.execute_async.return_value = self.create_lwt_success_future(applied=True) - - # Execute UPDATE with multiple conditions - result = await async_session.execute( - "UPDATE users SET status = ? WHERE id = ? IF name = ? AND email = ?", - ("active", 1, "Alice", "alice@example.com"), - ) - - # Verify result - assert result is not None - assert len(result.rows) == 1 - assert result.rows[0]["[applied]"] is True - - @pytest.mark.asyncio - async def test_lwt_timeout_handling(self, mock_session): - """ - Test LWT timeout scenarios. - - What this tests: - --------------- - 1. LWT timeouts properly identified - 2. WriteType.CAS indicates LWT - 3. Timeout details preserved - 4. Error not wrapped - - Why this matters: - ---------------- - LWT timeouts are special: - - May have partially applied - - Require careful handling - - Different from regular timeouts - - Applications must handle LWT - timeouts differently than - regular write timeouts. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock WriteTimeout for LWT - from cassandra import WriteType - - timeout_error = WriteTimeout( - "LWT operation timed out", write_type=WriteType.CAS # Compare-And-Set (LWT) - ) - timeout_error.consistency_level = 1 - timeout_error.required_responses = 2 - timeout_error.received_responses = 1 - - mock_session.execute_async.return_value = self.create_error_future(timeout_error) - - # Execute LWT that times out - with pytest.raises(WriteTimeout) as exc_info: - await async_session.execute( - "INSERT INTO users (id, name) VALUES (?, ?) IF NOT EXISTS", (1, "Alice") - ) - - assert "LWT operation timed out" in str(exc_info.value) - assert exc_info.value.write_type == WriteType.CAS - - @pytest.mark.asyncio - async def test_concurrent_lwt_operations(self, mock_session): - """ - Test handling of concurrent LWT operations. - - What this tests: - --------------- - 1. Concurrent LWTs race safely - 2. Only one succeeds - 3. Others see winner's value - 4. No corruption or errors - - Why this matters: - ---------------- - LWTs handle distributed races: - - Exactly one winner - - Losers see winner's data - - No lost updates - - This is THE pattern for distributed - mutual exclusion without locks. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track which request wins the race - request_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal request_count - request_count += 1 - - if request_count == 1: - # First request succeeds - return self.create_lwt_success_future(applied=True) - else: - # Subsequent requests fail (row already exists) - return self.create_lwt_success_future( - applied=False, existing_data={"id": 1, "name": "Alice"} - ) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Execute multiple concurrent LWT operations - tasks = [] - for i in range(5): - task = async_session.execute( - "INSERT INTO users (id, name) VALUES (?, ?) IF NOT EXISTS", (1, f"User_{i}") - ) - tasks.append(task) - - results = await asyncio.gather(*tasks) - - # Only first should succeed - applied_count = sum(1 for r in results if r.rows[0]["[applied]"]) - assert applied_count == 1 - - # Others should show the winning value - for i, result in enumerate(results): - if not result.rows[0]["[applied]"]: - assert result.rows[0]["name"] == "Alice" - - @pytest.mark.asyncio - async def test_lwt_with_prepared_statements(self, mock_session): - """ - Test LWT operations with prepared statements. - - What this tests: - --------------- - 1. LWTs work with prepared statements - 2. Parameters bound correctly - 3. [applied] result available - 4. Performance benefits maintained - - Why this matters: - ---------------- - Prepared LWTs combine: - - Query plan caching - - Parameter safety - - Atomic operations - - Best practice for production - LWT operations. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock prepared statement - mock_prepared = Mock() - mock_prepared.query = "INSERT INTO users (id, name) VALUES (?, ?) IF NOT EXISTS" - mock_prepared.bind = Mock(return_value=Mock()) - mock_session.prepare.return_value = mock_prepared - - # Prepare statement - prepared = await async_session.prepare( - "INSERT INTO users (id, name) VALUES (?, ?) IF NOT EXISTS" - ) - - # Execute with prepared statement - mock_session.execute_async.return_value = self.create_lwt_success_future(applied=True) - - result = await async_session.execute(prepared, (1, "Alice")) - - # Verify result - assert result is not None - assert result.rows[0]["[applied]"] is True - - @pytest.mark.asyncio - async def test_lwt_batch_not_supported(self, mock_session): - """ - Test that LWT in batch statements raises appropriate error. - - What this tests: - --------------- - 1. LWTs not allowed in batches - 2. InvalidRequest raised - 3. Clear error message - 4. Cassandra limitation enforced - - Why this matters: - ---------------- - Cassandra design limitation: - - Batches for atomicity - - LWTs for conditions - - Can't combine both - - Applications must use LWTs - individually, not in batches. - """ - from cassandra.query import BatchStatement, BatchType, SimpleStatement - - async_session = AsyncCassandraSession(mock_session) - - # Create batch with LWT (not supported by Cassandra) - batch = BatchStatement(batch_type=BatchType.LOGGED) - - # Use SimpleStatement to avoid parameter binding issues - stmt = SimpleStatement("INSERT INTO users (id, name) VALUES (1, 'Alice') IF NOT EXISTS") - batch.add(stmt) - - # Mock InvalidRequest for LWT in batch - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Conditional statements are not supported in batches") - ) - - # Should raise InvalidRequest - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute_batch(batch) - - assert "Conditional statements are not supported" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_lwt_result_parsing(self, mock_session): - """ - Test parsing of various LWT result formats. - - What this tests: - --------------- - 1. Various LWT result formats parsed - 2. [applied] always present - 3. Failed LWTs include data - 4. All columns accessible - - Why this matters: - ---------------- - LWT results vary by operation: - - Simple success/failure - - Single column conflicts - - Multi-column current state - - Robust parsing enables proper - conflict resolution logic. - """ - async_session = AsyncCassandraSession(mock_session) - - # Test different result formats - test_cases = [ - # Simple success - ({"[applied]": True}, True, None), - # Failure with single column - ({"[applied]": False, "value": 42}, False, {"value": 42}), - # Failure with multiple columns - ( - {"[applied]": False, "id": 1, "name": "Alice", "email": "alice@example.com"}, - False, - {"id": 1, "name": "Alice", "email": "alice@example.com"}, - ), - ] - - for result_data, expected_applied, expected_data in test_cases: - mock_session.execute_async.return_value = self.create_lwt_success_future( - applied=result_data["[applied]"], - existing_data={k: v for k, v in result_data.items() if k != "[applied]"}, - ) - - result = await async_session.execute("UPDATE users SET ... IF ...") - - assert result.rows[0]["[applied]"] == expected_applied - - if expected_data: - for key, value in expected_data.items(): - assert result.rows[0][key] == value diff --git a/tests/unit/test_monitoring_unified.py b/tests/unit/test_monitoring_unified.py deleted file mode 100644 index 7e90264..0000000 --- a/tests/unit/test_monitoring_unified.py +++ /dev/null @@ -1,1024 +0,0 @@ -""" -Unified monitoring and metrics tests for async-python-cassandra. - -This module provides comprehensive tests for the monitoring and metrics -functionality based on the actual implementation. - -Test Organization: -================== -1. Metrics Data Classes - Testing QueryMetrics and ConnectionMetrics -2. InMemoryMetricsCollector - Testing the in-memory metrics backend -3. PrometheusMetricsCollector - Testing Prometheus integration -4. MetricsMiddleware - Testing the middleware layer -5. ConnectionMonitor - Testing connection health monitoring -6. RateLimitedSession - Testing rate limiting functionality -7. Integration Tests - Testing the full monitoring stack - -Key Testing Principles: -====================== -- All metrics methods are async and must be awaited -- Test thread safety with asyncio.Lock -- Verify metrics accuracy and aggregation -- Test graceful degradation without prometheus_client -- Ensure monitoring doesn't impact performance -""" - -import asyncio -from datetime import datetime, timedelta, timezone -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from async_cassandra.metrics import ( - ConnectionMetrics, - InMemoryMetricsCollector, - MetricsMiddleware, - PrometheusMetricsCollector, - QueryMetrics, - create_metrics_system, -) -from async_cassandra.monitoring import ( - HOST_STATUS_DOWN, - HOST_STATUS_UNKNOWN, - HOST_STATUS_UP, - ClusterMetrics, - ConnectionMonitor, - HostMetrics, - RateLimitedSession, - create_monitored_session, -) - - -class TestMetricsDataClasses: - """Test the metrics data classes.""" - - def test_query_metrics_creation(self): - """Test QueryMetrics dataclass creation and fields.""" - now = datetime.now(timezone.utc) - metrics = QueryMetrics( - query_hash="abc123", - duration=0.123, - success=True, - error_type=None, - timestamp=now, - parameters_count=2, - result_size=10, - ) - - assert metrics.query_hash == "abc123" - assert metrics.duration == 0.123 - assert metrics.success is True - assert metrics.error_type is None - assert metrics.timestamp == now - assert metrics.parameters_count == 2 - assert metrics.result_size == 10 - - def test_query_metrics_defaults(self): - """Test QueryMetrics default values.""" - metrics = QueryMetrics( - query_hash="xyz789", duration=0.05, success=False, error_type="Timeout" - ) - - assert metrics.parameters_count == 0 - assert metrics.result_size == 0 - assert isinstance(metrics.timestamp, datetime) - assert metrics.timestamp.tzinfo == timezone.utc - - def test_connection_metrics_creation(self): - """Test ConnectionMetrics dataclass creation.""" - now = datetime.now(timezone.utc) - metrics = ConnectionMetrics( - host="127.0.0.1", - is_healthy=True, - last_check=now, - response_time=0.02, - error_count=0, - total_queries=100, - ) - - assert metrics.host == "127.0.0.1" - assert metrics.is_healthy is True - assert metrics.last_check == now - assert metrics.response_time == 0.02 - assert metrics.error_count == 0 - assert metrics.total_queries == 100 - - def test_host_metrics_creation(self): - """Test HostMetrics dataclass for monitoring.""" - now = datetime.now(timezone.utc) - metrics = HostMetrics( - address="127.0.0.1", - datacenter="dc1", - rack="rack1", - status=HOST_STATUS_UP, - release_version="4.0.1", - connection_count=1, - latency_ms=5.2, - last_error=None, - last_check=now, - ) - - assert metrics.address == "127.0.0.1" - assert metrics.datacenter == "dc1" - assert metrics.rack == "rack1" - assert metrics.status == HOST_STATUS_UP - assert metrics.release_version == "4.0.1" - assert metrics.connection_count == 1 - assert metrics.latency_ms == 5.2 - assert metrics.last_error is None - assert metrics.last_check == now - - def test_cluster_metrics_creation(self): - """Test ClusterMetrics aggregation dataclass.""" - now = datetime.now(timezone.utc) - host1 = HostMetrics("127.0.0.1", "dc1", "rack1", HOST_STATUS_UP, "4.0.1", 1) - host2 = HostMetrics("127.0.0.2", "dc1", "rack2", HOST_STATUS_DOWN, "4.0.1", 0) - - cluster = ClusterMetrics( - timestamp=now, - cluster_name="test_cluster", - protocol_version=4, - hosts=[host1, host2], - total_connections=1, - healthy_hosts=1, - unhealthy_hosts=1, - app_metrics={"requests_sent": 100}, - ) - - assert cluster.timestamp == now - assert cluster.cluster_name == "test_cluster" - assert cluster.protocol_version == 4 - assert len(cluster.hosts) == 2 - assert cluster.total_connections == 1 - assert cluster.healthy_hosts == 1 - assert cluster.unhealthy_hosts == 1 - assert cluster.app_metrics["requests_sent"] == 100 - - -class TestInMemoryMetricsCollector: - """Test the in-memory metrics collection system.""" - - @pytest.mark.asyncio - async def test_record_query_metrics(self): - """Test recording query metrics.""" - collector = InMemoryMetricsCollector(max_entries=100) - - # Create and record metrics - metrics = QueryMetrics( - query_hash="abc123", duration=0.1, success=True, parameters_count=1, result_size=5 - ) - - await collector.record_query(metrics) - - # Check it was recorded - assert len(collector.query_metrics) == 1 - assert collector.query_metrics[0] == metrics - assert collector.query_counts["abc123"] == 1 - - @pytest.mark.asyncio - async def test_record_query_with_error(self): - """Test recording failed queries.""" - collector = InMemoryMetricsCollector() - - # Record failed query - metrics = QueryMetrics( - query_hash="xyz789", duration=0.05, success=False, error_type="InvalidRequest" - ) - - await collector.record_query(metrics) - - # Check error counting - assert collector.error_counts["InvalidRequest"] == 1 - assert len(collector.query_metrics) == 1 - - @pytest.mark.asyncio - async def test_max_entries_limit(self): - """Test that collector respects max_entries limit.""" - collector = InMemoryMetricsCollector(max_entries=5) - - # Record more than max entries - for i in range(10): - metrics = QueryMetrics(query_hash=f"query_{i}", duration=0.1, success=True) - await collector.record_query(metrics) - - # Should only keep the last 5 - assert len(collector.query_metrics) == 5 - # Verify it's the last 5 queries (deque behavior) - hashes = [m.query_hash for m in collector.query_metrics] - assert hashes == ["query_5", "query_6", "query_7", "query_8", "query_9"] - - @pytest.mark.asyncio - async def test_record_connection_health(self): - """Test recording connection health metrics.""" - collector = InMemoryMetricsCollector() - - # Record healthy connection - healthy = ConnectionMetrics( - host="127.0.0.1", - is_healthy=True, - last_check=datetime.now(timezone.utc), - response_time=0.02, - error_count=0, - total_queries=50, - ) - await collector.record_connection_health(healthy) - - # Record unhealthy connection - unhealthy = ConnectionMetrics( - host="127.0.0.2", - is_healthy=False, - last_check=datetime.now(timezone.utc), - response_time=0, - error_count=5, - total_queries=10, - ) - await collector.record_connection_health(unhealthy) - - # Check storage - assert "127.0.0.1" in collector.connection_metrics - assert "127.0.0.2" in collector.connection_metrics - assert collector.connection_metrics["127.0.0.1"].is_healthy is True - assert collector.connection_metrics["127.0.0.2"].is_healthy is False - - @pytest.mark.asyncio - async def test_get_stats_no_data(self): - """ - Test get_stats with no data. - - What this tests: - --------------- - 1. Empty stats dictionary structure - 2. No errors with zero metrics - 3. Consistent stat categories - 4. Safe empty state handling - - Why this matters: - ---------------- - - Graceful startup behavior - - No NPEs in monitoring code - - Consistent API responses - - Clean initial state - - Additional context: - --------------------------------- - - Returns valid structure even if empty - - All stat categories present - - Zero values, not null/missing - """ - collector = InMemoryMetricsCollector() - stats = await collector.get_stats() - - assert stats == {"message": "No metrics available"} - - @pytest.mark.asyncio - async def test_get_stats_with_recent_queries(self): - """Test get_stats with recent query data.""" - collector = InMemoryMetricsCollector() - - # Record some recent queries - now = datetime.now(timezone.utc) - for i in range(5): - metrics = QueryMetrics( - query_hash=f"query_{i}", - duration=0.1 * (i + 1), - success=i % 2 == 0, - error_type="Timeout" if i % 2 else None, - timestamp=now - timedelta(minutes=1), - result_size=10 * i, - ) - await collector.record_query(metrics) - - stats = await collector.get_stats() - - # Check structure - assert "query_performance" in stats - assert "error_summary" in stats - assert "top_queries" in stats - assert "connection_health" in stats - - # Check calculations - perf = stats["query_performance"] - assert perf["total_queries"] == 5 - assert perf["recent_queries_5min"] == 5 - assert perf["success_rate"] == 0.6 # 3 out of 5 - assert "avg_duration_ms" in perf - assert "min_duration_ms" in perf - assert "max_duration_ms" in perf - - # Check error summary - assert stats["error_summary"]["Timeout"] == 2 - - @pytest.mark.asyncio - async def test_get_stats_with_old_queries(self): - """Test get_stats filters out old queries.""" - collector = InMemoryMetricsCollector() - - # Record old query - old_metrics = QueryMetrics( - query_hash="old_query", - duration=0.1, - success=True, - timestamp=datetime.now(timezone.utc) - timedelta(minutes=10), - ) - await collector.record_query(old_metrics) - - stats = await collector.get_stats() - - # Should have no recent queries - assert stats["query_performance"]["message"] == "No recent queries" - assert stats["error_summary"] == {} - - @pytest.mark.asyncio - async def test_thread_safety(self): - """Test that collector is thread-safe with async operations.""" - collector = InMemoryMetricsCollector(max_entries=1000) - - async def record_many(start_id: int): - for i in range(100): - metrics = QueryMetrics( - query_hash=f"query_{start_id}_{i}", duration=0.01, success=True - ) - await collector.record_query(metrics) - - # Run multiple concurrent tasks - tasks = [record_many(i * 100) for i in range(5)] - await asyncio.gather(*tasks) - - # Should have recorded all 500 - assert len(collector.query_metrics) == 500 - - -class TestPrometheusMetricsCollector: - """Test the Prometheus metrics collector.""" - - def test_initialization_without_prometheus_client(self): - """Test initialization when prometheus_client is not available.""" - with patch.dict("sys.modules", {"prometheus_client": None}): - collector = PrometheusMetricsCollector() - - assert collector._available is False - assert collector.query_duration is None - assert collector.query_total is None - assert collector.connection_health is None - assert collector.error_total is None - - @pytest.mark.asyncio - async def test_record_query_without_prometheus(self): - """Test recording works gracefully without prometheus_client.""" - with patch.dict("sys.modules", {"prometheus_client": None}): - collector = PrometheusMetricsCollector() - - # Should not raise - metrics = QueryMetrics(query_hash="test", duration=0.1, success=True) - await collector.record_query(metrics) - - @pytest.mark.asyncio - async def test_record_connection_without_prometheus(self): - """Test connection recording without prometheus_client.""" - with patch.dict("sys.modules", {"prometheus_client": None}): - collector = PrometheusMetricsCollector() - - # Should not raise - metrics = ConnectionMetrics( - host="127.0.0.1", - is_healthy=True, - last_check=datetime.now(timezone.utc), - response_time=0.02, - ) - await collector.record_connection_health(metrics) - - @pytest.mark.asyncio - async def test_get_stats_without_prometheus(self): - """Test get_stats without prometheus_client.""" - with patch.dict("sys.modules", {"prometheus_client": None}): - collector = PrometheusMetricsCollector() - stats = await collector.get_stats() - - assert stats == {"error": "Prometheus client not available"} - - @pytest.mark.asyncio - async def test_with_prometheus_client(self): - """Test with mocked prometheus_client.""" - # Mock prometheus_client - mock_histogram = Mock() - mock_counter = Mock() - mock_gauge = Mock() - - mock_prometheus = Mock() - mock_prometheus.Histogram.return_value = mock_histogram - mock_prometheus.Counter.return_value = mock_counter - mock_prometheus.Gauge.return_value = mock_gauge - - with patch.dict("sys.modules", {"prometheus_client": mock_prometheus}): - collector = PrometheusMetricsCollector() - - assert collector._available is True - assert collector.query_duration is mock_histogram - assert collector.query_total is mock_counter - assert collector.connection_health is mock_gauge - assert collector.error_total is mock_counter - - # Test recording query - metrics = QueryMetrics(query_hash="prepared_stmt_123", duration=0.05, success=True) - await collector.record_query(metrics) - - # Verify Prometheus metrics were updated - mock_histogram.labels.assert_called_with(query_type="prepared", success="success") - mock_histogram.labels().observe.assert_called_with(0.05) - mock_counter.labels.assert_called_with(query_type="prepared", success="success") - mock_counter.labels().inc.assert_called() - - -class TestMetricsMiddleware: - """Test the metrics middleware functionality.""" - - @pytest.mark.asyncio - async def test_middleware_creation(self): - """Test creating metrics middleware.""" - collector = InMemoryMetricsCollector() - middleware = MetricsMiddleware([collector]) - - assert len(middleware.collectors) == 1 - assert middleware._enabled is True - - def test_enable_disable(self): - """Test enabling and disabling middleware.""" - middleware = MetricsMiddleware([]) - - # Initially enabled - assert middleware._enabled is True - - # Disable - middleware.disable() - assert middleware._enabled is False - - # Re-enable - middleware.enable() - assert middleware._enabled is True - - @pytest.mark.asyncio - async def test_record_query_metrics(self): - """Test recording metrics through middleware.""" - collector = InMemoryMetricsCollector() - middleware = MetricsMiddleware([collector]) - - # Record a query - await middleware.record_query_metrics( - query="SELECT * FROM users WHERE id = ?", - duration=0.05, - success=True, - error_type=None, - parameters_count=1, - result_size=1, - ) - - # Check it was recorded - assert len(collector.query_metrics) == 1 - recorded = collector.query_metrics[0] - assert recorded.duration == 0.05 - assert recorded.success is True - assert recorded.parameters_count == 1 - assert recorded.result_size == 1 - - @pytest.mark.asyncio - async def test_record_query_metrics_disabled(self): - """Test that disabled middleware doesn't record.""" - collector = InMemoryMetricsCollector() - middleware = MetricsMiddleware([collector]) - middleware.disable() - - # Try to record - await middleware.record_query_metrics( - query="SELECT * FROM users", duration=0.05, success=True - ) - - # Nothing should be recorded - assert len(collector.query_metrics) == 0 - - def test_normalize_query(self): - """Test query normalization for grouping.""" - middleware = MetricsMiddleware([]) - - # Test normalization creates consistent hashes - query1 = "SELECT * FROM users WHERE id = 123" - query2 = "SELECT * FROM users WHERE id = 456" - query3 = "select * from users where id = 789" - - # Different values but same structure should get same hash - hash1 = middleware._normalize_query(query1) - hash2 = middleware._normalize_query(query2) - hash3 = middleware._normalize_query(query3) - - assert hash1 == hash2 # Same query structure - assert hash1 == hash3 # Whitespace normalized - - def test_normalize_query_different_structures(self): - """Test normalization of different query structures.""" - middleware = MetricsMiddleware([]) - - queries = [ - "SELECT * FROM users WHERE id = ?", - "SELECT * FROM users WHERE name = ?", - "INSERT INTO users VALUES (?, ?)", - "DELETE FROM users WHERE id = ?", - ] - - hashes = [middleware._normalize_query(q) for q in queries] - - # All should be different - assert len(set(hashes)) == len(queries) - - @pytest.mark.asyncio - async def test_record_connection_metrics(self): - """Test recording connection health through middleware.""" - collector = InMemoryMetricsCollector() - middleware = MetricsMiddleware([collector]) - - await middleware.record_connection_metrics( - host="127.0.0.1", is_healthy=True, response_time=0.02, error_count=0, total_queries=100 - ) - - assert "127.0.0.1" in collector.connection_metrics - metrics = collector.connection_metrics["127.0.0.1"] - assert metrics.is_healthy is True - assert metrics.response_time == 0.02 - - @pytest.mark.asyncio - async def test_multiple_collectors(self): - """Test middleware with multiple collectors.""" - collector1 = InMemoryMetricsCollector() - collector2 = InMemoryMetricsCollector() - middleware = MetricsMiddleware([collector1, collector2]) - - await middleware.record_query_metrics( - query="SELECT * FROM test", duration=0.1, success=True - ) - - # Both collectors should have the metrics - assert len(collector1.query_metrics) == 1 - assert len(collector2.query_metrics) == 1 - - @pytest.mark.asyncio - async def test_collector_error_handling(self): - """Test middleware handles collector errors gracefully.""" - # Create a failing collector - failing_collector = Mock() - failing_collector.record_query = AsyncMock(side_effect=Exception("Collector failed")) - - # And a working collector - working_collector = InMemoryMetricsCollector() - - middleware = MetricsMiddleware([failing_collector, working_collector]) - - # Should not raise - await middleware.record_query_metrics( - query="SELECT * FROM test", duration=0.1, success=True - ) - - # Working collector should still get metrics - assert len(working_collector.query_metrics) == 1 - - -class TestConnectionMonitor: - """Test the connection monitoring functionality.""" - - def test_monitor_initialization(self): - """Test ConnectionMonitor initialization.""" - mock_session = Mock() - monitor = ConnectionMonitor(mock_session) - - assert monitor.session == mock_session - assert monitor.metrics["requests_sent"] == 0 - assert monitor.metrics["requests_completed"] == 0 - assert monitor.metrics["requests_failed"] == 0 - assert monitor._monitoring_task is None - assert len(monitor._callbacks) == 0 - - def test_add_callback(self): - """Test adding monitoring callbacks.""" - mock_session = Mock() - monitor = ConnectionMonitor(mock_session) - - callback1 = Mock() - callback2 = Mock() - - monitor.add_callback(callback1) - monitor.add_callback(callback2) - - assert len(monitor._callbacks) == 2 - assert callback1 in monitor._callbacks - assert callback2 in monitor._callbacks - - @pytest.mark.asyncio - async def test_check_host_health_up(self): - """Test checking health of an up host.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock()) - - monitor = ConnectionMonitor(mock_session) - - # Mock host - host = Mock() - host.address = "127.0.0.1" - host.datacenter = "dc1" - host.rack = "rack1" - host.is_up = True - host.release_version = "4.0.1" - - metrics = await monitor.check_host_health(host) - - assert metrics.address == "127.0.0.1" - assert metrics.datacenter == "dc1" - assert metrics.rack == "rack1" - assert metrics.status == HOST_STATUS_UP - assert metrics.release_version == "4.0.1" - assert metrics.connection_count == 1 - assert metrics.latency_ms is not None - assert metrics.latency_ms > 0 - assert isinstance(metrics.last_check, datetime) - - @pytest.mark.asyncio - async def test_check_host_health_down(self): - """Test checking health of a down host.""" - mock_session = Mock() - monitor = ConnectionMonitor(mock_session) - - # Mock host - host = Mock() - host.address = "127.0.0.1" - host.datacenter = "dc1" - host.rack = "rack1" - host.is_up = False - host.release_version = "4.0.1" - - metrics = await monitor.check_host_health(host) - - assert metrics.address == "127.0.0.1" - assert metrics.status == HOST_STATUS_DOWN - assert metrics.connection_count == 0 - assert metrics.latency_ms is None - assert metrics.last_check is None - - @pytest.mark.asyncio - async def test_check_host_health_with_error(self): - """Test host health check with connection error.""" - mock_session = Mock() - mock_session.execute = AsyncMock(side_effect=Exception("Connection failed")) - - monitor = ConnectionMonitor(mock_session) - - # Mock host - host = Mock() - host.address = "127.0.0.1" - host.datacenter = "dc1" - host.rack = "rack1" - host.is_up = True - host.release_version = "4.0.1" - - metrics = await monitor.check_host_health(host) - - assert metrics.address == "127.0.0.1" - assert metrics.status == HOST_STATUS_UNKNOWN - assert metrics.connection_count == 0 - assert metrics.last_error == "Connection failed" - - @pytest.mark.asyncio - async def test_get_cluster_metrics(self): - """Test getting comprehensive cluster metrics.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock()) - - # Mock cluster - mock_cluster = Mock() - mock_cluster.metadata.cluster_name = "test_cluster" - mock_cluster.protocol_version = 4 - - # Mock hosts - host1 = Mock() - host1.address = "127.0.0.1" - host1.datacenter = "dc1" - host1.rack = "rack1" - host1.is_up = True - host1.release_version = "4.0.1" - - host2 = Mock() - host2.address = "127.0.0.2" - host2.datacenter = "dc1" - host2.rack = "rack2" - host2.is_up = False - host2.release_version = "4.0.1" - - mock_cluster.metadata.all_hosts.return_value = [host1, host2] - mock_session._session.cluster = mock_cluster - - monitor = ConnectionMonitor(mock_session) - metrics = await monitor.get_cluster_metrics() - - assert isinstance(metrics, ClusterMetrics) - assert metrics.cluster_name == "test_cluster" - assert metrics.protocol_version == 4 - assert len(metrics.hosts) == 2 - assert metrics.healthy_hosts == 1 - assert metrics.unhealthy_hosts == 1 - assert metrics.total_connections == 1 - - @pytest.mark.asyncio - async def test_warmup_connections(self): - """Test warming up connections to hosts.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock()) - - # Mock cluster - mock_cluster = Mock() - host1 = Mock(is_up=True, address="127.0.0.1") - host2 = Mock(is_up=True, address="127.0.0.2") - host3 = Mock(is_up=False, address="127.0.0.3") - - mock_cluster.metadata.all_hosts.return_value = [host1, host2, host3] - mock_session._session.cluster = mock_cluster - - monitor = ConnectionMonitor(mock_session) - await monitor.warmup_connections() - - # Should only warm up the two up hosts - assert mock_session.execute.call_count == 2 - - @pytest.mark.asyncio - async def test_warmup_connections_with_failures(self): - """Test connection warmup with some failures.""" - mock_session = Mock() - # First call succeeds, second fails - mock_session.execute = AsyncMock(side_effect=[Mock(), Exception("Failed")]) - - # Mock cluster - mock_cluster = Mock() - host1 = Mock(is_up=True, address="127.0.0.1") - host2 = Mock(is_up=True, address="127.0.0.2") - - mock_cluster.metadata.all_hosts.return_value = [host1, host2] - mock_session._session.cluster = mock_cluster - - monitor = ConnectionMonitor(mock_session) - # Should not raise - await monitor.warmup_connections() - - @pytest.mark.asyncio - async def test_start_stop_monitoring(self): - """Test starting and stopping monitoring.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock()) - - # Mock cluster - mock_cluster = Mock() - mock_cluster.metadata.cluster_name = "test" - mock_cluster.protocol_version = 4 - mock_cluster.metadata.all_hosts.return_value = [] - mock_session._session.cluster = mock_cluster - - monitor = ConnectionMonitor(mock_session) - - # Start monitoring - await monitor.start_monitoring(interval=0.1) - assert monitor._monitoring_task is not None - assert not monitor._monitoring_task.done() - - # Let it run briefly - await asyncio.sleep(0.2) - - # Stop monitoring - await monitor.stop_monitoring() - assert monitor._monitoring_task.done() - - @pytest.mark.asyncio - async def test_monitoring_loop_with_callbacks(self): - """Test monitoring loop executes callbacks.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock()) - - # Mock cluster - mock_cluster = Mock() - mock_cluster.metadata.cluster_name = "test" - mock_cluster.protocol_version = 4 - mock_cluster.metadata.all_hosts.return_value = [] - mock_session._session.cluster = mock_cluster - - monitor = ConnectionMonitor(mock_session) - - # Track callback executions - callback_metrics = [] - - def sync_callback(metrics): - callback_metrics.append(metrics) - - async def async_callback(metrics): - await asyncio.sleep(0.01) - callback_metrics.append(metrics) - - monitor.add_callback(sync_callback) - monitor.add_callback(async_callback) - - # Start monitoring - await monitor.start_monitoring(interval=0.1) - - # Wait for at least one check - await asyncio.sleep(0.2) - - # Stop monitoring - await monitor.stop_monitoring() - - # Both callbacks should have been called at least once - assert len(callback_metrics) >= 1 - - def test_get_connection_summary(self): - """Test getting connection summary.""" - mock_session = Mock() - - # Mock cluster - mock_cluster = Mock() - mock_cluster.protocol_version = 4 - - host1 = Mock(is_up=True) - host2 = Mock(is_up=True) - host3 = Mock(is_up=False) - - mock_cluster.metadata.all_hosts.return_value = [host1, host2, host3] - mock_session._session.cluster = mock_cluster - - monitor = ConnectionMonitor(mock_session) - summary = monitor.get_connection_summary() - - assert summary["total_hosts"] == 3 - assert summary["up_hosts"] == 2 - assert summary["down_hosts"] == 1 - assert summary["protocol_version"] == 4 - assert summary["max_requests_per_connection"] == 32768 - - -class TestRateLimitedSession: - """Test the rate-limited session wrapper.""" - - @pytest.mark.asyncio - async def test_basic_execute(self): - """Test basic execute with rate limiting.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock(rows=[{"id": 1}])) - - # Create rate limited session (default 1000 concurrent) - limited = RateLimitedSession(mock_session, max_concurrent=10) - - result = await limited.execute("SELECT * FROM users") - - assert result.rows == [{"id": 1}] - mock_session.execute.assert_called_once_with("SELECT * FROM users", None) - - @pytest.mark.asyncio - async def test_execute_with_parameters(self): - """Test execute with parameters.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock(rows=[])) - - limited = RateLimitedSession(mock_session) - - await limited.execute("SELECT * FROM users WHERE id = ?", parameters=[123], timeout=5.0) - - mock_session.execute.assert_called_once_with( - "SELECT * FROM users WHERE id = ?", [123], timeout=5.0 - ) - - @pytest.mark.asyncio - async def test_prepare_not_rate_limited(self): - """Test that prepare statements are not rate limited.""" - mock_session = Mock() - mock_session.prepare = AsyncMock(return_value=Mock()) - - limited = RateLimitedSession(mock_session, max_concurrent=1) - - # Should not be delayed - stmt = await limited.prepare("SELECT * FROM users WHERE id = ?") - - assert stmt is not None - mock_session.prepare.assert_called_once() - - @pytest.mark.asyncio - async def test_concurrent_rate_limiting(self): - """Test rate limiting with concurrent requests.""" - mock_session = Mock() - - # Track concurrent executions - concurrent_count = 0 - max_concurrent_seen = 0 - - async def track_execute(*args, **kwargs): - nonlocal concurrent_count, max_concurrent_seen - concurrent_count += 1 - max_concurrent_seen = max(max_concurrent_seen, concurrent_count) - await asyncio.sleep(0.05) # Simulate query time - concurrent_count -= 1 - return Mock(rows=[]) - - mock_session.execute = track_execute - - # Very limited concurrency: 2 - limited = RateLimitedSession(mock_session, max_concurrent=2) - - # Try to execute 4 queries concurrently - tasks = [limited.execute(f"SELECT {i}") for i in range(4)] - - await asyncio.gather(*tasks) - - # Should never exceed max_concurrent - assert max_concurrent_seen <= 2 - - def test_get_metrics(self): - """Test getting rate limiter metrics.""" - mock_session = Mock() - limited = RateLimitedSession(mock_session) - - metrics = limited.get_metrics() - - assert metrics["total_requests"] == 0 - assert metrics["active_requests"] == 0 - assert metrics["rejected_requests"] == 0 - - @pytest.mark.asyncio - async def test_metrics_tracking(self): - """Test that metrics are tracked correctly.""" - mock_session = Mock() - mock_session.execute = AsyncMock(return_value=Mock()) - - limited = RateLimitedSession(mock_session) - - # Execute some queries - await limited.execute("SELECT 1") - await limited.execute("SELECT 2") - - metrics = limited.get_metrics() - assert metrics["total_requests"] == 2 - assert metrics["active_requests"] == 0 # Both completed - - -class TestIntegration: - """Test integration of monitoring components.""" - - def test_create_metrics_system_memory(self): - """Test creating metrics system with memory backend.""" - middleware = create_metrics_system(backend="memory") - - assert isinstance(middleware, MetricsMiddleware) - assert len(middleware.collectors) == 1 - assert isinstance(middleware.collectors[0], InMemoryMetricsCollector) - - def test_create_metrics_system_prometheus(self): - """Test creating metrics system with prometheus.""" - middleware = create_metrics_system(backend="memory", prometheus_enabled=True) - - assert isinstance(middleware, MetricsMiddleware) - assert len(middleware.collectors) == 2 - assert isinstance(middleware.collectors[0], InMemoryMetricsCollector) - assert isinstance(middleware.collectors[1], PrometheusMetricsCollector) - - @pytest.mark.asyncio - async def test_create_monitored_session(self): - """Test creating a fully monitored session.""" - # Mock cluster and session creation - mock_cluster = Mock() - mock_session = Mock() - mock_session._session = Mock() - mock_session._session.cluster = Mock() - mock_session._session.cluster.metadata = Mock() - mock_session._session.cluster.metadata.all_hosts.return_value = [] - mock_session.execute = AsyncMock(return_value=Mock()) - - mock_cluster.connect = AsyncMock(return_value=mock_session) - - with patch("async_cassandra.cluster.AsyncCluster", return_value=mock_cluster): - session, monitor = await create_monitored_session( - contact_points=["127.0.0.1"], keyspace="test", max_concurrent=100, warmup=False - ) - - # Should return rate limited session and monitor - assert isinstance(session, RateLimitedSession) - assert isinstance(monitor, ConnectionMonitor) - assert session.session == mock_session - - @pytest.mark.asyncio - async def test_create_monitored_session_no_rate_limit(self): - """Test creating monitored session without rate limiting.""" - # Mock cluster and session creation - mock_cluster = Mock() - mock_session = Mock() - mock_session._session = Mock() - mock_session._session.cluster = Mock() - mock_session._session.cluster.metadata = Mock() - mock_session._session.cluster.metadata.all_hosts.return_value = [] - - mock_cluster.connect = AsyncMock(return_value=mock_session) - - with patch("async_cassandra.cluster.AsyncCluster", return_value=mock_cluster): - session, monitor = await create_monitored_session( - contact_points=["127.0.0.1"], max_concurrent=None, warmup=False - ) - - # Should return original session (not rate limited) - assert session == mock_session - assert isinstance(monitor, ConnectionMonitor) diff --git a/tests/unit/test_network_failures.py b/tests/unit/test_network_failures.py deleted file mode 100644 index b2a7759..0000000 --- a/tests/unit/test_network_failures.py +++ /dev/null @@ -1,634 +0,0 @@ -""" -Unit tests for network failure scenarios. - -Tests how the async wrapper handles: -- Partial network failures -- Connection timeouts -- Slow network conditions -- Coordinator failures mid-query - -Test Organization: -================== -1. Partial Failures - Connected but queries fail -2. Timeout Handling - Different timeout types -3. Network Instability - Flapping, congestion -4. Connection Pool - Recovery after issues -5. Network Topology - Partitions, distance changes - -Key Testing Principles: -====================== -- Differentiate timeout types -- Test recovery mechanisms -- Simulate real network issues -- Verify error propagation -""" - -import asyncio -import time -from unittest.mock import Mock, patch - -import pytest -from cassandra import OperationTimedOut, ReadTimeout, WriteTimeout -from cassandra.cluster import ConnectionException, Host, NoHostAvailable - -from async_cassandra import AsyncCassandraSession, AsyncCluster - - -class TestNetworkFailures: - """Test various network failure scenarios.""" - - def create_error_future(self, exception): - """ - Create a mock future that raises the given exception. - - Helper to simulate driver futures that fail with - network-related exceptions. - """ - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def create_success_future(self, result): - """ - Create a mock future that returns a result. - - Helper to simulate successful driver futures after - network recovery. - """ - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - # For success, the callback expects an iterable of rows - mock_rows = [result] if result else [] - callback(mock_rows) - if errback: - errbacks.append(errback) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock() - session.execute_async = Mock() - session.prepare_async = Mock() - session.cluster = Mock() - return session - - @pytest.mark.asyncio - async def test_partial_network_failure(self, mock_session): - """ - Test handling of partial network failures (can connect but can't query). - - What this tests: - --------------- - 1. Connection established but queries fail - 2. ConnectionException during execution - 3. Exception passed through directly - 4. Native error handling preserved - - Why this matters: - ---------------- - Partial failures are common in production: - - Firewall rules changed mid-session - - Network degradation after connect - - Load balancer issues - - Applications need direct access to - handle these "connected but broken" states. - """ - async_session = AsyncCassandraSession(mock_session) - - # Queries fail with connection error - mock_session.execute_async.return_value = self.create_error_future( - ConnectionException("Connection closed by remote host") - ) - - # ConnectionException is now passed through directly - with pytest.raises(ConnectionException) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Connection closed by remote host" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_connection_timeout_during_query(self, mock_session): - """ - Test handling of connection timeouts during query execution. - - What this tests: - --------------- - 1. OperationTimedOut errors handled - 2. Transient timeouts can recover - 3. Multiple attempts tracked - 4. Eventually succeeds - - Why this matters: - ---------------- - Timeouts can be transient: - - Network congestion - - Temporary overload - - GC pauses - - Applications often retry timeouts - as they may succeed on retry. - """ - async_session = AsyncCassandraSession(mock_session) - - # Simulate timeout patterns - timeout_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal timeout_count - timeout_count += 1 - - if timeout_count <= 2: - # First attempts timeout - return self.create_error_future(OperationTimedOut("Connection timed out")) - else: - # Eventually succeeds - return self.create_success_future({"id": 1}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # First two attempts should timeout - for i in range(2): - with pytest.raises(OperationTimedOut): - await async_session.execute("SELECT * FROM test") - - # Third attempt succeeds - result = await async_session.execute("SELECT * FROM test") - assert result.rows[0]["id"] == 1 - assert timeout_count == 3 - - @pytest.mark.asyncio - async def test_slow_network_simulation(self, mock_session): - """ - Test handling of slow network conditions. - - What this tests: - --------------- - 1. Slow queries still complete - 2. No premature timeouts - 3. Results returned correctly - 4. Latency tracked - - Why this matters: - ---------------- - Not all slowness is a timeout: - - Cross-region queries - - Large result sets - - Complex aggregations - - The wrapper must handle slow - operations without failing. - """ - async_session = AsyncCassandraSession(mock_session) - - # Create a future that simulates delay - start_time = time.time() - mock_session.execute_async.return_value = self.create_success_future( - {"latency": 0.5, "timestamp": start_time} - ) - - # Execute query - result = await async_session.execute("SELECT * FROM test") - - # Should return result - assert result.rows[0]["latency"] == 0.5 - - @pytest.mark.asyncio - async def test_coordinator_failure_mid_query(self, mock_session): - """ - Test coordinator node failing during query execution. - - What this tests: - --------------- - 1. Coordinator can fail mid-query - 2. NoHostAvailable with details - 3. Retry finds new coordinator - 4. Query eventually succeeds - - Why this matters: - ---------------- - Coordinator failures happen: - - Node crashes - - Network partition - - Rolling restarts - - The driver picks new coordinators - automatically on retry. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track coordinator changes - attempt_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal attempt_count - attempt_count += 1 - - if attempt_count == 1: - # First coordinator fails mid-query - return self.create_error_future( - NoHostAvailable( - "Unable to connect to any servers", - {"node0": ConnectionException("Connection lost to coordinator")}, - ) - ) - else: - # New coordinator succeeds - return self.create_success_future({"coordinator": f"node{attempt_count-1}"}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # First attempt should fail - with pytest.raises(NoHostAvailable): - await async_session.execute("SELECT * FROM test") - - # Second attempt should succeed - result = await async_session.execute("SELECT * FROM test") - assert result.rows[0]["coordinator"] == "node1" - assert attempt_count == 2 - - @pytest.mark.asyncio - async def test_network_flapping(self, mock_session): - """ - Test handling of network that rapidly connects/disconnects. - - What this tests: - --------------- - 1. Alternating success/failure pattern - 2. Each state change handled - 3. No corruption from rapid changes - 4. Accurate success/failure tracking - - Why this matters: - ---------------- - Network flapping occurs with: - - Faulty hardware - - Overloaded switches - - Misconfigured networking - - The wrapper must remain stable - despite unstable network. - """ - async_session = AsyncCassandraSession(mock_session) - - # Simulate flapping network - flap_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal flap_count - flap_count += 1 - - # Flip network state every call (odd = down, even = up) - if flap_count % 2 == 1: - return self.create_error_future( - ConnectionException(f"Network down (flap {flap_count})") - ) - else: - return self.create_success_future({"flap_count": flap_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Try multiple queries during flapping - results = [] - errors = [] - - for i in range(6): - try: - result = await async_session.execute(f"SELECT {i}") - results.append(result.rows[0]["flap_count"]) - except ConnectionException as e: - errors.append(str(e)) - - # Should have mix of successes and failures - assert len(results) == 3 # Even numbered attempts succeed - assert len(errors) == 3 # Odd numbered attempts fail - assert flap_count == 6 - - @pytest.mark.asyncio - async def test_request_timeout_vs_connection_timeout(self, mock_session): - """ - Test differentiating between request and connection timeouts. - - What this tests: - --------------- - 1. ReadTimeout vs WriteTimeout vs OperationTimedOut - 2. Each timeout type preserved - 3. Timeout details maintained - 4. Proper exception types raised - - Why this matters: - ---------------- - Different timeouts mean different things: - - ReadTimeout: query executed, waiting for data - - WriteTimeout: write may have partially succeeded - - OperationTimedOut: connection-level timeout - - Applications handle each differently: - - Read timeouts often safe to retry - - Write timeouts need idempotency checks - - Connection timeouts may need backoff - """ - async_session = AsyncCassandraSession(mock_session) - - # Test different timeout scenarios - from cassandra import WriteType - - timeout_scenarios = [ - ( - ReadTimeout( - "Read timeout", - consistency_level=1, - required_responses=1, - received_responses=0, - data_retrieved=False, - ), - "read", - ), - (WriteTimeout("Write timeout", write_type=WriteType.SIMPLE), "write"), - (OperationTimedOut("Connection timeout"), "connection"), - ] - - for timeout_error, timeout_type in timeout_scenarios: - # Set additional attributes for WriteTimeout - if timeout_type == "write": - timeout_error.consistency_level = 1 - timeout_error.required_responses = 1 - timeout_error.received_responses = 0 - - mock_session.execute_async.return_value = self.create_error_future(timeout_error) - - try: - await async_session.execute(f"SELECT * FROM test_{timeout_type}") - except Exception as e: - # Verify correct timeout type - if timeout_type == "read": - assert isinstance(e, ReadTimeout) - elif timeout_type == "write": - assert isinstance(e, WriteTimeout) - else: - assert isinstance(e, OperationTimedOut) - - @pytest.mark.asyncio - async def test_connection_pool_recovery_after_network_issue(self, mock_session): - """ - Test connection pool recovery after network issues. - - What this tests: - --------------- - 1. Pool can be exhausted by failures - 2. Recovery happens automatically - 3. Queries fail during recovery - 4. Eventually queries succeed - - Why this matters: - ---------------- - Connection pools need time to recover: - - Reconnection attempts - - Health checks - - Pool replenishment - - Applications should retry after - pool exhaustion as recovery - is often automatic. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track pool state - recovery_attempts = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal recovery_attempts - recovery_attempts += 1 - - if recovery_attempts <= 2: - # Pool not recovered - return self.create_error_future( - NoHostAvailable( - "Unable to connect to any servers", - {"all_hosts": ConnectionException("Pool not recovered")}, - ) - ) - else: - # Pool recovered - return self.create_success_future({"healthy": True}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # First two queries fail during network issue - for i in range(2): - with pytest.raises(NoHostAvailable): - await async_session.execute(f"SELECT {i}") - - # Third query succeeds after recovery - result = await async_session.execute("SELECT 3") - assert result.rows[0]["healthy"] is True - assert recovery_attempts == 3 - - @pytest.mark.asyncio - async def test_network_congestion_backoff(self, mock_session): - """ - Test exponential backoff during network congestion. - - What this tests: - --------------- - 1. Congestion causes timeouts - 2. Exponential backoff implemented - 3. Delays increase appropriately - 4. Eventually succeeds - - Why this matters: - ---------------- - Network congestion requires backoff: - - Prevents thundering herd - - Gives network time to recover - - Reduces overall load - - Exponential backoff is a best - practice for congestion handling. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track retry attempts - attempt_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal attempt_count - attempt_count += 1 - - if attempt_count < 4: - # Network congested - return self.create_error_future(OperationTimedOut("Network congested")) - else: - # Congestion clears - return self.create_success_future({"attempts": attempt_count}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Execute with manual exponential backoff - backoff_delays = [0.01, 0.02, 0.04] # Small delays for testing - - async def execute_with_backoff(query): - for i, delay in enumerate(backoff_delays): - try: - return await async_session.execute(query) - except OperationTimedOut: - if i < len(backoff_delays) - 1: - await asyncio.sleep(delay) - else: - # Try one more time after last delay - await asyncio.sleep(delay) - return await async_session.execute(query) # Final attempt - - result = await execute_with_backoff("SELECT * FROM test") - - # Verify backoff worked - assert attempt_count == 4 # 3 failures + 1 success - assert result.rows[0]["attempts"] == 4 - - @pytest.mark.asyncio - async def test_asymmetric_network_partition(self): - """ - Test asymmetric partition where node can send but not receive. - - What this tests: - --------------- - 1. Asymmetric network failures - 2. Some hosts unreachable - 3. Cluster finds working hosts - 4. Connection eventually succeeds - - Why this matters: - ---------------- - Real network partitions are often asymmetric: - - One-way firewall rules - - Routing issues - - Split-brain scenarios - - The cluster must work around - partially failed hosts. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 # Add protocol version - - # Create multiple hosts - hosts = [] - for i in range(3): - host = Mock(spec=Host) - host.address = f"10.0.0.{i+1}" - host.is_up = True - hosts.append(host) - - mock_cluster.metadata = Mock() - mock_cluster.metadata.all_hosts = Mock(return_value=hosts) - - # Simulate connection failure to partitioned host - connection_count = 0 - - def connect_side_effect(keyspace=None): - nonlocal connection_count - connection_count += 1 - - if connection_count == 1: - # First attempt includes partitioned host - raise NoHostAvailable( - "Unable to connect to any servers", - {hosts[1].address: OperationTimedOut("Cannot reach host")}, - ) - else: - # Second attempt succeeds without partitioned host - return Mock() - - mock_cluster.connect.side_effect = connect_side_effect - - async_cluster = AsyncCluster(contact_points=["10.0.0.1"]) - - # Should eventually connect using available hosts - session = await async_cluster.connect() - assert session is not None - assert connection_count == 2 - - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_host_distance_changes(self): - """ - Test handling of host distance changes (LOCAL to REMOTE). - - What this tests: - --------------- - 1. Host distance can change - 2. LOCAL to REMOTE transitions - 3. Distance changes tracked - 4. Affects query routing - - Why this matters: - ---------------- - Host distances change due to: - - Datacenter reconfigurations - - Network topology changes - - Dynamic snitch updates - - Distance affects: - - Query routing preferences - - Connection pool sizes - - Retry strategies - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - mock_cluster.protocol_version = 5 # Add protocol version - mock_cluster.connect.return_value = Mock() - - # Create hosts with distances - local_host = Mock(spec=Host, address="10.0.0.1") - remote_host = Mock(spec=Host, address="10.1.0.1") - - mock_cluster.metadata = Mock() - mock_cluster.metadata.all_hosts = Mock(return_value=[local_host, remote_host]) - - async_cluster = AsyncCluster() - - # Track distance changes - distance_changes = [] - - def on_distance_change(host, old_distance, new_distance): - distance_changes.append({"host": host, "old": old_distance, "new": new_distance}) - - # Simulate distance change - on_distance_change(local_host, "LOCAL", "REMOTE") - - # Verify tracking - assert len(distance_changes) == 1 - assert distance_changes[0]["old"] == "LOCAL" - assert distance_changes[0]["new"] == "REMOTE" - - await async_cluster.shutdown() diff --git a/tests/unit/test_no_host_available.py b/tests/unit/test_no_host_available.py deleted file mode 100644 index 40b13ce..0000000 --- a/tests/unit/test_no_host_available.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -Unit tests for NoHostAvailable exception handling. - -This module tests the specific handling of NoHostAvailable errors, -which indicate that no Cassandra nodes are available to handle requests. - -Test Organization: -================== -1. Direct Exception Propagation - NoHostAvailable raised without wrapping -2. Error Details Preservation - Host-specific errors maintained -3. Metrics Recording - Failure metrics tracked correctly -4. Exception Type Consistency - All Cassandra exceptions handled uniformly - -Key Testing Principles: -====================== -- NoHostAvailable must not be wrapped in QueryError -- Host error details must be preserved -- Metrics must capture connection failures -- Cassandra exceptions get special treatment -""" - -import asyncio -from unittest.mock import Mock - -import pytest -from cassandra.cluster import NoHostAvailable - -from async_cassandra.exceptions import QueryError -from async_cassandra.session import AsyncCassandraSession - - -@pytest.mark.asyncio -class TestNoHostAvailableHandling: - """Test NoHostAvailable exception handling.""" - - async def test_execute_raises_no_host_available_directly(self): - """ - Test that NoHostAvailable is raised directly without wrapping. - - What this tests: - --------------- - 1. NoHostAvailable propagates unchanged - 2. Not wrapped in QueryError - 3. Original message preserved - 4. Exception type maintained - - Why this matters: - ---------------- - NoHostAvailable requires special handling: - - Indicates infrastructure problems - - May need different retry strategy - - Often requires manual intervention - - Wrapping it would hide its specific nature and - break error handling code that catches NoHostAvailable. - """ - # Mock cassandra session that raises NoHostAvailable - mock_session = Mock() - mock_session.execute_async = Mock(side_effect=NoHostAvailable("All hosts are down", {})) - - # Create async session - async_session = AsyncCassandraSession(mock_session) - - # Should raise NoHostAvailable directly, not wrapped in QueryError - with pytest.raises(NoHostAvailable) as exc_info: - await async_session.execute("SELECT * FROM test") - - # Verify it's the original exception - assert "All hosts are down" in str(exc_info.value) - - async def test_execute_stream_raises_no_host_available_directly(self): - """ - Test that execute_stream raises NoHostAvailable directly. - - What this tests: - --------------- - 1. Streaming also preserves NoHostAvailable - 2. Consistent with execute() behavior - 3. No wrapping in streaming path - 4. Same exception handling for both methods - - Why this matters: - ---------------- - Applications need consistent error handling: - - Same exceptions from execute() and execute_stream() - - Can reuse error handling logic - - No surprises when switching methods - - This ensures streaming doesn't introduce - different error handling requirements. - """ - # Mock cassandra session that raises NoHostAvailable - mock_session = Mock() - mock_session.execute_async = Mock(side_effect=NoHostAvailable("Connection failed", {})) - - # Create async session - async_session = AsyncCassandraSession(mock_session) - - # Should raise NoHostAvailable directly - with pytest.raises(NoHostAvailable) as exc_info: - await async_session.execute_stream("SELECT * FROM test") - - # Verify it's the original exception - assert "Connection failed" in str(exc_info.value) - - async def test_no_host_available_preserves_host_errors(self): - """ - Test that NoHostAvailable preserves detailed host error information. - - What this tests: - --------------- - 1. Host-specific errors in 'errors' dict - 2. Each host's failure reason preserved - 3. Error details not lost in propagation - 4. Can diagnose per-host problems - - Why this matters: - ---------------- - NoHostAvailable.errors contains valuable debugging info: - - Which hosts failed and why - - Connection refused vs timeout vs other - - Helps identify patterns (all timeout = network issue) - - Operations teams need these details to: - - Identify which nodes are problematic - - Diagnose network vs node issues - - Take targeted corrective action - """ - # Create NoHostAvailable with host errors - host_errors = { - "host1": Exception("Connection refused"), - "host2": Exception("Host unreachable"), - } - no_host_error = NoHostAvailable("No hosts available", host_errors) - - # Mock cassandra session - mock_session = Mock() - mock_session.execute_async = Mock(side_effect=no_host_error) - - # Create async session - async_session = AsyncCassandraSession(mock_session) - - # Execute and catch exception - with pytest.raises(NoHostAvailable) as exc_info: - await async_session.execute("SELECT * FROM test") - - # Verify host errors are preserved - caught_exception = exc_info.value - assert hasattr(caught_exception, "errors") - assert "host1" in caught_exception.errors - assert "host2" in caught_exception.errors - - async def test_metrics_recorded_for_no_host_available(self): - """ - Test that metrics are recorded when NoHostAvailable occurs. - - What this tests: - --------------- - 1. Metrics capture NoHostAvailable errors - 2. Error type recorded as 'NoHostAvailable' - 3. Success=False in metrics - 4. Fire-and-forget metrics don't block - - Why this matters: - ---------------- - Monitoring connection failures is critical: - - Track cluster health over time - - Alert on connection problems - - Identify patterns and trends - - NoHostAvailable metrics help detect: - - Cluster-wide outages - - Network partitions - - Configuration problems - """ - # Mock cassandra session - mock_session = Mock() - mock_session.execute_async = Mock(side_effect=NoHostAvailable("All hosts down", {})) - - # Mock metrics - from async_cassandra.metrics import MetricsMiddleware - - mock_metrics = Mock(spec=MetricsMiddleware) - mock_metrics.record_query_metrics = Mock() - - # Create async session with metrics - async_session = AsyncCassandraSession(mock_session, metrics=mock_metrics) - - # Execute and expect NoHostAvailable - with pytest.raises(NoHostAvailable): - await async_session.execute("SELECT * FROM test") - - # Give time for fire-and-forget metrics - await asyncio.sleep(0.1) - - # Verify metrics were called with correct error type - mock_metrics.record_query_metrics.assert_called_once() - call_args = mock_metrics.record_query_metrics.call_args[1] - assert call_args["success"] is False - assert call_args["error_type"] == "NoHostAvailable" - - async def test_other_exceptions_still_wrapped(self): - """ - Test that non-Cassandra exceptions are still wrapped in QueryError. - - What this tests: - --------------- - 1. Non-Cassandra exceptions wrapped in QueryError - 2. Only Cassandra exceptions get special treatment - 3. Generic errors still provide context - 4. Original exception in __cause__ - - Why this matters: - ---------------- - Different exception types need different handling: - - Cassandra exceptions: domain-specific, preserve as-is - - Other exceptions: wrap for context and consistency - - This ensures unexpected errors still get - meaningful context while preserving Cassandra's - carefully designed exception hierarchy. - """ - # Mock cassandra session that raises generic exception - mock_session = Mock() - mock_session.execute_async = Mock(side_effect=RuntimeError("Unexpected error")) - - # Create async session - async_session = AsyncCassandraSession(mock_session) - - # Should wrap in QueryError - with pytest.raises(QueryError) as exc_info: - await async_session.execute("SELECT * FROM test") - - # Verify it's wrapped - assert "Query execution failed" in str(exc_info.value) - assert isinstance(exc_info.value.__cause__, RuntimeError) - - async def test_all_cassandra_exceptions_not_wrapped(self): - """ - Test that all Cassandra exceptions are raised directly. - - What this tests: - --------------- - 1. All Cassandra exception types preserved - 2. InvalidRequest, timeouts, Unavailable, etc. - 3. Exact exception instances propagated - 4. Consistent handling across all types - - Why this matters: - ---------------- - Cassandra's exception hierarchy is well-designed: - - Each type indicates specific problems - - Contains relevant diagnostic information - - Enables proper retry strategies - - Wrapping would: - - Break existing error handlers - - Hide important error details - - Prevent proper retry logic - - This comprehensive test ensures all Cassandra - exceptions are treated consistently. - """ - # Test each Cassandra exception type - from cassandra import ( - InvalidRequest, - OperationTimedOut, - ReadTimeout, - Unavailable, - WriteTimeout, - WriteType, - ) - - cassandra_exceptions = [ - InvalidRequest("Invalid query"), - ReadTimeout("Read timeout", consistency=1, required_responses=3, received_responses=1), - WriteTimeout( - "Write timeout", - consistency=1, - required_responses=3, - received_responses=1, - write_type=WriteType.SIMPLE, - ), - Unavailable( - "Not enough replicas", consistency=1, required_replicas=3, alive_replicas=1 - ), - OperationTimedOut("Operation timed out"), - NoHostAvailable("No hosts", {}), - ] - - for exception in cassandra_exceptions: - # Mock session - mock_session = Mock() - mock_session.execute_async = Mock(side_effect=exception) - - # Create async session - async_session = AsyncCassandraSession(mock_session) - - # Should raise original exception type - with pytest.raises(type(exception)) as exc_info: - await async_session.execute("SELECT * FROM test") - - # Verify it's the exact same exception - assert exc_info.value is exception diff --git a/tests/unit/test_page_callback_deadlock.py b/tests/unit/test_page_callback_deadlock.py deleted file mode 100644 index 70dc94d..0000000 --- a/tests/unit/test_page_callback_deadlock.py +++ /dev/null @@ -1,314 +0,0 @@ -""" -Unit tests for page callback execution outside lock. - -This module tests a critical deadlock prevention mechanism in streaming -results. Page callbacks must be executed outside the internal lock to -prevent deadlocks when callbacks try to interact with the result set. - -Test Organization: -================== -- Lock behavior during callbacks -- Error isolation in callbacks -- Performance with slow callbacks -- Callback data accuracy - -Key Testing Principles: -====================== -- Callbacks must not hold internal locks -- Callback errors must not affect streaming -- Slow callbacks must not block iteration -- Callbacks are optional (no overhead when unused) -""" - -import threading -import time -from unittest.mock import Mock - -import pytest - -from async_cassandra.streaming import AsyncStreamingResultSet, StreamConfig - - -@pytest.mark.asyncio -class TestPageCallbackDeadlock: - """Test that page callbacks are executed outside the lock to prevent deadlocks.""" - - async def test_page_callback_executed_outside_lock(self): - """ - Test that page callback is called outside the lock. - - What this tests: - --------------- - 1. Page callback runs without holding _lock - 2. Lock is released before callback execution - 3. Callback can acquire lock if needed - 4. No deadlock risk from callbacks - - Why this matters: - ---------------- - Previous implementations held the lock during callbacks, - which caused deadlocks when: - - Callbacks tried to iterate the result set - - Callbacks called methods that needed the lock - - Multiple threads were involved - - This test ensures callbacks run in a "clean" context - without holding internal locks, preventing deadlocks. - """ - # Track if callback was called while lock was held - lock_held_during_callback = None - callback_called = threading.Event() - - # Create a custom callback that checks lock status - def page_callback(page_num, row_count): - nonlocal lock_held_during_callback - # Try to acquire the lock - if we can't, it's held by _handle_page - lock_held_during_callback = not result_set._lock.acquire(blocking=False) - if not lock_held_during_callback: - result_set._lock.release() - callback_called.set() - - # Create streaming result set with callback - response_future = Mock() - response_future.has_more_pages = False - response_future._final_exception = None - response_future.add_callbacks = Mock() - - config = StreamConfig(page_callback=page_callback) - result_set = AsyncStreamingResultSet(response_future, config) - - # Trigger page callback - args = response_future.add_callbacks.call_args - page_handler = args[1]["callback"] - page_handler(["row1", "row2", "row3"]) - - # Wait for callback - assert callback_called.wait(timeout=2.0) - - # Callback should have been called outside the lock - assert lock_held_during_callback is False - - async def test_callback_error_does_not_affect_streaming(self): - """ - Test that callback errors don't affect streaming functionality. - - What this tests: - --------------- - 1. Callback exceptions are caught and isolated - 2. Streaming continues normally after callback error - 3. All rows are still accessible - 4. No corruption of internal state - - Why this matters: - ---------------- - User callbacks might have bugs or throw exceptions. - These errors should not: - - Crash the streaming process - - Lose data or skip rows - - Corrupt the result set state - - This ensures robustness against user code errors. - """ - - # Create a callback that raises an error - def bad_callback(page_num, row_count): - raise ValueError("Callback error") - - # Create streaming result set - response_future = Mock() - response_future.has_more_pages = False - response_future._final_exception = None - response_future.add_callbacks = Mock() - - config = StreamConfig(page_callback=bad_callback) - result_set = AsyncStreamingResultSet(response_future, config) - - # Trigger page with bad callback from a thread - args = response_future.add_callbacks.call_args - page_handler = args[1]["callback"] - - def thread_callback(): - page_handler(["row1", "row2"]) - - thread = threading.Thread(target=thread_callback) - thread.start() - - # Should still be able to iterate results despite callback error - rows = [] - async for row in result_set: - rows.append(row) - - assert len(rows) == 2 - assert rows == ["row1", "row2"] - - async def test_slow_callback_does_not_block_iteration(self): - """ - Test that slow callbacks don't block result iteration. - - What this tests: - --------------- - 1. Slow callbacks run asynchronously - 2. Row iteration proceeds without waiting - 3. Callback duration doesn't affect iteration speed - 4. No performance impact from slow callbacks - - Why this matters: - ---------------- - Page callbacks might do expensive operations: - - Write to databases - - Send network requests - - Perform complex calculations - - These slow operations should not block the main - iteration thread. Users can process rows immediately - while callbacks run in the background. - """ - callback_times = [] - iteration_start_time = None - - # Create a slow callback - def slow_callback(page_num, row_count): - callback_times.append(time.time()) - time.sleep(0.5) # Simulate slow callback - - # Create streaming result set - response_future = Mock() - response_future.has_more_pages = False - response_future._final_exception = None - response_future.add_callbacks = Mock() - - config = StreamConfig(page_callback=slow_callback) - result_set = AsyncStreamingResultSet(response_future, config) - - # Trigger page from a thread - args = response_future.add_callbacks.call_args - page_handler = args[1]["callback"] - - def thread_callback(): - page_handler(["row1", "row2"]) - - thread = threading.Thread(target=thread_callback) - thread.start() - - # Start iteration immediately - iteration_start_time = time.time() - rows = [] - async for row in result_set: - rows.append(row) - iteration_end_time = time.time() - - # Iteration should complete quickly, not waiting for callback - iteration_duration = iteration_end_time - iteration_start_time - assert iteration_duration < 0.2 # Much less than callback duration - - # Results should be available - assert len(rows) == 2 - - # Wait for thread to complete to avoid event loop closed warning - thread.join(timeout=1.0) - - async def test_callback_receives_correct_page_info(self): - """ - Test that callbacks receive correct page information. - - What this tests: - --------------- - 1. Page numbers increment correctly (1, 2, 3...) - 2. Row counts match actual page sizes - 3. Multiple pages tracked accurately - 4. Last page handled correctly - - Why this matters: - ---------------- - Callbacks often need to: - - Track progress through large result sets - - Update progress bars or metrics - - Log page processing statistics - - Detect when processing is complete - - Accurate page information enables these use cases. - """ - page_infos = [] - - def track_pages(page_num, row_count): - page_infos.append((page_num, row_count)) - - # Create streaming result set - response_future = Mock() - response_future.has_more_pages = True - response_future._final_exception = None - response_future.add_callbacks = Mock() - response_future.start_fetching_next_page = Mock() - - config = StreamConfig(page_callback=track_pages) - AsyncStreamingResultSet(response_future, config) - - # Get page handler - args = response_future.add_callbacks.call_args - page_handler = args[1]["callback"] - - # Simulate multiple pages - page_handler(["row1", "row2"]) - page_handler(["row3", "row4", "row5"]) - response_future.has_more_pages = False - page_handler(["row6"]) - - # Check callback data - assert len(page_infos) == 3 - assert page_infos[0] == (1, 2) # First page: 2 rows - assert page_infos[1] == (2, 3) # Second page: 3 rows - assert page_infos[2] == (3, 1) # Third page: 1 row - - async def test_no_callback_no_overhead(self): - """ - Test that having no callback doesn't add overhead. - - What this tests: - --------------- - 1. No performance penalty without callbacks - 2. Page handling is fast when no callback - 3. 1000 rows processed in <10ms - 4. Optional feature has zero cost when unused - - Why this matters: - ---------------- - Most streaming operations don't use callbacks. - The callback feature should have zero overhead - when not used, following the principle: - "You don't pay for what you don't use" - - This ensures the callback feature doesn't slow - down the common case of simple iteration. - """ - # Create streaming result set without callback - response_future = Mock() - response_future.has_more_pages = False - response_future._final_exception = None - response_future.add_callbacks = Mock() - - result_set = AsyncStreamingResultSet(response_future) - - # Trigger page from a thread - args = response_future.add_callbacks.call_args - page_handler = args[1]["callback"] - - rows = ["row" + str(i) for i in range(1000)] - start_time = time.time() - - def thread_callback(): - page_handler(rows) - - thread = threading.Thread(target=thread_callback) - thread.start() - thread.join() # Wait for thread to complete - handle_time = time.time() - start_time - - # Should be very fast without callback - assert handle_time < 0.01 - - # Should still work normally - count = 0 - async for row in result_set: - count += 1 - - assert count == 1000 diff --git a/tests/unit/test_prepared_statement_invalidation.py b/tests/unit/test_prepared_statement_invalidation.py deleted file mode 100644 index 23b5ec2..0000000 --- a/tests/unit/test_prepared_statement_invalidation.py +++ /dev/null @@ -1,587 +0,0 @@ -""" -Unit tests for prepared statement invalidation and re-preparation. - -Tests how the async wrapper handles: -- Prepared statements being invalidated by schema changes -- Automatic re-preparation -- Concurrent invalidation scenarios -""" - -import asyncio -from unittest.mock import Mock - -import pytest -from cassandra import InvalidRequest, OperationTimedOut -from cassandra.cluster import Session -from cassandra.query import BatchStatement, BatchType, PreparedStatement - -from async_cassandra import AsyncCassandraSession - - -class TestPreparedStatementInvalidation: - """Test prepared statement invalidation and recovery.""" - - def create_error_future(self, exception): - """Create a mock future that raises the given exception.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def create_success_future(self, result): - """Create a mock future that returns a result.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - # For success, the callback expects an iterable of rows - mock_rows = [result] if result else [] - callback(mock_rows) - if errback: - errbacks.append(errback) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def create_prepared_future(self, prepared_stmt): - """Create a mock future for prepare_async that returns a prepared statement.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - # Prepare callback gets the prepared statement directly - callback(prepared_stmt) - if errback: - errbacks.append(errback) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock(spec=Session) - session.execute_async = Mock() - session.prepare = Mock() - session.prepare_async = Mock() - session.cluster = Mock() - session.get_execution_profile = Mock(return_value=Mock()) - return session - - @pytest.fixture - def mock_prepared_statement(self): - """Create a mock prepared statement.""" - stmt = Mock(spec=PreparedStatement) - stmt.query_id = b"test_query_id" - stmt.query = "SELECT * FROM test WHERE id = ?" - - # Create a mock bound statement with proper attributes - bound_stmt = Mock() - bound_stmt.custom_payload = None - bound_stmt.routing_key = None - bound_stmt.keyspace = None - bound_stmt.consistency_level = None - bound_stmt.fetch_size = None - bound_stmt.serial_consistency_level = None - bound_stmt.retry_policy = None - - stmt.bind = Mock(return_value=bound_stmt) - return stmt - - @pytest.mark.asyncio - async def test_prepared_statement_invalidation_error( - self, mock_session, mock_prepared_statement - ): - """ - Test that invalidated prepared statements raise InvalidRequest. - - What this tests: - --------------- - 1. Invalidated statements detected - 2. InvalidRequest exception raised - 3. Clear error message provided - 4. No automatic re-preparation - - Why this matters: - ---------------- - Schema changes invalidate statements: - - Column added/removed - - Table recreated - - Type changes - - Applications must handle invalidation - and re-prepare statements. - """ - async_session = AsyncCassandraSession(mock_session) - - # First prepare succeeds (using sync prepare method) - mock_session.prepare.return_value = mock_prepared_statement - - # Prepare statement - prepared = await async_session.prepare("SELECT * FROM test WHERE id = ?") - assert prepared == mock_prepared_statement - - # Setup execution to fail with InvalidRequest (statement invalidated) - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Prepared statement is invalid") - ) - - # Execute with invalidated statement - should raise InvalidRequest - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute(prepared, [1]) - - assert "Prepared statement is invalid" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_manual_reprepare_after_invalidation(self, mock_session, mock_prepared_statement): - """ - Test manual re-preparation after invalidation. - - What this tests: - --------------- - 1. Re-preparation creates new statement - 2. New statement has different ID - 3. Execution works after re-prepare - 4. Old statement remains invalid - - Why this matters: - ---------------- - Recovery pattern after invalidation: - - Catch InvalidRequest - - Re-prepare statement - - Retry with new statement - - Critical for handling schema - evolution in production. - """ - async_session = AsyncCassandraSession(mock_session) - - # First prepare succeeds (using sync prepare method) - mock_session.prepare.return_value = mock_prepared_statement - - # Prepare statement - prepared = await async_session.prepare("SELECT * FROM test WHERE id = ?") - - # Setup execution to fail with InvalidRequest - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Prepared statement is invalid") - ) - - # First execution fails - with pytest.raises(InvalidRequest): - await async_session.execute(prepared, [1]) - - # Create new prepared statement - new_prepared = Mock(spec=PreparedStatement) - new_prepared.query_id = b"new_query_id" - new_prepared.query = "SELECT * FROM test WHERE id = ?" - - # Create bound statement with proper attributes - new_bound = Mock() - new_bound.custom_payload = None - new_bound.routing_key = None - new_bound.keyspace = None - new_prepared.bind = Mock(return_value=new_bound) - - # Re-prepare manually - mock_session.prepare.return_value = new_prepared - prepared2 = await async_session.prepare("SELECT * FROM test WHERE id = ?") - assert prepared2 == new_prepared - assert prepared2.query_id != prepared.query_id - - # Now execution succeeds with new prepared statement - mock_session.execute_async.return_value = self.create_success_future({"id": 1}) - result = await async_session.execute(prepared2, [1]) - assert result.rows[0]["id"] == 1 - - @pytest.mark.asyncio - async def test_concurrent_invalidation_handling(self, mock_session, mock_prepared_statement): - """ - Test that concurrent executions all fail with invalidation. - - What this tests: - --------------- - 1. All concurrent queries fail - 2. Each gets InvalidRequest - 3. No race conditions - 4. Consistent error handling - - Why this matters: - ---------------- - Under high concurrency: - - Many queries may use same statement - - All must handle invalidation - - No query should hang or corrupt - - Ensures thread-safe error propagation - for invalidated statements. - """ - async_session = AsyncCassandraSession(mock_session) - - # Prepare statement - mock_session.prepare.return_value = mock_prepared_statement - prepared = await async_session.prepare("SELECT * FROM test WHERE id = ?") - - # All executions fail with invalidation - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Prepared statement is invalid") - ) - - # Execute multiple concurrent queries - tasks = [async_session.execute(prepared, [i]) for i in range(5)] - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # All should fail with InvalidRequest - assert len(results) == 5 - assert all(isinstance(r, InvalidRequest) for r in results) - assert all("Prepared statement is invalid" in str(r) for r in results) - - @pytest.mark.asyncio - async def test_invalidation_during_batch_execution(self, mock_session, mock_prepared_statement): - """ - Test prepared statement invalidation during batch execution. - - What this tests: - --------------- - 1. Batch with prepared statements - 2. Invalidation affects batch - 3. Whole batch fails - 4. Error clearly indicates issue - - Why this matters: - ---------------- - Batches often contain prepared statements: - - Bulk inserts/updates - - Multi-row operations - - Transaction-like semantics - - Batch invalidation requires re-preparing - all statements in the batch. - """ - async_session = AsyncCassandraSession(mock_session) - - # Prepare statement - mock_session.prepare.return_value = mock_prepared_statement - prepared = await async_session.prepare("INSERT INTO test (id, value) VALUES (?, ?)") - - # Create batch with prepared statement - batch = BatchStatement(batch_type=BatchType.LOGGED) - batch.add(prepared, (1, "value1")) - batch.add(prepared, (2, "value2")) - - # Batch execution fails with invalidation - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Prepared statement is invalid") - ) - - # Batch execution should fail - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute(batch) - - assert "Prepared statement is invalid" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_invalidation_error_propagation(self, mock_session, mock_prepared_statement): - """ - Test that non-invalidation errors are properly propagated. - - What this tests: - --------------- - 1. Non-invalidation errors preserved - 2. Timeouts not confused with invalidation - 3. Error types maintained - 4. No incorrect error wrapping - - Why this matters: - ---------------- - Different errors need different handling: - - Timeouts: retry same statement - - Invalidation: re-prepare needed - - Other errors: various responses - - Accurate error types enable - correct recovery strategies. - """ - async_session = AsyncCassandraSession(mock_session) - - # Prepare statement - mock_session.prepare.return_value = mock_prepared_statement - prepared = await async_session.prepare("SELECT * FROM test WHERE id = ?") - - # Execution fails with different error (not invalidation) - mock_session.execute_async.return_value = self.create_error_future( - OperationTimedOut("Query timed out") - ) - - # Should propagate the error - with pytest.raises(OperationTimedOut) as exc_info: - await async_session.execute(prepared, [1]) - - assert "Query timed out" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_reprepare_failure_handling(self, mock_session, mock_prepared_statement): - """ - Test handling when re-preparation itself fails. - - What this tests: - --------------- - 1. Re-preparation can fail - 2. Table might be dropped - 3. QueryError wraps prepare errors - 4. Original cause preserved - - Why this matters: - ---------------- - Re-preparation fails when: - - Table/keyspace dropped - - Permissions changed - - Query now invalid - - Applications must handle both - invalidation AND re-prepare failure. - """ - async_session = AsyncCassandraSession(mock_session) - - # Initial prepare succeeds - mock_session.prepare.return_value = mock_prepared_statement - prepared = await async_session.prepare("SELECT * FROM test WHERE id = ?") - - # Execution fails with invalidation - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Prepared statement is invalid") - ) - - # First execution fails - with pytest.raises(InvalidRequest): - await async_session.execute(prepared, [1]) - - # Re-preparation fails (e.g., table dropped) - mock_session.prepare.side_effect = InvalidRequest("Table test does not exist") - - # Re-prepare attempt should fail - InvalidRequest passed through - with pytest.raises(InvalidRequest) as exc_info: - await async_session.prepare("SELECT * FROM test WHERE id = ?") - - assert "Table test does not exist" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_prepared_statement_cache_behavior(self, mock_session): - """ - Test that prepared statements are not cached by the async wrapper. - - What this tests: - --------------- - 1. No built-in caching in wrapper - 2. Each prepare goes to driver - 3. Driver handles caching - 4. Different IDs for re-prepares - - Why this matters: - ---------------- - Caching strategy important: - - Driver caches per connection - - Application may cache globally - - Wrapper stays simple - - Applications should implement - their own caching strategy. - """ - async_session = AsyncCassandraSession(mock_session) - - # Create different prepared statements for same query - stmt1 = Mock(spec=PreparedStatement) - stmt1.query_id = b"id1" - stmt1.query = "SELECT * FROM test WHERE id = ?" - bound1 = Mock(custom_payload=None) - stmt1.bind = Mock(return_value=bound1) - - stmt2 = Mock(spec=PreparedStatement) - stmt2.query_id = b"id2" - stmt2.query = "SELECT * FROM test WHERE id = ?" - bound2 = Mock(custom_payload=None) - stmt2.bind = Mock(return_value=bound2) - - # First prepare - mock_session.prepare.return_value = stmt1 - prepared1 = await async_session.prepare("SELECT * FROM test WHERE id = ?") - assert prepared1.query_id == b"id1" - - # Second prepare of same query (no caching in wrapper) - mock_session.prepare.return_value = stmt2 - prepared2 = await async_session.prepare("SELECT * FROM test WHERE id = ?") - assert prepared2.query_id == b"id2" - - # Verify prepare was called twice - assert mock_session.prepare.call_count == 2 - - @pytest.mark.asyncio - async def test_invalidation_with_custom_payload(self, mock_session, mock_prepared_statement): - """ - Test prepared statement invalidation with custom payload. - - What this tests: - --------------- - 1. Custom payloads work with prepare - 2. Payload passed to driver - 3. Invalidation still detected - 4. Tracing/debugging preserved - - Why this matters: - ---------------- - Custom payloads used for: - - Request tracing - - Performance monitoring - - Debugging metadata - - Must work correctly even during - error scenarios like invalidation. - """ - async_session = AsyncCassandraSession(mock_session) - - # Prepare with custom payload - custom_payload = {"app_name": "test_app"} - mock_session.prepare.return_value = mock_prepared_statement - - prepared = await async_session.prepare( - "SELECT * FROM test WHERE id = ?", custom_payload=custom_payload - ) - - # Verify custom payload was passed - mock_session.prepare.assert_called_with("SELECT * FROM test WHERE id = ?", custom_payload) - - # Execute fails with invalidation - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Prepared statement is invalid") - ) - - with pytest.raises(InvalidRequest): - await async_session.execute(prepared, [1]) - - @pytest.mark.asyncio - async def test_statement_id_tracking(self, mock_session): - """ - Test that statement IDs are properly tracked. - - What this tests: - --------------- - 1. Each statement has unique ID - 2. IDs preserved in errors - 3. Can identify which statement failed - 4. Helpful error messages - - Why this matters: - ---------------- - Statement IDs help debugging: - - Which statement invalidated - - Correlate with server logs - - Track statement lifecycle - - Essential for troubleshooting - production invalidation issues. - """ - async_session = AsyncCassandraSession(mock_session) - - # Create statements with specific IDs - stmt1 = Mock(spec=PreparedStatement, query_id=b"id1", query="SELECT 1") - stmt2 = Mock(spec=PreparedStatement, query_id=b"id2", query="SELECT 2") - - # Prepare multiple statements - mock_session.prepare.side_effect = [stmt1, stmt2] - - prepared1 = await async_session.prepare("SELECT 1") - prepared2 = await async_session.prepare("SELECT 2") - - # Verify different IDs - assert prepared1.query_id == b"id1" - assert prepared2.query_id == b"id2" - assert prepared1.query_id != prepared2.query_id - - # Execute with specific statement - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest(f"Prepared statement with ID {stmt1.query_id.hex()} is invalid") - ) - - # Should fail with specific error message - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute(prepared1) - - assert stmt1.query_id.hex() in str(exc_info.value) - - @pytest.mark.asyncio - async def test_invalidation_after_schema_change(self, mock_session): - """ - Test prepared statement invalidation after schema change. - - What this tests: - --------------- - 1. Statement works before change - 2. Schema change invalidates - 3. Result metadata mismatch detected - 4. Clear error about metadata - - Why this matters: - ---------------- - Common schema changes that invalidate: - - ALTER TABLE ADD COLUMN - - DROP/RECREATE TABLE - - Type modifications - - This is the most common cause of - invalidation in production systems. - """ - async_session = AsyncCassandraSession(mock_session) - - # Prepare statement - stmt = Mock(spec=PreparedStatement) - stmt.query_id = b"test_id" - stmt.query = "SELECT id, name FROM users WHERE id = ?" - bound = Mock(custom_payload=None) - stmt.bind = Mock(return_value=bound) - - mock_session.prepare.return_value = stmt - prepared = await async_session.prepare("SELECT id, name FROM users WHERE id = ?") - - # First execution succeeds - mock_session.execute_async.return_value = self.create_success_future( - {"id": 1, "name": "Alice"} - ) - result = await async_session.execute(prepared, [1]) - assert result.rows[0]["name"] == "Alice" - - # Simulate schema change (column added) - # Next execution fails with invalidation - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Prepared query has an invalid result metadata") - ) - - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute(prepared, [2]) - - assert "invalid result metadata" in str(exc_info.value) diff --git a/tests/unit/test_prepared_statements.py b/tests/unit/test_prepared_statements.py deleted file mode 100644 index 1ab38f4..0000000 --- a/tests/unit/test_prepared_statements.py +++ /dev/null @@ -1,381 +0,0 @@ -"""Prepared statements functionality tests. - -This module tests prepared statement creation, execution, and caching. -""" - -import asyncio -from unittest.mock import Mock - -import pytest -from cassandra.query import BoundStatement, PreparedStatement - -from async_cassandra import AsyncCassandraSession as AsyncSession -from tests.unit.test_helpers import create_mock_response_future - - -class TestPreparedStatements: - """Test prepared statement functionality.""" - - @pytest.mark.features - @pytest.mark.quick - @pytest.mark.critical - async def test_prepare_statement(self): - """ - Test basic prepared statement creation. - - What this tests: - --------------- - 1. Prepare statement async wrapper works - 2. Query string passed correctly - 3. PreparedStatement returned - 4. Synchronous prepare called once - - Why this matters: - ---------------- - Prepared statements are critical for: - - Query performance (cached plans) - - SQL injection prevention - - Type safety with parameters - - Every production app should use - prepared statements for queries. - """ - mock_session = Mock() - mock_prepared = Mock(spec=PreparedStatement) - mock_session.prepare.return_value = mock_prepared - - async_session = AsyncSession(mock_session) - - prepared = await async_session.prepare("SELECT * FROM users WHERE id = ?") - - assert prepared == mock_prepared - mock_session.prepare.assert_called_once_with("SELECT * FROM users WHERE id = ?", None) - - @pytest.mark.features - async def test_execute_prepared_statement(self): - """ - Test executing prepared statements. - - What this tests: - --------------- - 1. Prepared statements can be executed - 2. Parameters bound correctly - 3. Results returned properly - 4. Async execution flow works - - Why this matters: - ---------------- - Prepared statement execution: - - Most common query pattern - - Must handle parameter binding - - Critical for performance - - Proper parameter handling prevents - injection attacks and type errors. - """ - mock_session = Mock() - mock_prepared = Mock(spec=PreparedStatement) - mock_bound = Mock(spec=BoundStatement) - - mock_prepared.bind.return_value = mock_bound - mock_session.prepare.return_value = mock_prepared - - # Create a mock response future manually to have more control - response_future = Mock() - response_future.has_more_pages = False - response_future.timeout = None - response_future.add_callbacks = Mock() - - def setup_callback(callback=None, errback=None): - # Call the callback immediately with test data - if callback: - callback([{"id": 1, "name": "test"}]) - - response_future.add_callbacks.side_effect = setup_callback - mock_session.execute_async.return_value = response_future - - async_session = AsyncSession(mock_session) - - # Prepare statement - prepared = await async_session.prepare("SELECT * FROM users WHERE id = ?") - - # Execute with parameters - result = await async_session.execute(prepared, [1]) - - assert len(result.rows) == 1 - assert result.rows[0] == {"id": 1, "name": "test"} - # The prepared statement and parameters are passed to execute_async - mock_session.execute_async.assert_called_once() - # Check that the prepared statement was passed - args = mock_session.execute_async.call_args[0] - assert args[0] == prepared - assert args[1] == [1] - - @pytest.mark.features - @pytest.mark.critical - async def test_prepared_statement_caching(self): - """ - Test that prepared statements can be cached and reused. - - What this tests: - --------------- - 1. Same query returns same statement - 2. Multiple prepares allowed - 3. Statement object reusable - 4. No built-in caching (driver handles) - - Why this matters: - ---------------- - Statement caching important for: - - Avoiding re-preparation overhead - - Consistent query plans - - Memory efficiency - - Applications should cache statements - at application level for best performance. - """ - mock_session = Mock() - mock_prepared = Mock(spec=PreparedStatement) - mock_session.prepare.return_value = mock_prepared - mock_session.execute.return_value = Mock(current_rows=[]) - - async_session = AsyncSession(mock_session) - - # Prepare same statement multiple times - query = "SELECT * FROM users WHERE id = ? AND status = ?" - - prepared1 = await async_session.prepare(query) - prepared2 = await async_session.prepare(query) - prepared3 = await async_session.prepare(query) - - # All should be the same instance - assert prepared1 == prepared2 == prepared3 == mock_prepared - - # But prepare is called each time (caching would be an optimization) - assert mock_session.prepare.call_count == 3 - - @pytest.mark.features - async def test_prepared_statement_with_custom_options(self): - """ - Test prepared statements with custom execution options. - - What this tests: - --------------- - 1. Custom timeout honored - 2. Custom payload passed through - 3. Execution options work with prepared - 4. Parameters still bound correctly - - Why this matters: - ---------------- - Production queries often need: - - Custom timeouts for SLAs - - Tracing via custom payloads - - Consistency level tuning - - Prepared statements must support - all execution options. - """ - mock_session = Mock() - mock_prepared = Mock(spec=PreparedStatement) - mock_bound = Mock(spec=BoundStatement) - - mock_prepared.bind.return_value = mock_bound - mock_session.prepare.return_value = mock_prepared - mock_session.execute_async.return_value = create_mock_response_future([]) - - async_session = AsyncSession(mock_session) - - prepared = await async_session.prepare("UPDATE users SET name = ? WHERE id = ?") - - # Execute with custom timeout and consistency - await async_session.execute( - prepared, ["new name", 123], timeout=30.0, custom_payload={"trace": "true"} - ) - - # Verify execute_async was called with correct parameters - mock_session.execute_async.assert_called_once() - # Check the arguments passed to execute_async - args = mock_session.execute_async.call_args[0] - assert args[0] == prepared - assert args[1] == ["new name", 123] - # Check timeout was passed (position 4) - assert args[4] == 30.0 - - @pytest.mark.features - async def test_concurrent_prepare_statements(self): - """ - Test preparing multiple statements concurrently. - - What this tests: - --------------- - 1. Multiple prepares can run concurrently - 2. Each gets correct statement back - 3. No race conditions or mixing - 4. Async gather works properly - - Why this matters: - ---------------- - Application startup often: - - Prepares many statements - - Benefits from parallelism - - Must not corrupt statements - - Concurrent preparation speeds up - application initialization. - """ - mock_session = Mock() - - # Different prepared statements - prepared_stmts = { - "SELECT": Mock(spec=PreparedStatement), - "INSERT": Mock(spec=PreparedStatement), - "UPDATE": Mock(spec=PreparedStatement), - "DELETE": Mock(spec=PreparedStatement), - } - - def prepare_side_effect(query, custom_payload=None): - for key in prepared_stmts: - if key in query: - return prepared_stmts[key] - return Mock(spec=PreparedStatement) - - mock_session.prepare.side_effect = prepare_side_effect - - async_session = AsyncSession(mock_session) - - # Prepare statements concurrently - tasks = [ - async_session.prepare("SELECT * FROM users WHERE id = ?"), - async_session.prepare("INSERT INTO users (id, name) VALUES (?, ?)"), - async_session.prepare("UPDATE users SET name = ? WHERE id = ?"), - async_session.prepare("DELETE FROM users WHERE id = ?"), - ] - - results = await asyncio.gather(*tasks) - - assert results[0] == prepared_stmts["SELECT"] - assert results[1] == prepared_stmts["INSERT"] - assert results[2] == prepared_stmts["UPDATE"] - assert results[3] == prepared_stmts["DELETE"] - - @pytest.mark.features - async def test_prepared_statement_error_handling(self): - """ - Test error handling during statement preparation. - - What this tests: - --------------- - 1. Prepare errors propagated - 2. Original exception preserved - 3. Error message maintained - 4. No hanging or corruption - - Why this matters: - ---------------- - Prepare can fail due to: - - Syntax errors in query - - Unknown tables/columns - - Schema mismatches - - Clear errors help developers - fix queries during development. - """ - mock_session = Mock() - mock_session.prepare.side_effect = Exception("Invalid query syntax") - - async_session = AsyncSession(mock_session) - - with pytest.raises(Exception, match="Invalid query syntax"): - await async_session.prepare("INVALID QUERY SYNTAX") - - @pytest.mark.features - @pytest.mark.critical - async def test_bound_statement_reuse(self): - """ - Test reusing bound statements. - - What this tests: - --------------- - 1. Prepare once, execute many - 2. Different parameters each time - 3. Statement prepared only once - 4. Executions independent - - Why this matters: - ---------------- - This is THE pattern for production: - - Prepare statements at startup - - Execute with different params - - Massive performance benefit - - Reusing prepared statements reduces - latency and cluster load. - """ - mock_session = Mock() - mock_prepared = Mock(spec=PreparedStatement) - mock_bound = Mock(spec=BoundStatement) - - mock_prepared.bind.return_value = mock_bound - mock_session.prepare.return_value = mock_prepared - mock_session.execute_async.return_value = create_mock_response_future([]) - - async_session = AsyncSession(mock_session) - - # Prepare once - prepared = await async_session.prepare("SELECT * FROM users WHERE id = ?") - - # Execute multiple times with different parameters - for user_id in [1, 2, 3, 4, 5]: - await async_session.execute(prepared, [user_id]) - - # Prepare called once, execute_async called for each execution - assert mock_session.prepare.call_count == 1 - assert mock_session.execute_async.call_count == 5 - - @pytest.mark.features - async def test_prepared_statement_metadata(self): - """ - Test accessing prepared statement metadata. - - What this tests: - --------------- - 1. Column metadata accessible - 2. Type information available - 3. Partition key info present - 4. Metadata correctly structured - - Why this matters: - ---------------- - Metadata enables: - - Dynamic result processing - - Type validation - - Routing optimization - - ORMs and frameworks rely on - metadata for mapping and validation. - """ - mock_session = Mock() - mock_prepared = Mock(spec=PreparedStatement) - - # Mock metadata - mock_prepared.column_metadata = [ - ("keyspace", "table", "id", "uuid"), - ("keyspace", "table", "name", "text"), - ("keyspace", "table", "created_at", "timestamp"), - ] - mock_prepared.routing_key_indexes = [0] # id is partition key - - mock_session.prepare.return_value = mock_prepared - - async_session = AsyncSession(mock_session) - - prepared = await async_session.prepare( - "SELECT id, name, created_at FROM users WHERE id = ?" - ) - - # Access metadata - assert len(prepared.column_metadata) == 3 - assert prepared.column_metadata[0][2] == "id" - assert prepared.column_metadata[1][2] == "name" - assert prepared.routing_key_indexes == [0] diff --git a/tests/unit/test_protocol_edge_cases.py b/tests/unit/test_protocol_edge_cases.py deleted file mode 100644 index 3c7eb38..0000000 --- a/tests/unit/test_protocol_edge_cases.py +++ /dev/null @@ -1,572 +0,0 @@ -""" -Unit tests for protocol-level edge cases. - -Tests how the async wrapper handles: -- Protocol version negotiation issues -- Protocol errors during queries -- Custom payloads -- Large queries -- Various Cassandra exceptions - -Test Organization: -================== -1. Protocol Negotiation - Version negotiation failures -2. Protocol Errors - Errors during query execution -3. Custom Payloads - Application-specific protocol data -4. Query Size Limits - Large query handling -5. Error Recovery - Recovery from protocol issues - -Key Testing Principles: -====================== -- Test protocol boundary conditions -- Verify error propagation -- Ensure graceful degradation -- Test recovery mechanisms -""" - -from unittest.mock import Mock, patch - -import pytest -from cassandra import InvalidRequest, OperationTimedOut, UnsupportedOperation -from cassandra.cluster import NoHostAvailable, Session -from cassandra.connection import ProtocolError - -from async_cassandra import AsyncCassandraSession -from async_cassandra.exceptions import ConnectionError - - -class TestProtocolEdgeCases: - """Test protocol-level edge cases and error handling.""" - - def create_error_future(self, exception): - """Create a mock future that raises the given exception.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def create_success_future(self, result): - """Create a mock future that returns a result.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - # For success, the callback expects an iterable of rows - mock_rows = [result] if result else [] - callback(mock_rows) - if errback: - errbacks.append(errback) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock(spec=Session) - session.execute_async = Mock() - session.prepare = Mock() - session.cluster = Mock() - session.cluster.protocol_version = 5 - return session - - @pytest.mark.asyncio - async def test_protocol_version_negotiation_failure(self): - """ - Test handling of protocol version negotiation failures. - - What this tests: - --------------- - 1. Protocol negotiation can fail - 2. NoHostAvailable with ProtocolError - 3. Wrapped in ConnectionError - 4. Clear error message - - Why this matters: - ---------------- - Protocol negotiation failures occur when: - - Client/server version mismatch - - Unsupported protocol features - - Configuration conflicts - - Users need clear guidance on - version compatibility issues. - """ - from async_cassandra import AsyncCluster - - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster instance - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - - # Simulate protocol negotiation failure during connect - mock_cluster.connect.side_effect = NoHostAvailable( - "Unable to connect to any servers", - {"127.0.0.1": ProtocolError("Cannot negotiate protocol version")}, - ) - - async_cluster = AsyncCluster(contact_points=["127.0.0.1"]) - - # Should fail with connection error - with pytest.raises(ConnectionError) as exc_info: - await async_cluster.connect() - - assert "Failed to connect" in str(exc_info.value) - - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_protocol_error_during_query(self, mock_session): - """ - Test handling of protocol errors during query execution. - - What this tests: - --------------- - 1. Protocol errors during execution - 2. ProtocolError passed through without wrapping - 3. Direct exception access - 4. Error details preserved as-is - - Why this matters: - ---------------- - Protocol errors indicate: - - Corrupted messages - - Protocol violations - - Driver/server bugs - - Users need direct access for - proper error handling and debugging. - """ - async_session = AsyncCassandraSession(mock_session) - - # Simulate protocol error - mock_session.execute_async.return_value = self.create_error_future( - ProtocolError("Invalid or unsupported protocol version") - ) - - # ProtocolError is now passed through without wrapping - with pytest.raises(ProtocolError) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Invalid or unsupported protocol version" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_custom_payload_handling(self, mock_session): - """ - Test handling of custom payloads in protocol. - - What this tests: - --------------- - 1. Custom payloads passed through - 2. Payload data preserved - 3. No interference with query - 4. Application metadata works - - Why this matters: - ---------------- - Custom payloads enable: - - Request tracing - - Application context - - Cross-system correlation - - Used for debugging and monitoring - in production systems. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track custom payloads - sent_payloads = [] - - def execute_async_side_effect(*args, **kwargs): - # Extract custom payload if provided - custom_payload = args[3] if len(args) > 3 else kwargs.get("custom_payload") - if custom_payload: - sent_payloads.append(custom_payload) - - return self.create_success_future({"payload_received": True}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Execute with custom payload - custom_data = {"app_name": "test_app", "request_id": "12345"} - result = await async_session.execute("SELECT * FROM test", custom_payload=custom_data) - - # Verify payload was sent - assert len(sent_payloads) == 1 - assert sent_payloads[0] == custom_data - assert result.rows[0]["payload_received"] is True - - @pytest.mark.asyncio - async def test_large_query_handling(self, mock_session): - """ - Test handling of very large queries. - - What this tests: - --------------- - 1. Query size limits enforced - 2. InvalidRequest for oversized queries - 3. Clear size limit in error - 4. Not wrapped (Cassandra error) - - Why this matters: - ---------------- - Query size limits prevent: - - Memory exhaustion - - Network overload - - Protocol buffer overflow - - Applications must chunk large - operations or use prepared statements. - """ - async_session = AsyncCassandraSession(mock_session) - - # Create very large query - large_values = ["x" * 1000 for _ in range(100)] # ~100KB of data - large_query = f"INSERT INTO test (id, data) VALUES (1, '{','.join(large_values)}')" - - # Execution fails due to size - mock_session.execute_async.return_value = self.create_error_future( - InvalidRequest("Query string length (102400) is greater than maximum allowed (65535)") - ) - - # InvalidRequest is not wrapped - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute(large_query) - - assert "greater than maximum allowed" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_unsupported_operation(self, mock_session): - """ - Test handling of unsupported operations. - - What this tests: - --------------- - 1. UnsupportedOperation errors passed through - 2. No wrapping - direct exception access - 3. Feature limitations clearly visible - 4. Version-specific features preserved - - Why this matters: - ---------------- - Features vary by protocol version: - - Continuous paging (v5+) - - Duration type (v5+) - - Per-query keyspace (v5+) - - Users need direct access to handle - version-specific feature errors. - """ - async_session = AsyncCassandraSession(mock_session) - - # Simulate unsupported operation - mock_session.execute_async.return_value = self.create_error_future( - UnsupportedOperation("Continuous paging is not supported by this protocol version") - ) - - # UnsupportedOperation is now passed through without wrapping - with pytest.raises(UnsupportedOperation) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Continuous paging is not supported" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_protocol_error_recovery(self, mock_session): - """ - Test recovery from protocol-level errors. - - What this tests: - --------------- - 1. Protocol errors can be transient - 2. Recovery possible after errors - 3. Direct exception handling - 4. Eventually succeeds - - Why this matters: - ---------------- - Some protocol errors are recoverable: - - Stream ID conflicts - - Temporary corruption - - Race conditions - - Users can implement retry logic - with new connections as needed. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track protocol errors - error_count = 0 - - def execute_async_side_effect(*args, **kwargs): - nonlocal error_count - error_count += 1 - - if error_count <= 2: - # First attempts fail with protocol error - return self.create_error_future(ProtocolError("Protocol error: Invalid stream id")) - else: - # Recovery succeeds - return self.create_success_future({"recovered": True}) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # First two attempts should fail - for i in range(2): - with pytest.raises(ProtocolError): - await async_session.execute("SELECT * FROM test") - - # Third attempt should succeed - result = await async_session.execute("SELECT * FROM test") - assert result.rows[0]["recovered"] is True - assert error_count == 3 - - @pytest.mark.asyncio - async def test_protocol_version_in_session(self, mock_session): - """ - Test accessing protocol version from session. - - What this tests: - --------------- - 1. Protocol version accessible - 2. Available via cluster object - 3. Version doesn't affect queries - 4. Useful for debugging - - Why this matters: - ---------------- - Applications may need version info: - - Feature detection - - Compatibility checks - - Debugging protocol issues - - Version should be easily accessible - for runtime decisions. - """ - async_session = AsyncCassandraSession(mock_session) - - # Protocol version should be accessible via cluster - assert mock_session.cluster.protocol_version == 5 - - # Execute query to verify protocol version doesn't affect normal operation - mock_session.execute_async.return_value = self.create_success_future( - {"protocol_version": mock_session.cluster.protocol_version} - ) - - result = await async_session.execute("SELECT * FROM system.local") - assert result.rows[0]["protocol_version"] == 5 - - @pytest.mark.asyncio - async def test_timeout_vs_protocol_error(self, mock_session): - """ - Test differentiating between timeouts and protocol errors. - - What this tests: - --------------- - 1. Timeouts not wrapped - 2. Protocol errors wrapped - 3. Different error handling - 4. Clear distinction - - Why this matters: - ---------------- - Different errors need different handling: - - Timeouts: often transient, retry - - Protocol errors: serious, investigate - - Applications must distinguish to - implement proper error handling. - """ - async_session = AsyncCassandraSession(mock_session) - - # Test timeout - mock_session.execute_async.return_value = self.create_error_future( - OperationTimedOut("Request timed out") - ) - - # OperationTimedOut is not wrapped - with pytest.raises(OperationTimedOut): - await async_session.execute("SELECT * FROM test") - - # Test protocol error - mock_session.execute_async.return_value = self.create_error_future( - ProtocolError("Protocol violation") - ) - - # ProtocolError is now passed through without wrapping - with pytest.raises(ProtocolError): - await async_session.execute("SELECT * FROM test") - - @pytest.mark.asyncio - async def test_prepare_with_protocol_error(self, mock_session): - """ - Test prepared statement with protocol errors. - - What this tests: - --------------- - 1. Prepare can fail with protocol error - 2. Passed through without wrapping - 3. Statement preparation issues visible - 4. Direct exception access - - Why this matters: - ---------------- - Prepare failures indicate: - - Schema issues - - Protocol limitations - - Query complexity problems - - Users need direct access to - handle preparation failures. - """ - async_session = AsyncCassandraSession(mock_session) - - # Prepare fails with protocol error - mock_session.prepare.side_effect = ProtocolError("Cannot prepare statement") - - # ProtocolError is now passed through without wrapping - with pytest.raises(ProtocolError) as exc_info: - await async_session.prepare("SELECT * FROM test WHERE id = ?") - - assert "Cannot prepare statement" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_execution_profile_with_protocol_settings(self, mock_session): - """ - Test execution profiles don't interfere with protocol handling. - - What this tests: - --------------- - 1. Execution profiles work correctly - 2. Profile parameter passed through - 3. No protocol interference - 4. Custom settings preserved - - Why this matters: - ---------------- - Execution profiles customize: - - Consistency levels - - Retry policies - - Load balancing - - Must work seamlessly with - protocol-level features. - """ - async_session = AsyncCassandraSession(mock_session) - - # Execute with custom execution profile - mock_session.execute_async.return_value = self.create_success_future({"profile": "custom"}) - - result = await async_session.execute( - "SELECT * FROM test", execution_profile="custom_profile" - ) - - # Verify execution profile was passed - mock_session.execute_async.assert_called_once() - call_args = mock_session.execute_async.call_args - # Check positional arguments: query, parameters, trace, custom_payload, timeout, execution_profile - assert call_args[0][5] == "custom_profile" # execution_profile is 6th parameter (index 5) - assert result.rows[0]["profile"] == "custom" - - @pytest.mark.asyncio - async def test_batch_with_protocol_error(self, mock_session): - """ - Test batch execution with protocol errors. - - What this tests: - --------------- - 1. Batch operations can hit protocol limits - 2. Protocol errors passed through directly - 3. Batch size limits visible to users - 4. Native exception handling - - Why this matters: - ---------------- - Batches have protocol limits: - - Maximum batch size - - Statement count limits - - Protocol buffer constraints - - Users need direct access to - handle batch size errors. - """ - from cassandra.query import BatchStatement, BatchType - - async_session = AsyncCassandraSession(mock_session) - - # Create batch - batch = BatchStatement(batch_type=BatchType.LOGGED) - batch.add("INSERT INTO test (id) VALUES (1)") - batch.add("INSERT INTO test (id) VALUES (2)") - - # Batch execution fails with protocol error - mock_session.execute_async.return_value = self.create_error_future( - ProtocolError("Batch too large for protocol") - ) - - # ProtocolError is now passed through without wrapping - with pytest.raises(ProtocolError) as exc_info: - await async_session.execute_batch(batch) - - assert "Batch too large" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_no_host_available_with_protocol_errors(self, mock_session): - """ - Test NoHostAvailable containing protocol errors. - - What this tests: - --------------- - 1. NoHostAvailable can contain various errors - 2. Protocol errors preserved per host - 3. Mixed error types handled - 4. Detailed error information - - Why this matters: - ---------------- - Connection failures vary by host: - - Some have protocol issues - - Others timeout - - Mixed failure modes - - Detailed per-host errors help - diagnose cluster-wide issues. - """ - async_session = AsyncCassandraSession(mock_session) - - # Create NoHostAvailable with protocol errors - errors = { - "10.0.0.1": ProtocolError("Protocol version mismatch"), - "10.0.0.2": ProtocolError("Protocol negotiation failed"), - "10.0.0.3": OperationTimedOut("Connection timeout"), - } - - mock_session.execute_async.return_value = self.create_error_future( - NoHostAvailable("Unable to connect to any servers", errors) - ) - - # NoHostAvailable is not wrapped - with pytest.raises(NoHostAvailable) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Unable to connect to any servers" in str(exc_info.value) - assert len(exc_info.value.errors) == 3 - assert isinstance(exc_info.value.errors["10.0.0.1"], ProtocolError) diff --git a/tests/unit/test_protocol_exceptions.py b/tests/unit/test_protocol_exceptions.py deleted file mode 100644 index 098700a..0000000 --- a/tests/unit/test_protocol_exceptions.py +++ /dev/null @@ -1,847 +0,0 @@ -""" -Comprehensive unit tests for protocol exceptions from the DataStax driver. - -Tests proper handling of all protocol-level exceptions including: -- OverloadedErrorMessage -- ReadTimeout/WriteTimeout -- Unavailable -- ReadFailure/WriteFailure -- ServerError -- ProtocolException -- IsBootstrappingErrorMessage -- TruncateError -- FunctionFailure -- CDCWriteFailure -""" - -from unittest.mock import Mock - -import pytest -from cassandra import ( - AlreadyExists, - AuthenticationFailed, - CDCWriteFailure, - CoordinationFailure, - FunctionFailure, - InvalidRequest, - OperationTimedOut, - ReadFailure, - ReadTimeout, - Unavailable, - WriteFailure, - WriteTimeout, -) -from cassandra.cluster import NoHostAvailable, ServerError -from cassandra.connection import ( - ConnectionBusy, - ConnectionException, - ConnectionShutdown, - ProtocolError, -) -from cassandra.pool import NoConnectionsAvailable - -from async_cassandra import AsyncCassandraSession - - -class TestProtocolExceptions: - """Test handling of all protocol-level exceptions.""" - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock() - session.execute_async = Mock() - session.prepare_async = Mock() - session.cluster = Mock() - session.cluster.protocol_version = 5 - return session - - def create_error_future(self, exception): - """Create a mock future that raises the given exception.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - @pytest.mark.asyncio - async def test_overloaded_error_message(self, mock_session): - """ - Test handling of OverloadedErrorMessage from coordinator. - - What this tests: - --------------- - 1. Server overload errors handled - 2. OperationTimedOut for overload - 3. Clear error message - 4. Not wrapped (timeout exception) - - Why this matters: - ---------------- - Server overload indicates: - - Too much concurrent load - - Insufficient cluster capacity - - Need for backpressure - - Applications should respond with - backoff and retry strategies. - """ - async_session = AsyncCassandraSession(mock_session) - - # Create OverloadedErrorMessage - this is typically wrapped in OperationTimedOut - error = OperationTimedOut("Request timed out - server overloaded") - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(OperationTimedOut) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "server overloaded" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_read_timeout(self, mock_session): - """ - Test handling of ReadTimeout errors. - - What this tests: - --------------- - 1. Read timeouts not wrapped - 2. Consistency level preserved - 3. Response count available - 4. Data retrieval flag set - - Why this matters: - ---------------- - Read timeouts tell you: - - How many replicas responded - - Whether any data was retrieved - - If retry might succeed - - Applications can make informed - retry decisions based on details. - """ - async_session = AsyncCassandraSession(mock_session) - - error = ReadTimeout( - "Read request timed out", - consistency_level=1, - required_responses=2, - received_responses=1, - data_retrieved=False, - ) - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(ReadTimeout) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert exc_info.value.required_responses == 2 - assert exc_info.value.received_responses == 1 - assert exc_info.value.data_retrieved is False - - @pytest.mark.asyncio - async def test_write_timeout(self, mock_session): - """ - Test handling of WriteTimeout errors. - - What this tests: - --------------- - 1. Write timeouts not wrapped - 2. Write type preserved - 3. Response counts available - 4. Consistency level included - - Why this matters: - ---------------- - Write timeout details critical for: - - Determining if write succeeded - - Understanding failure mode - - Deciding on retry safety - - Different write types (SIMPLE, BATCH, - UNLOGGED_BATCH, COUNTER) need different - retry strategies. - """ - async_session = AsyncCassandraSession(mock_session) - - from cassandra import WriteType - - error = WriteTimeout("Write request timed out", write_type=WriteType.SIMPLE) - # Set additional attributes - error.consistency_level = 1 - error.required_responses = 3 - error.received_responses = 2 - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(WriteTimeout) as exc_info: - await async_session.execute("INSERT INTO test VALUES (1)") - - assert exc_info.value.required_responses == 3 - assert exc_info.value.received_responses == 2 - # write_type is stored as numeric value - from cassandra import WriteType - - assert exc_info.value.write_type == WriteType.SIMPLE - - @pytest.mark.asyncio - async def test_unavailable(self, mock_session): - """ - Test handling of Unavailable errors (not enough replicas). - - What this tests: - --------------- - 1. Unavailable errors not wrapped - 2. Required replica count shown - 3. Alive replica count shown - 4. Consistency level preserved - - Why this matters: - ---------------- - Unavailable means: - - Not enough replicas up - - Cannot meet consistency - - Cluster health issue - - Retry won't help until more - replicas come online. - """ - async_session = AsyncCassandraSession(mock_session) - - error = Unavailable( - "Not enough replicas available", consistency=1, required_replicas=3, alive_replicas=1 - ) - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(Unavailable) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert exc_info.value.required_replicas == 3 - assert exc_info.value.alive_replicas == 1 - - @pytest.mark.asyncio - async def test_read_failure(self, mock_session): - """ - Test handling of ReadFailure errors (replicas failed during read). - - What this tests: - --------------- - 1. ReadFailure passed through without wrapping - 2. Failure count preserved - 3. Data retrieval flag available - 4. Direct exception access - - Why this matters: - ---------------- - Read failures indicate: - - Replicas crashed/errored - - Data corruption possible - - More serious than timeout - - Users need direct access to - handle these serious errors. - """ - async_session = AsyncCassandraSession(mock_session) - - original_error = ReadFailure("Read failed on replicas", data_retrieved=False) - # Set additional attributes - original_error.consistency_level = 1 - original_error.required_responses = 2 - original_error.received_responses = 1 - original_error.numfailures = 1 - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # ReadFailure is now passed through without wrapping - with pytest.raises(ReadFailure) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Read failed on replicas" in str(exc_info.value) - assert exc_info.value.numfailures == 1 - assert exc_info.value.data_retrieved is False - - @pytest.mark.asyncio - async def test_write_failure(self, mock_session): - """ - Test handling of WriteFailure errors (replicas failed during write). - - What this tests: - --------------- - 1. WriteFailure passed through without wrapping - 2. Write type preserved - 3. Failure count available - 4. Response details included - - Why this matters: - ---------------- - Write failures mean: - - Replicas rejected write - - Possible constraint violation - - Data inconsistency risk - - Users need direct access to - understand write outcomes. - """ - async_session = AsyncCassandraSession(mock_session) - - from cassandra import WriteType - - original_error = WriteFailure("Write failed on replicas", write_type=WriteType.BATCH) - # Set additional attributes - original_error.consistency_level = 1 - original_error.required_responses = 3 - original_error.received_responses = 2 - original_error.numfailures = 1 - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # WriteFailure is now passed through without wrapping - with pytest.raises(WriteFailure) as exc_info: - await async_session.execute("INSERT INTO test VALUES (1)") - - assert "Write failed on replicas" in str(exc_info.value) - assert exc_info.value.numfailures == 1 - - @pytest.mark.asyncio - async def test_function_failure(self, mock_session): - """ - Test handling of FunctionFailure errors (UDF execution failed). - - What this tests: - --------------- - 1. FunctionFailure passed through without wrapping - 2. Function details preserved - 3. Keyspace and name available - 4. Argument types included - - Why this matters: - ---------------- - UDF failures indicate: - - Logic errors in function - - Invalid input data - - Resource constraints - - Users need direct access to - debug function failures. - """ - async_session = AsyncCassandraSession(mock_session) - - # Create the actual FunctionFailure that would come from the driver - original_error = FunctionFailure( - "User defined function failed", - keyspace="test_ks", - function="my_func", - arg_types=["text", "int"], - ) - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # FunctionFailure is now passed through without wrapping - with pytest.raises(FunctionFailure) as exc_info: - await async_session.execute("SELECT my_func(name, age) FROM users") - - # Verify the exception contains the original error info - assert "User defined function failed" in str(exc_info.value) - assert exc_info.value.keyspace == "test_ks" - assert exc_info.value.function == "my_func" - - @pytest.mark.asyncio - async def test_cdc_write_failure(self, mock_session): - """ - Test handling of CDCWriteFailure errors. - - What this tests: - --------------- - 1. CDCWriteFailure passed through without wrapping - 2. CDC-specific error preserved - 3. Direct exception access - 4. Native error handling - - Why this matters: - ---------------- - CDC (Change Data Capture) failures: - - CDC log space exhausted - - CDC disabled on table - - System overload - - Applications need direct access - for CDC-specific handling. - """ - async_session = AsyncCassandraSession(mock_session) - - original_error = CDCWriteFailure("CDC write failed") - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # CDCWriteFailure is now passed through without wrapping - with pytest.raises(CDCWriteFailure) as exc_info: - await async_session.execute("INSERT INTO cdc_table VALUES (1)") - - assert "CDC write failed" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_coordinator_failure(self, mock_session): - """ - Test handling of CoordinationFailure errors. - - What this tests: - --------------- - 1. CoordinationFailure passed through without wrapping - 2. Coordinator node failure preserved - 3. Error message unchanged - 4. Direct exception handling - - Why this matters: - ---------------- - Coordination failures mean: - - Coordinator node issues - - Cannot orchestrate query - - Different from replica failures - - Users need direct access to - implement retry strategies. - """ - async_session = AsyncCassandraSession(mock_session) - - original_error = CoordinationFailure("Coordinator failed to execute query") - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # CoordinationFailure is now passed through without wrapping - with pytest.raises(CoordinationFailure) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Coordinator failed to execute query" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_is_bootstrapping_error(self, mock_session): - """ - Test handling of IsBootstrappingErrorMessage. - - What this tests: - --------------- - 1. Bootstrapping errors in NoHostAvailable - 2. Node state errors handled - 3. Connection exceptions preserved - 4. Host-specific errors shown - - Why this matters: - ---------------- - Bootstrapping nodes: - - Still joining cluster - - Not ready for queries - - Temporary state - - Applications should retry on - other nodes until bootstrap completes. - """ - async_session = AsyncCassandraSession(mock_session) - - # Bootstrapping errors are typically wrapped in NoHostAvailable - error = NoHostAvailable( - "No host available", {"127.0.0.1": ConnectionException("Host is bootstrapping")} - ) - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(NoHostAvailable) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "No host available" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_truncate_error(self, mock_session): - """ - Test handling of TruncateError. - - What this tests: - --------------- - 1. Truncate timeouts handled - 2. OperationTimedOut for truncate - 3. Error message specific - 4. Not wrapped - - Why this matters: - ---------------- - Truncate errors indicate: - - Truncate taking too long - - Cluster coordination issues - - Heavy operation timeout - - Truncate is expensive - timeouts - expected on large tables. - """ - async_session = AsyncCassandraSession(mock_session) - - # TruncateError is typically wrapped in OperationTimedOut - error = OperationTimedOut("Truncate operation timed out") - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(OperationTimedOut) as exc_info: - await async_session.execute("TRUNCATE test_table") - - assert "Truncate operation timed out" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_server_error(self, mock_session): - """ - Test handling of generic ServerError. - - What this tests: - --------------- - 1. ServerError wrapped in QueryError - 2. Error code preserved - 3. Error message included - 4. Additional info available - - Why this matters: - ---------------- - Generic server errors indicate: - - Internal Cassandra errors - - Unexpected conditions - - Bugs or edge cases - - Error codes help identify - specific server issues. - """ - async_session = AsyncCassandraSession(mock_session) - - # ServerError is an ErrorMessage subclass that requires code, message, info - original_error = ServerError(0x0000, "Internal server error occurred", {}) - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # ServerError is passed through directly (ErrorMessage subclass) - with pytest.raises(ServerError) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Internal server error occurred" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_protocol_error(self, mock_session): - """ - Test handling of ProtocolError. - - What this tests: - --------------- - 1. ProtocolError passed through without wrapping - 2. Protocol violations preserved as-is - 3. Error message unchanged - 4. Direct exception access for handling - - Why this matters: - ---------------- - Protocol errors serious: - - Version mismatches - - Message corruption - - Driver/server bugs - - Users need direct access to these - exceptions for proper handling. - """ - async_session = AsyncCassandraSession(mock_session) - - # ProtocolError from connection module takes just a message - original_error = ProtocolError("Protocol version mismatch") - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # ProtocolError is now passed through without wrapping - with pytest.raises(ProtocolError) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Protocol version mismatch" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_connection_busy(self, mock_session): - """ - Test handling of ConnectionBusy errors. - - What this tests: - --------------- - 1. ConnectionBusy passed through without wrapping - 2. In-flight request limit error preserved - 3. Connection saturation visible to users - 4. Direct exception handling possible - - Why this matters: - ---------------- - Connection busy means: - - Too many concurrent requests - - Per-connection limit reached - - Need more connections or less load - - Users need to handle this directly - for proper connection management. - """ - async_session = AsyncCassandraSession(mock_session) - - original_error = ConnectionBusy("Connection has too many in-flight requests") - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # ConnectionBusy is now passed through without wrapping - with pytest.raises(ConnectionBusy) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Connection has too many in-flight requests" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_connection_shutdown(self, mock_session): - """ - Test handling of ConnectionShutdown errors. - - What this tests: - --------------- - 1. ConnectionShutdown passed through without wrapping - 2. Graceful shutdown exception preserved - 3. Connection closing visible to users - 4. Direct error handling enabled - - Why this matters: - ---------------- - Connection shutdown occurs when: - - Node shutting down cleanly - - Connection being recycled - - Maintenance operations - - Applications need direct access - to handle retry logic properly. - """ - async_session = AsyncCassandraSession(mock_session) - - original_error = ConnectionShutdown("Connection is shutting down") - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # ConnectionShutdown is now passed through without wrapping - with pytest.raises(ConnectionShutdown) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Connection is shutting down" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_no_connections_available(self, mock_session): - """ - Test handling of NoConnectionsAvailable from pool. - - What this tests: - --------------- - 1. NoConnectionsAvailable passed through without wrapping - 2. Pool exhaustion exception preserved - 3. Direct access to pool state - 4. Native exception handling - - Why this matters: - ---------------- - No connections available means: - - Connection pool exhausted - - All connections busy - - Need to wait or expand pool - - Applications need direct access - for proper backpressure handling. - """ - async_session = AsyncCassandraSession(mock_session) - - original_error = NoConnectionsAvailable("Connection pool exhausted") - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # NoConnectionsAvailable is now passed through without wrapping - with pytest.raises(NoConnectionsAvailable) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert "Connection pool exhausted" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_already_exists(self, mock_session): - """ - Test handling of AlreadyExists errors. - - What this tests: - --------------- - 1. AlreadyExists wrapped in QueryError - 2. Keyspace/table info preserved - 3. Schema conflict detected - 4. Details accessible - - Why this matters: - ---------------- - Already exists errors for: - - CREATE TABLE conflicts - - CREATE KEYSPACE conflicts - - Schema synchronization issues - - May be safe to ignore if - idempotent schema creation. - """ - async_session = AsyncCassandraSession(mock_session) - - original_error = AlreadyExists(keyspace="test_ks", table="test_table") - mock_session.execute_async.return_value = self.create_error_future(original_error) - - # AlreadyExists is passed through directly - with pytest.raises(AlreadyExists) as exc_info: - await async_session.execute("CREATE TABLE test_table (id int PRIMARY KEY)") - - assert exc_info.value.keyspace == "test_ks" - assert exc_info.value.table == "test_table" - - @pytest.mark.asyncio - async def test_invalid_request(self, mock_session): - """ - Test handling of InvalidRequest errors. - - What this tests: - --------------- - 1. InvalidRequest not wrapped - 2. Syntax errors caught - 3. Clear error message - 4. Driver exception passed through - - Why this matters: - ---------------- - Invalid requests indicate: - - CQL syntax errors - - Schema mismatches - - Invalid operations - - These are programming errors - that need fixing, not retrying. - """ - async_session = AsyncCassandraSession(mock_session) - - error = InvalidRequest("Invalid CQL syntax") - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute("SELCT * FROM test") # Typo in SELECT - - assert "Invalid CQL syntax" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_multiple_error_types_in_sequence(self, mock_session): - """ - Test handling different error types in sequence. - - What this tests: - --------------- - 1. Multiple error types handled - 2. Each preserves its type - 3. No error state pollution - 4. Clean error handling - - Why this matters: - ---------------- - Real applications see various errors: - - Must handle each appropriately - - Error handling can't break - - State must stay clean - - Ensures robust error handling - across all exception types. - """ - async_session = AsyncCassandraSession(mock_session) - - errors = [ - Unavailable( - "Not enough replicas", consistency=1, required_replicas=3, alive_replicas=1 - ), - ReadTimeout("Read timed out"), - InvalidRequest("Invalid query syntax"), # ServerError requires code/message/info - ] - - # Test each error type - for error in errors: - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(type(error)): - await async_session.execute("SELECT * FROM test") - - @pytest.mark.asyncio - async def test_error_during_prepared_statement(self, mock_session): - """ - Test error handling during prepared statement execution. - - What this tests: - --------------- - 1. Prepare succeeds, execute fails - 2. Prepared statement errors handled - 3. WriteTimeout during execution - 4. Error details preserved - - Why this matters: - ---------------- - Prepared statements can fail at: - - Preparation time (schema issues) - - Execution time (timeout/failures) - - Both error paths must work correctly - for production reliability. - """ - async_session = AsyncCassandraSession(mock_session) - - # Prepare succeeds - prepared = Mock() - prepared.query = "INSERT INTO users (id, name) VALUES (?, ?)" - prepare_future = Mock() - prepare_future.result = Mock(return_value=prepared) - prepare_future.add_callbacks = Mock() - prepare_future.has_more_pages = False - prepare_future.timeout = None - prepare_future.clear_callbacks = Mock() - mock_session.prepare_async.return_value = prepare_future - - stmt = await async_session.prepare("INSERT INTO users (id, name) VALUES (?, ?)") - - # But execution fails with write timeout - from cassandra import WriteType - - error = WriteTimeout("Write timed out", write_type=WriteType.SIMPLE) - error.consistency_level = 1 - error.required_responses = 2 - error.received_responses = 1 - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(WriteTimeout): - await async_session.execute(stmt, [1, "test"]) - - @pytest.mark.asyncio - async def test_no_host_available_with_multiple_errors(self, mock_session): - """ - Test NoHostAvailable with different errors per host. - - What this tests: - --------------- - 1. NoHostAvailable aggregates errors - 2. Per-host errors preserved - 3. Different failure modes shown - 4. All error details available - - Why this matters: - ---------------- - NoHostAvailable shows why each host failed: - - Connection refused - - Authentication failed - - Timeout - - Detailed errors essential for - diagnosing cluster-wide issues. - """ - async_session = AsyncCassandraSession(mock_session) - - # Multiple hosts with different failures - host_errors = { - "10.0.0.1": ConnectionException("Connection refused"), - "10.0.0.2": AuthenticationFailed("Bad credentials"), - "10.0.0.3": OperationTimedOut("Connection timeout"), - } - - error = NoHostAvailable("Unable to connect to any servers", host_errors) - mock_session.execute_async.return_value = self.create_error_future(error) - - with pytest.raises(NoHostAvailable) as exc_info: - await async_session.execute("SELECT * FROM test") - - assert len(exc_info.value.errors) == 3 - assert "10.0.0.1" in exc_info.value.errors - assert isinstance(exc_info.value.errors["10.0.0.2"], AuthenticationFailed) diff --git a/tests/unit/test_protocol_version_validation.py b/tests/unit/test_protocol_version_validation.py deleted file mode 100644 index 21a7c9e..0000000 --- a/tests/unit/test_protocol_version_validation.py +++ /dev/null @@ -1,320 +0,0 @@ -""" -Unit tests for protocol version validation. - -These tests ensure protocol version validation happens immediately at -configuration time without requiring a real Cassandra connection. - -Test Organization: -================== -1. Legacy Protocol Rejection - v1, v2, v3 not supported -2. Protocol v4 - Rejected with cloud provider guidance -3. Modern Protocols - v5, v6+ accepted -4. Auto-negotiation - No version specified allowed -5. Error Messages - Clear guidance for upgrades - -Key Testing Principles: -====================== -- Fail fast at configuration time -- Provide clear upgrade guidance -- Support future protocol versions -- Help users migrate from legacy versions -""" - -import pytest - -from async_cassandra import AsyncCluster -from async_cassandra.exceptions import ConfigurationError - - -class TestProtocolVersionValidation: - """Test protocol version validation at configuration time.""" - - def test_protocol_v1_rejected(self): - """ - Protocol version 1 should be rejected immediately. - - What this tests: - --------------- - 1. Protocol v1 raises ConfigurationError - 2. Error happens at configuration time - 3. No connection attempt made - 4. Clear error message - - Why this matters: - ---------------- - Protocol v1 is ancient (Cassandra 1.2): - - Lacks modern features - - Security vulnerabilities - - No async support - - Failing fast prevents confusing - runtime errors later. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(contact_points=["localhost"], protocol_version=1) - - assert "Protocol version 1 is not supported" in str(exc_info.value) - - def test_protocol_v2_rejected(self): - """ - Protocol version 2 should be rejected immediately. - - What this tests: - --------------- - 1. Protocol v2 raises ConfigurationError - 2. Consistent with v1 rejection - 3. Clear not supported message - 4. No connection attempted - - Why this matters: - ---------------- - Protocol v2 (Cassandra 2.0) lacks: - - Necessary async features - - Modern authentication - - Performance optimizations - - async-cassandra needs v5+ features. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(contact_points=["localhost"], protocol_version=2) - - assert "Protocol version 2 is not supported" in str(exc_info.value) - - def test_protocol_v3_rejected(self): - """ - Protocol version 3 should be rejected immediately. - - What this tests: - --------------- - 1. Protocol v3 raises ConfigurationError - 2. Even though v3 is common - 3. Clear rejection message - 4. Fail at configuration - - Why this matters: - ---------------- - Protocol v3 (Cassandra 2.1) is common but: - - Missing required async features - - No continuous paging - - Limited result metadata - - Many users on v3 need clear - upgrade guidance. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(contact_points=["localhost"], protocol_version=3) - - assert "Protocol version 3 is not supported" in str(exc_info.value) - - def test_protocol_v4_rejected_with_guidance(self): - """ - Protocol version 4 should be rejected with cloud provider guidance. - - What this tests: - --------------- - 1. Protocol v4 rejected despite being modern - 2. Special cloud provider guidance - 3. Helps managed service users - 4. Clear next steps - - Why this matters: - ---------------- - Protocol v4 (Cassandra 3.0) is tricky: - - Some cloud providers stuck on v4 - - Users need provider-specific help - - v5 adds critical async features - - Guidance helps users navigate - cloud provider limitations. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(contact_points=["localhost"], protocol_version=4) - - error_msg = str(exc_info.value) - assert "Protocol version 4 is not supported" in error_msg - assert "cloud provider" in error_msg - assert "check their documentation" in error_msg - - def test_protocol_v5_accepted(self): - """ - Protocol version 5 should be accepted. - - What this tests: - --------------- - 1. Protocol v5 configuration succeeds - 2. Minimum supported version - 3. No errors at config time - 4. Cluster object created - - Why this matters: - ---------------- - Protocol v5 (Cassandra 4.0) provides: - - Required async features - - Better streaming - - Improved performance - - This is the minimum version - for async-cassandra. - """ - # Should not raise an exception - cluster = AsyncCluster(contact_points=["localhost"], protocol_version=5) - assert cluster is not None - - def test_protocol_v6_accepted(self): - """ - Protocol version 6 should be accepted (even if beta). - - What this tests: - --------------- - 1. Protocol v6 configuration allowed - 2. Beta protocols accepted - 3. Forward compatibility - 4. No artificial limits - - Why this matters: - ---------------- - Protocol v6 (Cassandra 5.0) adds: - - Vector search features - - Improved metadata - - Performance enhancements - - Users testing new features - shouldn't be blocked. - """ - # Should not raise an exception at configuration time - cluster = AsyncCluster(contact_points=["localhost"], protocol_version=6) - assert cluster is not None - - def test_future_protocol_accepted(self): - """ - Future protocol versions should be accepted for forward compatibility. - - What this tests: - --------------- - 1. Unknown versions accepted - 2. Forward compatibility maintained - 3. No hardcoded upper limit - 4. Future-proof design - - Why this matters: - ---------------- - Future protocols will add features: - - Don't block early adopters - - Allow testing new versions - - Avoid forced upgrades - - The driver should work with - future Cassandra versions. - """ - # Should not raise an exception - cluster = AsyncCluster(contact_points=["localhost"], protocol_version=7) - assert cluster is not None - - def test_no_protocol_version_accepted(self): - """ - No protocol version specified should be accepted (auto-negotiation). - - What this tests: - --------------- - 1. Protocol version optional - 2. Auto-negotiation supported - 3. Driver picks best version - 4. Simplifies configuration - - Why this matters: - ---------------- - Auto-negotiation benefits: - - Works across versions - - Picks optimal protocol - - Reduces configuration errors - - Most users should use - auto-negotiation. - """ - # Should not raise an exception - cluster = AsyncCluster(contact_points=["localhost"]) - assert cluster is not None - - def test_auth_with_legacy_protocol_rejected(self): - """ - Authentication with legacy protocol should fail immediately. - - What this tests: - --------------- - 1. Auth + legacy protocol rejected - 2. create_with_auth validates protocol - 3. Consistent validation everywhere - 4. Clear error message - - Why this matters: - ---------------- - Legacy protocols + auth problematic: - - Security vulnerabilities - - Missing auth features - - Incompatible mechanisms - - Prevent insecure configurations - at setup time. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster.create_with_auth( - contact_points=["localhost"], username="user", password="pass", protocol_version=3 - ) - - assert "Protocol version 3 is not supported" in str(exc_info.value) - - def test_migration_guidance_for_v4(self): - """ - Protocol v4 error should include migration guidance. - - What this tests: - --------------- - 1. v4 error includes specifics - 2. Mentions Cassandra 4.0 - 3. Release date provided - 4. Clear upgrade path - - Why this matters: - ---------------- - v4 users need specific help: - - Many on Cassandra 3.x - - Upgrade path exists - - Time-based guidance helps - - Actionable errors reduce - support burden. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(contact_points=["localhost"], protocol_version=4) - - error_msg = str(exc_info.value) - assert "async-cassandra requires CQL protocol v5" in error_msg - assert "Cassandra 4.0 (released July 2021)" in error_msg - - def test_error_message_includes_upgrade_path(self): - """ - Legacy protocol errors should include upgrade path. - - What this tests: - --------------- - 1. Errors mention upgrade - 2. Target version specified (4.0+) - 3. Actionable guidance - 4. Not just "not supported" - - Why this matters: - ---------------- - Good error messages: - - Guide users to solution - - Reduce confusion - - Speed up migration - - Users need to know both - problem AND solution. - """ - with pytest.raises(ConfigurationError) as exc_info: - AsyncCluster(contact_points=["localhost"], protocol_version=3) - - error_msg = str(exc_info.value) - assert "upgrade" in error_msg.lower() - assert "4.0+" in error_msg diff --git a/tests/unit/test_race_conditions.py b/tests/unit/test_race_conditions.py deleted file mode 100644 index 8c17c99..0000000 --- a/tests/unit/test_race_conditions.py +++ /dev/null @@ -1,545 +0,0 @@ -"""Race condition and deadlock prevention tests. - -This module tests for various race conditions including TOCTOU issues, -callback deadlocks, and concurrent access patterns. -""" - -import asyncio -import threading -import time -from unittest.mock import Mock - -import pytest - -from async_cassandra import AsyncCassandraSession as AsyncSession -from async_cassandra.result import AsyncResultHandler - - -def create_mock_response_future(rows=None, has_more_pages=False): - """Helper to create a properly configured mock ResponseFuture.""" - mock_future = Mock() - mock_future.has_more_pages = has_more_pages - mock_future.timeout = None # Avoid comparison issues - mock_future.add_callbacks = Mock() - - def handle_callbacks(callback=None, errback=None): - if callback: - callback(rows if rows is not None else []) - - mock_future.add_callbacks.side_effect = handle_callbacks - return mock_future - - -class TestRaceConditions: - """Test race conditions and thread safety.""" - - @pytest.mark.resilience - @pytest.mark.critical - async def test_toctou_event_loop_check(self): - """ - Test Time-of-Check-Time-of-Use race in event loop handling. - - What this tests: - --------------- - 1. Thread-safe event loop access from multiple threads - 2. Race conditions in get_or_create_event_loop utility - 3. Concurrent thread access to event loop creation - 4. Proper synchronization in event loop management - - Why this matters: - ---------------- - - Production systems often have multiple threads accessing async code - - TOCTOU bugs can cause crashes or incorrect behavior - - Event loop corruption can break entire applications - - Critical for mixed sync/async codebases - - Additional context: - --------------------------------- - - Simulates 20 concurrent threads accessing event loop - - Common pattern in web servers with thread pools - - Tests defensive programming in utils module - """ - from async_cassandra.utils import get_or_create_event_loop - - # Simulate rapid concurrent access from multiple threads - results = [] - errors = [] - - def worker(): - try: - loop = get_or_create_event_loop() - results.append(loop) - except Exception as e: - errors.append(e) - - # Create many threads to increase chance of race - threads = [] - for _ in range(20): - thread = threading.Thread(target=worker) - threads.append(thread) - - # Start all threads at once - for thread in threads: - thread.start() - - # Wait for completion - for thread in threads: - thread.join() - - # Should have no errors - assert len(errors) == 0 - # Each thread should get a valid event loop - assert len(results) == 20 - assert all(loop is not None for loop in results) - - @pytest.mark.resilience - async def test_callback_registration_race(self): - """ - Test race condition in callback registration. - - What this tests: - --------------- - 1. Thread-safe callback registration in AsyncResultHandler - 2. Race between success and error callbacks - 3. Proper result state management - 4. Only one callback should win in a race - - Why this matters: - ---------------- - - Callbacks from driver happen on different threads - - Race conditions can cause undefined behavior - - Result state must be consistent - - Prevents duplicate result processing - - Additional context: - --------------------------------- - - Driver callbacks are inherently multi-threaded - - Tests internal synchronization mechanisms - - Simulates real driver callback patterns - """ - # Create a mock ResponseFuture - mock_future = Mock() - mock_future.has_more_pages = False - mock_future.timeout = None - mock_future.add_callbacks = Mock() - - handler = AsyncResultHandler(mock_future) - results = [] - - # Try to register callbacks from multiple threads - def register_success(): - handler._handle_page(["success"]) - results.append("success") - - def register_error(): - handler._handle_error(Exception("error")) - results.append("error") - - # Start threads that race to set result - t1 = threading.Thread(target=register_success) - t2 = threading.Thread(target=register_error) - - t1.start() - t2.start() - - t1.join() - t2.join() - - # Only one should win - try: - result = await handler.get_result() - assert result._rows == ["success"] - assert results.count("success") >= 1 - except Exception as e: - assert str(e) == "error" - assert results.count("error") >= 1 - - @pytest.mark.resilience - @pytest.mark.critical - @pytest.mark.timeout(10) # Add timeout to prevent hanging - async def test_concurrent_session_operations(self): - """ - Test concurrent operations on same session. - - What this tests: - --------------- - 1. Thread-safe session operations under high concurrency - 2. No lost updates or race conditions in query execution - 3. Proper result isolation between concurrent queries - 4. Sequential counter integrity across 50 concurrent operations - - Why this matters: - ---------------- - - Production apps execute many queries concurrently - - Session must handle concurrent access safely - - Lost queries can cause data inconsistency - - Common pattern in web applications - - Additional context: - --------------------------------- - - Simulates 50 concurrent SELECT queries - - Verifies each query gets unique result - - Tests thread pool handling under load - """ - mock_session = Mock() - call_count = 0 - - def thread_safe_execute(*args, **kwargs): - nonlocal call_count - # Simulate some work - time.sleep(0.001) - call_count += 1 - - # Capture the count at creation time - current_count = call_count - return create_mock_response_future([{"count": current_count}]) - - mock_session.execute_async.side_effect = thread_safe_execute - - async_session = AsyncSession(mock_session) - - # Execute many queries concurrently - tasks = [] - for i in range(50): - task = asyncio.create_task(async_session.execute(f"SELECT COUNT(*) FROM table{i}")) - tasks.append(task) - - results = await asyncio.gather(*tasks) - - # All should complete - assert len(results) == 50 - assert call_count == 50 - - # Results should have sequential counts (no lost updates) - counts = sorted([r._rows[0]["count"] for r in results]) - assert counts == list(range(1, 51)) - - @pytest.mark.resilience - @pytest.mark.timeout(10) # Add timeout to prevent hanging - async def test_page_callback_deadlock_prevention(self): - """ - Test prevention of deadlock in paging callbacks. - - What this tests: - --------------- - 1. Independent iteration state for concurrent AsyncResultSet usage - 2. No deadlock when multiple coroutines iterate same result - 3. Sequential iteration works correctly - 4. Each iterator maintains its own position - - Why this matters: - ---------------- - - Paging through large results is common - - Deadlocks can hang entire applications - - Multiple consumers may process same result set - - Critical for streaming large datasets - - Additional context: - --------------------------------- - - Tests both concurrent and sequential iteration - - Each AsyncResultSet has independent state - - Simulates real paging scenarios - """ - from async_cassandra.result import AsyncResultSet - - # Test that each AsyncResultSet has its own iteration state - rows = [1, 2, 3, 4, 5, 6] - - # Create separate result sets for each concurrent iteration - async def collect_results(): - # Each task gets its own AsyncResultSet instance - result_set = AsyncResultSet(rows.copy()) - collected = [] - async for row in result_set: - # Simulate some async work - await asyncio.sleep(0.001) - collected.append(row) - return collected - - # Run multiple iterations concurrently - tasks = [asyncio.create_task(collect_results()) for _ in range(3)] - - # Wait for all to complete - all_results = await asyncio.gather(*tasks) - - # Each iteration should get all rows - for result in all_results: - assert result == [1, 2, 3, 4, 5, 6] - - # Also test that sequential iterations work correctly - single_result = AsyncResultSet([1, 2, 3]) - first_iteration = [] - async for row in single_result: - first_iteration.append(row) - - second_iteration = [] - async for row in single_result: - second_iteration.append(row) - - assert first_iteration == [1, 2, 3] - assert second_iteration == [1, 2, 3] - - @pytest.mark.resilience - @pytest.mark.timeout(15) # Increase timeout to account for 5s shutdown delay - async def test_session_close_during_query(self): - """ - Test closing session while queries are in flight. - - What this tests: - --------------- - 1. Graceful session closure with active queries - 2. Proper cleanup during 5-second shutdown delay - 3. In-flight queries complete before final closure - 4. No resource leaks or hanging queries - - Why this matters: - ---------------- - - Applications need graceful shutdown - - In-flight queries shouldn't be lost - - Resource cleanup is critical - - Prevents connection leaks in production - - Additional context: - --------------------------------- - - Tests 5-second graceful shutdown period - - Simulates real shutdown scenarios - - Critical for container deployments - """ - mock_session = Mock() - query_started = asyncio.Event() - query_can_proceed = asyncio.Event() - shutdown_called = asyncio.Event() - - def blocking_execute(*args): - # Create a mock ResponseFuture that blocks - mock_future = Mock() - mock_future.has_more_pages = False - mock_future.timeout = None # Avoid comparison issues - mock_future.add_callbacks = Mock() - - def handle_callbacks(callback=None, errback=None): - async def wait_and_callback(): - query_started.set() - await query_can_proceed.wait() - if callback: - callback([]) - - asyncio.create_task(wait_and_callback()) - - mock_future.add_callbacks.side_effect = handle_callbacks - return mock_future - - mock_session.execute_async.side_effect = blocking_execute - - def mock_shutdown(): - shutdown_called.set() - query_can_proceed.set() - - mock_session.shutdown = mock_shutdown - - async_session = AsyncSession(mock_session) - - # Start query - query_task = asyncio.create_task(async_session.execute("SELECT * FROM users")) - - # Wait for query to start - await query_started.wait() - - # Start closing session in background (includes 5s delay) - close_task = asyncio.create_task(async_session.close()) - - # Wait for driver shutdown - await shutdown_called.wait() - - # Query should complete during the 5s delay - await query_task - - # Wait for close to fully complete - await close_task - - # Session should be closed - assert async_session.is_closed - - @pytest.mark.resilience - @pytest.mark.critical - @pytest.mark.timeout(10) # Add timeout to prevent hanging - async def test_thread_pool_saturation(self): - """ - Test behavior when thread pool is saturated. - - What this tests: - --------------- - 1. Behavior with more queries than thread pool size - 2. No deadlock when thread pool is exhausted - 3. All queries eventually complete - 4. Async execution handles thread pool limits gracefully - - Why this matters: - ---------------- - - Production loads can exceed thread pool capacity - - Deadlocks under load are catastrophic - - Must handle burst traffic gracefully - - Common issue in high-traffic applications - - Additional context: - --------------------------------- - - Uses 2-thread pool with 6 concurrent queries - - Tests 3x oversubscription scenario - - Verifies async model prevents blocking - """ - from async_cassandra.cluster import AsyncCluster - - # Create cluster with small thread pool - cluster = AsyncCluster(executor_threads=2) - - # Mock the underlying cluster - mock_cluster = Mock() - mock_session = Mock() - - # Simulate slow queries - def slow_query(*args): - # Create a mock ResponseFuture that simulates delay - mock_future = Mock() - mock_future.has_more_pages = False - mock_future.timeout = None # Avoid comparison issues - mock_future.add_callbacks = Mock() - - def handle_callbacks(callback=None, errback=None): - # Call callback immediately to avoid empty result issue - if callback: - callback([{"id": 1}]) - - mock_future.add_callbacks.side_effect = handle_callbacks - return mock_future - - mock_session.execute_async.side_effect = slow_query - mock_cluster.connect.return_value = mock_session - - cluster._cluster = mock_cluster - cluster._cluster.protocol_version = 5 # Mock protocol version - - session = await cluster.connect() - - # Submit more queries than thread pool size - tasks = [] - for i in range(6): # 3x thread pool size - task = asyncio.create_task(session.execute(f"SELECT * FROM table{i}")) - tasks.append(task) - - # All should eventually complete - results = await asyncio.gather(*tasks) - - assert len(results) == 6 - # With async execution, all queries can run concurrently regardless of thread pool - # Just verify they all completed - assert all(result.rows == [{"id": 1}] for result in results) - - @pytest.mark.resilience - @pytest.mark.timeout(5) # Add timeout to prevent hanging - async def test_event_loop_callback_ordering(self): - """ - Test that callbacks maintain order when scheduled. - - What this tests: - --------------- - 1. Thread-safe callback scheduling to event loop - 2. All callbacks execute despite concurrent scheduling - 3. No lost callbacks under concurrent access - 4. safe_call_soon_threadsafe utility correctness - - Why this matters: - ---------------- - - Driver callbacks come from multiple threads - - Lost callbacks mean lost query results - - Order preservation prevents race conditions - - Foundation of async-to-sync bridge - - Additional context: - --------------------------------- - - Tests 10 concurrent threads scheduling callbacks - - Verifies thread-safe event loop integration - - Core to driver callback handling - """ - from async_cassandra.utils import safe_call_soon_threadsafe - - results = [] - loop = asyncio.get_running_loop() - - # Schedule callbacks from different threads - def schedule_callback(value): - safe_call_soon_threadsafe(loop, results.append, value) - - threads = [] - for i in range(10): - thread = threading.Thread(target=schedule_callback, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads - for thread in threads: - thread.join() - - # Give callbacks time to execute - await asyncio.sleep(0.1) - - # All callbacks should have executed - assert len(results) == 10 - assert sorted(results) == list(range(10)) - - @pytest.mark.resilience - @pytest.mark.timeout(10) # Add timeout to prevent hanging - async def test_prepared_statement_concurrent_access(self): - """ - Test concurrent access to prepared statements. - - What this tests: - --------------- - 1. Thread-safe prepared statement creation - 2. Multiple coroutines preparing same statement - 3. No corruption during concurrent preparation - 4. All coroutines receive valid prepared statement - - Why this matters: - ---------------- - - Prepared statements are performance critical - - Concurrent preparation is common at startup - - Statement corruption causes query failures - - Caching optimization opportunity identified - - Additional context: - --------------------------------- - - Currently allows duplicate preparation - - Future optimization: statement caching - - Tests current thread-safe behavior - """ - mock_session = Mock() - mock_prepared = Mock() - - prepare_count = 0 - - def prepare_side_effect(*args): - nonlocal prepare_count - prepare_count += 1 - time.sleep(0.01) # Simulate preparation time - return mock_prepared - - mock_session.prepare.side_effect = prepare_side_effect - - # Create a mock ResponseFuture for execute_async - mock_session.execute_async.return_value = create_mock_response_future([]) - - async_session = AsyncSession(mock_session) - - # Many coroutines try to prepare same statement - tasks = [] - for _ in range(10): - task = asyncio.create_task(async_session.prepare("SELECT * FROM users WHERE id = ?")) - tasks.append(task) - - prepared_statements = await asyncio.gather(*tasks) - - # All should get the same prepared statement - assert all(ps == mock_prepared for ps in prepared_statements) - # But prepare should only be called once (would need caching impl) - # For now, it's called multiple times - assert prepare_count == 10 diff --git a/tests/unit/test_response_future_cleanup.py b/tests/unit/test_response_future_cleanup.py deleted file mode 100644 index 11d679e..0000000 --- a/tests/unit/test_response_future_cleanup.py +++ /dev/null @@ -1,380 +0,0 @@ -""" -Unit tests for explicit cleanup of ResponseFuture callbacks on error. -""" - -import asyncio -from unittest.mock import Mock - -import pytest - -from async_cassandra.exceptions import ConnectionError -from async_cassandra.result import AsyncResultHandler -from async_cassandra.session import AsyncCassandraSession -from async_cassandra.streaming import AsyncStreamingResultSet - - -@pytest.mark.asyncio -class TestResponseFutureCleanup: - """Test explicit cleanup of ResponseFuture callbacks.""" - - async def test_handler_cleanup_on_error(self): - """ - Test that callbacks are cleaned up when handler encounters error. - - What this tests: - --------------- - 1. Callbacks cleared on error - 2. ResponseFuture cleanup called - 3. No dangling references - 4. Error still propagated - - Why this matters: - ---------------- - Callback cleanup prevents: - - Memory leaks - - Circular references - - Ghost callbacks firing - - Critical for long-running apps - with many queries. - """ - # Create mock response future - response_future = Mock() - response_future.has_more_pages = True # Prevent immediate completion - response_future.add_callbacks = Mock() - response_future.timeout = None - - # Track if callbacks were cleared - callbacks_cleared = False - - def mock_clear_callbacks(): - nonlocal callbacks_cleared - callbacks_cleared = True - - response_future.clear_callbacks = mock_clear_callbacks - - # Create handler - handler = AsyncResultHandler(response_future) - - # Start get_result - result_task = asyncio.create_task(handler.get_result()) - await asyncio.sleep(0.01) # Let it set up - - # Trigger error callback - call_args = response_future.add_callbacks.call_args - if call_args: - errback = call_args.kwargs.get("errback") - if errback: - errback(Exception("Test error")) - - # Should get the error - with pytest.raises(Exception, match="Test error"): - await result_task - - # Callbacks should be cleared - assert callbacks_cleared, "Callbacks were not cleared on error" - - async def test_streaming_cleanup_on_error(self): - """ - Test that streaming callbacks are cleaned up on error. - - What this tests: - --------------- - 1. Streaming error triggers cleanup - 2. Callbacks cleared properly - 3. Error propagated to iterator - 4. Resources freed - - Why this matters: - ---------------- - Streaming holds more resources: - - Page callbacks - - Event handlers - - Buffer memory - - Must clean up even on partial - stream consumption. - """ - # Create mock response future - response_future = Mock() - response_future.has_more_pages = True - response_future.add_callbacks = Mock() - response_future.start_fetching_next_page = Mock() - - # Track if callbacks were cleared - callbacks_cleared = False - - def mock_clear_callbacks(): - nonlocal callbacks_cleared - callbacks_cleared = True - - response_future.clear_callbacks = mock_clear_callbacks - - # Create streaming result set - result_set = AsyncStreamingResultSet(response_future) - - # Get the registered callbacks - call_args = response_future.add_callbacks.call_args - callback = call_args.kwargs.get("callback") if call_args else None - errback = call_args.kwargs.get("errback") if call_args else None - - # First trigger initial page callback to set up state - callback([]) # Empty initial page - - # Now trigger error for streaming - errback(Exception("Streaming error")) - - # Try to iterate - should get error immediately - error_raised = False - try: - async for _ in result_set: - pass - except Exception as e: - error_raised = True - assert str(e) == "Streaming error" - - assert error_raised, "No error raised during iteration" - - # Callbacks should be cleared - assert callbacks_cleared, "Callbacks were not cleared on streaming error" - - async def test_handler_cleanup_on_timeout(self): - """ - Test cleanup when operation times out. - - What this tests: - --------------- - 1. Timeout triggers cleanup - 2. Callbacks cleared - 3. TimeoutError raised - 4. No hanging callbacks - - Why this matters: - ---------------- - Timeouts common in production: - - Network issues - - Overloaded servers - - Slow queries - - Must clean up to prevent - resource accumulation. - """ - # Create mock response future that never completes - response_future = Mock() - response_future.has_more_pages = True # Prevent immediate completion - response_future.add_callbacks = Mock() - response_future.timeout = 0.1 # Short timeout - - # Track if callbacks were cleared - callbacks_cleared = False - - def mock_clear_callbacks(): - nonlocal callbacks_cleared - callbacks_cleared = True - - response_future.clear_callbacks = mock_clear_callbacks - - # Create handler - handler = AsyncResultHandler(response_future) - - # Should timeout - with pytest.raises(asyncio.TimeoutError): - await handler.get_result() - - # Callbacks should be cleared - assert callbacks_cleared, "Callbacks were not cleared on timeout" - - async def test_no_memory_leak_on_error(self): - """ - Test that error handling cleans up properly to prevent memory leaks. - - What this tests: - --------------- - 1. Error path cleans callbacks - 2. Internal state cleaned - 3. Future marked done - 4. Circular refs broken - - Why this matters: - ---------------- - Memory leaks kill apps: - - Gradual memory growth - - Eventually OOM - - Hard to diagnose - - Proper cleanup essential for - production stability. - """ - # Create response future - response_future = Mock() - response_future.has_more_pages = True # Prevent immediate completion - response_future.add_callbacks = Mock() - response_future.timeout = None - response_future.clear_callbacks = Mock() - - # Create handler - handler = AsyncResultHandler(response_future) - - # Start task - task = asyncio.create_task(handler.get_result()) - await asyncio.sleep(0.01) - - # Trigger error - call_args = response_future.add_callbacks.call_args - if call_args: - errback = call_args.kwargs.get("errback") - if errback: - errback(Exception("Memory test")) - - # Get error - with pytest.raises(Exception): - await task - - # Verify that callbacks were cleared on error - # This is the important part - breaking circular references - assert response_future.clear_callbacks.called - assert response_future.clear_callbacks.call_count >= 1 - - # Also verify the handler cleans up its internal state - assert handler._future is not None # Future was created - assert handler._future.done() # Future completed with error - - async def test_session_cleanup_on_close(self): - """ - Test that session cleans up callbacks when closed. - - What this tests: - --------------- - 1. Session close prevents new ops - 2. Existing ops complete - 3. New ops raise ConnectionError - 4. Clean shutdown behavior - - Why this matters: - ---------------- - Graceful shutdown requires: - - Complete in-flight queries - - Reject new queries - - Clean up resources - - Prevents data loss and - connection leaks. - """ - # Create mock session - mock_session = Mock() - - # Create separate futures for each operation - futures_created = [] - - def create_future(*args, **kwargs): - future = Mock() - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - - # Store callbacks when registered - def register_callbacks(callback=None, errback=None): - future._callback = callback - future._errback = errback - - future.add_callbacks = Mock(side_effect=register_callbacks) - futures_created.append(future) - return future - - mock_session.execute_async = Mock(side_effect=create_future) - mock_session.shutdown = Mock() - - # Create async session - async_session = AsyncCassandraSession(mock_session) - - # Start multiple operations - tasks = [] - for i in range(3): - task = asyncio.create_task(async_session.execute(f"SELECT {i}")) - tasks.append(task) - - await asyncio.sleep(0.01) # Let them start - - # Complete the operations by triggering callbacks - for i, future in enumerate(futures_created): - if hasattr(future, "_callback") and future._callback: - future._callback([f"row{i}"]) - - # Wait for all tasks to complete - results = await asyncio.gather(*tasks) - - # Now close the session - await async_session.close() - - # Verify all operations completed successfully - assert len(results) == 3 - - # New operations should fail - with pytest.raises(ConnectionError): - await async_session.execute("SELECT after close") - - async def test_cleanup_prevents_callback_execution(self): - """ - Test that cleaned callbacks don't execute. - - What this tests: - --------------- - 1. Cleared callbacks don't fire - 2. No zombie callbacks - 3. Cleanup is effective - 4. State properly cleared - - Why this matters: - ---------------- - Zombie callbacks cause: - - Unexpected behavior - - Race conditions - - Data corruption - - Cleanup must truly prevent - future callback execution. - """ - # Create response future - response_future = Mock() - response_future.has_more_pages = False - response_future.add_callbacks = Mock() - response_future.timeout = None - - # Track callback execution - callback_executed = False - original_callback = None - - def track_add_callbacks(callback=None, errback=None): - nonlocal original_callback - original_callback = callback - - response_future.add_callbacks = track_add_callbacks - - def clear_callbacks(): - nonlocal original_callback - original_callback = None # Simulate clearing - - response_future.clear_callbacks = clear_callbacks - - # Create handler - handler = AsyncResultHandler(response_future) - - # Start task - task = asyncio.create_task(handler.get_result()) - await asyncio.sleep(0.01) - - # Clear callbacks (simulating cleanup) - response_future.clear_callbacks() - - # Try to trigger callback - should have no effect - if original_callback: - callback_executed = True - - # Cancel task to clean up - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - assert not callback_executed, "Callback executed after cleanup" diff --git a/tests/unit/test_result.py b/tests/unit/test_result.py deleted file mode 100644 index 6f29b56..0000000 --- a/tests/unit/test_result.py +++ /dev/null @@ -1,479 +0,0 @@ -""" -Unit tests for async result handling. - -This module tests the core result handling mechanisms that convert -Cassandra driver's callback-based results into Python async/await -compatible results. - -Test Organization: -================== -- TestAsyncResultHandler: Tests the callback-to-async conversion -- TestAsyncResultSet: Tests the result set wrapper functionality - -Key Testing Focus: -================== -1. Single and multi-page result handling -2. Error propagation from callbacks -3. Async iteration protocol -4. Result set convenience methods (one(), all()) -5. Empty result handling -""" - -from unittest.mock import Mock - -import pytest - -from async_cassandra.result import AsyncResultHandler, AsyncResultSet - - -class TestAsyncResultHandler: - """ - Test cases for AsyncResultHandler. - - AsyncResultHandler is the bridge between Cassandra driver's callback-based - ResponseFuture and Python's async/await. It registers callbacks that get - called when results are ready and converts them to awaitable results. - """ - - @pytest.fixture - def mock_response_future(self): - """ - Create a mock ResponseFuture. - - ResponseFuture is the driver's async result object that uses - callbacks. We mock it to test our handler without real queries. - """ - future = Mock() - future.has_more_pages = False - future.add_callbacks = Mock() - future.timeout = None # Add timeout attribute for new timeout handling - return future - - @pytest.mark.asyncio - async def test_single_page_result(self, mock_response_future): - """ - Test handling single page of results. - - What this tests: - --------------- - 1. Handler correctly receives page callback - 2. Single page results are wrapped in AsyncResultSet - 3. get_result() returns when page is complete - 4. No pagination logic triggered for single page - - Why this matters: - ---------------- - Most queries return a single page of results. This is the - common case that must work efficiently: - - Small result sets - - Queries with LIMIT - - Single row lookups - - The handler should not add overhead for simple cases. - """ - handler = AsyncResultHandler(mock_response_future) - - # Simulate successful page callback - test_rows = [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}] - handler._handle_page(test_rows) - - # Get result - result = await handler.get_result() - - assert isinstance(result, AsyncResultSet) - assert len(result) == 2 - assert result.rows == test_rows - - @pytest.mark.asyncio - async def test_multi_page_result(self, mock_response_future): - """ - Test handling multiple pages of results. - - What this tests: - --------------- - 1. Multi-page results are handled correctly - 2. Next page fetch is triggered automatically - 3. All pages are accumulated into final result - 4. has_more_pages flag controls pagination - - Why this matters: - ---------------- - Large result sets are split into pages to: - - Prevent memory exhaustion - - Allow incremental processing - - Control network bandwidth - - The handler must: - - Automatically fetch all pages - - Accumulate results correctly - - Handle page boundaries transparently - - Common with: - - Large table scans - - No LIMIT queries - - Analytics workloads - """ - # Configure mock for multiple pages - mock_response_future.has_more_pages = True - mock_response_future.start_fetching_next_page = Mock() - - handler = AsyncResultHandler(mock_response_future) - - # First page - first_page = [{"id": 1}, {"id": 2}] - handler._handle_page(first_page) - - # Verify next page fetch was triggered - mock_response_future.start_fetching_next_page.assert_called_once() - - # Second page (final) - mock_response_future.has_more_pages = False - second_page = [{"id": 3}, {"id": 4}] - handler._handle_page(second_page) - - # Get result - result = await handler.get_result() - - assert len(result) == 4 - assert result.rows == first_page + second_page - - @pytest.mark.asyncio - async def test_error_handling(self, mock_response_future): - """ - Test error handling in result handler. - - What this tests: - --------------- - 1. Errors from callbacks are captured - 2. Errors are propagated when get_result() is called - 3. Original exception is preserved - 4. No results are returned on error - - Why this matters: - ---------------- - Many things can go wrong during query execution: - - Network failures - - Query syntax errors - - Timeout exceptions - - Server overload - - The handler must: - - Capture errors from callbacks - - Propagate them at the right time - - Preserve error details for debugging - - Without proper error handling, errors could be: - - Silently swallowed - - Raised at callback time (wrong thread) - - Lost without stack trace - """ - handler = AsyncResultHandler(mock_response_future) - - # Simulate error callback - test_error = Exception("Query failed") - handler._handle_error(test_error) - - # Should raise the exception - with pytest.raises(Exception) as exc_info: - await handler.get_result() - - assert str(exc_info.value) == "Query failed" - - @pytest.mark.asyncio - async def test_callback_registration(self, mock_response_future): - """ - Test that callbacks are properly registered. - - What this tests: - --------------- - 1. Callbacks are registered on ResponseFuture - 2. Both success and error callbacks are set - 3. Correct handler methods are used - 4. Registration happens during init - - Why this matters: - ---------------- - The callback registration is the critical link between - driver and our async wrapper: - - Must register before results arrive - - Must handle both success and error paths - - Must use correct method signatures - - If registration fails: - - Results would never arrive - - Queries would hang forever - - Errors would be lost - - This test ensures the "wiring" is correct. - """ - handler = AsyncResultHandler(mock_response_future) - - # Verify callbacks were registered - mock_response_future.add_callbacks.assert_called_once() - call_args = mock_response_future.add_callbacks.call_args - - assert call_args.kwargs["callback"] == handler._handle_page - assert call_args.kwargs["errback"] == handler._handle_error - - -class TestAsyncResultSet: - """ - Test cases for AsyncResultSet. - - AsyncResultSet wraps query results to provide async iteration - and convenience methods. It's what users interact with after - executing a query. - """ - - @pytest.fixture - def sample_rows(self): - """ - Create sample row data. - - Simulates typical query results with multiple rows - and columns. Used across multiple tests. - """ - return [ - {"id": 1, "name": "Alice", "age": 30}, - {"id": 2, "name": "Bob", "age": 25}, - {"id": 3, "name": "Charlie", "age": 35}, - ] - - @pytest.mark.asyncio - async def test_async_iteration(self, sample_rows): - """ - Test async iteration over result set. - - What this tests: - --------------- - 1. AsyncResultSet supports 'async for' syntax - 2. All rows are yielded in order - 3. Iteration completes normally - 4. Each row is accessible during iteration - - Why this matters: - ---------------- - Async iteration is the primary way to process results: - ```python - async for row in result: - await process_row(row) - ``` - - This enables: - - Non-blocking result processing - - Integration with async frameworks - - Natural Python syntax - - Without this, users would need callbacks or blocking calls. - """ - result_set = AsyncResultSet(sample_rows) - - collected_rows = [] - async for row in result_set: - collected_rows.append(row) - - assert collected_rows == sample_rows - - def test_len(self, sample_rows): - """ - Test length of result set. - - What this tests: - --------------- - 1. len() works on AsyncResultSet - 2. Returns correct count of rows - 3. Works with standard Python functions - - Why this matters: - ---------------- - Users expect Pythonic behavior: - - if len(result) > 0: - - print(f"Found {len(result)} rows") - - assert len(result) == expected_count - - This makes AsyncResultSet feel like a normal collection. - """ - result_set = AsyncResultSet(sample_rows) - assert len(result_set) == 3 - - def test_one_with_results(self, sample_rows): - """ - Test one() method with results. - - What this tests: - --------------- - 1. one() returns first row when results exist - 2. Only the first row is returned (not a list) - 3. Remaining rows are ignored - - Why this matters: - ---------------- - Common pattern for single-row queries: - ```python - user = result.one() - if user: - print(f"Found user: {user.name}") - ``` - - Useful for: - - Primary key lookups - - COUNT queries - - Existence checks - - Mirrors driver's ResultSet.one() behavior. - """ - result_set = AsyncResultSet(sample_rows) - assert result_set.one() == sample_rows[0] - - def test_one_empty(self): - """ - Test one() method with empty results. - - What this tests: - --------------- - 1. one() returns None for empty results - 2. No exception is raised - 3. Safe to use without checking length first - - Why this matters: - ---------------- - Handles the "not found" case gracefully: - ```python - user = result.one() - if not user: - raise NotFoundError("User not found") - ``` - - No need for try/except or length checks. - """ - result_set = AsyncResultSet([]) - assert result_set.one() is None - - def test_all(self, sample_rows): - """ - Test all() method. - - What this tests: - --------------- - 1. all() returns complete list of rows - 2. Original row order is preserved - 3. Returns actual list (not iterator) - - Why this matters: - ---------------- - Sometimes you need all results immediately: - - Converting to JSON - - Passing to templates - - Batch processing - - Convenience method avoids: - ```python - rows = [row async for row in result] # More complex - ``` - """ - result_set = AsyncResultSet(sample_rows) - assert result_set.all() == sample_rows - - def test_rows_property(self, sample_rows): - """ - Test rows property. - - What this tests: - --------------- - 1. Direct access to underlying rows list - 2. Returns same data as all() - 3. Property access (no parentheses) - - Why this matters: - ---------------- - Provides flexibility: - - result.rows for property access - - result.all() for method call - - Both return same data - - Some users prefer property syntax for data access. - """ - result_set = AsyncResultSet(sample_rows) - assert result_set.rows == sample_rows - - @pytest.mark.asyncio - async def test_empty_iteration(self): - """ - Test iteration over empty result set. - - What this tests: - --------------- - 1. Empty result sets can be iterated - 2. No rows are yielded - 3. Iteration completes immediately - 4. No errors or hangs occur - - Why this matters: - ---------------- - Empty results are common and must work correctly: - - No matching rows - - Deleted data - - Fresh tables - - The iteration should complete gracefully without - special handling: - ```python - async for row in result: # Should not error if empty - process(row) - ``` - """ - result_set = AsyncResultSet([]) - - collected_rows = [] - async for row in result_set: - collected_rows.append(row) - - assert collected_rows == [] - - @pytest.mark.asyncio - async def test_multiple_iterations(self, sample_rows): - """ - Test that result set can be iterated multiple times. - - What this tests: - --------------- - 1. Same result set can be iterated repeatedly - 2. Each iteration yields all rows - 3. Order is consistent across iterations - 4. No state corruption between iterations - - Why this matters: - ---------------- - Unlike generators, AsyncResultSet allows re-iteration: - - Processing results multiple ways - - Retry logic after errors - - Debugging (print then process) - - This differs from streaming results which can only - be consumed once. AsyncResultSet holds all data in - memory, allowing multiple passes. - - Example use case: - ---------------- - # First pass: validation - async for row in result: - validate(row) - - # Second pass: processing - async for row in result: - await process(row) - """ - result_set = AsyncResultSet(sample_rows) - - # First iteration - first_iter = [] - async for row in result_set: - first_iter.append(row) - - # Second iteration - second_iter = [] - async for row in result_set: - second_iter.append(row) - - assert first_iter == sample_rows - assert second_iter == sample_rows diff --git a/tests/unit/test_results.py b/tests/unit/test_results.py deleted file mode 100644 index 6d3ebd4..0000000 --- a/tests/unit/test_results.py +++ /dev/null @@ -1,437 +0,0 @@ -"""Core result handling tests. - -This module tests AsyncResultHandler and AsyncResultSet functionality, -which are critical for proper async operation of query results. - -Test Organization: -================== -- TestAsyncResultHandler: Core callback-to-async conversion tests -- TestAsyncResultSet: Result collection wrapper tests - -Key Testing Focus: -================== -1. Callback registration and handling -2. Multi-callback safety (duplicate calls) -3. Result set iteration and access patterns -4. Property access and convenience methods -5. Edge cases (empty results, single results) - -Note: This complements test_result.py with additional edge cases. -""" - -from unittest.mock import Mock - -import pytest -from cassandra.cluster import ResponseFuture - -from async_cassandra.result import AsyncResultHandler, AsyncResultSet - - -class TestAsyncResultHandler: - """ - Test AsyncResultHandler for callback-based result handling. - - This class focuses on the core mechanics of converting Cassandra's - callback-based results to Python async/await. It tests edge cases - not covered in test_result.py. - """ - - @pytest.mark.core - @pytest.mark.quick - async def test_init(self): - """ - Test AsyncResultHandler initialization. - - What this tests: - --------------- - 1. Handler stores reference to ResponseFuture - 2. Empty rows list is initialized - 3. Callbacks are registered immediately - 4. Handler is ready to receive results - - Why this matters: - ---------------- - Initialization must happen quickly before results arrive: - - Callbacks must be registered before driver calls them - - State must be initialized to handle results - - No async operations during init (can't await) - - The handler is the critical bridge between sync callbacks - and async/await, so initialization must be bulletproof. - """ - mock_future = Mock(spec=ResponseFuture) - mock_future.add_callbacks = Mock() - - handler = AsyncResultHandler(mock_future) - assert handler.response_future == mock_future - assert handler.rows == [] - mock_future.add_callbacks.assert_called_once() - - @pytest.mark.core - async def test_on_success(self): - """ - Test successful result handling. - - What this tests: - --------------- - 1. Success callback properly receives rows - 2. Rows are stored in the handler - 3. Result future completes with AsyncResultSet - 4. No paging logic for single-page results - - Why this matters: - ---------------- - The success path is the most common case: - - Query executes successfully - - Results arrive via callback - - Must convert to awaitable result - - This tests the happy path that 99% of queries follow. - The callback happens in driver thread, so thread safety - is critical here. - """ - mock_future = Mock(spec=ResponseFuture) - mock_future.add_callbacks = Mock() - mock_future.has_more_pages = False - - handler = AsyncResultHandler(mock_future) - - # Get result future and simulate success callback - result_future = handler.get_result() - - # Simulate the driver calling our success callback - mock_result = Mock() - mock_result.current_rows = [{"id": 1}, {"id": 2}] - handler._handle_page(mock_result.current_rows) - - result = await result_future - assert isinstance(result, AsyncResultSet) - - @pytest.mark.core - async def test_on_error(self): - """ - Test error handling. - - What this tests: - --------------- - 1. Error callback receives exceptions - 2. Exception is stored and re-raised on await - 3. No result is returned on error - 4. Original exception details preserved - - Why this matters: - ---------------- - Error handling is critical for debugging: - - Network errors - - Query syntax errors - - Timeout errors - - Permission errors - - The error must be: - - Captured from callback thread - - Stored until await - - Re-raised with full details - - Not swallowed or lost - """ - mock_future = Mock(spec=ResponseFuture) - mock_future.add_callbacks = Mock() - - handler = AsyncResultHandler(mock_future) - error = Exception("Test error") - - # Get result future and simulate error callback - result_future = handler.get_result() - handler._handle_error(error) - - with pytest.raises(Exception, match="Test error"): - await result_future - - @pytest.mark.core - @pytest.mark.critical - async def test_multiple_callbacks(self): - """ - Test that multiple success/error calls don't break the handler. - - What this tests: - --------------- - 1. First callback sets the result - 2. Subsequent callbacks are safely ignored - 3. No exceptions from duplicate callbacks - 4. Result remains stable after first callback - - Why this matters: - ---------------- - Defensive programming against driver bugs: - - Driver might call callbacks multiple times - - Race conditions in callback handling - - Error after success (or vice versa) - - Real-world scenario: - - Network packet arrives late - - Retry logic in driver - - Threading race conditions - - The handler must be idempotent - multiple calls should - not corrupt state or raise exceptions. First result wins. - """ - mock_future = Mock(spec=ResponseFuture) - mock_future.add_callbacks = Mock() - mock_future.has_more_pages = False - - handler = AsyncResultHandler(mock_future) - - # Get result future - result_future = handler.get_result() - - # First success should set the result - mock_result = Mock() - mock_result.current_rows = [{"id": 1}] - handler._handle_page(mock_result.current_rows) - - result = await result_future - assert isinstance(result, AsyncResultSet) - - # Subsequent calls should be ignored (no exceptions) - handler._handle_page([{"id": 2}]) - handler._handle_error(Exception("should be ignored")) - - -class TestAsyncResultSet: - """ - Test AsyncResultSet for handling query results. - - Tests additional functionality not covered in test_result.py, - focusing on edge cases and additional access patterns. - """ - - @pytest.mark.core - @pytest.mark.quick - async def test_init_single_page(self): - """ - Test initialization with single page result. - - What this tests: - --------------- - 1. ResultSet correctly stores provided rows - 2. No data transformation during init - 3. Rows are accessible immediately - 4. Works with typical dict-like row data - - Why this matters: - ---------------- - Single page results are the most common case: - - Queries with LIMIT - - Primary key lookups - - Small tables - - Initialization should be fast and simple, just - storing the rows for later access. - """ - rows = [{"id": 1}, {"id": 2}, {"id": 3}] - - async_result = AsyncResultSet(rows) - assert async_result.rows == rows - - @pytest.mark.core - async def test_init_empty(self): - """ - Test initialization with empty result. - - What this tests: - --------------- - 1. Empty list is handled correctly - 2. No errors with zero rows - 3. Properties work with empty data - 4. Ready for iteration (will complete immediately) - - Why this matters: - ---------------- - Empty results are common and must work: - - No matching WHERE clause - - Deleted data - - Fresh tables - - Empty ResultSet should behave like empty list, - not None or error. - """ - async_result = AsyncResultSet([]) - assert async_result.rows == [] - - @pytest.mark.core - @pytest.mark.critical - async def test_async_iteration(self): - """ - Test async iteration over results. - - What this tests: - --------------- - 1. Supports async for syntax - 2. Yields rows in correct order - 3. Completes after all rows - 4. Each row is yielded exactly once - - Why this matters: - ---------------- - Core functionality for result processing: - ```python - async for row in results: - await process(row) - ``` - - Must work correctly for: - - FastAPI endpoints - - Async data processing - - Streaming responses - - Async iteration allows non-blocking processing - of each row, critical for scalability. - """ - rows = [{"id": 1}, {"id": 2}, {"id": 3}] - async_result = AsyncResultSet(rows) - - results = [] - async for row in async_result: - results.append(row) - - assert results == rows - - @pytest.mark.core - async def test_one(self): - """ - Test getting single result. - - What this tests: - --------------- - 1. one() returns first row - 2. Works with single row result - 3. Returns actual row, not wrapped - 4. Matches driver behavior - - Why this matters: - ---------------- - Optimized for single-row queries: - - User lookup by ID - - Configuration values - - Existence checks - - Simpler than iteration for single values. - """ - rows = [{"id": 1, "name": "test"}] - async_result = AsyncResultSet(rows) - - result = async_result.one() - assert result == {"id": 1, "name": "test"} - - @pytest.mark.core - async def test_all(self): - """ - Test getting all results. - - What this tests: - --------------- - 1. all() returns complete row list - 2. No async needed (already in memory) - 3. Returns actual list, not copy - 4. Preserves row order - - Why this matters: - ---------------- - For when you need all data at once: - - JSON serialization - - Bulk operations - - Data export - - More convenient than list comprehension. - """ - rows = [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}] - async_result = AsyncResultSet(rows) - - results = async_result.all() - assert results == rows - - @pytest.mark.core - async def test_len(self): - """ - Test getting result count. - - What this tests: - --------------- - 1. len() protocol support - 2. Accurate row count - 3. O(1) operation (not counting) - 4. Works with empty results - - Why this matters: - ---------------- - Standard Python patterns: - - Checking if results exist - - Pagination calculations - - Progress reporting - - Makes ResultSet feel native. - """ - rows = [{"id": i} for i in range(5)] - async_result = AsyncResultSet(rows) - - assert len(async_result) == 5 - - @pytest.mark.core - async def test_getitem(self): - """ - Test indexed access to results. - - What this tests: - --------------- - 1. Square bracket notation works - 2. Zero-based indexing - 3. Access specific rows by position - 4. Returns actual row data - - Why this matters: - ---------------- - Pythonic access patterns: - - first = results[0] - - last = results[-1] - - middle = results[len(results)//2] - - Useful for: - - Accessing specific rows - - Sampling results - - Testing specific positions - - Makes ResultSet behave like a list. - """ - rows = [{"id": 1, "name": "test"}, {"id": 2, "name": "test2"}] - async_result = AsyncResultSet(rows) - - assert async_result[0] == {"id": 1, "name": "test"} - assert async_result[1] == {"id": 2, "name": "test2"} - - @pytest.mark.core - async def test_properties(self): - """ - Test result set properties. - - What this tests: - --------------- - 1. Direct access to rows property - 2. Property returns underlying list - 3. Can check length via property - 4. Properties are consistent - - Why this matters: - ---------------- - Properties provide direct access: - - Debugging (inspect results.rows) - - Integration with other code - - Performance (no method call) - - The .rows property gives escape hatch to - raw data when needed. - """ - rows = [{"id": 1}, {"id": 2}, {"id": 3}] - async_result = AsyncResultSet(rows) - - # Check basic properties - assert len(async_result.rows) == 3 - assert async_result.rows == rows diff --git a/tests/unit/test_retry_policy_unified.py b/tests/unit/test_retry_policy_unified.py deleted file mode 100644 index 4d6dc8d..0000000 --- a/tests/unit/test_retry_policy_unified.py +++ /dev/null @@ -1,940 +0,0 @@ -""" -Unified retry policy tests for async-python-cassandra. - -This module consolidates all retry policy testing from multiple files: -- test_retry_policy.py: Basic retry policy initialization and configuration -- test_retry_policies.py: Partial consolidation attempt (used as base) -- test_retry_policy_comprehensive.py: Query-specific retry scenarios -- test_retry_policy_idempotency.py: Deep idempotency validation -- test_retry_policy_unlogged_batch.py: UNLOGGED_BATCH specific tests - -Test Organization: -================== -1. Basic Retry Policy Tests - Initialization, configuration, basic behavior -2. Read Timeout Tests - All read timeout scenarios -3. Write Timeout Tests - All write timeout scenarios -4. Unavailable Tests - Node unavailability handling -5. Idempotency Tests - Comprehensive idempotency validation -6. Batch Operation Tests - LOGGED and UNLOGGED batch handling -7. Error Propagation Tests - Error handling and logging -8. Edge Cases - Special scenarios and boundary conditions - -Key Testing Principles: -====================== -- Test both idempotent and non-idempotent operations -- Verify retry counts and decision logic -- Ensure consistency level adjustments are correct -- Test all ConsistencyLevel combinations -- Validate error messages and logging -""" - -from unittest.mock import Mock - -from cassandra.policies import ConsistencyLevel, RetryPolicy, WriteType - -from async_cassandra.retry_policy import AsyncRetryPolicy - - -class TestAsyncRetryPolicy: - """ - Comprehensive tests for AsyncRetryPolicy. - - AsyncRetryPolicy extends the standard retry policy to handle - async operations while maintaining idempotency guarantees. - """ - - # ======================================== - # Basic Retry Policy Tests - # ======================================== - - def test_initialization_default(self): - """ - Test default initialization of AsyncRetryPolicy. - - What this tests: - --------------- - 1. Policy can be created without parameters - 2. Default max retries is 3 - 3. Inherits from cassandra.policies.RetryPolicy - - Why this matters: - ---------------- - The retry policy must work with sensible defaults for - users who don't customize retry behavior. - """ - policy = AsyncRetryPolicy() - assert isinstance(policy, RetryPolicy) - assert policy.max_retries == 3 - - def test_initialization_custom_max_retries(self): - """ - Test initialization with custom max retries. - - What this tests: - --------------- - 1. Custom max_retries is respected - 2. Value is stored correctly - - Why this matters: - ---------------- - Different applications have different tolerance for retries. - Some may want more aggressive retries, others less. - """ - policy = AsyncRetryPolicy(max_retries=5) - assert policy.max_retries == 5 - - def test_initialization_zero_retries(self): - """ - Test initialization with zero retries (fail fast). - - What this tests: - --------------- - 1. Zero retries is valid configuration - 2. Policy will not retry on failures - - Why this matters: - ---------------- - Some applications prefer to fail fast and handle - retries at a higher level. - """ - policy = AsyncRetryPolicy(max_retries=0) - assert policy.max_retries == 0 - - # ======================================== - # Read Timeout Tests - # ======================================== - - def test_on_read_timeout_sufficient_responses(self): - """ - Test read timeout when we have enough responses. - - What this tests: - --------------- - 1. When received >= required, retry the read - 2. Retry count is incremented - 3. Returns RETRY decision - - Why this matters: - ---------------- - If we got enough responses but timed out, the data - likely exists and a retry might succeed. - """ - policy = AsyncRetryPolicy() - query = Mock() - - decision = policy.on_read_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - required_responses=2, - received_responses=2, # Got enough responses - data_retrieved=False, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETRY, ConsistencyLevel.QUORUM) - - def test_on_read_timeout_insufficient_responses(self): - """ - Test read timeout when we don't have enough responses. - - What this tests: - --------------- - 1. When received < required, rethrow the error - 2. No retry attempted - - Why this matters: - ---------------- - If we didn't get enough responses, retrying immediately - is unlikely to help. Better to fail fast. - """ - policy = AsyncRetryPolicy() - query = Mock() - - decision = policy.on_read_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - required_responses=2, - received_responses=1, # Not enough responses - data_retrieved=False, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETHROW, None) - - def test_on_read_timeout_max_retries_exceeded(self): - """ - Test read timeout when max retries exceeded. - - What this tests: - --------------- - 1. After max_retries attempts, stop retrying - 2. Return RETHROW decision - - Why this matters: - ---------------- - Prevents infinite retry loops and ensures eventual - failure when operations consistently timeout. - """ - policy = AsyncRetryPolicy(max_retries=2) - query = Mock() - - decision = policy.on_read_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - required_responses=2, - received_responses=2, - data_retrieved=False, - retry_num=2, # Already at max retries - ) - - assert decision == (RetryPolicy.RETHROW, None) - - def test_on_read_timeout_data_retrieved(self): - """ - Test read timeout when data was retrieved. - - What this tests: - --------------- - 1. When data_retrieved=True, RETRY the read - 2. Data retrieved means we got some data and retry might get more - - Why this matters: - ---------------- - If we already got some data, retrying might get the complete - result set. This implementation differs from standard behavior. - """ - policy = AsyncRetryPolicy() - query = Mock() - - decision = policy.on_read_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - required_responses=2, - received_responses=2, - data_retrieved=True, # Got some data - retry_num=0, - ) - - assert decision == (RetryPolicy.RETRY, ConsistencyLevel.QUORUM) - - def test_on_read_timeout_all_consistency_levels(self): - """ - Test read timeout behavior across all consistency levels. - - What this tests: - --------------- - 1. Policy works with all ConsistencyLevel values - 2. Retry logic is consistent across levels - - Why this matters: - ---------------- - Applications use different consistency levels for different - use cases. The retry policy must handle all of them. - """ - policy = AsyncRetryPolicy() - query = Mock() - - consistency_levels = [ - ConsistencyLevel.ANY, - ConsistencyLevel.ONE, - ConsistencyLevel.TWO, - ConsistencyLevel.THREE, - ConsistencyLevel.QUORUM, - ConsistencyLevel.ALL, - ConsistencyLevel.LOCAL_QUORUM, - ConsistencyLevel.EACH_QUORUM, - ConsistencyLevel.LOCAL_ONE, - ] - - for cl in consistency_levels: - # Test with sufficient responses - decision = policy.on_read_timeout( - query=query, - consistency=cl, - required_responses=2, - received_responses=2, - data_retrieved=False, - retry_num=0, - ) - assert decision == (RetryPolicy.RETRY, cl) - - # ======================================== - # Write Timeout Tests - # ======================================== - - def test_on_write_timeout_idempotent_simple_statement(self): - """ - Test write timeout for idempotent simple statement. - - What this tests: - --------------- - 1. Idempotent writes are retried - 2. Consistency level is preserved - 3. WriteType.SIMPLE is handled correctly - - Why this matters: - ---------------- - Idempotent operations can be safely retried without - risk of duplicate effects. - """ - policy = AsyncRetryPolicy() - query = Mock(is_idempotent=True) - - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.SIMPLE, - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETRY, ConsistencyLevel.QUORUM) - - def test_on_write_timeout_non_idempotent_simple_statement(self): - """ - Test write timeout for non-idempotent simple statement. - - What this tests: - --------------- - 1. Non-idempotent writes are NOT retried - 2. Returns RETHROW decision - - Why this matters: - ---------------- - Non-idempotent operations (like counter updates) could - cause data corruption if retried after partial success. - """ - policy = AsyncRetryPolicy() - query = Mock(is_idempotent=False) - - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.SIMPLE, - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETHROW, None) - - def test_on_write_timeout_batch_log_write(self): - """ - Test write timeout during batch log write. - - What this tests: - --------------- - 1. BATCH_LOG writes are NOT retried in this implementation - 2. Only SIMPLE, BATCH, and UNLOGGED_BATCH are retried if idempotent - - Why this matters: - ---------------- - This implementation focuses on user-facing write types. - BATCH_LOG is an internal operation that's not covered. - """ - policy = AsyncRetryPolicy() - # Even idempotent query won't retry for BATCH_LOG - query = Mock(is_idempotent=True) - - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.BATCH_LOG, - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETHROW, None) - - def test_on_write_timeout_unlogged_batch_idempotent(self): - """ - Test write timeout for idempotent UNLOGGED_BATCH. - - What this tests: - --------------- - 1. UNLOGGED_BATCH is retried if the batch itself is marked idempotent - 2. Individual statement idempotency is not checked here - - Why this matters: - ---------------- - The retry policy checks the batch's is_idempotent attribute, - not the individual statements within it. - """ - policy = AsyncRetryPolicy() - - # Create a batch statement marked as idempotent - from cassandra.query import BatchStatement - - batch = BatchStatement() - batch.is_idempotent = True # Mark the batch itself as idempotent - batch._statements_and_parameters = [ - (Mock(is_idempotent=True), []), - (Mock(is_idempotent=True), []), - ] - - decision = policy.on_write_timeout( - query=batch, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.UNLOGGED_BATCH, - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETRY, ConsistencyLevel.QUORUM) - - def test_on_write_timeout_unlogged_batch_mixed_idempotency(self): - """ - Test write timeout for UNLOGGED_BATCH with mixed idempotency. - - What this tests: - --------------- - 1. Batch with any non-idempotent statement is not retried - 2. Partial idempotency is not sufficient - - Why this matters: - ---------------- - A single non-idempotent statement in an unlogged batch - makes the entire batch non-retriable. - """ - policy = AsyncRetryPolicy() - - from cassandra.query import BatchStatement - - batch = BatchStatement() - batch._statements_and_parameters = [ - (Mock(is_idempotent=True), []), # Idempotent - (Mock(is_idempotent=False), []), # Non-idempotent - ] - - decision = policy.on_write_timeout( - query=batch, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.UNLOGGED_BATCH, - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETHROW, None) - - def test_on_write_timeout_logged_batch(self): - """ - Test that LOGGED batches are handled as BATCH write type. - - What this tests: - --------------- - 1. LOGGED batches use WriteType.BATCH (not UNLOGGED_BATCH) - 2. Different retry logic applies - - Why this matters: - ---------------- - LOGGED batches have atomicity guarantees through the batch log, - so they have different retry semantics than UNLOGGED batches. - """ - policy = AsyncRetryPolicy() - - from cassandra.query import BatchStatement, BatchType - - batch = BatchStatement(batch_type=BatchType.LOGGED) - - # For BATCH write type, should check idempotency - batch.is_idempotent = True - - decision = policy.on_write_timeout( - query=batch, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.BATCH, # Not UNLOGGED_BATCH - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETRY, ConsistencyLevel.QUORUM) - - def test_on_write_timeout_counter_write(self): - """ - Test write timeout for counter operations. - - What this tests: - --------------- - 1. Counter writes are never retried - 2. WriteType.COUNTER is handled correctly - - Why this matters: - ---------------- - Counter operations are not idempotent by nature. - Retrying could lead to incorrect counter values. - """ - policy = AsyncRetryPolicy() - query = Mock() # Counters are never idempotent - - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.COUNTER, - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision == (RetryPolicy.RETHROW, None) - - def test_on_write_timeout_max_retries_exceeded(self): - """ - Test write timeout when max retries exceeded. - - What this tests: - --------------- - 1. After max_retries attempts, stop retrying - 2. Even idempotent operations are not retried - - Why this matters: - ---------------- - Prevents infinite retry loops for consistently failing writes. - """ - policy = AsyncRetryPolicy(max_retries=1) - query = Mock(is_idempotent=True) - - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.SIMPLE, - required_responses=2, - received_responses=1, - retry_num=1, # Already at max retries - ) - - assert decision == (RetryPolicy.RETHROW, None) - - # ======================================== - # Unavailable Tests - # ======================================== - - def test_on_unavailable_first_attempt(self): - """ - Test handling unavailable exception on first attempt. - - What this tests: - --------------- - 1. First unavailable error triggers RETRY_NEXT_HOST - 2. Consistency level is preserved - - Why this matters: - ---------------- - Temporary node failures are common. Trying the next host - often succeeds when the current coordinator is having issues. - """ - policy = AsyncRetryPolicy() - query = Mock() - - decision = policy.on_unavailable( - query=query, - consistency=ConsistencyLevel.QUORUM, - required_replicas=3, - alive_replicas=2, - retry_num=0, - ) - - # Should retry on next host with same consistency - assert decision == (RetryPolicy.RETRY_NEXT_HOST, ConsistencyLevel.QUORUM) - - def test_on_unavailable_max_retries_exceeded(self): - """ - Test unavailable exception when max retries exceeded. - - What this tests: - --------------- - 1. After max retries, stop trying - 2. Return RETHROW decision - - Why this matters: - ---------------- - If nodes remain unavailable after multiple attempts, - the cluster likely has serious issues. - """ - policy = AsyncRetryPolicy(max_retries=2) - query = Mock() - - decision = policy.on_unavailable( - query=query, - consistency=ConsistencyLevel.QUORUM, - required_replicas=3, - alive_replicas=1, - retry_num=2, - ) - - assert decision == (RetryPolicy.RETHROW, None) - - def test_on_unavailable_consistency_downgrade(self): - """ - Test that consistency level is NOT downgraded on unavailable. - - What this tests: - --------------- - 1. Policy preserves original consistency level - 2. No automatic downgrade in this implementation - - Why this matters: - ---------------- - This implementation maintains consistency requirements - rather than trading consistency for availability. - """ - policy = AsyncRetryPolicy() - query = Mock() - - # Test that consistency is preserved on retry - decision = policy.on_unavailable( - query=query, - consistency=ConsistencyLevel.QUORUM, - required_replicas=2, - alive_replicas=1, # Only 1 alive, can't do QUORUM - retry_num=1, # Not first attempt, so RETRY not RETRY_NEXT_HOST - ) - - # Should retry with SAME consistency level - assert decision == (RetryPolicy.RETRY, ConsistencyLevel.QUORUM) - - # ======================================== - # Idempotency Tests - # ======================================== - - def test_idempotency_check_simple_statement(self): - """ - Test idempotency checking for simple statements. - - What this tests: - --------------- - 1. Simple statements have is_idempotent attribute - 2. Attribute is checked correctly - - Why this matters: - ---------------- - Idempotency is critical for safe retries. Must be - explicitly set by the application. - """ - policy = AsyncRetryPolicy() - - # Test idempotent statement - idempotent_query = Mock(is_idempotent=True) - decision = policy.on_write_timeout( - query=idempotent_query, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.SIMPLE, - required_responses=1, - received_responses=0, - retry_num=0, - ) - assert decision[0] == RetryPolicy.RETRY - - # Test non-idempotent statement - non_idempotent_query = Mock(is_idempotent=False) - decision = policy.on_write_timeout( - query=non_idempotent_query, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.SIMPLE, - required_responses=1, - received_responses=0, - retry_num=0, - ) - assert decision[0] == RetryPolicy.RETHROW - - def test_idempotency_check_prepared_statement(self): - """ - Test idempotency checking for prepared statements. - - What this tests: - --------------- - 1. Prepared statements can be marked idempotent - 2. Idempotency is preserved through preparation - - Why this matters: - ---------------- - Prepared statements are the recommended way to execute - queries. Their idempotency must be tracked. - """ - policy = AsyncRetryPolicy() - - # Mock prepared statement - from cassandra.query import PreparedStatement - - prepared = Mock(spec=PreparedStatement) - prepared.is_idempotent = True - - decision = policy.on_write_timeout( - query=prepared, - consistency=ConsistencyLevel.QUORUM, - write_type=WriteType.SIMPLE, - required_responses=2, - received_responses=1, - retry_num=0, - ) - - assert decision[0] == RetryPolicy.RETRY - - def test_idempotency_missing_attribute(self): - """ - Test handling of queries without is_idempotent attribute. - - What this tests: - --------------- - 1. Missing attribute is treated as non-idempotent - 2. Safe default behavior - - Why this matters: - ---------------- - Safety first: if we don't know if an operation is - idempotent, assume it's not. - """ - policy = AsyncRetryPolicy() - - # Query without is_idempotent attribute - query = Mock(spec=[]) # No attributes - - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.SIMPLE, - required_responses=1, - received_responses=0, - retry_num=0, - ) - - assert decision[0] == RetryPolicy.RETHROW - - def test_batch_idempotency_validation(self): - """ - Test batch idempotency validation. - - What this tests: - --------------- - 1. Batch must have is_idempotent=True to be retried - 2. Individual statement idempotency is not checked - 3. Missing is_idempotent attribute means non-idempotent - - Why this matters: - ---------------- - The retry policy only checks the batch's own idempotency flag, - not the individual statements within it. - """ - policy = AsyncRetryPolicy() - - from cassandra.query import BatchStatement - - # Test batch without is_idempotent attribute (default) - default_batch = BatchStatement() - # Don't set is_idempotent - should default to non-idempotent - - decision = policy.on_write_timeout( - query=default_batch, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.UNLOGGED_BATCH, - required_responses=1, - received_responses=0, - retry_num=0, - ) - # Batch without explicit is_idempotent=True should not retry - assert decision[0] == RetryPolicy.RETHROW - - # Test batch explicitly marked idempotent - idempotent_batch = BatchStatement() - idempotent_batch.is_idempotent = True - - decision = policy.on_write_timeout( - query=idempotent_batch, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.UNLOGGED_BATCH, - required_responses=1, - received_responses=0, - retry_num=0, - ) - assert decision[0] == RetryPolicy.RETRY - - # Test batch explicitly marked non-idempotent - non_idempotent_batch = BatchStatement() - non_idempotent_batch.is_idempotent = False - - decision = policy.on_write_timeout( - query=non_idempotent_batch, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.UNLOGGED_BATCH, - required_responses=1, - received_responses=0, - retry_num=0, - ) - assert decision[0] == RetryPolicy.RETHROW - - # ======================================== - # Error Propagation Tests - # ======================================== - - def test_request_error_handling(self): - """ - Test on_request_error method. - - What this tests: - --------------- - 1. Request errors trigger RETRY_NEXT_HOST - 2. Max retries is respected - - Why this matters: - ---------------- - Connection errors and other request failures should - try a different coordinator node. - """ - policy = AsyncRetryPolicy() - query = Mock() - error = Exception("Connection failed") - - # First attempt should try next host - decision = policy.on_request_error( - query=query, consistency=ConsistencyLevel.QUORUM, error=error, retry_num=0 - ) - assert decision == (RetryPolicy.RETRY_NEXT_HOST, ConsistencyLevel.QUORUM) - - # After max retries, should rethrow - decision = policy.on_request_error( - query=query, - consistency=ConsistencyLevel.QUORUM, - error=error, - retry_num=3, # At max retries - ) - assert decision == (RetryPolicy.RETHROW, None) - - # ======================================== - # Edge Cases - # ======================================== - - def test_retry_with_zero_max_retries(self): - """ - Test that zero max_retries means no retries. - - What this tests: - --------------- - 1. max_retries=0 disables all retries - 2. First attempt is not counted as retry - - Why this matters: - ---------------- - Some applications want to handle retries at a higher - level and disable driver-level retries. - """ - policy = AsyncRetryPolicy(max_retries=0) - query = Mock(is_idempotent=True) - - # Even on first attempt (retry_num=0), should not retry - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.SIMPLE, - required_responses=1, - received_responses=0, - retry_num=0, - ) - - assert decision[0] == RetryPolicy.RETHROW - - def test_consistency_level_all_special_handling(self): - """ - Test special handling for ConsistencyLevel.ALL. - - What this tests: - --------------- - 1. ALL consistency has special retry considerations - 2. May not retry even when others would - - Why this matters: - ---------------- - ConsistencyLevel.ALL requires all replicas. If any - are down, retrying won't help. - """ - policy = AsyncRetryPolicy() - query = Mock() - - decision = policy.on_unavailable( - query=query, - consistency=ConsistencyLevel.ALL, - required_replicas=3, - alive_replicas=2, # Missing one replica - retry_num=0, - ) - - # Implementation dependent, but should handle ALL specially - assert decision is not None # Use the decision variable - - def test_query_string_not_accessed(self): - """ - Test that retry policy doesn't access query internals. - - What this tests: - --------------- - 1. Policy only uses public query attributes - 2. No query string parsing or inspection - - Why this matters: - ---------------- - Retry decisions should be based on metadata, not - query content. This ensures performance and security. - """ - policy = AsyncRetryPolicy() - - # Mock with minimal interface - query = Mock() - query.is_idempotent = True - # Don't provide query string or other internals - - # Should work without accessing query details - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.SIMPLE, - required_responses=1, - received_responses=0, - retry_num=0, - ) - - assert decision[0] == RetryPolicy.RETRY - - def test_concurrent_retry_decisions(self): - """ - Test that retry policy is thread-safe. - - What this tests: - --------------- - 1. Multiple threads can use same policy instance - 2. No shared state corruption - - Why this matters: - ---------------- - In async applications, the same retry policy instance - may be used by multiple concurrent operations. - """ - import threading - - policy = AsyncRetryPolicy() - results = [] - - def make_decision(): - query = Mock(is_idempotent=True) - decision = policy.on_write_timeout( - query=query, - consistency=ConsistencyLevel.ONE, - write_type=WriteType.SIMPLE, - required_responses=1, - received_responses=0, - retry_num=0, - ) - results.append(decision) - - # Run multiple threads - threads = [threading.Thread(target=make_decision) for _ in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - # All should get same decision - assert len(results) == 10 - assert all(r[0] == RetryPolicy.RETRY for r in results) diff --git a/tests/unit/test_schema_changes.py b/tests/unit/test_schema_changes.py deleted file mode 100644 index d65c09f..0000000 --- a/tests/unit/test_schema_changes.py +++ /dev/null @@ -1,483 +0,0 @@ -""" -Unit tests for schema change handling. - -Tests how the async wrapper handles: -- Schema change events -- Metadata refresh -- Schema agreement -- DDL operation execution -- Prepared statement invalidation on schema changes -""" - -import asyncio -from unittest.mock import Mock, patch - -import pytest -from cassandra import AlreadyExists, InvalidRequest - -from async_cassandra import AsyncCassandraSession, AsyncCluster - - -class TestSchemaChanges: - """Test schema change handling scenarios.""" - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock() - session.execute_async = Mock() - session.prepare_async = Mock() - session.cluster = Mock() - return session - - def create_error_future(self, exception): - """Create a mock future that raises the given exception.""" - future = Mock() - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - # Call errback immediately with the error - errback(exception) - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - return future - - def _create_mock_future(self, result=None, error=None): - """Create a properly configured mock future that simulates driver behavior.""" - future = Mock() - - # Store callbacks - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - - # Delay the callback execution to allow AsyncResultHandler to set up properly - def execute_callback(): - if error: - if errback: - errback(error) - else: - if callback and result is not None: - # For successful results, pass rows - rows = getattr(result, "rows", []) - callback(rows) - - # Schedule callback for next event loop iteration - try: - loop = asyncio.get_running_loop() - loop.call_soon(execute_callback) - except RuntimeError: - # No event loop, execute immediately - execute_callback() - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - - return future - - @pytest.mark.asyncio - async def test_create_table_already_exists(self, mock_session): - """ - Test handling of AlreadyExists errors during schema changes. - - What this tests: - --------------- - 1. CREATE TABLE on existing table - 2. AlreadyExists wrapped in QueryError - 3. Keyspace and table info preserved - 4. Error details accessible - - Why this matters: - ---------------- - Schema conflicts common in: - - Concurrent deployments - - Idempotent migrations - - Multi-datacenter setups - - Applications need to handle - schema conflicts gracefully. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock AlreadyExists error - error = AlreadyExists(keyspace="test_ks", table="test_table") - mock_session.execute_async.return_value = self.create_error_future(error) - - # AlreadyExists is passed through directly - with pytest.raises(AlreadyExists) as exc_info: - await async_session.execute("CREATE TABLE test_table (id int PRIMARY KEY)") - - assert exc_info.value.keyspace == "test_ks" - assert exc_info.value.table == "test_table" - - @pytest.mark.asyncio - async def test_ddl_invalid_syntax(self, mock_session): - """ - Test handling of invalid DDL syntax. - - What this tests: - --------------- - 1. DDL syntax errors detected - 2. InvalidRequest not wrapped - 3. Parser error details shown - 4. Line/column info preserved - - Why this matters: - ---------------- - DDL syntax errors indicate: - - Typos in schema scripts - - Version incompatibilities - - Invalid CQL constructs - - Clear errors help developers - fix schema definitions quickly. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock InvalidRequest error - error = InvalidRequest("line 1:13 no viable alternative at input 'TABEL'") - mock_session.execute_async.return_value = self.create_error_future(error) - - # InvalidRequest is NOT wrapped - it's in the re-raise list - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute("CREATE TABEL test (id int PRIMARY KEY)") - - assert "no viable alternative" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_create_keyspace_already_exists(self, mock_session): - """ - Test handling of keyspace already exists errors. - - What this tests: - --------------- - 1. CREATE KEYSPACE conflicts - 2. AlreadyExists for keyspaces - 3. Table field is None - 4. Wrapped in QueryError - - Why this matters: - ---------------- - Keyspace conflicts occur when: - - Multiple apps create keyspaces - - Deployment race conditions - - Recreating environments - - Idempotent keyspace creation - requires proper error handling. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock AlreadyExists error for keyspace - error = AlreadyExists(keyspace="test_keyspace", table=None) - mock_session.execute_async.return_value = self.create_error_future(error) - - # AlreadyExists is passed through directly - with pytest.raises(AlreadyExists) as exc_info: - await async_session.execute( - "CREATE KEYSPACE test_keyspace WITH replication = " - "{'class': 'SimpleStrategy', 'replication_factor': 1}" - ) - - assert exc_info.value.keyspace == "test_keyspace" - assert exc_info.value.table is None - - @pytest.mark.asyncio - async def test_concurrent_ddl_operations(self, mock_session): - """ - Test handling of concurrent DDL operations. - - What this tests: - --------------- - 1. Multiple DDL ops can run concurrently - 2. No interference between operations - 3. All operations complete - 4. Order not guaranteed - - Why this matters: - ---------------- - Schema migrations often involve: - - Multiple table creations - - Index additions - - Concurrent alterations - - Async wrapper must handle parallel - DDL operations safely. - """ - async_session = AsyncCassandraSession(mock_session) - - # Track DDL operations - ddl_operations = [] - - def execute_async_side_effect(*args, **kwargs): - query = args[0] if args else kwargs.get("query", "") - ddl_operations.append(query) - - # Use the same pattern as test_session_edge_cases - result = Mock() - result.rows = [] # DDL operations return no rows - return self._create_mock_future(result=result) - - mock_session.execute_async.side_effect = execute_async_side_effect - - # Execute multiple DDL operations concurrently - ddl_queries = [ - "CREATE TABLE table1 (id int PRIMARY KEY)", - "CREATE TABLE table2 (id int PRIMARY KEY)", - "ALTER TABLE table1 ADD column1 text", - "CREATE INDEX idx1 ON table1 (column1)", - "DROP TABLE IF EXISTS table3", - ] - - tasks = [async_session.execute(query) for query in ddl_queries] - await asyncio.gather(*tasks) - - # All DDL operations should have been executed - assert len(ddl_operations) == 5 - assert all(query in ddl_operations for query in ddl_queries) - - @pytest.mark.asyncio - async def test_alter_table_column_type_error(self, mock_session): - """ - Test handling of invalid column type changes. - - What this tests: - --------------- - 1. Invalid type changes rejected - 2. InvalidRequest not wrapped - 3. Type conflict details shown - 4. Original types mentioned - - Why this matters: - ---------------- - Type changes restricted because: - - Data compatibility issues - - Storage format conflicts - - Query implications - - Developers need clear guidance - on valid schema evolution. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock InvalidRequest for incompatible type change - error = InvalidRequest("Cannot change column type from 'int' to 'text'") - mock_session.execute_async.return_value = self.create_error_future(error) - - # InvalidRequest is NOT wrapped - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute("ALTER TABLE users ALTER age TYPE text") - - assert "Cannot change column type" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_drop_nonexistent_keyspace(self, mock_session): - """ - Test dropping a non-existent keyspace. - - What this tests: - --------------- - 1. DROP on missing keyspace - 2. InvalidRequest not wrapped - 3. Clear error message - 4. Keyspace name in error - - Why this matters: - ---------------- - Drop operations may fail when: - - Cleanup scripts run twice - - Keyspace already removed - - Name typos - - IF EXISTS clause recommended - for idempotent drops. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock InvalidRequest for non-existent keyspace - error = InvalidRequest("Keyspace 'nonexistent' doesn't exist") - mock_session.execute_async.return_value = self.create_error_future(error) - - # InvalidRequest is NOT wrapped - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute("DROP KEYSPACE nonexistent") - - assert "doesn't exist" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_create_type_already_exists(self, mock_session): - """ - Test creating a user-defined type that already exists. - - What this tests: - --------------- - 1. CREATE TYPE conflicts - 2. UDTs treated like tables - 3. AlreadyExists wrapped - 4. Type name in error - - Why this matters: - ---------------- - User-defined types (UDTs): - - Share namespace with tables - - Support complex data models - - Need conflict handling - - Schema with UDTs requires - careful version control. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock AlreadyExists for UDT - error = AlreadyExists(keyspace="test_ks", table="address_type") - mock_session.execute_async.return_value = self.create_error_future(error) - - # AlreadyExists is passed through directly - with pytest.raises(AlreadyExists) as exc_info: - await async_session.execute( - "CREATE TYPE address_type (street text, city text, zip int)" - ) - - assert exc_info.value.keyspace == "test_ks" - assert exc_info.value.table == "address_type" - - @pytest.mark.asyncio - async def test_batch_ddl_operations(self, mock_session): - """ - Test that DDL operations cannot be batched. - - What this tests: - --------------- - 1. DDL not allowed in batches - 2. InvalidRequest not wrapped - 3. Clear error message - 4. Cassandra limitation enforced - - Why this matters: - ---------------- - DDL restrictions exist because: - - Schema changes are global - - Cannot be transactional - - Affect all nodes - - Schema changes must be - executed individually. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock InvalidRequest for DDL in batch - error = InvalidRequest("DDL statements cannot be batched") - mock_session.execute_async.return_value = self.create_error_future(error) - - # InvalidRequest is NOT wrapped - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute( - """ - BEGIN BATCH - CREATE TABLE t1 (id int PRIMARY KEY); - CREATE TABLE t2 (id int PRIMARY KEY); - APPLY BATCH; - """ - ) - - assert "cannot be batched" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_schema_metadata_access(self): - """ - Test accessing schema metadata through the cluster. - - What this tests: - --------------- - 1. Metadata accessible via cluster - 2. Keyspace information available - 3. Schema discovery works - 4. No async wrapper needed - - Why this matters: - ---------------- - Metadata access enables: - - Dynamic schema discovery - - Table introspection - - Type information - - Applications use metadata for - ORM mapping and validation. - """ - with patch("async_cassandra.cluster.Cluster") as mock_cluster_class: - # Create mock cluster with metadata - mock_cluster = Mock() - mock_cluster_class.return_value = mock_cluster - - # Mock metadata - mock_metadata = Mock() - mock_metadata.keyspaces = { - "system": Mock(name="system"), - "test_ks": Mock(name="test_ks"), - } - mock_cluster.metadata = mock_metadata - - async_cluster = AsyncCluster(contact_points=["127.0.0.1"]) - - # Access metadata - metadata = async_cluster.metadata - assert "system" in metadata.keyspaces - assert "test_ks" in metadata.keyspaces - - await async_cluster.shutdown() - - @pytest.mark.asyncio - async def test_materialized_view_already_exists(self, mock_session): - """ - Test creating a materialized view that already exists. - - What this tests: - --------------- - 1. MV conflicts detected - 2. AlreadyExists wrapped - 3. View name in error - 4. Same handling as tables - - Why this matters: - ---------------- - Materialized views: - - Auto-maintained denormalization - - Share table namespace - - Need conflict resolution - - MV schema changes require same - care as regular tables. - """ - async_session = AsyncCassandraSession(mock_session) - - # Mock AlreadyExists for materialized view - error = AlreadyExists(keyspace="test_ks", table="user_by_email") - mock_session.execute_async.return_value = self.create_error_future(error) - - # AlreadyExists is passed through directly - with pytest.raises(AlreadyExists) as exc_info: - await async_session.execute( - """ - CREATE MATERIALIZED VIEW user_by_email AS - SELECT * FROM users - WHERE email IS NOT NULL - PRIMARY KEY (email, id) - """ - ) - - assert exc_info.value.table == "user_by_email" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py deleted file mode 100644 index 6871927..0000000 --- a/tests/unit/test_session.py +++ /dev/null @@ -1,609 +0,0 @@ -""" -Unit tests for async session management. - -This module thoroughly tests AsyncCassandraSession, covering: -- Session creation from cluster -- Query execution (simple and parameterized) -- Prepared statement handling -- Batch operations -- Error handling and propagation -- Resource cleanup and context managers -- Streaming operations -- Edge cases and error conditions - -Key Testing Patterns: -==================== -- Mocks ResponseFuture to simulate async operations -- Tests callback-based async conversion -- Verifies proper error wrapping -- Ensures resource cleanup in all paths -""" - -from unittest.mock import AsyncMock, Mock, patch - -import pytest -from cassandra.cluster import ResponseFuture, Session -from cassandra.query import PreparedStatement - -from async_cassandra.exceptions import ConnectionError, QueryError -from async_cassandra.result import AsyncResultSet -from async_cassandra.session import AsyncCassandraSession - - -class TestAsyncCassandraSession: - """ - Test cases for AsyncCassandraSession. - - AsyncCassandraSession is the core interface for executing queries. - It converts the driver's callback-based async operations into - Python async/await compatible operations. - """ - - @pytest.fixture - def mock_session(self): - """ - Create a mock Cassandra session. - - Provides a minimal session interface for testing - without actual database connections. - """ - session = Mock(spec=Session) - session.keyspace = "test_keyspace" - session.shutdown = Mock() - return session - - @pytest.fixture - def async_session(self, mock_session): - """ - Create an AsyncCassandraSession instance. - - Uses the mock_session fixture to avoid real connections. - """ - return AsyncCassandraSession(mock_session) - - @pytest.mark.asyncio - async def test_create_session(self): - """ - Test creating a session from cluster. - - What this tests: - --------------- - 1. create() class method works - 2. Keyspace is passed to cluster.connect() - 3. Returns AsyncCassandraSession instance - - Why this matters: - ---------------- - The create() method is a factory that: - - Handles sync cluster.connect() call - - Wraps result in async session - - Sets initial keyspace if provided - - This is the primary way to get a session. - """ - mock_cluster = Mock() - mock_session = Mock(spec=Session) - mock_cluster.connect.return_value = mock_session - - async_session = await AsyncCassandraSession.create(mock_cluster, "test_keyspace") - - assert isinstance(async_session, AsyncCassandraSession) - # Verify keyspace was used - mock_cluster.connect.assert_called_once_with("test_keyspace") - - @pytest.mark.asyncio - async def test_create_session_without_keyspace(self): - """ - Test creating a session without keyspace. - - What this tests: - --------------- - 1. Keyspace parameter is optional - 2. connect() called without arguments - - Why this matters: - ---------------- - Common patterns: - - Connect first, set keyspace later - - Working across multiple keyspaces - - Administrative operations - """ - mock_cluster = Mock() - mock_session = Mock(spec=Session) - mock_cluster.connect.return_value = mock_session - - async_session = await AsyncCassandraSession.create(mock_cluster) - - assert isinstance(async_session, AsyncCassandraSession) - # Verify no keyspace argument - mock_cluster.connect.assert_called_once_with() - - @pytest.mark.asyncio - async def test_execute_simple_query(self, async_session, mock_session): - """ - Test executing a simple query. - - What this tests: - --------------- - 1. Basic SELECT query execution - 2. Async conversion of ResponseFuture - 3. Results wrapped in AsyncResultSet - 4. Callback mechanism works correctly - - Why this matters: - ---------------- - This is the core functionality - converting driver's - callback-based async into Python async/await: - - Driver: execute_async() -> ResponseFuture -> callbacks - Wrapper: await execute() -> AsyncResultSet - - The AsyncResultHandler manages this conversion. - """ - # Setup mock response future - mock_future = Mock(spec=ResponseFuture) - mock_future.has_more_pages = False - mock_future.add_callbacks = Mock() - mock_session.execute_async.return_value = mock_future - - # Execute query - query = "SELECT * FROM users" - - # Patch AsyncResultHandler to simulate immediate result - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_result = AsyncResultSet([{"id": 1, "name": "test"}]) - mock_handler.get_result = AsyncMock(return_value=mock_result) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute(query) - - assert isinstance(result, AsyncResultSet) - mock_session.execute_async.assert_called_once() - - @pytest.mark.asyncio - async def test_execute_with_parameters(self, async_session, mock_session): - """ - Test executing query with parameters. - - What this tests: - --------------- - 1. Parameterized queries work - 2. Parameters passed to execute_async - 3. ? placeholder syntax supported - - Why this matters: - ---------------- - Parameters are critical for: - - SQL injection prevention - - Query plan caching - - Type safety - - Must ensure parameters flow through correctly. - """ - mock_future = Mock(spec=ResponseFuture) - mock_session.execute_async.return_value = mock_future - - query = "SELECT * FROM users WHERE id = ?" - params = [123] - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_result = AsyncResultSet([]) - mock_handler.get_result = AsyncMock(return_value=mock_result) - mock_handler_class.return_value = mock_handler - - await async_session.execute(query, parameters=params) - - # Verify both query and parameters were passed - call_args = mock_session.execute_async.call_args - assert call_args[0][0] == query - assert call_args[0][1] == params - - @pytest.mark.asyncio - async def test_execute_query_error(self, async_session, mock_session): - """ - Test handling query execution error. - - What this tests: - --------------- - 1. Exceptions from driver are caught - 2. Wrapped in QueryError - 3. Original exception preserved as __cause__ - 4. Helpful error message provided - - Why this matters: - ---------------- - Error handling is critical: - - Users need clear error messages - - Stack traces must be preserved - - Debugging requires full context - - Common errors: - - Network failures - - Invalid queries - - Timeout issues - """ - mock_session.execute_async.side_effect = Exception("Connection failed") - - with pytest.raises(QueryError) as exc_info: - await async_session.execute("SELECT * FROM users") - - assert "Query execution failed" in str(exc_info.value) - # Original exception preserved for debugging - assert exc_info.value.__cause__ is not None - - @pytest.mark.asyncio - async def test_execute_on_closed_session(self, async_session): - """ - Test executing query on closed session. - - What this tests: - --------------- - 1. Closed session check works - 2. Fails fast with ConnectionError - 3. Clear error message - - Why this matters: - ---------------- - Prevents confusing errors: - - No hanging on closed connections - - No cryptic driver errors - - Immediate feedback - - Common scenario: - - Session closed in error handler - - Retry logic tries to use it - - Should fail clearly - """ - await async_session.close() - - with pytest.raises(ConnectionError) as exc_info: - await async_session.execute("SELECT * FROM users") - - assert "Session is closed" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_prepare_statement(self, async_session, mock_session): - """ - Test preparing a statement. - - What this tests: - --------------- - 1. Basic prepared statement creation - 2. Query string is passed correctly to driver - 3. Prepared statement object is returned - 4. Async wrapper handles synchronous prepare call - - Why this matters: - ---------------- - - Prepared statements are critical for performance - - Must work correctly for parameterized queries - - Foundation for safe query execution - - Used in almost every production application - - Additional context: - --------------------------------- - - Prepared statements use ? placeholders - - Driver handles actual preparation - - Wrapper provides async interface - """ - mock_prepared = Mock(spec=PreparedStatement) - mock_session.prepare.return_value = mock_prepared - - query = "SELECT * FROM users WHERE id = ?" - prepared = await async_session.prepare(query) - - assert prepared == mock_prepared - mock_session.prepare.assert_called_once_with(query, None) - - @pytest.mark.asyncio - async def test_prepare_with_custom_payload(self, async_session, mock_session): - """ - Test preparing statement with custom payload. - - What this tests: - --------------- - 1. Custom payload support in prepare method - 2. Payload is correctly passed to driver - 3. Advanced prepare options are preserved - 4. API compatibility with driver features - - Why this matters: - ---------------- - - Custom payloads enable advanced features - - Required for certain driver extensions - - Ensures full driver API coverage - - Used in specialized deployments - - Additional context: - --------------------------------- - - Payloads can contain metadata or hints - - Driver-specific feature passthrough - - Maintains wrapper transparency - """ - mock_prepared = Mock(spec=PreparedStatement) - mock_session.prepare.return_value = mock_prepared - - query = "SELECT * FROM users WHERE id = ?" - payload = {"key": b"value"} - - await async_session.prepare(query, custom_payload=payload) - - mock_session.prepare.assert_called_once_with(query, payload) - - @pytest.mark.asyncio - async def test_prepare_error(self, async_session, mock_session): - """ - Test handling prepare statement error. - - What this tests: - --------------- - 1. Error handling during statement preparation - 2. Exceptions are wrapped in QueryError - 3. Error messages are informative - 4. No resource leaks on preparation failure - - Why this matters: - ---------------- - - Invalid queries must fail gracefully - - Clear errors help debugging - - Prevents silent failures - - Common during development - - Additional context: - --------------------------------- - - Syntax errors caught at prepare time - - Better than runtime query failures - - Helps catch bugs early - """ - mock_session.prepare.side_effect = Exception("Invalid query") - - with pytest.raises(QueryError) as exc_info: - await async_session.prepare("INVALID QUERY") - - assert "Statement preparation failed" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_prepare_on_closed_session(self, async_session): - """ - Test preparing statement on closed session. - - What this tests: - --------------- - 1. Closed session prevents prepare operations - 2. ConnectionError is raised appropriately - 3. Session state is checked before operations - 4. No operations on closed resources - - Why this matters: - ---------------- - - Prevents use-after-close bugs - - Clear error for invalid operations - - Resource safety in async contexts - - Common error in connection pooling - - Additional context: - --------------------------------- - - Sessions may be closed by timeouts - - Error handling must be predictable - - Helps identify lifecycle issues - """ - await async_session.close() - - with pytest.raises(ConnectionError): - await async_session.prepare("SELECT * FROM users") - - @pytest.mark.asyncio - async def test_close_session(self, async_session, mock_session): - """ - Test closing the session. - - What this tests: - --------------- - 1. Session close sets is_closed flag - 2. Underlying driver shutdown is called - 3. Clean resource cleanup - 4. State transition is correct - - Why this matters: - ---------------- - - Proper cleanup prevents resource leaks - - Connection pools need clean shutdown - - Memory leaks in production are critical - - Graceful shutdown is required - - Additional context: - --------------------------------- - - Driver shutdown releases connections - - Must work in async contexts - - Part of session lifecycle management - """ - await async_session.close() - - assert async_session.is_closed - mock_session.shutdown.assert_called_once() - - @pytest.mark.asyncio - async def test_close_idempotent(self, async_session, mock_session): - """ - Test that close is idempotent. - - What this tests: - --------------- - 1. Multiple close calls are safe - 2. Driver shutdown called only once - 3. No errors on repeated close - 4. Idempotent operation guarantee - - Why this matters: - ---------------- - - Defensive programming principle - - Simplifies error handling code - - Prevents double-free issues - - Common in cleanup handlers - - Additional context: - --------------------------------- - - May be called from multiple paths - - Exception handlers often close twice - - Standard pattern in resource management - """ - await async_session.close() - await async_session.close() - - # Should only be called once - mock_session.shutdown.assert_called_once() - - @pytest.mark.asyncio - async def test_context_manager(self, mock_session): - """ - Test using session as async context manager. - - What this tests: - --------------- - 1. Async context manager protocol support - 2. Session is open within context - 3. Automatic cleanup on context exit - 4. Exception safety in context manager - - Why this matters: - ---------------- - - Pythonic resource management - - Guarantees cleanup even with exceptions - - Prevents resource leaks - - Best practice for session usage - - Additional context: - --------------------------------- - - async with syntax is preferred - - Handles all cleanup paths - - Standard Python pattern - """ - async with AsyncCassandraSession(mock_session) as session: - assert isinstance(session, AsyncCassandraSession) - assert not session.is_closed - - # Session should be closed after exiting context - mock_session.shutdown.assert_called_once() - - @pytest.mark.asyncio - async def test_set_keyspace(self, async_session): - """ - Test setting keyspace. - - What this tests: - --------------- - 1. Keyspace change via USE statement - 2. Execute method called with correct query - 3. Async execution of keyspace change - 4. No errors on valid keyspace - - Why this matters: - ---------------- - - Multi-tenant applications switch keyspaces - - Session reuse across keyspaces - - Avoids creating multiple sessions - - Common operational requirement - - Additional context: - --------------------------------- - - USE statement changes active keyspace - - Affects all subsequent queries - - Alternative to connection-time keyspace - """ - with patch.object(async_session, "execute") as mock_execute: - mock_execute.return_value = AsyncResultSet([]) - - await async_session.set_keyspace("new_keyspace") - - mock_execute.assert_called_once_with("USE new_keyspace") - - @pytest.mark.asyncio - async def test_set_keyspace_invalid_name(self, async_session): - """ - Test setting keyspace with invalid name. - - What this tests: - --------------- - 1. Validation of keyspace names - 2. Rejection of invalid characters - 3. SQL injection prevention - 4. Clear error messages - - Why this matters: - ---------------- - - Security against injection attacks - - Prevents malformed CQL execution - - Data integrity protection - - User input validation - - Additional context: - --------------------------------- - - Tests spaces, dashes, semicolons - - CQL identifier rules enforced - - First line of defense - """ - # Test various invalid keyspace names - invalid_names = ["", "keyspace with spaces", "keyspace-with-dash", "keyspace;drop"] - - for invalid_name in invalid_names: - with pytest.raises(ValueError) as exc_info: - await async_session.set_keyspace(invalid_name) - - assert "Invalid keyspace name" in str(exc_info.value) - - def test_keyspace_property(self, async_session, mock_session): - """ - Test keyspace property. - - What this tests: - --------------- - 1. Keyspace property delegates to driver - 2. Read-only access to current keyspace - 3. Property reflects driver state - 4. No caching or staleness - - Why this matters: - ---------------- - - Applications need current keyspace info - - Debugging multi-keyspace operations - - State transparency - - API compatibility with driver - - Additional context: - --------------------------------- - - Property is read-only - - Always reflects driver state - - Used for logging and debugging - """ - mock_session.keyspace = "test_keyspace" - assert async_session.keyspace == "test_keyspace" - - def test_is_closed_property(self, async_session): - """ - Test is_closed property. - - What this tests: - --------------- - 1. Initial state is not closed - 2. Property reflects internal state - 3. Boolean property access - 4. State tracking accuracy - - Why this matters: - ---------------- - - Applications check before operations - - Lifecycle state visibility - - Defensive programming support - - Connection pool management - - Additional context: - --------------------------------- - - Used to prevent use-after-close - - Simple boolean check - - Thread-safe property access - """ - assert not async_session.is_closed - async_session._closed = True - assert async_session.is_closed diff --git a/tests/unit/test_session_edge_cases.py b/tests/unit/test_session_edge_cases.py deleted file mode 100644 index 4ca5224..0000000 --- a/tests/unit/test_session_edge_cases.py +++ /dev/null @@ -1,740 +0,0 @@ -""" -Unit tests for session edge cases and failure scenarios. - -Tests how the async wrapper handles various session-level failures and edge cases -within its existing functionality. -""" - -import asyncio -from unittest.mock import AsyncMock, Mock - -import pytest -from cassandra import InvalidRequest, OperationTimedOut, ReadTimeout, Unavailable, WriteTimeout -from cassandra.cluster import Session -from cassandra.query import BatchStatement, PreparedStatement, SimpleStatement - -from async_cassandra import AsyncCassandraSession - - -class TestSessionEdgeCases: - """Test session edge cases and failure scenarios.""" - - def _create_mock_future(self, result=None, error=None): - """Create a properly configured mock future that simulates driver behavior.""" - future = Mock() - - # Store callbacks - callbacks = [] - errbacks = [] - - def add_callbacks(callback=None, errback=None): - if callback: - callbacks.append(callback) - if errback: - errbacks.append(errback) - - # Delay the callback execution to allow AsyncResultHandler to set up properly - def execute_callback(): - if error: - if errback: - errback(error) - else: - if callback and result is not None: - # For successful results, pass rows - rows = getattr(result, "rows", []) - callback(rows) - - # Schedule callback for next event loop iteration - try: - loop = asyncio.get_running_loop() - loop.call_soon(execute_callback) - except RuntimeError: - # No event loop, execute immediately - execute_callback() - - future.add_callbacks = add_callbacks - future.has_more_pages = False - future.timeout = None - future.clear_callbacks = Mock() - - return future - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - session = Mock(spec=Session) - session.execute_async = Mock() - session.prepare_async = Mock() - session.close = Mock() - session.close_async = Mock() - session.cluster = Mock() - session.cluster.protocol_version = 5 - return session - - @pytest.fixture - async def async_session(self, mock_session): - """Create an async session wrapper.""" - return AsyncCassandraSession(mock_session) - - @pytest.mark.asyncio - async def test_execute_with_invalid_request(self, async_session, mock_session): - """ - Test handling of InvalidRequest errors. - - What this tests: - --------------- - 1. InvalidRequest not wrapped - 2. Error message preserved - 3. Direct propagation - 4. Query syntax errors - - Why this matters: - ---------------- - InvalidRequest indicates: - - Query syntax errors - - Schema mismatches - - Invalid operations - - Clear errors help developers - fix queries quickly. - """ - # Mock execute_async to fail with InvalidRequest - future = self._create_mock_future(error=InvalidRequest("Table does not exist")) - mock_session.execute_async.return_value = future - - # Should propagate InvalidRequest - with pytest.raises(InvalidRequest) as exc_info: - await async_session.execute("SELECT * FROM nonexistent_table") - - assert "Table does not exist" in str(exc_info.value) - assert mock_session.execute_async.called - - @pytest.mark.asyncio - async def test_execute_with_timeout(self, async_session, mock_session): - """ - Test handling of operation timeout. - - What this tests: - --------------- - 1. OperationTimedOut propagated - 2. Timeout errors not wrapped - 3. Message preserved - 4. Clean error handling - - Why this matters: - ---------------- - Timeouts are common: - - Slow queries - - Network issues - - Overloaded nodes - - Applications need clear - timeout information. - """ - # Mock execute_async to fail with timeout - future = self._create_mock_future(error=OperationTimedOut("Query timed out")) - mock_session.execute_async.return_value = future - - # Should propagate timeout - with pytest.raises(OperationTimedOut) as exc_info: - await async_session.execute("SELECT * FROM large_table") - - assert "Query timed out" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_execute_with_read_timeout(self, async_session, mock_session): - """ - Test handling of read timeout. - - What this tests: - --------------- - 1. ReadTimeout details preserved - 2. Response counts available - 3. Data retrieval flag set - 4. Not wrapped - - Why this matters: - ---------------- - Read timeout details crucial: - - Shows partial success - - Indicates retry potential - - Helps tune consistency - - Details enable smart - retry decisions. - """ - # Mock read timeout - future = self._create_mock_future( - error=ReadTimeout( - "Read timeout", - consistency_level=1, - required_responses=1, - received_responses=0, - data_retrieved=False, - ) - ) - mock_session.execute_async.return_value = future - - # Should propagate read timeout - with pytest.raises(ReadTimeout) as exc_info: - await async_session.execute("SELECT * FROM table") - - # Just verify we got the right exception with the message - assert "Read timeout" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_execute_with_write_timeout(self, async_session, mock_session): - """ - Test handling of write timeout. - - What this tests: - --------------- - 1. WriteTimeout propagated - 2. Write type preserved - 3. Response details available - 4. Proper error type - - Why this matters: - ---------------- - Write timeouts critical: - - May have partial writes - - Write type matters for retry - - Data consistency concerns - - Details determine if - retry is safe. - """ - # Mock write timeout (write_type needs to be numeric) - from cassandra import WriteType - - future = self._create_mock_future( - error=WriteTimeout( - "Write timeout", - consistency_level=1, - required_responses=1, - received_responses=0, - write_type=WriteType.SIMPLE, - ) - ) - mock_session.execute_async.return_value = future - - # Should propagate write timeout - with pytest.raises(WriteTimeout) as exc_info: - await async_session.execute("INSERT INTO table (id) VALUES (1)") - - # Just verify we got the right exception with the message - assert "Write timeout" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_execute_with_unavailable(self, async_session, mock_session): - """ - Test handling of Unavailable exception. - - What this tests: - --------------- - 1. Unavailable propagated - 2. Replica counts preserved - 3. Consistency level shown - 4. Clear error info - - Why this matters: - ---------------- - Unavailable means: - - Not enough replicas up - - Cluster health issue - - Cannot meet consistency - - Shows cluster state for - operations decisions. - """ - # Mock unavailable (consistency is first positional arg) - future = self._create_mock_future( - error=Unavailable( - "Not enough replicas", consistency=1, required_replicas=3, alive_replicas=1 - ) - ) - mock_session.execute_async.return_value = future - - # Should propagate unavailable - with pytest.raises(Unavailable) as exc_info: - await async_session.execute("SELECT * FROM table") - - # Just verify we got the right exception with the message - assert "Not enough replicas" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_prepare_statement_error(self, async_session, mock_session): - """ - Test error handling during statement preparation. - - What this tests: - --------------- - 1. Prepare errors wrapped - 2. QueryError with cause - 3. Error message clear - 4. Original exception preserved - - Why this matters: - ---------------- - Prepare failures indicate: - - Syntax errors - - Schema issues - - Permission problems - - Wrapped to distinguish from - execution errors. - """ - # Mock prepare to fail (it uses sync prepare in executor) - mock_session.prepare.side_effect = InvalidRequest("Syntax error in CQL") - - # Should pass through InvalidRequest directly - with pytest.raises(InvalidRequest) as exc_info: - await async_session.prepare("INVALID CQL SYNTAX") - - assert "Syntax error in CQL" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_execute_prepared_statement(self, async_session, mock_session): - """ - Test executing prepared statements. - - What this tests: - --------------- - 1. Prepared statements work - 2. Parameters handled - 3. Results returned - 4. Proper execution flow - - Why this matters: - ---------------- - Prepared statements are: - - Performance critical - - Security essential - - Most common pattern - - Must work seamlessly - through async wrapper. - """ - # Create mock prepared statement - prepared = Mock(spec=PreparedStatement) - prepared.query = "SELECT * FROM users WHERE id = ?" - - # Mock successful execution - result = Mock() - result.one = Mock(return_value={"id": 1, "name": "test"}) - result.rows = [{"id": 1, "name": "test"}] - future = self._create_mock_future(result=result) - mock_session.execute_async.return_value = future - - # Execute prepared statement - result = await async_session.execute(prepared, [1]) - assert result.one()["id"] == 1 - - @pytest.mark.asyncio - async def test_execute_batch_statement(self, async_session, mock_session): - """ - Test executing batch statements. - - What this tests: - --------------- - 1. Batch execution works - 2. Multiple statements grouped - 3. Parameters preserved - 4. Batch type maintained - - Why this matters: - ---------------- - Batches provide: - - Atomic operations - - Better performance - - Reduced round trips - - Critical for bulk - data operations. - """ - # Create batch statement - batch = BatchStatement() - batch.add(SimpleStatement("INSERT INTO users (id, name) VALUES (%s, %s)"), (1, "user1")) - batch.add(SimpleStatement("INSERT INTO users (id, name) VALUES (%s, %s)"), (2, "user2")) - - # Mock successful execution - result = Mock() - result.rows = [] - future = self._create_mock_future(result=result) - mock_session.execute_async.return_value = future - - # Execute batch - await async_session.execute(batch) - - # Verify batch was executed - mock_session.execute_async.assert_called_once() - call_args = mock_session.execute_async.call_args[0] - assert isinstance(call_args[0], BatchStatement) - - @pytest.mark.asyncio - async def test_concurrent_queries(self, async_session, mock_session): - """ - Test concurrent query execution. - - What this tests: - --------------- - 1. Concurrent execution allowed - 2. All queries complete - 3. Results independent - 4. True parallelism - - Why this matters: - ---------------- - Concurrency essential for: - - High throughput - - Parallel processing - - Efficient resource use - - Async wrapper must enable - true concurrent execution. - """ - # Track execution order to verify concurrency - execution_times = [] - - def execute_side_effect(*args, **kwargs): - import time - - execution_times.append(time.time()) - - # Create result - result = Mock() - result.one = Mock(return_value={"count": len(execution_times)}) - result.rows = [{"count": len(execution_times)}] - - # Use our standard mock future - future = self._create_mock_future(result=result) - return future - - mock_session.execute_async.side_effect = execute_side_effect - - # Execute multiple queries concurrently - queries = [async_session.execute(f"SELECT {i} FROM table") for i in range(10)] - - results = await asyncio.gather(*queries) - - # All should complete - assert len(results) == 10 - assert len(execution_times) == 10 - - # Verify we got results - for result in results: - assert len(result.rows) == 1 - assert result.rows[0]["count"] > 0 - - # The execute_async calls should happen close together (within 100ms) - # This verifies they were submitted concurrently - time_span = max(execution_times) - min(execution_times) - assert time_span < 0.1, f"Queries took {time_span}s, suggesting serial execution" - - @pytest.mark.asyncio - async def test_session_close_idempotent(self, async_session, mock_session): - """ - Test that session close is idempotent. - - What this tests: - --------------- - 1. Multiple closes safe - 2. Shutdown called once - 3. No errors on re-close - 4. State properly tracked - - Why this matters: - ---------------- - Idempotent close needed for: - - Error handling paths - - Multiple cleanup sources - - Resource leak prevention - - Safe cleanup in all - code paths. - """ - # Setup shutdown - mock_session.shutdown = Mock() - - # First close - await async_session.close() - assert mock_session.shutdown.call_count == 1 - - # Second close should be safe - await async_session.close() - # Should still only be called once - assert mock_session.shutdown.call_count == 1 - - @pytest.mark.asyncio - async def test_query_after_close(self, async_session, mock_session): - """ - Test querying after session is closed. - - What this tests: - --------------- - 1. Closed sessions reject queries - 2. ConnectionError raised - 3. Clear error message - 4. State checking works - - Why this matters: - ---------------- - Using closed resources: - - Common bug source - - Hard to debug - - Silent failures bad - - Clear errors prevent - mysterious failures. - """ - # Close session - mock_session.shutdown = Mock() - await async_session.close() - - # Try to execute query - should fail with ConnectionError - from async_cassandra.exceptions import ConnectionError - - with pytest.raises(ConnectionError) as exc_info: - await async_session.execute("SELECT * FROM table") - - assert "Session is closed" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_metrics_recording_on_success(self, mock_session): - """ - Test metrics are recorded on successful queries. - - What this tests: - --------------- - 1. Success metrics recorded - 2. Async metrics work - 3. Proper success flag - 4. No error type - - Why this matters: - ---------------- - Metrics enable: - - Performance monitoring - - Error tracking - - Capacity planning - - Accurate metrics critical - for production observability. - """ - # Create metrics mock - mock_metrics = Mock() - mock_metrics.record_query_metrics = AsyncMock() - - # Create session with metrics - async_session = AsyncCassandraSession(mock_session, metrics=mock_metrics) - - # Mock successful execution - result = Mock() - result.one = Mock(return_value={"id": 1}) - result.rows = [{"id": 1}] - future = self._create_mock_future(result=result) - mock_session.execute_async.return_value = future - - # Execute query - await async_session.execute("SELECT * FROM users") - - # Give time for async metrics recording - await asyncio.sleep(0.1) - - # Verify metrics were recorded - mock_metrics.record_query_metrics.assert_called_once() - call_kwargs = mock_metrics.record_query_metrics.call_args[1] - assert call_kwargs["success"] is True - assert call_kwargs["error_type"] is None - - @pytest.mark.asyncio - async def test_metrics_recording_on_failure(self, mock_session): - """ - Test metrics are recorded on failed queries. - - What this tests: - --------------- - 1. Failure metrics recorded - 2. Error type captured - 3. Success flag false - 4. Async recording works - - Why this matters: - ---------------- - Error metrics reveal: - - Problem patterns - - Error types - - Failure rates - - Essential for debugging - production issues. - """ - # Create metrics mock - mock_metrics = Mock() - mock_metrics.record_query_metrics = AsyncMock() - - # Create session with metrics - async_session = AsyncCassandraSession(mock_session, metrics=mock_metrics) - - # Mock failed execution - future = self._create_mock_future(error=InvalidRequest("Bad query")) - mock_session.execute_async.return_value = future - - # Execute query (should fail) - with pytest.raises(InvalidRequest): - await async_session.execute("INVALID QUERY") - - # Give time for async metrics recording - await asyncio.sleep(0.1) - - # Verify metrics were recorded - mock_metrics.record_query_metrics.assert_called_once() - call_kwargs = mock_metrics.record_query_metrics.call_args[1] - assert call_kwargs["success"] is False - assert call_kwargs["error_type"] == "InvalidRequest" - - @pytest.mark.asyncio - async def test_custom_payload_handling(self, async_session, mock_session): - """ - Test custom payload in queries. - - What this tests: - --------------- - 1. Custom payloads passed through - 2. Correct parameter position - 3. Payload preserved - 4. Driver feature works - - Why this matters: - ---------------- - Custom payloads enable: - - Request tracing - - Debugging metadata - - Cross-system correlation - - Important for distributed - system observability. - """ - # Mock execution with custom payload - result = Mock() - result.custom_payload = {"server_time": "2024-01-01"} - result.rows = [] - future = self._create_mock_future(result=result) - mock_session.execute_async.return_value = future - - # Execute with custom payload - custom_payload = {"client_id": "12345"} - result = await async_session.execute("SELECT * FROM table", custom_payload=custom_payload) - - # Verify custom payload was passed (4th positional arg) - call_args = mock_session.execute_async.call_args[0] - assert call_args[3] == custom_payload # custom_payload is 4th arg - - @pytest.mark.asyncio - async def test_trace_execution(self, async_session, mock_session): - """ - Test query tracing. - - What this tests: - --------------- - 1. Trace flag passed through - 2. Correct parameter position - 3. Tracing enabled - 4. Request setup correct - - Why this matters: - ---------------- - Query tracing helps: - - Debug slow queries - - Understand execution - - Optimize performance - - Essential debugging tool - for production issues. - """ - # Mock execution with trace - result = Mock() - result.get_query_trace = Mock(return_value=Mock(trace_id="abc123")) - result.rows = [] - future = self._create_mock_future(result=result) - mock_session.execute_async.return_value = future - - # Execute with tracing - result = await async_session.execute("SELECT * FROM table", trace=True) - - # Verify trace was requested (3rd positional arg) - call_args = mock_session.execute_async.call_args[0] - assert call_args[2] is True # trace is 3rd arg - - # AsyncResultSet doesn't expose trace methods - that's ok - # Just verify the request was made with trace=True - - @pytest.mark.asyncio - async def test_execution_profile_handling(self, async_session, mock_session): - """ - Test using execution profiles. - - What this tests: - --------------- - 1. Execution profiles work - 2. Profile name passed - 3. Correct parameter position - 4. Driver feature accessible - - Why this matters: - ---------------- - Execution profiles control: - - Consistency levels - - Retry policies - - Load balancing - - Critical for workload - optimization. - """ - # Mock execution - result = Mock() - result.rows = [] - future = self._create_mock_future(result=result) - mock_session.execute_async.return_value = future - - # Execute with custom profile - await async_session.execute("SELECT * FROM table", execution_profile="high_throughput") - - # Verify profile was passed (6th positional arg) - call_args = mock_session.execute_async.call_args[0] - assert call_args[5] == "high_throughput" # execution_profile is 6th arg - - @pytest.mark.asyncio - async def test_timeout_parameter(self, async_session, mock_session): - """ - Test query timeout parameter. - - What this tests: - --------------- - 1. Timeout parameter works - 2. Value passed correctly - 3. Correct position - 4. Per-query timeouts - - Why this matters: - ---------------- - Query timeouts prevent: - - Hanging queries - - Resource exhaustion - - Poor user experience - - Per-query control enables - SLA compliance. - """ - # Mock execution - result = Mock() - result.rows = [] - future = self._create_mock_future(result=result) - mock_session.execute_async.return_value = future - - # Execute with timeout - await async_session.execute("SELECT * FROM table", timeout=5.0) - - # Verify timeout was passed (5th positional arg) - call_args = mock_session.execute_async.call_args[0] - assert call_args[4] == 5.0 # timeout is 5th arg diff --git a/tests/unit/test_simplified_threading.py b/tests/unit/test_simplified_threading.py deleted file mode 100644 index 3e3ff3e..0000000 --- a/tests/unit/test_simplified_threading.py +++ /dev/null @@ -1,455 +0,0 @@ -""" -Unit tests for simplified threading implementation. - -These tests verify that the simplified implementation: -1. Uses only essential locks -2. Accepts reasonable trade-offs -3. Maintains thread safety where necessary -4. Performs better than complex locking -""" - -import asyncio -import time -from unittest.mock import Mock - -import pytest - -from async_cassandra.exceptions import ConnectionError -from async_cassandra.session import AsyncCassandraSession - - -@pytest.mark.asyncio -class TestSimplifiedThreading: - """Test simplified threading and locking implementation.""" - - async def test_no_operation_lock_overhead(self): - """ - Test that operations don't have unnecessary lock overhead. - - What this tests: - --------------- - 1. No locks on individual query operations - 2. Concurrent queries execute without contention - 3. Performance scales with concurrency - 4. 100 operations complete quickly - - Why this matters: - ---------------- - Previous implementations had per-operation locks that - caused contention under high concurrency. The simplified - implementation removes these locks, accepting that: - - Some edge cases during shutdown might be racy - - Performance is more important than perfect consistency - - This test proves the performance benefit is real. - """ - # Create session - mock_session = Mock() - mock_response_future = Mock() - mock_response_future.has_more_pages = False - mock_response_future.add_callbacks = Mock() - mock_response_future.timeout = None - mock_session.execute_async = Mock(return_value=mock_response_future) - - async_session = AsyncCassandraSession(mock_session) - - # Measure time for multiple concurrent operations - start_time = time.perf_counter() - - # Run many concurrent queries - tasks = [] - for i in range(100): - task = asyncio.create_task(async_session.execute(f"SELECT {i}")) - tasks.append(task) - - # Trigger callbacks - await asyncio.sleep(0) # Let tasks start - - # Trigger all callbacks - for call in mock_response_future.add_callbacks.call_args_list: - callback = call[1]["callback"] - callback([f"row{i}" for i in range(10)]) - - # Wait for all to complete - await asyncio.gather(*tasks) - - duration = time.perf_counter() - start_time - - # With simplified implementation, 100 concurrent ops should be very fast - # No operation locks means no contention - assert duration < 0.5 # Should complete in well under 500ms - assert mock_session.execute_async.call_count == 100 - - async def test_simple_close_behavior(self): - """ - Test simplified close behavior without complex state tracking. - - What this tests: - --------------- - 1. Close is simple and predictable - 2. Fixed 5-second delay for driver cleanup - 3. Subsequent operations fail immediately - 4. No complex state machine - - Why this matters: - ---------------- - The simplified implementation uses a simple approach: - - Set closed flag - - Wait 5 seconds for driver threads - - Shutdown underlying session - - This avoids complex tracking of in-flight operations - and accepts that some operations might fail during - the shutdown window. - """ - # Create session - mock_session = Mock() - mock_session.shutdown = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Close should be simple and fast - start_time = time.perf_counter() - await async_session.close() - close_duration = time.perf_counter() - start_time - - # Close includes a 5-second delay to let driver threads finish - assert 5.0 <= close_duration < 6.0 - assert async_session.is_closed - - # Subsequent operations should fail immediately (no complex checks) - with pytest.raises(ConnectionError): - await async_session.execute("SELECT 1") - - async def test_acceptable_race_condition(self): - """ - Test that we accept reasonable race conditions for simplicity. - - What this tests: - --------------- - 1. Operations during close might succeed or fail - 2. No guarantees about in-flight operations - 3. Various error outcomes are acceptable - 4. System remains stable regardless - - Why this matters: - ---------------- - The simplified implementation makes a trade-off: - - Remove complex operation tracking - - Accept that close() might interrupt operations - - Gain significant performance improvement - - This test verifies that the race conditions are - indeed "reasonable" - they don't crash or corrupt - state, they just return errors sometimes. - """ - # Create session - mock_session = Mock() - mock_response_future = Mock() - mock_response_future.has_more_pages = False - mock_response_future.add_callbacks = Mock() - mock_response_future.timeout = None - mock_session.execute_async = Mock(return_value=mock_response_future) - mock_session.shutdown = Mock() - - async_session = AsyncCassandraSession(mock_session) - - results = [] - - async def execute_query(): - """Try to execute during close.""" - try: - # Start the execute - task = asyncio.create_task(async_session.execute("SELECT 1")) - # Give it a moment to start - await asyncio.sleep(0) - - # Trigger callback if it was registered - if mock_response_future.add_callbacks.called: - args = mock_response_future.add_callbacks.call_args - callback = args[1]["callback"] - callback(["row1"]) - - await task - results.append("success") - except ConnectionError: - results.append("closed") - except Exception as e: - # With simplified implementation, we might get driver errors - # if close happens during execution - this is acceptable - results.append(f"error: {type(e).__name__}") - - async def close_session(): - """Close after a tiny delay.""" - await asyncio.sleep(0.001) - await async_session.close() - - # Run concurrently - await asyncio.gather(execute_query(), close_session(), return_exceptions=True) - - # With simplified implementation, we accept that the result - # might be success, closed, or a driver error - assert len(results) == 1 - # Any of these outcomes is acceptable - assert results[0] in ["success", "closed"] or results[0].startswith("error:") - - async def test_no_complex_state_tracking(self): - """ - Test that we don't have complex state tracking. - - What this tests: - --------------- - 1. No _active_operations counter - 2. No _operation_lock for tracking - 3. No _close_event for coordination - 4. Only simple _closed flag and _close_lock - - Why this matters: - ---------------- - Complex state tracking was removed because: - - It added overhead to every operation - - Lock contention hurt performance - - Perfect tracking wasn't needed for correctness - - This test ensures we maintain the simplified - design and don't accidentally reintroduce - complex state management. - """ - # Create session - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Check that we don't have complex state attributes - # These should not exist in simplified implementation - assert not hasattr(async_session, "_active_operations") - assert not hasattr(async_session, "_operation_lock") - assert not hasattr(async_session, "_close_event") - - # Should only have simple state - assert hasattr(async_session, "_closed") - assert hasattr(async_session, "_close_lock") # Single lock for close - - async def test_result_handler_simplified(self): - """ - Test that result handlers are simplified. - - What this tests: - --------------- - 1. Handler has minimal state (just lock and rows) - 2. No complex initialization tracking - 3. No result ready events - 4. Thread lock is still necessary for callbacks - - Why this matters: - ---------------- - AsyncResultHandler bridges driver callbacks to async: - - Must be thread-safe (callbacks from driver threads) - - But doesn't need complex state tracking - - Just needs to safely accumulate results - - The simplified version keeps only what's essential. - """ - from async_cassandra.result import AsyncResultHandler - - mock_future = Mock() - mock_future.has_more_pages = False - mock_future.add_callbacks = Mock() - mock_future.timeout = None - - handler = AsyncResultHandler(mock_future) - - # Should have minimal state tracking - assert hasattr(handler, "_lock") # Thread lock is necessary - assert hasattr(handler, "rows") - - # Should not have complex state tracking - assert not hasattr(handler, "_future_initialized") - assert not hasattr(handler, "_result_ready") - - async def test_streaming_simplified(self): - """ - Test that streaming result set is simplified. - - What this tests: - --------------- - 1. Streaming has thread lock for safety - 2. No complex callback tracking - 3. No active callback counters - 4. Minimal state management - - Why this matters: - ---------------- - Streaming involves multiple callbacks as pages - are fetched. The simplified implementation: - - Keeps thread safety (essential) - - Removes callback counting (not essential) - - Accepts that close() might interrupt streaming - - This maintains functionality while improving performance. - """ - from async_cassandra.streaming import AsyncStreamingResultSet, StreamConfig - - mock_future = Mock() - mock_future.has_more_pages = True - mock_future.add_callbacks = Mock() - - stream = AsyncStreamingResultSet(mock_future, StreamConfig()) - - # Should have thread lock (necessary for callbacks) - assert hasattr(stream, "_lock") - - # Should not have complex callback tracking - assert not hasattr(stream, "_active_callbacks") - - async def test_idempotent_close(self): - """ - Test that close is idempotent with simple implementation. - - What this tests: - --------------- - 1. Multiple close() calls are safe - 2. Only shuts down once - 3. No errors on repeated close - 4. Simple flag-based implementation - - Why this matters: - ---------------- - Users might call close() multiple times: - - In finally blocks - - In error handlers - - In cleanup code - - The simple implementation uses a flag to ensure - shutdown only happens once, without complex locking. - """ - # Create session - mock_session = Mock() - mock_session.shutdown = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Multiple closes should work without complex locking - await async_session.close() - await async_session.close() - await async_session.close() - - # Should only shutdown once - assert mock_session.shutdown.call_count == 1 - - async def test_no_operation_counting(self): - """ - Test that we don't count active operations. - - What this tests: - --------------- - 1. No tracking of in-flight operations - 2. Close doesn't wait for operations - 3. Fixed 5-second delay regardless - 4. Operations might fail during close - - Why this matters: - ---------------- - Operation counting was removed because: - - It required locks on every operation - - Caused contention under load - - Waiting for operations could hang - - The 5-second delay gives driver threads time - to finish naturally, without complex tracking. - """ - # Create session - mock_session = Mock() - mock_response_future = Mock() - mock_response_future.has_more_pages = False - mock_response_future.add_callbacks = Mock() - mock_response_future.timeout = None - - # Make execute_async slow to simulate long operation - async def slow_execute(*args, **kwargs): - await asyncio.sleep(0.1) - return mock_response_future - - mock_session.execute_async = Mock(side_effect=lambda *a, **k: mock_response_future) - mock_session.shutdown = Mock() - - async_session = AsyncCassandraSession(mock_session) - - # Start a query - query_task = asyncio.create_task(async_session.execute("SELECT 1")) - await asyncio.sleep(0.01) # Let it start - - # Close should not wait for operations - start_time = time.perf_counter() - await async_session.close() - close_duration = time.perf_counter() - start_time - - # Close includes a 5-second delay to let driver threads finish - assert 5.0 <= close_duration < 6.0 - - # Query might fail or succeed - both are acceptable - try: - # Trigger callback if query is still running - if mock_response_future.add_callbacks.called: - callback = mock_response_future.add_callbacks.call_args[1]["callback"] - callback(["row"]) - await query_task - except Exception: - # Error is acceptable if close interrupted it - pass - - @pytest.mark.benchmark - async def test_performance_improvement(self): - """ - Benchmark to show performance improvement with simplified locking. - - What this tests: - --------------- - 1. Throughput with many concurrent operations - 2. No lock contention slowing things down - 3. >5000 operations per second achievable - 4. Linear scaling with concurrency - - Why this matters: - ---------------- - This benchmark proves the value of simplification: - - Complex locking: ~1000 ops/second - - Simplified: >5000 ops/second - - The 5x improvement justifies accepting some - edge case race conditions during shutdown. - Real applications care more about throughput - than perfect shutdown semantics. - """ - # This test demonstrates that simplified locking improves performance - - # Create session - mock_session = Mock() - mock_response_future = Mock() - mock_response_future.has_more_pages = False - mock_response_future.add_callbacks = Mock() - mock_response_future.timeout = None - mock_session.execute_async = Mock(return_value=mock_response_future) - - async_session = AsyncCassandraSession(mock_session) - - # Measure throughput - iterations = 1000 - start_time = time.perf_counter() - - tasks = [] - for i in range(iterations): - task = asyncio.create_task(async_session.execute(f"SELECT {i}")) - tasks.append(task) - - # Trigger all callbacks immediately - await asyncio.sleep(0) - for call in mock_response_future.add_callbacks.call_args_list: - callback = call[1]["callback"] - callback(["row"]) - - await asyncio.gather(*tasks) - - duration = time.perf_counter() - start_time - ops_per_second = iterations / duration - - # With simplified locking, should handle >5000 ops/second - assert ops_per_second > 5000 - print(f"Performance: {ops_per_second:.0f} ops/second") diff --git a/tests/unit/test_sql_injection_protection.py b/tests/unit/test_sql_injection_protection.py deleted file mode 100644 index 8632d59..0000000 --- a/tests/unit/test_sql_injection_protection.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Test SQL injection protection in example code.""" - -from unittest.mock import AsyncMock, MagicMock, call - -import pytest - -from async_cassandra import AsyncCassandraSession - - -class TestSQLInjectionProtection: - """Test that example code properly protects against SQL injection.""" - - @pytest.mark.asyncio - async def test_prepared_statements_used_for_user_input(self): - """ - Test that all user inputs use prepared statements. - - What this tests: - --------------- - 1. User input via prepared statements - 2. No dynamic SQL construction - 3. Parameters properly bound - 4. LIMIT values parameterized - - Why this matters: - ---------------- - SQL injection prevention requires: - - ALWAYS use prepared statements - - NEVER concatenate user input - - Parameterize ALL values - - This is THE most critical - security requirement. - """ - # Create mock session - mock_session = AsyncMock(spec=AsyncCassandraSession) - mock_stmt = AsyncMock() - mock_session.prepare.return_value = mock_stmt - - # Test LIMIT parameter - mock_session.execute.return_value = MagicMock() - await mock_session.prepare("SELECT * FROM users LIMIT ?") - await mock_session.execute(mock_stmt, [10]) - - # Verify prepared statement was used - mock_session.prepare.assert_called_with("SELECT * FROM users LIMIT ?") - mock_session.execute.assert_called_with(mock_stmt, [10]) - - @pytest.mark.asyncio - async def test_update_query_no_dynamic_sql(self): - """ - Test that UPDATE queries don't use dynamic SQL construction. - - What this tests: - --------------- - 1. UPDATE queries predefined - 2. No dynamic column lists - 3. All variations prepared - 4. Static query patterns - - Why this matters: - ---------------- - Dynamic SQL construction risky: - - Column names from user = danger - - Dynamic SET clauses = injection - - Must prepare all variations - - Prefer multiple prepared statements - over dynamic SQL generation. - """ - # Create mock session - mock_session = AsyncMock(spec=AsyncCassandraSession) - mock_stmt = AsyncMock() - mock_session.prepare.return_value = mock_stmt - - # Test different update scenarios - update_queries = [ - "UPDATE users SET name = ?, updated_at = ? WHERE id = ?", - "UPDATE users SET email = ?, updated_at = ? WHERE id = ?", - "UPDATE users SET age = ?, updated_at = ? WHERE id = ?", - "UPDATE users SET name = ?, email = ?, updated_at = ? WHERE id = ?", - "UPDATE users SET name = ?, age = ?, updated_at = ? WHERE id = ?", - "UPDATE users SET email = ?, age = ?, updated_at = ? WHERE id = ?", - "UPDATE users SET name = ?, email = ?, age = ?, updated_at = ? WHERE id = ?", - ] - - for query in update_queries: - await mock_session.prepare(query) - - # Verify only static queries were prepared - for query in update_queries: - assert call(query) in mock_session.prepare.call_args_list - - @pytest.mark.asyncio - async def test_table_name_validation_before_use(self): - """ - Test that table names are validated before use in queries. - - What this tests: - --------------- - 1. Table names validated first - 2. System tables checked - 3. Only valid tables queried - 4. Prevents table name injection - - Why this matters: - ---------------- - Table names cannot be parameterized: - - Must validate against whitelist - - Check system_schema.tables - - Reject unknown tables - - Critical when table names come - from external sources. - """ - # Create mock session - mock_session = AsyncMock(spec=AsyncCassandraSession) - - # Mock validation query response - mock_result = MagicMock() - mock_result.one.return_value = {"table_name": "products"} - mock_session.execute.return_value = mock_result - - # Test table validation - keyspace = "export_example" - table_name = "products" - - # Validate table exists - validation_result = await mock_session.execute( - "SELECT table_name FROM system_schema.tables WHERE keyspace_name = ? AND table_name = ?", - [keyspace, table_name], - ) - - # Only proceed if table exists - if validation_result.one(): - await mock_session.execute(f"SELECT COUNT(*) FROM {keyspace}.{table_name}") - - # Verify validation query was called - mock_session.execute.assert_any_call( - "SELECT table_name FROM system_schema.tables WHERE keyspace_name = ? AND table_name = ?", - [keyspace, table_name], - ) - - @pytest.mark.asyncio - async def test_no_string_interpolation_in_queries(self): - """ - Test that queries don't use string interpolation with user input. - - What this tests: - --------------- - 1. No f-strings with queries - 2. No .format() with SQL - 3. No string concatenation - 4. Safe parameter handling - - Why this matters: - ---------------- - String interpolation = SQL injection: - - f"{query}" is ALWAYS wrong - - "query " + value is DANGEROUS - - .format() enables attacks - - Prepared statements are the - ONLY safe approach. - """ - # Create mock session - mock_session = AsyncMock(spec=AsyncCassandraSession) - mock_stmt = AsyncMock() - mock_session.prepare.return_value = mock_stmt - - # Bad patterns that should NOT be used - user_input = "'; DROP TABLE users; --" - - # Good: Using prepared statements - await mock_session.prepare("SELECT * FROM users WHERE name = ?") - await mock_session.execute(mock_stmt, [user_input]) - - # Good: Using prepared statements for LIMIT - limit = "100; DROP TABLE users" - await mock_session.prepare("SELECT * FROM users LIMIT ?") - await mock_session.execute(mock_stmt, [int(limit.split(";")[0])]) # Parse safely - - # Verify prepared statements were used (not string interpolation) - # The execute calls should have the mock statement and parameters, not raw SQL - for exec_call in mock_session.execute.call_args_list: - # Each call should be execute(mock_stmt, [params]) - assert exec_call[0][0] == mock_stmt # First arg is the prepared statement - assert isinstance(exec_call[0][1], list) # Second arg is parameters list - - @pytest.mark.asyncio - async def test_hardcoded_keyspace_names(self): - """ - Test that keyspace names are hardcoded, not from user input. - - What this tests: - --------------- - 1. Keyspace names are constants - 2. No dynamic keyspace creation - 3. DDL uses fixed names - 4. set_keyspace uses constants - - Why this matters: - ---------------- - Keyspace names critical for security: - - Cannot be parameterized - - Must be hardcoded/whitelisted - - User input = disaster - - Never let users control - keyspace or table names. - """ - # Create mock session - mock_session = AsyncMock(spec=AsyncCassandraSession) - - # Good: Hardcoded keyspace names - await mock_session.execute( - """ - CREATE KEYSPACE IF NOT EXISTS example - WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1} - """ - ) - - await mock_session.set_keyspace("example") - - # Verify no dynamic keyspace creation - create_calls = [ - call for call in mock_session.execute.call_args_list if "CREATE KEYSPACE" in str(call) - ] - - for create_call in create_calls: - query = str(create_call) - # Should not contain f-string or format markers - assert "{" not in query or "{'class'" in query # Allow replication config - assert "%" not in query - - @pytest.mark.asyncio - async def test_streaming_queries_use_prepared_statements(self): - """ - Test that streaming queries use prepared statements. - - What this tests: - --------------- - 1. Streaming queries prepared - 2. Parameters used with streams - 3. No dynamic SQL in streams - 4. Safe LIMIT handling - - Why this matters: - ---------------- - Streaming queries especially risky: - - Process large data sets - - Long-running operations - - Injection = massive impact - - Must use prepared statements - even for streaming queries. - """ - # Create mock session - mock_session = AsyncMock(spec=AsyncCassandraSession) - mock_stmt = AsyncMock() - mock_session.prepare.return_value = mock_stmt - mock_session.execute_stream.return_value = AsyncMock() - - # Test streaming with parameters - limit = 1000 - await mock_session.prepare("SELECT * FROM users LIMIT ?") - await mock_session.execute_stream(mock_stmt, [limit]) - - # Verify prepared statement was used - mock_session.prepare.assert_called_with("SELECT * FROM users LIMIT ?") - mock_session.execute_stream.assert_called_with(mock_stmt, [limit]) - - def test_sql_injection_patterns_not_present(self): - """ - Test that common SQL injection patterns are not in the codebase. - - What this tests: - --------------- - 1. No f-string SQL queries - 2. No .format() with queries - 3. No string concatenation - 4. No %-formatting SQL - - Why this matters: - ---------------- - Static analysis prevents: - - Accidental SQL injection - - Bad patterns creeping in - - Security regressions - - Code reviews should check - for these dangerous patterns. - """ - # This is a meta-test to ensure dangerous patterns aren't used - dangerous_patterns = [ - 'f"SELECT', # f-string SQL - 'f"INSERT', - 'f"UPDATE', - 'f"DELETE', - '".format(', # format string SQL - '" + ', # string concatenation - "' + ", - "% (", # old-style formatting - "% {", - ] - - # In real implementation, this would scan the actual files - # For now, we just document what patterns to avoid - for pattern in dangerous_patterns: - # Document that these patterns should not be used - assert pattern in dangerous_patterns # Tautology for documentation diff --git a/tests/unit/test_streaming_unified.py b/tests/unit/test_streaming_unified.py deleted file mode 100644 index 41472a5..0000000 --- a/tests/unit/test_streaming_unified.py +++ /dev/null @@ -1,710 +0,0 @@ -""" -Unified streaming tests for async-python-cassandra. - -This module consolidates all streaming-related tests from multiple files: -- test_streaming.py: Core streaming functionality and multi-page iteration -- test_streaming_memory.py: Memory management during streaming -- test_streaming_memory_management.py: Duplicate memory management tests -- test_streaming_memory_leak.py: Memory leak prevention tests - -Test Organization: -================== -1. Core Streaming Tests - Basic streaming functionality -2. Multi-Page Streaming Tests - Pagination and page fetching -3. Memory Management Tests - Resource cleanup and leak prevention -4. Error Handling Tests - Streaming error scenarios -5. Cancellation Tests - Stream cancellation and cleanup -6. Performance Tests - Large result set handling - -Key Testing Principles: -====================== -- Test both single-page and multi-page results -- Verify memory is properly released -- Ensure callbacks are cleaned up -- Test error propagation during streaming -- Verify cancellation doesn't leak resources -""" - -import gc -import weakref -from typing import Any, AsyncIterator, List -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from async_cassandra import AsyncCassandraSession -from async_cassandra.exceptions import QueryError -from async_cassandra.streaming import StreamConfig - - -class MockAsyncStreamingResultSet: - """Mock implementation of AsyncStreamingResultSet for testing""" - - def __init__(self, rows: List[Any], pages: List[List[Any]] = None): - self.rows = rows - self.pages = pages or [rows] - self._current_page_index = 0 - self._current_row_index = 0 - self._closed = False - self.total_rows_fetched = 0 - - async def __aenter__(self): - return self - - async def __aexit__(self, *args): - await self.close() - - async def close(self): - self._closed = True - - def __aiter__(self): - return self - - async def __anext__(self): - if self._closed: - raise StopAsyncIteration - - # If we have pages - if self.pages: - if self._current_page_index >= len(self.pages): - raise StopAsyncIteration - - current_page = self.pages[self._current_page_index] - if self._current_row_index >= len(current_page): - self._current_page_index += 1 - self._current_row_index = 0 - - if self._current_page_index >= len(self.pages): - raise StopAsyncIteration - - current_page = self.pages[self._current_page_index] - - row = current_page[self._current_row_index] - self._current_row_index += 1 - self.total_rows_fetched += 1 - return row - else: - # Simple case - all rows in one list - if self._current_row_index >= len(self.rows): - raise StopAsyncIteration - - row = self.rows[self._current_row_index] - self._current_row_index += 1 - self.total_rows_fetched += 1 - return row - - async def pages(self) -> AsyncIterator[List[Any]]: - """Iterate over pages instead of rows""" - for page in self.pages: - yield page - - -class TestStreamingFunctionality: - """ - Test core streaming functionality. - - Streaming is used for large result sets that don't fit in memory. - These tests verify the streaming API works correctly. - """ - - @pytest.mark.asyncio - async def test_single_page_streaming(self): - """ - Test streaming with a single page of results. - - What this tests: - --------------- - 1. execute_stream returns AsyncStreamingResultSet - 2. Single page results work correctly - 3. Context manager properly opens/closes stream - 4. All rows are yielded - - Why this matters: - ---------------- - Even single-page results should work with streaming API - for consistency. This is the simplest streaming case. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Mock the execute_stream to return our mock streaming result - rows = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}, {"id": 3, "name": "Charlie"}] - - mock_stream = MockAsyncStreamingResultSet(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Collect all streamed rows - collected_rows = [] - async with await async_session.execute_stream("SELECT * FROM users") as stream: - async for row in stream: - collected_rows.append(row) - - # Verify all rows were streamed - assert len(collected_rows) == 3 - assert collected_rows[0]["name"] == "Alice" - assert collected_rows[1]["name"] == "Bob" - assert collected_rows[2]["name"] == "Charlie" - - @pytest.mark.asyncio - async def test_multi_page_streaming(self): - """ - Test streaming with multiple pages of results. - - What this tests: - --------------- - 1. Multiple pages are fetched automatically - 2. Page boundaries are transparent to user - 3. All pages are processed in order - 4. Has_more_pages triggers next fetch - - Why this matters: - ---------------- - Large result sets span multiple pages. The streaming - API must seamlessly fetch pages as needed. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Define pages of data - pages = [ - [{"id": 1}, {"id": 2}, {"id": 3}], - [{"id": 4}, {"id": 5}, {"id": 6}], - [{"id": 7}, {"id": 8}, {"id": 9}], - ] - - all_rows = [row for page in pages for row in page] - mock_stream = MockAsyncStreamingResultSet(all_rows, pages) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Stream all pages - collected_rows = [] - async with await async_session.execute_stream("SELECT * FROM large_table") as stream: - async for row in stream: - collected_rows.append(row) - - # Verify all rows from all pages - assert len(collected_rows) == 9 - assert [r["id"] for r in collected_rows] == list(range(1, 10)) - - @pytest.mark.asyncio - async def test_streaming_with_fetch_size(self): - """ - Test streaming with custom fetch size. - - What this tests: - --------------- - 1. Custom fetch_size is respected - 2. Page size affects streaming behavior - 3. Configuration passes through correctly - - Why this matters: - ---------------- - Fetch size controls memory usage and performance. - Users need to tune this for their use case. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Just verify the config is passed - actual pagination is tested elsewhere - rows = [{"id": i} for i in range(100)] - mock_stream = MockAsyncStreamingResultSet(rows) - - # Mock execute_stream to verify it's called with correct config - execute_stream_mock = AsyncMock(return_value=mock_stream) - - with patch.object(async_session, "execute_stream", execute_stream_mock): - stream_config = StreamConfig(fetch_size=1000) - async with await async_session.execute_stream( - "SELECT * FROM large_table", stream_config=stream_config - ) as stream: - async for row in stream: - pass - - # Verify execute_stream was called with the config - execute_stream_mock.assert_called_once() - args, kwargs = execute_stream_mock.call_args - assert kwargs.get("stream_config") == stream_config - - @pytest.mark.asyncio - async def test_streaming_error_propagation(self): - """ - Test error handling during streaming. - - What this tests: - --------------- - 1. Errors are properly propagated - 2. Context manager handles errors - 3. Resources are cleaned up on error - - Why this matters: - ---------------- - Streaming operations can fail mid-stream. Errors must - be handled gracefully without resource leaks. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Create a mock that will raise an error - error_msg = "Network error during streaming" - execute_stream_mock = AsyncMock(side_effect=QueryError(error_msg)) - - with patch.object(async_session, "execute_stream", execute_stream_mock): - # Verify error is propagated - with pytest.raises(QueryError) as exc_info: - async with await async_session.execute_stream("SELECT * FROM test") as stream: - async for row in stream: - pass - - assert error_msg in str(exc_info.value) - - @pytest.mark.asyncio - async def test_streaming_cancellation(self): - """ - Test cancelling streaming mid-iteration. - - What this tests: - --------------- - 1. Stream can be cancelled - 2. Resources are cleaned up - 3. No errors on early exit - - Why this matters: - ---------------- - Users may need to stop streaming early. This shouldn't - leak resources or cause errors. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Large result set - rows = [{"id": i} for i in range(1000)] - mock_stream = MockAsyncStreamingResultSet(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - processed = 0 - async with await async_session.execute_stream("SELECT * FROM large_table") as stream: - async for row in stream: - processed += 1 - if processed >= 10: - break # Early exit - - # Verify we stopped early - assert processed == 10 - # Verify stream was closed - assert mock_stream._closed - - @pytest.mark.asyncio - async def test_empty_result_streaming(self): - """ - Test streaming with empty results. - - What this tests: - --------------- - 1. Empty results don't cause errors - 2. Iterator completes immediately - 3. Context manager works with no data - - Why this matters: - ---------------- - Queries may return no results. The streaming API - should handle this gracefully. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Empty result - mock_stream = MockAsyncStreamingResultSet([]) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - rows_found = 0 - async with await async_session.execute_stream("SELECT * FROM empty_table") as stream: - async for row in stream: - rows_found += 1 - - assert rows_found == 0 - - -class TestStreamingMemoryManagement: - """ - Test memory management during streaming operations. - - These tests verify that streaming doesn't leak memory and - properly cleans up resources. - """ - - @pytest.mark.asyncio - async def test_memory_cleanup_after_streaming(self): - """ - Test memory is released after streaming completes. - - What this tests: - --------------- - 1. Row objects are not retained after iteration - 2. Internal buffers are cleared - 3. Garbage collection works properly - - Why this matters: - ---------------- - Streaming large datasets shouldn't cause memory to - accumulate. Each page should be released after processing. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Track row object references - row_refs = [] - - # Create rows that support weakref - class Row: - def __init__(self, id, data): - self.id = id - self.data = data - - def __getitem__(self, key): - return getattr(self, key) - - rows = [] - for i in range(100): - row = Row(id=i, data="x" * 1000) - rows.append(row) - row_refs.append(weakref.ref(row)) - - mock_stream = MockAsyncStreamingResultSet(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Stream and process rows - processed = 0 - async with await async_session.execute_stream("SELECT * FROM test") as stream: - async for row in stream: - processed += 1 - # Don't keep references - - # Clear all references - rows = None - mock_stream.rows = [] - mock_stream.pages = [] - mock_stream = None - - # Force garbage collection - gc.collect() - - # Check that rows were released - alive_refs = sum(1 for ref in row_refs if ref() is not None) - assert processed == 100 - # Most rows should be collected (some may still be referenced) - assert alive_refs < 10 - - @pytest.mark.asyncio - async def test_memory_cleanup_on_error(self): - """ - Test memory cleanup when error occurs during streaming. - - What this tests: - --------------- - 1. Partial results are cleaned up on error - 2. Callbacks are removed - 3. No dangling references - - Why this matters: - ---------------- - Errors during streaming shouldn't leak the partially - processed results or internal state. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Create a stream that will fail mid-iteration - class FailingStream(MockAsyncStreamingResultSet): - def __init__(self, rows): - super().__init__(rows) - self.iterations = 0 - - async def __anext__(self): - self.iterations += 1 - if self.iterations > 5: - raise Exception("Database error") - return await super().__anext__() - - rows = [{"id": i} for i in range(50)] - mock_stream = FailingStream(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Try to stream, should error - with pytest.raises(Exception) as exc_info: - async with await async_session.execute_stream("SELECT * FROM test") as stream: - async for row in stream: - pass - - assert "Database error" in str(exc_info.value) - # Stream should be closed even on error - assert mock_stream._closed - - @pytest.mark.asyncio - async def test_no_memory_leak_with_many_pages(self): - """ - Test no memory accumulation with many pages. - - What this tests: - --------------- - 1. Memory doesn't grow with page count - 2. Old pages are released - 3. Only current page is in memory - - Why this matters: - ---------------- - Streaming millions of rows across thousands of pages - shouldn't cause memory to grow unbounded. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Create many small pages - pages = [] - for page_num in range(100): - page = [{"id": page_num * 10 + i, "page": page_num} for i in range(10)] - pages.append(page) - - all_rows = [row for page in pages for row in page] - mock_stream = MockAsyncStreamingResultSet(all_rows, pages) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Stream through all pages - total_rows = 0 - page_numbers_seen = set() - - async with await async_session.execute_stream("SELECT * FROM huge_table") as stream: - async for row in stream: - total_rows += 1 - page_numbers_seen.add(row.get("page")) - - # Verify we processed all pages - assert total_rows == 1000 - assert len(page_numbers_seen) == 100 - - @pytest.mark.asyncio - async def test_stream_close_releases_resources(self): - """ - Test that closing stream releases all resources. - - What this tests: - --------------- - 1. Explicit close() works - 2. Resources are freed immediately - 3. Cannot iterate after close - - Why this matters: - ---------------- - Users may need to close streams early. This should - immediately free all resources. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - rows = [{"id": i} for i in range(100)] - mock_stream = MockAsyncStreamingResultSet(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - stream = await async_session.execute_stream("SELECT * FROM test") - - # Process a few rows - row_count = 0 - async for row in stream: - row_count += 1 - if row_count >= 5: - break - - # Explicitly close - await stream.close() - - # Verify closed - assert stream._closed - - # Cannot iterate after close - with pytest.raises(StopAsyncIteration): - await stream.__anext__() - - @pytest.mark.asyncio - async def test_weakref_cleanup_on_session_close(self): - """ - Test cleanup when session is closed during streaming. - - What this tests: - --------------- - 1. Session close interrupts streaming - 2. Stream resources are cleaned up - 3. No dangling references - - Why this matters: - ---------------- - Session may be closed while streams are active. This - shouldn't leak stream resources. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Track if stream was cleaned up - stream_closed = False - - class TrackableStream(MockAsyncStreamingResultSet): - async def close(self): - nonlocal stream_closed - stream_closed = True - await super().close() - - rows = [{"id": i} for i in range(1000)] - mock_stream = TrackableStream(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Start streaming but don't finish - stream = await async_session.execute_stream("SELECT * FROM test") - - # Process a few rows - count = 0 - async for row in stream: - count += 1 - if count >= 5: - break - - # Close the stream (simulating session close) - await stream.close() - - # Verify cleanup happened - assert stream_closed - - -class TestStreamingPerformance: - """ - Test streaming performance characteristics. - - These tests verify streaming can handle large datasets efficiently. - """ - - @pytest.mark.asyncio - async def test_streaming_large_rows(self): - """ - Test streaming rows with large data. - - What this tests: - --------------- - 1. Large rows don't cause issues - 2. Memory per row is bounded - 3. Streaming continues smoothly - - Why this matters: - ---------------- - Some rows may contain blobs or large text fields. - Streaming should handle these efficiently. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Create rows with large data - rows = [] - for i in range(50): - rows.append( - { - "id": i, - "data": "x" * 100000, # 100KB per row - "blob": b"y" * 50000, # 50KB binary - } - ) - - mock_stream = MockAsyncStreamingResultSet(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - processed = 0 - total_size = 0 - - async with await async_session.execute_stream("SELECT * FROM blobs") as stream: - async for row in stream: - processed += 1 - total_size += len(row["data"]) + len(row["blob"]) - - assert processed == 50 - assert total_size == 50 * (100000 + 50000) - - @pytest.mark.asyncio - async def test_streaming_high_throughput(self): - """ - Test streaming can maintain high throughput. - - What this tests: - --------------- - 1. Thousands of rows/second possible - 2. Minimal overhead per row - 3. Efficient page transitions - - Why this matters: - ---------------- - Bulk data operations need high throughput. Streaming - overhead must be minimal. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Simulate high-throughput scenario - rows_per_page = 5000 - num_pages = 20 - - pages = [] - for page_num in range(num_pages): - page = [{"id": page_num * rows_per_page + i} for i in range(rows_per_page)] - pages.append(page) - - all_rows = [row for page in pages for row in page] - mock_stream = MockAsyncStreamingResultSet(all_rows, pages) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Stream all rows and measure throughput - import time - - start_time = time.time() - - total_rows = 0 - async with await async_session.execute_stream("SELECT * FROM big_table") as stream: - async for row in stream: - total_rows += 1 - - elapsed = time.time() - start_time - - expected_total = rows_per_page * num_pages - assert total_rows == expected_total - - # Should process quickly (implementation dependent) - # This documents the performance expectation - rows_per_second = total_rows / elapsed if elapsed > 0 else 0 - # Should handle thousands of rows/second - assert rows_per_second > 0 # Use the variable - - @pytest.mark.asyncio - async def test_streaming_memory_limit_enforcement(self): - """ - Test memory limits are enforced during streaming. - - What this tests: - --------------- - 1. Configurable memory limits - 2. Backpressure when limit reached - 3. Graceful handling of limits - - Why this matters: - ---------------- - Production systems have memory constraints. Streaming - must respect these limits. - """ - mock_session = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Large amount of data - rows = [{"id": i, "data": "x" * 10000} for i in range(1000)] - mock_stream = MockAsyncStreamingResultSet(rows) - - with patch.object(async_session, "execute_stream", return_value=mock_stream): - # Stream with memory awareness - rows_processed = 0 - async with await async_session.execute_stream("SELECT * FROM test") as stream: - async for row in stream: - rows_processed += 1 - # In real implementation, might pause/backpressure here - if rows_processed >= 100: - break diff --git a/tests/unit/test_thread_safety.py b/tests/unit/test_thread_safety.py deleted file mode 100644 index 9783d7e..0000000 --- a/tests/unit/test_thread_safety.py +++ /dev/null @@ -1,454 +0,0 @@ -"""Core thread safety and event loop handling tests. - -This module tests the critical thread pool configuration and event loop -integration that enables the async wrapper to work correctly. - -Test Organization: -================== -- TestEventLoopHandling: Event loop creation and management across threads -- TestThreadPoolConfiguration: Thread pool limits and concurrent execution - -Key Testing Focus: -================== -1. Event loop isolation between threads -2. Thread-safe callback scheduling -3. Thread pool size limits -4. Concurrent operation handling -5. Thread-local storage isolation - -Why This Matters: -================= -The Cassandra driver uses threads for I/O, while our wrapper provides -async/await interface. This requires careful thread and event loop -management to prevent: -- Deadlocks between threads and event loops -- Event loop conflicts -- Thread pool exhaustion -- Race conditions in callbacks -""" - -import asyncio -import threading -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from async_cassandra.utils import get_or_create_event_loop, safe_call_soon_threadsafe - -# Test constants -MAX_WORKERS = 32 -_thread_local = threading.local() - - -class TestEventLoopHandling: - """ - Test event loop management in threaded environments. - - The async wrapper must handle event loops correctly across - multiple threads since Cassandra driver callbacks may come - from any thread in the executor pool. - """ - - @pytest.mark.core - @pytest.mark.quick - async def test_get_or_create_event_loop_main_thread(self): - """ - Test getting event loop in main thread. - - What this tests: - --------------- - 1. In async context, returns the running loop - 2. Doesn't create a new loop when one exists - 3. Returns the correct loop instance - - Why this matters: - ---------------- - The main thread typically has an event loop (from asyncio.run - or pytest-asyncio). We must use the existing loop rather than - creating a new one, which would cause: - - Event loop conflicts - - Callbacks lost in wrong loop - - "Event loop is closed" errors - """ - # In async context, should return the running loop - expected_loop = asyncio.get_running_loop() - result = get_or_create_event_loop() - assert result == expected_loop - - @pytest.mark.core - def test_get_or_create_event_loop_worker_thread(self): - """ - Test creating event loop in worker thread. - - What this tests: - --------------- - 1. Worker threads create new event loops - 2. Created loop is stored thread-locally - 3. Loop is properly initialized - 4. Thread can use its own loop - - Why this matters: - ---------------- - Cassandra driver uses a thread pool for I/O operations. - When callbacks fire in these threads, they need a way to - communicate results back to the main async context. Each - worker thread needs its own event loop to: - - Schedule callbacks to main loop - - Handle thread-local async operations - - Avoid conflicts with other threads - - Without this, callbacks from driver threads would fail. - """ - result_loop = None - - def worker(): - nonlocal result_loop - # Worker thread should create a new loop - result_loop = get_or_create_event_loop() - assert result_loop is not None - assert isinstance(result_loop, asyncio.AbstractEventLoop) - - thread = threading.Thread(target=worker) - thread.start() - thread.join() - - assert result_loop is not None - - @pytest.mark.core - @pytest.mark.critical - def test_thread_local_event_loops(self): - """ - Test that each thread gets its own event loop. - - What this tests: - --------------- - 1. Multiple threads each get unique loops - 2. Loops don't interfere with each other - 3. Thread-local storage works correctly - 4. No loop sharing between threads - - Why this matters: - ---------------- - Event loops are not thread-safe. Sharing loops between - threads would cause: - - Race conditions - - Corrupted event loop state - - Callbacks executed in wrong thread - - Deadlocks and hangs - - This test ensures our thread-local storage pattern - correctly isolates event loops, which is critical for - the driver's thread pool to work with async/await. - """ - loops = [] - - def worker(): - loop = get_or_create_event_loop() - loops.append(loop) - - threads = [] - for _ in range(5): - thread = threading.Thread(target=worker) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - # Each thread should have created a unique loop - assert len(loops) == 5 - assert len(set(id(loop) for loop in loops)) == 5 - - @pytest.mark.core - async def test_safe_call_soon_threadsafe(self): - """ - Test thread-safe callback scheduling. - - What this tests: - --------------- - 1. Callbacks can be scheduled from same thread - 2. Callback executes in the target loop - 3. Arguments are passed correctly - 4. Callback runs asynchronously - - Why this matters: - ---------------- - This is the bridge between driver threads and async code: - - Driver completes query in thread pool - - Needs to deliver result to async context - - Must use call_soon_threadsafe for safety - - The safe wrapper handles edge cases like closed loops. - """ - result = [] - - def callback(value): - result.append(value) - - loop = asyncio.get_running_loop() - - # Schedule callback from same thread - safe_call_soon_threadsafe(loop, callback, "test1") - - # Give callback time to execute - await asyncio.sleep(0.1) - - assert result == ["test1"] - - @pytest.mark.core - def test_safe_call_soon_threadsafe_from_thread(self): - """ - Test scheduling callback from different thread. - - What this tests: - --------------- - 1. Callbacks work across thread boundaries - 2. Target loop receives callback correctly - 3. Synchronization works (via Event) - 4. No race conditions or deadlocks - - Why this matters: - ---------------- - This simulates the real scenario: - - Main thread has async event loop - - Driver thread completes I/O operation - - Driver thread schedules callback to main loop - - Result delivered safely across threads - - This is the core mechanism that makes the async - wrapper possible - bridging sync callbacks to async. - """ - result = [] - event = threading.Event() - - def callback(value): - result.append(value) - event.set() - - loop = asyncio.new_event_loop() - - def run_loop(): - asyncio.set_event_loop(loop) - loop.run_forever() - - loop_thread = threading.Thread(target=run_loop) - loop_thread.start() - - try: - # Schedule from different thread - def worker(): - safe_call_soon_threadsafe(loop, callback, "test2") - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - worker_thread.join() - - # Wait for callback - event.wait(timeout=1) - assert result == ["test2"] - - finally: - loop.call_soon_threadsafe(loop.stop) - loop_thread.join() - loop.close() - - @pytest.mark.core - def test_safe_call_soon_threadsafe_closed_loop(self): - """ - Test handling of closed event loop. - - What this tests: - --------------- - 1. Closed loop is handled gracefully - 2. No exception is raised - 3. Callback is silently dropped - 4. System remains stable - - Why this matters: - ---------------- - During shutdown or error scenarios: - - Event loop might be closed - - Driver callbacks might still arrive - - Must not crash the application - - Should fail silently rather than propagate - - This defensive programming prevents crashes during - shutdown sequences or error recovery. - """ - loop = asyncio.new_event_loop() - loop.close() - - # Should handle gracefully - safe_call_soon_threadsafe(loop, lambda: None) - # No exception should be raised - - -class TestThreadPoolConfiguration: - """ - Test thread pool configuration and limits. - - The Cassandra driver uses a thread pool for I/O operations. - These tests ensure proper configuration and behavior under load. - """ - - @pytest.mark.core - @pytest.mark.quick - def test_max_workers_constant(self): - """ - Test MAX_WORKERS is set correctly. - - What this tests: - --------------- - 1. Thread pool size constant is defined - 2. Value is reasonable (32 threads) - 3. Constant is accessible - - Why this matters: - ---------------- - Thread pool size affects: - - Maximum concurrent operations - - Memory usage (each thread has stack) - - Performance under load - - 32 threads is a balance between concurrency and - resource usage for typical applications. - """ - assert MAX_WORKERS == 32 - - @pytest.mark.core - def test_thread_pool_creation(self): - """ - Test thread pool is created with correct parameters. - - What this tests: - --------------- - 1. AsyncCluster respects executor_threads parameter - 2. Thread pool is created with specified size - 3. Configuration flows to driver correctly - - Why this matters: - ---------------- - Applications need to tune thread pool size based on: - - Expected query volume - - Available system resources - - Latency requirements - - Too few threads: queries queue up, high latency - Too many threads: memory waste, context switching - - This ensures the configuration works as expected. - """ - from async_cassandra.cluster import AsyncCluster - - cluster = AsyncCluster(executor_threads=16) - assert cluster._cluster.executor._max_workers == 16 - - @pytest.mark.core - @pytest.mark.critical - async def test_concurrent_operations_within_limit(self): - """ - Test handling concurrent operations within thread pool limit. - - What this tests: - --------------- - 1. Multiple concurrent queries execute successfully - 2. All operations complete without blocking - 3. Results are delivered correctly - 4. No thread pool exhaustion with reasonable load - - Why this matters: - ---------------- - Real applications execute many queries concurrently: - - Web requests trigger multiple queries - - Batch processing runs parallel operations - - Background tasks query simultaneously - - The thread pool must handle reasonable concurrency - without deadlocking or failing. This test simulates - a typical concurrent load scenario. - - 10 concurrent operations is well within the 32 thread - limit, so all should complete successfully. - """ - from cassandra.cluster import ResponseFuture - - from async_cassandra.session import AsyncCassandraSession as AsyncSession - - mock_session = Mock() - results = [] - - def mock_execute_async(*args, **kwargs): - mock_future = Mock(spec=ResponseFuture) - mock_future.result.return_value = Mock(rows=[]) - mock_future.timeout = None - mock_future.has_more_pages = False - results.append(1) - return mock_future - - mock_session.execute_async.side_effect = mock_execute_async - - async_session = AsyncSession(mock_session) - - # Run operations concurrently - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=Mock(rows=[])) - mock_handler_class.return_value = mock_handler - - tasks = [] - for i in range(10): - task = asyncio.create_task(async_session.execute(f"SELECT * FROM table{i}")) - tasks.append(task) - - await asyncio.gather(*tasks) - - # All operations should complete - assert len(results) == 10 - - @pytest.mark.core - def test_thread_local_storage(self): - """ - Test thread-local storage for event loops. - - What this tests: - --------------- - 1. Each thread has isolated storage - 2. Values don't leak between threads - 3. Thread-local mechanism works correctly - 4. Storage is truly thread-specific - - Why this matters: - ---------------- - Thread-local storage is critical for: - - Event loop isolation (each thread's loop) - - Connection state per thread - - Avoiding race conditions - - If thread-local storage failed: - - Event loops would be shared (crashes) - - State would corrupt between threads - - Race conditions everywhere - - This fundamental mechanism enables safe multi-threaded - operation of the async wrapper. - """ - # Each thread should have its own storage - storage_values = [] - - def worker(value): - _thread_local.test_value = value - storage_values.append((_thread_local.test_value, threading.current_thread().ident)) - - threads = [] - for i in range(5): - thread = threading.Thread(target=worker, args=(i,)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - # Each thread should have stored its own value - assert len(storage_values) == 5 - values = [v[0] for v in storage_values] - assert sorted(values) == [0, 1, 2, 3, 4] diff --git a/tests/unit/test_timeout_unified.py b/tests/unit/test_timeout_unified.py deleted file mode 100644 index 8c8d5c6..0000000 --- a/tests/unit/test_timeout_unified.py +++ /dev/null @@ -1,517 +0,0 @@ -""" -Consolidated timeout tests for async-python-cassandra. - -This module consolidates timeout testing from multiple files into focused, -clear tests that match the actual implementation. - -Test Organization: -================== -1. Query Timeout Tests - Timeout parameter propagation -2. Timeout Exception Tests - ReadTimeout, WriteTimeout handling -3. Prepare Timeout Tests - Statement preparation timeouts -4. Resource Cleanup Tests - Proper cleanup on timeout - -Key Testing Principles: -====================== -- Test timeout parameter flow through the layers -- Verify timeout exceptions are handled correctly -- Ensure no resource leaks on timeout -- Test default timeout behavior -""" - -import asyncio -from unittest.mock import AsyncMock, Mock, patch - -import pytest -from cassandra import ReadTimeout, WriteTimeout -from cassandra.cluster import _NOT_SET, ResponseFuture -from cassandra.policies import WriteType - -from async_cassandra import AsyncCassandraSession - - -class TestTimeoutHandling: - """ - Test timeout handling throughout the async wrapper. - - These tests verify that timeouts work correctly at all levels - and that timeout exceptions are properly handled. - """ - - # ======================================== - # Query Timeout Tests - # ======================================== - - @pytest.mark.asyncio - async def test_execute_with_explicit_timeout(self): - """ - Test that explicit timeout is passed to driver. - - What this tests: - --------------- - 1. Timeout parameter flows to execute_async - 2. Timeout value is preserved correctly - 3. Handler receives timeout for its operation - - Why this matters: - ---------------- - Users need to control query timeouts for different - operations based on their performance requirements. - """ - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_future.has_more_pages = False - mock_session.execute_async.return_value = mock_future - - async_session = AsyncCassandraSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=Mock(rows=[])) - mock_handler_class.return_value = mock_handler - - await async_session.execute("SELECT * FROM test", timeout=5.0) - - # Verify execute_async was called with timeout - mock_session.execute_async.assert_called_once() - args = mock_session.execute_async.call_args[0] - # timeout is the 5th argument (index 4) - assert args[4] == 5.0 - - # Verify handler.get_result was called with timeout - mock_handler.get_result.assert_called_once_with(timeout=5.0) - - @pytest.mark.asyncio - async def test_execute_without_timeout_uses_not_set(self): - """ - Test that missing timeout uses _NOT_SET sentinel. - - What this tests: - --------------- - 1. No timeout parameter results in _NOT_SET - 2. Handler receives None for timeout - 3. Driver uses its default timeout - - Why this matters: - ---------------- - Most queries don't specify timeout and should use - driver defaults rather than arbitrary values. - """ - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_future.has_more_pages = False - mock_session.execute_async.return_value = mock_future - - async_session = AsyncCassandraSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=Mock(rows=[])) - mock_handler_class.return_value = mock_handler - - await async_session.execute("SELECT * FROM test") - - # Verify _NOT_SET was passed to execute_async - args = mock_session.execute_async.call_args[0] - # timeout is the 5th argument (index 4) - assert args[4] is _NOT_SET - - # Verify handler got None timeout - mock_handler.get_result.assert_called_once_with(timeout=None) - - @pytest.mark.asyncio - async def test_concurrent_queries_different_timeouts(self): - """ - Test concurrent queries with different timeouts. - - What this tests: - --------------- - 1. Multiple queries maintain separate timeouts - 2. Concurrent execution doesn't mix timeouts - 3. Each query respects its timeout - - Why this matters: - ---------------- - Real applications run many queries concurrently, - each with different performance characteristics. - """ - mock_session = Mock() - - # Track futures to return them in order - futures = [] - - def create_future(*args, **kwargs): - future = Mock(spec=ResponseFuture) - future.has_more_pages = False - futures.append(future) - return future - - mock_session.execute_async.side_effect = create_future - - async_session = AsyncCassandraSession(mock_session) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - # Create handlers that return immediately - handlers = [] - - def create_handler(future): - handler = Mock() - handler.get_result = AsyncMock(return_value=Mock(rows=[])) - handlers.append(handler) - return handler - - mock_handler_class.side_effect = create_handler - - # Execute queries concurrently - await asyncio.gather( - async_session.execute("SELECT 1", timeout=1.0), - async_session.execute("SELECT 2", timeout=5.0), - async_session.execute("SELECT 3"), # No timeout - ) - - # Verify timeouts were passed correctly - calls = mock_session.execute_async.call_args_list - # timeout is the 5th argument (index 4) - assert calls[0][0][4] == 1.0 - assert calls[1][0][4] == 5.0 - assert calls[2][0][4] is _NOT_SET - - # Verify handlers got correct timeouts - assert handlers[0].get_result.call_args[1]["timeout"] == 1.0 - assert handlers[1].get_result.call_args[1]["timeout"] == 5.0 - assert handlers[2].get_result.call_args[1]["timeout"] is None - - # ======================================== - # Timeout Exception Tests - # ======================================== - - @pytest.mark.asyncio - async def test_read_timeout_exception_handling(self): - """ - Test ReadTimeout exception is properly handled. - - What this tests: - --------------- - 1. ReadTimeout from driver is caught - 2. Not wrapped in QueryError (re-raised as-is) - 3. Exception details are preserved - - Why this matters: - ---------------- - Read timeouts indicate the query took too long. - Applications need the full exception details for - retry decisions and debugging. - """ - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_session.execute_async.return_value = mock_future - - async_session = AsyncCassandraSession(mock_session) - - # Create proper ReadTimeout - timeout_error = ReadTimeout( - message="Read timeout", - consistency=3, # ConsistencyLevel.THREE - required_responses=2, - received_responses=1, - ) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(side_effect=timeout_error) - mock_handler_class.return_value = mock_handler - - # Should raise ReadTimeout directly (not wrapped) - with pytest.raises(ReadTimeout) as exc_info: - await async_session.execute("SELECT * FROM test") - - # Verify it's the same exception - assert exc_info.value is timeout_error - - @pytest.mark.asyncio - async def test_write_timeout_exception_handling(self): - """ - Test WriteTimeout exception is properly handled. - - What this tests: - --------------- - 1. WriteTimeout from driver is caught - 2. Not wrapped in QueryError (re-raised as-is) - 3. Write type information is preserved - - Why this matters: - ---------------- - Write timeouts need special handling as they may - have partially succeeded. Write type helps determine - if retry is safe. - """ - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_session.execute_async.return_value = mock_future - - async_session = AsyncCassandraSession(mock_session) - - # Create proper WriteTimeout with numeric write_type - timeout_error = WriteTimeout( - message="Write timeout", - consistency=3, # ConsistencyLevel.THREE - write_type=WriteType.SIMPLE, # Use enum value (0) - required_responses=2, - received_responses=1, - ) - - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(side_effect=timeout_error) - mock_handler_class.return_value = mock_handler - - # Should raise WriteTimeout directly - with pytest.raises(WriteTimeout) as exc_info: - await async_session.execute("INSERT INTO test VALUES (1)") - - assert exc_info.value is timeout_error - - @pytest.mark.asyncio - async def test_timeout_with_retry_policy(self): - """ - Test timeout exceptions are properly propagated. - - What this tests: - --------------- - 1. ReadTimeout errors are not wrapped - 2. Exception details are preserved - 3. Retry happens at driver level - - Why this matters: - ---------------- - The driver handles retries internally based on its - retry policy. We just need to propagate the exception. - """ - mock_session = Mock() - - # Simulate timeout from driver (after retries exhausted) - timeout_error = ReadTimeout("Read Timeout") - mock_session.execute_async.side_effect = timeout_error - - async_session = AsyncCassandraSession(mock_session) - - # Should raise the ReadTimeout as-is - with pytest.raises(ReadTimeout) as exc_info: - await async_session.execute("SELECT * FROM test") - - # Verify it's the same exception instance - assert exc_info.value is timeout_error - - # ======================================== - # Prepare Timeout Tests - # ======================================== - - @pytest.mark.asyncio - async def test_prepare_with_explicit_timeout(self): - """ - Test statement preparation with timeout. - - What this tests: - --------------- - 1. Prepare accepts timeout parameter - 2. Uses asyncio timeout for blocking operation - 3. Returns prepared statement on success - - Why this matters: - ---------------- - Statement preparation can be slow with complex - queries or overloaded clusters. - """ - mock_session = Mock() - mock_prepared = Mock() # PreparedStatement - mock_session.prepare.return_value = mock_prepared - - async_session = AsyncCassandraSession(mock_session) - - # Should complete within timeout - prepared = await async_session.prepare("SELECT * FROM test WHERE id = ?", timeout=5.0) - - assert prepared is mock_prepared - mock_session.prepare.assert_called_once_with( - "SELECT * FROM test WHERE id = ?", None # custom_payload - ) - - @pytest.mark.asyncio - async def test_prepare_uses_default_timeout(self): - """ - Test prepare uses default timeout when not specified. - - What this tests: - --------------- - 1. Default timeout constant is used - 2. Prepare completes successfully - - Why this matters: - ---------------- - Most prepare calls don't specify timeout and - should use a reasonable default. - """ - mock_session = Mock() - mock_prepared = Mock() - mock_session.prepare.return_value = mock_prepared - - async_session = AsyncCassandraSession(mock_session) - - # Prepare without timeout - prepared = await async_session.prepare("SELECT * FROM test WHERE id = ?") - - assert prepared is mock_prepared - - @pytest.mark.asyncio - async def test_prepare_timeout_error(self): - """ - Test prepare timeout is handled correctly. - - What this tests: - --------------- - 1. Slow prepare operations timeout - 2. TimeoutError is wrapped in QueryError - 3. Error message is helpful - - Why this matters: - ---------------- - Prepare timeouts need clear error messages to - help debug schema or query complexity issues. - """ - mock_session = Mock() - - # Simulate slow prepare in the sync driver - def slow_prepare(query, payload): - import time - - time.sleep(10) # This will block, causing timeout - return Mock() - - mock_session.prepare = Mock(side_effect=slow_prepare) - - async_session = AsyncCassandraSession(mock_session) - - # Should timeout quickly (prepare uses DEFAULT_REQUEST_TIMEOUT if not specified) - with pytest.raises(asyncio.TimeoutError): - await async_session.prepare("SELECT * FROM test WHERE id = ?", timeout=0.1) - - # ======================================== - # Resource Cleanup Tests - # ======================================== - - @pytest.mark.asyncio - async def test_timeout_cleanup_on_session_close(self): - """ - Test pending operations are cleaned up on close. - - What this tests: - --------------- - 1. Pending queries are cancelled on close - 2. No "pending task" warnings - 3. Session closes cleanly - - Why this matters: - ---------------- - Proper cleanup prevents resource leaks and - "task was destroyed but pending" warnings. - """ - mock_session = Mock() - mock_future = Mock(spec=ResponseFuture) - mock_future.has_more_pages = False - - # Track callback registration - registered_callbacks = [] - - def add_callbacks(callback=None, errback=None): - registered_callbacks.append((callback, errback)) - - mock_future.add_callbacks = add_callbacks - mock_session.execute_async.return_value = mock_future - - async_session = AsyncCassandraSession(mock_session) - - # Start a long-running query - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - # Make get_result hang - hang_event = asyncio.Event() - - async def hang_forever(*args, **kwargs): - await hang_event.wait() - - mock_handler.get_result = hang_forever - mock_handler_class.return_value = mock_handler - - # Start query but don't await it - query_task = asyncio.create_task( - async_session.execute("SELECT * FROM test", timeout=30.0) - ) - - # Let it start - await asyncio.sleep(0.01) - - # Close session - await async_session.close() - - # Set event to unblock - hang_event.set() - - # Task should complete (likely with error) - try: - await query_task - except Exception: - pass # Expected - - @pytest.mark.asyncio - async def test_multiple_timeout_cleanup(self): - """ - Test cleanup of multiple timed-out operations. - - What this tests: - --------------- - 1. Multiple timeouts don't leak resources - 2. Session remains stable after timeouts - 3. New queries work after timeouts - - Why this matters: - ---------------- - Production systems may experience many timeouts. - The session must remain stable and usable. - """ - mock_session = Mock() - - # Track created futures - futures = [] - - def create_future(*args, **kwargs): - future = Mock(spec=ResponseFuture) - future.has_more_pages = False - futures.append(future) - return future - - mock_session.execute_async.side_effect = create_future - - async_session = AsyncCassandraSession(mock_session) - - # Create multiple queries that timeout - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(side_effect=ReadTimeout("Timeout")) - mock_handler_class.return_value = mock_handler - - # Execute multiple queries that will timeout - for i in range(5): - with pytest.raises(ReadTimeout): - await async_session.execute(f"SELECT {i}") - - # Session should still be usable - assert not async_session.is_closed - - # New query should work - with patch("async_cassandra.session.AsyncResultHandler") as mock_handler_class: - mock_handler = Mock() - mock_handler.get_result = AsyncMock(return_value=Mock(rows=[{"id": 1}])) - mock_handler_class.return_value = mock_handler - - result = await async_session.execute("SELECT * FROM test") - assert len(result.rows) == 1 diff --git a/tests/unit/test_toctou_race_condition.py b/tests/unit/test_toctou_race_condition.py deleted file mode 100644 index 90fbc9b..0000000 --- a/tests/unit/test_toctou_race_condition.py +++ /dev/null @@ -1,481 +0,0 @@ -""" -Unit tests for TOCTOU (Time-of-Check-Time-of-Use) race condition in AsyncCloseable. - -TOCTOU Race Conditions Explained: -================================= -A TOCTOU race condition occurs when there's a gap between checking a condition -(Time-of-Check) and using that information (Time-of-Use). In our context: - -1. Thread A checks if session is closed (is_closed == False) -2. Thread B closes the session -3. Thread A tries to execute query on now-closed session -4. Result: Unexpected errors or undefined behavior - -These tests verify that our AsyncCassandraSession properly handles these race -conditions by ensuring atomicity between the check and the operation. - -Key Concepts: -- Atomicity: The check and operation must be indivisible -- Thread Safety: Operations must be safe when called concurrently -- Deterministic Behavior: Same conditions should produce same results -- Proper Error Handling: Errors should be predictable (ConnectionError) -""" - -import asyncio -from unittest.mock import Mock - -import pytest - -from async_cassandra.exceptions import ConnectionError -from async_cassandra.session import AsyncCassandraSession - - -@pytest.mark.asyncio -class TestTOCTOURaceCondition: - """ - Test TOCTOU race condition in is_closed checks. - - These tests simulate concurrent operations to verify that our session - implementation properly handles race conditions between checking if - the session is closed and performing operations. - - The tests use asyncio.create_task() and asyncio.gather() to simulate - true concurrent execution where operations can interleave at any point. - """ - - async def test_concurrent_close_and_execute(self): - """ - Test race condition between close() and execute(). - - Scenario: - --------- - 1. Two coroutines run concurrently: - - One tries to execute a query - - One tries to close the session - 2. The race occurs when: - - Execute checks is_closed (returns False) - - Close() sets is_closed to True and shuts down - - Execute tries to proceed with a closed session - - Expected Behavior: - ----------------- - With proper atomicity: - - If execute starts first: Query completes successfully - - If close completes first: Execute fails with ConnectionError - - No other errors should occur (no race condition errors) - - Implementation Details: - ---------------------- - - Uses asyncio.sleep(0.001) to increase chance of race - - Manually triggers callbacks to simulate driver responses - - Tracks whether a race condition was detected - """ - # Create session - mock_session = Mock() - mock_response_future = Mock() - mock_response_future.has_more_pages = False - mock_response_future.add_callbacks = Mock() - mock_response_future.timeout = None - mock_session.execute_async = Mock(return_value=mock_response_future) - mock_session.shutdown = Mock() # Add shutdown mock - async_session = AsyncCassandraSession(mock_session) - - # Track if race condition occurred - race_detected = False - execute_error = None - - async def close_session(): - """Close session after a small delay.""" - # Small delay to increase chance of race condition - await asyncio.sleep(0.001) - await async_session.close() - - async def execute_query(): - """Execute query that might race with close.""" - nonlocal race_detected, execute_error - try: - # Start execute task - task = asyncio.create_task(async_session.execute("SELECT * FROM test")) - - # Trigger the callback to simulate driver response - await asyncio.sleep(0) # Yield to let execute start - if mock_response_future.add_callbacks.called: - # Extract the callback function from the mock call - args = mock_response_future.add_callbacks.call_args - callback = args[1]["callback"] - # Simulate successful query response - callback(["row1"]) - - # Wait for result - await task - except ConnectionError as e: - execute_error = e - except Exception as e: - # If we get here, the race condition allowed execution - # after is_closed check passed but before actual execution - race_detected = True - execute_error = e - - # Run both concurrently - close_task = asyncio.create_task(close_session()) - execute_task = asyncio.create_task(execute_query()) - - await asyncio.gather(close_task, execute_task, return_exceptions=True) - - # With atomic operations, the behavior is deterministic: - # - If execute starts before close, it will complete successfully - # - If close completes before execute starts, we get ConnectionError - # No other errors should occur (no race conditions) - if execute_error is not None: - # If there was an error, it should only be ConnectionError - assert isinstance(execute_error, ConnectionError) - # No race condition detected - assert not race_detected - else: - # Execute succeeded - this is valid if it started before close - assert not race_detected - - async def test_multiple_concurrent_operations_during_close(self): - """ - Test multiple operations racing with close. - - Scenario: - --------- - This test simulates a real-world scenario where multiple different - operations (execute, prepare, execute_stream) are running concurrently - when a close() is initiated. This tests the atomicity of ALL operations, - not just execute. - - Race Conditions Being Tested: - ---------------------------- - 1. Execute query vs close - 2. Prepare statement vs close - 3. Execute stream vs close - All happening simultaneously! - - Expected Behavior: - ----------------- - Each operation should either: - - Complete successfully (if it started before close) - - Fail with ConnectionError (if close completed first) - - There should be NO mixed states or unexpected errors due to races. - - Implementation Details: - ---------------------- - - Creates separate mock futures for each operation type - - Tracks which operations succeed vs fail - - Verifies all failures are ConnectionError (not race errors) - - Uses operation_count to return different futures for different calls - """ - # Create session - mock_session = Mock() - - # Create separate mock futures for each operation - execute_future = Mock() - execute_future.has_more_pages = False - execute_future.timeout = None - execute_callbacks = [] - execute_future.add_callbacks = Mock( - side_effect=lambda callback=None, errback=None: execute_callbacks.append( - (callback, errback) - ) - ) - - prepare_future = Mock() - prepare_future.timeout = None - - stream_future = Mock() - stream_future.has_more_pages = False - stream_future.timeout = None - stream_callbacks = [] - stream_future.add_callbacks = Mock( - side_effect=lambda callback=None, errback=None: stream_callbacks.append( - (callback, errback) - ) - ) - - # Track which operation is being called - operation_count = 0 - - def mock_execute_async(*args, **kwargs): - nonlocal operation_count - operation_count += 1 - if operation_count == 1: - return execute_future - elif operation_count == 2: - return stream_future - else: - return execute_future - - mock_session.execute_async = Mock(side_effect=mock_execute_async) - mock_session.prepare = Mock(return_value=prepare_future) - mock_session.shutdown = Mock() - async_session = AsyncCassandraSession(mock_session) - - results = {"execute": None, "prepare": None, "execute_stream": None} - errors = {"execute": None, "prepare": None, "execute_stream": None} - - async def close_session(): - """Close session after small delay.""" - await asyncio.sleep(0.001) - await async_session.close() - - async def run_operations(): - """Run multiple operations that might race.""" - # Create tasks for each operation - tasks = [] - - # Execute - async def run_execute(): - try: - result_task = asyncio.create_task(async_session.execute("SELECT 1")) - # Let the operation start - await asyncio.sleep(0) - # Trigger callback if registered - if execute_callbacks: - callback, _ = execute_callbacks[0] - if callback: - callback(["row1"]) - await result_task - results["execute"] = "success" - except Exception as e: - errors["execute"] = e - - tasks.append(run_execute()) - - # Prepare - async def run_prepare(): - try: - await async_session.prepare("SELECT ?") - results["prepare"] = "success" - except Exception as e: - errors["prepare"] = e - - tasks.append(run_prepare()) - - # Execute stream - async def run_stream(): - try: - result_task = asyncio.create_task(async_session.execute_stream("SELECT 2")) - # Let the operation start - await asyncio.sleep(0) - # Trigger callback if registered - if stream_callbacks: - callback, _ = stream_callbacks[0] - if callback: - callback(["row2"]) - await result_task - results["execute_stream"] = "success" - except Exception as e: - errors["execute_stream"] = e - - tasks.append(run_stream()) - - # Run all operations concurrently - await asyncio.gather(*tasks, return_exceptions=True) - - # Run concurrently - await asyncio.gather(close_session(), run_operations(), return_exceptions=True) - - # All operations should either succeed or fail with ConnectionError - # Not a mix of behaviors due to race conditions - for op_name in ["execute", "prepare", "execute_stream"]: - if errors[op_name] is not None: - # This assertion will fail until race condition is fixed - assert isinstance( - errors[op_name], ConnectionError - ), f"{op_name} failed with {type(errors[op_name])} instead of ConnectionError" - - async def test_execute_after_close(self): - """ - Test that execute after close always fails with ConnectionError. - - This is the baseline test - no race condition here. - - Scenario: - --------- - 1. Close the session completely - 2. Try to execute a query - - Expected: - --------- - Should ALWAYS fail with ConnectionError and proper error message. - This tests the non-race condition case to ensure basic behavior works. - """ - # Create session - mock_session = Mock() - mock_session.shutdown = Mock() - async_session = AsyncCassandraSession(mock_session) - - # Close the session - await async_session.close() - - # Try to execute - should always fail with ConnectionError - with pytest.raises(ConnectionError, match="Session is closed"): - await async_session.execute("SELECT 1") - - async def test_is_closed_check_atomicity(self): - """ - Test that is_closed check and operation are atomic. - - This is the most complex test - it specifically targets the moment - between checking is_closed and starting the operation. - - Scenario: - --------- - 1. Thread A: Checks is_closed (returns False) - 2. Thread B: Waits for check to complete, then closes session - 3. Thread A: Tries to execute based on the is_closed check - - The Race Window: - --------------- - In broken code: - - is_closed check passes (False) - - close() happens before execute starts - - execute proceeds anyway → undefined behavior - - With Proper Atomicity: - -------------------- - The is_closed check and operation start must be atomic: - - Either both happen before close (success) - - Or both happen after close (ConnectionError) - - Never a mix! - - Implementation Details: - ---------------------- - - check_passed flag: Signals when is_closed returned False - - close_after_check: Waits for flag, then closes - - Tracks all state transitions to verify atomicity - """ - # Create session - mock_session = Mock() - - check_passed = False - operation_started = False - close_called = False - execute_callbacks = [] - - # Create a mock future that tracks callbacks - mock_response_future = Mock() - mock_response_future.has_more_pages = False - mock_response_future.timeout = None - mock_response_future.add_callbacks = Mock( - side_effect=lambda callback=None, errback=None: execute_callbacks.append( - (callback, errback) - ) - ) - - # Track when execute_async is called to detect the exact race timing - def tracked_execute(*args, **kwargs): - nonlocal operation_started - operation_started = True - # Return the mock future - this simulates the driver's async operation - return mock_response_future - - mock_session.execute_async = Mock(side_effect=tracked_execute) - mock_session.shutdown = Mock() - async_session = AsyncCassandraSession(mock_session) - - execute_task = None - execute_error = None - - async def execute_with_check(): - nonlocal check_passed, execute_task, execute_error - try: - # The is_closed check happens inside execute() - if not async_session.is_closed: - check_passed = True - # Start the execute operation - execute_task = asyncio.create_task(async_session.execute("SELECT 1")) - # Let it start - await asyncio.sleep(0) - # Trigger callback if registered - if execute_callbacks: - callback, _ = execute_callbacks[0] - if callback: - callback(["row1"]) - # Wait for completion - await execute_task - except Exception as e: - execute_error = e - - async def close_after_check(): - nonlocal close_called - # Wait for is_closed check to pass (returns False) - for _ in range(100): # Max 100 iterations - if check_passed: - break - await asyncio.sleep(0.001) - # Now close while execute might be in progress - # This is the critical moment - we're closing right after - # the is_closed check but possibly before execute starts - close_called = True - await async_session.close() - - # Run both concurrently - await asyncio.gather(execute_with_check(), close_after_check(), return_exceptions=True) - - # Check results - assert check_passed - assert close_called - - # With proper atomicity in the fixed implementation: - # Either the operation completes successfully (if it started before close) - # Or it fails with ConnectionError (if close happened first) - if execute_error is not None: - assert isinstance(execute_error, ConnectionError) - - async def test_close_close_race(self): - """ - Test concurrent close() calls. - - Scenario: - --------- - Multiple threads/coroutines all try to close the session at once. - This can happen in cleanup scenarios where multiple error handlers - or finalizers might try to ensure the session is closed. - - Expected Behavior: - ----------------- - - Only ONE actual close/shutdown should occur - - All close() calls should complete successfully - - No errors or exceptions - - is_closed should be True after all complete - - Why This Matters: - ---------------- - Without proper locking: - - Multiple threads might call shutdown() - - Could lead to errors or resource leaks - - State might become inconsistent - - Implementation: - -------------- - - Wraps shutdown() to count actual calls - - Runs 5 concurrent close() operations - - Verifies shutdown() called exactly once - """ - # Create session - mock_session = Mock() - mock_session.shutdown = Mock() - async_session = AsyncCassandraSession(mock_session) - - close_count = 0 - original_shutdown = async_session._session.shutdown - - def count_closes(): - nonlocal close_count - close_count += 1 - return original_shutdown() - - async_session._session.shutdown = count_closes - - # Multiple concurrent closes - tasks = [async_session.close() for _ in range(5)] - await asyncio.gather(*tasks) - - # Should only close once despite concurrent calls - # This test should pass as the lock prevents multiple closes - assert close_count == 1 - assert async_session.is_closed diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py deleted file mode 100644 index 0e23ca6..0000000 --- a/tests/unit/test_utils.py +++ /dev/null @@ -1,537 +0,0 @@ -""" -Unit tests for utils module. -""" - -import asyncio -import threading -from unittest.mock import Mock, patch - -import pytest - -from async_cassandra.utils import get_or_create_event_loop, safe_call_soon_threadsafe - - -class TestGetOrCreateEventLoop: - """Test get_or_create_event_loop function.""" - - @pytest.mark.asyncio - async def test_get_existing_loop(self): - """ - Test getting existing event loop. - - What this tests: - --------------- - 1. Returns current running loop - 2. Doesn't create new loop - 3. Type is AbstractEventLoop - 4. Works in async context - - Why this matters: - ---------------- - Reusing existing loops: - - Prevents loop conflicts - - Maintains event ordering - - Avoids resource waste - - Critical for proper async - integration. - """ - # Inside an async function, there's already a loop - loop = get_or_create_event_loop() - assert loop is asyncio.get_running_loop() - assert isinstance(loop, asyncio.AbstractEventLoop) - - def test_create_new_loop_when_none_exists(self): - """ - Test creating new loop when none exists. - - What this tests: - --------------- - 1. Creates loop in thread - 2. No pre-existing loop - 3. Returns valid loop - 4. Thread-safe creation - - Why this matters: - ---------------- - Background threads need loops: - - Driver callbacks - - Thread pool tasks - - Cross-thread communication - - Automatic loop creation enables - seamless async operations. - """ - # Run in a thread without event loop - result = {"loop": None, "created": False} - - def run_in_thread(): - # Ensure no event loop exists - try: - asyncio.get_running_loop() - result["created"] = False - except RuntimeError: - # Good, no loop exists - result["created"] = True - - # Get or create loop - loop = get_or_create_event_loop() - result["loop"] = loop - - thread = threading.Thread(target=run_in_thread) - thread.start() - thread.join() - - assert result["created"] is True - assert result["loop"] is not None - assert isinstance(result["loop"], asyncio.AbstractEventLoop) - - def test_creates_and_sets_event_loop(self): - """ - Test that function sets the created loop as current. - - What this tests: - --------------- - 1. New loop becomes current - 2. set_event_loop called - 3. Future calls return same - 4. Thread-local storage - - Why this matters: - ---------------- - Setting as current enables: - - asyncio.get_event_loop() - - Task scheduling - - Coroutine execution - - Required for asyncio to - function properly. - """ - # Mock to control behavior - mock_loop = Mock(spec=asyncio.AbstractEventLoop) - - with patch("asyncio.get_running_loop", side_effect=RuntimeError): - with patch("asyncio.new_event_loop", return_value=mock_loop): - with patch("asyncio.set_event_loop") as mock_set: - loop = get_or_create_event_loop() - - assert loop is mock_loop - mock_set.assert_called_once_with(mock_loop) - - @pytest.mark.asyncio - async def test_concurrent_calls_return_same_loop(self): - """ - Test concurrent calls return the same loop in async context. - - What this tests: - --------------- - 1. Multiple calls same result - 2. No duplicate loops - 3. Consistent behavior - 4. Thread-safe access - - Why this matters: - ---------------- - Loop consistency critical: - - Tasks run on same loop - - Callbacks properly scheduled - - No cross-loop issues - - Prevents subtle async bugs - from loop confusion. - """ - # In async context, they should all get the current running loop - current_loop = asyncio.get_running_loop() - - # Get multiple references - loop1 = get_or_create_event_loop() - loop2 = get_or_create_event_loop() - loop3 = get_or_create_event_loop() - - # All should be the same loop - assert loop1 is current_loop - assert loop2 is current_loop - assert loop3 is current_loop - - -class TestSafeCallSoonThreadsafe: - """Test safe_call_soon_threadsafe function.""" - - def test_with_valid_loop(self): - """ - Test calling with valid event loop. - - What this tests: - --------------- - 1. Delegates to loop method - 2. Args passed correctly - 3. Normal operation path - 4. No error handling needed - - Why this matters: - ---------------- - Happy path must work: - - Most common case - - Performance critical - - No overhead added - - Ensures wrapper doesn't - break normal operation. - """ - mock_loop = Mock(spec=asyncio.AbstractEventLoop) - callback = Mock() - - safe_call_soon_threadsafe(mock_loop, callback, "arg1", "arg2") - - mock_loop.call_soon_threadsafe.assert_called_once_with(callback, "arg1", "arg2") - - def test_with_none_loop(self): - """ - Test calling with None loop. - - What this tests: - --------------- - 1. None loop handled gracefully - 2. No exception raised - 3. Callback not executed - 4. Silent failure mode - - Why this matters: - ---------------- - Defensive programming: - - Shutdown scenarios - - Initialization races - - Error conditions - - Prevents crashes from - unexpected None values. - """ - callback = Mock() - - # Should not raise exception - safe_call_soon_threadsafe(None, callback, "arg1", "arg2") - - # Callback should not be called - callback.assert_not_called() - - def test_with_closed_loop(self): - """ - Test calling with closed event loop. - - What this tests: - --------------- - 1. RuntimeError caught - 2. Warning logged - 3. No exception propagated - 4. Graceful degradation - - Why this matters: - ---------------- - Closed loops common during: - - Application shutdown - - Test teardown - - Error recovery - - Must handle gracefully to - prevent shutdown hangs. - """ - mock_loop = Mock(spec=asyncio.AbstractEventLoop) - mock_loop.call_soon_threadsafe.side_effect = RuntimeError("Event loop is closed") - callback = Mock() - - # Should not raise exception - with patch("async_cassandra.utils.logger") as mock_logger: - safe_call_soon_threadsafe(mock_loop, callback, "arg1", "arg2") - - # Should log warning - mock_logger.warning.assert_called_once() - assert "Failed to schedule callback" in mock_logger.warning.call_args[0][0] - - def test_with_various_callback_types(self): - """ - Test with different callback types. - - What this tests: - --------------- - 1. Regular functions work - 2. Lambda functions work - 3. Class methods work - 4. All args preserved - - Why this matters: - ---------------- - Flexible callback support: - - Library callbacks - - User callbacks - - Framework integration - - Must handle all Python - callable types correctly. - """ - mock_loop = Mock(spec=asyncio.AbstractEventLoop) - - # Regular function - def regular_func(x, y): - return x + y - - safe_call_soon_threadsafe(mock_loop, regular_func, 1, 2) - mock_loop.call_soon_threadsafe.assert_called_with(regular_func, 1, 2) - - # Lambda - def lambda_func(x): - return x * 2 - - safe_call_soon_threadsafe(mock_loop, lambda_func, 5) - mock_loop.call_soon_threadsafe.assert_called_with(lambda_func, 5) - - # Method - class TestClass: - def method(self, x): - return x - - obj = TestClass() - safe_call_soon_threadsafe(mock_loop, obj.method, 10) - mock_loop.call_soon_threadsafe.assert_called_with(obj.method, 10) - - def test_no_args(self): - """ - Test calling with no arguments. - - What this tests: - --------------- - 1. Zero args supported - 2. Callback still scheduled - 3. No TypeError raised - 4. Varargs handling works - - Why this matters: - ---------------- - Simple callbacks common: - - Event notifications - - State changes - - Cleanup functions - - Must support parameterless - callback functions. - """ - mock_loop = Mock(spec=asyncio.AbstractEventLoop) - callback = Mock() - - safe_call_soon_threadsafe(mock_loop, callback) - - mock_loop.call_soon_threadsafe.assert_called_once_with(callback) - - def test_many_args(self): - """ - Test calling with many arguments. - - What this tests: - --------------- - 1. Many args supported - 2. All args preserved - 3. Order maintained - 4. No arg limit - - Why this matters: - ---------------- - Complex callbacks exist: - - Result processing - - Multi-param handlers - - Framework callbacks - - Must handle arbitrary - argument counts. - """ - mock_loop = Mock(spec=asyncio.AbstractEventLoop) - callback = Mock() - - args = list(range(10)) - safe_call_soon_threadsafe(mock_loop, callback, *args) - - mock_loop.call_soon_threadsafe.assert_called_once_with(callback, *args) - - @pytest.mark.asyncio - async def test_real_event_loop_integration(self): - """ - Test with real event loop. - - What this tests: - --------------- - 1. Cross-thread scheduling - 2. Real loop execution - 3. Args passed correctly - 4. Async/sync bridge works - - Why this matters: - ---------------- - Real-world usage pattern: - - Driver thread callbacks - - Background operations - - Event notifications - - Verifies actual cross-thread - callback execution. - """ - loop = asyncio.get_running_loop() - result = {"called": False, "args": None} - - def callback(*args): - result["called"] = True - result["args"] = args - - # Call from another thread - def call_from_thread(): - safe_call_soon_threadsafe(loop, callback, "test", 123) - - thread = threading.Thread(target=call_from_thread) - thread.start() - thread.join() - - # Give the loop a chance to process the callback - await asyncio.sleep(0.1) - - assert result["called"] is True - assert result["args"] == ("test", 123) - - def test_exception_in_callback_scheduling(self): - """ - Test handling of exceptions during scheduling. - - What this tests: - --------------- - 1. Generic exceptions caught - 2. No exception propagated - 3. Different from RuntimeError - 4. Robust error handling - - Why this matters: - ---------------- - Unexpected errors happen: - - Implementation bugs - - Resource exhaustion - - Platform issues - - Must never crash from - scheduling failures. - """ - mock_loop = Mock(spec=asyncio.AbstractEventLoop) - mock_loop.call_soon_threadsafe.side_effect = Exception("Unexpected error") - callback = Mock() - - # Should handle any exception type gracefully - with patch("async_cassandra.utils.logger") as mock_logger: - # This should not raise - try: - safe_call_soon_threadsafe(mock_loop, callback) - except Exception: - pytest.fail("safe_call_soon_threadsafe should not raise exceptions") - - # Should still log warning for non-RuntimeError - mock_logger.warning.assert_not_called() # Only logs for RuntimeError - - -class TestUtilsModuleAttributes: - """Test module-level attributes and imports.""" - - def test_logger_configured(self): - """ - Test that logger is properly configured. - - What this tests: - --------------- - 1. Logger exists - 2. Correct name set - 3. Module attribute present - 4. Standard naming convention - - Why this matters: - ---------------- - Proper logging enables: - - Debugging issues - - Monitoring behavior - - Error tracking - - Consistent logger naming - aids troubleshooting. - """ - import async_cassandra.utils - - assert hasattr(async_cassandra.utils, "logger") - assert async_cassandra.utils.logger.name == "async_cassandra.utils" - - def test_public_api(self): - """ - Test that public API is as expected. - - What this tests: - --------------- - 1. Expected functions exist - 2. No extra exports - 3. Clean public API - 4. No implementation leaks - - Why this matters: - ---------------- - API stability critical: - - Backward compatibility - - Clear contracts - - No accidental exports - - Prevents breaking changes - to public interface. - """ - import async_cassandra.utils - - # Expected public functions - expected_functions = {"get_or_create_event_loop", "safe_call_soon_threadsafe"} - - # Get actual public functions - actual_functions = { - name - for name in dir(async_cassandra.utils) - if not name.startswith("_") and callable(getattr(async_cassandra.utils, name)) - } - - # Remove imports that aren't our functions - actual_functions.discard("asyncio") - actual_functions.discard("logging") - actual_functions.discard("Any") - actual_functions.discard("Optional") - - assert actual_functions == expected_functions - - def test_type_annotations(self): - """ - Test that functions have proper type annotations. - - What this tests: - --------------- - 1. Return types annotated - 2. Parameter types present - 3. Correct type usage - 4. Type safety enabled - - Why this matters: - ---------------- - Type annotations enable: - - IDE autocomplete - - Static type checking - - Better documentation - - Improves developer experience - and catches type errors. - """ - import inspect - - from async_cassandra.utils import get_or_create_event_loop, safe_call_soon_threadsafe - - # Check get_or_create_event_loop - sig = inspect.signature(get_or_create_event_loop) - assert sig.return_annotation == asyncio.AbstractEventLoop - - # Check safe_call_soon_threadsafe - sig = inspect.signature(safe_call_soon_threadsafe) - params = sig.parameters - assert "loop" in params - assert "callback" in params - assert "args" in params diff --git a/tests/utils/cassandra_control.py b/tests/utils/cassandra_control.py deleted file mode 100644 index 64a29c9..0000000 --- a/tests/utils/cassandra_control.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Unified Cassandra control interface for tests. - -This module provides a unified interface for controlling Cassandra in tests, -supporting both local container environments and CI service environments. -""" - -import os -import subprocess -import time -from typing import Tuple - -import pytest - - -class CassandraControl: - """Provides unified control interface for Cassandra in different environments.""" - - def __init__(self, container=None): - """Initialize with optional container reference.""" - self.container = container - self.is_ci = os.environ.get("CI") == "true" - - def execute_nodetool_command(self, command: str) -> subprocess.CompletedProcess: - """Execute a nodetool command, handling both container and CI environments. - - In CI environments where Cassandra runs as a service, this will skip the test. - - Args: - command: The nodetool command to execute (e.g., "disablebinary", "enablebinary") - - Returns: - CompletedProcess with returncode, stdout, and stderr - """ - if self.is_ci: - # In CI, we can't control the Cassandra service - pytest.skip("Cannot control Cassandra service in CI environment") - - # In local environment, execute in container - if not self.container: - raise ValueError("Container reference required for non-CI environments") - - container_ref = ( - self.container.container_name - if hasattr(self.container, "container_name") and self.container.container_name - else self.container.container_id - ) - - return subprocess.run( - [self.container.runtime, "exec", container_ref, "nodetool", command], - capture_output=True, - text=True, - ) - - def wait_for_cassandra_ready(self, host: str = "127.0.0.1", timeout: int = 30) -> bool: - """Wait for Cassandra to be ready by executing a test query with cqlsh. - - This works in both container and CI environments. - """ - start_time = time.time() - while time.time() - start_time < timeout: - try: - result = subprocess.run( - ["cqlsh", host, "-e", "SELECT release_version FROM system.local;"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0: - return True - except (subprocess.TimeoutExpired, Exception): - pass - time.sleep(0.5) - return False - - def wait_for_cassandra_down(self, host: str = "127.0.0.1", timeout: int = 10) -> bool: - """Wait for Cassandra to be down by checking if cqlsh fails. - - This works in both container and CI environments. - """ - if self.is_ci: - # In CI, Cassandra service is always running - pytest.skip("Cannot control Cassandra service in CI environment") - - start_time = time.time() - while time.time() - start_time < timeout: - try: - result = subprocess.run( - ["cqlsh", host, "-e", "SELECT 1;"], - capture_output=True, - text=True, - timeout=2, - ) - if result.returncode != 0: - return True - except (subprocess.TimeoutExpired, Exception): - return True - time.sleep(0.5) - return False - - def disable_binary_protocol(self) -> Tuple[bool, str]: - """Disable Cassandra binary protocol. - - Returns: - Tuple of (success, message) - """ - result = self.execute_nodetool_command("disablebinary") - if result.returncode == 0: - return True, "Binary protocol disabled" - return False, f"Failed to disable binary protocol: {result.stderr}" - - def enable_binary_protocol(self) -> Tuple[bool, str]: - """Enable Cassandra binary protocol. - - Returns: - Tuple of (success, message) - """ - result = self.execute_nodetool_command("enablebinary") - if result.returncode == 0: - return True, "Binary protocol enabled" - return False, f"Failed to enable binary protocol: {result.stderr}" - - def simulate_outage(self) -> bool: - """Simulate a Cassandra outage. - - In CI, this will skip the test. - """ - if self.is_ci: - # In CI, we can't actually create an outage - pytest.skip("Cannot control Cassandra service in CI environment") - - success, _ = self.disable_binary_protocol() - if success: - return self.wait_for_cassandra_down() - return False - - def restore_service(self) -> bool: - """Restore Cassandra service after simulated outage. - - In CI, this will skip the test. - """ - if self.is_ci: - # In CI, service is always running - pytest.skip("Cannot control Cassandra service in CI environment") - - success, _ = self.enable_binary_protocol() - if success: - return self.wait_for_cassandra_ready() - return False diff --git a/tests/utils/cassandra_health.py b/tests/utils/cassandra_health.py deleted file mode 100644 index b94a0b5..0000000 --- a/tests/utils/cassandra_health.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -Shared utilities for Cassandra health checks across test suites. -""" - -import subprocess -import time -from typing import Dict, Optional - - -def check_cassandra_health( - runtime: str, container_name_or_id: str, timeout: float = 5.0 -) -> Dict[str, bool]: - """ - Check Cassandra health using nodetool info. - - Args: - runtime: Container runtime (docker or podman) - container_name_or_id: Container name or ID - timeout: Timeout for each command - - Returns: - Dictionary with health status: - - native_transport: Whether native transport is active - - gossip: Whether gossip is active - - cql_available: Whether CQL queries work - """ - health_status = { - "native_transport": False, - "gossip": False, - "cql_available": False, - } - - try: - # Run nodetool info - result = subprocess.run( - [runtime, "exec", container_name_or_id, "nodetool", "info"], - capture_output=True, - text=True, - timeout=timeout, - ) - - if result.returncode == 0: - info = result.stdout - health_status["native_transport"] = "Native Transport active: true" in info - - # Parse gossip status more carefully - if "Gossip active" in info: - gossip_line = info.split("Gossip active")[1].split("\n")[0] - health_status["gossip"] = "true" in gossip_line - - # Check CQL availability - cql_result = subprocess.run( - [ - runtime, - "exec", - container_name_or_id, - "cqlsh", - "-e", - "SELECT now() FROM system.local", - ], - capture_output=True, - timeout=timeout, - ) - health_status["cql_available"] = cql_result.returncode == 0 - except subprocess.TimeoutExpired: - pass - except Exception: - pass - - return health_status - - -def wait_for_cassandra_health( - runtime: str, - container_name_or_id: str, - timeout: int = 90, - check_interval: float = 3.0, - required_checks: Optional[list] = None, -) -> bool: - """ - Wait for Cassandra to be healthy. - - Args: - runtime: Container runtime (docker or podman) - container_name_or_id: Container name or ID - timeout: Maximum time to wait in seconds - check_interval: Time between health checks - required_checks: List of required health checks (default: native_transport and cql_available) - - Returns: - True if healthy within timeout, False otherwise - """ - if required_checks is None: - required_checks = ["native_transport", "cql_available"] - - start_time = time.time() - while time.time() - start_time < timeout: - health = check_cassandra_health(runtime, container_name_or_id) - - if all(health.get(check, False) for check in required_checks): - return True - - time.sleep(check_interval) - - return False - - -def ensure_cassandra_healthy(runtime: str, container_name_or_id: str) -> Dict[str, bool]: - """ - Ensure Cassandra is healthy, raising an exception if not. - - Args: - runtime: Container runtime (docker or podman) - container_name_or_id: Container name or ID - - Returns: - Health status dictionary - - Raises: - RuntimeError: If Cassandra is not healthy - """ - health = check_cassandra_health(runtime, container_name_or_id) - - if not health["native_transport"] or not health["cql_available"]: - raise RuntimeError( - f"Cassandra is not healthy: Native Transport={health['native_transport']}, " - f"CQL Available={health['cql_available']}" - ) - - return health