Skip to content

Commit 1b19127

Browse files
author
Gin
committed
writing image embedding computation feature
1 parent d135e08 commit 1b19127

File tree

7 files changed

+65
-20
lines changed

7 files changed

+65
-20
lines changed

SerialPrograms/Source/ML/DataLabeling/SegmentAnythingModel.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
* Run Segment Anything Model (SAM) to segment objects on images
66
*/
77

8+
#include <QDir>
9+
#include <QDirIterator>
810
#include <fstream>
911
#include <iostream>
1012
#include <onnxruntime_cxx_api.h>
@@ -246,5 +248,20 @@ bool load_image_embedding(const std::string& image_filepath, std::vector<float>&
246248
}
247249

248250

251+
void compute_embeddings_for_folder(const std::string& image_folder_path){
252+
QDir image_dir(image_folder_path.c_str());
253+
if (!image_dir.exists()){
254+
std::cerr << "Error: input image folder path " << image_folder_path << " does not exist." << std::endl;
255+
return;
256+
}
257+
258+
QDirIterator image_file_iter(image_dir.absolutePath(), {"*.png", "*.jpg", "*.jpeg"}, QDir::Files, QDirIterator::Subdirectories);
259+
std::vector<std::string> all_image_paths;
260+
while (image_file_iter.hasNext()){
261+
all_image_paths.emplace_back(image_file_iter.next().toStdString());
262+
}
263+
std::cout << "Found " << all_image_paths.size() << " images recursively in folder " << image_folder_path << std::endl;
264+
}
265+
249266
}
250267
}

