Skip to content

Commit 8d9b0b9

Browse files
committed
Adding tests for array class
- Also changed checks for data types in data creation
1 parent 80ab970 commit 8d9b0b9

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

arrayfire/array.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@ def create_array(buf, numdims, idims, dtype):
1111
return out_arr
1212

1313
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
14+
1415
if not isinstance(dtype, c_int):
15-
raise TypeError("Invalid dtype")
16+
if isinstance(dtype, int):
17+
dtype = c_int(dtype)
18+
else:
19+
raise TypeError("Invalid dtype")
1620

1721
out = c_longlong(0)
1822
dims = dim4(d0, d1, d2, d3)
@@ -130,7 +134,7 @@ def dims(self):
130134
def type(self):
131135
dty = c_int(f32.value)
132136
safe_call(clib.af_get_type(pointer(dty), self.arr))
133-
return dty
137+
return dty.value
134138

135139
def __add__(self, other):
136140
return binary_func(self, other, clib.af_add)

arrayfire/data.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
1212
brange = range
1313

1414
def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
15+
1516
if not isinstance(dtype, c_int):
16-
raise TypeError("Invalid dtype")
17+
if isinstance(dtype, int):
18+
dtype = c_int(dtype)
19+
else:
20+
raise TypeError("Invalid dtype")
1721

1822
out = array()
1923
dims = dim4(d0, d1, d2, d3)
@@ -24,7 +28,10 @@ def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
2428

2529
def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
2630
if not isinstance(dtype, c_int):
27-
raise TypeError("Invalid dtype")
31+
if isinstance(dtype, int):
32+
dtype = c_int(dtype)
33+
else:
34+
raise TypeError("Invalid dtype")
2835

2936
out = array()
3037
dims = dim4(d0, d1, d2, d3)
@@ -42,7 +49,10 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
4249
def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
4350

4451
if not isinstance(dtype, c_int):
45-
raise TypeError("Invalid dtype")
52+
if isinstance(dtype, int):
53+
dtype = c_int(dtype)
54+
else:
55+
raise TypeError("Invalid dtype")
4656

4757
out = array()
4858
dims = dim4(d0, d1, d2, d3)
@@ -53,7 +63,10 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
5363
def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
5464

5565
if not isinstance(dtype, c_int):
56-
raise TypeError("Invalid dtype")
66+
if isinstance(dtype, int):
67+
dtype = c_int(dtype)
68+
else:
69+
raise TypeError("Invalid dtype")
5770

5871
out = array()
5972
dims = dim4(d0, d1, d2, d3)
@@ -72,7 +85,10 @@ def get_seed():
7285
def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
7386

7487
if not isinstance(dtype, c_int):
75-
raise TypeError("Invalid dtype")
88+
if isinstance(dtype, int):
89+
dtype = c_int(dtype)
90+
else:
91+
raise TypeError("Invalid dtype")
7692

7793
out = array()
7894
dims = dim4(d0, d1, d2, d3)

tests/simple_array.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/usr/bin/python
2+
import arrayfire as af
3+
import array as host
4+
5+
a = af.array([1, 2, 3])
6+
af.print_array(a)
7+
print(a.numdims(), a.dims(), a.type())
8+
9+
a = af.array(host.array('d', [4, 5, 6]))
10+
af.print_array(a)
11+
print(a.numdims(), a.dims(), a.type())
12+
13+
a = af.array(host.array('l', [7, 8, 9] * 4), (2, 5))
14+
af.print_array(a)
15+
print(a.numdims(), a.dims(), a.type())

0 commit comments

Comments
 (0)