Skip to content

Commit 38b95b2

Browse files
committed
Adding model prediction dashboard
1 parent 15bdada commit 38b95b2

File tree

11 files changed

+158
-59
lines changed

11 files changed

+158
-59
lines changed

R/stan_vetiver.R

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,15 @@ handler_predict.brmsfit <- function(vetiver_model, ...) {
4444
newdata <- req$body
4545
newdata <- vetiver::vetiver_type_convert(newdata, ptype)
4646
newdata <- hardhat::scream(newdata, ptype)
47-
ret <- predict(vetiver_model$model, new_data = newdata, ...)
48-
list(.pred = ret[, "Estimate"],
49-
"Q2.5" = ret[, "Q2.5"],
50-
"Q97.5" = ret[, "Q97.5"])
47+
ret <- predict(vetiver_model$model, new_data = newdata, ndraws = 50, ...)
48+
49+
list(.pred = list(
50+
list(
51+
.pred = ret[, "Estimate"] |> mean(),
52+
.pred_lower = ret[, "Q2.5"] |> mean(),
53+
.pred_upper = ret[, "Q97.5"] |> mean()
54+
)
55+
))
5156
}
5257

5358
}

app/R/globals.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ library(jsonlite)
99
library(tibble)
1010
library(xml2)
1111
library(tools)
12+
library(dplyr)

app/api/R/vetiver_integration.R

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,22 @@ predict.int_conformal_split <- function(object, new_data, level = 0.95, ...) {
2121
}
2222

2323
handler_predict.int_conformal_split <- function(vetiver_model, ...) {
24-
ptype <- vetiver_model$prototype
25-
26-
function(req) {
27-
newdata <- req$body
28-
newdata <- vetiver::vetiver_type_convert(newdata, ptype)
29-
newdata <- hardhat::scream(newdata, ptype)
30-
ret <- predict(vetiver_model$model, new_data = newdata, ...)
31-
list(.pred = ret)
32-
}
33-
24+
ptype <- vetiver_model$prototype
25+
26+
function(req) {
27+
newdata <- req$body
28+
newdata <- vetiver::vetiver_type_convert(newdata, ptype)
29+
newdata <- hardhat::scream(newdata, ptype)
30+
ret <- predict(vetiver_model$model, new_data = newdata, ...)
31+
list(.pred = ret)
32+
}
33+
3434
}
3535

3636
vetiver_create_meta.int_conformal_split <-
3737
function(model, metadata) {
38-
vetiver::vetiver_meta(metadata, required_pkgs = c("probably", "workflows", "stacks", "workflowsets"))
38+
vetiver::vetiver_meta(metadata,
39+
required_pkgs = c("probably", "workflows", "stacks", "workflowsets"))
3940
}
4041

4142
vetiver_create_description.brmsfit <- function(model) {
@@ -57,36 +58,37 @@ handler_predict.brmsfit <- function(vetiver_model, ...) {
5758
newdata <- req$body
5859
newdata <- vetiver::vetiver_type_convert(newdata, ptype)
5960
newdata <- hardhat::scream(newdata, ptype)
60-
ret <- predict(vetiver_model$model, new_data = newdata, ...)
61-
list(
62-
.pred = ret[,"Estimate"],
63-
"Q2.5" = ret[, "Q2.5"],
64-
"Q97.5" = ret[, "Q97.5"]
65-
)
61+
ret <- predict(vetiver_model$model, new_data = newdata, ndraws = 50, ...)
62+
63+
list(.pred = list(
64+
list(
65+
.pred = ret[, "Estimate"] |> mean(),
66+
.pred_lower = ret[, "Q2.5"] |> mean(),
67+
.pred_upper = ret[, "Q97.5"] |> mean()
68+
)
69+
))
6670
}
6771

6872
}
6973

