Skip to content

Commit 011535b

Browse files
committed
Perf: lazy-load heavy modules to speed up import datajoint
1 parent 92bc557 commit 011535b

File tree

2 files changed

+89
-20
lines changed

2 files changed

+89
-20
lines changed

datajoint/__init__.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,88 @@
5454
"logger",
5555
"cli",
5656
]
57+
import importlib
58+
from typing import TYPE_CHECKING
5759

58-
from . import errors
59-
from .admin import kill, set_password
60-
from .attribute_adapter import AttributeAdapter
61-
from .blob import MatCell, MatStruct
62-
from .cli import cli
63-
from .connection import Connection, conn
64-
from .diagram import Diagram
60+
from . import errors
6561
from .errors import DataJointError
66-
from .expression import AndList, Not, Top, U
67-
from .fetch import key
68-
from .hash import key_hash
69-
from .logging import logger
70-
from .schemas import Schema, VirtualModule, list_schemas
71-
from .settings import config
72-
from .table import FreeTable, Table
73-
from .user_tables import Computed, Imported, Lookup, Manual, Part
74-
from .version import __version__
75-
76-
ERD = Di = Diagram # Aliases for Diagram
77-
schema = Schema # Aliases for Schema
78-
create_virtual_module = VirtualModule # Aliases for VirtualModule
62+
from .logging import logger
63+
from .settings import config
64+
from .version import __version__
65+
66+
from .connection import Connection, conn
67+
68+
if TYPE_CHECKING:
69+
from .admin import kill, set_password
70+
from .attribute_adapter import AttributeAdapter
71+
from .blob import MatCell, MatStruct
72+
from .cli import cli
73+
from .diagram import Diagram
74+
from .expression import AndList, Not, Top, U
75+
from .fetch import key
76+
from .hash import key_hash
77+
from .schemas import Schema, VirtualModule, list_schemas
78+
from .table import FreeTable, Table
79+
from .user_tables import Computed, Imported, Lookup, Manual, Part
80+
81+
82+
_LAZY: dict[str, tuple[str, str]] = {
83+
# admin
84+
"kill": ("datajoint.admin", "kill"),
85+
"set_password": ("datajoint.admin", "set_password"),
86+
87+
# core objects
88+
"Schema": ("datajoint.schemas", "Schema"),
89+
"VirtualModule": ("datajoint.schemas", "VirtualModule"),
90+
"list_schemas": ("datajoint.schemas", "list_schemas"),
91+
92+
# tables
93+
"Table": ("datajoint.table", "Table"),
94+
"FreeTable": ("datajoint.table", "FreeTable"),
95+
"Manual": ("datajoint.user_tables", "Manual"),
96+
"Lookup": ("datajoint.user_tables", "Lookup"),
97+
"Imported": ("datajoint.user_tables", "Imported"),
98+
"Computed": ("datajoint.user_tables", "Computed"),
99+
"Part": ("datajoint.user_tables", "Part"),
100+
101+
# diagram
102+
"Diagram": ("datajoint.diagram", "Diagram"),
103+
104+
# expressions
105+
"Not": ("datajoint.expression", "Not"),
106+
"AndList": ("datajoint.expression", "AndList"),
107+
"Top": ("datajoint.expression", "Top"),
108+
"U": ("datajoint.expression", "U"),
109+
110+
# misc utilities
111+
"MatCell": ("datajoint.blob", "MatCell"),
112+
"MatStruct": ("datajoint.blob", "MatStruct"),
113+
"AttributeAdapter": ("datajoint.attribute_adapter", "AttributeAdapter"),
114+
"key": ("datajoint.fetch", "key"),
115+
"key_hash": ("datajoint.hash", "key_hash"),
116+
"cli": ("datajoint.cli", "cli"),
117+
}
118+
_ALIAS: dict[str, str] = {
119+
"ERD": "Diagram",
120+
"Di": "Diagram",
121+
"schema": "Schema",
122+
"create_virtual_module": "VirtualModule",
123+
}
124+
125+
126+
def __getattr__(name: str):
127+
if name in _ALIAS:
128+
target = _ALIAS[name]
129+
value = getattr(importlib.import_module(_LAZY[target][0]), _LAZY[target][1])
130+
globals()[target] = value
131+
globals()[name] = value
132+
return value
133+
134+
if name in _LAZY:
135+
module_name, attr = _LAZY[name]
136+
module = importlib.import_module(module_name)
137+
value = getattr(module, attr)
138+
globals()[name] = value # cache
139+
return value
140+
141+
raise AttributeError(f"module 'datajoint' has no attribute {name}")

tests/test_import_performance.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def test_import_does_not_eager_load_heavy_deps():
2+
import sys
3+
import datajoint # noqa: F401
4+
5+
assert "datajoint.diagram" not in sys.modules
6+
assert "pandas" not in sys.modules

0 commit comments

Comments
 (0)