@@ -21,21 +21,22 @@ predict.int_conformal_split <- function(object, new_data, level = 0.95, ...) {
2121}
2222
2323handler_predict.int_conformal_split <- function (vetiver_model , ... ) {
24- ptype <- vetiver_model $ prototype
25-
26- function (req ) {
27- newdata <- req $ body
28- newdata <- vetiver :: vetiver_type_convert(newdata , ptype )
29- newdata <- hardhat :: scream(newdata , ptype )
30- ret <- predict(vetiver_model $ model , new_data = newdata , ... )
31- list (.pred = ret )
32- }
33-
24+ ptype <- vetiver_model $ prototype
25+
26+ function (req ) {
27+ newdata <- req $ body
28+ newdata <- vetiver :: vetiver_type_convert(newdata , ptype )
29+ newdata <- hardhat :: scream(newdata , ptype )
30+ ret <- predict(vetiver_model $ model , new_data = newdata , ... )
31+ list (.pred = ret )
32+ }
33+
3434}
3535
3636vetiver_create_meta.int_conformal_split <-
3737 function (model , metadata ) {
38- vetiver :: vetiver_meta(metadata , required_pkgs = c(" probably" , " workflows" , " stacks" , " workflowsets" ))
38+ vetiver :: vetiver_meta(metadata ,
39+ required_pkgs = c(" probably" , " workflows" , " stacks" , " workflowsets" ))
3940 }
4041
4142vetiver_create_description.brmsfit <- function (model ) {
@@ -57,36 +58,37 @@ handler_predict.brmsfit <- function(vetiver_model, ...) {
5758 newdata <- req $ body
5859 newdata <- vetiver :: vetiver_type_convert(newdata , ptype )
5960 newdata <- hardhat :: scream(newdata , ptype )
60- ret <- predict(vetiver_model $ model , new_data = newdata , ... )
61- list (
62- .pred = ret [," Estimate" ],
63- " Q2.5" = ret [, " Q2.5" ],
64- " Q97.5" = ret [, " Q97.5" ]
65- )
61+ ret <- predict(vetiver_model $ model , new_data = newdata , ndraws = 50 , ... )
62+
63+ list (.pred = list (
64+ list (
65+ .pred = ret [, " Estimate" ] | > mean(),
66+ .pred_lower = ret [, " Q2.5" ] | > mean(),
67+ .pred_upper = ret [, " Q97.5" ] | > mean()
68+ )
69+ ))
6670 }
6771
6872}
6973
7074predict.cmdstanr_container <- function (model , new_data ) {
7175 data <- list ()
72- variable_names <- model $ model $ variables()$ data | >
76+ variable_names <- model $ model $ variables()$ data | >
7377 names()
7478
75- for (variable_name in variable_names ) {
76- if (variable_name == " N" ) {
79+ for (variable_name in variable_names ) {
80+ if (variable_name == " N" ) {
7781 data [" N" ] <- nrow(new_data )
7882 } else {
7983 data [[variable_name ]] <- new_data [[variable_name ]]
8084 indx_var <- paste0(variable_name , " _J" )
81- if (indx_var %in% variable_names ) {
85+ if (indx_var %in% variable_names ) {
8286 data [[indx_var ]] <- model $ data [[indx_var ]]
8387 }
8488 }
8589 }
8690
87- model $ model $ generate_quantities(
88- model $ fit , data = data
89- )$ summary() | > dplyr :: select(- variable )
91+ model $ model $ generate_quantities(model $ fit , data = data )$ summary() | > dplyr :: select(- variable )
9092}
9193
9294vetiver_create_description.cmdstanr_container <- function (model ) {
@@ -110,7 +112,7 @@ handler_predict.cmdstanr_container <- function(vetiver_model, ...) {
110112 newdata <- hardhat :: scream(newdata , ptype )
111113 ret <- predict(vetiver_model $ model , new_data = newdata , ... )
112114 list (
113- .pred = ret [," mean" ],
115+ .pred = ret [, " mean" ],
114116 " sd" = ret [, " sd" ],
115117 " q5" = ret [, " q5" ],
116118 " q95" = ret [, " q95" ]
0 commit comments