Skip to content

Commit 5bcb651

Browse files
committed
2 parents 9c44282 + f441a8f commit 5bcb651

13 files changed

+175
-157
lines changed

SerialPrograms/Source/CommonFramework/VideoPipeline/VideoSession.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#include "VideoSources/VideoSource_Null.h"
1313
#include "VideoSession.h"
1414

15-
//#include <iostream>
16-
//using std::cout;
17-
//using std::endl;
15+
// #include <iostream>
16+
// using std::cout;
17+
// using std::endl;
1818

1919
namespace PokemonAutomation{
2020

SerialPrograms/Source/ML/DataLabeling/SegmentAnythingModel.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
#include <QDirIterator>
1010
#include <fstream>
1111
#include <iostream>
12+
#include <QMessageBox>
1213
#include <onnxruntime_cxx_api.h>
1314
#include <opencv2/imgcodecs.hpp>
15+
#include <opencv2/imgproc.hpp>
1416
#include "3rdParty/ONNX/OnnxToolsPA.h"
1517
#include "SegmentAnythingModel.h"
1618

@@ -207,7 +209,7 @@ void SAMSession::run(
207209

208210
// save the image embedding as a file with path <image_filepath>.embedding
209211
void save_image_embedding_to_disk(const std::string& image_filepath, const std::vector<float>& embedding){
210-
std::string embedding_path = image_filepath + ".embedding";
212+
const std::string embedding_path = image_filepath + ".embedding";
211213
std::ofstream fout(embedding_path, std::ios::binary);
212214
// write embedding shape
213215
fout.write(reinterpret_cast<const char*>(&SAM_EMBEDDER_OUTPUT_N_CHANNELS), sizeof(SAM_EMBEDDER_OUTPUT_N_CHANNELS));
@@ -248,7 +250,7 @@ bool load_image_embedding(const std::string& image_filepath, std::vector<float>&
248250
}
249251

250252

251-
void compute_embeddings_for_folder(const std::string& image_folder_path){
253+
void compute_embeddings_for_folder(const std::string& embedding_model_path, const std::string& image_folder_path){
252254
QDir image_dir(image_folder_path.c_str());
253255
if (!image_dir.exists()){
254256
std::cerr << "Error: input image folder path " << image_folder_path << " does not exist." << std::endl;
@@ -261,6 +263,48 @@ void compute_embeddings_for_folder(const std::string& image_folder_path){
261263
all_image_paths.emplace_back(image_file_iter.next().toStdString());
262264
}
263265
std::cout << "Found " << all_image_paths.size() << " images recursively in folder " << image_folder_path << std::endl;
266+
267+
SAMEmbedderSession embedding_session(embedding_model_path);
268+
std::vector<float> output_image_embedding;
269+
for (size_t i = 0; i < all_image_paths.size(); i++){
270+
const auto& image_path = all_image_paths[i];
271+
std::cout << (i+1) << "/" << all_image_paths.size() << ": ";
272+
const std::string embedding_path = image_path + ".embedding";
273+
if (std::filesystem::exists(embedding_path)){
274+
std::cout << "skip already computed embedding " << embedding_path << "." << std::endl;
275+
continue;
276+
}
277+
std::cout << "computing embedding for " << image_path << "..." << std::endl;
278+
cv::Mat image_bgr = cv::imread(image_path);
279+
if (image_bgr.empty()){
280+
std::cerr << "Error: image empty. Probably the file is not an image?" << std::endl;
281+
QMessageBox box;
282+
box.warning(nullptr, "Unable To Open Image",
283+
QString::fromStdString("Cannot open image file " + image_path + ". Probably not an actual image?"));
284+
return;
285+
}
286+
cv::Mat image;
287+
if (image_bgr.channels() == 4){
288+
cv::cvtColor(image_bgr, image, cv::COLOR_BGRA2RGB);
289+
} else if (image_bgr.channels() == 3){
290+
cv::cvtColor(image_bgr, image, cv::COLOR_BGR2RGB);
291+
} else{
292+
std::cerr << "Error: wrong image channels. Only work with RGB or RGBA images." << std::endl;
293+
QMessageBox box;
294+
box.warning(nullptr, "Wrong Image Channels",
295+
QString::fromStdString("Image has " + std::to_string(image_bgr.channels()) + " channels. Only support 3 or 4 channels."));
296+
return;
297+
}
298+
299+
cv::Mat resized_mat; // resize to the shape for the ML model input
300+
cv::resize(image, resized_mat, cv::Size(SAM_EMBEDDER_INPUT_IMAGE_WIDTH, SAM_EMBEDDER_INPUT_IMAGE_HEIGHT));
301+
302+
output_image_embedding.clear();
303+
embedding_session.run(resized_mat, output_image_embedding);
304+
save_image_embedding_to_disk(image_path, output_image_embedding);
305+
}
306+
307+
264308
}
265309

266310
}

SerialPrograms/Source/ML/DataLabeling/SegmentAnythingModel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ bool load_image_embedding(const std::string& image_filepath, std::vector<float>&
2929
// Save the image embedding as a file with path <image_filepath>.embedding.
3030
void save_image_embedding_to_disk(const std::string& image_filepath, const std::vector<float>& embedding);
3131

32-
// Compute embeddings for all images in a folder.
32+
// Compute embeddings for all images in a folder. Only support .png, .jpg and .jpeg filename extensions so far.
3333
// This can be very slow!
34-
void compute_embeddings_for_folder(const std::string& image_folder_path);
34+
void compute_embeddings_for_folder(const std::string& embedding_model_path, const std::string& image_folder_path);
3535

3636

3737
class SAMEmbedderSession{

SerialPrograms/Source/ML/Programs/ML_LabelImages.cpp

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ JsonValue LabelImages::to_json() const{
227227
JsonObject obj = std::move(*m_options.to_json().to_object());
228228
obj["ImageSetup"] = m_display_option.to_json();
229229

230+
save_annotation_to_file();
231+
return obj;
232+
}
233+
234+
void LabelImages::save_annotation_to_file() const{
230235
// m_annotation_file_path
231236
if (m_annotation_file_path.size() > 0 && !m_fail_to_load_annotation_file){
232237
JsonArray anno_json_arr;
@@ -236,8 +241,19 @@ JsonValue LabelImages::to_json() const{
236241
cout << "Saving annotation to " << m_annotation_file_path << endl;
237242
anno_json_arr.dump(m_annotation_file_path);
238243
}
239-
return obj;
240244
}
245+
246+
void LabelImages::clear_for_new_image(){
247+
source_image_height = source_image_height = 0;
248+
m_image_embedding.clear();
249+
m_output_boolean_mask.clear();
250+
m_mask_image = ImageRGB32();
251+
m_annotations.clear();
252+
m_last_object_idx = 0;
253+
m_annotation_file_path = "";
254+
m_fail_to_load_annotation_file = false;
255+
}
256+
241257
QWidget* LabelImages::make_widget(QWidget& parent, PanelHolder& holder){
242258
return new LabelImages_Widget(parent, *this, holder);
243259
}
@@ -253,6 +269,7 @@ void LabelImages::load_image_related_data(const std::string& image_path, size_t
253269
if (!embedding_loaded){
254270
return; // no embedding, then no way for us to annotate
255271
}
272+
256273
// see if we can load the previously created labels
257274
const std::string anno_filename = std::filesystem::path(image_path).filename().replace_extension(".json").string();
258275

@@ -410,14 +427,16 @@ void LabelImages::compute_mask(VideoOverlaySet& overlay_set){
410427
}
411428

412429
void LabelImages::compute_embeddings_for_folder(const std::string& image_folder_path){
413-
ML::compute_embeddings_for_folder(image_folder_path);
430+
std::string embedding_model_path = RESOURCE_PATH() + "ML/sam_embedder_cpu.onnx";
431+
std::cout << "Use SAM Embedding model " << embedding_model_path << std::endl;
432+
ML::compute_embeddings_for_folder(embedding_model_path, image_folder_path);
414433
}
415434

416435

417436

418437
LabelImages_Widget::~LabelImages_Widget(){
419438
m_program.FORM_LABEL.remove_listener(*this);
420-
delete m_switch_widget;
439+
delete m_image_display_widget;
421440
}
422441
LabelImages_Widget::LabelImages_Widget(
423442
QWidget& parent,
@@ -431,6 +450,9 @@ LabelImages_Widget::LabelImages_Widget(
431450
, m_drawn_box(*this, m_display_session.overlay())
432451
{
433452
m_program.FORM_LABEL.add_listener(*this);
453+
m_display_session.video_session().add_state_listener(*this);
454+
455+
m_embedding_info_label = new QLabel(this);
434456

435457
QVBoxLayout* layout = new QVBoxLayout(this);
436458
layout->setContentsMargins(0, 0, 0, 0);
@@ -445,8 +467,13 @@ LabelImages_Widget::LabelImages_Widget(
445467
QVBoxLayout* scroll_layout = new QVBoxLayout(scroll_inner);
446468
scroll_layout->setAlignment(Qt::AlignTop);
447469

448-
m_switch_widget = new ImageAnnotationDisplayWidget(*this, m_display_session, 0);
449-
scroll_layout->addWidget(m_switch_widget);
470+
m_image_display_widget = new ImageAnnotationDisplayWidget(*this, m_display_session, 0);
471+
scroll_layout->addWidget(m_image_display_widget);
472+
473+
QHBoxLayout* embedding_info_row = new QHBoxLayout();
474+
scroll_layout->addLayout(embedding_info_row);
475+
embedding_info_row->addWidget(new QLabel("<b>Image Embedding File:</b> ", this));
476+
embedding_info_row->addWidget(m_embedding_info_label);
450477

451478
QPushButton* button = new QPushButton("Delete Last Mask", scroll_inner);
452479
scroll_layout->addWidget(button);
@@ -476,17 +503,14 @@ LabelImages_Widget::LabelImages_Widget(
476503
}
477504
});
478505

479-
const auto cur_res = m_display_session.video_session().current_resolution();
480-
if (cur_res.width > 0 && cur_res.height > 0){
481-
const std::string& image_path = m_display_session.option().m_image_path;
482-
m_program.load_image_related_data(image_path, cur_res.width, cur_res.height);
483-
m_program.update_rendered_objects(m_overlay_set);
484-
}
485-
486-
487506
cout << "LabelImages_Widget built" << endl;
488507
}
489508

509+
void LabelImages_Widget::clear_for_new_image(){
510+
m_overlay_set.clear();
511+
m_program.clear_for_new_image();
512+
}
513+
490514
void LabelImages_Widget::on_config_value_changed(void* object){
491515
if (m_program.m_annotations.size() > 0 && m_program.m_last_object_idx < m_program.m_annotations.size()){
492516
std::string& cur_label = m_program.m_annotations[m_program.m_last_object_idx].label;
@@ -495,6 +519,40 @@ void LabelImages_Widget::on_config_value_changed(void* object){
495519
}
496520
}
497521

522+
// This callback function will be called whenever the display source (the image source) is loaded or reloaded:
523+
void LabelImages_Widget::post_startup(VideoSource* source){
524+
const std::string& image_path = m_display_session.option().m_image_path;
525+
526+
m_program.save_annotation_to_file(); // save the current annotation file
527+
clear_for_new_image();
528+
if (image_path.size() == 0){
529+
m_embedding_info_label->setText("");
530+
return;
531+
}
532+
533+
const std::string embedding_path = image_path + ".embedding";
534+
const std::string embedding_path_display = "<IMAGE_FOLDER>/" + std::filesystem::path(embedding_path).filename().string();
535+
if (!std::filesystem::exists(embedding_path)){
536+
m_embedding_info_label->setText(QString::fromStdString(embedding_path_display + " Dose Not Exist. Cannot Annotate The Image!"));
537+
m_embedding_info_label->setStyleSheet("color: red");
538+
return;
539+
}
540+
541+
m_embedding_info_label->setText(QString::fromStdString(embedding_path_display));
542+
m_embedding_info_label->setStyleSheet("color: blue");
543+
544+
const auto cur_res = m_display_session.video_session().current_resolution();
545+
if (cur_res.width == 0 || cur_res.height == 0){
546+
QMessageBox box;
547+
box.warning(nullptr, "Invalid Image Dimension",
548+
QString::fromStdString("Loaded image " + image_path + " has invalid dimension: " + cur_res.to_string()));
549+
return;
550+
}
551+
552+
m_program.load_image_related_data(image_path, cur_res.width, cur_res.height);
553+
m_program.update_rendered_objects(m_overlay_set);
554+
}
555+
498556

499557

500558
}

SerialPrograms/Source/ML/Programs/ML_LabelImages.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
class QGraphicsView;
2727
class QGraphicsPixmapItem;
28+
class QLabel;
2829

2930
namespace PokemonAutomation{
3031

@@ -72,6 +73,11 @@ class LabelImages : public PanelInstance{
7273
virtual void from_json(const JsonValue& json) override;
7374
virtual JsonValue to_json() const override;
7475

76+
void save_annotation_to_file() const;
77+
78+
// called after loading a new image, clean up all internal data
79+
void clear_for_new_image();
80+
7581
// Load image related data:
7682
// - Image SAM embedding data file, which has the same file path but with a name suffix ".embedding"
7783
// - Existing annotation file, which is stored in a pre-defined ML_ANNOTATION_PATH() and with the same filename as
@@ -145,7 +151,7 @@ class DrawnBoundingBox : public ConfigOption::Listener, public VideoOverlay::Mou
145151
};
146152

147153

148-
class LabelImages_Widget : public PanelWidget, public ConfigOption::Listener{
154+
class LabelImages_Widget : public PanelWidget, public ConfigOption::Listener, public VideoSession::StateListener{
149155
public:
150156
~LabelImages_Widget();
151157
LabelImages_Widget(
@@ -154,14 +160,24 @@ class LabelImages_Widget : public PanelWidget, public ConfigOption::Listener{
154160
PanelHolder& holder
155161
);
156162

163+
// called after loading a new image, clean up all internal data
164+
void clear_for_new_image();
165+
157166
virtual void on_config_value_changed(void* object) override;
158167

168+
// Overwrites VideoSession::StateListener::post_startup().
169+
virtual void post_startup(VideoSource* source) override;
170+
159171
private:
160172
LabelImages& m_program;
161173
ImageAnnotationDisplaySession& m_display_session;
162-
ImageAnnotationDisplayWidget* m_switch_widget;
174+
175+
ImageAnnotationDisplayWidget* m_image_display_widget;
176+
163177
VideoOverlaySet m_overlay_set;
164178
DrawnBoundingBox m_drawn_box;
179+
180+
QLabel* m_embedding_info_label = nullptr;
165181
ConfigWidget* m_option_widget;
166182

167183
friend class DrawnBoundingBox;

SerialPrograms/Source/ML/UI/ML_ImageAnnotationCommandRow.cpp

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ ImageAnnotationCommandRow::ImageAnnotationCommandRow(
3333
{
3434
QHBoxLayout* command_row = new QHBoxLayout(this);
3535
command_row->setContentsMargins(0, 0, 0, 0);
36-
37-
command_row->addWidget(new QLabel("<b>Console Type:</b>", this), 2);
38-
command_row->addSpacing(5);
39-
4036
QHBoxLayout* row = new QHBoxLayout();
4137
command_row->addLayout(row, 12);
4238

@@ -80,12 +76,6 @@ ImageAnnotationCommandRow::ImageAnnotationCommandRow(
8076

8177
row->addSpacing(5);
8278

83-
m_load_profile_button = new QPushButton("Load Profile", this);
84-
row->addWidget(m_load_profile_button, 2);
85-
86-
m_save_profile_button = new QPushButton("Save Profile", this);
87-
row->addWidget(m_save_profile_button, 2);
88-
8979
// m_test_button = new QPushButton("Test Button", this);
9080
// row->addWidget(m_test_button, 3);
9181

@@ -132,30 +122,6 @@ ImageAnnotationCommandRow::ImageAnnotationCommandRow(
132122
this, [this](Qt::CheckState state){ m_session.set_enabled_stats(state == Qt::Checked); }
133123
);
134124
#endif
135-
connect(
136-
m_load_profile_button, &QPushButton::clicked,
137-
this, [this](bool) { emit load_profile(); }
138-
);
139-
connect(
140-
m_save_profile_button, &QPushButton::clicked,
141-
this, [this](bool) { emit save_profile(); }
142-
);
143-
144-
#if (QT_VERSION_MAJOR == 6) && (QT_VERSION_MINOR >= 8)
145-
if (IS_BETA_VERSION || PreloadSettings::instance().DEVELOPER_MODE){
146-
m_video_button = new QPushButton("Video Capture", this);
147-
command_row->addWidget(m_video_button, 2);
148-
if (GlobalSettings::instance().STREAM_HISTORY->enabled()){
149-
connect(
150-
m_video_button, &QPushButton::clicked,
151-
this, [this](bool){ emit video_requested(); }
152-
);
153-
}else{
154-
m_video_button->setEnabled(false);
155-
m_video_button->setToolTip("Please turn on Stream History to enable video capture.");
156-
}
157-
}
158-
#endif
159125

160126
m_session.add_listener(*this);
161127
}
@@ -177,11 +143,6 @@ void ImageAnnotationCommandRow::set_focus(bool focused){
177143
}
178144

179145
void ImageAnnotationCommandRow::update_ui(){
180-
// cout << "ImageAnnotationCommandRow::update_ui(): focus = " << m_last_known_focus << endl;
181-
182-
bool stopped = m_last_known_state == ProgramState::STOPPED;
183-
m_load_profile_button->setEnabled(stopped);
184-
185146
if (!m_last_known_focus){
186147
m_status->setText(
187148
QString::fromStdString(

SerialPrograms/Source/ML/UI/ML_ImageAnnotationCommandRow.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ class ImageAnnotationCommandRow :
3939
void on_key_press(const QKeyEvent& key);
4040
void on_key_release(const QKeyEvent& key);
4141

42-
signals:
43-
void load_profile();
44-
void save_profile();
45-
void screenshot_requested();
46-
void video_requested();
47-
4842
public:
4943
void set_focus(bool focused);
5044
void update_ui();
@@ -69,9 +63,6 @@ class ImageAnnotationCommandRow :
6963
QCheckBox* m_overlay_boxes;
7064
QCheckBox* m_overlay_stats;
7165

72-
QPushButton* m_load_profile_button;
73-
QPushButton* m_save_profile_button;
74-
QPushButton* m_video_button;
7566
bool m_last_known_focus;
7667
ProgramState m_last_known_state;
7768
};

0 commit comments

Comments
 (0)