@@ -12,8 +12,12 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
1212brange = range
1313
1414def range (d0 , d1 = None , d2 = None , d3 = None , dim = - 1 , dtype = f32 ):
15+
1516 if not isinstance (dtype , c_int ):
16- raise TypeError ("Invalid dtype" )
17+ if isinstance (dtype , int ):
18+ dtype = c_int (dtype )
19+ else :
20+ raise TypeError ("Invalid dtype" )
1721
1822 out = array ()
1923 dims = dim4 (d0 , d1 , d2 , d3 )
@@ -24,7 +28,10 @@ def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
2428
2529def iota (d0 , d1 = None , d2 = None , d3 = None , dim = - 1 , tile_dims = None , dtype = f32 ):
2630 if not isinstance (dtype , c_int ):
27- raise TypeError ("Invalid dtype" )
31+ if isinstance (dtype , int ):
32+ dtype = c_int (dtype )
33+ else :
34+ raise TypeError ("Invalid dtype" )
2835
2936 out = array ()
3037 dims = dim4 (d0 , d1 , d2 , d3 )
@@ -42,7 +49,10 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
4249def randu (d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
4350
4451 if not isinstance (dtype , c_int ):
45- raise TypeError ("Invalid dtype" )
52+ if isinstance (dtype , int ):
53+ dtype = c_int (dtype )
54+ else :
55+ raise TypeError ("Invalid dtype" )
4656
4757 out = array ()
4858 dims = dim4 (d0 , d1 , d2 , d3 )
@@ -53,7 +63,10 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
5363def randn (d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
5464
5565 if not isinstance (dtype , c_int ):
56- raise TypeError ("Invalid dtype" )
66+ if isinstance (dtype , int ):
67+ dtype = c_int (dtype )
68+ else :
69+ raise TypeError ("Invalid dtype" )
5770
5871 out = array ()
5972 dims = dim4 (d0 , d1 , d2 , d3 )
@@ -72,7 +85,10 @@ def get_seed():
7285def identity (d0 , d1 = None , d2 = None , d3 = None , dtype = f32 ):
7386
7487 if not isinstance (dtype , c_int ):
75- raise TypeError ("Invalid dtype" )
88+ if isinstance (dtype , int ):
89+ dtype = c_int (dtype )
90+ else :
91+ raise TypeError ("Invalid dtype" )
7692
7793 out = array ()
7894 dims = dim4 (d0 , d1 , d2 , d3 )
0 commit comments