7074
predict.cmdstanr_container <- function(model, new_data) {
7175
data <- list()
72-
variable_names <- model$model$variables()$data |>
76+
variable_names <- model$model$variables()$data |>
7377
names()
7478

75-
for(variable_name in variable_names) {
76-
if(variable_name == "N") {
79+
for (variable_name in variable_names) {
80+
if (variable_name == "N") {
7781
data["N"] <- nrow(new_data)
7882
} else {
7983
data[[variable_name]] <- new_data[[variable_name]]
8084
indx_var <- paste0(variable_name, "_J")
81-
if(indx_var %in% variable_names) {
85+
if (indx_var %in% variable_names) {
8286
data[[indx_var]] <- model$data[[indx_var]]
8387
}
8488
}
8589
}
8690

87-
model$model$generate_quantities(
88-
model$fit, data = data
89-
)$summary() |> dplyr::select(-variable)
91+
model$model$generate_quantities(model$fit, data = data)$summary() |> dplyr::select(-variable)
9092
}
9193

9294
vetiver_create_description.cmdstanr_container <- function(model) {
@@ -110,7 +112,7 @@ handler_predict.cmdstanr_container <- function(vetiver_model, ...) {
110112
newdata <- hardhat::scream(newdata, ptype)
111113
ret <- predict(vetiver_model$model, new_data = newdata, ...)
112114
list(
113-
.pred = ret[,"mean"],
115+
.pred = ret[, "mean"],
114116
"sd" = ret[, "sd"],
115117
"q5" = ret[, "q5"],
116118
"q95" = ret[, "q95"]

app/server.R

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ server <- function(input, output, session) {
2626
model_endpoints[[idx]]
2727
})
2828

29+
df_features <- reactive({
30+
paste(model_endpoint(), "prototype", sep = "/") |>
31+
GET() |>
32+
content("text") |>
33+
fromJSON()
34+
}) |>
35+
bindCache(input$model_choice)
36+
2937
df <- reactive({
3038
paste(model_endpoint(), "data", sep = "/") |>
3139
GET() |>
@@ -35,6 +43,8 @@ server <- function(input, output, session) {
3543
}) |>
3644
bindCache(input$model_choice)
3745

46+
prediction_df <- reactiveVal(tibble())
47+
3848
model_card <- reactive({
3949
paste(model_endpoint(), "card", sep = "/") |>
4050
GET() |>
@@ -95,6 +105,18 @@ server <- function(input, output, session) {
95105
model_card_content()
96106
})
97107

108+
output$dynamic_features <- renderUI({
109+
lapply(names(df_features()), function(feature_name) {
110+
feature_type <- df_features()[[feature_name]]$type
111+
switch(feature_type,
112+
character = textInput(feature_name, label = feature_name, value = "None"),
113+
numeric = numericInput(feature_name, label = feature_name, value = 0),
114+
select = selectInput(feature_name, label = feature_name, choices = c()),
115+
NULL
116+
)
117+
})
118+
})
119+
98120
updateSelectizeInput(
99121
session, "plot_choice", choices = plot_choices(), server = FALSE,
100122
selected = plot_choices()[1]
@@ -128,6 +150,51 @@ server <- function(input, output, session) {
128150
})
129151
})
130152

153+
observeEvent(input$clear_predict, {
154+
prediction_df(tibble())
155+
})
156+
157+
observeEvent(input$predict, {
158+
req_body <- reactiveValuesToList(input)[names(df_features())] |>
159+
as.tibble()
160+
161+
preds <- paste(model_endpoint(), "predict", sep = "/") |>
162+
POST(
163+
body = req_body,
164+
encode = "json"
165+
) |>
166+
content("text") |>
167+
fromJSON()
168+
169+
req_body$.pred <- preds$.pred$.pred
170+
req_body$.pred_lower <- preds$.pred$.pred_lower
171+
req_body$.pred_upper <- preds$.pred$.pred_upper
172+
print(req_body)
173+
prediction_df(bind_rows(prediction_df(), req_body))
174+
175+
output$predictions <- renderPlotly({
176+
axes_choices <- axes_choices()
177+
x_ax <- x_ax()
178+
179+
if((!x_ax %in% axes_choices)) {
180+
return()
181+
}
182+
183+
if(nrow(prediction_df()) == 0) {
184+
return()
185+
}
186+
187+
(
188+
prediction_df() |>
189+
ggplot(aes(x=.data[[x_ax]], y=.data[[".pred"]])) +
190+
geom_point() +
191+
geom_line() +
192+
geom_errorbar(aes(ymin = .data[[".pred_lower"]], ymax = .data[[".pred_upper"]]), width = 0.3)
193+
) |>
194+
ggplotly()
195+
})
196+
})
197+
131198
output$scatter <- renderPlotly({
132199
axes_choices <- axes_choices()
133200
x_ax <- x_ax()

app/ui.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ ui <- page_sidebar(
4242
selectize = TRUE
4343
)
4444
),
45+
accordion_panel(
46+
"Predictions",
47+
tags$style(HTML("
48+
#predict {
49+
position: sticky;
50+
top: 10px;
51+
z-index: 1000;
52+
}
53+
")),
54+
actionButton("predict", "Predict"),
55+
htmlOutput("dynamic_features"),
56+
actionButton("clear_predict", "Clear Predictions")
57+
),
4558
accordion_panel(
4659
"Downloads",
4760
downloadButton("download_data", "Download Data"),
@@ -60,6 +73,13 @@ ui <- page_sidebar(
6073
full_screen = TRUE
6174
)
6275
),
76+
nav_panel(
77+
"Predictions",
78+
card(
79+
plotlyOutput("predictions"),
80+
full_screen = TRUE
81+
)
82+
),
6383
nav_panel(
6484
"Histograms",
6585
layout_columns(

deployments/k8s/pfr/r-deployment.yaml

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,6 @@ spec:
1818
app: r-pipelines
1919
spec:
2020
containers:
21-
- image: ghcr.io/plant-food-research-open/shiny-rpipelines:0.1.0
22-
imagePullPolicy: Always
23-
name: shiny
24-
ports:
25-
- containerPort: 3838
26-
protocol: TCP
27-
args:
28-
- R
29-
- -e
30-
- shiny::runApp('/app', launch.browser = FALSE, host = '0.0.0.0', port = 3838)
31-
env:
32-
- name: AWS_ACCESS_KEY_ID
33-
value: user
34-
- name: AWS_SECRET_ACCESS_KEY
35-
value: password
3621
- image: h2oai/h2o-open-source-k8s:3.44.0.3
3722
resources:
3823
requests:

deployments/k8s/pfr/r-ingress.yaml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,6 @@ spec:
1717
app: r-pipelines
1818
ingressClassName: nginx
1919
rules:
20-
- host: shiny.r-pipelines.k8s.dev.pfr.co.nz
21-
http:
22-
paths:
23-
- backend:
24-
service:
25-
name: r-pipelines
26-
port:
27-
name: port-3838
28-
path: /
29-
pathType: Prefix
3020
- host: minio.r-pipelines.k8s.dev.pfr.co.nz
3121
http:
3222
paths:
@@ -69,7 +59,6 @@ spec:
6959
pathType: Prefix
7060
tls:
7161
- hosts:
72-
- "shiny.r-pipelines.k8s.dev.pfr.co.nz"
7362
- "minio.r-pipelines.k8s.dev.pfr.co.nz"
7463
- "minio-api.r-pipelines.k8s.dev.pfr.co.nz"
7564
- "h2o.r-pipelines.k8s.dev.pfr.co.nz"

deployments/k8s/pfr/r-models-deployment.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,25 @@ spec:
1818
app: r-pipelines
1919
spec:
2020
containers:
21+
- image: ghcr.io/plant-food-research-open/shiny-rpipelines:0.1.0
22+
imagePullPolicy: Always
23+
name: shiny
24+
ports:
25+
- containerPort: 3838
26+
protocol: TCP
27+
args:
28+
- R
29+
- -e
30+
- shiny::runApp('/app', launch.browser = FALSE, host = '0.0.0.0', port = 3838)
31+
env:
32+
- name: AWS_ACCESS_KEY_ID
33+
value: user
34+
- name: AWS_SECRET_ACCESS_KEY
35+
value: password
36+
- name: MODEL_CHOICES
37+
value: random_forest,gompertz
38+
- name: MODEL_ENDPOINTS
39+
value: https://random-forest.model.r-pipelines.k8s.dev.pfr.co.nz,https://gompertz.model.r-pipelines.k8s.dev.pfr.co.nz
2140
- image: ghcr.io/plant-food-research-open/shiny-rpipelines:0.1.0
2241
imagePullPolicy: Always
2342
name: random-forest

deployments/k8s/pfr/r-models-ingress.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ spec:
1717
app: r-pipelines
1818
ingressClassName: nginx
1919
rules:
20+
- host: shiny.r-pipelines.k8s.dev.pfr.co.nz
21+
http:
22+
paths:
23+
- backend:
24+
service:
25+
name: r-pipelines-models
26+
port:
27+
name: port-3838
28+
path: /
29+
pathType: Prefix
2030
- host: random-forest.model.r-pipelines.k8s.dev.pfr.co.nz
2131
http:
2232
paths:
@@ -39,6 +49,7 @@ spec:
3949
pathType: Prefix
4050
tls:
4151
- hosts:
52+
- "shiny.r-pipelines.k8s.dev.pfr.co.nz"
4253
- "random-forest.model.r-pipelines.k8s.dev.pfr.co.nz"
4354
- "gompertz.model.r-pipelines.k8s.dev.pfr.co.nz"
4455
secretName: r-pipelines-models-tls-secret

deployments/k8s/pfr/r-models-services.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ spec:
1414
app: r-pipelines
1515
type: LoadBalancer
1616
ports:
17+
- name: port-3838
18+
protocol: TCP
19+
port: 3838
20+
targetPort: 3838
1721
- name: port-8089
1822
protocol: TCP
1923
port: 8089

0 commit comments

Comments
 (0)