Skip to content

Commit cb9a2c5

Browse files
committed
postgresql loader: Add reorg aware streaming support
1 parent cc5730f commit cb9a2c5

File tree

3 files changed

+315
-8
lines changed

3 files changed

+315
-8
lines changed

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ addopts = [
9898
"--tb=short",
9999
"--strict-markers",
100100
]
101-
# Timeout configuration for longer-running integration tests
102-
timeout = 300 # 5 minutes per test
103-
timeout_method = "thread"
104101

105102
markers = [
106103
"unit: Unit tests (fast, no external dependencies)",

src/amp/loaders/implementations/postgresql_loader.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, Optional, Union
2+
from typing import Any, Dict, List, Optional, Union
33

44
import pyarrow as pa
55
from psycopg2.pool import ThreadedConnectionPool
66

7+
from ...streaming.types import BlockRange
78
from ..base import DataLoader, LoadMode
89
from ._postgres_helpers import has_binary_columns, prepare_csv_data, prepare_insert_data
910

@@ -120,7 +121,8 @@ def _clear_table(self, table_name: str) -> None:
120121

121122
def _copy_arrow_data(self, cursor: Any, data: Union[pa.RecordBatch, pa.Table], table_name: str) -> None:
122123
"""Copy Arrow data to PostgreSQL using optimal method based on data types."""
123-
if has_binary_columns(data.schema):
124+
# Use INSERT for data with binary columns OR metadata columns (JSONB/range types need special handling)
125+
if has_binary_columns(data.schema) or '_meta_block_ranges' in data.schema.names:
124126
self._insert_arrow_data(cursor, data, table_name)
125127
else:
126128
self._csv_copy_arrow_data(cursor, data, table_name)
@@ -160,7 +162,7 @@ def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
160162
# Check if table already exists to avoid unnecessary work
161163
cursor.execute(
162164
"""
163-
SELECT 1 FROM information_schema.tables
165+
SELECT 1 FROM information_schema.tables
164166
WHERE table_name = %s AND table_schema = 'public'
165167
""",
166168
(table_name,),
@@ -205,9 +207,18 @@ def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
205207

206208
# Build CREATE TABLE statement
207209
columns = []
210+
# Check if this is streaming data with metadata columns
211+
has_metadata = any(field.name.startswith('_meta_') for field in schema)
212+
208213
for field in schema:
214+
# Skip generic metadata columns - we'll use _meta_block_range instead
215+
if field.name in ('_meta_range_start', '_meta_range_end'):
216+
continue
217+
# Special handling for JSONB metadata column
218+
elif field.name == '_meta_block_ranges':
219+
pg_type = 'JSONB'
209220
# Handle complex types
210-
if pa.types.is_timestamp(field.type):
221+
elif pa.types.is_timestamp(field.type):
211222
# Handle timezone-aware timestamps
212223
if field.type.tz is not None:
213224
pg_type = 'TIMESTAMPTZ'
@@ -246,6 +257,14 @@ def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
246257
# Quote column name for safety (important for blockchain field names)
247258
columns.append(f'"{field.name}" {pg_type}{nullable}')
248259

260+
# Add metadata columns for streaming/reorg support if this is streaming data
261+
# but only if they don't already exist in the schema
262+
if has_metadata:
263+
schema_field_names = [field.name for field in schema]
264+
if '_meta_block_ranges' not in schema_field_names:
265+
# Use JSONB for multi-network block ranges with GIN index support
266+
columns.append('"_meta_block_ranges" JSONB')
267+
249268
# Create the table - Fixed: use proper identifier quoting
250269
create_sql = f"""
251270
CREATE TABLE IF NOT EXISTS {table_name} (
@@ -272,7 +291,7 @@ def get_table_schema(self, table_name: str) -> Optional[pa.Schema]:
272291
cur.execute(
273292
"""
274293
SELECT column_name, data_type, is_nullable
275-
FROM information_schema.columns
294+
FROM information_schema.columns
276295
WHERE table_name = %s
277296
ORDER BY ordinal_position
278297
""",
@@ -328,3 +347,70 @@ def _pg_type_to_arrow(self, pg_type: str) -> pa.DataType:
328347
return pa.decimal128(18, 6) # Default precision/scale
329348

330349
return type_mapping.get(pg_type, pa.string()) # Default to string
350+
351+
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str) -> None:
352+
"""
353+
Handle blockchain reorganization by deleting affected rows using PostgreSQL JSONB operations.
354+
355+
In blockchain reorgs, if block N gets reorganized, ALL blocks >= N become invalid
356+
because the chain has forked from that point. This method deletes all data
357+
from the reorg point forward for each affected network, including ranges that overlap.
358+
359+
Args:
360+
invalidation_ranges: List of block ranges to invalidate (reorg points)
361+
table_name: The table containing the data to invalidate
362+
"""
363+
if not invalidation_ranges:
364+
return
365+
366+
conn = self.pool.getconn()
367+
try:
368+
with conn.cursor() as cur:
369+
# Build WHERE clause using JSONB operators for multi-network support
370+
# For blockchain reorgs: if reorg starts at block N, delete all data that
371+
# either starts >= N OR overlaps with N (range_end >= N)
372+
where_conditions = []
373+
params = []
374+
375+
for range_obj in invalidation_ranges:
376+
# Delete all data from reorg point forward for this network
377+
# Check if JSONB array contains any range where:
378+
# 1. Network matches
379+
# 2. Range end >= reorg start (catches both overlap and forward cases)
380+
where_conditions.append("""
381+
EXISTS (
382+
SELECT 1 FROM jsonb_array_elements("_meta_block_ranges") AS range_elem
383+
WHERE range_elem->>'network' = %s
384+
AND (range_elem->>'end')::int >= %s
385+
)
386+
""")
387+
params.extend(
388+
[
389+
range_obj.network,
390+
range_obj.start, # Delete everything where range_end >= reorg_start
391+
]
392+
)
393+
394+
# Combine conditions with OR (if any network has reorg, delete the row)
395+
where_clause = ' OR '.join(where_conditions)
396+
397+
# Execute deletion
398+
delete_sql = f'DELETE FROM {table_name} WHERE {where_clause}'
399+
400+
self.logger.info(
401+
f'Executing blockchain reorg deletion for {len(invalidation_ranges)} networks '
402+
f"in table '{table_name}'"
403+
)
404+
self.logger.debug(f'Delete SQL: {delete_sql} with params: {params}')
405+
406+
cur.execute(delete_sql, params)
407+
deleted_rows = cur.rowcount
408+
conn.commit()
409+
410+
self.logger.info(f"Blockchain reorg deleted {deleted_rows} rows from table '{table_name}'")
411+
412+
except Exception as e:
413+
self.logger.error(f"Failed to handle blockchain reorg for table '{table_name}': {str(e)}")
414+
raise
415+
finally:
416+
self.pool.putconn(conn)

tests/integration/test_postgresql_loader.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,227 @@ def test_large_data_loading(self, postgresql_test_config, test_table_name, clean
425425
assert count == 50000
426426
finally:
427427
loader.pool.putconn(conn)
428+
429+
430+
@pytest.mark.integration
431+
@pytest.mark.postgresql
432+
class TestPostgreSQLLoaderStreaming:
433+
"""Integration tests for PostgreSQL loader streaming functionality"""
434+
435+
def test_streaming_metadata_columns(self, postgresql_test_config, test_table_name, cleanup_tables):
436+
"""Test that streaming data creates tables with metadata columns"""
437+
cleanup_tables.append(test_table_name)
438+
439+
# Import streaming types
440+
from src.amp.streaming.types import BlockRange
441+
442+
# Create test data with metadata
443+
data = {
444+
'block_number': [100, 101, 102],
445+
'transaction_hash': ['0xabc', '0xdef', '0x123'],
446+
'value': [1.0, 2.0, 3.0],
447+
}
448+
batch = pa.RecordBatch.from_pydict(data)
449+
450+
# Create metadata with block ranges
451+
block_ranges = [BlockRange(network='ethereum', start=100, end=102)]
452+
453+
loader = PostgreSQLLoader(postgresql_test_config)
454+
455+
with loader:
456+
# Add metadata columns (simulating what load_stream_continuous does)
457+
batch_with_metadata = loader._add_metadata_columns(batch, block_ranges)
458+
459+
# Load the batch
460+
result = loader.load_batch(batch_with_metadata, test_table_name, create_table=True)
461+
assert result.success == True
462+
assert result.rows_loaded == 3
463+
464+
# Verify metadata columns were created in the table
465+
conn = loader.pool.getconn()
466+
try:
467+
with conn.cursor() as cur:
468+
# Check table schema includes metadata columns
469+
cur.execute(
470+
"""
471+
SELECT column_name, data_type
472+
FROM information_schema.columns
473+
WHERE table_name = %s
474+
ORDER BY ordinal_position
475+
""",
476+
(test_table_name,),
477+
)
478+
479+
columns = cur.fetchall()
480+
column_names = [col[0] for col in columns]
481+
482+
# Should have original columns plus metadata columns
483+
assert '_meta_block_ranges' in column_names
484+
485+
# Verify metadata column types
486+
column_types = {col[0]: col[1] for col in columns}
487+
assert 'jsonb' in column_types['_meta_block_ranges'].lower()
488+
489+
# Verify data was stored correctly
490+
cur.execute(f'SELECT "_meta_block_ranges" FROM {test_table_name} LIMIT 1')
491+
meta_row = cur.fetchone()
492+
493+
# PostgreSQL JSONB automatically parses to Python objects
494+
ranges_data = meta_row[0] # Already parsed by psycopg2
495+
assert len(ranges_data) == 1
496+
assert ranges_data[0]['network'] == 'ethereum'
497+
assert ranges_data[0]['start'] == 100
498+
assert ranges_data[0]['end'] == 102
499+
500+
finally:
501+
loader.pool.putconn(conn)
502+
503+
def test_handle_reorg_deletion(self, postgresql_test_config, test_table_name, cleanup_tables):
504+
"""Test that _handle_reorg correctly deletes invalidated ranges"""
505+
cleanup_tables.append(test_table_name)
506+
507+
from src.amp.streaming.types import BlockRange
508+
509+
loader = PostgreSQLLoader(postgresql_test_config)
510+
511+
with loader:
512+
# Create table and load test data with multiple block ranges
513+
data_batch1 = {
514+
'tx_hash': ['0x100', '0x101', '0x102'],
515+
'block_num': [100, 101, 102],
516+
'value': [10.0, 11.0, 12.0],
517+
}
518+
batch1 = pa.RecordBatch.from_pydict(data_batch1)
519+
ranges1 = [BlockRange(network='ethereum', start=100, end=102)]
520+
batch1_with_meta = loader._add_metadata_columns(batch1, ranges1)
521+
522+
data_batch2 = {'tx_hash': ['0x200', '0x201'], 'block_num': [103, 104], 'value': [12.0, 33.0]}
523+
batch2 = pa.RecordBatch.from_pydict(data_batch2)
524+
ranges2 = [BlockRange(network='ethereum', start=103, end=104)]
525+
batch2_with_meta = loader._add_metadata_columns(batch2, ranges2)
526+
527+
data_batch3 = {'tx_hash': ['0x200', '0x201'], 'block_num': [105, 106], 'value': [7.0, 9.0]}
528+
batch3 = pa.RecordBatch.from_pydict(data_batch3)
529+
ranges3 = [BlockRange(network='ethereum', start=103, end=104)]
530+
batch3_with_meta = loader._add_metadata_columns(batch3, ranges3)
531+
532+
data_batch4 = {'tx_hash': ['0x200', '0x201'], 'block_num': [107, 108], 'value': [6.0, 73.0]}
533+
batch4 = pa.RecordBatch.from_pydict(data_batch4)
534+
ranges4 = [BlockRange(network='ethereum', start=103, end=104)]
535+
batch4_with_meta = loader._add_metadata_columns(batch4, ranges4)
536+
537+
# Load all batches
538+
result1 = loader.load_batch(batch1_with_meta, test_table_name, create_table=True)
539+
result2 = loader.load_batch(batch2_with_meta, test_table_name, create_table=False)
540+
result3 = loader.load_batch(batch3_with_meta, test_table_name, create_table=False)
541+
result4 = loader.load_batch(batch4_with_meta, test_table_name, create_table=False)
542+
543+
assert all([result1.success, result2.success, result3.success, result4.success])
544+
545+
# Verify initial data count
546+
conn = loader.pool.getconn()
547+
try:
548+
with conn.cursor() as cur:
549+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
550+
initial_count = cur.fetchone()[0]
551+
assert initial_count == 9 # 3 + 2 + 2 + 2
552+
553+
# Test reorg deletion - invalidate blocks 104-108 on ethereum
554+
invalidation_ranges = [BlockRange(network='ethereum', start=104, end=108)]
555+
loader._handle_reorg(invalidation_ranges, test_table_name)
556+
557+
# Should delete batch2, batch3 and batch4 leaving only the 3 rows from batch1
558+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
559+
after_reorg_count = cur.fetchone()[0]
560+
assert after_reorg_count == 3
561+
562+
finally:
563+
loader.pool.putconn(conn)
564+
565+
def test_reorg_with_overlapping_ranges(self, postgresql_test_config, test_table_name, cleanup_tables):
566+
"""Test reorg deletion with overlapping block ranges"""
567+
cleanup_tables.append(test_table_name)
568+
569+
from src.amp.streaming.types import BlockRange
570+
571+
loader = PostgreSQLLoader(postgresql_test_config)
572+
573+
with loader:
574+
# Load data with overlapping ranges that should be invalidated
575+
data = {'tx_hash': ['0x150', '0x175', '0x250'], 'block_num': [150, 175, 250], 'value': [15.0, 17.5, 25.0]}
576+
batch = pa.RecordBatch.from_pydict(data)
577+
ranges = [BlockRange(network='ethereum', start=150, end=175)]
578+
batch_with_meta = loader._add_metadata_columns(batch, ranges)
579+
580+
result = loader.load_batch(batch_with_meta, test_table_name, create_table=True)
581+
assert result.success == True
582+
583+
conn = loader.pool.getconn()
584+
try:
585+
with conn.cursor() as cur:
586+
# Verify initial data
587+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
588+
assert cur.fetchone()[0] == 3
589+
590+
# Test partial overlap invalidation (160-180)
591+
# This should invalidate our range [150,175] because they overlap
592+
invalidation_ranges = [BlockRange(network='ethereum', start=160, end=180)]
593+
loader._handle_reorg(invalidation_ranges, test_table_name)
594+
595+
# All data should be deleted due to overlap
596+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
597+
assert cur.fetchone()[0] == 0
598+
599+
finally:
600+
loader.pool.putconn(conn)
601+
602+
def test_reorg_preserves_different_networks(self, postgresql_test_config, test_table_name, cleanup_tables):
603+
"""Test that reorg only affects specified network"""
604+
cleanup_tables.append(test_table_name)
605+
606+
from src.amp.streaming.types import BlockRange
607+
608+
loader = PostgreSQLLoader(postgresql_test_config)
609+
610+
with loader:
611+
# Load data from multiple networks with same block ranges
612+
data_eth = {'tx_hash': ['0x100_eth'], 'network_id': ['ethereum'], 'block_num': [100], 'value': [10.0]}
613+
batch_eth = pa.RecordBatch.from_pydict(data_eth)
614+
ranges_eth = [BlockRange(network='ethereum', start=100, end=100)]
615+
batch_eth_with_meta = loader._add_metadata_columns(batch_eth, ranges_eth)
616+
617+
data_poly = {'tx_hash': ['0x100_poly'], 'network_id': ['polygon'], 'block_num': [100], 'value': [10.0]}
618+
batch_poly = pa.RecordBatch.from_pydict(data_poly)
619+
ranges_poly = [BlockRange(network='polygon', start=100, end=100)]
620+
batch_poly_with_meta = loader._add_metadata_columns(batch_poly, ranges_poly)
621+
622+
# Load both batches
623+
result1 = loader.load_batch(batch_eth_with_meta, test_table_name, create_table=True)
624+
result2 = loader.load_batch(batch_poly_with_meta, test_table_name, create_table=False)
625+
626+
assert result1.success and result2.success
627+
628+
conn = loader.pool.getconn()
629+
try:
630+
with conn.cursor() as cur:
631+
# Verify both networks' data exists
632+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
633+
assert cur.fetchone()[0] == 2
634+
635+
# Invalidate only ethereum network
636+
invalidation_ranges = [BlockRange(network='ethereum', start=100, end=100)]
637+
loader._handle_reorg(invalidation_ranges, test_table_name)
638+
639+
# Should only delete ethereum data, polygon should remain
640+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
641+
assert cur.fetchone()[0] == 1
642+
643+
# Verify remaining data is from polygon
644+
cur.execute(f'SELECT "_meta_block_ranges" FROM {test_table_name}')
645+
remaining_ranges = cur.fetchone()[0]
646+
# PostgreSQL JSONB automatically parses to Python objects
647+
ranges_data = remaining_ranges
648+
assert ranges_data[0]['network'] == 'polygon'
649+
650+
finally:
651+
loader.pool.putconn(conn)

0 commit comments

Comments
 (0)