Refine JIT wrappers for new JAX for comaptiblity with jax>=0.8.2#809
Refine JIT wrappers for new JAX for comaptiblity with jax>=0.8.2#809chaoming0625 merged 3 commits intomasterfrom
jax>=0.8.2#809Conversation
Reviewer's GuideRefactors BrainPy to be compatible with newer JAX (>=0.8.2) by updating the JIT wrappers, cleaning up imports and style across many modules, and adding a small smoke test entry point, without changing core numerical logic. Sequence diagram for updated jit wrapper behaviorsequenceDiagram
participant User as UserCode
participant BPJit as brainpy_math_object_transform_jit
participant BSTransform as brainstate_transform
participant compiled_func as compiled_func
User->>BPJit: jit(func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, **kwargs)
activate BPJit
BPJit->>BPJit: warp_to_no_state_input_output(func)
BPJit->>BSTransform: jit(wrapped_func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, **kwargs)
activate BSTransform
BSTransform-->>BPJit: compiled_func
deactivate BSTransform
BPJit-->>User: compiled_func
deactivate BPJit
User->>compiled_func: call_compiled(*args, **kwargs)
activate compiled_func
compiled_func-->>User: results
deactivate compiled_func
Sequence diagram for cls_jit method wrappingsequenceDiagram
participant User as UserClassDefinition
participant BPJit as brainpy_math_object_transform_jit
participant BSTransform as brainstate_transform
User->>BPJit: cls_jit(func, static_argnums, static_argnames, inline, keep_unused, **kwargs)
activate BPJit
BPJit->>BPJit: wrap func to bind self as first argument
BPJit->>BPJit: call jit(wrapped_method, static_argnums, static_argnames, donate_argnums, inline, keep_unused, **kwargs)
BPJit->>BSTransform: jit(wrapped_method, ...)
activate BSTransform
BSTransform-->>BPJit: compiled_method
deactivate BSTransform
BPJit-->>User: compiled_method (to be attached as bound method)
deactivate BPJit
User->>User: attach compiled_method to class instances
Class diagram for jit and ProgressBar related utilitiesclassDiagram
class brainpy_math_object_transform_controls {
+_convert_progress_bar_to_pbar(progress_bar) brainstate_transform_ProgressBar
}
class brainpy_math_object_transform_jit {
+jit(func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, kwargs) Callable
+cls_jit(func, static_argnums, static_argnames, inline, keep_unused, kwargs) Callable
}
class brainpy_math_object_transform__utils {
+warp_to_no_state_input_output(func) Callable
}
class brainstate_transform_ProgressBar {
+freq
}
class brainstate_transform {
+jit(func, static_argnums, static_argnames, donate_argnums, inline, keep_unused, kwargs) Callable
}
brainpy_math_object_transform_controls ..> brainstate_transform_ProgressBar : uses
brainpy_math_object_transform_jit ..> brainpy_math_object_transform__utils : uses warp_to_no_state_input_output
brainpy_math_object_transform_jit ..> brainstate_transform : wraps jit
brainpy_math_object_transform_controls ..> brainstate_transform : creates ProgressBar instances
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Hey - I've found 1 issue, and left some high level feedback:
- In
brainpy/dyn/projections/utils.py,_get_return()still referencesbmbut thefrom brainpy import math as bmimport was removed, which will cause a NameError at runtime; restore or replace this import. - The new
brainpy/test_main.pytest function only imports and printsbrainpy.statewithout any assertions or checks; consider turning this into an actual smoke test by asserting basic properties of the import instead of printing. - Several added
if __name__ == '__main__': ...blocks (e.g., inmath/object_transform/__init__.py,brainpy/__init__.py,channels.py) simply reference symbols without doing anything and can be removed to avoid dead, confusing code paths.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `brainpy/dyn/projections/utils.py`, `_get_return()` still references `bm` but the `from brainpy import math as bm` import was removed, which will cause a NameError at runtime; restore or replace this import.
- The new `brainpy/test_main.py` test function only imports and prints `brainpy.state` without any assertions or checks; consider turning this into an actual smoke test by asserting basic properties of the import instead of printing.
- Several added `if __name__ == '__main__': ...` blocks (e.g., in `math/object_transform/__init__.py`, `brainpy/__init__.py`, `channels.py`) simply reference symbols without doing anything and can be removed to avoid dead, confusing code paths.
## Individual Comments
### Comment 1
<location> `brainpy/math/object_transform/__init__.py:43-44` </location>
<code_context>
-
-
if __name__ == '__main__':
connect
initialize, # weight initialization
</code_context>
<issue_to_address>
**suggestion:** The `__main__` block referencing `ProgressBar` appears to be dead code.
This `if __name__ == '__main__':` block has no side effects (the expressions are evaluated and discarded) and won’t run in normal library usage. If it isn’t needed, consider removing it; if it is, replace it with explicit calls or move the relevant logic to import time to make its purpose clear.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| if __name__ == '__main__': | ||
| ProgressBar |
There was a problem hiding this comment.
suggestion: The __main__ block referencing ProgressBar appears to be dead code.
This if __name__ == '__main__': block has no side effects (the expressions are evaluated and discarded) and won’t run in normal library usage. If it isn’t needed, consider removing it; if it is, replace it with explicit calls or move the relevant logic to import time to make its purpose clear.
jax>=0.8.2
Summary by Sourcery
Make minor refactors and formatting cleanups across the codebase while adjusting JIT wrappers for newer JAX versions and adding a basic import test.
Enhancements:
Tests: