Skip to content

Commit 0527744

Browse files
author
Gin
committed
add model cache update mechanism
1 parent 3a65dee commit 0527744

File tree

6 files changed

+107
-21
lines changed

6 files changed

+107
-21
lines changed

SerialPrograms/Source/ML/DataLabeling/ML_SegmentAnythingModel.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ namespace ML{
2424

2525

2626
SAMEmbedderSession::SAMEmbedderSession(const std::string& model_path)
27-
: session_options(create_session_option("SAMEmbedder"))
28-
, session{env, str_to_onnx_str(model_path).c_str(), session_options}
27+
: session{create_session(model_path, "SAMEmbedder")}
2928
, memory_info{Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)}
3029
, input_names{session.GetInputNames()}
3130
, output_names{session.GetOutputNames()}
@@ -64,8 +63,7 @@ void SAMEmbedderSession::run(cv::Mat& input_image, std::vector<float>& model_out
6463

6564

6665
SAMSession::SAMSession(const std::string& model_path)
67-
: session_options(create_session_option("SAM"))
68-
, session{env, str_to_onnx_str(model_path).c_str(), session_options}
66+
: session{create_session(model_path, "SAM")}
6967
, memory_info{Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)}
7068
, input_names{session.GetInputNames()}
7169
, output_names{session.GetOutputNames()}

