Skip to content

Commit f9b4da8

Browse files
authored
feat: new macro @quickloop and @reaction_eq (#6)
* feat: new macro `@quickloop` and `@reaction_eq` * Bump version to `0.1.4`
1 parent c3f31fa commit f9b4da8

File tree

7 files changed

+684
-111
lines changed

7 files changed

+684
-111
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
name = "EvolutionaryModelingTools"
22
uuid = "faf572ed-2b53-4324-a741-daa175e50348"
33
authors = ["Long Wang <wangl.cc@outlook.com>"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99
RecordedArrays = "1040807a-3a2f-4266-b2c1-805b33c7034a"
10+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011

1112
[compat]
13+
LoopVectorization = "0.12.107"
1214
MacroTools = "0.5"
1315
RecordedArrays = "0.4.1"
1416
julia = "1.2"
1517

1618
[extras]
19+
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1720
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1821

1922
[targets]
20-
test = ["Test"]
23+
test = ["Test", "LoopVectorization"]

example/competition.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using RecordedArrays
22
using RecordedArrays.ResizingTools
33
using Random
4+
using LoopVectorization
45
using EvolutionaryModelingTools
56
using EvolutionaryModelingTools: sample
67

@@ -13,13 +14,13 @@ const m = ones(1, 1)
1314

1415
# simulate with this package
1516
## define reactions
16-
@reaction growth begin
17-
r * (1 - μ) * v
18-
v[ind] += 1
19-
end
2017

18+
# define growth with @reaction_eq, loop vectorization is disabled but inbounds, fastmath and offset are enable
19+
@reaction_eq growth r * (1 - μ) v[i] 2v[i] avx=false inbounds=true fastmath=true offset=true
20+
21+
# define mutation with @reaction and @quickloop with given index
2122
@reaction mutation begin
22-
@. r * μ * v
23+
@quickloop r * μ * v[i]
2324
begin
2425
push!(v, 1)
2526
n = length(v)
@@ -30,11 +31,9 @@ end
3031
end
3132
end
3233

34+
# define competition with @reaction and @quickloop
3335
@reaction competition begin
34-
begin
35-
c = r / K # baseline competition coefficient
36-
@. c * v * m * v'
37-
end
36+
@quickloop (r / K) * v[i] * m[i, j] * v[j] turbo # force enable loop vectorization
3837
begin
3938
i = ind[1]
4039
v[i] -= 1

example/sir.jl

Lines changed: 154 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,89 @@ using EvolutionaryModelingTools.Scalar
66
# parameters
77
const T = 100.0
88
const β = 0.001
9-
const γ = 0.01
9+
const ν = 0.01
10+
const α = 0.2
11+
const r = 0.5
12+
const d = 0.1
13+
const c = 0.001
1014
const S = 100
1115
const I = 1
1216
const R = 0
1317

1418
## simulation with this package
15-
@reaction infection begin
16-
β * S * I
17-
(S[] -= 1; I[] += 1)
18-
end
19-
20-
@reaction recovery begin
21-
γ * I
22-
(I[] -= 1; R[] += 1)
19+
# epidemic dynamics
20+
@reaction_eq infection β S + I --> I + I
21+
@reaction_eq recovery ν I --> R
22+
@reaction_eq virulence α I --> 0 # death of infection host caused by virus
23+
# demography dynamics
24+
# generate reactions with @eval
25+
for sym in (:S, :I, :R)
26+
r_name = Symbol(:growth_, sym)
27+
d_name = Symbol(:death_, sym)
28+
@eval @reaction_eq $r_name r $sym --> S + $sym # growth
29+
@eval @reaction_eq $d_name d $sym --> 0 # death
30+
for sym′ in (:S, :I, :R)
31+
c_name = Symbol(:competition_, sym, sym′)
32+
@eval @reaction_eq $c_name c $sym + $sym′ --> $sym′ # competition
33+
end
2334
end
2435

25-
run_gillespie(rng=Random.GLOBAL_RNG, T=T, β=β, γ=γ, S=S, I=I, R=R) = gillespie(
26-
rng,
27-
T,
28-
(; β, γ, S=scalar(S), I=scalar(I), R=scalar(R)),
29-
(infection, recovery),
36+
const REACTIONS = (
37+
infection, recovery, virulence,
38+
growth_S, growth_I, growth_R,
39+
death_S, death_I, death_R,
40+
competition_SS, competition_SI, competition_SR,
41+
competition_IS, competition_II, competition_IR,
42+
competition_RS, competition_RI, competition_RR
3043
)
3144

32-
function run_gillespie_record(rng=Random.GLOBAL_RNG, T=T, β=β, γ=γ, S=S, I=I, R=R)
45+
run_gillespie(rng=Random.GLOBAL_RNG, T=T, β=β, ν=ν, α=α, r=r, d=d, c=c, S=S, I=I, R=R) =
46+
gillespie(rng, T, (; β, ν, α, r, d, c, S=scalar(S), I=scalar(I), R=scalar(R)), REACTIONS)
47+
48+
function run_gillespie_record(rng=Random.GLOBAL_RNG, T=T, β=β, ν=ν, α=α,
49+
r=r, d=d, c=c, S=S, I=I, R=R)
3350
clock = ContinuousClock(T)
34-
ps = (;
35-
β,
36-
γ,
51+
ps = (;β, ν, α, r, d, c,
3752
S=recorded(DynamicEntry, clock, S),
3853
I=recorded(DynamicEntry, clock, I),
3954
R=recorded(DynamicEntry, clock, R),
4055
)
41-
return gillespie(rng, clock, ps, (infection, recovery))
56+
return gillespie(rng, clock, ps, REACTIONS)
4257
end
4358

4459
## simulation manually
45-
function run_manually(rng=Random.GLOBAL_RNG, T=T, β=β, γ=γ, S=S, I=I, R=R)
60+
### most of the code is generated by Github Copilot
61+
### a useful tool for generating code like this
62+
function run_manually(rng=Random.GLOBAL_RNG, T=T, β=β, ν=ν, α=α,
63+
r=r, d=d, c=c, S=S, I=I, R=R)
4664
t = 0
4765
while t <= T
4866
# calculate rates
4967
infection_rate = β * S * I
50-
recovery_rate = γ * I
51-
summed = infection_rate + recovery_rate
68+
recovery_rate = ν * I
69+
virulence_rate = α * I
70+
growth_S_rate = r * S
71+
growth_I_rate = r * I
72+
growth_R_rate = r * R
73+
death_S_rate = d * S
74+
death_I_rate = d * I
75+
death_R_rate = d * R
76+
competition_SS_rate = c * S * S
77+
competition_SI_rate = c * S * I
78+
competition_SR_rate = c * S * R
79+
competition_IS_rate = c * I * S
80+
competition_II_rate = c * I * I
81+
competition_IR_rate = c * I * R
82+
competition_RS_rate = c * R * S
83+
competition_RI_rate = c * R * I
84+
competition_RR_rate = c * R * R
85+
# summed rate
86+
summed = infection_rate + recovery_rate + virulence_rate +
87+
growth_S_rate + growth_I_rate + growth_R_rate +
88+
death_S_rate + death_I_rate + death_R_rate +
89+
competition_SS_rate + competition_SI_rate + competition_SR_rate +
90+
competition_IS_rate + competition_II_rate + competition_IR_rate +
91+
competition_RS_rate + competition_RI_rate + competition_RR_rate
5292
# break if summed is zero
5393
summed == 0 && break
5494
# update current time
@@ -58,24 +98,79 @@ function run_manually(rng=Random.GLOBAL_RNG, T=T, β=β, γ=γ, S=S, I=I, R=R)
5898
if rn < infection_rate
5999
S -= 1
60100
I += 1
61-
else
101+
elseif (rn -= infection_rate) < recovery_rate
62102
I -= 1
63103
R += 1
104+
elseif (rn -= recovery_rate) < virulence_rate
105+
I -= 1
106+
elseif (rn -= virulence_rate) < growth_S_rate
107+
S += 1
108+
elseif (rn -= growth_S_rate) < growth_I_rate
109+
S += 1
110+
elseif (rn -= growth_I_rate) < growth_R_rate
111+
S += 1
112+
elseif (rn -= growth_R_rate) < death_S_rate
113+
S -= 1
114+
elseif (rn -= death_S_rate) < death_I_rate
115+
I -= 1
116+
elseif (rn -= death_I_rate) < death_R_rate
117+
R -= 1
118+
elseif (rn -= death_R_rate) < competition_SS_rate
119+
S -= 1
120+
elseif (rn -= competition_SS_rate) < competition_SI_rate
121+
S -= 1
122+
elseif (rn -= competition_SI_rate) < competition_SR_rate
123+
S -= 1
124+
elseif (rn -= competition_SR_rate) < competition_IS_rate
125+
I -= 1
126+
elseif (rn -= competition_IS_rate) < competition_II_rate
127+
I -= 1
128+
elseif (rn -= competition_II_rate) < competition_IR_rate
129+
I -= 1
130+
elseif (rn -= competition_IR_rate) < competition_RS_rate
131+
R -= 1
132+
elseif (rn -= competition_RS_rate) < competition_RI_rate
133+
R -= 1
134+
else
135+
R -= 1
64136
end
65137
end
66138
return (; t, S, I, R)
67139
end
68140

69-
function run_manually_record(rng=Random.GLOBAL_RNG, T=T, β=β, γ=γ, S=S, I=I, R=R)
141+
function run_manually_record(rng=Random.GLOBAL_RNG, T=T, β=β, ν=ν, α=α,
142+
r=r, d=d, c=c, S=S, I=I, R=R)
70143
clock = ContinuousClock(T)
71144
S′ = recorded(DynamicEntry, clock, S)
72145
I′ = recorded(DynamicEntry, clock, I)
73146
R′ = recorded(DynamicEntry, clock, R)
74147
for _ in clock
75148
# calculate rates
76149
infection_rate = β * S′ * I′
77-
recovery_rate = γ * I′
78-
summed = infection_rate + recovery_rate
150+
recovery_rate = ν * I′
151+
virulence_rate = α * I′
152+
growth_S_rate = r * S′
153+
growth_I_rate = r * I′
154+
growth_R_rate = r * R′
155+
death_S_rate = d * S′
156+
death_I_rate = d * I′
157+
death_R_rate = d * R′
158+
competition_SS_rate = c * S′ * S′
159+
competition_SI_rate = c * S′ * I′
160+
competition_SR_rate = c * S′ * R′
161+
competition_IS_rate = c * I′ * S′
162+
competition_II_rate = c * I′ * I′
163+
competition_IR_rate = c * I′ * R′
164+
competition_RS_rate = c * R′ * S′
165+
competition_RI_rate = c * R′ * I′
166+
competition_RR_rate = c * R′ * R′
167+
# summed rate
168+
summed = infection_rate + recovery_rate + virulence_rate +
169+
growth_S_rate + growth_I_rate + growth_R_rate +
170+
death_S_rate + death_I_rate + death_R_rate +
171+
competition_SS_rate + competition_SI_rate + competition_SR_rate +
172+
competition_IS_rate + competition_II_rate + competition_IR_rate +
173+
competition_RS_rate + competition_RI_rate + competition_RR_rate
79174
# break if summed is zero
80175
summed == 0 && break
81176
# update current time
@@ -85,9 +180,41 @@ function run_manually_record(rng=Random.GLOBAL_RNG, T=T, β=β, γ=γ, S=S, I=I,
85180
if rn < infection_rate
86181
S′[] -= 1
87182
I′[] += 1
88-
else
183+
elseif (rn -= infection_rate) < recovery_rate
89184
I′[] -= 1
90185
R′[] += 1
186+
elseif (rn -= recovery_rate) < virulence_rate
187+
I′[] -= 1
188+
elseif (rn -= virulence_rate) < growth_S_rate
189+
S′[] += 1
190+
elseif (rn -= growth_S_rate) < growth_I_rate
191+
S′[] += 1
192+
elseif (rn -= growth_I_rate) < growth_R_rate
193+
S′[] += 1
194+
elseif (rn -= growth_R_rate) < death_S_rate
195+
S′[] -= 1
196+
elseif (rn -= death_S_rate) < death_I_rate
197+
I′[] -= 1
198+
elseif (rn -= death_I_rate) < death_R_rate
199+
R′[] -= 1
200+
elseif (rn -= death_R_rate) < competition_SS_rate
201+
S′[] -= 1
202+
elseif (rn -= competition_SS_rate) < competition_SI_rate
203+
S′[] -= 1
204+
elseif (rn -= competition_SI_rate) < competition_SR_rate
205+
S′[] -= 1
206+
elseif (rn -= competition_SR_rate) < competition_IS_rate
207+
I′[] -= 1
208+
elseif (rn -= competition_IS_rate) < competition_II_rate
209+
I′[] -= 1
210+
elseif (rn -= competition_II_rate) < competition_IR_rate
211+
I′[] -= 1
212+
elseif (rn -= competition_IR_rate) < competition_RS_rate
213+
R′[] -= 1
214+
elseif (rn -= competition_RS_rate) < competition_RI_rate
215+
R′[] -= 1
216+
else
217+
R′[] -= 1
91218
end
92219
end
93220
return (; clock, S=S′, I=I′, R=R′)

src/EvolutionaryModelingTools.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
module EvolutionaryModelingTools
22

33
using RecordedArrays
4+
using RecordedArrays.ArrayInterface: indices
45
using Random
56
using MacroTools
7+
using Requires
68

7-
export @cfunc, @ufunc, @reaction
8-
export Reaction, gillespie, gillespie!
9+
export @cfunc, @ufunc, @reaction, @quickloop, @reaction_eq
10+
export Reaction, gillespie, gillespie!, indices
911

1012
include("tools.jl")
1113
include("sample.jl")
1214
include("model.jl")
1315
include("scalar.jl")
1416

17+
function __init__()
18+
@require LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" begin
19+
DEFAULT_LOOP_MODE[1] = true
20+
end
21+
end
22+
1523
end # module

0 commit comments

Comments
 (0)