SerialPrograms/Source/ML/DataLabeling/SegmentAnythingModel.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,18 @@ namespace PokemonAutomation{
2121
namespace ML{
2222

2323

24-
// load pre-computed image embedding from disk
25-
// return true if there is the embedding file
24+
// Load pre-computed image embedding from disk
25+
// Return true if there is the embedding file.
26+
// The embedding is stored in a file in the same folder as the image, having the same name but with a suffix ".embedding".
2627
bool load_image_embedding(const std::string& image_filepath, std::vector<float>& image_embedding);
2728

28-
// save the image embedding as a file with path <image_filepath>.embedding
29+
// Save the image embedding as a file with path <image_filepath>.embedding.
2930
void save_image_embedding_to_disk(const std::string& image_filepath, const std::vector<float>& embedding);
3031

32+
// Compute embeddings for all images in a folder.
33+
// This can be very slow!
34+
void compute_embeddings_for_folder(const std::string& image_folder_path);
35+
3136

3237
class SAMEmbedderSession{
3338
public:

SerialPrograms/Source/ML/Programs/ML_LabelImages.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
*
55
*/
66

7+
#include <QFileDialog>
78
#include <QLabel>
89
#include <QDir>
10+
#include <QDirIterator>
911
#include <QVBoxLayout>
1012
#include <QGraphicsView>
1113
#include <QGraphicsScene>
@@ -407,6 +409,12 @@ void LabelImages::compute_mask(VideoOverlaySet& overlay_set){
407409
}
408410
}
409411

412+
void LabelImages::compute_embeddings_for_folder(const std::string& image_folder_path){
413+
ML::compute_embeddings_for_folder(image_folder_path);
414+
}
415+
416+
417+
410418
LabelImages_Widget::~LabelImages_Widget(){
411419
m_program.FORM_LABEL.remove_listener(*this);
412420
delete m_switch_widget;
@@ -453,9 +461,21 @@ LabelImages_Widget::LabelImages_Widget(
453461
program.update_rendered_objects(this->m_overlay_set);
454462
});
455463

464+
// Add all option UI elements defined by LabelImage program.
456465
m_option_widget = program.m_options.make_QtWidget(*scroll_inner);
457466
scroll_layout->addWidget(&m_option_widget->widget());
458467

468+
button = new QPushButton("Compute Image Embeddings (SLOW!)", scroll_inner);
469+
scroll_layout->addWidget(button);
470+
connect(button, &QPushButton::clicked, this, [this](bool){
471+
std::string folder_path = QFileDialog::getExistingDirectory(
472+
nullptr, "Open image folder", ".").toStdString();
473+
474+
if (folder_path.size() > 0){
475+
this->m_program.compute_embeddings_for_folder(folder_path);
476+
}
477+
});
478+
459479
const auto cur_res = m_display_session.video_session().current_resolution();
460480
if (cur_res.width > 0 && cur_res.height > 0){
461481
const std::string& image_path = m_display_session.option().m_image_path;

SerialPrograms/Source/ML/Programs/ML_LabelImages.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,34 @@ class LabelImages_Descriptor : public PanelDescriptor{
6060
LabelImages_Descriptor();
6161
};
6262

63-
// label image program
63+
64+
// Program to annoatation images for training ML models
6465
class LabelImages : public PanelInstance{
6566
public:
6667
LabelImages(const LabelImages_Descriptor& descriptor);
6768
virtual QWidget* make_widget(QWidget& parent, PanelHolder& holder) override;
6869

6970
public:
70-
// Serialization
71+
// Serialization
7172
virtual void from_json(const JsonValue& json) override;
7273
virtual JsonValue to_json() const override;
7374

75+
// Load image related data:
76+
// - Image SAM embedding data file, which has the same file path but with a name suffix ".embedding"
77+
// - Existing annotation file, which is stored in a pre-defined ML_ANNOTATION_PATH() and with the same filename as
78+
// the image but with name extension replaced to be ".json".
7479
void load_image_related_data(const std::string& image_path, const size_t source_image_width, const size_t source_image_height);
7580

81+
// Update rendering data reflect the current annotation
7682
void update_rendered_objects(VideoOverlaySet& overlayset);
7783

84+
// Use user currently drawn box to compute per-pixel masks on the image using SAM model
7885
void compute_mask(VideoOverlaySet& overlay_set);
7986

87+
// Compute embeddings for all images in a folder.
88+
// This can be very slow!
89+
void compute_embeddings_for_folder(const std::string& image_folder);
90+
8091
private:
8192
friend class LabelImages_Widget;
8293
friend class DrawnBoundingBox;

SerialPrograms/Source/ML/UI/ML_ImageAnnotationCommandRow.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,6 @@ ImageAnnotationCommandRow::ImageAnnotationCommandRow(
8686
m_save_profile_button = new QPushButton("Save Profile", this);
8787
row->addWidget(m_save_profile_button, 2);
8888

89-
m_screenshot_button = new QPushButton("Screenshot", this);
90-
// m_screenshot_button->setToolTip("Take a screenshot of the console and save to disk.");
91-
row->addWidget(m_screenshot_button, 2);
92-
93-
9489
// m_test_button = new QPushButton("Test Button", this);
9590
// row->addWidget(m_test_button, 3);
9691

@@ -145,11 +140,7 @@ ImageAnnotationCommandRow::ImageAnnotationCommandRow(
145140
m_save_profile_button, &QPushButton::clicked,
146141
this, [this](bool) { emit save_profile(); }
147142
);
148-
connect(
149-
m_screenshot_button, &QPushButton::clicked,
150-
this, [this](bool){ emit screenshot_requested(); }
151-
);
152-
143+
153144
#if (QT_VERSION_MAJOR == 6) && (QT_VERSION_MINOR >= 8)
154145
if (IS_BETA_VERSION || PreloadSettings::instance().DEVELOPER_MODE){
155146
m_video_button = new QPushButton("Video Capture", this);

SerialPrograms/Source/ML/UI/ML_ImageAnnotationCommandRow.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class ImageAnnotationCommandRow :
7171

7272
QPushButton* m_load_profile_button;
7373
QPushButton* m_save_profile_button;
74-
QPushButton* m_screenshot_button;
7574
QPushButton* m_video_button;
7675
bool m_last_known_focus;
7776
ProgramState m_last_known_state;

SerialPrograms/Source/ML/UI/ML_ImageAnnotationSourceSelectorWidget.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ ImageAnnotationSourceSelectorWidget::ImageAnnotationSourceSelectorWidget(ImageAn
5555
m_reset_button, &QPushButton::clicked,
5656
this, [this](bool){
5757
std::string path = QFileDialog::getOpenFileName(
58-
nullptr, "Open image file", ".", "*.png *.jpg"
58+
nullptr, "Open image file", ".", "*.png *.jpg *.jpeg"
5959
).toStdString();
60-
61-
m_source_file_path_label->setText(QString::fromStdString(path));
62-
m_session.set_image_source(path);
60+
61+
if (path.size() > 0){
62+
m_source_file_path_label->setText(QString::fromStdString(path));
63+
m_session.set_image_source(path);
64+
}
6365
}
6466
);
6567

0 commit comments

Comments
 (0)