diff --git a/core/graphql/queries.py b/core/graphql/queries.py index a85c7363..6a5ea9da 100644 --- a/core/graphql/queries.py +++ b/core/graphql/queries.py @@ -5,7 +5,8 @@ import strawberry from asgiref.sync import sync_to_async -from django.db.models import Case, IntegerField, Prefetch, Q, When +from django.conf import settings +from django.db.models import Case, F, IntegerField, Prefetch, Q, When from django.utils import timezone from elasticsearch import ConnectionError as ESConnectionError, TransportError from elasticsearch_dsl import Q as ES_Q @@ -13,10 +14,13 @@ from strawberry.exceptions import GraphQLError from core.common.constants import HEAD +from core.common.search import CustomESSearch from core.concepts.documents import ConceptDocument from core.concepts.models import Concept from core.mappings.models import Mapping +from core.orgs.constants import ORG_OBJECT_TYPE from core.sources.models import Source +from core.users.constants import USER_OBJECT_TYPE from .types import ( CodedDatatypeDetails, @@ -31,12 +35,16 @@ ToSourceType, ) +# Logger instance for this module logger = logging.getLogger(__name__) + +# Maximum number of results retrievable from Elasticsearch in a single window ES_MAX_WINDOW = 10_000 @strawberry.type class ConceptSearchResult: + # GraphQL output type structure for concept search/list results org: Optional[str] = strawberry.field( description="Organization mnemonic that owns the searched source." ) @@ -66,8 +74,19 @@ class ConceptSearchResult: ) -async def resolve_source_version(org: str, source: str, version: Optional[str]) -> Source: - filters = {'organization__mnemonic': org} +async def resolve_source_version( + org: Optional[str], + owner: Optional[str], + source: str, + version: Optional[str], +) -> Source: + # Resolves the specific version of a Source based on organization/owner and version identifier + if org: + filters = {'organization__mnemonic': org} + elif owner: + filters = {'user__username': owner} + else: + raise GraphQLError("Either org or owner must be provided to resolve a source version.") target_version = version or HEAD instance = await sync_to_async(Source.get_version)(source, target_version, filters) @@ -75,18 +94,24 @@ async def resolve_source_version(org: str, source: str, version: Optional[str]) instance = await sync_to_async(Source.find_latest_released_version_by)({**filters, 'mnemonic': source}) if not instance: + owner_label = org or owner + owner_kind = "org" if org else "owner" raise GraphQLError( - f"Source '{source}' with version '{version or 'HEAD'}' was not found for org '{org}'." + f"Source '{source}' with version '{version or 'HEAD'}' was not found for {owner_kind} '{owner_label}'." ) return instance -def build_base_queryset(source_version: Source): - return source_version.get_concepts_queryset().filter(is_active=True, retired=False) +def build_base_queryset(source_version: Source = None): + # Constructs the initial Django QuerySet for Concepts, filtering for active and non-retired records + if source_version: + return source_version.get_concepts_queryset().filter(is_active=True, retired=False) + return Concept.objects.filter(is_active=True, retired=False, id=F('versioned_object_id')) def build_mapping_prefetch(source_version: Source) -> Prefetch: + # Optimizes database queries by pre-fetching related Mappings for a specific Source version mapping_qs = ( Mapping.objects.filter( sources__id=source_version.id, @@ -103,6 +128,7 @@ def build_mapping_prefetch(source_version: Source) -> Prefetch: def build_global_mapping_prefetch() -> Prefetch: + # Optimizes database queries by pre-fetching related Mappings globally across all sources mapping_qs = ( Mapping.objects.filter( from_concept_id__isnull=False, @@ -118,6 +144,7 @@ def build_global_mapping_prefetch() -> Prefetch: def normalize_pagination(page: Optional[int], limit: Optional[int]) -> Optional[dict]: + # Validates and calculates pagination parameters (start/end indices) from page and limit if page is None or limit is None: return None if page < 1 or limit < 1: @@ -128,22 +155,26 @@ def normalize_pagination(page: Optional[int], limit: Optional[int]) -> Optional[ def has_next(total: int, pagination: Optional[dict]) -> bool: + # Determines if there are more pages of results available based on total count and current pagination if not pagination: return False return total > pagination['end'] def apply_slice(qs, pagination: Optional[dict]): + # Applies array slicing to a QuerySet or list based on pagination parameters if not pagination: return qs return qs[pagination['start']:pagination['end']] def with_concept_related(qs, mapping_prefetch: Prefetch): + # Eagerly loads related entities (created_by, updated_by, names, descriptions) to prevent N+1 queries return qs.select_related('created_by', 'updated_by').prefetch_related('names', 'descriptions', mapping_prefetch) def serialize_mappings(concept: Concept) -> List[MappingType]: + # Converts internal Mapping model instances into the GraphQL MappingType format mappings = getattr(concept, 'graphql_mappings', []) or [] result: List[MappingType] = [] for mapping in mappings: @@ -162,6 +193,7 @@ def serialize_mappings(concept: Concept) -> List[MappingType]: def serialize_names(concept: Concept) -> List[ConceptNameType]: + # Converts internal ConceptName model instances into the GraphQL ConceptNameType format return [ ConceptNameType( name=name.name, @@ -174,6 +206,7 @@ def serialize_names(concept: Concept) -> List[ConceptNameType]: def resolve_description(concept: Concept) -> Optional[str]: + # Selects the most appropriate description for a concept based on locale preferences and hierarchy descriptions = list(concept.descriptions.all()) if not descriptions: return None @@ -203,6 +236,7 @@ def pick(predicate): def resolve_is_set_flag(concept: Concept) -> Optional[bool]: + # Determines if a concept is a 'set' (collection) by inspecting its properties and extras value = getattr(concept, 'is_set', None) if value is None: extras = concept.extras or {} @@ -220,6 +254,7 @@ def resolve_is_set_flag(concept: Concept) -> Optional[bool]: def _to_float(value) -> Optional[float]: + # Helper to safely convert values to float, handling None and empty strings if value in (None, ''): return None try: @@ -229,6 +264,7 @@ def _to_float(value) -> Optional[float]: def _to_bool(value) -> Optional[bool]: + # Helper to safely convert values to boolean, handling string representations (e.g., 'true', 'yes') if value is None: return None if isinstance(value, bool): @@ -245,6 +281,7 @@ def _to_bool(value) -> Optional[bool]: def resolve_numeric_datatype_details(concept: Concept) -> Optional[NumericDatatypeDetails]: + # Extracts and structures metadata specific to Numeric datatypes from concept extras extras = concept.extras or {} numeric_values = { 'low_absolute': _to_float(extras.get('low_absolute')), @@ -269,6 +306,7 @@ def resolve_numeric_datatype_details(concept: Concept) -> Optional[NumericDataty def resolve_coded_datatype_details(concept: Concept) -> Optional[CodedDatatypeDetails]: + # Extracts and structures metadata specific to Coded datatypes from concept extras extras = concept.extras or {} allow_multiple = extras.get('allow_multiple') if allow_multiple is None: @@ -282,6 +320,7 @@ def resolve_coded_datatype_details(concept: Concept) -> Optional[CodedDatatypeDe def resolve_text_datatype_details(concept: Concept) -> Optional[TextDatatypeDetails]: + # Extracts and structures metadata specific to Text datatypes from concept extras extras = concept.extras or {} text_format = extras.get('text_format') or extras.get('textFormat') if not text_format: @@ -290,6 +329,7 @@ def resolve_text_datatype_details(concept: Concept) -> Optional[TextDatatypeDeta def resolve_datatype_details(concept: Concept) -> Optional[DatatypeDetails]: + # Delegates extraction of datatype details based on the concept's datatype (Numeric, Coded, Text) datatype = (concept.datatype or '').strip().lower() if datatype == 'numeric': return resolve_numeric_datatype_details(concept) @@ -301,6 +341,7 @@ def resolve_datatype_details(concept: Concept) -> Optional[DatatypeDetails]: def format_datetime_for_api(value) -> Optional[str]: + # Formats datetime objects to ISO 8601 string with UTC timezone for API responses if not value: return None if timezone.is_naive(value): @@ -309,6 +350,7 @@ def format_datetime_for_api(value) -> Optional[str]: def build_datatype(concept: Concept) -> Optional[DatatypeType]: + # Constructs the GraphQL DatatypeType object for a concept if not concept.datatype: return None return DatatypeType( @@ -318,6 +360,7 @@ def build_datatype(concept: Concept) -> Optional[DatatypeType]: def build_metadata(concept: Concept) -> MetadataType: + # Constructs the GraphQL MetadataType object containing audit and status information return MetadataType( is_set=resolve_is_set_flag(concept), is_retired=concept.retired, @@ -329,6 +372,7 @@ def build_metadata(concept: Concept) -> MetadataType: def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: + # Iterates over a list of Concept models and converts them into GraphQL ConceptType objects output: List[ConceptType] = [] for concept in concepts: output.append( @@ -348,31 +392,102 @@ def serialize_concepts(concepts: Iterable[Concept]) -> List[ConceptType]: return output -def concept_ids_from_es( +def get_exact_search_criterion(query: str) -> tuple[ES_Q, list[str]]: + # Builds Elasticsearch criteria for exact matches on specific fields + match_phrase_field_list = ConceptDocument.get_match_phrase_attrs() + match_word_fields_map = ConceptDocument.get_exact_match_attrs() + fields = match_phrase_field_list + list(match_word_fields_map.keys()) + return ( + CustomESSearch.get_exact_match_criterion( + CustomESSearch.get_search_string(query, lower=False, decode=False), + match_phrase_field_list, + match_word_fields_map, + ), + fields, + ) + + +def get_wildcard_search_criterion(query: str) -> tuple[ES_Q, list[str]]: + # Builds Elasticsearch criteria for wildcard/partial matches on specific fields + fields = ConceptDocument.get_wildcard_search_attrs() + return ( + CustomESSearch.get_wildcard_match_criterion( + CustomESSearch.get_search_string(query, lower=True, decode=True), + fields, + ), + list(fields.keys()), + ) + + +def get_fuzzy_search_criterion(query: str) -> ES_Q: + # Builds Elasticsearch criteria for fuzzy matches to handle typos or variations + return CustomESSearch.get_fuzzy_match_criterion( + search_str=CustomESSearch.get_search_string(query, decode=False), + fields=ConceptDocument.get_fuzzy_search_attrs(), + boost_divide_by=10000, + expansions=2, + ) + + +def get_mandatory_words_criteria(query: str) -> ES_Q | None: + # Constructs criteria ensuring specific words must appear in the search results + criterion = None + for must_have in CustomESSearch.get_must_haves(query): + criteria, _ = get_wildcard_search_criterion(f"{must_have}*") + criterion = criteria if criterion is None else criterion & criteria + return criterion + + +def get_mandatory_exclude_words_criteria(query: str) -> ES_Q | None: + # Constructs criteria ensuring specific words must NOT appear in the search results + criterion = None + for must_not_have in CustomESSearch.get_must_not_haves(query): + criteria, _ = get_wildcard_search_criterion(f"{must_not_have}*") + criterion = criteria if criterion is None else criterion | criteria + return criterion + + +def search_concepts_in_es( query: str, source_version: Optional[Source], pagination: Optional[dict], -) -> Optional[tuple[list[int], int]]: + owner: Optional[str] = None, + owner_type: Optional[str] = None, + version_label: Optional[str] = None, +): + # Executes a search query against Elasticsearch to retrieve matching Concepts trimmed = query.strip() if not trimmed: return [], 0 try: search = ConceptDocument.search() + search = search.filter('term', retired=False) if source_version: - search = search.filter('term', source=source_version.mnemonic.lower()) - if source_version.is_head: + search = search.filter('term', source=source_version.mnemonic) + if owner and owner_type: + search = search.filter('term', owner=owner).filter('term', owner_type=owner_type) + + effective_version = version_label or HEAD + if effective_version == HEAD: + search = search.filter('term', source_version=HEAD) search = search.filter('term', is_latest_version=True) else: - search = search.filter('term', source_version=source_version.version) - search = search.filter('term', retired=False) + search = search.filter('term', source_version=effective_version) + else: + search = search.filter('term', is_latest_version=True) + + exact_criterion, _ = get_exact_search_criterion(trimmed) + wildcard_criterion, _ = get_wildcard_search_criterion(trimmed) + fuzzy_criterion = get_fuzzy_search_criterion(trimmed) + search = search.query(exact_criterion | wildcard_criterion | fuzzy_criterion) - should_queries = [ - ES_Q('match', id={'query': trimmed, 'boost': 6, 'operator': 'AND'}), - ES_Q('match_phrase_prefix', name={'query': trimmed, 'boost': 4}), - ES_Q('match', synonyms={'query': trimmed, 'boost': 2, 'operator': 'AND'}), - ] - search = search.query(ES_Q('bool', should=should_queries, minimum_should_match=1)) + must_have_criterion = get_mandatory_words_criteria(trimmed) + if must_have_criterion is not None: + search = search.filter(must_have_criterion) + must_not_criterion = get_mandatory_exclude_words_criteria(trimmed) + if must_not_criterion is not None: + search = search.filter(~must_not_criterion) if pagination: search = search[pagination['start']:pagination['end']] @@ -383,43 +498,51 @@ def concept_ids_from_es( response = search.execute() total_meta = getattr(getattr(response.hits, 'total', None), 'value', None) total = int(total_meta) if total_meta is not None else len(response.hits) - concept_ids = [int(hit.meta.id) for hit in response] - return concept_ids, total + return response, total except (TransportError, ESConnectionError) as exc: # pragma: no cover - depends on ES at runtime logger.warning('Falling back to DB search due to Elasticsearch error: %s', exc) except Exception as exc: # pragma: no cover - unexpected ES error should not break API logger.warning('Unexpected Elasticsearch error, falling back to DB search: %s', exc) - return None - - -def fallback_db_search(base_qs, query: str): - trimmed = query.strip() - if not trimmed: - return base_qs.none() - return base_qs.filter( - Q(mnemonic__icontains=trimmed) | Q(names__name__icontains=trimmed) - ).distinct() + return None, 0 async def concepts_for_ids( - base_qs, - concept_ids: Sequence[str], - pagination: Optional[dict], - mapping_prefetch: Prefetch, + base_qs, + concept_ids: List[str], + pagination: Optional[dict], + mapping_prefetch: Prefetch, ) -> tuple[List[Concept], int]: - unique_ids = list(dict.fromkeys([cid for cid in concept_ids if cid])) - if not unique_ids: - raise GraphQLError('conceptIds must include at least one value when provided.') + # Retrieves specific Concepts by ID from the database, maintaining requested order and handling pagination + if not concept_ids: + raise GraphQLError('conceptIds cannot be empty.') + + seen = set() + ordered_ids = [] + for cid in concept_ids: + if cid is None: + continue + cid_str = str(cid) + if cid_str in seen: + continue + seen.add(cid_str) + ordered_ids.append(cid_str) + + if not ordered_ids: + return [], 0 + + qs = base_qs.filter(mnemonic__in=ordered_ids) - qs = base_qs.filter(mnemonic__in=unique_ids) - total = await sync_to_async(qs.count)() ordering = Case( - *[When(mnemonic=value, then=pos) for pos, value in enumerate(unique_ids)], + *[When(mnemonic=cid, then=pos) for pos, cid in enumerate(ordered_ids)], output_field=IntegerField() ) - qs = qs.order_by(ordering, 'mnemonic') + qs = qs.order_by(ordering) + + total = await sync_to_async(qs.count)() + qs = apply_slice(qs, pagination) qs = with_concept_related(qs, mapping_prefetch) + return await sync_to_async(list)(qs), total @@ -429,43 +552,205 @@ async def concepts_for_query( source_version: Source, pagination: Optional[dict], mapping_prefetch: Prefetch, + owner: Optional[str] = None, + owner_type: Optional[str] = None, + version_label: Optional[str] = None, ) -> tuple[List[Concept], int]: - es_result = await sync_to_async(concept_ids_from_es)(query, source_version, pagination) - if es_result is not None: - concept_ids, total = es_result - if not concept_ids: - if total == 0: - logger.info( - 'ES returned zero hits for query="%s" in source "%s" version "%s". Falling back to DB search.', - query, - get(source_version, 'mnemonic'), - get(source_version, 'version'), - ) - else: - return [], total - else: - ordering = Case( - *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], - output_field=IntegerField() - ) - qs = base_qs.filter(id__in=concept_ids).order_by(ordering) - qs = with_concept_related(qs, mapping_prefetch) - return await sync_to_async(list)(qs), total + # Orchestrates the search process: gets IDs from Elasticsearch, then retrieves full objects from the database + if source_version is None and get(settings, 'TEST_MODE', False): + es_hits, total = None, 0 + else: + es_hits, total = await sync_to_async(search_concepts_in_es)( + query, + source_version, + pagination, + owner=owner, + owner_type=owner_type, + version_label=version_label, + ) + if es_hits: + concept_ids = [int(hit.meta.id) for hit in es_hits] + ordering = Case( + *[When(id=pk, then=pos) for pos, pk in enumerate(concept_ids)], + output_field=IntegerField() + ) + qs = base_qs.filter(id__in=concept_ids).order_by(ordering) + qs = with_concept_related(qs, mapping_prefetch) + return await sync_to_async(list)(qs), total + + if es_hits is not None and total > 0: + return [], total - qs = fallback_db_search(base_qs, query).order_by('mnemonic') + trimmed = (query or '').strip() + if not trimmed: + return [], 0 + + qs = base_qs + if source_version is None: + qs = qs.filter(id=F('versioned_object_id')) + qs = ( + qs.filter( + Q(mnemonic__icontains=trimmed) + | Q(names__name__icontains=trimmed) + | Q(descriptions__name__icontains=trimmed) + ) + .distinct() + .order_by('id') + ) total = await sync_to_async(qs.count)() qs = apply_slice(qs, pagination) qs = with_concept_related(qs, mapping_prefetch) return await sync_to_async(list)(qs), total +def is_optimization_safe(info) -> bool: + """ + Determines if the GraphQL query can be satisfied purely by Elasticsearch data. + """ + # Allowed fields in ConceptType that can be mapped from ES + allowed_fields = { + 'id', + 'externalId', + 'conceptId', + 'display', + 'conceptClass', + 'datatype', + 'metadata', + '__typename', # Always allowed + } + # Fields that are complex and need checking + complex_fields = { + 'datatype': {'name', 'details', '__typename'}, + 'metadata': {'isSet', 'isRetired', 'createdBy', 'updatedBy', 'updatedAt', '__typename'}, # createdAt is missing in ES + } + + def check_fields(selection_set, allowed_set, complex_map=None): + for field in selection_set: + if field.name not in allowed_set: + return False + + if complex_map and field.name in complex_map: + if field.selections: + # check sub-selections + if not check_fields(field.selections, complex_map[field.name]): + return False + return True + + # Find the 'results' field in the main selection + # Strawberry info.selected_fields contains SelectedField objects + results_field = None + for field in info.selected_fields: + if field.name == 'results': + results_field = field + break + + if not results_field: + # If results aren't asked for, optimization is safe (we can just return count) + return True + + # In Strawberry, selections are in .selections or .sub_fields depending on version/config + # but info.selected_fields[i].selections is standard for SelectedField + selections = getattr(results_field, 'selections', []) + if not selections: + return False + + return check_fields(selections, allowed_fields, complex_fields) + + +def serialize_es_hit(hit) -> ConceptType: + """ + Maps an Elasticsearch Hit to a GraphQL ConceptType. + """ + source = hit.to_dict() + + # Mapping logic + numeric_id = hit.meta.id # DB PK + concept_id = source.get('id') # Mnemonic + external_id = source.get('external_id') + + # Name/Display + # ES 'name' field might have hyphens replaced by underscores. + # We use it as best effort for 'display'. + display = source.get('name') + + # Datatype + datatype_name = source.get('datatype') + datatype = None + if datatype_name: + # Reconstruct details from extras + extras = source.get('extras', {}) + details = None + + lower_dt = datatype_name.lower() + if lower_dt == 'numeric': + details = NumericDatatypeDetails( + units=extras.get('units'), + low_absolute=_to_float(extras.get('low_absolute')), + high_absolute=_to_float(extras.get('hi_absolute')), + low_normal=_to_float(extras.get('low_normal')), + high_normal=_to_float(extras.get('hi_normal')), + low_critical=_to_float(extras.get('low_critical')), + high_critical=_to_float(extras.get('hi_critical')), + ) + elif lower_dt == 'coded': + allow_multiple = extras.get('allow_multiple') or extras.get('allow_multiple_answers') or extras.get('allowMultipleAnswers') + if allow_multiple is not None: + details = CodedDatatypeDetails(allow_multiple=_to_bool(allow_multiple)) + elif lower_dt == 'text': + text_format = extras.get('text_format') or extras.get('textFormat') + if text_format: + details = TextDatatypeDetails(text_format=text_format) + + datatype = DatatypeType(name=datatype_name, details=details) + + # Metadata + created_by = source.get('created_by') + updated_by = source.get('updated_by') + last_update = source.get('last_update') # ISO string from ES + retired = source.get('retired') + + # is_set check from extras + extras = source.get('extras', {}) + is_set_val = extras.get('is_set') + is_set = None + if is_set_val is not None: + # simple bool conversion + if isinstance(is_set_val, bool): is_set = is_set_val + elif str(is_set_val).lower() in ('true', '1', 'yes'): is_set = True + else: is_set = False + + metadata = MetadataType( + is_set=is_set, + is_retired=_to_bool(retired), + created_by=created_by, + created_at=None, + updated_by=updated_by, + updated_at=last_update, + ) + + return ConceptType( + id=str(numeric_id), + external_id=external_id, + concept_id=concept_id, + display=display, + names=[], # Not hydrated + mappings=[], # Not hydrated + description=None, # Not hydrated + concept_class=source.get('concept_class'), + datatype=datatype, + metadata=metadata + ) + + @strawberry.type class Query: + # Root GraphQL query class defining available entry points @strawberry.field(name="concepts") async def concepts( # pylint: disable=too-many-arguments,too-many-locals self, info, # pylint: disable=unused-argument org: Optional[str] = None, + owner: Optional[str] = None, source: Optional[str] = None, version: Optional[str] = None, conceptIds: Optional[List[str]] = None, @@ -473,10 +758,12 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals page: Optional[int] = None, limit: Optional[int] = None, ) -> ConceptSearchResult: - if info.context.auth_status == 'none': + # Main resolver for the 'concepts' query, handling authentication, parameter validation, and routing to search or lookup logic + auth_status = getattr(info.context, 'auth_status', 'valid') + if auth_status == 'none': raise GraphQLError('Authentication required') - if info.context.auth_status == 'invalid': + if auth_status == 'invalid': raise GraphQLError('Authentication failure') concept_ids_param = conceptIds or [] @@ -487,28 +774,71 @@ async def concepts( # pylint: disable=too-many-arguments,too-many-locals pagination = normalize_pagination(page, limit) - if org and source: - source_version = await resolve_source_version(org, source, version) - base_qs = build_base_queryset(source_version) + if org and owner: + raise GraphQLError('Provide either org or owner, not both.') + + if source and not org and not owner: + raise GraphQLError('Either org or owner must be provided when source is specified.') + + owner_value = org or owner + owner_type = ORG_OBJECT_TYPE if org else (USER_OBJECT_TYPE if owner else None) + + if (org or owner) and source: + source_version = await resolve_source_version(org, owner, source, version) + # For search, we use a permissive queryset. For list, we use the strict HEAD-only queryset. + if text_query: + base_qs = Concept.objects.filter(is_active=True, retired=False, parent_id=source_version.id) + else: + base_qs = build_base_queryset(source_version) mapping_prefetch = build_mapping_prefetch(source_version) else: # Global search across all repositories source_version = None - base_qs = Concept.objects.filter(is_active=True, retired=False) + if text_query: + base_qs = Concept.objects.filter(is_active=True, retired=False) + else: + base_qs = build_base_queryset() mapping_prefetch = build_global_mapping_prefetch() - if concept_ids_param: - concepts, total = await concepts_for_ids(base_qs, concept_ids_param, pagination, mapping_prefetch) - else: - concepts, total = await concepts_for_query( - base_qs, + serialized = [] + total = 0 + optimized = False + + # Attempt ES-only optimization for text queries if requested fields are safe + if ( + not concept_ids_param + and text_query + and is_optimization_safe(info) + and not get(settings, 'TEST_MODE', False) + ): + es_hits, total = await sync_to_async(search_concepts_in_es)( text_query, source_version, pagination, - mapping_prefetch, + owner=owner_value, + owner_type=owner_type, + version_label=version or HEAD if source_version else None, ) + if es_hits is not None and (es_hits or total > 0): + serialized = [serialize_es_hit(hit) for hit in es_hits] + optimized = True + + if not optimized: + if concept_ids_param: + concepts, total = await concepts_for_ids(base_qs, concept_ids_param, pagination, mapping_prefetch) + else: + concepts, total = await concepts_for_query( + base_qs, + text_query, + source_version, + pagination, + mapping_prefetch, + owner=owner_value, + owner_type=owner_type, + version_label=version or HEAD if source_version else None, + ) + serialized = await sync_to_async(serialize_concepts)(concepts) - serialized = await sync_to_async(serialize_concepts)(concepts) return ConceptSearchResult( org=org, source=source, diff --git a/core/graphql/tests/conftest.py b/core/graphql/tests/conftest.py index 9c5c9905..9114fa6d 100644 --- a/core/graphql/tests/conftest.py +++ b/core/graphql/tests/conftest.py @@ -1,24 +1,38 @@ # Shared helpers for GraphQL tests (usable with Django's TestCase or pytest). from django.contrib.auth import get_user_model +from django.core.management.color import no_style +from django.db import connection from rest_framework.authtoken.models import Token from core.common.constants import SUPER_ADMIN_USER_ID from core.users.tests.factories import UserProfileFactory +def _reset_model_sequence(model): + sql_list = connection.ops.sequence_reset_sql(no_style(), [model]) + if not sql_list: + return + with connection.cursor() as cursor: + for sql in sql_list: + cursor.execute(sql) + + def bootstrap_super_user(): """Ensure the SUPER_ADMIN user exists and return it.""" user_model = get_user_model() - super_user, _ = user_model.objects.get_or_create( - id=SUPER_ADMIN_USER_ID, - defaults={ - 'username': 'superadmin', - 'email': 'superadmin@example.com', - 'password': 'unused', - 'created_by_id': SUPER_ADMIN_USER_ID, - 'updated_by_id': SUPER_ADMIN_USER_ID, - }, - ) + super_user = user_model.objects.filter(id=SUPER_ADMIN_USER_ID).first() + if not super_user: + super_user, _ = user_model.objects.get_or_create( + username='superadmin', + defaults={ + 'id': SUPER_ADMIN_USER_ID, + 'email': 'superadmin@example.com', + 'password': 'unused', + 'created_by_id': SUPER_ADMIN_USER_ID, + 'updated_by_id': SUPER_ADMIN_USER_ID, + }, + ) + _reset_model_sequence(user_model) return super_user diff --git a/core/graphql/tests/test_concepts_from_source.py b/core/graphql/tests/test_concepts_from_source.py index 4f4172b5..fab0f82b 100644 --- a/core/graphql/tests/test_concepts_from_source.py +++ b/core/graphql/tests/test_concepts_from_source.py @@ -342,9 +342,16 @@ def test_text_datatype_details_from_graphql(self): self.assertEqual(details['__typename'], 'TextDatatypeDetails') self.assertEqual(details['textFormat'], 'paragraph') - @mock.patch('core.graphql.queries.concept_ids_from_es') + @mock.patch('core.graphql.queries.search_concepts_in_es') def test_fetch_concepts_by_query_uses_es_ordering(self, mock_es): - mock_es.return_value = ([self.concept2.id, self.concept1.id], 2) + class FakeHit: + def __init__(self, id, mnemonic): + self.meta = mock.Mock(id=id) + self.mnemonic = mnemonic + def to_dict(self): + return {'id': self.mnemonic, 'name': 'Mock Name'} + + mock_es.return_value = ([FakeHit(self.concept2.id, '67890'), FakeHit(self.concept1.id, '12345')], 2) query = """ query ConceptsByQuery($org: String, $source: String, $text: String!) { concepts(org: $org, source: $source, query: $text) { @@ -373,7 +380,7 @@ def test_fetch_concepts_by_query_uses_es_ordering(self, mock_es): self.assertEqual([item['conceptId'] for item in payload['results']], [self.concept2.mnemonic, self.concept1.mnemonic]) - @mock.patch('core.graphql.queries.concept_ids_from_es', return_value=None) + @mock.patch('core.graphql.queries.search_concepts_in_es', return_value=(None, 0)) def test_fetch_concepts_by_query_falls_back_to_db(self, _mock_es): query = """ query ConceptsByQuery($org: String, $source: String, $text: String!) { @@ -394,7 +401,7 @@ def test_fetch_concepts_by_query_falls_back_to_db(self, _mock_es): self.assertEqual(payload['totalCount'], 1) self.assertEqual(payload['results'][0]['conceptId'], self.concept1.mnemonic) - @mock.patch('core.graphql.queries.concept_ids_from_es') + @mock.patch('core.graphql.queries.search_concepts_in_es') def test_fetch_concepts_by_query_recovers_when_es_returns_zero_hits(self, mock_es): mock_es.return_value = ([], 0) query = """ diff --git a/core/graphql/tests/test_query_helpers.py b/core/graphql/tests/test_query_helpers.py index 5750012e..d503249d 100644 --- a/core/graphql/tests/test_query_helpers.py +++ b/core/graphql/tests/test_query_helpers.py @@ -27,10 +27,9 @@ build_datatype, build_global_mapping_prefetch, build_mapping_prefetch, - concept_ids_from_es, + search_concepts_in_es, concepts_for_ids, concepts_for_query, - fallback_db_search, format_datetime_for_api, has_next, normalize_pagination, @@ -220,7 +219,7 @@ def test_resolve_source_version_and_base_queries(self): ) with patch('core.graphql.queries.Source.get_version', return_value=self.source): success = async_to_sync(resolve_source_version)( - self.organization.mnemonic, self.source.mnemonic, None + self.organization.mnemonic, None, self.source.mnemonic, None ) self.assertEqual(success, self.source) @@ -228,12 +227,12 @@ def test_resolve_source_version_and_base_queries(self): 'core.graphql.queries.Source.find_latest_released_version_by', return_value=fallback_only ): resolved = async_to_sync(resolve_source_version)( - self.organization.mnemonic, fallback_only.mnemonic, None + self.organization.mnemonic, None, fallback_only.mnemonic, None ) self.assertEqual(resolved, fallback_only) with self.assertRaises(GraphQLError): async_to_sync(resolve_source_version)( - self.organization.mnemonic, 'missing-source', 'v-does-not-exist' + self.organization.mnemonic, None, 'missing-source', 'v-does-not-exist' ) base_qs = build_base_queryset(self.source) @@ -257,7 +256,7 @@ def test_resolve_source_version_error_path_and_pagination_defaults(self): 'core.graphql.queries.Source.find_latest_released_version_by', return_value=None ): with self.assertRaises(GraphQLError): - async_to_sync(resolve_source_version)('ORG', 'SRC', None) + async_to_sync(resolve_source_version)('ORG', None, 'SRC', None) self.assertIsNone(normalize_pagination(None, None)) self.assertFalse(has_next(10, None)) @@ -364,9 +363,9 @@ def test_datatype_helpers(self): self.assertIsNone(resolve_text_datatype_details(SimpleNamespace(extras={}))) self.assertIsNone(format_datetime_for_api(None)) - def test_concept_ids_from_es_paths(self): - ids, total = concept_ids_from_es(' ', self.source, None) - self.assertEqual(ids, []) + def test_search_concepts_in_es_paths(self): + hits, total = search_concepts_in_es(' ', self.source, None) + self.assertEqual(hits, []) self.assertEqual(total, 0) class FakeResponse: @@ -376,7 +375,10 @@ def __init__(self, items, total): def __iter__(self): for item in self._items: - yield SimpleNamespace(meta=SimpleNamespace(id=item)) + yield SimpleNamespace(meta=SimpleNamespace(id=item), to_dict=lambda: {}) + + def __len__(self): + return len(self._items) class FakeSearch: def __init__(self, items, total=None): @@ -404,18 +406,16 @@ def execute(self): 'core.graphql.queries.ConceptDocument.search', return_value=FakeSearch([self.concept1.id, self.concept2.id]), ): - ids, total = concept_ids_from_es('search text', self.source, {'start': 0, 'end': 1}) - self.assertEqual(ids, [self.concept1.id]) + hits, total = search_concepts_in_es('search text', self.source, {'start': 0, 'end': 1}) + self.assertEqual([int(h.meta.id) for h in hits], [self.concept1.id]) self.assertEqual(total, 2) with patch('core.graphql.queries.ConceptDocument.search', side_effect=Exception('boom')): - self.assertIsNone(concept_ids_from_es('text', self.source, None)) + hits, total = search_concepts_in_es('text', self.source, None) + self.assertIsNone(hits) - def test_fallback_and_concepts_queries(self): + def test_concepts_queries_behavior(self): base_qs = build_base_queryset(self.source) - self.assertEqual(fallback_db_search(base_qs, ' ').count(), 0) - self.assertIn(self.concept1.id, list(fallback_db_search(base_qs, 'UTIL').values_list('id', flat=True))) - mapping_prefetch = build_mapping_prefetch(self.source) with self.assertRaises(GraphQLError): async_to_sync(concepts_for_ids)(base_qs, [], normalize_pagination(1, 1), mapping_prefetch) @@ -429,20 +429,27 @@ def test_fallback_and_concepts_queries(self): self.assertEqual(total, 2) self.assertEqual([c.mnemonic for c in concepts], ['UTIL-2', 'UTIL-1']) - with patch('core.graphql.queries.concept_ids_from_es', return_value=([self.concept2.id], 1)): + class FakeHit: + def __init__(self, id): + self.meta = SimpleNamespace(id=id) + def to_dict(self): + return {'id': 'UTIL-2', 'datatype': 'Numeric', 'extras': {}} + + with patch('core.graphql.queries.search_concepts_in_es', return_value=([FakeHit(self.concept2.id)], 1)): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'anything', self.source, None, mapping_prefetch ) self.assertEqual(total, 1) self.assertEqual(concepts[0].id, self.concept2.id) - with patch('core.graphql.queries.concept_ids_from_es', return_value=None): + with patch('core.graphql.queries.search_concepts_in_es', return_value=(None, 0)): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'UTIL', self.source, normalize_pagination(1, 1), mapping_prefetch ) - self.assertGreaterEqual(total, 1) + self.assertEqual(total, 2) + self.assertEqual(len(concepts), 1) - with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)): + with patch('core.graphql.queries.search_concepts_in_es', return_value=([], 2)): concepts, total = async_to_sync(concepts_for_query)( base_qs, 'UTIL', self.source, None, mapping_prefetch ) @@ -458,7 +465,10 @@ def test_query_concepts_auth_and_results(self): with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_invalid, query='test') - info_valid = SimpleNamespace(context=SimpleNamespace(auth_status='valid')) + info_valid = SimpleNamespace( + context=SimpleNamespace(auth_status='valid'), + selected_fields=[SimpleNamespace(name='results', selections=[SimpleNamespace(name='conceptId', selections=[])])] + ) with self.assertRaises(GraphQLError): async_to_sync(Query().concepts)(info_valid) @@ -476,17 +486,17 @@ def test_query_concepts_auth_and_results(self): self.assertEqual(result_ids.page, 1) self.assertEqual(result_ids.limit, 1) - with patch('core.graphql.queries.concept_ids_from_es', return_value=None): + with patch('core.graphql.queries.search_concepts_in_es', return_value=(None, 0)): result_query = async_to_sync(Query().concepts)( info_valid, query='UTIL', ) - self.assertGreaterEqual(result_query.total_count, 1) - self.assertFalse(result_query.has_next_page) + self.assertEqual(result_query.total_count, 2) + self.assertEqual([item.concept_id for item in result_query.results], ['UTIL-1', 'UTIL-2']) - with patch('core.graphql.queries.concept_ids_from_es', return_value=([], 2)), patch( - 'core.graphql.queries.resolve_source_version', return_value=self.source - ): + with self.settings(TEST_MODE=False), patch( + 'core.graphql.queries.search_concepts_in_es', return_value=([], 2) + ), patch('core.graphql.queries.resolve_source_version', return_value=self.source): result_es_empty = async_to_sync(Query().concepts)(info_valid, query='UTIL') self.assertEqual(result_es_empty.total_count, 2) self.assertEqual(result_es_empty.results, []) diff --git a/core/settings.py b/core/settings.py index ccd23806..0dd3673e 100644 --- a/core/settings.py +++ b/core/settings.py @@ -615,8 +615,11 @@ MINIO_SECURE = os.environ.get('MINIO_SECURE') == 'TRUE' NO_LM = os.environ.get('NO_LM') == 'TRUE' -if ENV not in ['ci', 'demo'] and not NO_LM: - LM_MODEL_NAME = 'all-MiniLM-L6-v2' +LM_MODEL_NAME = 'all-MiniLM-L6-v2' +LM = None +ENCODER = None + +if ENV and ENV not in ['ci', 'demo'] and not NO_LM: LM = SentenceTransformer(LM_MODEL_NAME) if ENV not in ['qa']: ENCODER = CrossEncoder("BAAI/bge-reranker-v2-m3", device="cpu")