@@ -41,6 +41,8 @@ cdef class patch:
4141 functions_dict = {}
4242
4343 def __cinit__ (self ):
44+ cdef int pi, oi
45+
4446 umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
4547 self .functions_count = 0
4648 for umath in umaths:
@@ -51,23 +53,31 @@ cdef class patch:
5153
5254 func_number = 0
5355 for umath in umaths:
54- mkl_umath = getattr (mu, umath)
55- np_umath = getattr (nu, umath)
56- c_mkl_umath = < cnp.ufunc> mkl_umath
57- c_np_umath = < cnp.ufunc> np_umath
58- for type in mkl_umath.types:
59- np_index = np_umath.types.index(type )
60- self .functions[func_number].original_function = c_np_umath.functions[np_index]
61- mkl_index = mkl_umath.types.index(type )
62- self .functions[func_number].patch_function = c_mkl_umath.functions[mkl_index]
63-
64- nargs = c_mkl_umath.nargs
65- self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
66- for i in range (nargs):
67- self .functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
68-
69- self .functions_dict[(umath, type )] = func_number
70- func_number = func_number + 1
56+ patch_umath = getattr (mu, umath)
57+ c_patch_umath = < cnp.ufunc> patch_umath
58+ c_orig_umath = < cnp.ufunc> getattr (nu, umath)
59+ nargs = c_patch_umath.nargs
60+ for pi in range (c_patch_umath.ntypes):
61+ oi = 0
62+ while oi < c_orig_umath.ntypes:
63+ found = True
64+ for i in range (c_patch_umath.nargs):
65+ if c_patch_umath.types[pi * nargs + i] != c_orig_umath.types[oi * nargs + i]:
66+ found = False
67+ break
68+ if found == True :
69+ break
70+ oi = oi + 1
71+ if oi < c_orig_umath.ntypes:
72+ self .functions[func_number].original_function = c_orig_umath.functions[oi]
73+ self .functions[func_number].patch_function = c_patch_umath.functions[pi]
74+ self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
75+ for i in range (nargs):
76+ self .functions[func_number].signature[i] = c_patch_umath.types[pi * nargs + i]
77+ self .functions_dict[(umath, patch_umath.types[pi])] = func_number
78+ func_number = func_number + 1
79+ else :
80+ raise RuntimeError (" Unable to find original function for: " + umath + " " + patch_umath.types[pi])
7181
7282 def __dealloc__ (self ):
7383 for i in range (self .functions_count):
0 commit comments