@@ -60,9 +60,9 @@ def binary_func(lhs, rhs, c_func):
6060
6161 if (is_number (rhs )):
6262 ldims = dim4_tuple (lhs .dims ())
63- lty = lhs . type ( )
63+ rty = number_dtype ( rhs )
6464 other = array ()
65- other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], lty )
65+ other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], rty )
6666 elif not isinstance (rhs , array ):
6767 raise TypeError ("Invalid parameter to binary function" )
6868
@@ -76,9 +76,9 @@ def binary_funcr(lhs, rhs, c_func):
7676
7777 if (is_number (lhs )):
7878 rdims = dim4_tuple (rhs .dims ())
79- rty = rhs . type ( )
79+ lty = number_dtype ( lhs )
8080 other = array ()
81- other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], rty )
81+ other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], lty )
8282 elif not isinstance (lhs , array ):
8383 raise TypeError ("Invalid parameter to binary function" )
8484
@@ -179,13 +179,18 @@ def __init__(self, src=None, dims=(0,), type_char=None):
179179
180180 def copy (self ):
181181 out = array ()
182- safe_call (clib .af_retain_array (ct .pointer (out .arr ), self .arr ))
182+ safe_call (clib .af_copy_array (ct .pointer (out .arr ), self .arr ))
183183 return out
184184
185185 def __del__ (self ):
186186 if (self .arr .value != 0 ):
187187 clib .af_release_array (self .arr )
188188
189+ def device_ptr (self ):
190+ ptr = ctypes .c_void_p (0 )
191+ clib .af_get_device_ptr (ct .pointer (ptr ), self .arr )
192+ return ptr .value
193+
189194 def elements (self ):
190195 num = ct .c_ulonglong (0 )
191196 safe_call (clib .af_get_elements (ct .pointer (num ), self .arr ))
0 commit comments