diff --git a/CMakeLists.txt b/CMakeLists.txt
index ef72053..e913dcf 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.28)
-project(vision.cpp VERSION 0.1.0 LANGUAGES CXX)
+project(vision.cpp VERSION 0.2.0 LANGUAGES CXX)
option(VISP_VULKAN "Enable Vulkan support" OFF)
option(VISP_DEV "Enable development mode" OFF)
@@ -30,7 +30,7 @@ elseif(CMAKE_BUILD_TYPE)
endif()
endif()
-# Configure address sanitizer (Clang only)
+# Configure address sanitizer
if(VISP_ASAN)
if(MSVC)
diff --git a/README.md b/README.md
index 34a501a..12dc260 100644
--- a/README.md
+++ b/README.md
@@ -12,14 +12,17 @@ Based on [ggml](https://github.com/ggml-org/ggml) similar to the [llama.cpp](htt
### Features
-| Model | Task | Backends |
-| :-------------------------- | :--------------- | :---------- |
-| [**MobileSAM**](#mobilesam) | Segmentation | CPU, Vulkan |
-| [**BiRefNet**](#birefnet) | Segmentation | CPU, Vulkan |
-| [**MI-GAN**](#mi-gan) | Inpainting | CPU, Vulkan |
-| [**ESRGAN**](#real-esrgan) | Super-resolution | CPU, Vulkan |
+| Model | Task | Backends |
+| :--------------------------------------- | :----------------------- | :---------- |
+| [**MobileSAM**](#mobilesam) | Promptable segmentation | CPU, Vulkan |
+| [**BiRefNet**](#birefnet) | Dichotomous segmentation | CPU, Vulkan |
+| [**Depth-Anything**](#depth-anything-v2) | Depth estimation | CPU, Vulkan |
+| [**MI-GAN**](#mi-gan) | Inpainting | CPU, Vulkan |
+| [**ESRGAN**](#real-esrgan) | Super-resolution | CPU, Vulkan |
| [_Implement a model [**Guide**]_](docs/model-implementation-guide.md) | | |
+**Backbones:** SWIN (v1), DINO (v2), TinyViT
+
## Get Started
Get the library and executables:
@@ -92,6 +95,16 @@ vision-cli sam -m MobileSAM-F16.gguf -i input.png -p 300 200 -o mask.png --compo
vision-cli birefnet -m BiRefNet-lite-F16.gguf -i input.png -o mask.png --composite comp.png
```
+#### Depth-Anything V2
+
+
+
+[Model download](https://huggingface.co/Acly/Depth-Anything-V2-GGUF/tree/main) | [Paper (arXiv)](https://arxiv.org/abs/2406.09414) | [Repository (GitHub)](https://github.com/DepthAnything/Depth-Anything-V2) | License: Apache-2 / CC-BY-NC-4
+
+```sh
+vision-cli depth-anything -m Depth-Anything-V2-Small-F16.gguf -i input.png -o depth.png
+```
+
#### MI-GAN
@@ -191,10 +204,17 @@ as other frameworks for inference speed, but with:
| Model | | | _vision.cpp_ | PyTorch | ONNX Runtime |
| :---- | :--- | :--- | -----------: | -------: | -----------: |
-| Full | cpu | f32 | 16333 ms | 18800 ms | |
-| Full | gpu | f16 | 243 ms | 140 ms | |
+| Full | cpu | f32 | 16333 ms | 18290 ms | |
+| Full | gpu | f16 | 208 ms | 190 ms | |
| Lite | cpu | f32 | 4505 ms | 10900 ms | 6978 ms |
-| Lite | gpu | f16 | 86 ms | 59 ms | |
+| Lite | gpu | f16 | 85 ms | 84 ms | |
+
+#### Depth-Anything, 518x714
+
+| Model | | | _vision.cpp_ | PyTorch |
+| :---- | :--- | :--- | -----------: | ------: |
+| Small | gpu | f16 | 11 ms | 10 ms |
+| Base | gpu | f16 | 24 ms | 22 ms |
#### MI-GAN, 512x512
@@ -205,7 +225,7 @@ as other frameworks for inference speed, but with:
#### Setup
-* vision.cpp: using vision-bench, GPU via Vulkan, eg. `vision-bench -m sam -b cpu`
+* vision.cpp: using vision-bench, GPU via Vulkan, eg. `vision-bench -m sam`
* PyTorch: v2.7.1+cu128, eager eval, GPU via CUDA, average n iterations after warm-up
## Dependencies (integrated)
diff --git a/depend/ggml b/depend/ggml
index 96840f1..7d1a4d8 160000
--- a/depend/ggml
+++ b/depend/ggml
@@ -1 +1 @@
-Subproject commit 96840f15c3d0aa61a901c05003efd1976df4e5a8
+Subproject commit 7d1a4d803cb807b45beb9c4c6605013d9a8354f7
diff --git a/include/visp/image.h b/include/visp/image.h
index cb766cb..ddc2596 100644
--- a/include/visp/image.h
+++ b/include/visp/image.h
@@ -169,6 +169,12 @@ VISP_API void image_alpha_composite(
VISP_API image_data image_alpha_composite(
image_view const& fg, image_view const& bg, image_view const& mask);
+// Rescale pixels values such that the minimum value over all pixels becomes `min` and
+// the maximum becomes `max`. Channels are processed independently.
+VISP_API void image_normalize(
+ image_view const& src, image_span const& dst, float min = 0, float max = 1);
+VISP_API image_data image_normalize(image_view const& img, float min = 0, float max = 1);
+
// Compute root-mean-square difference between two images
VISP_API float image_difference_rms(image_view const& a, image_view const& b);
diff --git a/include/visp/ml.h b/include/visp/ml.h
index ed108a4..93a0af1 100644
--- a/include/visp/ml.h
+++ b/include/visp/ml.h
@@ -65,7 +65,8 @@ enum class model_build_flag {
conv_2d_direct_cwhn = 1 << 1,
concat_n = 1 << 2,
f16_conv_transpose = 1 << 3,
- window_partition = 1 << 4
+ window_partition = 1 << 4,
+ flash_attention = 1 << 5
}; // clang-format on
using model_build_flags = flags;
@@ -87,6 +88,7 @@ struct model_file {
VISP_API int64_t key(char const* name) const;
VISP_API int get_int(char const* name) const;
VISP_API std::string_view get_string(char const* name) const;
+ VISP_API void get_array(char const* name, span out_values) const;
};
// Opens a .gguf file and reads its contents into memory.
@@ -216,8 +218,10 @@ struct VISP_API tensor_data {
span as_f32();
span as_i32();
+ span as_bytes();
span as_f32() const;
span as_i32() const;
+ span as_bytes() const;
};
// Allocates data for a tensor in main memory, outside of context and backend buffers.
@@ -225,6 +229,7 @@ VISP_API tensor_data tensor_alloc(tensor x);
// Loads tensor data from a file storing raw numbers as binary.
VISP_API tensor_data tensor_load(tensor x, char const* filepath);
+VISP_API void tensor_save(tensor x, char const* filepath);
// Copies data to the tensor's backend buffer (which should already be allocated).
VISP_API void transfer_to_backend(tensor_data const&);
@@ -274,28 +279,6 @@ VISP_API tensor concat(model_ref const&, std::array src, i
// Up- or downsample a 2D tensor (WHCN) to target width x height.
VISP_API tensor interpolate(model_ref const&, tensor x, i64x2 target, int32_t mode);
-//
-// SWIN Transformer
-
-struct swin_layer_t {
- int depth;
- int n_heads;
- int n_features;
- bool downsample;
-};
-
-struct swin_params {
- static constexpr int n_layers = 4;
-
- int embed_dim;
- int window_size;
- std::array layers;
-};
-
-extern swin_params const swin_t_params;
-extern swin_params const swin_l_params;
-VISP_API swin_params swin_detect_params(model_file const&);
-
//
// implementation
diff --git a/include/visp/vision.h b/include/visp/vision.h
index 4daeaab..1e22096 100644
--- a/include/visp/vision.h
+++ b/include/visp/vision.h
@@ -57,8 +57,9 @@
// 7. Run the compute graph.
// 8. Transfer the output to the host and post-process it.
//
-// Custom pipelines are simply functions which call the individual steps and extend them
-// where needed. The implementation of the high-level API functions is a good starting point.
+// Custom pipelines can be created simply by writing a function that calls the
+// individual steps. As a starting point, check out or copy the implementation
+// of the high-level API functions. Then adapt them as needed.
// This allows to:
// * load model weights from a different source
// * control exactly when allocation happens
@@ -76,9 +77,46 @@
#include
#include
+#include
namespace visp {
+// SWIN v1 - vision transformer for feature extraction
+
+constexpr int swin_n_layers = 4;
+
+struct swin_layer_t {
+ int depth;
+ int n_heads;
+ int n_features;
+};
+
+struct swin_params {
+ int embed_dim;
+ int window_size;
+ std::array layers;
+};
+
+using swin_buffers = std::array;
+using swin_result = std::array;
+
+VISP_API swin_params swin_detect_params(model_file const&);
+VISP_API swin_buffers swin_precompute(model_ref, i32x2 image_extent, swin_params const&);
+VISP_API swin_result swin_encode(model_ref, tensor image, swin_params const&);
+
+// DINO v2 - vision transformer for feature extraction
+
+struct dino_params {
+ int patch_size = 16;
+ int embed_dim = 768;
+ int n_layers = 12;
+ int n_heads = 12;
+};
+
+VISP_API dino_params dino_detect_params(model_file const&);
+VISP_API std::vector dino_get_intermediate_layers(
+ model_ref, tensor image, span layers_ids, dino_params const&);
+
//
// Mobile SAM - image segmentation with prompt (point or box)
@@ -133,7 +171,9 @@ VISP_API image_data sam_process_mask(
struct birefnet_model;
// Loads a BiRefNet model from GGUF file onto the backend device.
-// * supports BiRefNet, BiRefNet_lite, BiRefNet_Matting variants at 1024px resolution
+// * supports BiRefNet, BiRefNet-lite, BiRefNet-Matting variants at 1024px resolution
+// * supports BiRefNet-HR variant at 2048px resolution
+// * supports BiRefNet-dynamic variant at arbitrary resolution
VISP_API birefnet_model birefnet_load_model(char const* filepath, backend_device const&);
// Takes RGB input and computes an alpha mask with foreground as 1.0 and background as 0.0.
@@ -148,7 +188,7 @@ struct birefnet_params {
swin_params encoder;
};
-using birefnet_buffers = std::array;
+using birefnet_buffers = swin_buffers;
VISP_API birefnet_params birefnet_detect_params(
model_file const&, i32x2 dynamic_extent = {}, size_t max_alloc = SIZE_MAX);
@@ -162,6 +202,39 @@ VISP_API image_data birefnet_process_output(
VISP_API tensor birefnet_predict(model_ref, tensor image, birefnet_params const&);
+//
+// Depth Anything - depth estimation
+
+struct depthany_model;
+
+// Loads a Depth Anything V2 model from GGUF file onto the backend device.
+// * supports Small/Base/Large variants with flexible input resolution
+VISP_API depthany_model depthany_load_model(char const* filepath, backend_device const&);
+
+// Takes RGB input and computes estimated depth (distance from camera).
+// Output is a single-channel float32 image in range [0, 1.0].
+VISP_API image_data depthany_compute(depthany_model&, image_view image);
+
+// --- Depth Anything pipeline
+
+struct depthany_params {
+ int image_size = 518;
+ int image_multiple = 14;
+ i32x2 image_extent = {518, 518};
+ float max_depth = 1;
+ std::array feature_layers = {2, 5, 8, 11};
+ dino_params dino;
+};
+
+VISP_API depthany_params depthany_detect_params(model_file const&, i32x2 input_extent = {});
+VISP_API i32x2 depthany_image_extent(i32x2 input_extent, depthany_params const&);
+
+VISP_API image_data depthany_process_input(image_view image, depthany_params const&);
+image_data depthany_process_output(
+ std::span output_data, i32x2 target_extent, depthany_params const&);
+
+VISP_API tensor depthany_predict(model_ref, tensor image, depthany_params const&);
+
//
// MI-GAN - image inpainting
@@ -246,6 +319,17 @@ struct birefnet_model {
tensor output = nullptr;
};
+// internal
+struct depthany_model {
+ backend_device const* backend = nullptr;
+ model_weights weights;
+ depthany_params params;
+
+ compute_graph graph;
+ tensor input = nullptr;
+ tensor output = nullptr;
+};
+
// internal
struct migan_model {
backend_device const* backend = nullptr;
diff --git a/models/CMakeLists.txt b/models/CMakeLists.txt
index d1afb96..a5ad052 100644
--- a/models/CMakeLists.txt
+++ b/models/CMakeLists.txt
@@ -14,6 +14,13 @@ file(DOWNLOAD
EXPECTED_HASH "SHA256=7b5397a2c98d66677f8f74317774bbeac49dbb321b8a3dc744af913db71d4fa5"
SHOW_PROGRESS
)
+message(STATUS "Checking for models/Depth-Anything-V2-Small-F16.gguf")
+file(DOWNLOAD
+ "https://huggingface.co/Acly/Depth-Anything-V2-GGUF/resolve/main/Depth-Anything-V2-Small-F16.gguf"
+ ${CMAKE_CURRENT_LIST_DIR}/Depth-Anything-V2-Small-F16.gguf
+ EXPECTED_HASH "SHA256=0f83332d6a8b4375cd7fdcc168f3e3636f474f8e84b0959e903f513aace782f5"
+ SHOW_PROGRESS
+)
message(STATUS "Checking for models/MIGAN-512-places2-F16.gguf")
file(DOWNLOAD
"https://huggingface.co/Acly/MIGAN-GGUF/resolve/main/MIGAN-512-places2-F16.gguf"
diff --git a/scripts/convert.py b/scripts/convert.py
index 054bf42..cc91d63 100644
--- a/scripts/convert.py
+++ b/scripts/convert.py
@@ -93,6 +93,14 @@ def add_conv2d_weight_indices(self):
self.add_array(f"{self.arch}.conv2d_weights", self.conv2d_weights)
+def load_model(path: Path) -> dict[str, Tensor]:
+ if path.suffix in [".safetensors", ".safetensor"]:
+ weights = safetensors.safe_open(path, "pt")
+ return {k: weights.get_tensor(k) for k in weights.keys()}
+ else:
+ return torch.load(path, map_location="cpu", weights_only=True)
+
+
batch_norm_eps = 1e-5
@@ -100,7 +108,7 @@ def is_conv_2d(name: str, tensor: Tensor):
return (
tensor.ndim == 4
and tensor.shape[2] == tensor.shape[3]
- and tensor.shape[2] in (1, 3, 4, 7)
+ and tensor.shape[2] in (1, 3, 4, 7, 14)
and name.endswith("weight")
)
@@ -341,6 +349,60 @@ def convert_birefnet(input_filepath: Path, writer: Writer):
writer.add_tensor(name, tensor)
+#
+# Depth-Anything
+
+
+def convert_depth_anything(input_filepath: Path, writer: Writer):
+ if "small" in input_filepath.name.lower():
+ writer.add_license("apache-2.0")
+ else:
+ writer.add_license("cc-by-nc-4.0")
+ writer.set_tensor_layout_default(TensorLayout.nchw)
+
+ model: dict[str, Tensor] = load_model(input_filepath)
+
+ if "pretrained.cls_token" in model:
+ print("The converter is written for the transformers (.safetensors) version of the model.")
+ print("The original weights (.pth) are currently not supported.")
+ raise ValueError("Weights not supported")
+
+ shape = model["backbone.embeddings.patch_embeddings.projection.weight"].shape
+ writer.add_int32("dino.patch_size", shape[2])
+ writer.add_int32("dino.embed_dim", shape[0])
+ writer.add_int32("depthanything.image_size", 518)
+ match shape[0]:
+ case 384: # Small
+ writer.add_int32("dino.n_heads", 6)
+ writer.add_int32("dino.n_layers", 12)
+ writer.add_array("depthanything.feature_layers", [2, 5, 8, 11])
+ case 768: # Base
+ writer.add_int32("dino.n_heads", 12)
+ writer.add_int32("dino.n_layers", 12)
+ writer.add_array("depthanything.feature_layers", [2, 5, 8, 11])
+ case 1024: # Large
+ writer.add_int32("dino.n_heads", 16)
+ writer.add_int32("dino.n_layers", 24)
+ writer.add_array("depthanything.feature_layers", [4, 11, 17, 23])
+
+ for key, tensor in model.items():
+ name = key
+
+ if is_conv_2d(name, tensor):
+ if "patch_embeddings" in name or ("projection" in name and "fusion" not in name):
+ tensor = conv_2d_to_nhwc(tensor)
+ elif "0.resize" in name or "1.resize" in name:
+ pass # ConvTranspose2D, don't change layout
+ else:
+ tensor = writer.convert_tensor_2d(tensor)
+
+ if "position_embeddings" in name or "cls_token" in name:
+ writer.add_tensor(name, tensor, "f32")
+ continue
+
+ writer.add_tensor(name, tensor)
+
+
#
# MI-GAN
@@ -400,6 +462,7 @@ def convert_esrgan(input_filepath: Path, writer: Writer):
arch_names = {
"sam": "mobile-sam",
"birefnet": "birefnet",
+ "depth-anything": "depthanything",
"migan": "migan",
"esrgan": "esrgan",
}
@@ -448,6 +511,8 @@ def convert_esrgan(input_filepath: Path, writer: Writer):
convert_sam(input_path, writer)
case "birefnet":
convert_birefnet(input_path, writer)
+ case "depthany" | "depth-anything":
+ convert_depth_anything(input_path, writer)
case "migan":
convert_migan(input_path, writer)
case "esrgan":
diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp
index 3e37434..fc7f2a1 100644
--- a/src/cli/cli.cpp
+++ b/src/cli/cli.cpp
@@ -13,7 +13,7 @@
namespace visp {
using std::filesystem::path;
-enum class cli_command { none, sam, birefnet, migan, esrgan };
+enum class cli_command { none, sam, birefnet, depth_anything, migan, esrgan };
struct cli_args {
cli_command command = cli_command::none;
@@ -38,6 +38,7 @@ Usage: vision-cli [options]
Commands:
sam - MobileSAM image segmentation
birefnet - BirefNet background removal
+ depthany - Depth-Anything depth estimation
migan - MI-GAN inpainting
esrgan - ESRGAN/Real-ESRGAN upscaling
@@ -119,6 +120,8 @@ cli_args cli_parse(int argc, char** argv) {
r.command = cli_command::sam;
} else if (arg1 == "birefnet") {
r.command = cli_command::birefnet;
+ } else if (arg1 == "depthany" || arg1 == "depth-anything") {
+ r.command = cli_command::depth_anything;
} else if (arg1 == "migan") {
r.command = cli_command::migan;
} else if (arg1 == "esrgan") {
@@ -162,6 +165,7 @@ cli_args cli_parse(int argc, char** argv) {
void run_sam(cli_args const&);
void run_birefnet(cli_args const&);
+void run_depth_anything(cli_args const&);
void run_migan(cli_args const&);
void run_esrgan(cli_args const&);
@@ -179,6 +183,7 @@ int main(int argc, char** argv) {
switch (args.command) {
case cli_command::sam: run_sam(args); break;
case cli_command::birefnet: run_birefnet(args); break;
+ case cli_command::depth_anything: run_depth_anything(args); break;
case cli_command::migan: run_migan(args); break;
case cli_command::esrgan: run_esrgan(args); break;
case cli_command::none: break;
@@ -266,6 +271,11 @@ std::tuple load_model_weights(
return {std::move(file), std::move(weights)};
}
+void print_model_flags(model_ref const& m) {
+ bool flash_attn = !!(m.flags & model_build_flag::flash_attention);
+ printf("- flash attention: %s\n", flash_attn ? "on" : "off");
+}
+
void compute_timed(compute_graph const& g, backend_device const& b) {
timer t;
printf("Running inference... ");
@@ -409,6 +419,7 @@ void run_birefnet(cli_args const& args) {
compute_graph graph = compute_graph_init(6 * 1024);
model_ref m(weights, graph);
+ print_model_flags(m);
birefnet_buffers buffers = birefnet_precompute(m, params);
tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1});
@@ -432,6 +443,42 @@ void run_birefnet(cli_args const& args) {
composite_image_with_mask(image, mask_resized, args.composite);
}
+//
+// Depth Anything
+
+void run_depth_anything(cli_args const& args) {
+ backend_device backend = backend_init(args);
+ auto [file, weights] = load_model_weights(
+ args, backend, "models/DepthAnythingV2-Small-F32.gguf", 0, backend.preferred_layout());
+
+ require_inputs(args.inputs, 1, "");
+ image_data image = image_load(args.inputs[0]);
+ depthany_params params = depthany_detect_params(file, image.extent);
+ image_data input_data = depthany_process_input(image, params);
+
+ i32x2 extent = params.image_extent;
+ printf("- model image size: %d\n", params.image_size);
+ printf("- inference image size: %dx%d\n", params.image_extent[0], params.image_extent[1]);
+
+ compute_graph graph = compute_graph_init();
+ model_ref m(weights, graph);
+ print_model_flags(m);
+
+ tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1});
+ tensor output = depthany_predict(m, input, params);
+
+ compute_graph_allocate(graph, backend);
+ transfer_to_backend(input, input_data);
+
+ compute_timed(graph, backend);
+
+ tensor_data output_data = transfer_from_backend(output);
+ image_data depth_raw = depthany_process_output(output_data.as_f32(), image.extent, params);
+ image_data depth_image = image_f32_to_u8(depth_raw, image_format::alpha_u8);
+ image_save(depth_image, args.output);
+ printf("-> depth image saved to %s\n", args.output);
+}
+
//
// MI-GAN
diff --git a/src/util/math.h b/src/util/math.h
index 835229d..ed4dd24 100644
--- a/src/util/math.h
+++ b/src/util/math.h
@@ -57,7 +57,12 @@ constexpr i32x2 operator/(i32x2 a, int32_t b) { return {a[0] / b, a[1] / b}; }
constexpr i32x2 div_ceil(i32x2 a, i32x2 b) { return {div_ceil(a[0], b[0]), div_ceil(a[1], b[1])}; }
constexpr i32x2 div_ceil(i32x2 a, int32_t b) { return div_ceil(a, i32x2{b, b}); }
+constexpr i32x2 next_multiple(i32x2 x, int32_t mult) { return div_ceil(x, mult) * mult; }
constexpr i32x2 min(i32x2 a, i32x2 b) { return {std::min(a[0], b[0]), std::min(a[1], b[1])}; }
+// i64x2 operations
+constexpr i64x2 operator*(i64x2 a, int64_t b) { return {a[0] * b, a[1] * b}; }
+constexpr i64x2 operator/(i64x2 a, int64_t b) { return {a[0] / b, a[1] / b}; }
+
// clang-format on
} // namespace visp
\ No newline at end of file
diff --git a/src/visp/CMakeLists.txt b/src/visp/CMakeLists.txt
index 5cdbd54..14d7964 100644
--- a/src/visp/CMakeLists.txt
+++ b/src/visp/CMakeLists.txt
@@ -2,9 +2,12 @@ add_library(visioncpp SHARED)
target_sources(visioncpp PRIVATE
arch/birefnet.cpp
+ arch/depth-anything.cpp
+ arch/dino.cpp
arch/esrgan.cpp
arch/migan.cpp
arch/mobile-sam.cpp
+ arch/swin.cpp
image.cpp
ml.cpp
nn.cpp
diff --git a/src/visp/arch/birefnet.cpp b/src/visp/arch/birefnet.cpp
index b294b87..37915db 100644
--- a/src/visp/arch/birefnet.cpp
+++ b/src/visp/arch/birefnet.cpp
@@ -1,6 +1,7 @@
#include "visp/arch/birefnet.h"
#include "util/math.h"
#include "util/string.h"
+#include "visp/arch/swin.h"
#include "visp/nn.h"
#include "visp/vision.h"
@@ -9,290 +10,9 @@
namespace visp {
namespace birefnet {
-tensor mlp(model_ref m, tensor x) {
- x = linear(m["fc1"], x);
- x = ggml_gelu_inplace(m, x);
- x = linear(m["fc2"], x);
- return named(m, x);
-}
-
-// Ensures that the tensor's data is not overwritten during computation.
-tensor make_constant(tensor x, tensor_name name) {
- ggml_set_name(x, name.c_str());
- ggml_set_input(x); // allocate at the beginning of the graph buffer
- ggml_set_output(x); // don't reuse memory for computations
- return x;
-}
-
-void compute_relative_position_index(span dst, int window_size) {
- int n = window_size;
- int n2 = n * n;
- int n4 = n2 * n2;
- for (int i = 0; i < n4; ++i) {
- int x0 = i % n;
- int y0 = (i / n) % n;
- int x1 = (i / n2) % n;
- int y1 = (i / n2 / n) % n;
- dst[i] = (y1 - y0 + n - 1) * (2 * n - 1) + (x1 - x0 + n - 1);
- }
-}
-
-tensor_data create_relative_position_index(ggml_context* ctx, int window_size) {
- int n = window_size;
- auto result = tensor_alloc(ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n * n * n * n));
- auto name = format("window_attention_{}.rel_pos_index", n);
- compute_relative_position_index(result.as_i32(), n);
- make_constant(result.x, name);
- return result;
-}
-
-tensor window_partition(model_ref m, tensor x, int window) {
- auto [c, w, h, b] = nelements(x);
- ASSERT(w % window == 0 && h % window == 0, "Expecting padded input");
-
- x = ggml_reshape_4d(m, x, c * window, w / window, window, (h / window) * b);
- x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
- x = ggml_reshape_3d(m, x, c, window * window, (w / window) * (h / window) * b);
- return x;
-}
-
-tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) {
- int64_t c = x->ne[0];
- int64_t b = x->ne[2] / (w / window) / (h / window);
- ASSERT(x->ne[2] % (w / window) == 0, "Expecting ne[2] to be multiple of window count");
-
- x = ggml_reshape_4d(m, x, c * window, window, w / window, (h / window) * b);
- x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
- x = ggml_reshape_4d(m, x, c, w, h, b);
- return x;
-}
-
-tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window) {
- auto [c, n, b, _] = nelements(x);
-
- tensor qkv = linear(m["qkv"], x);
- qkv = ggml_reshape_4d(m, qkv, c / num_heads, num_heads, 3, n * b);
- qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2));
-
- auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable {
- tensor = slice(m, tensor, {}, {}, {}, index);
- tensor = ggml_reshape_4d(m, tensor, c / num_heads, num_heads, n, b);
- if (transpose) {
- tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3));
- } else {
- tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3));
- }
- return tensor;
- };
- tensor q = split(qkv, 0);
- tensor k = split(qkv, 1);
- tensor v = split(qkv, 2, true);
-
- q = ggml_scale_inplace(m, q, 1.0f / std::sqrt(float(c / num_heads)));
-
- tensor attn = ggml_mul_mat(m, k, q);
-
- tensor_name rel_pos_name = format("window_attention_{}.rel_pos_index", window);
- tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str());
- tensor rel_pos_table = m.weights("relative_position_bias_table");
- tensor rel_pos_bias = ggml_get_rows(m, rel_pos_table, rel_pos_index);
- rel_pos_bias = ggml_reshape_4d(m, rel_pos_bias, num_heads, window * window, window * window, 1);
- rel_pos_bias = ggml_cont(m, ggml_permute(m, rel_pos_bias, 2, 0, 1, 3));
- attn = ggml_add_inplace(m, attn, rel_pos_bias);
-
- if (mask) {
- int64_t nw = mask->ne[2];
- attn = ggml_reshape_4d(m, attn, n * n, num_heads, nw, b / nw);
- mask = ggml_reshape_4d(m, mask, n * n, 1, nw, 1);
- attn = ggml_add_inplace(m, attn, mask);
- attn = ggml_reshape_4d(m, attn, n, n, num_heads, b);
- }
- attn = ggml_soft_max(m, attn);
-
- x = ggml_mul_mat(m, v, attn);
- x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
- x = ggml_reshape_3d(m, x, c, n, b);
-
- x = linear(m["proj"], x);
- return named(m, x);
-}
-
-tensor swin_block(model_ref m, tensor x, tensor mask, swin_block_params const& p) {
- auto [c, n, b, _] = nelements(x);
- auto [num_heads, window, w, h, shift] = p;
- ASSERT(n == w * h && "Spatial dimensions do not match");
-
- tensor shortcut = x;
- x = layer_norm(m["norm1"], x);
- x = ggml_reshape_4d(m, x, c, w, h, b);
-
- int pad_r = (window - w % window) % window;
- int pad_b = (window - h % window) % window;
- if (pad_r > 0 || pad_b > 0) {
- x = ggml_pad(m, x, 0, pad_r, pad_b, 0);
- }
-
- ASSERT(shift == 0 || mask != nullptr);
- if (shift > 0) {
- x = ggml_roll(m, x, 0, -shift, -shift, 0);
- }
-
- x = window_partition(m, x, window);
- x = window_attention(m["attn"], x, mask, num_heads, window);
- x = window_reverse(m, x, w + pad_r, h + pad_b, window);
-
- if (shift > 0) { // undo shift
- x = ggml_roll(m, x, 0, shift, shift, 0);
- }
-
- if (pad_r > 0 || pad_b > 0) { // undo padding
- x = ggml_reshape_4d(m, x, c, w + pad_r, h + pad_b, b);
- x = slice(m, x, {}, {0, w}, {0, h}, {});
- x = ggml_cont(m, x);
- }
-
- x = ggml_reshape_3d(m, x, c, n, b);
- x = ggml_add_inplace(m, x, shortcut);
-
- tensor x_mlp = layer_norm(m["norm2"], x);
- x_mlp = mlp(m["mlp"], x_mlp);
- x = ggml_add_inplace(m, x, x_mlp);
-
- return named(m, x);
-}
-
-tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h) {
- auto [c, n, b, _] = nelements(x);
- ASSERT(n == w * h, "Spatial dimensions do not match");
- ASSERT(w % 2 == 0 && h % 2 == 0, "Expecting even spatial dimensions");
-
- x = ggml_reshape_4d(m, x, c, w, h, b);
- // clang-format off
- x = concat(m, {
- slice(m, x, {}, {0, w, 2}, {0, h, 2}, {}),
- slice(m, x, {}, {0, w, 2}, {1, h, 2}, {}),
- slice(m, x, {}, {1, w, 2}, {0, h, 2}, {}),
- slice(m, x, {}, {1, w, 2}, {1, h, 2}, {})}, 0);
- // clang-format on
- x = ggml_reshape_3d(m, x, c * 4, n / 4, b);
-
- x = layer_norm(m["norm"], x);
- x = linear(m["reduction"], x);
- return named(m, x);
-}
-
-void compute_attention_mask(span out, int64_t w, int64_t h, int window_size) {
- int n = window_size;
- int n2 = n * n;
- int n4 = n2 * n2;
- int shift = window_size / 2;
- int64_t nw_x = (w + n - 1) / n;
- int64_t nw_y = (h + n - 1) / n;
- int64_t w_pad = nw_x * n;
- int64_t h_pad = nw_y * n;
-
- std::fill(out.begin(), out.end(), 0.0f);
-
- for (int iw_y = 0; iw_y < nw_y; ++iw_y) {
- for (int iw_x = 0; iw_x < nw_x; ++iw_x) {
- // Skip all windows that aren't at the right or bottom edges of the image
- if (iw_y < nw_y - 1 && iw_x < nw_x - 1) {
- continue;
- }
- int64_t base = iw_y * nw_x * n4 + iw_x * n4;
-
- for (int y0 = 0; y0 < n; ++y0) {
- for (int x0 = 0; x0 < n; ++x0) {
- for (int y1 = 0; y1 < n; ++y1) {
- for (int x1 = 0; x1 < n; ++x1) {
- // Window-local coordinates to global image coordinates
- int yy0 = iw_y * n + y0;
- int xx0 = iw_x * n + x0;
- int yy1 = iw_y * n + y1;
- int xx1 = iw_x * n + x1;
- // Check if two patches being matched belong to the same window
- // that is: they are both in the shift zone, or both outside
- bool match_y = (yy0 < h_pad - shift) == (yy1 < h_pad - shift);
- bool match_x = (xx0 < w_pad - shift) == (xx1 < w_pad - shift);
- // If not, set mask to -100 (added to attention before softmax)
- if (!match_y || !match_x) {
- int64_t idx = base + (y0 * n + x0) * n2 + (y1 * n + x1);
- out[idx] = -100.f;
- }
- }
- }
- }
- }
- }
- }
-}
-
-tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size) {
- int n = window_size;
- int64_t nw_x = (w + n - 1) / n;
- int64_t nw_y = (h + n - 1) / n;
- auto result = tensor_alloc(ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n * n, n * n, nw_x * nw_y));
- auto name = format("swin_layer_{}x{}.attn_mask", w, h);
- compute_attention_mask(result.as_f32(), w, h, window_size);
- make_constant(result.x, name);
- return result;
-}
-
-swin_layer_result swin_layer(
- model_ref m, tensor x, int64_t w, int64_t h, swin_layer_t const& p, int window_size) {
- // Attention masks need to be precomputed
- tensor_name attn_mask_name = format("swin_layer_{}x{}.attn_mask", w, h);
- tensor attn_mask = ggml_get_tensor(m, attn_mask_name.c_str());
-
- model_ref blocks = m["blocks"];
- for (int i = 0; i < p.depth; ++i) {
- x = swin_block(
- blocks[i], x, attn_mask,
- {.n_heads = p.n_heads,
- .window_size = window_size,
- .w = w,
- .h = h,
- .shift = i % 2 == 0 ? 0 : window_size / 2});
- }
- if (p.downsample) {
- tensor x_down = patch_merging(m["downsample"], x, w, h);
- return {x, w, h, x_down, (w + 1) / 2, (h + 1) / 2};
- }
- return {x, w, h, x, w, h};
-}
-
-tensor patch_embed(model_ref m, tensor x, int patch_size) {
- ASSERT(x->ne[1] % patch_size == 0 && x->ne[2] % patch_size == 0);
-
- m.flags |= model_build_flag::cwhn;
- x = conv_2d(m["proj"], x, patch_size);
- auto [c, ww, wh, b] = nelements(x);
- x = ggml_reshape_3d(m, x, c, ww * wh, b);
- x = layer_norm(m["norm"], x);
- x = ggml_reshape_4d(m, x, c, ww, wh, b);
- return named(m, x);
-}
-
-swin_result swin_transformer(model_ref m, tensor x, swin_params const& p) {
- x = patch_embed(m["patch_embed"], x, 4);
-
- auto [c, w, h, b] = nelements(x);
- x = ggml_reshape_3d(m, x, c, w * h, b);
-
- swin_layer_result r{x, w, h, x, w, h};
- swin_result outs = {};
-
- for (int i = 0; i < swin_params::n_layers; ++i) {
- model_ref layer = m["layers"][i];
- r = swin_layer(layer, r.x_down, r.w_down, r.h_down, p.layers[i], p.window_size);
-
- tensor_name norm_layer = format("norm{}", i);
- tensor out = layer_norm(m[norm_layer], r.x_out);
- out = ggml_reshape_4d(m, out, p.layers[i].n_features, r.w_out, r.h_out, b);
- outs[i] = out;
- }
- return outs;
-}
+//
+// Encoder
+//
constexpr int32_t bilinear_align_corners = GGML_SCALE_MODE_BILINEAR |
(int)GGML_SCALE_FLAG_ALIGN_CORNERS;
@@ -345,9 +65,9 @@ swin_result encode_concat(model_ref m, swin_result& xs, swin_result& xs_low) {
}
swin_result encode(model_ref m, tensor x, swin_params const& p) {
- auto xs = swin_transformer(m["bb"], x, p);
+ auto xs = swin_encode(m["bb"], x, p);
auto x_low = downscale_by(m, x, 2);
- auto xs_low = swin_transformer(m["bb"], x_low, p);
+ auto xs_low = swin_encode(m["bb"], x_low, p);
encode_concat(m, xs, xs_low);
return xs;
}
@@ -531,7 +251,7 @@ tensor decode(model_ref m, tensor x, swin_result const& features) {
tensor birefnet_predict(model_ref m, tensor image, birefnet_params const& p) {
// Encoder
- birefnet::swin_result features = birefnet::encode(m, image, p.encoder);
+ swin_result features = birefnet::encode(m, image, p.encoder);
// Squeeze block
features[3] = birefnet::basic_decoder_block(m["squeeze_module.0"], features[3]);
// Decoder
@@ -565,52 +285,6 @@ image_data birefnet_process_output(
return image_f32_to_u8(mask_output, image_format::alpha_u8);
}
-birefnet_buffers birefnet_precompute(model_ref m, birefnet_params const& params) {
- int w = params.encoder.window_size;
- int width = params.image_extent[0] / 4;
- int height = params.image_extent[1] / 4;
-
- birefnet_buffers b;
- b[0] = birefnet::create_relative_position_index(m, w);
- for (int i = 0; i < swin_params::n_layers + 1; ++i) {
- b[i + 1] = birefnet::create_attention_mask(m, width >> i, height >> i, w);
- }
- return b;
-}
-
-// clang-format off
-const swin_params swin_t_params = {
- .embed_dim = 96,
- .window_size = 7,
- .layers = {
- // depth n_heads n_features downsample
- swin_layer_t{2, 3, 96 * 1, true},
- swin_layer_t{2, 6, 96 * 2, true},
- swin_layer_t{6, 12, 96 * 4, true},
- swin_layer_t{2, 24, 96 * 8, false}}};
-
-const swin_params swin_l_params = {
- .embed_dim = 192,
- .window_size = 12,
- .layers = {
- // depth n_heads n_features downsample
- swin_layer_t{2, 6, 192 * 1, true},
- swin_layer_t{2, 12, 192 * 2, true},
- swin_layer_t{18, 24, 192 * 4, true},
- swin_layer_t{2, 48, 192 * 8, false}}};
-// clang-format on
-
-swin_params swin_detect_params(model_file const& f) {
- int embed_dim = f.get_int("swin.embed_dim");
- if (embed_dim == 96) {
- return swin_t_params;
- } else if (embed_dim == 192) {
- return swin_l_params;
- } else {
- throw except("Unsupported Swin Transformer embed dim: {}", embed_dim);
- }
-}
-
i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const& p, size_t max_alloc) {
i32x2 extent{p.image_size, p.image_size};
if (p.image_size == -1) {
@@ -644,4 +318,8 @@ birefnet_params birefnet_detect_params(
return p;
}
+birefnet_buffers birefnet_precompute(model_ref m, birefnet_params const& p) {
+ return swin_precompute(m, p.image_extent, p.encoder);
+}
+
} // namespace visp
diff --git a/src/visp/arch/birefnet.h b/src/visp/arch/birefnet.h
index 7f109ad..90d855c 100644
--- a/src/visp/arch/birefnet.h
+++ b/src/visp/arch/birefnet.h
@@ -1,50 +1,10 @@
#pragma once
-#include "visp/ml.h"
#include "visp/image.h"
+#include "visp/ml.h"
+#include "visp/vision.h"
-#include
-
-namespace visp {
-
-namespace birefnet {
-
-// SWIN Transformer
-
-struct swin_block_params {
- int n_heads = 6;
- int window_size = 7;
- int64_t w = 0;
- int64_t h = 0;
- int shift = 0;
-};
-
-struct swin_layer_result {
- tensor x_out;
- int64_t w_out;
- int64_t h_out;
- tensor x_down;
- int64_t w_down;
- int64_t h_down;
-};
-
-using swin_result = std::array;
-
-void compute_relative_position_index(span dst, int window_size);
-tensor_data create_relative_position_index(ggml_context* ctx, int window_size);
-void compute_attention_mask(std::span out, int64_t w, int64_t h, int window_size);
-tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size);
-
-tensor mlp(model_ref m, tensor x);
-tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h);
-tensor patch_embed(model_ref m, tensor x, int patch_size = 4);
-tensor window_partition(model_ref m, tensor x, int window);
-tensor window_reverse(model_ref m, tensor x, int w, int h, int window);
-tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window);
-tensor swin_block(model_ref m, tensor x, tensor mask, swin_block_params const&);
-swin_layer_result swin_layer(
- model_ref m, tensor x, int64_t w, int64_t h, swin_layer_t const&, int window_size);
-swin_result swin_transformer(model_ref m, tensor x, swin_params const& p);
+namespace visp::birefnet {
// Encoder
@@ -64,5 +24,4 @@ tensor image_to_patches(model_ref m, tensor x, int64_t out_w, int64_t out_h);
tensor gdt_conv(model_ref m, tensor x);
tensor decode(model_ref m, tensor x, swin_result const& features);
-} // namespace birefnet
-} // namespace visp
\ No newline at end of file
+} // namespace visp::birefnet
\ No newline at end of file
diff --git a/src/visp/arch/depth-anything.cpp b/src/visp/arch/depth-anything.cpp
new file mode 100644
index 0000000..22a4127
--- /dev/null
+++ b/src/visp/arch/depth-anything.cpp
@@ -0,0 +1,151 @@
+
+#include "visp/arch/depth-anything.h"
+#include "util/math.h"
+#include "util/string.h"
+#include "visp/arch/dino.h"
+#include "visp/ml.h"
+#include "visp/nn.h"
+
+namespace visp {
+namespace dpt {
+
+int32_t const bilinear_align_corners = int32_t(GGML_SCALE_MODE_BILINEAR) |
+ GGML_SCALE_FLAG_ALIGN_CORNERS;
+
+tensor residual_conv(model_ref m, tensor x) {
+ tensor out = x;
+ out = ggml_relu(m, out);
+ out = conv_2d(m["convolution1"], out, 1, 1);
+ out = ggml_relu(m, out);
+ out = conv_2d(m["convolution2"], out, 1, 1);
+ x = ggml_add_inplace(m, x, out);
+ return named(m, x);
+}
+
+tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size) {
+ tensor x = x0;
+ if (x1) {
+ tensor res = residual_conv(m["residual_layer1"], x1);
+ x = ggml_add_inplace(m, x, res);
+ }
+ x = residual_conv(m["residual_layer2"], x);
+
+ int const dim = m.flags & model_build_flag::cwhn ? 1 : 0;
+ int64_t w = size ? size[dim + 0] : x->ne[dim + 0] * 2;
+ int64_t h = size ? size[dim + 1] : x->ne[dim + 1] * 2;
+ x = contiguous_2d_to_whcn(m, x);
+ x = interpolate(m, x, {w, h}, bilinear_align_corners);
+ x = whcn_to_contiguous_2d(m, x);
+
+ x = conv_2d(m["projection"], x);
+ return named(m, x);
+}
+
+tensor neck(model_ref m, span features, int64_t patch_w, int64_t patch_h) {
+ ASSERT(features.size() == 4);
+ std::array layer;
+
+ model_ref reassemble = m["reassemble_stage.layers"];
+ for (int i = 0; i < 4; ++i) {
+ tensor x = features[i];
+ x = slice(m, x, {}, {1, x->ne[1]}, {}, {});
+ x = ggml_reshape_4d(m, x, x->ne[0], patch_w, patch_h, x->ne[3]);
+
+ model_ref proj = reassemble[i]["projection"];
+ proj.flags |= model_build_flag::cwhn;
+ x = conv_2d(proj, x); // 1x1 conv, keep CWHN layout and directly use mul_mat
+
+ x = cwhn_to_contiguous_2d(m, x);
+ switch (i) {
+ case 0: x = conv_transpose_2d(reassemble[i]["resize"], x, 4); break;
+ case 1: x = conv_transpose_2d(reassemble[i]["resize"], x, 2); break;
+ case 3: x = conv_2d(reassemble[i]["resize"], x, 2, 1); break;
+ }
+ layer[i] = x;
+ }
+
+ model_ref convs = m["convs"];
+ for (int i = 0; i < 4; ++i) {
+ layer[i] = conv_2d(convs[i], layer[i], 1, 1);
+ }
+
+ model_ref fusion = m["fusion_stage.layers"];
+ tensor fused;
+ fused = feature_fusion(fusion[0], layer[3], nullptr, layer[2]->ne);
+ fused = feature_fusion(fusion[1], fused, layer[2], layer[1]->ne);
+ fused = feature_fusion(fusion[2], fused, layer[1], layer[0]->ne);
+ fused = feature_fusion(fusion[3], fused, layer[0]);
+ return fused;
+}
+
+tensor head(model_ref m, tensor x, int64_t w, int64_t h, float max_depth) {
+ tensor out = conv_2d(m["conv1"], x, 1, 1);
+ out = contiguous_2d_to_whcn(m, out);
+ out = interpolate(m, out, {w, h}, bilinear_align_corners);
+ out = whcn_to_contiguous_2d(m, out);
+
+ out = conv_2d(m["conv2"], out, 1, 1);
+ out = ggml_relu_inplace(m, out);
+ out = conv_2d(m["conv3"], out);
+ out = ggml_relu_inplace(m, out);
+
+ if (max_depth != 1) {
+ out = ggml_scale(m, out, max_depth);
+ }
+ return out;
+}
+
+} // namespace dpt
+
+tensor depthany_predict(model_ref m, tensor image, depthany_params const& p) {
+ auto [c, w, h, n] = nelements(image);
+ int64_t w_patch = w / p.dino.patch_size;
+ int64_t h_patch = h / p.dino.patch_size;
+
+ auto features = dino_get_intermediate_layers(m["backbone"], image, p.feature_layers, p.dino);
+ tensor fused = dpt::neck(m["neck"], features, w_patch, h_patch);
+ tensor depth = dpt::head(m["head"], fused, w, h, p.max_depth);
+
+ return compute_graph_output(m, depth);
+}
+
+i32x2 depthany_image_extent(i32x2 extent, depthany_params const& p) {
+ int min_side = std::min(extent[0], extent[1]);
+ int tgt_side = std::max(p.image_size, next_multiple(min_side, p.image_multiple));
+ i32x2 target = extent * tgt_side / min_side;
+ return next_multiple(target, p.image_multiple);
+}
+
+depthany_params depthany_detect_params(model_file const& file, i32x2 input_extent) {
+ depthany_params p;
+ p.dino = dino_detect_params(file);
+ p.image_size = file.get_int("depthanything.image_size");
+ file.get_array("depthanything.feature_layers", p.feature_layers);
+ if (input_extent[0] > 0 && input_extent[1] > 0) {
+ p.image_extent = depthany_image_extent(input_extent, p);
+ }
+ return p;
+}
+
+image_data depthany_process_input(image_view image, depthany_params const& p) {
+ constexpr f32x4 mean = f32x4{0.485f, 0.456f, 0.406f, 0.f};
+ constexpr f32x4 std = f32x4{0.229f, 0.224f, 0.225f, 1.f};
+
+ image_data resized;
+ if (image.extent != p.image_extent) {
+ resized = image_scale(image, p.image_extent);
+ image = image_view(resized);
+ }
+ return image_u8_to_f32(image, image_format::rgb_f32, -mean, 1.f / std);
+}
+
+image_data depthany_process_output(span data, i32x2 extent, depthany_params const& p) {
+ image_view depth_output(p.image_extent, data);
+ image_data normalized = image_normalize(depth_output);
+ if (normalized.extent != extent) {
+ return image_scale(normalized, extent);
+ }
+ return normalized;
+}
+
+} // namespace visp
\ No newline at end of file
diff --git a/src/visp/arch/depth-anything.h b/src/visp/arch/depth-anything.h
new file mode 100644
index 0000000..cc8a0c3
--- /dev/null
+++ b/src/visp/arch/depth-anything.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "visp/ml.h"
+#include "visp/vision.h"
+
+namespace visp::dpt {
+
+tensor residual_conv(model_ref m, tensor x);
+tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size = nullptr);
+tensor neck(model_ref m, span features, int64_t patch_w, int64_t patch_h);
+tensor head(model_ref m, tensor fused, int64_t patch_w, int64_t patch_h, float max_depth);
+
+} // namespace visp::dpt
diff --git a/src/visp/arch/dino.cpp b/src/visp/arch/dino.cpp
new file mode 100644
index 0000000..a1717c4
--- /dev/null
+++ b/src/visp/arch/dino.cpp
@@ -0,0 +1,147 @@
+#include "visp/arch/dino.h"
+#include "util/math.h"
+#include "visp/ml.h"
+#include "visp/nn.h"
+
+namespace visp {
+namespace dino {
+
+tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int patch_size) {
+ tensor pos_embed = m.weights("position_embeddings");
+ int64_t n_patch = x->ne[1] - 1;
+ int64_t n = pos_embed->ne[1] - 1;
+ if (n_patch == n && w == h) {
+ return pos_embed;
+ }
+
+ tensor class_embed = slice(m, pos_embed, {}, {0}, {}, {});
+ tensor patch_embed = slice(m, pos_embed, {}, {1, n + 1}, {}, {});
+ int64_t dim = x->ne[0];
+ i64x2 target = i64x2{w, h} / patch_size;
+ int64_t sqrt_n = int64_t(std::sqrt(float(n)) + 0.01f);
+
+ patch_embed = ggml_reshape_4d(m, patch_embed, dim, sqrt_n, sqrt_n, 1);
+ patch_embed = ggml_cont(m, permute_cwhn_to_whcn(m, patch_embed));
+ patch_embed = interpolate(m, patch_embed, target, GGML_SCALE_MODE_BICUBIC);
+ patch_embed = ggml_cont(m, permute_whcn_to_cwhn(m, patch_embed));
+ patch_embed = ggml_reshape_3d(m, patch_embed, dim, target[0] * target[1], 1);
+ return concat(m, {class_embed, patch_embed}, 1);
+}
+
+tensor prepare_tokens(model_ref m, tensor x, int patch_size) {
+ auto [c, w, h, n] = nelements(x);
+ x = patch_embed(m["patch_embeddings"], x, patch_size);
+ x = ggml_reshape_3d(m, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]);
+
+ tensor cls_token = m.weights("cls_token");
+ if (cls_token->ne[2] != n) {
+ cls_token = ggml_repeat_4d(m, cls_token, cls_token->ne[0], 1, n, 1);
+ }
+ x = concat(m, {cls_token, x}, 1);
+
+ tensor pos_enc = interpolate_pos_encoding(m, x, w, h, patch_size);
+ x = ggml_add_inplace(m, x, pos_enc);
+ return x;
+}
+
+tensor layer_scale(model_ref m, tensor x) {
+ return ggml_mul(m, x, m.weights("lambda1"));
+}
+
+tensor mlp(model_ref m, tensor x) {
+ x = linear(m["fc1"], x);
+ x = ggml_gelu(m, x);
+ x = linear(m["fc2"], x);
+ return x;
+}
+
+tensor attention(model_ref m, tensor x, int n_heads) {
+ auto [c, n, b, _] = nelements(x);
+ float scale = 1.0f / std::sqrt(float(c) / float(n_heads));
+ bool flash_attn = bool(m.flags & model_build_flag::flash_attention);
+ ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ auto split = [=](model_ref m, tensor x, ggml_type type, bool transpose = false) mutable {
+ x = linear(m, x);
+ x = ggml_reshape_4d(m, x, c / n_heads, n_heads, n, b);
+ x = transpose ? ggml_permute(m, x, 1, 2, 0, 3) : ggml_permute(m, x, 0, 2, 1, 3);
+ return ggml_cast(m, x, type);
+ };
+
+ tensor q = split(m["attention.query"], x, GGML_TYPE_F32);
+ tensor k = split(m["attention.key"], x, kv_type);
+ tensor v = split(m["attention.value"], x, kv_type, !flash_attn);
+
+ if (flash_attn) {
+ x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f);
+ } else {
+ tensor attn = ggml_mul_mat(m, k, q);
+ attn = ggml_soft_max_ext(m, attn, nullptr, scale, 0.0f);
+
+ x = ggml_mul_mat(m, v, attn);
+ x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
+ }
+
+ x = ggml_reshape_3d(m, x, c, n, b);
+ x = linear(m["output.dense"], x);
+ return named(m, x);
+}
+
+tensor layer(model_ref m, tensor x, dino_params const& p) {
+ tensor attn = x;
+ attn = layer_norm(m["norm1"], attn, 1e-6f);
+ attn = attention(m["attention"], attn, p.n_heads);
+ attn = layer_scale(m["layer_scale1"], attn);
+ x = ggml_add(m, x, attn);
+
+ tensor ffn = x;
+ ffn = layer_norm(m["norm2"], ffn, 1e-6f);
+ ffn = mlp(m["mlp"], ffn);
+ ffn = layer_scale(m["layer_scale2"], ffn);
+ x = ggml_add(m, x, ffn);
+
+ return named(m, x);
+}
+
+template
+bool contains(std::span r, T const& value) {
+ return std::find(r.begin(), r.end(), value) != r.end();
+}
+
+std::vector get_intermediate_layers(
+ model_ref m, tensor x, std::span layers, dino_params const& p) {
+
+ x = prepare_tokens(m["embeddings"], x, p.patch_size);
+
+ std::vector outputs;
+ model_ref encoder = m["encoder.layer"];
+ for (int i = 0; i < p.n_layers; ++i) {
+ x = layer(encoder[i], x, p);
+
+ if (contains(layers, i)) {
+ tensor out = layer_norm(m["layernorm"], x, 1e-6f);
+ ggml_format_name(out, "dino_layer_%d", i);
+ ggml_build_forward_expand(m.graph, out);
+ outputs.push_back(out);
+ }
+ }
+ return outputs;
+}
+
+} // namespace dino
+
+std::vector dino_get_intermediate_layers(
+ model_ref m, tensor x, std::span layers, dino_params const& p) {
+ return dino::get_intermediate_layers(m, x, layers, p);
+}
+
+dino_params dino_detect_params(model_file const& file) {
+ dino_params p{};
+ p.patch_size = file.get_int("dino.patch_size");
+ p.embed_dim = file.get_int("dino.embed_dim");
+ p.n_heads = file.get_int("dino.n_heads");
+ p.n_layers = file.get_int("dino.n_layers");
+ return p;
+}
+
+} // namespace visp
diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h
new file mode 100644
index 0000000..43d915b
--- /dev/null
+++ b/src/visp/arch/dino.h
@@ -0,0 +1,21 @@
+#pragma once
+
+#include "util/math.h"
+#include "visp/ml.h"
+#include "visp/vision.h"
+
+#include
+
+namespace visp::dino {
+
+tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int patch_size);
+tensor prepare_tokens(model_ref m, tensor x, int patch_size);
+tensor layer_scale(model_ref m, tensor x);
+tensor mlp(model_ref m, tensor x);
+tensor attention(model_ref m, tensor x, int n_heads);
+tensor layer(model_ref m, tensor x, dino_params const& p);
+
+std::vector get_intermediate_layers(
+ model_ref m, tensor x, std::span layers, dino_params const& p);
+
+} // namespace visp::dino
diff --git a/src/visp/arch/swin.cpp b/src/visp/arch/swin.cpp
new file mode 100644
index 0000000..b46483d
--- /dev/null
+++ b/src/visp/arch/swin.cpp
@@ -0,0 +1,344 @@
+#include "visp/arch/swin.h"
+#include "util/string.h"
+#include "visp/nn.h"
+
+namespace visp {
+namespace swin {
+
+tensor mlp(model_ref m, tensor x) {
+ x = linear(m["fc1"], x);
+ x = ggml_gelu_inplace(m, x);
+ x = linear(m["fc2"], x);
+ return named(m, x);
+}
+
+// Ensures that the tensor's data is not overwritten during computation.
+tensor make_constant(tensor x, tensor_name name) {
+ ggml_set_name(x, name.c_str());
+ ggml_set_input(x); // allocate at the beginning of the graph buffer
+ ggml_set_output(x); // don't reuse memory for computations
+ return x;
+}
+
+void compute_relative_position_index(span dst, int window_size) {
+ int n = window_size;
+ int n2 = n * n;
+ int n4 = n2 * n2;
+ for (int i = 0; i < n4; ++i) {
+ int x0 = i % n;
+ int y0 = (i / n) % n;
+ int x1 = (i / n2) % n;
+ int y1 = (i / n2 / n) % n;
+ dst[i] = (y1 - y0 + n - 1) * (2 * n - 1) + (x1 - x0 + n - 1);
+ }
+}
+
+tensor_data create_relative_position_index(ggml_context* ctx, int window_size) {
+ int n = window_size;
+ auto result = tensor_alloc(ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n * n * n * n));
+ auto name = format("window_attention_{}.rel_pos_index", n);
+ compute_relative_position_index(result.as_i32(), n);
+ make_constant(result.x, name);
+ return result;
+}
+
+tensor window_partition(model_ref m, tensor x, int window) {
+ auto [c, w, h, b] = nelements(x);
+ ASSERT(w % window == 0 && h % window == 0, "Expecting padded input");
+
+ x = ggml_reshape_4d(m, x, c * window, w / window, window, (h / window) * b);
+ x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
+ x = ggml_reshape_3d(m, x, c, window * window, (w / window) * (h / window) * b);
+ return x;
+}
+
+tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) {
+ int64_t c = x->ne[0];
+ int64_t b = x->ne[2] / (w / window) / (h / window);
+ ASSERT(x->ne[2] % (w / window) == 0, "Expecting ne[2] to be multiple of window count");
+
+ x = ggml_reshape_4d(m, x, c * window, window, w / window, (h / window) * b);
+ x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
+ x = ggml_reshape_4d(m, x, c, w, h, b);
+ return x;
+}
+
+tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int window) {
+ auto [c, n, b, _] = nelements(x);
+ float scale = 1.0f / std::sqrt(float(c / n_heads));
+ bool flash_attn = bool(m.flags & model_build_flag::flash_attention);
+ ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
+
+ tensor qkv = linear(m["qkv"], x);
+ qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b);
+ qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2));
+
+ auto split = [=](tensor t, size_t index, ggml_type type, bool transpose = false) mutable {
+ t = slice(m, t, {}, {}, {}, index);
+ t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b);
+ t = transpose ? ggml_permute(m, t, 1, 2, 0, 3) : ggml_permute(m, t, 0, 2, 1, 3);
+ t = ggml_cast(m, t, type); // TODO: future flash attention supports f32 and permutations
+ return t;
+ };
+ tensor q = split(qkv, 0, GGML_TYPE_F32);
+ tensor k = split(qkv, 1, kv_type);
+ tensor v = split(qkv, 2, kv_type, !flash_attn);
+
+ tensor_name rel_pos_name = format("window_attention_{}.rel_pos_index", window);
+ tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str());
+ tensor rel_pos_table = m.weights("relative_position_bias_table");
+ tensor rel_pos_bias = ggml_get_rows(m, rel_pos_table, rel_pos_index);
+ rel_pos_bias = ggml_reshape_4d(m, rel_pos_bias, n_heads, n, n, 1);
+ rel_pos_bias = ggml_permute(m, rel_pos_bias, 2, 0, 1, 3); // [n, n, n_heads, 1]
+ rel_pos_bias = ggml_cast(m, rel_pos_bias, GGML_TYPE_F16); // get_rows result is always f32
+
+ tensor attn_mask = rel_pos_bias;
+ if (mask) {
+ int64_t n_windows = mask->ne[2];
+ if (b > n_windows) { // if there are multiple images in the batch
+ mask = ggml_reshape_4d(m, mask, n, n, n_windows, 1);
+ mask = ggml_repeat_4d(m, mask, n, n, n_windows, b / n_windows);
+ }
+ mask = ggml_reshape_4d(m, mask, n, n, 1, b);
+ mask = ggml_repeat_4d(m, mask, n, n, n_heads, b); // can only broadcast one operand in add
+ attn_mask = ggml_add(m, mask, attn_mask); // [n, n, n_heads, b] + [n, n, n_heads, 1]
+ }
+
+ if (flash_attn) {
+ x = ggml_flash_attn_ext(m, q, k, v, attn_mask, scale, 0.0f, 0.0f);
+ ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32);
+ } else {
+ tensor attn = ggml_mul_mat(m, k, q);
+ attn = ggml_soft_max_ext(m, attn, attn_mask, scale, 0.0f);
+
+ x = ggml_mul_mat(m, v, attn);
+ x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
+ }
+
+ x = ggml_reshape_3d(m, x, c, n, b);
+ x = linear(m["proj"], x);
+ return named(m, x);
+}
+
+tensor block(model_ref m, tensor x, tensor mask, block_params const& p) {
+ auto [c, n, b, _] = nelements(x);
+ auto [num_heads, window, w, h, shift] = p;
+ ASSERT(n == w * h && "Spatial dimensions do not match");
+
+ tensor shortcut = x;
+ x = layer_norm(m["norm1"], x);
+ x = ggml_reshape_4d(m, x, c, w, h, b);
+
+ int pad_r = (window - w % window) % window;
+ int pad_b = (window - h % window) % window;
+ if (pad_r > 0 || pad_b > 0) {
+ x = ggml_pad(m, x, 0, pad_r, pad_b, 0);
+ }
+
+ ASSERT(shift == 0 || mask != nullptr);
+ if (shift > 0) {
+ x = ggml_roll(m, x, 0, -shift, -shift, 0);
+ }
+
+ x = window_partition(m, x, window);
+ x = window_attention(m["attn"], x, mask, num_heads, window);
+ x = window_reverse(m, x, w + pad_r, h + pad_b, window);
+
+ if (shift > 0) { // undo shift
+ x = ggml_roll(m, x, 0, shift, shift, 0);
+ }
+
+ if (pad_r > 0 || pad_b > 0) { // undo padding
+ x = ggml_reshape_4d(m, x, c, w + pad_r, h + pad_b, b);
+ x = slice(m, x, {}, {0, w}, {0, h}, {});
+ x = ggml_cont(m, x);
+ }
+
+ x = ggml_reshape_3d(m, x, c, n, b);
+ x = ggml_add_inplace(m, x, shortcut);
+
+ tensor x_mlp = layer_norm(m["norm2"], x);
+ x_mlp = mlp(m["mlp"], x_mlp);
+ x = ggml_add_inplace(m, x, x_mlp);
+
+ return named(m, x);
+}
+
+tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h) {
+ auto [c, n, b, _] = nelements(x);
+ ASSERT(n == w * h, "Spatial dimensions do not match");
+ ASSERT(w % 2 == 0 && h % 2 == 0, "Expecting even spatial dimensions");
+
+ x = ggml_reshape_4d(m, x, c, w, h, b);
+ // clang-format off
+ x = concat(m, {
+ slice(m, x, {}, {0, w, 2}, {0, h, 2}, {}),
+ slice(m, x, {}, {0, w, 2}, {1, h, 2}, {}),
+ slice(m, x, {}, {1, w, 2}, {0, h, 2}, {}),
+ slice(m, x, {}, {1, w, 2}, {1, h, 2}, {})}, 0);
+ // clang-format on
+ x = ggml_reshape_3d(m, x, c * 4, n / 4, b);
+
+ x = layer_norm(m["norm"], x);
+ x = linear(m["reduction"], x);
+ return named(m, x);
+}
+
+constexpr uint16_t neg_inf_f16 = 0xfc00; // -infinity in IEEE 754 half-precision
+
+void compute_attention_mask(span out_bytes, int64_t w, int64_t h, int window_size) {
+ uint16_t* out = reinterpret_cast(out_bytes.data());
+ int n = window_size;
+ int n2 = n * n;
+ int n4 = n2 * n2;
+ int shift = window_size / 2;
+ int64_t nw_x = (w + n - 1) / n;
+ int64_t nw_y = (h + n - 1) / n;
+ int64_t w_pad = nw_x * n;
+ int64_t h_pad = nw_y * n;
+
+ std::memset(out, 0, out_bytes.size());
+
+ for (int iw_y = 0; iw_y < nw_y; ++iw_y) {
+ for (int iw_x = 0; iw_x < nw_x; ++iw_x) {
+ // Skip all windows that aren't at the right or bottom edges of the image
+ if (iw_y < nw_y - 1 && iw_x < nw_x - 1) {
+ continue;
+ }
+ int64_t base = iw_y * nw_x * n4 + iw_x * n4;
+
+ for (int y0 = 0; y0 < n; ++y0) {
+ for (int x0 = 0; x0 < n; ++x0) {
+ for (int y1 = 0; y1 < n; ++y1) {
+ for (int x1 = 0; x1 < n; ++x1) {
+ // Window-local coordinates to global image coordinates
+ int yy0 = iw_y * n + y0;
+ int xx0 = iw_x * n + x0;
+ int yy1 = iw_y * n + y1;
+ int xx1 = iw_x * n + x1;
+ // Check if two patches being matched belong to the same window
+ // that is: they are both in the shift zone, or both outside
+ bool match_y = (yy0 < h_pad - shift) == (yy1 < h_pad - shift);
+ bool match_x = (xx0 < w_pad - shift) == (xx1 < w_pad - shift);
+ // If not, set attention mask to -inf so it is ignored by softmax
+ if (!match_y || !match_x) {
+ int64_t idx = base + (y0 * n + x0) * n2 + (y1 * n + x1);
+ out[idx] = neg_inf_f16;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size) {
+ int n = window_size;
+ int64_t nw_x = (w + n - 1) / n;
+ int64_t nw_y = (h + n - 1) / n;
+ auto result = tensor_alloc(ggml_new_tensor_3d(ctx, GGML_TYPE_F16, n * n, n * n, nw_x * nw_y));
+ auto name = format("swin_layer_{}x{}.attn_mask", w, h);
+ compute_attention_mask(result.as_bytes(), w, h, window_size);
+ make_constant(result.x, name);
+ return result;
+}
+
+layer_result layer(
+ model_ref m, tensor x, int64_t w, int64_t h, swin_layer_t const& p, int window, bool down) {
+ // Attention masks need to be precomputed
+ tensor_name attn_mask_name = format("swin_layer_{}x{}.attn_mask", w, h);
+ tensor attn_mask = ggml_get_tensor(m, attn_mask_name.c_str());
+
+ model_ref blocks = m["blocks"];
+ for (int i = 0; i < p.depth; ++i) {
+ x = block(
+ blocks[i], x, attn_mask,
+ {.n_heads = p.n_heads,
+ .window_size = window,
+ .w = w,
+ .h = h,
+ .shift = i % 2 == 0 ? 0 : window / 2});
+ }
+ if (down) {
+ tensor x_down = patch_merging(m["downsample"], x, w, h);
+ return {x, w, h, x_down, (w + 1) / 2, (h + 1) / 2};
+ }
+ return {x, w, h, x, w, h};
+}
+
+swin_result encode(model_ref m, tensor x, swin_params const& p) {
+ x = patch_embed(m["patch_embed"], x, 4);
+
+ auto [c, w, h, b] = nelements(x);
+ x = ggml_reshape_3d(m, x, c, w * h, b);
+
+ layer_result r{x, w, h, x, w, h};
+ swin_result outs = {};
+
+ for (int i = 0; i < swin_n_layers; ++i) {
+ bool downsample = (i < swin_n_layers - 1);
+ r = layer(
+ m["layers"][i], r.x_down, r.w_down, r.h_down, p.layers[i], p.window_size, downsample);
+
+ tensor_name norm_layer = format("norm{}", i);
+ tensor out = layer_norm(m[norm_layer], r.x_out);
+ out = ggml_reshape_4d(m, out, p.layers[i].n_features, r.w_out, r.h_out, b);
+ outs[i] = out;
+ }
+ return outs;
+}
+
+} // namespace swin
+
+// clang-format off
+const swin_params swin_t_params = {
+ .embed_dim = 96,
+ .window_size = 7,
+ .layers = {
+ // depth n_heads n_features
+ swin_layer_t{2, 3, 96 * 1},
+ swin_layer_t{2, 6, 96 * 2},
+ swin_layer_t{6, 12, 96 * 4},
+ swin_layer_t{2, 24, 96 * 8}}};
+
+const swin_params swin_l_params = {
+ .embed_dim = 192,
+ .window_size = 12,
+ .layers = {
+ // depth n_heads n_features
+ swin_layer_t{2, 6, 192 * 1},
+ swin_layer_t{2, 12, 192 * 2},
+ swin_layer_t{18, 24, 192 * 4},
+ swin_layer_t{2, 48, 192 * 8}}};
+// clang-format on
+
+swin_params swin_detect_params(model_file const& f) {
+ int embed_dim = f.get_int("swin.embed_dim");
+ if (embed_dim == 96) {
+ return swin_t_params;
+ } else if (embed_dim == 192) {
+ return swin_l_params;
+ } else {
+ throw except("Unsupported Swin Transformer embed dim: {}", embed_dim);
+ }
+}
+
+swin_buffers swin_precompute(model_ref m, i32x2 image_extent, swin_params const& p) {
+ int w = p.window_size;
+ int width = image_extent[0] / 4;
+ int height = image_extent[1] / 4;
+
+ swin_buffers b;
+ b[0] = swin::create_relative_position_index(m, w);
+ for (int i = 0; i < swin_n_layers + 1; ++i) {
+ b[i + 1] = swin::create_attention_mask(m, width >> i, height >> i, w);
+ }
+ return b;
+}
+
+swin_result swin_encode(model_ref m, tensor image, swin_params const& p) {
+ return swin::encode(m, image, p);
+}
+
+} // namespace visp
\ No newline at end of file
diff --git a/src/visp/arch/swin.h b/src/visp/arch/swin.h
new file mode 100644
index 0000000..6b1195b
--- /dev/null
+++ b/src/visp/arch/swin.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include "visp/ml.h"
+#include "visp/vision.h"
+
+namespace visp::swin {
+
+struct block_params {
+ int n_heads = 6;
+ int window_size = 7;
+ int64_t w = 0;
+ int64_t h = 0;
+ int shift = 0;
+};
+
+struct layer_result {
+ tensor x_out;
+ int64_t w_out;
+ int64_t h_out;
+ tensor x_down;
+ int64_t w_down;
+ int64_t h_down;
+};
+
+void compute_relative_position_index(span dst, int window_size);
+tensor_data create_relative_position_index(ggml_context* ctx, int window_size);
+void compute_attention_mask(std::span out, int64_t w, int64_t h, int window_size);
+tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size);
+
+tensor mlp(model_ref m, tensor x);
+tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h);
+tensor window_partition(model_ref m, tensor x, int window);
+tensor window_reverse(model_ref m, tensor x, int w, int h, int window);
+tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window);
+tensor block(model_ref m, tensor x, tensor mask, block_params const&);
+layer_result layer(
+ model_ref, tensor, int64_t w, int64_t h, swin_layer_t const&, int window_size, bool downsample);
+
+} // namespace visp::swin
\ No newline at end of file
diff --git a/src/visp/image.cpp b/src/visp/image.cpp
index f230876..77cb42c 100644
--- a/src/visp/image.cpp
+++ b/src/visp/image.cpp
@@ -197,7 +197,7 @@ image_data image_load(char const* filepath) {
void image_save(image_view const& img, char const* filepath) {
ASSERT(img.extent[0] > 0 && img.extent[1] > 0);
-
+
if (!(img.format == image_format::alpha_u8 || img.format == image_format::rgb_u8 ||
img.format == image_format::rgba_u8)) {
throw except("Unsupported image format [{}]", int(img.format));
@@ -534,6 +534,53 @@ void image_erosion(image_view const& src, image_span const& dst, int radius) {
}
}
+void image_normalize(image_view const& src, image_span const& dst, float min, float max) {
+ ASSERT(src.extent == dst.extent);
+ ASSERT(is_float(src.format) && is_float(dst.format));
+ ASSERT(min < max);
+
+ float const fmax = std::numeric_limits::max();
+ int const channels = n_channels(src);
+ float const* src_data = (float const*)src.data;
+ float* dst_data = (float*)dst.data;
+
+ f32x4 min_val = {fmax, fmax, fmax, fmax};
+ f32x4 max_val = {-fmax, -fmax, -fmax, -fmax};
+
+ for (int y = 0; y < src.extent[1]; ++y) {
+ for (int x = 0; x < src.extent[0]; ++x) {
+ for (int c = 0; c < channels; ++c) {
+ float v = src_data[y * src.stride / 4 + x * channels + c];
+ min_val[c] = std::min(min_val[c], v);
+ max_val[c] = std::max(max_val[c], v);
+ }
+ }
+ }
+
+ f32x4 delta = max_val - min_val;
+ for (int c = 0; c < channels; ++c) {
+ delta[c] = delta[c] < 1e-5f ? 1.0f : delta[c];
+ }
+ f32x4 scale = f32x4{max - min} / delta;
+ f32x4 offset = -min_val * scale + f32x4{min};
+
+ for (int y = 0; y < src.extent[1]; ++y) {
+ for (int x = 0; x < src.extent[0]; ++x) {
+ for (int c = 0; c < channels; ++c) {
+ float v = src_data[y * src.stride / 4 + x * channels + c];
+ v = v * scale[c] + offset[c];
+ dst_data[y * dst.stride / 4 + x * channels + c] = v;
+ }
+ }
+ }
+}
+
+image_data image_normalize(image_view const& img, float min, float max) {
+ image_data dst = image_alloc(img.extent, img.format);
+ image_normalize(img, dst, min, max);
+ return dst;
+}
+
template
float difference_rms(image_source a, image_source b) {
float sum_sq_diff = 0.0f;
diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp
index 8a85888..ad5ae9e 100644
--- a/src/visp/ml.cpp
+++ b/src/visp/ml.cpp
@@ -138,12 +138,23 @@ void backend_set_n_threads(backend_device& b, int n_threads) {
//
// model_build_flags
+model_build_flags flash_attn_flag(bool default_enabled) {
+ static char const* const env = getenv("VISP_FLASH_ATTENTION");
+ if (env && env[0] == '1') {
+ return model_build_flag::flash_attention;
+ } else if (env && env[0] == '0') {
+ return model_build_flags{};
+ }
+ return default_enabled ? model_build_flag::flash_attention : model_build_flags{};
+}
+
model_build_flags backend_default_flags(backend_type type) {
using enum model_build_flag;
switch (type) {
case backend_type::cpu:
- return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition;
- case backend_type::gpu: return {};
+ return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition |
+ flash_attn_flag(false);
+ case backend_type::gpu: return flash_attn_flag(true);
}
return {};
}
@@ -199,6 +210,19 @@ int model_file::get_int(char const* key_name) const {
return gguf_get_val_i32(gguf.get(), key(key_name));
}
+void model_file::get_array(char const* key_name, span out_values) const {
+ int64_t key_id = key(key_name);
+ if (gguf_get_arr_n(gguf.get(), key_id) != out_values.size()) {
+ throw except("Array size mismatch for key '{}' in model file {}", key_name, path);
+ }
+ if (gguf_get_arr_type(gguf.get(), key_id) != GGUF_TYPE_INT32) {
+ throw except(
+ "Array type mismatch for key '{}' in model file {}, expected int32", key_name, path);
+ }
+ auto ptr = (int const*)gguf_get_arr_data(gguf.get(), key_id);
+ std::copy(ptr, ptr + out_values.size(), out_values.data());
+}
+
std::string_view model_file::arch() const {
return get_string("general.architecture");
}
@@ -587,6 +611,18 @@ tensor_data tensor_load(tensor x, char const* filepath) {
return result;
}
+void tensor_save(tensor x, char const* filepath) {
+ FILE* file = fopen(filepath, "wb");
+ if (!file) {
+ throw except("Failed to open file for writing: {}", filepath);
+ }
+ size_t written = fwrite(x->data, 1, ggml_nbytes(x), file);
+ fclose(file);
+ if (written != ggml_nbytes(x)) {
+ throw except("Failed to write tensor data to file: {}", filepath);
+ }
+}
+
std::span tensor_data::as_f32() {
ASSERT(x->type == GGML_TYPE_F32);
return span(reinterpret_cast(data.get()), ggml_nelements(x));
@@ -607,6 +643,14 @@ std::span tensor_data::as_i32() const {
return span(reinterpret_cast(data.get()), ggml_nelements(x));
}
+std::span tensor_data::as_bytes() {
+ return span(data.get(), ggml_nbytes(x));
+}
+
+std::span tensor_data::as_bytes() const {
+ return span(data.get(), ggml_nbytes(x));
+}
+
void transfer_to_backend(tensor_data const& d) {
ggml_backend_tensor_set(d.x, d.data.get(), 0, ggml_nbytes(d.x));
}
diff --git a/src/visp/nn.cpp b/src/visp/nn.cpp
index 7b6065b..6d3268c 100644
--- a/src/visp/nn.cpp
+++ b/src/visp/nn.cpp
@@ -3,7 +3,6 @@
namespace visp {
-
tensor linear(model_ref m, tensor x) {
x = ggml_mul_mat(m, m.weights("weight"), x);
if (tensor bias = m.find("bias")) {
@@ -88,16 +87,10 @@ tensor conv_2d(model_ref m, tensor x, int stride, int pad) {
x = permute_whcn_to_cwhn(m, x);
} else {
- x = permute_cwhn_to_whcn(m, x);
- tensor permuted_weight = permute_cwhn_to_whcn(m, weight);
- tensor cols = ggml_im2col(
- m, permuted_weight, x, stride, stride, pad, pad, 1, 1, true, GGML_TYPE_F32);
- tensor a = ggml_reshape_2d(
- m, cols, cols->ne[0], cols->ne[1] * cols->ne[2] * cols->ne[3]);
- tensor b = ggml_reshape_2d(
- m, weight, weight->ne[0] * weight->ne[1] * weight->ne[2], weight->ne[3]);
- x = ggml_mul_mat(m, b, a);
- x = ggml_reshape_4d(m, x, weight->ne[3], cols->ne[1], cols->ne[2], cols->ne[3]);
+ weight = ggml_cont(m, permute_cwhn_to_whcn(m, weight));
+ x = ggml_cont(m, permute_cwhn_to_whcn(m, x));
+ x = ggml_conv_2d(m, weight, x, stride, stride, pad, pad, 1, 1);
+ x = ggml_cont(m, permute_whcn_to_cwhn(m, x));
}
} else { // WHCN layout
x = ggml_conv_2d_direct(m, weight, x, stride, stride, pad, pad, 1, 1);
@@ -174,4 +167,20 @@ tensor batch_norm_2d(model_ref m, tensor x) {
return named(m, x);
}
+tensor patch_embed(model_ref m, tensor x, int patch_size) {
+ ASSERT(x->ne[1] % patch_size == 0 && x->ne[2] % patch_size == 0);
+ char const* proj = m.find("proj.weight") ? "proj" : "projection";
+
+ m.flags |= model_build_flag::cwhn;
+ x = conv_2d(m[proj], x, patch_size);
+
+ if (m.find("norm.weight")) {
+ auto [c, w, h, b] = nelements(x);
+ x = ggml_reshape_3d(m, x, c, w * h, b);
+ x = layer_norm(m["norm"], x);
+ x = ggml_reshape_4d(m, x, c, w, h, b);
+ }
+ return named(m, x);
+}
+
} // namespace visp
\ No newline at end of file
diff --git a/src/visp/nn.h b/src/visp/nn.h
index eb8c106..9b7e762 100644
--- a/src/visp/nn.h
+++ b/src/visp/nn.h
@@ -38,4 +38,7 @@ tensor conv_2d_deform(
tensor conv_transpose_2d(model_ref m, tensor x, int stride);
tensor batch_norm_2d(model_ref, tensor x);
+// 2D image to patch embedding using convolution and optional norm. CWHN input and output.
+tensor patch_embed(model_ref, tensor x, int patch_size);
+
} // namespace visp
diff --git a/src/visp/vision.cpp b/src/visp/vision.cpp
index bd8216e..36d324c 100644
--- a/src/visp/vision.cpp
+++ b/src/visp/vision.cpp
@@ -115,6 +115,41 @@ image_data birefnet_compute(birefnet_model& model, image_view image) {
return birefnet_process_output(mask_data.as_f32(), image.extent, model.params);
}
+//
+// Depth Anything
+
+depthany_model depthany_load_model(char const* filepath, backend_device const& dev) {
+ depthany_model model;
+ model.backend = &dev;
+ model_file file = model_load(filepath);
+ model.params = depthany_detect_params(file);
+ model.weights = model_init(file.n_tensors());
+ model_transfer(file, model.weights, dev, dev.preferred_float_type(), dev.preferred_layout());
+ return model;
+}
+
+image_data depthany_compute(depthany_model& model, image_view image) {
+ i32x2 res = depthany_image_extent(image.extent, model.params);
+
+ if (!model.graph || res != model.params.image_extent) {
+ model.params.image_extent = res;
+ model.graph = compute_graph_init();
+
+ model_ref m(model.weights, model.graph);
+ model.input = compute_graph_input(m, GGML_TYPE_F32, {3, res[0], res[1], 1});
+ model.output = depthany_predict(m, model.input, model.params);
+ compute_graph_allocate(model.graph, *model.backend);
+ }
+
+ image_data img_data = depthany_process_input(image, model.params);
+ transfer_to_backend(model.input, img_data);
+
+ compute(model.graph, *model.backend);
+
+ tensor_data output_data = transfer_from_backend(model.output);
+ return depthany_process_output(output_data.as_f32(), image.extent, model.params);
+}
+
//
// MI-GAN
diff --git a/tests/benchmark.cpp b/tests/benchmark.cpp
index a75bd13..d10bcfb 100644
--- a/tests/benchmark.cpp
+++ b/tests/benchmark.cpp
@@ -93,6 +93,17 @@ bench_timings benchmark_birefnet(path model_path, backend_device& backend) {
return run_benchmark(model.graph, backend, 8, {{model.input, input_data}});
}
+bench_timings benchmark_depth_anything(path model_path, backend_device& backend) {
+ path input_path = test_dir().input / "wardrobe.jpg";
+
+ depthany_model model = depthany_load_model(model_path.string().c_str(), backend);
+ image_data input = image_load(input_path.string().c_str());
+ depthany_compute(model, input);
+
+ image_data input_data = depthany_process_input(input, model.params);
+ return run_benchmark(model.graph, backend, 12, {{model.input, input_data}});
+}
+
bench_timings benchmark_migan(path model_path, backend_device& backend) {
path image_path = test_dir().input / "bench-image.jpg";
path mask_path = test_dir().input / "bench-mask.png";
@@ -172,6 +183,10 @@ bench_result benchmark_model(
path model_path = select_model(model, "BiRefNet-lite-F16.gguf");
result.time = benchmark_birefnet(model_path, backend);
+ } else if (arch == "depthany") {
+ path model_path = select_model(model, "Depth-Anything-V2-Small-F16.gguf");
+ result.time = benchmark_depth_anything(model_path, backend);
+
} else if (arch == "migan") {
path model_path = select_model(model, "MIGAN-512-places2-F16.gguf");
result.time = benchmark_migan(model_path, backend);
diff --git a/tests/reference-images.cmake b/tests/reference-images.cmake
index d2d0a0b..bdc1cdc 100644
--- a/tests/reference-images.cmake
+++ b/tests/reference-images.cmake
@@ -1,6 +1,8 @@
file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-cpu.png/c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09" "tests/reference/birefnet-cpu.png" EXPECTED_HASH SHA256=c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09)
-file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-dynamic.png/720bf20140f6f93c3c3953ed2e28a9cb395def8426f53c031d58a8393784227f" "tests/reference/birefnet-dynamic.png" EXPECTED_HASH SHA256=720bf20140f6f93c3c3953ed2e28a9cb395def8426f53c031d58a8393784227f)
-file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-gpu.png/c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09" "tests/reference/birefnet-gpu.png" EXPECTED_HASH SHA256=c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09)
+file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-dynamic.png/5ef6a13855c566609de54e08112c4308c97a0f6740b410e8639bc993b2273c7c" "tests/reference/birefnet-dynamic.png" EXPECTED_HASH SHA256=5ef6a13855c566609de54e08112c4308c97a0f6740b410e8639bc993b2273c7c)
+file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-gpu.png/1d55cdcb0f3648c32830ad1247d768b867e34e20cdbcf08ed166859b55f75aad" "tests/reference/birefnet-gpu.png" EXPECTED_HASH SHA256=1d55cdcb0f3648c32830ad1247d768b867e34e20cdbcf08ed166859b55f75aad)
+file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/depth-anything-cpu.png/36adde57ebd2589fe37bf7c0efbf9d3a013f98f7d7a45bb19fd2c492c8ade7a9" "tests/reference/depth-anything-cpu.png" EXPECTED_HASH SHA256=36adde57ebd2589fe37bf7c0efbf9d3a013f98f7d7a45bb19fd2c492c8ade7a9)
+file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/depth-anything-gpu.png/b3639c0e049081ea35d2fdc37c12634457d52c320a6b839f4d6099319103464b" "tests/reference/depth-anything-gpu.png" EXPECTED_HASH SHA256=b3639c0e049081ea35d2fdc37c12634457d52c320a6b839f4d6099319103464b)
file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/esrgan-cpu.png/481dcc0eb617feb9f8f7403ce179e77e2eba2c7a067f4a1ea90e0fb47083d814" "tests/reference/esrgan-cpu.png" EXPECTED_HASH SHA256=481dcc0eb617feb9f8f7403ce179e77e2eba2c7a067f4a1ea90e0fb47083d814)
file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/esrgan-gpu.png/a8bfab0e07aeca16b737872bb3dbbe0e6b76cfff5616d2f02f2b0465cc7a0937" "tests/reference/esrgan-gpu.png" EXPECTED_HASH SHA256=a8bfab0e07aeca16b737872bb3dbbe0e6b76cfff5616d2f02f2b0465cc7a0937)
file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/migan-cpu.png/9fb32419246e3e073c73df8f4a0fefd334934ffddf8a157535b8b2fc3c1d93ee" "tests/reference/migan-cpu.png" EXPECTED_HASH SHA256=9fb32419246e3e073c73df8f4a0fefd334934ffddf8a157535b8b2fc3c1d93ee)
diff --git a/tests/test-image.cpp b/tests/test-image.cpp
index 85a94c6..83837e3 100644
--- a/tests/test-image.cpp
+++ b/tests/test-image.cpp
@@ -280,6 +280,26 @@ VISP_TEST(image_erosion) {
CHECK_IMAGES_EQUAL(output, expected);
}
+VISP_TEST(image_normalize) {
+ constexpr i32x2 extent{2, 2};
+ std::array input_data = {
+ f32x3{-1.0f, 4.2f, 0.5f}, f32x3{5.0f, 4.2f, 0.0f}, //
+ f32x3{-5.0f, 4.2f, 0.6f}, f32x3{1.0f, 4.2f, 1.0f}, //
+ };
+ std::array expected_data = {
+ f32x3{0.4f, 0.0f, 0.5f}, f32x3{1.0f, 0.0f, 0.0f}, //
+ f32x3{0.0f, 0.0f, 0.6f}, f32x3{0.6f, 0.0f, 1.0f}, //
+ };
+ std::array output_data{};
+
+ auto input = image_view(extent, input_data);
+ auto output = image_span(extent, output_data);
+ image_normalize(input, output);
+
+ auto expected = image_view(extent, expected_data);
+ CHECK_IMAGES_EQUAL(output, expected);
+}
+
VISP_TEST(tile_merge) {
std::array, 4> tiles;
for (int t = 0; t < 4; ++t) {
diff --git a/tests/test-models.cpp b/tests/test-models.cpp
index 3f7b803..2ca1cd6 100644
--- a/tests/test-models.cpp
+++ b/tests/test-models.cpp
@@ -70,6 +70,22 @@ VISP_TEST(test_birefnet_dynamic) {
compare_images("birefnet-dynamic.png", output2, 0.015f);
}
+VISP_BACKEND_TEST(test_depth_anything)(backend_type bt) {
+ path model_path = test_dir().models / "Depth-Anything-V2-Small-F16.gguf";
+ path input_path = test_dir().input / "wardrobe.jpg";
+ std::string name = "depth-anything";
+ name += bt == backend_type::cpu ? "-cpu.png" : "-gpu.png";
+
+ backend_device b = backend_init(bt);
+ depthany_model model = depthany_load_model(model_path.string().c_str(), b);
+ image_data input = image_load(input_path.string().c_str());
+ image_data depth = depthany_compute(model, input);
+ image_data output = image_f32_to_u8(depth, image_format::alpha_u8);
+
+ float tolerance = bt == backend_type::cpu ? 0.01f : 0.015f;
+ compare_images(name, output, tolerance);
+}
+
VISP_BACKEND_TEST(test_migan)(backend_type bt) {
path model_path = test_dir().models / "MIGAN-512-places2-F16.gguf";
path image_path = test_dir().input / "bench-image.jpg";
diff --git a/tests/test_birefnet.py b/tests/test_birefnet.py
index 353bb0d..b57586a 100644
--- a/tests/test_birefnet.py
+++ b/tests/test_birefnet.py
@@ -118,7 +118,9 @@ def test_relative_position_index():
@pytest.mark.parametrize("masking", ["mask", "no_mask"])
-def test_window_attention(masking: bool):
+@pytest.mark.parametrize("backend", ["cpu", "gpu"])
+@pytest.mark.parametrize("attn", ["default", "flash_attn"])
+def test_window_attention(masking: bool, backend: str, attn: str):
num_heads = 2
window_attention = WindowAttention(dim=8, window_size=(3, 3), num_heads=num_heads)
state = generate_state(window_attention.state_dict())
@@ -132,9 +134,13 @@ def test_window_attention(masking: bool):
state["mask"] = mask
expected = window_attention(x, mask)
- result = workbench.invoke_test("biref_window_attention", x, state)
+ del state["relative_position_index"] # computed in C++
+ if mask is not None:
+ state["mask"] = mask.half()
+ state["relative_position_bias_table"] = state["relative_position_bias_table"].half()
+ result = workbench.invoke_test("biref_window_attention", x, state, {"attn": attn}, backend)
- assert torch.allclose(result, expected)
+ assert torch.allclose(result, expected, rtol=1e-3)
def window_partition(x, window_size):
@@ -740,8 +746,8 @@ def test_encode():
expected = forward_enc(x, xs, xs_low)
state = {}
- state.update({f"input{i}": to_nhwc(xs[i]) for i in range(4)})
- state.update({f"input_low{i}": to_nhwc(xs_low[i]) for i in range(4)})
+ state.update({f"xs{i}": to_nhwc(xs[i]) for i in range(4)})
+ state.update({f"xs_low{i}": to_nhwc(xs_low[i]) for i in range(4)})
results = workbench.invoke_test("biref_encode", x, state, nhwc_layout)
diff --git a/tests/test_primitives.py b/tests/test_primitives.py
index 08c6414..a151a40 100644
--- a/tests/test_primitives.py
+++ b/tests/test_primitives.py
@@ -16,9 +16,7 @@ def test_linear():
assert torch.allclose(result, expected)
-@pytest.mark.parametrize(
- "scenario", ["stride_1_pad_0", "stride_2_pad_1", "dilation_2_pad_2"]
-)
+@pytest.mark.parametrize("scenario", ["stride_1_pad_0", "stride_2_pad_1", "dilation_2_pad_2"])
@pytest.mark.parametrize("memory_layout", ["nchw", "nhwc"])
@pytest.mark.parametrize("batch", ["single", "batch"])
@pytest.mark.parametrize("backend", ["cpu", "vulkan"])
@@ -128,9 +126,7 @@ def test_window_partition(backend: str):
nW = pW // win
# window partition
expected = (
- expected.view(B, nH, win, nW, win, C)
- .transpose(2, 3)
- .reshape(B * nH * nW, win * win, C)
+ expected.view(B, nH, win, nW, win, C).transpose(2, 3).reshape(B * nH * nW, win * win, C)
)
result = workbench.invoke_test("sam_window_partition", x, {}, backend=backend)
@@ -150,3 +146,24 @@ def test_roll(shift: tuple[int, int, int, int], backend: str):
result = workbench.invoke_test("roll", x, {}, params, backend)
assert torch.allclose(result, expected)
+
+
+@pytest.mark.parametrize("mode", ["bilinear", "bicubic"])
+@pytest.mark.parametrize("align_corners", [True, False])
+@pytest.mark.parametrize("size", ["small", "large"])
+@pytest.mark.parametrize("scale", [0.6, 2.0])
+@pytest.mark.parametrize("backend", ["cpu", "vulkan"])
+def test_interpolate(mode: str, align_corners: bool, size: str, scale: float, backend: str):
+ b, c, h, w = {
+ "small": (1, 3, 2, 3),
+ "large": (4, 19, 20, 30),
+ }[size]
+ target = (round(h * scale), round(w * scale))
+ x = torch.arange(b * c * h * w).reshape(b, c, h, w).float()
+ expected = torch.nn.functional.interpolate(
+ x, size=target, mode=mode, align_corners=align_corners
+ )
+
+ params = dict(mode=mode, h=target[0], w=target[1], align_corners=1 if align_corners else 0)
+ result = workbench.invoke_test("interpolate", x, {}, params, backend)
+ assert torch.allclose(result, expected)
diff --git a/tests/workbench.cpp b/tests/workbench.cpp
index f31e83d..b3dc1d4 100644
--- a/tests/workbench.cpp
+++ b/tests/workbench.cpp
@@ -1,8 +1,11 @@
#include "util/string.h"
#include "visp/arch/birefnet.h"
+#include "visp/arch/depth-anything.h"
+#include "visp/arch/dino.h"
#include "visp/arch/esrgan.h"
#include "visp/arch/migan.h"
#include "visp/arch/mobile-sam.h"
+#include "visp/arch/swin.h"
#include "visp/nn.h"
#include
@@ -116,6 +119,17 @@ DEF(linear)(model_ref m, span input, param_dict const& p) {
return {linear(m, input[0])};
}
+DEF(interpolate)(model_ref m, span input, param_dict const& p) {
+ int w = p.get("w", 8);
+ int h = p.get("h", 8);
+ uint32_t mode = p.get("mode", "bilinear") == "bilinear"sv ? GGML_SCALE_MODE_BILINEAR
+ : GGML_SCALE_MODE_BICUBIC;
+ if (p.get("align_corners", 0)) {
+ mode |= GGML_SCALE_FLAG_ALIGN_CORNERS;
+ }
+ return {ggml_interpolate(m, input[0], w, h, input[0]->ne[2], input[0]->ne[3], mode)};
+}
+
//
// Mobile SAM
@@ -229,45 +243,50 @@ DEF(sam_predict_masks)(model_ref m, span input, param_dict const& p) {
// BiRefNet
DEF(biref_patch_embed)(model_ref m, span input, param_dict const& p) {
- return {birefnet::patch_embed(m, input[0])};
+ return {patch_embed(m, input[0], 4)};
}
DEF(biref_relative_position_index)(model_ref m, span input, param_dict const& p) {
auto dst = span(reinterpret_cast(input[0]->data), ggml_nelements(input[0]));
- birefnet::compute_relative_position_index(dst, 3);
+ swin::compute_relative_position_index(dst, 3);
return {input[0]};
}
DEF(biref_window_attention)(model_ref m, span input, param_dict const& p) {
+ if (p.get("attn", "default") == "flash_attn"sv) {
+ m.flags = m.flags | model_build_flag::flash_attention;
+ } else {
+ m.flags = m.flags & ~model_build_flag::flash_attention;
+ }
int window_size = 3;
tensor mask = m.find("mask");
- auto rel_pos_index = birefnet::create_relative_position_index(m, window_size);
+ auto rel_pos_index = swin::create_relative_position_index(m, window_size);
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
transfer_to_backend(rel_pos_index);
- return {birefnet::window_attention(m, input[0], mask, 2, window_size)};
+ return {swin::window_attention(m, input[0], mask, 2, window_size)};
}
DEF(biref_swin_block)(model_ref m, span input, param_dict const& p) {
- birefnet::swin_block_params block;
+ swin::block_params block;
block.n_heads = 2;
block.window_size = 3;
block.w = 6;
block.h = 6;
block.shift = 0;
tensor mask = m.find("mask");
- auto rel_pos_index = birefnet::create_relative_position_index(m, 3);
+ auto rel_pos_index = swin::create_relative_position_index(m, 3);
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
transfer_to_backend(rel_pos_index);
- return {birefnet::swin_block(m, input[0], mask, block)};
+ return {swin::block(m, input[0], mask, block)};
}
DEF(biref_patch_merging)(model_ref m, span input, param_dict const& p) {
- return {birefnet::patch_merging(m, input[0], 6, 4)};
+ return {swin::patch_merging(m, input[0], 6, 4)};
}
DEF(biref_attention_mask)(model_ref m, span input, param_dict const& p) {
- auto dst = span((float*)input[0]->data, ggml_nelements(input[0]));
- birefnet::compute_attention_mask(dst, 18, 18, 6);
+ auto dst = span((byte*)input[0]->data, ggml_nbytes(input[0]));
+ swin::compute_attention_mask(dst, 18, 18, 6);
return {input[0]};
}
@@ -276,13 +295,12 @@ DEF(biref_swin_layer)(model_ref m, span input, param_dict const& p) {
layer.depth = 2;
layer.n_heads = 2;
layer.n_features = 8;
- layer.downsample = true;
- auto rel_pos_index = birefnet::create_relative_position_index(m, 3);
- auto attn_mask = birefnet::create_attention_mask(m, 6, 6, 3);
+ auto rel_pos_index = swin::create_relative_position_index(m, 3);
+ auto attn_mask = swin::create_attention_mask(m, 6, 6, 3);
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
transfer_to_backend(rel_pos_index);
transfer_to_backend(attn_mask);
- auto result = birefnet::swin_layer(m, input[0], 6, 6, layer, 3);
+ auto result = swin::layer(m, input[0], 6, 6, layer, 3, true);
ASSERT(result.w_down == 3 && result.h_down == 3);
return {result.x_down};
}
@@ -292,29 +310,29 @@ DEF(biref_swin_transformer)(model_ref m, span input, param_dict const& p
.embed_dim = 8,
.window_size = 3,
.layers = {
- swin_layer_t{2, 2, 8 * 1, true},
- swin_layer_t{2, 2, 8 * 2, true},
- swin_layer_t{2, 4, 8 * 4, true},
- swin_layer_t{2, 2, 8 * 8, false},
+ swin_layer_t{2, 2, 8 * 1},
+ swin_layer_t{2, 2, 8 * 2},
+ swin_layer_t{2, 4, 8 * 4},
+ swin_layer_t{2, 2, 8 * 8},
}};
- auto rel_pos_index = birefnet::create_relative_position_index(m, 3);
+ auto rel_pos_index = swin::create_relative_position_index(m, 3);
auto attn_masks = std::array{
- birefnet::create_attention_mask(m, 8, 8, 3), birefnet::create_attention_mask(m, 4, 4, 3),
- birefnet::create_attention_mask(m, 2, 2, 3), birefnet::create_attention_mask(m, 1, 1, 3)};
+ swin::create_attention_mask(m, 8, 8, 3), swin::create_attention_mask(m, 4, 4, 3),
+ swin::create_attention_mask(m, 2, 2, 3), swin::create_attention_mask(m, 1, 1, 3)};
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
transfer_to_backend(rel_pos_index);
for (auto&& attn_mask : attn_masks) {
transfer_to_backend(attn_mask);
}
- auto result = birefnet::swin_transformer(m, input[0], swinp);
+ auto result = swin_encode(m, input[0], swinp);
return {result[0], result[1], result[2], result[3]};
}
DEF(biref_encode)(model_ref m, span input, param_dict const& p) {
- birefnet::swin_result xs, xs_low;
+ swin_result xs, xs_low;
for (int i = 0; i < 4; ++i) {
- xs[i] = m.find(format("input{}", i).c_str());
- xs_low[i] = m.find(format("input_low{}", i).c_str());
+ xs[i] = m.find(format("xs{}", i).c_str());
+ xs_low[i] = m.find(format("xs_low{}", i).c_str());
}
birefnet::encode_concat(m, xs, xs_low);
return std::vector{xs[0], xs[1], xs[2], xs[3]};
@@ -341,7 +359,7 @@ DEF(biref_image_to_patches_2)(model_ref m, span input, param_dict const&
}
DEF(biref_decode)(model_ref m, span input, param_dict const& p) {
- birefnet::swin_result features;
+ swin_result features;
for (int i = 0; i < 4; ++i) {
features[i] = m.find(format("x{}", i + 1).c_str());
}
@@ -402,6 +420,63 @@ DEF(esrgan_rrdbnet)(model_ref m, span input, param_dict const& p) {
return {esrgan_generate(m, input[0], params)};
}
+//
+// DINO
+
+DEF(dino_interpolate_pos_encoding)(model_ref m, span input, param_dict const& p) {
+ int s = p.get("img_size", 64);
+ int patch_size = p.get("patch_size", 16);
+ return {dino::interpolate_pos_encoding(m, input[0], s, s, patch_size)};
+}
+
+DEF(dino_prepare_tokens)(model_ref m, span input, param_dict const& p) {
+ return {dino::prepare_tokens(m, input[0], 4)};
+}
+
+DEF(dino_attention)(model_ref m, span input, param_dict const& p) {
+ if (p.get("flash_attn", 0) != 0) {
+ m.flags |= model_build_flag::flash_attention;
+ }
+ return {dino::attention(m, input[0], p.get("n_heads", 8))};
+}
+
+DEF(dino_block)(model_ref m, span input, param_dict const& p) {
+ dino_params params{};
+ params.n_heads = p.get("n_heads", 8);
+ return {dino::layer(m, input[0], params)};
+}
+
+DEF(dino_intermediate_layers)(model_ref m, span input, param_dict const& p) {
+ dino_params params{};
+ params.patch_size = 4;
+ params.embed_dim = 6;
+ params.n_layers = 4;
+ params.n_heads = 3;
+ auto layers = std::array{0, 1, 2, 3};
+ return dino::get_intermediate_layers(m, input[0], layers, params);
+}
+
+//
+// Depth Anything
+
+DEF(depthany_feature_fusion)(model_ref m, span input, param_dict const& p) {
+ if (input.size() == 1) {
+ int64_t size[] = {8, 8, 6, 1};
+ return {dpt::feature_fusion(m, input[0], nullptr, size)};
+ } else {
+ ASSERT(input.size() == 2);
+ return {dpt::feature_fusion(m, input[0], input[1])};
+ }
+}
+
+DEF(depthany_head)(model_ref m, span input, param_dict const& p) {
+ int patch_w = p.get("patch_w", 8);
+ int patch_h = p.get("patch_h", 8);
+ tensor fused = dpt::neck(m, input, patch_w, patch_h);
+ tensor depth = dpt::head(m, fused, patch_w * 14, patch_h * 14, 1.0f);
+ return {depth};
+}
+
//
// Workbench implementation
//
@@ -419,19 +494,19 @@ param_dict build_dict(span raw_params) {
param.name = raw.name;
switch (param_type(raw.type)) {
- case param_type::int32:
- param.type = param_type::int32;
- param.value.i = std::stoi(raw.value);
- break;
- case param_type::float32:
- param.type = param_type::float32;
- param.value.f = std::stof(raw.value);
- break;
- case param_type::string:
- param.type = param_type::string;
- param.value.s = raw.value;
- break;
- default: throw except("Unknown parameter type");
+ case param_type::int32:
+ param.type = param_type::int32;
+ param.value.i = std::stoi(raw.value);
+ break;
+ case param_type::float32:
+ param.type = param_type::float32;
+ param.value.f = std::stof(raw.value);
+ break;
+ case param_type::string:
+ param.type = param_type::string;
+ param.value.s = raw.value;
+ break;
+ default: throw except("Unknown parameter type");
}
dict.params.push_back(param);
}
@@ -470,7 +545,7 @@ char const* param_dict::get(char const* name, char const* default_value) const {
struct raw_tensor {
char const* name;
- float* data;
+ byte* data;
int32_t type_;
int32_t ne[4];
@@ -479,7 +554,6 @@ struct raw_tensor {
size_t size_bytes() const { return size() * ggml_type_size(type()); }
};
-
struct test_case {
char const* name;
test_function func;
@@ -533,16 +607,15 @@ void workbench_run(
for (raw_tensor const& raw : tensors) {
auto tensor = ggml_new_tensor_4d(
m.weights_context, raw.type(), raw.ne[0], raw.ne[1], raw.ne[2], raw.ne[3]);
- if (raw.name && raw.name[0] != '\0' && raw.name != std::string_view("input")) {
- ggml_set_name(tensor, raw.name);
- } else {
+ ggml_set_name(tensor, raw.name);
+ if (std::string_view(raw.name).starts_with("input")) {
inputs.push_back(tensor);
}
}
model_allocate(weights, w.current_backend);
for (raw_tensor const& raw : tensors) {
- transfer_to_backend(m.weights(raw.name), span(raw.data, raw.size()));
+ transfer_to_backend(m.weights(raw.name), span(raw.data, raw.size_bytes()));
}
param_dict test_params = build_dict(params);
@@ -576,7 +649,7 @@ void workbench_run(
ggml_backend_tensor_get(outputs[i], data_ptr, 0, ggml_nbytes(outputs[i]));
output_raw[i].name = ggml_get_name(outputs[i]);
- output_raw[i].data = reinterpret_cast(data_ptr);
+ output_raw[i].data = reinterpret_cast(data_ptr);
output_raw[i].type_ = int32_t(outputs[i]->type);
output_raw[i].ne[0] = outputs[i]->ne[0];
output_raw[i].ne[1] = outputs[i]->ne[1];
@@ -594,7 +667,8 @@ extern "C" {
#ifdef _MSC_VER
__declspec(dllexport)
#endif
-int32_t visp_workbench(
+int32_t
+visp_workbench(
char const* testcase,
visp::raw_tensor const* inputs,
int32_t n_inputs,
diff --git a/tests/workbench.py b/tests/workbench.py
index 7e7da42..0095fd0 100644
--- a/tests/workbench.py
+++ b/tests/workbench.py
@@ -32,6 +32,7 @@ class RawParam(ctypes.Structure):
def torch_to_raw_tensor(name: str, tensor: torch.Tensor):
tensor_types = {
torch.float32: 0, # GGML_TYPE_F32
+ torch.float16: 1, # GGML_TYPE_F16
torch.int32: 26, # GGML_TYPE_I32
}
t = tensor.contiguous()
@@ -112,7 +113,7 @@ def invoke_test(
backend: str = "cpu",
):
input = input if isinstance(input, list) else [input]
- raw_inputs = [torch_to_raw_tensor("", tensor) for tensor in input]
+ raw_inputs = [torch_to_raw_tensor(f"input{i}", tensor) for i, tensor in enumerate(input)]
raw_inputs += [torch_to_raw_tensor(name, tensor) for name, tensor in state.items()]
input_tensors = [t for _, t in raw_inputs]
input_tensors # keep the tensors alive