Skip to content

Commit f382419

Browse files
committed
Generalizing af.constant to handle complex and 64 bit data types
1 parent 4cb7daa commit f382419

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

arrayfire/data.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,22 @@ def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
4242

4343
out = array()
4444
dims = dim4(d0, d1, d2, d3)
45-
c_val = c_double(val)
46-
clib.af_constant(pointer(out.arr), c_val, 4, pointer(dims), dtype)
45+
46+
if isinstance(val, complex):
47+
c_real = c_double(val.real)
48+
c_imag = c_double(val.imag)
49+
50+
if (dtype != c32 and dtype != c64):
51+
dtype = c32
52+
53+
clib.af_constant_complex(pointer(out.arr), c_real, c_imag, 4, pointer(dims), dtype)
54+
elif dtype == s64:
55+
c_val = c_longlong(val.real)
56+
clib.af_constant_long(pointer(out.arr), c_val, 4, pointer(dims))
57+
elif dtype == u64:
58+
c_val = c_ulonglong(val.real)
59+
clib.af_constant_ulong(pointer(out.arr), c_val, 4, pointer(dims))
60+
else:
61+
c_val = c_double(val)
62+
clib.af_constant(pointer(out.arr), c_val, 4, pointer(dims), dtype)
4763
return out

0 commit comments

Comments
 (0)