55 * Helper functions to work with ONNX Runtime library
66 */
77
8+ #include < QString>
9+ #include < QFile>
10+ #include < QCryptographicHash>
11+ #include < QByteArray>
12+
813#include < iostream>
14+ #include < string>
15+ #include < filesystem>
16+ #include < fstream>
917#include < onnxruntime_cxx_api.h>
18+ #include " 3rdParty/ONNX/OnnxToolsPA.h"
1019#include " Common/Compiler.h"
1120#include " ML_ONNXRuntimeHelpers.h"
1221
22+ namespace fs = std::filesystem;
23+
1324namespace PokemonAutomation {
1425namespace ML {
1526
16- Ort::SessionOptions create_session_option (const std::string& cache_folder_name){
17- Ort::SessionOptions so;
27+ const char * MODEL_CACHE_FOLDER = " ./ModelCache/" ;
28+
29+ // Computes the cryptographic hash of a file.
30+ std::string create_file_hash (const std::string& filepath){
31+ QFile file (QString::fromStdString (filepath));
32+ if (!file.open (QIODevice::ReadOnly)) {
33+ return " " ;
34+ }
35+
36+ QCryptographicHash hash (QCryptographicHash::Sha256);
37+ if (hash.addData (&file)) {
38+ return hash.result ().toHex (0 ).toStdString ();
39+ } else {
40+ return " " ;
41+ }
42+ }
43+
1844
45+ Ort::SessionOptions create_session_options (const std::string& model_cache_path){
46+ Ort::SessionOptions so;
47+ std::cout << " Set potential model cache path in session options: " << model_cache_path << std::endl;
1948#if __APPLE__
2049 // create session using Apple ML acceleration library CoreML
2150 std::unordered_map<std::string, std::string> provider_options;
2251 // See for provider options: https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html
2352 // "NeuralNetwork" is a faster ModelFormat than "MLProgram".
2453 provider_options[" ModelFormat" ] = std::string (" NeuralNetwork" );
25- // TODO: need to make sure the cache works
26- provider_options[" ModelCacheDirectory" ] = " ./ModelCache/" + cache_folder_name;
54+ provider_options[" ModelCacheDirectory" ] = model_cache_path;
2755 // provider_options["MLComputeUnits"] = "ALL";
2856 // provider_options["RequireStaticInputShapes"] = "0";
2957 // provider_options["EnableOnSubgraphs"] = "0";
@@ -34,6 +62,73 @@ Ort::SessionOptions create_session_option(const std::string& cache_folder_name){
3462 return so;
3563}
3664
65+ // Check the model file cache integrity by checking the existence of a flag file and the model hash stored
66+ // in the flag file. If the flag does not exist, we assume the file cache does not exist or is broken.
67+ // If the hash stored in the flag file does not match the model file, the model file is a new model, delete
68+ // the old cache.
69+ // Return
70+ // - bool: whether to write flag file after cache is created
71+ // - string: the model file hash to write into the flag file after Ort::Session is built and the cache is created.
72+ //
73+ // model_cache_path: Folder path to store model cache. This name is better to be unique for each model for
74+ // easier file management.
75+ // model_path: the model path to load the ML model. This is needed to ensure we delete the old model cache
76+ // when a new model
77+ std::pair<bool , std::string> clean_up_old_model_cache (const std::string& model_cache_path, const std::string& model_path){
78+ std::string file_hash = create_file_hash (model_path);
79+ if (file_hash.size () == 0 ){
80+ // the model file cannot be loaded
81+ return {true , " " };
82+ }
83+
84+ if (!fs::exists (fs::path (model_cache_path))){
85+ return {true , file_hash};
86+ }
87+
88+ const std::string flag_file_path = model_cache_path + " /HASH.txt" ;
89+ if (fs::exists (fs::path (flag_file_path))){
90+ std::ifstream fin (flag_file_path);
91+ if (fin){
92+ std::string line;
93+ fin >> line;
94+ if (line == file_hash){
95+ // hash match!
96+ return {false , file_hash};
97+ }
98+ }
99+ }
100+ // remove everything from model_cache_path
101+ fs::remove_all (fs::path (model_cache_path));
102+ return {true , file_hash};
103+ }
104+
105+
106+ void write_cache_flag_file (const std::string& model_cache_path, const std::string& hash){
107+ if (!fs::exists (fs::path (model_cache_path))){
108+ return ;
109+ }
110+ const std::string flag_file_path = model_cache_path + " /HASH.txt" ;
111+ std::ofstream fout (flag_file_path);
112+ fout << hash;
113+ }
114+
115+
116+ Ort::Session create_session (const std::string& model_path, const std::string& cache_folder_name){
117+ const std::string model_cache_path = MODEL_CACHE_FOLDER + cache_folder_name;
118+ Ort::SessionOptions so = create_session_options (model_cache_path);
119+ bool write_flag_file = true ;
120+ std::string file_hash;
121+ std::tie (write_flag_file, file_hash) = clean_up_old_model_cache (model_cache_path, model_path);
122+
123+ Ort::Env env;
124+ Ort::Session session{env, str_to_onnx_str (model_path).c_str (), so};
125+ // when Ort::Ssssion is created, if possible, it will create a model cache
126+ if (write_flag_file){
127+ write_cache_flag_file (model_cache_path, file_hash);
128+ }
129+ return session;
130+ }
131+
37132
38133void print_model_input_output_info (const Ort::Session& session){
39134 const auto input_names = session.GetInputNames ();
0 commit comments