SerialPrograms/Source/ML/DataLabeling/ML_SegmentAnythingModel.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class SAMEmbedderSession{
3636
void run(cv::Mat& input_image, std::vector<float>& output_image_embedding);
3737

3838
private:
39-
Ort::Env env;
40-
Ort::SessionOptions session_options;
4139
Ort::Session session;
4240
Ort::MemoryInfo memory_info;
4341
Ort::RunOptions run_options;
@@ -71,8 +69,6 @@ class SAMSession{
7169
const std::vector<int>& input_box,
7270
std::vector<bool>& output_boolean_mask);
7371
private:
74-
Ort::Env env;
75-
Ort::SessionOptions session_options;
7672
Ort::Session session;
7773
Ort::MemoryInfo memory_info;
7874
Ort::RunOptions run_options;

SerialPrograms/Source/ML/Models/ML_ONNXRuntimeHelpers.cpp

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,53 @@
55
* Helper functions to work with ONNX Runtime library
66
*/
77

8+
#include <QString>
9+
#include <QFile>
10+
#include <QCryptographicHash>
11+
#include <QByteArray>
12+
813
#include <iostream>
14+
#include <string>
15+
#include <filesystem>
16+
#include <fstream>
917
#include <onnxruntime_cxx_api.h>
18+
#include "3rdParty/ONNX/OnnxToolsPA.h"
1019
#include "Common/Compiler.h"
1120
#include "ML_ONNXRuntimeHelpers.h"
1221

22+
namespace fs = std::filesystem;
23+
1324
namespace PokemonAutomation{
1425
namespace ML{
1526

16-
Ort::SessionOptions create_session_option(const std::string& cache_folder_name){
17-
Ort::SessionOptions so;
27+
const char* MODEL_CACHE_FOLDER = "./ModelCache/";
28+
29+
// Computes the cryptographic hash of a file.
30+
std::string create_file_hash(const std::string& filepath){
31+
QFile file(QString::fromStdString(filepath));
32+
if (!file.open(QIODevice::ReadOnly)) {
33+
return "";
34+
}
35+
36+
QCryptographicHash hash(QCryptographicHash::Sha256);
37+
if (hash.addData(&file)) {
38+
return hash.result().toHex(0).toStdString();
39+
} else {
40+
return "";
41+
}
42+
}
43+
1844

45+
Ort::SessionOptions create_session_options(const std::string& model_cache_path){
46+
Ort::SessionOptions so;
47+
std::cout << "Set potential model cache path in session options: " << model_cache_path << std::endl;
1948
#if __APPLE__
2049
// create session using Apple ML acceleration library CoreML
2150
std::unordered_map<std::string, std::string> provider_options;
2251
// See for provider options: https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html
2352
// "NeuralNetwork" is a faster ModelFormat than "MLProgram".
2453
provider_options["ModelFormat"] = std::string("NeuralNetwork");
25-
// TODO: need to make sure the cache works
26-
provider_options["ModelCacheDirectory"] = "./ModelCache/" + cache_folder_name;
54+
provider_options["ModelCacheDirectory"] = model_cache_path;
2755
// provider_options["MLComputeUnits"] = "ALL";
2856
// provider_options["RequireStaticInputShapes"] = "0";
2957
// provider_options["EnableOnSubgraphs"] = "0";
@@ -34,6 +62,73 @@ Ort::SessionOptions create_session_option(const std::string& cache_folder_name){
3462
return so;
3563
}
3664

65+
// Check the model file cache integrity by checking the existence of a flag file and the model hash stored
66+
// in the flag file. If the flag does not exist, we assume the file cache does not exist or is broken.
67+
// If the hash stored in the flag file does not match the model file, the model file is a new model, delete
68+
// the old cache.
69+
// Return
70+
// - bool: whether to write flag file after cache is created
71+
// - string: the model file hash to write into the flag file after Ort::Session is built and the cache is created.
72+
//
73+
// model_cache_path: Folder path to store model cache. This name is better to be unique for each model for
74+
// easier file management.
75+
// model_path: the model path to load the ML model. This is needed to ensure we delete the old model cache
76+
// when a new model
77+
std::pair<bool, std::string> clean_up_old_model_cache(const std::string& model_cache_path, const std::string& model_path){
78+
std::string file_hash = create_file_hash(model_path);
79+
if (file_hash.size() == 0){
80+
// the model file cannot be loaded
81+
return {true, ""};
82+
}
83+
84+
if (!fs::exists(fs::path(model_cache_path))){
85+
return {true, file_hash};
86+
}
87+
88+
const std::string flag_file_path = model_cache_path + "/HASH.txt";
89+
if (fs::exists(fs::path(flag_file_path))){
90+
std::ifstream fin(flag_file_path);
91+
if (fin){
92+
std::string line;
93+
fin >> line;
94+
if (line == file_hash){
95+
// hash match!
96+
return {false, file_hash};
97+
}
98+
}
99+
}
100+
// remove everything from model_cache_path
101+
fs::remove_all(fs::path(model_cache_path));
102+
return {true, file_hash};
103+
}
104+
105+
106+
void write_cache_flag_file(const std::string& model_cache_path, const std::string& hash){
107+
if (!fs::exists(fs::path(model_cache_path))){
108+
return;
109+
}
110+
const std::string flag_file_path = model_cache_path + "/HASH.txt";
111+
std::ofstream fout(flag_file_path);
112+
fout << hash;
113+
}
114+
115+
116+
Ort::Session create_session(const std::string& model_path, const std::string& cache_folder_name){
117+
const std::string model_cache_path = MODEL_CACHE_FOLDER + cache_folder_name;
118+
Ort::SessionOptions so = create_session_options(model_cache_path);
119+
bool write_flag_file = true;
120+
std::string file_hash;
121+
std::tie(write_flag_file, file_hash) = clean_up_old_model_cache(model_cache_path, model_path);
122+
123+
Ort::Env env;
124+
Ort::Session session{env, str_to_onnx_str(model_path).c_str(), so};
125+
// when Ort::Ssssion is created, if possible, it will create a model cache
126+
if (write_flag_file){
127+
write_cache_flag_file(model_cache_path, file_hash);
128+
}
129+
return session;
130+
}
131+
37132

38133
void print_model_input_output_info(const Ort::Session& session){
39134
const auto input_names = session.GetInputNames();

SerialPrograms/Source/ML/Models/ML_ONNXRuntimeHelpers.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
namespace PokemonAutomation{
1717
namespace ML{
1818

19-
20-
// Create an ONNX Runtime session options object.
19+
// Create an ONNX Session.
20+
// cache_folder_name: the folder name in under ./ModelCache/ to store model caches. This name is better
21+
// to be unique for each model for easier file management.
22+
//
2123
// If on macOS, will use CoreML as the backend.
2224
// Otherwise, use CPU to run the model.
2325
// TODO: add Cuda backend for Windows machine.
24-
// cache_folder_name: the folder name in under ./ModelCache/ to store model caches. This name is better
25-
// to be unique for each model for easier file management.
26-
Ort::SessionOptions create_session_option(const std::string& cache_folder_name);
26+
Ort::Session create_session(const std::string& model_path, const std::string& cache_folder_name);
2727

2828
// Handy function to create an ONNX Runtime tensor view class from a vector-like `buffer` object holding
2929
// the tensor data and an array-like `shape` object that represents the dimension of the tensor.

SerialPrograms/Source/ML/Models/ML_YOLOv5Model.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ std::tuple<int, int, double, double> resize_image_with_border(
5858

5959
YOLOv5Session::YOLOv5Session(const std::string& model_path, std::vector<std::string> label_names)
6060
: m_label_names(std::move(label_names))
61-
, m_session_options(create_session_option("YOLOv5"))
62-
, m_session{m_env, str_to_onnx_str(model_path).c_str(), m_session_options}
61+
, m_session{create_session(model_path, "YOLOv5")}
6362
, m_memory_info{Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)}
6463
, m_input_names{m_session.GetInputNames()}
6564
, m_output_names{m_session.GetOutputNames()}

SerialPrograms/Source/ML/Models/ML_YOLOv5Model.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class YOLOv5Session{
3636

3737
std::vector<std::string> m_label_names;
3838

39-
Ort::Env m_env;
40-
Ort::SessionOptions m_session_options;
4139
Ort::Session m_session;
4240
Ort::MemoryInfo m_memory_info;
4341
Ort::RunOptions m_run_options;

0 commit comments

Comments
 (0)