1313from .util import *
1414
1515def constant (val , d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
16- out = array ()
16+ out = Array ()
1717 out .arr = constant_array (val , d0 , d1 , d2 , d3 , dtype )
1818 return out
1919
@@ -28,7 +28,7 @@ def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
2828 else :
2929 raise TypeError ("Invalid dtype" )
3030
31- out = array ()
31+ out = Array ()
3232 dims = dim4 (d0 , d1 , d2 , d3 )
3333
3434 safe_call (clib .af_range (ct .pointer (out .arr ), 4 , ct .pointer (dims ), dim , dtype ))
@@ -42,7 +42,7 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
4242 else :
4343 raise TypeError ("Invalid dtype" )
4444
45- out = array ()
45+ out = Array ()
4646 dims = dim4 (d0 , d1 , d2 , d3 )
4747 td = [1 ]* 4
4848
@@ -63,7 +63,7 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
6363 else :
6464 raise TypeError ("Invalid dtype" )
6565
66- out = array ()
66+ out = Array ()
6767 dims = dim4 (d0 , d1 , d2 , d3 )
6868
6969 safe_call (clib .af_randu (ct .pointer (out .arr ), 4 , ct .pointer (dims ), dtype ))
@@ -77,7 +77,7 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
7777 else :
7878 raise TypeError ("Invalid dtype" )
7979
80- out = array ()
80+ out = Array ()
8181 dims = dim4 (d0 , d1 , d2 , d3 )
8282
8383 safe_call (clib .af_randn (ct .pointer (out .arr ), 4 , ct .pointer (dims ), dtype ))
@@ -99,22 +99,22 @@ def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
9999 else :
100100 raise TypeError ("Invalid dtype" )
101101
102- out = array ()
102+ out = Array ()
103103 dims = dim4 (d0 , d1 , d2 , d3 )
104104
105105 safe_call (clib .af_identity (ct .pointer (out .arr ), 4 , ct .pointer (dims ), dtype ))
106106 return out
107107
108108def diag (a , num = 0 , extract = True ):
109- out = array ()
109+ out = Array ()
110110 if extract :
111111 safe_call (clib .af_diag_extract (ct .pointer (out .arr ), a .arr , ct .c_int (num )))
112112 else :
113113 safe_call (clib .af_diag_create (ct .pointer (out .arr ), a .arr , ct .c_int (num )))
114114 return out
115115
116116def join (dim , first , second , third = None , fourth = None ):
117- out = array ()
117+ out = Array ()
118118 if (third is None and fourth is None ):
119119 safe_call (clib .af_join (ct .pointer (out .arr ), dim , first .arr , second .arr ))
120120 else :
@@ -131,43 +131,43 @@ def join(dim, first, second, third=None, fourth=None):
131131
132132
133133def tile (a , d0 , d1 = 1 , d2 = 1 , d3 = 1 ):
134- out = array ()
134+ out = Array ()
135135 safe_call (clib .af_tile (ct .pointer (out .arr ), a .arr , d0 , d1 , d2 , d3 ))
136136 return out
137137
138138
139139def reorder (a , d0 = 1 , d1 = 0 , d2 = 2 , d3 = 3 ):
140- out = array ()
140+ out = Array ()
141141 safe_call (clib .af_reorder (ct .pointer (out .arr ), a .arr , d0 , d1 , d2 , d3 ))
142142 return out
143143
144144def shift (a , d0 , d1 = 0 , d2 = 0 , d3 = 0 ):
145- out = array ()
145+ out = Array ()
146146 safe_call (clib .af_shift (ct .pointer (out .arr ), a .arr , d0 , d1 , d2 , d3 ))
147147 return out
148148
149149def moddims (a , d0 , d1 = 1 , d2 = 1 , d3 = 1 ):
150- out = array ()
150+ out = Array ()
151151 dims = dim4 (d0 , d1 , d2 , d3 )
152152 safe_call (clib .af_moddims (ct .pointer (out .arr ), a .arr , 4 , ct .pointer (dims )))
153153 return out
154154
155155def flat (a ):
156- out = array ()
156+ out = Array ()
157157 safe_call (clib .af_flat (ct .pointer (out .arr ), a .arr ))
158158 return out
159159
160160def flip (a , dim = 0 ):
161- out = array ()
161+ out = Array ()
162162 safe_call (clib .af_flip (ct .pointer (out .arr ), a .arr , ct .c_int (dim )))
163163 return out
164164
165165def lower (a , is_unit_diag = False ):
166- out = array ()
166+ out = Array ()
167167 safe_call (clib .af_lower (ct .pointer (out .arr ), a .arr , is_unit_diag ))
168168 return out
169169
170170def upper (a , is_unit_diag = False ):
171- out = array ()
171+ out = Array ()
172172 safe_call (clib .af_upper (ct .pointer (out .arr ), a .arr , is_unit_diag ))
173173 return out
0 commit comments