Skip to content

Commit 381dac9

Browse files
committed
Porting all functions from data.h
1 parent 5acbfe1 commit 381dac9

File tree

1 file changed

+126
-18
lines changed

1 file changed

+126
-18
lines changed

arrayfire/data.py

Lines changed: 126 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,67 @@
1+
from sys import version_info
12
from .library import *
23
from .array import *
34
from .util import *
45

6+
7+
def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
8+
9+
if not isinstance(dtype, c_int):
10+
raise TypeError("Invalid dtype")
11+
12+
out = array()
13+
dims = dim4(d0, d1, d2, d3)
14+
15+
if isinstance(val, complex):
16+
c_real = c_double(val.real)
17+
c_imag = c_double(val.imag)
18+
19+
if (dtype != c32 and dtype != c64):
20+
dtype = c32
21+
22+
clib.af_constant_complex(pointer(out.arr), c_real, c_imag, 4, pointer(dims), dtype)
23+
elif dtype == s64:
24+
c_val = c_longlong(val.real)
25+
clib.af_constant_long(pointer(out.arr), c_val, 4, pointer(dims))
26+
elif dtype == u64:
27+
c_val = c_ulonglong(val.real)
28+
clib.af_constant_ulong(pointer(out.arr), c_val, 4, pointer(dims))
29+
else:
30+
c_val = c_double(val)
31+
clib.af_constant(pointer(out.arr), c_val, 4, pointer(dims), dtype)
32+
return out
33+
34+
# Store builtin range function to be used later
35+
brange = range
36+
37+
def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
38+
if not isinstance(dtype, c_int):
39+
raise TypeError("Invalid dtype")
40+
41+
out = array()
42+
dims = dim4(d0, d1, d2, d3)
43+
44+
clib.af_range(pointer(out.arr), 4, pointer(dims), dim, dtype)
45+
return out
46+
47+
48+
def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
49+
if not isinstance(dtype, c_int):
50+
raise TypeError("Invalid dtype")
51+
52+
out = array()
53+
dims = dim4(d0, d1, d2, d3)
54+
td=[1]*4
55+
56+
if tile_dims is not None:
57+
for i in brange(len(tile_dims)):
58+
td[i] = tile_dims[i]
59+
60+
tdims = dim4(td[0], td[1], td[2], td[3])
61+
62+
clib.af_iota(pointer(out.arr), 4, pointer(dims), 4, pointer(tdims), dtype)
63+
return out
64+
565
def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
666

767
if not isinstance(dtype, c_int):
@@ -24,6 +84,14 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
2484
clib.af_randn(pointer(out.arr), 4, pointer(dims), dtype)
2585
return out
2686

87+
def set_seed(seed=0):
88+
clib.af_set_seed(c_ulonglong(seed))
89+
90+
def get_seed():
91+
seed = c_ulonglong(0)
92+
clib.af_get_seed(pointer(seed))
93+
return seed.value
94+
2795
def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
2896

2997
if not isinstance(dtype, c_int):
@@ -35,29 +103,69 @@ def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
35103
clib.af_identity(pointer(out.arr), 4, pointer(dims), dtype)
36104
return out
37105

38-
def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
106+
def diag(a, num=0, extract=True):
107+
out = array()
108+
if extract:
109+
clib.af_diag_extract(pointer(out.arr), a.arr, c_int(num))
110+
else:
111+
clib.af_diag_create(pointer(out.arr), a.arr, c_int(num))
112+
return out
39113

40-
if not isinstance(dtype, c_int):
41-
raise TypeError("Invalid dtype")
114+
def join(dim, first, second, third=None, fourth=None):
115+
out = array()
116+
if (third is None and fourth is None):
117+
clib.af_join(pointer(out.arr), dim, first.arr, second.arr)
118+
else:
119+
c_array_vec = dim4(first, second, 0, 0)
120+
num = 2
121+
if third is not None:
122+
c_array_vec[num] = third.arr
123+
num+=1
124+
if fourth is not None:
125+
c_array_vec[num] = fourth.arr
126+
num+=1
127+
128+
clib.af_join_many(pointer(out.arr), dim, num, pointer(c_array_vec))
42129

130+
131+
def tile(a, d0, d1=1, d2=1, d3=1):
132+
out = array()
133+
clib.af_tile(pointer(out.arr), a.arr, d0, d1, d2, d3)
134+
return out
135+
136+
137+
def reorder(a, d0=1, d1=0, d2=2, d3=3):
138+
out = array()
139+
clib.af_reorder(pointer(out.arr), a.arr, d0, d1, d2, d3)
140+
return out
141+
142+
def shift(a, d0, d1=0, d2=0, d3=0):
143+
out = array()
144+
clib.af_shift(pointer(out.arr), a.arr, d0, d1, d2, d3)
145+
return out
146+
147+
def moddims(a, d0, d1=1, d2=1, d3=1):
43148
out = array()
44149
dims = dim4(d0, d1, d2, d3)
150+
clib.af_moddims(pointer(out.arr), a.arr, 4, pointer(dims))
151+
return out
45152

46-
if isinstance(val, complex):
47-
c_real = c_double(val.real)
48-
c_imag = c_double(val.imag)
153+
def flat(a):
154+
out = array()
155+
clib.af_flat(pointer(out.arr), a.arr)
156+
return out
49157

50-
if (dtype != c32 and dtype != c64):
51-
dtype = c32
158+
def flip(a, dim=0):
159+
out = array()
160+
clib.af_flip(pointer(out.arr), a.arr, c_int(dim))
161+
return out
52162

53-
clib.af_constant_complex(pointer(out.arr), c_real, c_imag, 4, pointer(dims), dtype)
54-
elif dtype == s64:
55-
c_val = c_longlong(val.real)
56-
clib.af_constant_long(pointer(out.arr), c_val, 4, pointer(dims))
57-
elif dtype == u64:
58-
c_val = c_ulonglong(val.real)
59-
clib.af_constant_ulong(pointer(out.arr), c_val, 4, pointer(dims))
60-
else:
61-
c_val = c_double(val)
62-
clib.af_constant(pointer(out.arr), c_val, 4, pointer(dims), dtype)
163+
def lower(a, is_unit_diag=False):
164+
out = array()
165+
clib.af_lower(pointer(out.arr), a.arr, is_unit_diag)
166+
return out
167+
168+
def upper(a, is_unit_diag=False):
169+
out = array()
170+
clib.af_upper(pointer(out.arr), a.arr, is_unit_diag)
63171
return out

0 commit comments

Comments
 (0)