Skip to content

Commit e84cbb7

Browse files
committed
Adding device helper functions
1 parent f728d7a commit e84cbb7

File tree

4 files changed

+85
-5
lines changed

4 files changed

+85
-5
lines changed

arrayfire/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,4 @@
22
from .data import *
33
from .util import *
44
from .algorithm import *
5-
6-
def info():
7-
clib.af_info()
5+
from .device import *

arrayfire/device.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from .library import *
2+
from ctypes import *
3+
from .util import (safe_call, to_str)
4+
5+
def info():
6+
safe_call(clib.af_info())
7+
8+
def device_info():
9+
c_char_256 = c_char * 256
10+
device_name = c_char_256()
11+
backend_name = c_char_256()
12+
toolkit = c_char_256()
13+
compute = c_char_256()
14+
15+
safe_call(clib.af_device_info(pointer(device_name), pointer(backend_name), \
16+
pointer(toolkit), pointer(compute)))
17+
dev_info = {}
18+
dev_info['device'] = to_str(device_name)
19+
dev_info['backend'] = to_str(backend_name)
20+
dev_info['toolkit'] = to_str(toolkit)
21+
dev_info['compute'] = to_str(compute)
22+
23+
return dev_info
24+
25+
def get_device_count():
26+
c_num = c_int(0)
27+
safe_call(clib.af_get_device_count(pointer(c_num)))
28+
return c_num.value
29+
30+
def get_device():
31+
c_dev = c_int(0)
32+
safe_call(clib.af_get_device(pointer(c_dev)))
33+
return c_dev.value
34+
35+
def set_device(num):
36+
safe_call(clib.af_set_device(num))
37+
38+
def is_dbl_supported(device=None):
39+
dev = device if device is not None else get_device()
40+
res = c_bool(False)
41+
safe_call(clib.af_get_dbl_support(pointer(res), dev))
42+
return res.value
43+
44+
def sync(device=None):
45+
dev = device if device is not None else get_device()
46+
safe_call(clib.af_sync(dev))
47+
48+
def device_mem_info():
49+
alloc_bytes = c_size_t(0)
50+
alloc_buffers = c_size_t(0)
51+
lock_bytes = c_size_t(0)
52+
lock_buffers = c_size_t(0)
53+
safe_call(clib.af_device_mem_info(pointer(alloc_bytes), pointer(alloc_buffers),\
54+
pointer(lock_bytes), pointer(lock_buffers)))
55+
mem_info = {}
56+
mem_info['alloc'] = {'buffers' : alloc_buffers.value, 'bytes' : alloc_bytes.value}
57+
mem_info['lock'] = {'buffers' : lock_buffers.value, 'bytes' : lock_bytes.value}
58+
return mem_info

arrayfire/util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
from .library import *
32

43
def dim4(d0=1, d1=1, d2=1, d3=1):
@@ -22,9 +21,12 @@ def dim4_tuple(dims):
2221
def is_valid_scalar(a):
2322
return isinstance(a, float) or isinstance(a, int) or isinstance(a, complex)
2423

24+
def to_str(c_str):
25+
return str(c_str.value.decode('utf-8'))
26+
2527
def safe_call(af_error):
2628
if (af_error != AF_SUCCESS.value):
2729
c_err_str = c_char_p(0)
2830
c_err_len = c_longlong(0)
2931
clib.af_get_last_error(pointer(c_err_str), pointer(c_err_len))
30-
raise RuntimeError(c_err_str.value, af_error)
32+
raise RuntimeError('test', to_str(c_err_str), af_error)

tests/simple_device.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/python
2+
import arrayfire as af
3+
4+
af.info()
5+
print(af.device_info())
6+
print(af.get_device_count())
7+
print(af.is_dbl_supported())
8+
af.sync()
9+
10+
print('starting the loop')
11+
for k in range(af.get_device_count()):
12+
af.set_device(k)
13+
dev = af.get_device()
14+
assert(k == dev)
15+
16+
print(af.is_dbl_supported(k))
17+
18+
a = af.randu(100, 100)
19+
af.sync(dev)
20+
mem_info = af.device_mem_info()
21+
assert(mem_info['alloc']['buffers'] == 1)
22+
assert(mem_info[ 'lock']['buffers'] == 1)

0 commit comments

Comments
 (0)