Skip to content

Commit 8b583be

Browse files
author
Gin
committed
writing image embedding loading code
1 parent c76a3b6 commit 8b583be

File tree

5 files changed

+77
-9
lines changed

5 files changed

+77
-9
lines changed

SerialPrograms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ file(GLOB MAIN_SOURCES
871871
Source/Kernels/Waterfill/Kernels_Waterfill_Session.h
872872
Source/Kernels/Waterfill/Kernels_Waterfill_Session.tpp
873873
Source/Kernels/Waterfill/Kernels_Waterfill_Types.h
874+
Source/ML/DataLabeling/SegmentAnythingEmbedding.cpp
875+
Source/ML/DataLabeling/SegmentAnythingEmbedding.h
874876
Source/ML/ML_Panels.cpp
875877
Source/ML/ML_Panels.h
876878
Source/ML/Programs/ML_LabelImages.cpp
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* ML Segment Anything Embedding
2+
*
3+
* From: https://github.com/PokemonAutomation/
4+
*
5+
* Creating and loading image embedding used by Segment Anything Model
6+
*/
7+
8+
#include <fstream>
9+
#include <iostream>
10+
#include "SegmentAnythingEmbedding.h"
11+
12+
13+
bool load_image_embedding(const std::string& image_filepath, std::vector<float>& image_embedding){
14+
std::string emebdding_path = image_filepath + ".embedding";
15+
std::ifstream fin(emebdding_path, std::ios::binary);
16+
if (!fin.is_open()){
17+
std::cout << "No embedding for image " << image_filepath << std::endl;
18+
return false;
19+
}
20+
21+
int embedding_n_channels = 0, embedding_height = 0, emebedding_width = 0;
22+
fin.read(reinterpret_cast<char*>(&embedding_n_channels), sizeof(int));
23+
fin.read(reinterpret_cast<char*>(&embedding_height), sizeof(int));
24+
fin.read(reinterpret_cast<char*>(&emebedding_width), sizeof(int));
25+
26+
std::cout << "Image embedding shape [" << embedding_n_channels << ", " << embedding_height
27+
<< ", " << emebedding_width << "]" << std::endl;
28+
if (embedding_n_channels <= 0 || embedding_height <= 0 || emebedding_width <= 0){
29+
std::string err_msg = "Image embedding wrong dimension from " + emebdding_path;
30+
std::cerr << err_msg << std::endl;
31+
throw std::runtime_error(err_msg);
32+
}
33+
34+
const int size = embedding_n_channels * embedding_height * emebedding_width;
35+
image_embedding.resize(size);
36+
fin.read(reinterpret_cast<char*>(image_embedding.data()), sizeof(float) * size);
37+
std::cout << "Loaded image embedding from " << emebdding_path << std::endl;
38+
return true;
39+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/* ML Segment Anything Embedding
2+
*
3+
* From: https://github.com/PokemonAutomation/
4+
*
5+
* Creating and loading image embedding used by Segment Anything Model
6+
*/
7+
8+
#ifndef PokemonAutomation_ML_SEGMENTANYTHINGEMBEDDING_H
9+
#define PokemonAutomation_ML_SEGMENTANYTHINGEMBEDDING_H
10+
11+
#include <string>
12+
#include <vector>
13+
14+
// load pre-computed image embedding from disk
15+
// return true if there is the embedding file
16+
bool load_image_embedding(const std::string& image_filepath, std::vector<float>& image_embedding);
17+
18+
19+
#endif

SerialPrograms/Source/ML/Programs/ML_LabelImages.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <QPushButton>
1515
#include <QResizeEvent>
1616
#include <iostream>
17+
#include <fstream>
18+
#include <filesystem>
1719
#include "Common/Cpp/Json/JsonObject.h"
1820
#include "Common/Cpp/Json/JsonValue.h"
1921
#include "Common/Qt/CollapsibleGroupBox.h"
@@ -24,6 +26,7 @@
2426
#include "ML_LabelImages.h"
2527
#include "Pokemon/Pokemon_Strings.h"
2628
#include "Common/Qt/Options/ConfigWidget.h"
29+
#include "ML/DataLabeling/SegmentAnythingEmbedding.h"
2730

2831

2932
using std::cout;
@@ -52,7 +55,7 @@ DrawnBoundingBox::~DrawnBoundingBox(){
5255
void DrawnBoundingBox::on_config_value_changed(void* object){
5356
std::lock_guard<std::mutex> lg(m_lock);
5457
m_overlay_set.clear();
55-
m_overlay_set.add(COLOR_RED, {m_parent.X, m_parent.Y, m_parent.WIDTH, m_parent.HEIGHT});
58+
m_overlay_set.add(COLOR_RED, {m_parent.X, m_parent.Y, m_parent.WIDTH, m_parent.HEIGHT}, "Unknown");
5659
}
5760
void DrawnBoundingBox::on_mouse_press(double x, double y){
5861
m_parent.WIDTH.set(0);
@@ -109,6 +112,7 @@ LabelImages_Descriptor::LabelImages_Descriptor()
109112
{}
110113

111114

115+
112116
#define ADD_OPTION(x) m_options.add_option(x, #x)
113117

114118
LabelImages::LabelImages(const LabelImages_Descriptor& descriptor)
@@ -177,20 +181,23 @@ LabelImages_Widget::LabelImages_Widget(
177181
QPushButton* button = new QPushButton("This is a button", scroll_inner);
178182
scroll_layout->addWidget(button);
179183
connect(button, &QPushButton::clicked, this, [&instance](bool){
180-
const VideoSourceDescriptor* videoSource = instance.m_switch_control_option.m_video.descriptor().get();
181-
auto imageSource = dynamic_cast<const VideoSourceDescriptor_StillImage*>(videoSource);
182-
if (imageSource != nullptr){
183-
cout << "Image source: " << imageSource->path() << endl;
184+
const VideoSourceDescriptor* video_source = instance.m_switch_control_option.m_video.descriptor().get();
185+
auto image_source = dynamic_cast<const VideoSourceDescriptor_StillImage*>(video_source);
186+
if (image_source != nullptr){
187+
cout << "Image source: " << image_source->path() << endl;
184188
}
185189
});
186190

187191
m_option_widget = instance.m_options.make_QtWidget(*scroll_inner);
188192
scroll_layout->addWidget(&m_option_widget->widget());
189193

190-
const VideoSourceDescriptor* videoSource = instance.m_switch_control_option.m_video.descriptor().get();
191-
auto imageSource = dynamic_cast<const VideoSourceDescriptor_StillImage*>(videoSource);
192-
if (imageSource != nullptr){
193-
cout << "Image source: " << imageSource->path() << endl;
194+
const VideoSourceDescriptor* video_source = instance.m_switch_control_option.m_video.descriptor().get();
195+
auto image_source = dynamic_cast<const VideoSourceDescriptor_StillImage*>(video_source);
196+
if (image_source != nullptr){
197+
std::string image_path = image_source->path();
198+
cout << "Image source: " << image_path << endl;
199+
// if no such embedding file, m_iamge_embedding will be empty
200+
load_image_embedding(image_path, m_image_embedding);
194201
}
195202
cout << "LabelImages_Widget built" << endl;
196203
}

SerialPrograms/Source/ML/Programs/ML_LabelImages.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class LabelImages_Widget : public PanelWidget{
100100
NintendoSwitch::SwitchSystemWidget* m_switch_widget;
101101
DrawnBoundingBox m_drawn_box;
102102
ConfigWidget* m_option_widget;
103+
std::vector<float> m_image_embedding;
103104
};
104105

105106

0 commit comments

Comments
 (0)