1+ from sys import version_info
12from .library import *
23from .array import *
34from .util import *
45
6+
7+ def constant (val , d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
8+
9+ if not isinstance (dtype , c_int ):
10+ raise TypeError ("Invalid dtype" )
11+
12+ out = array ()
13+ dims = dim4 (d0 , d1 , d2 , d3 )
14+
15+ if isinstance (val , complex ):
16+ c_real = c_double (val .real )
17+ c_imag = c_double (val .imag )
18+
19+ if (dtype != c32 and dtype != c64 ):
20+ dtype = c32
21+
22+ clib .af_constant_complex (pointer (out .arr ), c_real , c_imag , 4 , pointer (dims ), dtype )
23+ elif dtype == s64 :
24+ c_val = c_longlong (val .real )
25+ clib .af_constant_long (pointer (out .arr ), c_val , 4 , pointer (dims ))
26+ elif dtype == u64 :
27+ c_val = c_ulonglong (val .real )
28+ clib .af_constant_ulong (pointer (out .arr ), c_val , 4 , pointer (dims ))
29+ else :
30+ c_val = c_double (val )
31+ clib .af_constant (pointer (out .arr ), c_val , 4 , pointer (dims ), dtype )
32+ return out
33+
34+ # Store builtin range function to be used later
35+ brange = range
36+
37+ def range (d0 , d1 = None , d2 = None , d3 = None , dim = - 1 , dtype = f32 ):
38+ if not isinstance (dtype , c_int ):
39+ raise TypeError ("Invalid dtype" )
40+
41+ out = array ()
42+ dims = dim4 (d0 , d1 , d2 , d3 )
43+
44+ clib .af_range (pointer (out .arr ), 4 , pointer (dims ), dim , dtype )
45+ return out
46+
47+
48+ def iota (d0 , d1 = None , d2 = None , d3 = None , dim = - 1 , tile_dims = None , dtype = f32 ):
49+ if not isinstance (dtype , c_int ):
50+ raise TypeError ("Invalid dtype" )
51+
52+ out = array ()
53+ dims = dim4 (d0 , d1 , d2 , d3 )
54+ td = [1 ]* 4
55+
56+ if tile_dims is not None :
57+ for i in brange (len (tile_dims )):
58+ td [i ] = tile_dims [i ]
59+
60+ tdims = dim4 (td [0 ], td [1 ], td [2 ], td [3 ])
61+
62+ clib .af_iota (pointer (out .arr ), 4 , pointer (dims ), 4 , pointer (tdims ), dtype )
63+ return out
64+
565def randu (d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
666
767 if not isinstance (dtype , c_int ):
@@ -24,6 +84,14 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
2484 clib .af_randn (pointer (out .arr ), 4 , pointer (dims ), dtype )
2585 return out
2686
87+ def set_seed (seed = 0 ):
88+ clib .af_set_seed (c_ulonglong (seed ))
89+
90+ def get_seed ():
91+ seed = c_ulonglong (0 )
92+ clib .af_get_seed (pointer (seed ))
93+ return seed .value
94+
2795def identity (d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
2896
2997 if not isinstance (dtype , c_int ):
@@ -35,29 +103,69 @@ def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
35103 clib .af_identity (pointer (out .arr ), 4 , pointer (dims ), dtype )
36104 return out
37105
38- def constant (val , d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
106+ def diag (a , num = 0 , extract = True ):
107+ out = array ()
108+ if extract :
109+ clib .af_diag_extract (pointer (out .arr ), a .arr , c_int (num ))
110+ else :
111+ clib .af_diag_create (pointer (out .arr ), a .arr , c_int (num ))
112+ return out
39113
40- if not isinstance (dtype , c_int ):
41- raise TypeError ("Invalid dtype" )
114+ def join (dim , first , second , third = None , fourth = None ):
115+ out = array ()
116+ if (third is None and fourth is None ):
117+ clib .af_join (pointer (out .arr ), dim , first .arr , second .arr )
118+ else :
119+ c_array_vec = dim4 (first , second , 0 , 0 )
120+ num = 2
121+ if third is not None :
122+ c_array_vec [num ] = third .arr
123+ num += 1
124+ if fourth is not None :
125+ c_array_vec [num ] = fourth .arr
126+ num += 1
127+
128+ clib .af_join_many (pointer (out .arr ), dim , num , pointer (c_array_vec ))
42129
130+
131+ def tile (a , d0 , d1 = 1 , d2 = 1 , d3 = 1 ):
132+ out = array ()
133+ clib .af_tile (pointer (out .arr ), a .arr , d0 , d1 , d2 , d3 )
134+ return out
135+
136+
137+ def reorder (a , d0 = 1 , d1 = 0 , d2 = 2 , d3 = 3 ):
138+ out = array ()
139+ clib .af_reorder (pointer (out .arr ), a .arr , d0 , d1 , d2 , d3 )
140+ return out
141+
142+ def shift (a , d0 , d1 = 0 , d2 = 0 , d3 = 0 ):
143+ out = array ()
144+ clib .af_shift (pointer (out .arr ), a .arr , d0 , d1 , d2 , d3 )
145+ return out
146+
147+ def moddims (a , d0 , d1 = 1 , d2 = 1 , d3 = 1 ):
43148 out = array ()
44149 dims = dim4 (d0 , d1 , d2 , d3 )
150+ clib .af_moddims (pointer (out .arr ), a .arr , 4 , pointer (dims ))
151+ return out
45152
46- if isinstance (val , complex ):
47- c_real = c_double (val .real )
48- c_imag = c_double (val .imag )
153+ def flat (a ):
154+ out = array ()
155+ clib .af_flat (pointer (out .arr ), a .arr )
156+ return out
49157
50- if (dtype != c32 and dtype != c64 ):
51- dtype = c32
158+ def flip (a , dim = 0 ):
159+ out = array ()
160+ clib .af_flip (pointer (out .arr ), a .arr , c_int (dim ))
161+ return out
52162
53- clib .af_constant_complex (pointer (out .arr ), c_real , c_imag , 4 , pointer (dims ), dtype )
54- elif dtype == s64 :
55- c_val = c_longlong (val .real )
56- clib .af_constant_long (pointer (out .arr ), c_val , 4 , pointer (dims ))
57- elif dtype == u64 :
58- c_val = c_ulonglong (val .real )
59- clib .af_constant_ulong (pointer (out .arr ), c_val , 4 , pointer (dims ))
60- else :
61- c_val = c_double (val )
62- clib .af_constant (pointer (out .arr ), c_val , 4 , pointer (dims ), dtype )
163+ def lower (a , is_unit_diag = False ):
164+ out = array ()
165+ clib .af_lower (pointer (out .arr ), a .arr , is_unit_diag )
166+ return out
167+
168+ def upper (a , is_unit_diag = False ):
169+ out = array ()
170+ clib .af_upper (pointer (out .arr ), a .arr , is_unit_diag )
63171 return out
0 commit comments