Skip to content

Commit 4f569f1

Browse files
author
Gin
committed
add handling for missing ML model files
1 parent bb46c90 commit 4f569f1

File tree

5 files changed

+36
-6
lines changed

5 files changed

+36
-6
lines changed

SerialPrograms/Source/ML/DataLabeling/ML_SegmentAnythingModel.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,23 @@ void compute_embeddings_for_folder(const std::string& embedding_model_path, cons
199199
return;
200200
}
201201

202+
if (!std::filesystem::exists(embedding_model_path)){
203+
std::cerr << "Error: no such embedding model path " << embedding_model_path << "." << std::endl;
204+
QMessageBox box;
205+
box.critical(nullptr, "Embedding Model Does Not Exist",
206+
QString::fromStdString("Embedding model path" + embedding_model_path + " does not exist."));
207+
return;
208+
}
209+
// since the embedding model has too many weights, onnx created a .data file to contain weights.
210+
auto embedding_model_data_path = embedding_model_path + ".data";
211+
if (!std::filesystem::exists(embedding_model_data_path)){
212+
std::cerr << "Error: no such embedding model data path " << embedding_model_data_path << "." << std::endl;
213+
QMessageBox box;
214+
box.critical(nullptr, "Embedding Model Data File Does Not Exist",
215+
QString::fromStdString("Embedding model data file path" + embedding_model_data_path + " does not exist."));
216+
return;
217+
}
218+
202219
SAMEmbedderSession embedding_session(embedding_model_path);
203220
std::vector<float> output_image_embedding;
204221
for (size_t i = 0; i < all_image_paths.size(); i++){
@@ -238,7 +255,7 @@ void compute_embeddings_for_folder(const std::string& embedding_model_path, cons
238255
embedding_session.run(resized_mat, output_image_embedding);
239256
save_image_embedding_to_disk(image_path, output_image_embedding);
240257
}
241-
258+
std::cout << "Done computing embeddings for images in folder " << image_folder_path << "." << std::endl;
242259

243260
}
244261

SerialPrograms/Source/ML/Programs/ML_LabelImages.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,25 @@ LabelImages::LabelImages(const LabelImages_Descriptor& descriptor)
205205
, WIDTH("<b>Width:</b>", LockMode::UNLOCK_WHILE_RUNNING, 0.4, 0.0, 1.0)
206206
, HEIGHT("<b>Height:</b>", LockMode::UNLOCK_WHILE_RUNNING, 0.4, 0.0, 1.0)
207207
, FORM_LABEL("bulbasaur")
208-
, m_sam_session{RESOURCE_PATH() + "ML/sam_cpu.onnx"}
209208
{
210209
ADD_OPTION(X);
211210
ADD_OPTION(Y);
212211
ADD_OPTION(WIDTH);
213212
ADD_OPTION(HEIGHT);
214213
ADD_OPTION(FORM_LABEL);
214+
215+
// , m_sam_session{RESOURCE_PATH() + "ML/sam_cpu.onnx"}
216+
const std::string sam_model_path = RESOURCE_PATH() + "ML/sam_cpu.onnx";
217+
if (std::filesystem::exists(sam_model_path)){
218+
m_sam_session = std::make_unique<SAMSession>(sam_model_path);
219+
} else{
220+
std::cerr << "Error: no such SAM model path " << sam_model_path << "." << std::endl;
221+
QMessageBox box;
222+
box.critical(nullptr, "SAM Model Does Not Exist",
223+
QString::fromStdString("SAM model path" + sam_model_path + " does not exist."));
224+
}
215225
}
226+
216227
void LabelImages::from_json(const JsonValue& json){
217228
const JsonObject* obj = json.to_object();
218229
if (obj == nullptr){
@@ -370,11 +381,11 @@ void LabelImages::compute_mask(VideoOverlaySet& overlay_set){
370381
return;
371382
}
372383

373-
if (m_image_embedding.size() == 0){
384+
if (!m_sam_session || m_image_embedding.size() == 0){
374385
// no embedding file loaded
375386
return;
376387
}
377-
m_sam_session.run(
388+
m_sam_session->run(
378389
m_image_embedding,
379390
(int)source_height, (int)source_width, {}, {},
380391
{box_x, box_y, box_x + box_width, box_y + box_height},

SerialPrograms/Source/ML/Programs/ML_LabelImages.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define PokemonAutomation_ML_LabelImages_H
99

1010
#include <QGraphicsScene>
11+
#include <memory>
1112
#include "Common/Cpp/Options/BatchOption.h"
1213
#include "Common/Cpp/Options/FloatingPointOption.h"
1314
#include "CommonFramework/Panels/PanelInstance.h"
@@ -120,7 +121,7 @@ class LabelImages : public PanelInstance{
120121
// buffer to compute SAM mask on
121122
ImageRGB32 m_mask_image;
122123

123-
SAMSession m_sam_session;
124+
std::unique_ptr<SAMSession> m_sam_session;
124125
std::vector<ObjectAnnotation> m_annotations;
125126
size_t m_last_object_idx = 0;
126127
std::string m_annotation_file_path;

SerialPrograms/Source/ML/UI/ML_ImageAnnotationDisplayWidget.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ ImageAnnotationDisplayWidget::ImageAnnotationDisplayWidget(
4545
: QWidget(&parent)
4646
, m_session(session)
4747
{
48+
4849
QVBoxLayout* layout = new QVBoxLayout(this);
4950
layout->setContentsMargins(0, 0, 0, 0);
5051
layout->setAlignment(Qt::AlignTop);

SerialPrograms/Source/ML/UI/ML_ImageAnnotationSourceSelectorWidget.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ ImageAnnotationSourceSelectorWidget::ImageAnnotationSourceSelectorWidget(ImageAn
7373
folder_info_row->addSpacing(2);
7474
folder_info_row->addWidget(next_image_button, 2);
7575
folder_info_row->addSpacing(10);
76-
folder_info_row->addWidget(new QLabel(" ", this), 10);
76+
folder_info_row->addWidget(new QLabel(" ", this), 10); // empty label to push the buttons above to the left
7777

7878

7979
// Set the action for the video reset button

0 commit comments

Comments
 (0)