diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index c240ec23..af7100fd 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -56,7 +56,7 @@ jobs: - name: Test with pytest env: MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }} - #MP_API_ENDPOINT: https://api-preview.materialsproject.org/ + # MP_API_ENDPOINT: https://api-preview.materialsproject.org/ run: | pytest -n auto -x --cov=mp_api --cov-report=xml - uses: codecov/codecov-action@v1 diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 0545d78d..37b58aee 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -8,8 +8,10 @@ import gzip import inspect import itertools +import logging import os import platform +import shutil import sys import warnings from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait @@ -21,13 +23,17 @@ from json import JSONDecodeError from math import ceil from typing import TYPE_CHECKING, ForwardRef, Optional, get_args -from urllib.parse import quote +from urllib.parse import quote, urljoin import boto3 +import pyarrow as pa +import pyarrow.dataset as ds import requests from botocore import UNSIGNED from botocore.config import Config from botocore.exceptions import ClientError +from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake +from emmet.core.arrow import arrowize from emmet.core.utils import jsanitize from pydantic import BaseModel, create_model from requests.adapters import HTTPAdapter @@ -38,6 +44,7 @@ from mp_api.client.core.exceptions import MPRestError from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS from mp_api.client.core.utils import ( + MPDataset, load_json, validate_api_key, validate_endpoint, @@ -62,6 +69,15 @@ __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION") +hdlr = logging.StreamHandler() +fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +hdlr.setFormatter(fmt) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(hdlr) + + class _DictLikeAccess(BaseModel): """Define a pydantic mix-in which permits dict-like access to model fields.""" @@ -85,6 +101,7 @@ class BaseRester: suffix: str = "" document_model: type[BaseModel] | None = None primary_key: str = "material_id" + delta_backed: bool = False def __init__( self, @@ -98,6 +115,10 @@ def __init__( timeout: int = 20, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: ( + str | os.PathLike + ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, **kwargs, ): """Initialize the REST API helper class. @@ -129,6 +150,9 @@ def __init__( timeout: Time in seconds to wait until a request timeout error is thrown headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to 'mp_datasets' in the user's home directory + force_renew: Option to overwrite existing local dataset **kwargs: access to legacy kwargs that may be in the process of being deprecated """ self.api_key = validate_api_key(api_key) @@ -141,7 +165,14 @@ def __init__( self.timeout = timeout self.headers = headers or {} self.mute_progress_bars = mute_progress_bars - self.db_version = BaseRester._get_database_version(self.base_endpoint) + + ( + self.db_version, + self.access_controlled_batch_ids, + ) = BaseRester._get_heartbeat_info(self.base_endpoint) + + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self._session = session self._s3_client = s3_client @@ -209,8 +240,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover @staticmethod @cache - def _get_database_version(endpoint): - """The Materials Project database is periodically updated and has a + def _get_heartbeat_info(endpoint) -> tuple[str, str]: + """DB version: + The Materials Project database is periodically updated and has a database version associated with it. When the database is updated, consolidated data (information about "a material") may and does change, while calculation data about a specific calculation task @@ -220,9 +252,24 @@ def _get_database_version(endpoint): where "_DD" may be optional. An additional numerical or `postN` suffix might be added if multiple releases happen on the same day. - Returns: database version as a string + Access Controlled Datasets: + Certain contributions to the Materials Project have access + control restrictions that require explicit agreement to the + Terms of Use for the respective datasets prior to access being + granted. + + A full list of the Terms of Use for all contributions in the + Materials Project are available at: + + https://next-gen.materialsproject.org/about/terms + + Returns: + tuple with database version as a string and a comma separated + string with all calculation batch identifiers that have access + restrictions """ - return requests.get(url=endpoint + "heartbeat").json()["db_version"] + response = requests.get(url=endpoint + "heartbeat").json() + return response["db_version"], response["access_controlled_batch_ids"] def _post_resource( self, @@ -353,10 +400,7 @@ def _patch_resource( raise MPRestError(str(ex)) def _query_open_data( - self, - bucket: str, - key: str, - decoder: Callable | None = None, + self, bucket: str, key: str, decoder: Callable | None = None ) -> tuple[list[dict] | list[bytes], int]: """Query and deserialize Materials Project AWS open data s3 buckets. @@ -460,6 +504,12 @@ def _query_resource( url = validate_endpoint(self.endpoint, suffix=suburl) if query_s3: + pbar_message = ( # type: ignore + f"Retrieving {self.document_model.__name__} documents" # type: ignore + if self.document_model is not None + else "Retrieving documents" + ) + if "/" not in self.suffix: suffix = self.suffix elif self.suffix == "molecules/summary": @@ -469,15 +519,177 @@ def _query_resource( suffix = infix if suffix == "core" else suffix suffix = suffix.replace("_", "-") - # Paginate over all entries in the bucket. - # TODO: change when a subset of entries needed from DB + # Check if user has access to GNoMe + # temp suppress tqdm + re_enable = not self.mute_progress_bars + self.mute_progress_bars = True + has_gnome_access = bool( + self._submit_requests( + url=urljoin(self.base_endpoint, "materials/summary/"), + criteria={ + "batch_id": "gnome_r2scan_statics", + "_fields": "material_id", + }, + use_document_model=False, + num_chunks=1, + chunk_size=1, + timeout=timeout, + ) + .get("meta", {}) + .get("total_doc", 0) + ) + self.mute_progress_bars = not re_enable + if "tasks" in suffix: - bucket_suffix, prefix = "parsed", "tasks_atomate2" + bucket_suffix, prefix = ("parsed", "core/tasks/") else: bucket_suffix = "build" prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}" bucket = f"materialsproject-{bucket_suffix}" + + if self.delta_backed: + target_path = str( + self.local_dataset_cache.joinpath(f"{bucket_suffix}/{prefix}") + ) + os.makedirs(target_path, exist_ok=True) + + if DeltaTable.is_deltatable(target_path): + if self.force_renew: + shutil.rmtree(target_path) + logger.warning( + f"Regenerating {suffix} dataset at {target_path}..." + ) + os.makedirs(target_path, exist_ok=True) + else: + logger.warning( + f"Dataset for {suffix} already exists at {target_path}, returning existing dataset." + ) + logger.info( + "Delete or move existing dataset or re-run search query with MPRester(force_renew=True) " + "to refresh local dataset.", + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + tbl = DeltaTable( + f"s3a://{bucket}/{prefix}", + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + }, + ) + + controlled_batch_str = ",".join( + [f"'{tag}'" for tag in self.access_controlled_batch_ids] + ) + + predicate = ( + f"WHERE batch_id NOT IN ({controlled_batch_str})" + if not has_gnome_access + else "" + ) + + builder = QueryBuilder().register("tbl", tbl) + + # Setup progress bar + num_docs_needed = tbl.count() + + if not has_gnome_access: + num_docs_needed = self.count( + {"batch_id_neq_any": self.access_controlled_batch_ids} + ) + + pbar = ( + tqdm( + desc=pbar_message, + total=num_docs_needed, + ) + if not self.mute_progress_bars + else None + ) + + iterator = builder.execute(f"SELECT * FROM tbl {predicate}") + + file_options = ds.ParquetFileFormat().make_write_options( + compression="zstd" + ) + + def _flush( + accumulator: list[pa.RecordBatch], group: int, schema: pa.Schema + ): + # somewhere post datafusion 51.0.0 and arrow-rs 57.0.0 + # casts to *View types began, need to cast back to base schema + # -> pyarrow is behind on implementation support for *View types + tbl = ( + pa.Table.from_batches(accumulator) + .select(schema.names) + .cast(target_schema=schema) + ) + + ds.write_dataset( + tbl, + base_dir=target_path, + format="parquet", + basename_template=f"group-{group}-" + + "part-{i}.zstd.parquet", + existing_data_behavior="overwrite_or_ignore", + max_rows_per_group=1024, + file_options=file_options, + ) + + group = 1 + size = 0 + accumulator = [] + schema = pa.schema(arrowize(self.document_model)) + for page in iterator: + # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer + rg = pa.record_batch(page) + accumulator.append(rg) + page_size = page.num_rows + size += rg.get_total_buffer_size() + + if pbar is not None: + pbar.update(page_size) + + if size >= MAPI_CLIENT_SETTINGS.DATASET_FLUSH_THRESHOLD: + _flush(accumulator, group, schema) + group += 1 + size = 0 + accumulator.clear() + + if accumulator: + _flush(accumulator, group + 1, schema) + + if pbar is not None: + pbar.close() + + logger.info(f"Dataset for {suffix} written to {target_path}") + logger.info("Converting to DeltaTable...") + + convert_to_deltalake(target_path) + + logger.info( + "Consult the delta-rs and pyarrow documentation for advanced usage: " + "delta-io.github.io/delta-rs, arrow.apache.org/docs/python" + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + # Paginate over all entries in the bucket. + # TODO: change when a subset of entries needed from DB paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) @@ -514,11 +726,6 @@ def _query_resource( } # Setup progress bar - pbar_message = ( # type: ignore - f"Retrieving {self.document_model.__name__} documents" # type: ignore - if self.document_model is not None - else "Retrieving documents" - ) num_docs_needed = int(self.count()) pbar = ( tqdm( @@ -1365,6 +1572,8 @@ def __getattr__(self, v: str): use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) return self.sub_resters[v] diff --git a/mp_api/client/core/exceptions.py b/mp_api/client/core/exceptions.py index fa9f8793..45cef802 100644 --- a/mp_api/client/core/exceptions.py +++ b/mp_api/client/core/exceptions.py @@ -1,4 +1,5 @@ """Define custom exceptions and warnings for the client.""" + from __future__ import annotations @@ -8,3 +9,15 @@ class MPRestError(Exception): class MPRestWarning(Warning): """Raised when a query is malformed but interpretable.""" + + +class MPDatasetIndexingWarning(Warning): + """Raised during sub-optimal indexing of MPDatasets.""" + + +class MPDatasetSlicingWarning(Warning): + """Raised during sub-optimal slicing of MPDatasets.""" + + +class MPDatasetIterationWarning(Warning): + """Raised during sub-optimal iteration of MPDatasets.""" diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index f5818d57..9a7e0852 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from emmet.core.settings import EmmetSettings from pydantic import Field, field_validator @@ -47,6 +48,7 @@ class MAPIClientSettings(BaseSettings): "condition_mixing_media", "condition_heating_atmosphere", "operations", + "batch_id_neq_any", "_fields", ], description="List API query parameters that do not support parallel requests.", @@ -101,6 +103,16 @@ class MAPIClientSettings(BaseSettings): description="Angle tolerance for structure matching in degrees.", ) + LOCAL_DATASET_CACHE: Path = Field( + Path("~/mp_datasets").expanduser(), + description="Target directory for downloading full datasets", + ) + + DATASET_FLUSH_THRESHOLD: int = Field( + int(2.75 * 1024**3), + description="Threshold bytes to accumulate in memory before flushing dataset to disk", + ) + model_config = SettingsConfigDict(env_prefix="MPRESTER_") @field_validator("ENDPOINT", mode="before") diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 935e1453..82996a93 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -2,22 +2,35 @@ import os import warnings +from functools import cached_property from importlib import import_module +from itertools import chain from typing import TYPE_CHECKING, Literal from urllib.parse import urljoin import orjson +import pyarrow.dataset as ds +from deltalake import DeltaTable from emmet.core import __version__ as _EMMET_CORE_VER from emmet.core.mpid_ext import validate_identifier from monty.json import MontyDecoder from packaging.version import parse as parse_version -from mp_api.client.core.exceptions import MPRestError, MPRestWarning +from mp_api.client.core.exceptions import ( + MPDatasetIndexingWarning, + MPDatasetIterationWarning, + MPDatasetSlicingWarning, + MPRestError, + MPRestWarning, +) from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS if TYPE_CHECKING: + from pathlib import Path from typing import Any + from pydantic._internal._model_construction import ModelMetaclass + def _compare_emmet_ver( ref_version: str, op: Literal["==", ">", ">=", "<", "<="] @@ -229,3 +242,115 @@ def __getattr__(self, v: str) -> Any: self._load() if hasattr(self._imported, v): return getattr(self._imported, v) + + +class MPDataset: + """Convenience wrapper for pyarrow datasets stored on disk.""" + + def __init__( + self, + path: Path, + document_model: ModelMetaclass, + use_document_model: bool, + ): + """Initialize a MPDataset. + + Parameters + ----------- + path: Path | str + A path-like string. + document_model: ModelMetaclass + Pydantic document model for use during de-serialization of arrow data + use_document_model: bool + Use 'document_model' during de-serialization of arrow data. + """ + self._start = 0 + self._path = path + self._document_model = document_model + self._dataset = ds.dataset(path) + self._row_groups = list( + chain.from_iterable( + [ + fragment.split_by_row_group() + for fragment in self._dataset.get_fragments() + ] + ) + ) + self._use_document_model = use_document_model + + @property + def pyarrow_dataset(self) -> ds.Dataset: + return self._dataset + + @property + def pydantic_model(self) -> ModelMetaclass: + return self._document_model + + @property + def use_document_model(self) -> bool: + return self._use_document_model + + @use_document_model.setter + def use_document_model(self, value: bool): + self._use_document_model = value + + @cached_property + def delta_table(self) -> DeltaTable: + return DeltaTable(self._path) + + @cached_property + def num_chunks(self) -> int: + return len(self._row_groups) + + def __getitem__(self, idx): + if isinstance(idx, slice): + warnings.warn( + """ + Pythonic slicing of arrow-based MPDatasets is sub-optimal, consider using + idiomatic arrow patterns. See MP's docs on MPDatasets for relevant examples: + docs.materialsproject.org/downloading-data/arrow-datasets + """, + MPDatasetSlicingWarning, + stacklevel=2, + ) + start, stop, step = idx.indices(len(self)) + _take = list(range(start, stop, step)) + ds_slice = self._dataset.take(_take).to_pylist(maps_as_pydicts="strict") + return ( + [self._document_model(**_row) for _row in ds_slice] + if self._use_document_model + else ds_slice + ) + + warnings.warn( + """ + Pythonic indexing into arrow-based MPDatasets is sub-optimal, consider using + idiomatic arrow patterns. See MP's docs on MPDatasets for relevant examples: + docs.materialsproject.org/downloading-data/arrow-datasets + """, + MPDatasetIndexingWarning, + stacklevel=2, + ) + _row = self._dataset.take([idx]).to_pylist(maps_as_pydicts="strict")[0] + return self._document_model(**_row) if self._use_document_model else _row + + def __len__(self) -> int: + return self._dataset.count_rows() + + def __iter__(self): + with warnings.catch_warnings( + action="ignore", category=MPDatasetIndexingWarning + ): + warnings.warn( + """ + Iterating through arrow-based MPDatasets is sub-optimal, consider using + idiomatic arrow patterns. See MP's docs on MPDatasets for relevant examples: + docs.materialsproject.org/downloading-data/arrow-datasets + """, + MPDatasetIterationWarning, + stacklevel=2, + ) + current = self._start + while current < len(self): + yield self[current] + current += 1 diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 1d9afc5c..3dd81b3c 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +import os import warnings from collections import defaultdict from functools import cache, lru_cache @@ -82,6 +83,10 @@ def __init__( session: Session | None = None, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: ( + str | os.PathLike + ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, **kwargs, ): """Initialize the MPRester. @@ -116,6 +121,9 @@ def __init__( session: Session object to use. By default (None), the client will create one. headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to "mp_datasets" in the user's home directory + force_renew: Option to overwrite existing local dataset **kwargs: access to legacy kwargs that may be in the process of being deprecated """ self.api_key = validate_api_key(api_key) @@ -131,6 +139,8 @@ def __init__( self._include_user_agent = include_user_agent self.use_document_model = use_document_model self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self._contribs = None self._deprecated_attributes = [ @@ -205,6 +215,8 @@ def __init__( use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ), ) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 4e8498c9..fdf391a9 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -3,8 +3,11 @@ from datetime import datetime from typing import TYPE_CHECKING +import pyarrow as pa +from deltalake import DeltaTable, QueryBuilder from emmet.core.mpid import MPID, AlphaID from emmet.core.tasks import CoreTaskDoc +from emmet.core.trajectory import RelaxTrajectory from mp_api.client.core import BaseRester, MPRestError from mp_api.client.core.utils import validate_ids @@ -19,8 +22,9 @@ class TaskRester(BaseRester): suffix: str = "materials/tasks" document_model: type[BaseModel] = CoreTaskDoc # type: ignore primary_key: str = "task_id" + delta_backed = True - def get_trajectory(self, task_id: MPID | AlphaID | str) -> list[dict[str, Any]]: + def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]: """Returns a Trajectory object containing the geometry of the material throughout a calculation. This is most useful for observing how a material relaxes during a geometry optimization. @@ -29,20 +33,31 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> list[dict[str, Any]]: task_id (str, MPID, AlphaID): Task ID Returns: - list of dict representing emmet.core.trajectory.Trajectory + dict representing emmet.core.trajectory.RelaxTrajectory """ - traj_data = self._query_resource_data( - {"task_ids": [AlphaID(task_id).string]}, - suburl="trajectory/", - use_document_model=False, - )[0].get( - "trajectories", None - ) # type: ignore - - if traj_data is None: + as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] + traj_tbl = DeltaTable( + "s3a://materialsproject-parsed/core/trajectories/", + storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + ) + + traj_data = pa.table( + QueryBuilder() + .register("traj", traj_tbl) + .execute( + f""" + SELECT * + FROM traj + WHERE identifier='{as_alpha}' + """ + ) + .read_all() + ).to_pylist(maps_as_pydicts="strict") + + if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") - return traj_data + return RelaxTrajectory(**traj_data[0]).model_dump() def search( self, diff --git a/pyproject.toml b/pyproject.toml index d91f4029..a28c6122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "emmet-core>=0.86.3", "boto3", "orjson >= 3.10,<4", + "pyarrow >= 20.0.0", + "deltalake >= 1.4.0", ] dynamic = ["version"] diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index b83dc92a..2769b72b 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -6,13 +6,15 @@ # annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake bibtexparser==1.4.4 # via pymatgen blake3==1.0.8 # via emmet-core -boto3==1.42.38 +boto3==1.42.45 # via mp-api (pyproject.toml) -botocore==1.42.38 +botocore==1.42.45 # via # boto3 # s3transfer @@ -24,6 +26,10 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib +deltalake==1.4.2 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake emmet-core==0.86.3 # via mp-api (pyproject.toml) fonttools==4.61.1 @@ -49,11 +55,11 @@ monty==2025.3.3 # pymatgen mpmath==1.3.0 # via sympy -narwhals==2.15.0 +narwhals==2.16.0 # via plotly networkx==3.6.1 # via pymatgen -numpy==2.4.1 +numpy==2.4.2 # via # contourpy # matplotlib @@ -63,7 +69,7 @@ numpy==2.4.1 # pymatgen-io-validation # scipy # spglib -orjson==3.11.6 +orjson==3.11.7 # via # mp-api (pyproject.toml) # pymatgen @@ -79,6 +85,8 @@ pillow==12.1.0 # via matplotlib plotly==6.5.2 # via pymatgen +pyarrow==23.0.0 + # via mp-api (pyproject.toml) pybtex==0.25.1 # via emmet-core pydantic==2.12.5 @@ -133,10 +141,11 @@ sympy==1.14.0 # via pymatgen tabulate==0.9.0 # via pymatgen -tqdm==4.67.1 +tqdm==4.67.3 # via pymatgen typing-extensions==4.15.0 # via + # arro3-core # blake3 # emmet-core # mp-api (pyproject.toml) @@ -154,3 +163,5 @@ urllib3==2.6.3 # via # botocore # requests +wrapt==2.1.1 + # via deprecated diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index ffc1a886..0f35bc7c 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -14,6 +14,8 @@ anyio==4.12.1 # mcp # sse-starlette # starlette +arro3-core==0.6.5 + # via deltalake arrow==1.4.0 # via isoduration ase==3.27.0 @@ -25,9 +27,9 @@ attrs==25.4.0 # cyclopts # jsonschema # referencing -authlib==1.6.6 +authlib==1.6.7 # via fastmcp -babel==2.17.0 +babel==2.18.0 # via sphinx backports-tarfile==1.2.0 # via jaraco-context @@ -41,9 +43,9 @@ blake3==1.0.8 # via emmet-core boltons==25.0.0 # via mpcontribs-client -boto3==1.42.38 +boto3==1.42.45 # via mp-api (pyproject.toml) -botocore==1.42.38 +botocore==1.42.45 # via # boto3 # s3transfer @@ -51,7 +53,7 @@ bravado==12.0.1 # via mpcontribs-client bravado-core==6.1.1 # via bravado -cachetools==6.2.6 +cachetools==7.0.0 # via # mpcontribs-client # py-key-value-aio @@ -74,8 +76,10 @@ cloudpickle==3.1.2 # via pydocket contourpy==1.3.3 # via matplotlib -coverage[toml]==7.13.2 +coverage[toml]==7.13.4 # via pytest-cov +croniter==6.0.0 + # via pydocket cryptography==46.0.4 # via # authlib @@ -89,6 +93,10 @@ cyclopts==4.5.1 # via fastmcp decorator==5.2.1 # via ipython +deltalake==1.4.2 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake diskcache==5.6.3 # via py-key-value-aio distlib==0.4.0 @@ -116,7 +124,7 @@ executing==2.2.1 # via stack-data fakeredis[lua]==2.33.0 # via pydocket -fastmcp==2.14.4 +fastmcp==2.14.5 # via mp-api (pyproject.toml) filelock==3.20.3 # via virtualenv @@ -169,7 +177,7 @@ inflect==7.5.0 # via robocrys iniconfig==2.3.0 # via pytest -ipython==9.9.0 +ipython==9.10.0 # via mpcontribs-client ipython-pygments-lexers==1.1.1 # via ipython @@ -268,7 +276,7 @@ more-itertools==10.8.0 # jaraco-functools mp-pyrho==0.5.1 # via pymatgen-analysis-defects -mpcontribs-client==5.10.4 +mpcontribs-client==5.10.5 # via mp-api (pyproject.toml) mpmath==1.3.0 # via sympy @@ -282,7 +290,7 @@ mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.15.0 +narwhals==2.16.0 # via plotly networkx==3.6.1 # via @@ -291,7 +299,7 @@ networkx==3.6.1 # scikit-image nodeenv==1.10.0 # via pre-commit -numpy==2.4.1 +numpy==2.4.2 # via # ase # contourpy @@ -317,23 +325,8 @@ numpy==2.4.1 openapi-pydantic==0.5.1 # via fastmcp opentelemetry-api==1.39.1 - # via - # opentelemetry-exporter-prometheus - # opentelemetry-instrumentation - # opentelemetry-sdk - # opentelemetry-semantic-conventions - # pydocket -opentelemetry-exporter-prometheus==0.60b1 # via pydocket -opentelemetry-instrumentation==0.60b1 - # via pydocket -opentelemetry-sdk==1.39.1 - # via opentelemetry-exporter-prometheus -opentelemetry-semantic-conventions==0.60b1 - # via - # opentelemetry-instrumentation - # opentelemetry-sdk -orjson==3.11.6 +orjson==3.11.7 # via # mp-api (pyproject.toml) # pymatgen @@ -342,7 +335,6 @@ packaging==26.0 # fastmcp # lazy-loader # matplotlib - # opentelemetry-instrumentation # plotly # pytest # scikit-image @@ -355,7 +347,7 @@ pandas==2.3.3 # mpcontribs-client # pymatgen # seaborn -parso==0.8.5 +parso==0.8.6 # via jedi pathable==0.4.4 # via jsonschema-path @@ -388,9 +380,7 @@ pluggy==1.6.0 pre-commit==4.5.1 # via mp-api (pyproject.toml) prometheus-client==0.24.1 - # via - # opentelemetry-exporter-prometheus - # pydocket + # via pydocket prompt-toolkit==3.0.52 # via ipython psutil==7.2.2 @@ -408,7 +398,9 @@ py-key-value-aio[disk,keyring,memory,redis]==0.3.0 py-key-value-shared==0.3.0 # via py-key-value-aio pyarrow==23.0.0 - # via emmet-core + # via + # emmet-core + # mp-api (pyproject.toml) pybtex==0.25.1 # via # emmet-core @@ -434,7 +426,7 @@ pydantic-settings==2.12.0 # emmet-core # mcp # pymatgen-io-validation -pydocket==0.16.6 +pydocket==0.17.5 # via fastmcp pyflakes==3.4.0 # via flake8 @@ -447,7 +439,7 @@ pygments==2.19.2 # sphinx pyisemail==2.0.1 # via mpcontribs-client -pyjwt[crypto]==2.10.1 +pyjwt[crypto]==2.11.0 # via mcp pymatgen==2025.10.7 # via @@ -500,6 +492,7 @@ python-dateutil==2.9.0.post0 # botocore # bravado # bravado-core + # croniter # matplotlib # pandas python-dotenv==1.2.1 @@ -513,6 +506,7 @@ python-multipart==0.0.22 pytz==2025.2 # via # bravado-core + # croniter # pandas pyyaml==6.0.3 # via @@ -522,7 +516,7 @@ pyyaml==6.0.3 # pre-commit # pybtex # swagger-spec-validator -redis==7.1.0 +redis==7.1.1 # via # fakeredis # py-key-value-aio @@ -551,7 +545,7 @@ rfc3986-validator==0.1.1 # via jsonschema rfc3987-syntax==1.1.0 # via jsonschema -rich==14.3.1 +rich==14.3.2 # via # cyclopts # fastmcp @@ -591,10 +585,8 @@ seaborn==0.13.2 # via pymatgen-analysis-diffusion secretstorage==3.5.0 # via keyring -seekpath==2.2.0 +seekpath==2.2.1 # via emmet-core -semantic-version==2.10.0 - # via mpcontribs-client shapely==2.1.2 # via pymatgen-analysis-alloys shellingham==1.5.4 @@ -655,7 +647,7 @@ threadpoolctl==3.6.0 # via scikit-learn tifffile==2026.1.28 # via scikit-image -tqdm==4.67.1 +tqdm==4.67.3 # via # matminer # mpcontribs-client @@ -670,11 +662,12 @@ typer==0.21.1 # via pydocket types-requests==2.32.4.20260107 # via mp-api (pyproject.toml) -types-setuptools==80.10.0.20260124 +types-setuptools==81.0.0.20260209 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via # anyio + # arro3-core # blake3 # bravado # emmet-core @@ -686,8 +679,6 @@ typing-extensions==4.15.0 # mp-api (pyproject.toml) # mypy # opentelemetry-api - # opentelemetry-sdk - # opentelemetry-semantic-conventions # pint # py-key-value-shared # pydantic @@ -727,13 +718,13 @@ uvicorn==0.40.0 # mcp virtualenv==20.36.1 # via pre-commit -wcwidth==0.5.2 +wcwidth==0.6.0 # via prompt-toolkit webcolors==25.10.0 # via jsonschema websockets==16.0 # via fastmcp -wrapt==1.17.3 - # via opentelemetry-instrumentation +wrapt==2.1.1 + # via deprecated zipp==3.23.0 # via importlib-metadata diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index 9822afda..6bdfdf4a 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -6,13 +6,15 @@ # annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake bibtexparser==1.4.4 # via pymatgen blake3==1.0.8 # via emmet-core -boto3==1.42.38 +boto3==1.42.45 # via mp-api (pyproject.toml) -botocore==1.42.38 +botocore==1.42.45 # via # boto3 # s3transfer @@ -24,6 +26,10 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib +deltalake==1.4.2 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake emmet-core==0.86.3 # via mp-api (pyproject.toml) fonttools==4.61.1 @@ -49,11 +55,11 @@ monty==2025.3.3 # pymatgen mpmath==1.3.0 # via sympy -narwhals==2.15.0 +narwhals==2.16.0 # via plotly networkx==3.6.1 # via pymatgen -numpy==2.4.1 +numpy==2.4.2 # via # contourpy # matplotlib @@ -63,7 +69,7 @@ numpy==2.4.1 # pymatgen-io-validation # scipy # spglib -orjson==3.11.6 +orjson==3.11.7 # via # mp-api (pyproject.toml) # pymatgen @@ -79,6 +85,8 @@ pillow==12.1.0 # via matplotlib plotly==6.5.2 # via pymatgen +pyarrow==23.0.0 + # via mp-api (pyproject.toml) pybtex==0.25.1 # via emmet-core pydantic==2.12.5 @@ -133,7 +141,7 @@ sympy==1.14.0 # via pymatgen tabulate==0.9.0 # via pymatgen -tqdm==4.67.1 +tqdm==4.67.3 # via pymatgen typing-extensions==4.15.0 # via @@ -153,3 +161,5 @@ urllib3==2.6.3 # via # botocore # requests +wrapt==2.1.1 + # via deprecated diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index 6c0f7e0e..d6ae9c46 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -14,6 +14,8 @@ anyio==4.12.1 # mcp # sse-starlette # starlette +arro3-core==0.6.5 + # via deltalake arrow==1.4.0 # via isoduration ase==3.27.0 @@ -25,9 +27,9 @@ attrs==25.4.0 # cyclopts # jsonschema # referencing -authlib==1.6.6 +authlib==1.6.7 # via fastmcp -babel==2.17.0 +babel==2.18.0 # via sphinx beartype==0.22.9 # via @@ -39,9 +41,9 @@ blake3==1.0.8 # via emmet-core boltons==25.0.0 # via mpcontribs-client -boto3==1.42.38 +boto3==1.42.45 # via mp-api (pyproject.toml) -botocore==1.42.38 +botocore==1.42.45 # via # boto3 # s3transfer @@ -49,7 +51,7 @@ bravado==12.0.1 # via mpcontribs-client bravado-core==6.1.1 # via bravado -cachetools==6.2.6 +cachetools==7.0.0 # via # mpcontribs-client # py-key-value-aio @@ -72,8 +74,10 @@ cloudpickle==3.1.2 # via pydocket contourpy==1.3.3 # via matplotlib -coverage[toml]==7.13.2 +coverage[toml]==7.13.4 # via pytest-cov +croniter==6.0.0 + # via pydocket cryptography==46.0.4 # via # authlib @@ -87,6 +91,10 @@ cyclopts==4.5.1 # via fastmcp decorator==5.2.1 # via ipython +deltalake==1.4.2 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake diskcache==5.6.3 # via py-key-value-aio distlib==0.4.0 @@ -114,7 +122,7 @@ executing==2.2.1 # via stack-data fakeredis[lua]==2.33.0 # via pydocket -fastmcp==2.14.4 +fastmcp==2.14.5 # via mp-api (pyproject.toml) filelock==3.20.3 # via virtualenv @@ -165,7 +173,7 @@ inflect==7.5.0 # via robocrys iniconfig==2.3.0 # via pytest -ipython==9.9.0 +ipython==9.10.0 # via mpcontribs-client ipython-pygments-lexers==1.1.1 # via ipython @@ -264,7 +272,7 @@ more-itertools==10.8.0 # jaraco-functools mp-pyrho==0.5.1 # via pymatgen-analysis-defects -mpcontribs-client==5.10.4 +mpcontribs-client==5.10.5 # via mp-api (pyproject.toml) mpmath==1.3.0 # via sympy @@ -278,7 +286,7 @@ mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.15.0 +narwhals==2.16.0 # via plotly networkx==3.6.1 # via @@ -287,7 +295,7 @@ networkx==3.6.1 # scikit-image nodeenv==1.10.0 # via pre-commit -numpy==2.4.1 +numpy==2.4.2 # via # ase # contourpy @@ -313,23 +321,8 @@ numpy==2.4.1 openapi-pydantic==0.5.1 # via fastmcp opentelemetry-api==1.39.1 - # via - # opentelemetry-exporter-prometheus - # opentelemetry-instrumentation - # opentelemetry-sdk - # opentelemetry-semantic-conventions - # pydocket -opentelemetry-exporter-prometheus==0.60b1 # via pydocket -opentelemetry-instrumentation==0.60b1 - # via pydocket -opentelemetry-sdk==1.39.1 - # via opentelemetry-exporter-prometheus -opentelemetry-semantic-conventions==0.60b1 - # via - # opentelemetry-instrumentation - # opentelemetry-sdk -orjson==3.11.6 +orjson==3.11.7 # via # mp-api (pyproject.toml) # pymatgen @@ -338,7 +331,6 @@ packaging==26.0 # fastmcp # lazy-loader # matplotlib - # opentelemetry-instrumentation # plotly # pytest # scikit-image @@ -351,7 +343,7 @@ pandas==2.3.3 # mpcontribs-client # pymatgen # seaborn -parso==0.8.5 +parso==0.8.6 # via jedi pathable==0.4.4 # via jsonschema-path @@ -384,9 +376,7 @@ pluggy==1.6.0 pre-commit==4.5.1 # via mp-api (pyproject.toml) prometheus-client==0.24.1 - # via - # opentelemetry-exporter-prometheus - # pydocket + # via pydocket prompt-toolkit==3.0.52 # via ipython psutil==7.2.2 @@ -404,7 +394,9 @@ py-key-value-aio[disk,keyring,memory,redis]==0.3.0 py-key-value-shared==0.3.0 # via py-key-value-aio pyarrow==23.0.0 - # via emmet-core + # via + # emmet-core + # mp-api (pyproject.toml) pybtex==0.25.1 # via # emmet-core @@ -430,7 +422,7 @@ pydantic-settings==2.12.0 # emmet-core # mcp # pymatgen-io-validation -pydocket==0.16.6 +pydocket==0.17.5 # via fastmcp pyflakes==3.4.0 # via flake8 @@ -443,7 +435,7 @@ pygments==2.19.2 # sphinx pyisemail==2.0.1 # via mpcontribs-client -pyjwt[crypto]==2.10.1 +pyjwt[crypto]==2.11.0 # via mcp pymatgen==2025.10.7 # via @@ -496,6 +488,7 @@ python-dateutil==2.9.0.post0 # botocore # bravado # bravado-core + # croniter # matplotlib # pandas python-dotenv==1.2.1 @@ -509,6 +502,7 @@ python-multipart==0.0.22 pytz==2025.2 # via # bravado-core + # croniter # pandas pyyaml==6.0.3 # via @@ -518,7 +512,7 @@ pyyaml==6.0.3 # pre-commit # pybtex # swagger-spec-validator -redis==7.1.0 +redis==7.1.1 # via # fakeredis # py-key-value-aio @@ -547,7 +541,7 @@ rfc3986-validator==0.1.1 # via jsonschema rfc3987-syntax==1.1.0 # via jsonschema -rich==14.3.1 +rich==14.3.2 # via # cyclopts # fastmcp @@ -587,10 +581,8 @@ seaborn==0.13.2 # via pymatgen-analysis-diffusion secretstorage==3.5.0 # via keyring -seekpath==2.2.0 +seekpath==2.2.1 # via emmet-core -semantic-version==2.10.0 - # via mpcontribs-client shapely==2.1.2 # via pymatgen-analysis-alloys shellingham==1.5.4 @@ -651,7 +643,7 @@ threadpoolctl==3.6.0 # via scikit-learn tifffile==2026.1.28 # via scikit-image -tqdm==4.67.1 +tqdm==4.67.3 # via # matminer # mpcontribs-client @@ -666,7 +658,7 @@ typer==0.21.1 # via pydocket types-requests==2.32.4.20260107 # via mp-api (pyproject.toml) -types-setuptools==80.10.0.20260124 +types-setuptools==81.0.0.20260209 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via @@ -680,8 +672,6 @@ typing-extensions==4.15.0 # mp-api (pyproject.toml) # mypy # opentelemetry-api - # opentelemetry-sdk - # opentelemetry-semantic-conventions # pint # py-key-value-shared # pydantic @@ -721,13 +711,13 @@ uvicorn==0.40.0 # mcp virtualenv==20.36.1 # via pre-commit -wcwidth==0.5.2 +wcwidth==0.6.0 # via prompt-toolkit webcolors==25.10.0 # via jsonschema websockets==16.0 # via fastmcp -wrapt==1.17.3 - # via opentelemetry-instrumentation +wrapt==2.1.1 + # via deprecated zipp==3.23.0 # via importlib-metadata diff --git a/tests/client/materials/test_tasks.py b/tests/client/materials/test_tasks.py index c89530a4..ea064989 100644 --- a/tests/client/materials/test_tasks.py +++ b/tests/client/materials/test_tasks.py @@ -1,12 +1,14 @@ import os -from ..conftest import client_search_testing, requires_api_key -import pytest +import pytest from emmet.core.mpid import MPID, AlphaID -from emmet.core.trajectory import Trajectory +from emmet.core.trajectory import RelaxTrajectory from emmet.core.utils import utcnow + from mp_api.client.routes.materials.tasks import TaskRester +from ..conftest import client_search_testing, requires_api_key + @pytest.fixture def rester(): @@ -56,11 +58,11 @@ def test_client(rester): @pytest.mark.parametrize("mpid", ["mp-149", MPID("mp-149"), AlphaID("mp-149")]) def test_get_trajectories(rester, mpid): - trajectories = [traj for traj in rester.get_trajectory(mpid)] + trajectory = rester.get_trajectory(mpid) expected_model_fields = { field_name - for field_name, field in Trajectory.model_fields.items() + for field_name, field in RelaxTrajectory.model_fields.items() if not field.exclude } - assert all(set(traj) == expected_model_fields for traj in trajectories) + assert set(trajectory) == expected_model_fields