Skip to content

Commit 4a63c0d

Browse files
author
Frankie Robertson
committed
Stop storing item_bank separately in StatefulCatConfig
1 parent 806a221 commit 4a63c0d

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

src/Stateful.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,61 +56,65 @@ end
5656
## TODO: Materialise the cat into a decsision tree
5757

5858
## Implementation for CatConfig
59-
struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat
59+
struct StatefulCatConfig{TrackedResponsesT <: TrackedResponses} <: StatefulCat
6060
rules::CatRules
61-
tracked_responses::TrackedResponses
62-
item_bank::Ref{ItemBankT}
61+
tracked_responses::Ref{TrackedResponsesT}
6362
end
6463

65-
function StatefulCatConfig(rules, item_bank)
64+
function StatefulCatConfig(rules::CatRules, item_bank::AbstractItemBank)
6665
bare_responses = BareResponses(ResponseType(item_bank))
6766
tracked_responses = TrackedResponses(
6867
bare_responses,
6968
item_bank,
7069
rules.ability_tracker
7170
)
72-
return StatefulCatConfig(rules, tracked_responses, Ref(item_bank))
71+
return StatefulCatConfig(rules, Ref(tracked_responses))
7372
end
7473

7574
function next_item(config::StatefulCatConfig)
76-
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[])
75+
return best_item(config.rules.next_item, config.tracked_responses[])
7776
end
7877

7978
function ranked_items(config::StatefulCatConfig)
8079
return sortperm(compute_criteria(
81-
config.rules.next_item, config.tracked_responses, config.item_bank[]))
80+
config.rules.next_item, config.tracked_responses[]))
8281
end
8382

8483
function item_criteria(config::StatefulCatConfig)
8584
return compute_criteria(
86-
config.rules.next_item, config.tracked_responses, config.item_bank[])
85+
config.rules.next_item, config.tracked_responses[])
8786
end
8887

8988
function add_response!(config::StatefulCatConfig, index, response)
89+
tracked_responses = config.tracked_responses[]
9090
Aggregators.add_response!(
91-
config.tracked_responses, Response(
92-
ResponseType(config.item_bank[]), index, response))
91+
tracked_responses, Response(
92+
ResponseType(tracked_responses.item_bank), index, response))
9393
end
9494

9595
function rollback!(config::StatefulCatConfig)
96-
pop_response!(config.tracked_responses)
96+
pop_response!(config.tracked_responses[])
9797
end
9898

9999
function reset!(config::StatefulCatConfig)
100-
empty!(config.tracked_responses)
100+
empty!(config.tracked_responses[])
101101
end
102102

103103
function set_item_bank!(config::StatefulCatConfig, item_bank)
104-
reset!(config)
105-
config.item_bank[] = item_bank
104+
bare_responses = BareResponses(ResponseType(item_bank))
105+
config.tracked_responses[] = TrackedResponses(
106+
bare_responses,
107+
item_bank,
108+
config.rules.ability_tracker
109+
)
106110
end
107111

108112
function get_responses(config::StatefulCatConfig)
109-
return config.tracked_responses.responses
113+
return config.tracked_responses[].responses
110114
end
111115

112116
function get_ability(config::StatefulCatConfig)
113-
return (config.rules.ability_estimator(config.tracked_responses), nothing)
117+
return (config.rules.ability_estimator(config.tracked_responses[]), nothing)
114118
end
115119

116120
## TODO: Implementation for MaterializedDecisionTree

0 commit comments

Comments
 (0)