Skip to content

Commit 45d2ba3

Browse files
committed
python: fix
1 parent f7e8769 commit 45d2ba3

File tree

8 files changed

+95
-34
lines changed

8 files changed

+95
-34
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
os: [ubuntu-22.04, windows-latest, macos-14]
16+
fail-fast: false
1617

1718
runs-on: ${{ matrix.os }}
1819
timeout-minutes: 15

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ if(VISP_CI OR VISP_DEV)
145145
set_target_properties(vision-cli PROPERTIES INSTALL_RPATH "\$ORIGIN/../${VISP_LIB_INSTALL_DIR}")
146146
endif()
147147

148-
install(DIRECTORY bindings/python DESTINATION ${CMAKE_INSTALL_PREFIX} PATTERN "__pycache__" EXCLUDE)
148+
install(DIRECTORY bindings/python DESTINATION . PATTERN "__pycache__" EXCLUDE)
149149

150150
include(CMakePackageConfigHelpers)
151151

bindings/python/visioncpp/_lib.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,21 @@ def _load():
107107
if path.exists():
108108
try:
109109
lib = ctypes.CDLL(str(path))
110-
return lib
110+
return lib, path
111111
except OSError as e:
112112
error = e
113113
continue
114114
raise OSError(f"Could not load vision.cpp library from paths: {error}")
115115

116116

117117
def init():
118-
lib = _load()
118+
lib, path = _load()
119119

120120
lib.visp_get_last_error.restype = c_char_p
121121

122+
lib.visp_backend_load_all.argtypes = [c_char_p]
123+
lib.visp_backend_load_all.restype = c_int32
124+
122125
lib.visp_image_destroy.argtypes = [ImageData]
123126
lib.visp_image_destroy.restype = None
124127

@@ -158,6 +161,12 @@ def init():
158161
]
159162
lib.visp_model_compute.restype = c_int32
160163

164+
# On Linux, libvisioncpp might be in lib/ and ggml backends in bin/
165+
if path.parent.name == "lib":
166+
bin_dir = path.parent.parent / "bin"
167+
if bin_dir.exists():
168+
lib.visp_backend_load_all(str(bin_dir).encode())
169+
161170
return lib
162171

163172

