Skip to content

Commit 862b7dd

Browse files
author
Gin
committed
add annotation save load feature
1 parent 996ef64 commit 862b7dd

File tree

2 files changed

+161
-25
lines changed

2 files changed

+161
-25
lines changed

SerialPrograms/Source/ML/Programs/ML_LabelImages.cpp

Lines changed: 150 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
#include <QScrollArea>
1414
#include <QPushButton>
1515
#include <QResizeEvent>
16+
#include <QMessageBox>
1617
#include <iostream>
1718
#include <fstream>
1819
#include <filesystem>
1920
#include <cmath>
2021
#include "CommonFramework/Globals.h"
22+
#include "Common/Cpp/Json/JsonArray.h"
2123
#include "Common/Cpp/Json/JsonObject.h"
2224
#include "Common/Cpp/Json/JsonValue.h"
25+
#include "Common/Cpp/Json/JsonTools.h"
2326
#include "Common/Qt/CollapsibleGroupBox.h"
2427
#include "Pokemon/Resources/Pokemon_PokemonForms.h"
2528
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
@@ -40,7 +43,66 @@ namespace PokemonAutomation{
4043
namespace ML{
4144

4245

43-
ObjectAnnotation::ObjectAnnotation() {}
46+
ObjectAnnotation::ObjectAnnotation(): user_box(0,0,0,0), mask_box(0,0,0,0) {}
47+
48+
// if failed to pass, will throw JsonParseException
49+
ObjectAnnotation json_to_object_annotation(const JsonValue& value){
50+
ObjectAnnotation anno_obj;
51+
52+
const JsonObject& json_obj = value.to_object_throw();
53+
const JsonArray& user_box_array = json_obj.get_array_throw("UserBox");
54+
anno_obj.user_box = ImagePixelBox(
55+
size_t(user_box_array[0].to_integer_throw()),
56+
size_t(user_box_array[1].to_integer_throw()),
57+
size_t(user_box_array[2].to_integer_throw()),
58+
size_t(user_box_array[3].to_integer_throw())
59+
);
60+
const JsonArray& mask_box_array = json_obj.get_array_throw("MaskBox");
61+
anno_obj.mask_box = ImagePixelBox(
62+
size_t(mask_box_array[0].to_integer_throw()),
63+
size_t(mask_box_array[1].to_integer_throw()),
64+
size_t(mask_box_array[2].to_integer_throw()),
65+
size_t(mask_box_array[3].to_integer_throw())
66+
);
67+
size_t mask_width = anno_obj.mask_box.width(), mask_height = anno_obj.mask_box.height();
68+
anno_obj.mask.resize(mask_width * mask_height);
69+
const JsonArray& mask_values = json_obj.get_array_throw("Mask");
70+
for(size_t i = 0; i < anno_obj.mask.size(); i++){
71+
anno_obj.mask[i] = bool(mask_values[i].to_integer_throw());
72+
}
73+
74+
anno_obj.label = json_obj.get_string_throw("Label");
75+
76+
return anno_obj;
77+
}
78+
79+
JsonObject object_annotation_to_json(const ObjectAnnotation& object_annotation){
80+
JsonObject json_obj;
81+
JsonArray user_box_arr;
82+
user_box_arr.push_back(int64_t(object_annotation.user_box.min_x));
83+
user_box_arr.push_back(int64_t(object_annotation.user_box.min_y));
84+
user_box_arr.push_back(int64_t(object_annotation.user_box.max_x));
85+
user_box_arr.push_back(int64_t(object_annotation.user_box.max_y));
86+
json_obj["UserBox"] = std::move(user_box_arr);
87+
88+
JsonArray mask_box_arr;
89+
mask_box_arr.push_back(int64_t(object_annotation.mask_box.min_x));
90+
mask_box_arr.push_back(int64_t(object_annotation.mask_box.min_y));
91+
mask_box_arr.push_back(int64_t(object_annotation.mask_box.max_x));
92+
mask_box_arr.push_back(int64_t(object_annotation.mask_box.max_y));
93+
json_obj["MaskBox"] = std::move(mask_box_arr);
94+
95+
JsonArray mask_arr;
96+
for(size_t i = 0; i < object_annotation.mask.size(); i++){
97+
mask_arr.push_back(int64_t(object_annotation.mask[i]));
98+
}
99+
json_obj["Mask"] = std::move(mask_arr);
100+
101+
json_obj["Label"] = object_annotation.label;
102+
103+
return json_obj;
104+
}
105+
44106

45107
DrawnBoundingBox::DrawnBoundingBox(LabelImages_Widget& widget, VideoOverlay& overlay)
46108
: m_widget(widget)
@@ -62,7 +124,7 @@ DrawnBoundingBox::~DrawnBoundingBox(){
62124
void DrawnBoundingBox::on_config_value_changed(void* object){
63125
auto& program = m_widget.m_program;
64126
std::lock_guard<std::mutex> lg(m_lock);
65-
program.set_rendered_objects(m_widget.m_overlay_set);
127+
program.update_rendered_objects(m_widget.m_overlay_set);
66128
}
67129
void DrawnBoundingBox::on_mouse_press(double x, double y){
68130
auto& program = m_widget.m_program;
@@ -161,31 +223,99 @@ void LabelImages::from_json(const JsonValue& json){
161223
JsonValue LabelImages::to_json() const{
162224
JsonObject obj = std::move(*m_options.to_json().to_object());
163225
obj["SwitchSetup"] = m_switch_control_option.to_json();
226+
227+
// m_annotation_file_path
228+
if (m_annotation_file_path.size() > 0 && !m_fail_to_load_annotation_file){
229+
JsonArray anno_json_arr;
230+
for(const auto& anno_obj: m_annotated_objects){
231+
anno_json_arr.push_back(object_annotation_to_json(anno_obj));
232+
}
233+
cout << "Saving annotation to " << m_annotation_file_path << endl;
234+
anno_json_arr.dump(m_annotation_file_path);
235+
}
164236
return obj;
165237
}
166238
QWidget* LabelImages::make_widget(QWidget& parent, PanelHolder& holder){
167239
return new LabelImages_Widget(parent, *this, holder);
168240
}
169241

170-
void LabelImages::set_rendered_objects(VideoOverlaySet& overlay_set){
242+
void LabelImages::load_image_related_data(const std::string& image_path, size_t source_image_width, size_t source_image_height){
243+
this->source_image_height = source_image_height;
244+
this->source_image_width = source_image_width;
245+
246+
m_mask_image = ImageRGB32(source_image_width, source_image_height);
247+
cout << "Image source: " << image_path << ", " << source_image_width << " x " << source_image_height << endl;
248+
// if no such embedding file, m_iamge_embedding will be empty
249+
const bool embedding_loaded = load_image_embedding(image_path, m_image_embedding);
250+
if (!embedding_loaded){
251+
return; // no embedding, then no way for us to annotate
252+
}
253+
// see if we can load the previously created labels
254+
const std::string anno_filename = std::filesystem::path(image_path).filename().replace_extension(".json");
255+
256+
// ensure the folder exists
257+
std::filesystem::create_directory(ML_ANNOTATION_PATH());
258+
m_annotation_file_path = ML_ANNOTATION_PATH() + anno_filename;
259+
if (!std::filesystem::exists(m_annotation_file_path)){
260+
cout << "Annotataion output path, " << m_annotation_file_path << " does not exist yet" << endl;
261+
return;
262+
}
263+
std::string json_content;
264+
const bool anno_loaded = file_to_string(m_annotation_file_path, json_content);
265+
if (!anno_loaded){
266+
m_fail_to_load_annotation_file = true;
267+
QMessageBox box;
268+
box.warning(nullptr, "Unable to Load Annotation",
269+
QString::fromStdString("Cannot open annotation file " + m_annotation_file_path + ". Probably wrong permission?"));
270+
return;
271+
}
272+
273+
JsonValue loaded_json = parse_json(json_content);
274+
const JsonArray* json_array = loaded_json.to_array();
275+
if (json_array == nullptr){
276+
m_fail_to_load_annotation_file = true;
277+
QMessageBox box;
278+
box.warning(nullptr, "Unable to Load Annotation",
279+
QString::fromStdString("Cannot load annotation file " + m_annotation_file_path + ". Loaded json is not an array"));
280+
return;
281+
}
282+
283+
for(size_t i = 0; i < json_array->size(); i++){
284+
try{
285+
ObjectAnnotation anno_obj = json_to_object_annotation((*json_array)[i]);
286+
m_annotated_objects.emplace_back(std::move(anno_obj));
287+
} catch(JsonParseException & e){
288+
m_fail_to_load_annotation_file = true;
289+
QMessageBox box;
290+
box.warning(nullptr, "Unable to Load Annotation",
291+
QString::fromStdString("Cannot load annotation file " + m_annotation_file_path +
292+
". Parsing object " + std::to_string(i) + " failed."));
293+
}
294+
}
295+
m_last_object_idx = m_annotated_objects.size();
296+
cout << "Loaded existing annotation file " << m_annotation_file_path << endl;
297+
}
298+
299+
void LabelImages::update_rendered_objects(VideoOverlaySet& overlay_set){
171300
overlay_set.clear();
172301
overlay_set.add(COLOR_RED, {X, Y, WIDTH, HEIGHT});
173302

174-
for(const auto& obj : m_annotated_objects){
303+
for(size_t i_obj = 0; i_obj < m_annotated_objects.size(); i_obj++){
304+
const auto& obj = m_annotated_objects[i_obj];
175305
// overlayset.add(COLOR_RED, pixelbox_to_floatbox(source_image_width, source_image_height, obj.user_box));
176306
const auto mask_float_box = pixelbox_to_floatbox(source_image_width, source_image_height, obj.mask_box);
177307
std::string label = obj.label;
178308
const Pokemon::PokemonForm* form = Pokemon::get_pokemon_form(label);
179309
if (form != nullptr){
180310
label = form->display_name();
181311
}
182-
overlay_set.add(COLOR_BLUE, mask_float_box, label);
312+
Color mask_box_color = (i_obj == m_last_object_idx) ? COLOR_BLACK : COLOR_BLUE;
313+
overlay_set.add(mask_box_color, mask_float_box, label);
183314
size_t mask_width = obj.mask_box.width();
184315
size_t mask_height = obj.mask_box.height();
185316
ImageRGB32 mask_image(mask_width, mask_height);
186317
// cout << "in render, mask_box " << obj.mask_box.min_x << " " << obj.mask_box.min_y << " " << obj.mask_box.max_x << " " << obj.mask_box.max_y << endl;
187318

188-
// int count = 0;
189319
for (size_t y = 0; y < mask_height; y++){
190320
for (size_t x = 0; x < mask_width; x++){
191321
const bool mask = obj.mask[y*mask_width + x];
@@ -195,7 +325,6 @@ void LabelImages::set_rendered_objects(VideoOverlaySet& overlay_set){
195325
uint32_t color = 0;
196326
if (mask){
197327
color = (std::abs(int(x) - int(y)) % 4 <= 1) ? combine_argb(150, 30, 144, 255) : combine_argb(150, 0, 0, 60);
198-
// count++;
199328
}
200329
pixel = color;
201330
}
@@ -267,9 +396,10 @@ void LabelImages::compute_mask(VideoOverlaySet& overlay_set){
267396
}
268397

269398
annotation.label = label;
399+
m_last_object_idx = m_annotated_objects.size();
270400
m_annotated_objects.emplace_back(std::move(annotation));
271401

272-
set_rendered_objects(overlay_set);
402+
update_rendered_objects(overlay_set);
273403
}
274404
}
275405

@@ -309,10 +439,14 @@ LabelImages_Widget::LabelImages_Widget(
309439
QPushButton* button = new QPushButton("Delete Last Mask", scroll_inner);
310440
scroll_layout->addWidget(button);
311441
connect(button, &QPushButton::clicked, this, [this](bool){
312-
if (this->m_program.m_annotated_objects.size() > 0){
313-
this->m_program.m_annotated_objects.pop_back();
442+
auto& program = this->m_program;
443+
if (program.m_annotated_objects.size() > 0){
444+
program.m_annotated_objects.pop_back();
314445
}
315-
this->m_program.set_rendered_objects(this->m_overlay_set);
446+
if (program.m_annotated_objects.size() > 0){
447+
program.m_last_object_idx = program.m_annotated_objects.size() - 1;
448+
}
449+
program.update_rendered_objects(this->m_overlay_set);
316450
});
317451

318452
m_option_widget = instance.m_options.make_QtWidget(*scroll_inner);
@@ -324,24 +458,18 @@ LabelImages_Widget::LabelImages_Widget(
324458
const std::string image_path = image_source_desc->path();
325459
const size_t source_image_height = image_source_desc->source_image_height();
326460
const size_t source_image_width = image_source_desc->source_image_width();
327-
m_program.source_image_height = source_image_height;
328-
m_program.source_image_width = source_image_width;
329-
m_program.m_mask_image = ImageRGB32(source_image_width, source_image_height);
330-
cout << "Image source: " << image_path << ", " << source_image_width << " x " << source_image_height << endl;
331-
// if no such embedding file, m_iamge_embedding will be empty
332-
load_image_embedding(image_path, m_program.m_image_embedding);
461+
m_program.load_image_related_data(image_path, source_image_width, source_image_height);
462+
m_program.update_rendered_objects(m_overlay_set);
333463
}
334464

335465
cout << "LabelImages_Widget built" << endl;
336-
337-
// TODO: create a custom table to display the annotated bounding boxes
338466
}
339467

340468
void LabelImages_Widget::on_config_value_changed(void* object){
341-
if (m_program.m_annotated_objects.size() > 0){
342-
std::string& cur_label = m_program.m_annotated_objects.back().label;
469+
if (m_program.m_annotated_objects.size() > 0 && m_program.m_last_object_idx < m_program.m_annotated_objects.size()){
470+
std::string& cur_label = m_program.m_annotated_objects[m_program.m_last_object_idx].label;
343471
cur_label = m_program.FORM_LABEL.slug();
344-
m_program.set_rendered_objects(m_overlay_set);
472+
m_program.update_rendered_objects(m_overlay_set);
345473
}
346474
}
347475

SerialPrograms/Source/ML/Programs/ML_LabelImages.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ class LabelImages : public PanelInstance{
7070
virtual void from_json(const JsonValue& json) override;
7171
virtual JsonValue to_json() const override;
7272

73-
void set_rendered_objects(VideoOverlaySet& overlayset);
73+
void load_image_related_data(const std::string& image_path, const size_t source_image_width, const size_t source_image_height);
74+
75+
void update_rendered_objects(VideoOverlaySet& overlayset);
7476

7577
void compute_mask(VideoOverlaySet& overlay_set);
7678

@@ -94,9 +96,17 @@ class LabelImages : public PanelInstance{
9496
std::vector<float> m_image_embedding;
9597
std::vector<bool> m_output_boolean_mask;
9698

99+
// buffer to compute SAM mask on
97100
ImageRGB32 m_mask_image;
101+
98102
SAMSession m_sam_session;
99103
std::vector<ObjectAnnotation> m_annotated_objects;
104+
size_t m_last_object_idx = 0;
105+
std::string m_annotation_file_path;
106+
// if we find an annotation file that is supposed to be created by user in a previous session, but
107+
// we fail to load it, then we shouldn't overwrite this file to possibly erase the previous work.
108+
// so this flag is used to denote if we fail to load an annotation file
109+
bool m_fail_to_load_annotation_file = false;
100110
};
101111

102112

@@ -141,8 +151,6 @@ class LabelImages_Widget : public PanelWidget, public ConfigOption::Listener{
141151
ConfigWidget* m_option_widget;
142152

143153
friend class DrawnBoundingBox;
144-
// std::unique_ptr<ImageRGB32> m_image_mask;
145-
// std::unique_ptr<OverlayImage> m_overlay_image;
146154
};
147155

148156

0 commit comments

Comments
 (0)