diff --git a/docs/source/publishing/ogcapi-features.rst b/docs/source/publishing/ogcapi-features.rst index 5646f1a7b..df40d27bd 100644 --- a/docs/source/publishing/ogcapi-features.rst +++ b/docs/source/publishing/ogcapi-features.rst @@ -627,6 +627,18 @@ Must have PostGIS installed. geom_field: foo_geom count: true # Optional; Default true; Enable/disable count for improved performance. +This can be represented as a connection dictionary or as a connection string as follows: + +.. code-block:: yaml + + providers: + - type: feature + name: PostgreSQL + data: postgresql://postgres:postgres@127.0.0.1:3010/test + id_field: osm_id + table: hotosm_bdi_waterways + geom_field: foo_geom + A number of database connection options can be also configured in the provider in order to adjust properly the sqlalchemy engine client. These are optional and if not specified, the default from the engine will be used. Please see also `SQLAlchemy docs `_. diff --git a/pygeoapi/process/manager/postgresql.py b/pygeoapi/process/manager/postgresql.py index bf5033eef..82ed5eb91 100644 --- a/pygeoapi/process/manager/postgresql.py +++ b/pygeoapi/process/manager/postgresql.py @@ -46,7 +46,6 @@ from typing import Any, Tuple from sqlalchemy import insert, update, delete -from sqlalchemy.engine import make_url from sqlalchemy.orm import Session from pygeoapi.api import FORMAT_TYPES, F_JSON, F_JSONLD @@ -56,7 +55,9 @@ ProcessorGenericError ) from pygeoapi.process.manager.base import BaseManager -from pygeoapi.provider.sql import get_engine, get_table_model +from pygeoapi.provider.sql import ( + get_engine, get_table_model, store_db_parameters +) from pygeoapi.util import JobStatus @@ -66,13 +67,15 @@ class PostgreSQLManager(BaseManager): """PostgreSQL Manager""" + default_port = 5432 + def __init__(self, manager_def: dict): """ Initialize object :param manager_def: manager definition - :returns: `pygeoapi.process.manager.postgresqs.PostgreSQLManager` + :returns: `pygeoapi.process.manager.postgresql.PostgreSQLManager` """ super().__init__(manager_def) @@ -81,30 +84,18 @@ def __init__(self, manager_def: dict): self.supports_subscribing = True self.connection = manager_def['connection'] - try: - self.db_search_path = tuple(self.connection.get('search_path', - ['public'])) - except Exception: - self.db_search_path = ('public',) - - try: - LOGGER.debug('Connecting to database') - if isinstance(self.connection, str): - _url = make_url(self.connection) - self._engine = get_engine( - 'postgresql+psycopg2', - _url.host, - _url.port, - _url.database, - _url.username, - _url.password) - else: - self._engine = get_engine('postgresql+psycopg2', - **self.connection) - except Exception as err: - msg = 'Test connecting to DB failed' - LOGGER.error(f'{msg}: {err}') - raise ProcessorGenericError(msg) + options = manager_def.get('options', {}) + store_db_parameters(self, manager_def['connection'], options) + self._engine = get_engine( + 'postgresql+psycopg2', + self.db_host, + self.db_port, + self.db_name, + self.db_user, + self._db_password, + self.db_conn, + **self.db_options + ) try: LOGGER.debug('Getting table model') diff --git a/pygeoapi/provider/sql.py b/pygeoapi/provider/sql.py index a955f06db..19cc35ca8 100644 --- a/pygeoapi/provider/sql.py +++ b/pygeoapi/provider/sql.py @@ -39,25 +39,12 @@ # # ================================================================= -# Testing local postgis with docker: -# docker run --name "postgis" \ -# -v postgres_data:/var/lib/postgresql -p 5432:5432 \ -# -e ALLOW_IP_RANGE=0.0.0.0/0 \ -# -e POSTGRES_USER=postgres \ -# -e POSTGRES_PASS=postgres \ -# -e POSTGRES_DBNAME=test \ -# -d -t kartoza/postgis - -# Import dump: -# gunzip < tests/data/hotosm_bdi_waterways.sql.gz | -# psql -U postgres -h 127.0.0.1 -p 5432 test - from copy import deepcopy from datetime import datetime from decimal import Decimal import functools import logging -from typing import Optional +from typing import Optional, Any from geoalchemy2 import Geometry # noqa - this isn't used explicitly but is needed to process Geometry columns from geoalchemy2.functions import ST_MakeEnvelope, ST_Intersects @@ -73,7 +60,7 @@ desc, delete ) -from sqlalchemy.engine import URL +from sqlalchemy.engine import URL, Engine from sqlalchemy.exc import ( ConstraintColumnNotFoundError, InvalidRequestError, @@ -82,6 +69,7 @@ from sqlalchemy.ext.automap import automap_base from sqlalchemy.orm import Session, load_only from sqlalchemy.sql.expression import and_ +from sqlalchemy.schema import Table from pygeoapi.crs import get_transform_from_spec, get_srid from pygeoapi.provider.base import ( @@ -135,8 +123,8 @@ def __init__( LOGGER.debug(f'Configured Storage CRS: {self.storage_crs}') # Read table information from database - options = provider_def.get('options', {}) - self._store_db_parameters(provider_def['data'], options) + options = provider_def.get('options', {}) | extra_conn_args + store_db_parameters(self, provider_def['data'], options) self._engine = get_engine( driver_name, self.db_host, @@ -144,13 +132,13 @@ def __init__( self.db_name, self.db_user, self._db_password, - **self.db_options | extra_conn_args + self.db_conn, + **self.db_options ) self.table_model = get_table_model( self.table, self.id_field, self.db_search_path, self._engine ) - LOGGER.debug(f'DB connection: {repr(self._engine.url)}') self.get_fields() def query( @@ -426,22 +414,6 @@ def delete(self, identifier): return result.rowcount > 0 - def _store_db_parameters(self, parameters, options): - self.db_user = parameters.get('user') - self.db_host = parameters.get('host') - self.db_port = parameters.get('port', self.default_port) - self.db_name = parameters.get('dbname') - # db_search_path gets converted to a tuple here in order to ensure it - # is hashable - which allows us to use functools.cache() when - # reflecting the table definition from the DB - self.db_search_path = tuple(parameters.get('search_path', ['public'])) - self._db_password = parameters.get('password') - self.db_options = { - k: v - for k, v in options.items() - if not isinstance(v, dict) - } - def _sqlalchemy_to_feature(self, item, crs_transform_out=None, select_properties=[]): """ @@ -602,6 +574,48 @@ def _select_properties_clause(self, select_properties, skip_geometry): return selected_properties_clause +def store_db_parameters( + self: GenericSQLProvider | Any, + connection_data: str | dict[str], + options: dict[str, str] +) -> None: + """ + Store database connection parameters + + :self: instance of provider or manager class + :param connection_data: connection string or dict of connection params + :param options: additional connection options + + :returns: None + """ + if isinstance(connection_data, str): + self.db_conn = connection_data + connection_data = {} + else: + self.db_conn = None + # OR + self.db_user = connection_data.get('user') + self.db_host = connection_data.get('host') + self.db_port = connection_data.get('port', self.default_port) + self.db_name = ( + connection_data.get('dbname') or connection_data.get('database') + ) + self.db_query = connection_data.get('query') + self._db_password = connection_data.get('password') + # db_search_path gets converted to a tuple here in order to ensure it + # is hashable - which allows us to use functools.cache() when + # reflecting the table definition from the DB + self.db_search_path = tuple( + connection_data.get('search_path') or + options.pop('search_path', ['public']) + ) + self.db_options = { + k: v + for k, v in options.items() + if not isinstance(v, dict) + } + + @functools.cache def get_engine( driver_name: str, @@ -610,20 +624,38 @@ def get_engine( database: str, user: str, password: str, + conn_str: Optional[str] = None, **connect_args -): - """Create SQL Alchemy engine.""" - conn_str = URL.create( - drivername=driver_name, - username=user, - password=password, - host=host, - port=int(port), - database=database - ) +) -> Engine: + """ + Get SQL Alchemy engine. + + :param driver_name: database driver name + :param host: database host + :param port: database port + :param database: database name + :param user: database user + :param password: database password + :param conn_str: optional connection URL + :param connect_args: custom connection arguments to pass to create_engine() + + :returns: SQL Alchemy engine + """ + if conn_str is None: + conn_str = URL.create( + drivername=driver_name, + username=user, + password=password, + host=host, + port=int(port), + database=database + ) + engine = create_engine( conn_str, connect_args=connect_args, pool_pre_ping=True ) + + LOGGER.debug(f'Created engine for {repr(engine.url)}.') return engine @@ -632,14 +664,25 @@ def get_table_model( table_name: str, id_field: str, db_search_path: tuple[str], - engine -): - """Reflect table.""" + engine: Engine +) -> Table: + """ + Reflect table using SQLAlchemy Automap. + + :param table_name: name of table to reflect + :param id_field: name of primary key field + :param db_search_path: tuple of database schemas to search for the table + :param engine: SQLAlchemy engine to use for reflection + + :returns: SQLAlchemy model of the reflected table + """ + LOGGER.debug('Reflecting table definition from database') metadata = MetaData() # Look for table in the first schema in the search path schema = db_search_path[0] try: + LOGGER.debug(f'Looking for table {table_name} in schema {schema}') metadata.reflect( bind=engine, schema=schema, only=[table_name], views=True ) diff --git a/tests/provider/test_mysql_provider.py b/tests/provider/test_mysql_provider.py index 0f470d750..1ac10e1c1 100644 --- a/tests/provider/test_mysql_provider.py +++ b/tests/provider/test_mysql_provider.py @@ -37,44 +37,45 @@ PASSWORD = os.environ.get('MYSQL_PASSWORD', 'mysql') -""" -For local testing, a MySQL database can be spun up with docker -compose as follows: - -services: - - mysql: - image: mysql:8 - ports: - - 3306:3306 - environment: - MYSQL_ROOT_PASSWORD: mysql - MYSQL_USER: pygeoapi - MYSQL_PASSWORD: mysql - MYSQL_DATABASE: test_geo_app - volumes: - - ./tests/data/mysql_data.sql:/docker-entrypoint-initdb.d/init.sql:ro -""" - - -@pytest.fixture() -def config(): - return { +# Testing local MySQL with docker: +''' +docker run --name mysql-test \ + -e MYSQL_ROOT_PASSWORD=mysql \ + -e MYSQL_USER=pygeoapi \ + -e MYSQL_PASSWORD=mysql \ + -e MYSQL_DATABASE=test_geo_app \ + -p 3306:3306 \ + -v ./tests/data/mysql_data.sql:/docker-entrypoint-initdb.d/init.sql:ro \ + -d mysql:8 +''' + + +@pytest.fixture(params=['default', 'connection_string']) +def config(request): + config_ = { 'name': 'MySQL', 'type': 'feature', - 'data': { + 'options': {'connect_timeout': 10}, + 'id_field': 'locationID', + 'table': 'location', + 'geom_field': 'locationCoordinates' + } + if request.param == 'default': + config_['data'] = { 'host': 'localhost', 'dbname': 'test_geo_app', 'user': 'root', 'port': 3306, 'password': PASSWORD, 'search_path': ['test_geo_app'] - }, - 'options': {'connect_timeout': 10}, - 'id_field': 'locationID', - 'table': 'location', - 'geom_field': 'locationCoordinates' - } + } + elif request.param == 'connection_string': + config_['data'] = ( + f'mysql+pymysql://root:{PASSWORD}@localhost:3306/test_geo_app' + ) + config_['options']['search_path'] = ['test_geo_app'] + + return config_ def test_valid_connection_options(config): @@ -87,7 +88,8 @@ def test_valid_connection_options(config): 'keepalives', 'keepalives_idle', 'keepalives_count', - 'keepalives_interval' + 'keepalives_interval', + 'search_path' ] diff --git a/tests/provider/test_postgresql_provider.py b/tests/provider/test_postgresql_provider.py index c27660caf..eb0b8760c 100644 --- a/tests/provider/test_postgresql_provider.py +++ b/tests/provider/test_postgresql_provider.py @@ -37,7 +37,17 @@ # ================================================================= # Needs to be run like: python3 -m pytest -# See pygeoapi/provider/postgresql.py for instructions on setting up +# Testing local postgis with docker: +''' +docker run --name postgis \ + --rm \ + -p 5432:5432 \ + -e ALLOW_IP_RANGE=0.0.0.0/0 \ + -e POSTGRES_USER=postgres \ + -e POSTGRES_PASS=postgres \ + -e POSTGRES_DBNAME=test \ + -d -t kartoza/postgis +''' # test database in Docker from http import HTTPStatus @@ -69,44 +79,58 @@ PASSWORD = os.environ.get('POSTGRESQL_PASSWORD', 'postgres') -@pytest.fixture() -def config(): - return { +@pytest.fixture(params=['default', 'connection_string']) +def config(request): + config_ = { 'name': 'PostgreSQL', 'type': 'feature', - 'data': {'host': '127.0.0.1', - 'dbname': 'test', - 'user': 'postgres', - 'password': PASSWORD, - 'search_path': ['osm', 'public'] - }, - 'options': { - 'connect_timeout': 10 - }, + 'options': {'connect_timeout': 10}, 'id_field': 'osm_id', 'table': 'hotosm_bdi_waterways', 'geom_field': 'foo_geom' } + if request.param == 'default': + config_['data'] = { + 'host': '127.0.0.1', + 'dbname': 'test', + 'user': 'postgres', + 'password': PASSWORD, + 'search_path': ['osm', 'public'] + } + elif request.param == 'connection_string': + config_['data'] = ( + f'postgresql://postgres:{PASSWORD}@127.0.0.1:5432/test' + ) + config_['options']['search_path'] = ['osm', 'public'] + return config_ -@pytest.fixture() -def config_types(): - return { + +@pytest.fixture(params=['default', 'connection_string']) +def config_types(request): + config_ = { 'name': 'PostgreSQL', 'type': 'feature', - 'data': {'host': '127.0.0.1', - 'dbname': 'test', - 'user': 'postgres', - 'password': PASSWORD, - 'search_path': ['public'] - }, - 'options': { - 'connect_timeout': 10 - }, + 'options': {'connect_timeout': 10}, 'id_field': 'id', 'table': 'foo', 'geom_field': 'the_geom' } + if request.param == 'default': + config_['data'] = { + 'host': '127.0.0.1', + 'dbname': 'test', + 'user': 'postgres', + 'password': PASSWORD, + 'search_path': ['public', 'osm'] + } + elif request.param == 'connection_string': + config_['data'] = ( + f'postgresql://postgres:{PASSWORD}@127.0.0.1:5432/test' + ) + config_['options']['search_path'] = ['public', 'osm'] + + return config_ @pytest.fixture() @@ -148,14 +172,20 @@ def test_valid_connection_options(config): for key in keys: assert key in ['connect_timeout', 'tcp_user_timeout', 'keepalives', 'keepalives_idle', 'keepalives_count', - 'keepalives_interval'] + 'keepalives_interval', 'search_path'] def test_schema_path_search(config): - config['data']['search_path'] = ['public', 'osm'] + if isinstance(config['data'], dict): + config['data']['search_path'] = ['public', 'osm'] + else: + config['options']['search_path'] = ['public', 'osm'] PostgreSQLProvider(config) - config['data']['search_path'] = ['public', 'notosm'] + if isinstance(config['data'], dict): + config['data']['search_path'] = ['public', 'notosm'] + else: + config['options']['search_path'] = ['public', 'notosm'] with pytest.raises(ProviderQueryError): PostgreSQLProvider(config) @@ -189,13 +219,13 @@ def test_query_materialised_view(config): provider = PostgreSQLProvider(config_materialised_view) # Only ID, width and depth properties should be available - assert set(provider.get_fields().keys()) == {"osm_id", "width", "depth"} + assert set(provider.get_fields().keys()) == {'osm_id', 'width', 'depth'} def test_query_with_property_filter(config): """Test query valid features when filtering by property""" p = PostgreSQLProvider(config) - feature_collection = p.query(properties=[("waterway", "stream")]) + feature_collection = p.query(properties=[('waterway', 'stream')]) features = feature_collection.get('features') stream_features = list( filter(lambda feature: feature['properties']['waterway'] == 'stream', @@ -246,19 +276,19 @@ def test_query_with_config_properties(config): feature = result.get('features')[0] properties = feature.get('properties') for property_name in properties.keys(): - assert property_name in config["properties"] + assert property_name in config['properties'] -@pytest.mark.parametrize("property_filter, expected", [ +@pytest.mark.parametrize('property_filter, expected', [ ([], 14776), - ([("waterway", "stream")], 13930), - ([("waterway", "this does not exist")], 0), + ([('waterway', 'stream')], 13930), + ([('waterway', 'this does not exist')], 0), ]) def test_query_hits_with_property_filter(config, property_filter, expected): """Test query resulttype=hits""" provider = PostgreSQLProvider(config) - results = provider.query(properties=property_filter, resulttype="hits") - assert results["numberMatched"] == expected + results = provider.query(properties=property_filter, resulttype='hits') + assert results['numberMatched'] == expected def test_query_bbox(config): @@ -337,7 +367,7 @@ def test_get_with_config_properties(config): result = provider.get(80835483) properties = result.get('properties') for property_name in properties.keys(): - assert property_name in config["properties"] + assert property_name in config['properties'] def test_get_not_existing_item_raise_exception(config): @@ -376,7 +406,7 @@ def test_query_cql(config, cql, expected_ids): assert feature_collection.get('type') == 'FeatureCollection' features = feature_collection.get('features') - ids = [feature["id"] for feature in features] + ids = [feature['id'] for feature in features] assert ids == expected_ids @@ -385,7 +415,7 @@ def test_query_cql_properties_bbox_filters(config): # Arrange properties = [('waterway', 'stream')] bbox = [29, -2.8, 29.2, -2.9] - filterq = parse("osm_id BETWEEN 80800000 AND 80900000") + filterq = parse('osm_id BETWEEN 80800000 AND 80900000') expected_ids = [80835470] # Act @@ -395,7 +425,7 @@ def test_query_cql_properties_bbox_filters(config): bbox=bbox) # Assert - ids = [feature["id"] for feature in feature_collection.get('features')] + ids = [feature['id'] for feature in feature_collection.get('features')] assert ids == expected_ids @@ -457,9 +487,9 @@ def test_instantiation(config): provider = PostgreSQLProvider(config) # Assert - assert provider.name == "PostgreSQL" - assert provider.table == "hotosm_bdi_waterways" - assert provider.id_field == "osm_id" + assert provider.name == 'PostgreSQL' + assert provider.table == 'hotosm_bdi_waterways' + assert provider.id_field == 'osm_id' @pytest.mark.parametrize('bad_data, exception, match', [ @@ -484,8 +514,14 @@ def test_instantiation_with_bad_config(config, bad_data, exception, match): def test_instantiation_with_bad_credentials(config): # Arrange - config['data'].update({'user': 'bad_user'}) - match = r'Could not connect to .*bad_user:\*\*\*@' + if isinstance(config['data'], dict): + config['data'].update({'user': 'bad_user'}) + match = r'Could not connect to .*bad_user:\*\*\*@' + + else: + config['data'] = config['data'].replace('postgres:', 'bad_user:') + match = r'Could not connect to .*bad_user:\*\*\*@' + # Make sure we don't use a cached connection in the tests postgresql_provider_module._ENGINE_STORE = {} @@ -505,7 +541,7 @@ def test_engine_and_table_model_stores(config): # Same database connection details, but different table different_table = config.copy() - different_table.update(table="hotosm_bdi_drains") + different_table.update(table='hotosm_bdi_drains') provider2 = PostgreSQLProvider(different_table) assert repr(provider2._engine) == repr(provider0._engine) assert provider2._engine is provider0._engine @@ -515,7 +551,11 @@ def test_engine_and_table_model_stores(config): # and also a different table_model, as two databases may have different # tables with the same name different_host = config.copy() - different_host["data"]["host"] = "localhost" + if isinstance(config['data'], dict): + different_host['data']['host'] = 'localhost' + else: + different_host['data'] = config['data'].replace( + '127.0.0.1', 'localhost') provider3 = PostgreSQLProvider(different_host) assert provider3._engine is not provider0._engine assert provider3.table_model is not provider0.table_model @@ -584,7 +624,7 @@ def test_get_collection_items_postgresql_cql_invalid_filter_language(pg_api_): assert error_response['description'] == 'Invalid filter language' -@pytest.mark.parametrize("bad_cql", [ +@pytest.mark.parametrize('bad_cql', [ 'id IN (1, ~)', 'id EATS (1, 2)', # Valid CQL relations only 'id IN (1, 2' # At some point this may return UnexpectedEOF @@ -664,7 +704,7 @@ def test_get_collection_items_postgresql_cql_json_invalid_filter_language(pg_api """ # Arrange # CQL should never be parsed - cql = {"in": {"value": {"property": "id"}, "list": [1, 2]}} + cql = {'in': {'value': {'property': 'id'}, 'list': [1, 2]}} headers = {'CONTENT_TYPE': 'application/query-cql-json'} # Act @@ -681,9 +721,9 @@ def test_get_collection_items_postgresql_cql_json_invalid_filter_language(pg_api assert error_response['description'] == 'Bad CQL JSON' -@pytest.mark.parametrize("bad_cql", [ +@pytest.mark.parametrize('bad_cql', [ # Valid CQL relations only - {"eats": {"value": {"property": "id"}, "list": [1, 2]}}, + {'eats': {'value': {'property': 'id'}, 'list': [1, 2]}}, # At some point this may return UnexpectedEOF '{"in": {"value": {"property": "id"}, "list": [1, 2}}' ]) @@ -939,7 +979,7 @@ def test_provider_count_false_with_resulttype_hits(config): provider = PostgreSQLProvider(config) # Act - results = provider.query(resulttype="hits") + results = provider.query(resulttype='hits') # Assert assert results['numberMatched'] == 14776