Commit 54b2b34
ENH: Force lax.scan sequential operation to run on CPU
Add device=cpu to the qm_jax function decorator to avoid the known
XLA limitation where lax.scan with millions of lightweight iterations
performs poorly on GPU due to CPU-GPU synchronization overhead.
Added explanatory note about this pattern.
Co-authored-by: HumphreyYang <Humphrey.Yang@anu.edu.au>1 parent 10627ef commit 54b2b34
1 file changed
+8
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
509 | 509 | | |
510 | 510 | | |
511 | 511 | | |
512 | | - | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
513 | 515 | | |
514 | 516 | | |
515 | 517 | | |
| |||
521 | 523 | | |
522 | 524 | | |
523 | 525 | | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
524 | 531 | | |
525 | 532 | | |
526 | 533 | | |
| |||
0 commit comments