src/visp/c-api.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ VISP_API void visp_image_destroy(image_data* img) {
156156

157157
// device
158158

159+
VISP_API int32_t visp_backend_load_all(char const* dir) {
160+
ggml_backend_load_all_from_path(dir);
161+
return (int32_t)ggml_backend_reg_count();
162+
}
163+
159164
VISP_API int32_t visp_device_init(int32_t type, backend_device** out_device) {
160165
return handle_errors([&]() {
161166
if (type == 0) {
@@ -188,7 +193,7 @@ VISP_API char const* visp_device_description(backend_device const* d) {
188193

189194
// models
190195

191-
struct any_model {};
196+
struct any_model;
192197

193198
VISP_API int32_t visp_model_detect_family(char const* filepath, int32_t* out_family) {
194199
return handle_errors([&]() {

src/visp/ml.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ backend_device backend_init() {
5959
load_ggml_backends();
6060
backend_device b;
6161
b.handle.reset(ggml_backend_init_best());
62+
if (!b.handle) {
63+
throw except("Failed to initialize backend, no suitable device available");
64+
}
6265
b.device = ggml_backend_get_device(b.handle.get());
63-
ASSERT(b.handle, "Failed to initialize backend");
6466
return b;
6567
}
6668

tests/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ target_link_libraries(vision-workbench PRIVATE visioncpp ggml ${VISP_FMT_LINK})
5151
if(VISP_CI)
5252
set(PYTHON_TESTS_ARGS "--ci")
5353
endif()
54-
add_test(NAME python COMMAND uv run pytest -vs tests ${PYTHON_TESTS_ARGS})
54+
add_test(NAME python
55+
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
56+
COMMAND uv run pytest -vs tests ${PYTHON_TESTS_ARGS})
5557

5658
#
5759
# Benchmarks

tests/test_primitives.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
from . import workbench
5-
from .workbench import input_tensor, to_nchw, to_nhwc
5+
from .workbench import input_tensor, to_nchw, to_nhwc, tensors_match
66

77

88
def test_linear():
@@ -13,7 +13,7 @@ def test_linear():
1313
result = workbench.invoke_test("linear", x, dict(weight=weight, bias=bias))
1414

1515
expected = torch.nn.functional.linear(x, weight, bias)
16-
assert torch.allclose(result, expected)
16+
assert tensors_match(result, expected)
1717

1818

1919
@pytest.mark.parametrize("scenario", ["stride_1_pad_0", "stride_2_pad_1", "dilation_2_pad_2"])
@@ -48,7 +48,7 @@ def test_conv_2d_depthwise(scenario: str, memory_layout: str, batch: str, backen
4848
if memory_layout == "nhwc":
4949
result = to_nchw(result)
5050

51-
assert torch.allclose(result, expected)
51+
assert tensors_match(result, expected)
5252

5353

5454
@pytest.mark.parametrize("scenario", ["3x3", "5x5", "stride2", "nhwc"])
@@ -76,7 +76,7 @@ def test_conv_transpose_2d(scenario: str):
7676
if scenario == "nhwc":
7777
result = to_nchw(result)
7878

79-
assert torch.allclose(result, expected, rtol=1e-2)
79+
assert tensors_match(result, expected, rtol=1e-2)
8080

8181

8282
# def test_batch_norm_2d():
@@ -106,7 +106,7 @@ def test_layer_norm():
106106
result = workbench.invoke_test("layer_norm", x, dict(weight=weight, bias=bias))
107107

108108
expected = torch.nn.functional.layer_norm(x, [dim], weight, bias, eps=1e-5)
109-
assert torch.allclose(result, expected, atol=1e-6)
109+
assert tensors_match(result, expected, atol=1e-6)
110110

111111

112112
@pytest.mark.parametrize("backend", ["cpu", "vulkan"])
@@ -133,7 +133,7 @@ def test_window_partition(backend: str):
133133

134134
result = workbench.invoke_test("sam_window_partition", x, {}, backend=backend)
135135

136-
assert torch.allclose(result, expected)
136+
assert tensors_match(result, expected)
137137

138138

139139
@pytest.mark.parametrize("shift", [(0, 2, -1, 0), (0, -2, 0, 3)])
@@ -147,7 +147,7 @@ def test_roll(shift: tuple[int, int, int, int], backend: str):
147147
params = dict(s0=shift[3], s1=shift[2], s2=shift[1], s3=shift[0])
148148
result = workbench.invoke_test("roll", x, {}, params, backend)
149149

150-
assert torch.allclose(result, expected)
150+
assert tensors_match(result, expected)
151151

152152

153153
@pytest.mark.parametrize("mode", ["bilinear", "bicubic"])
@@ -169,4 +169,4 @@ def test_interpolate(mode: str, align_corners: bool, size: str, scale: float, ba
169169

170170
params = dict(mode=mode, h=target[0], w=target[1], align_corners=1 if align_corners else 0)
171171
result = workbench.invoke_test("interpolate", x, {}, params, backend)
172-
assert torch.allclose(result, expected)
172+
assert tensors_match(result, expected)

tests/workbench.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import ctypes
2-
from functools import reduce
3-
from typing import Mapping
42
import torch
53
import os
6-
4+
import platform
5+
from functools import reduce
6+
from typing import Mapping
77
from pathlib import Path
8+
from torch import Tensor
89

910
float_ptr = ctypes.POINTER(ctypes.c_float)
1011

@@ -90,20 +91,40 @@ def encode_params(params: Mapping[str, str | int | float]):
9091
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
9192

9293
root_dir = Path(__file__).parent.parent
93-
bin_dir = root_dir / "build" / "bin"
94-
95-
lib = ctypes.CDLL(str(bin_dir / "vision-workbench.dll"))
96-
lib.visp_workbench.argtypes = [
97-
ctypes.c_char_p,
98-
ctypes.POINTER(RawTensor),
99-
ctypes.c_int32,
100-
ctypes.POINTER(RawParam),
101-
ctypes.c_int32,
102-
ctypes.POINTER(ctypes.POINTER(RawTensor)),
103-
ctypes.POINTER(ctypes.c_int32),
104-
ctypes.c_int32,
105-
]
106-
lib.visp_workbench.restype = ctypes.c_int32
94+
95+
def _load_library():
96+
system = platform.system().lower()
97+
if system == "windows":
98+
prefix = ""
99+
suffix = ".dll"
100+
libdir = "bin"
101+
elif system == "darwin":
102+
prefix = "lib"
103+
suffix = ".dylib"
104+
libdir = "lib"
105+
else: # assume Linux / Unix
106+
prefix = "lib"
107+
suffix = ".so"
108+
libdir = "lib"
109+
lib_path = root_dir / "build" / libdir / f"{prefix}vision-workbench{suffix}"
110+
return ctypes.CDLL(str(lib_path))
111+
112+
try:
113+
lib = _load_library()
114+
115+
lib.visp_workbench.argtypes = [
116+
ctypes.c_char_p,
117+
ctypes.POINTER(RawTensor),
118+
ctypes.c_int32,
119+
ctypes.POINTER(RawParam),
120+
ctypes.c_int32,
121+
ctypes.POINTER(ctypes.POINTER(RawTensor)),
122+
ctypes.POINTER(ctypes.c_int32),
123+
ctypes.c_int32,
124+
]
125+
lib.visp_workbench.restype = ctypes.c_int32
126+
except OSError as e:
127+
print(f"Error loading vision-workbench library: {e}")
107128

108129

109130
def invoke_test(
@@ -168,11 +189,14 @@ def to_nhwc(tensor: torch.Tensor):
168189
return tensor.permute(0, 2, 3, 1).contiguous()
169190

170191

171-
def to_nchw(tensor: torch.Tensor):
192+
def to_nchw(tensor: Tensor|list[Tensor]|None):
193+
assert tensor is not None
194+
if isinstance(tensor, list):
195+
return [t.permute(0, 3, 1, 2).contiguous() for t in tensor]
172196
return tensor.permute(0, 3, 1, 2).contiguous()
173197

174198

175-
def convert_to_nhwc(state: dict[str, torch.Tensor], key=""):
199+
def convert_to_nhwc(state: dict[str, Tensor], key=""):
176200
for k, v in state.items():
177201
is_conv = (
178202
v.ndim == 4
@@ -250,7 +274,25 @@ def fuse_conv_2d_batch_norm(
250274
return False # no match
251275

252276

253-
def print_results(result: torch.Tensor, expected: torch.Tensor):
277+
def print_results(result: Tensor, expected: Tensor):
254278
print("\ntorch seed:", torch.initial_seed())
255279
print("\nresult -----", result, sep="\n")
256280
print("\nexpected ---", expected, sep="\n")
281+
282+
283+
def tensors_match(
284+
result: Tensor | list[Tensor] | None,
285+
expected: Tensor | list[Tensor],
286+
rtol=1e-3,
287+
atol=1e-5,
288+
show=False,
289+
):
290+
assert result is not None, "No result returned"
291+
if isinstance(expected, list):
292+
assert isinstance(result, list), "Result is not a list"
293+
assert len(result) == len(expected), f"Expected {len(expected)} tensors, got {len(result)}"
294+
return all(tensors_match(r, e, rtol, atol, show) for r, e in zip(result, expected))
295+
assert isinstance(result, Tensor), "Result is not a tensor"
296+
if show:
297+
print_results(result, expected)
298+
return torch.allclose(result, expected, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)