@@ -59,7 +59,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
5959 shape_var_item_names = [f"{ name } _item" for name in shape_var_names ]
6060 shapes_to_items_src = indent (
6161 "\n " .join (
62- f"{ item_name } = { shape_name } .item( )"
62+ f"{ item_name } = int( { shape_name } )"
6363 for item_name , shape_name in zip (
6464 shape_var_item_names , shape_var_names , strict = True
6565 )
@@ -74,24 +74,26 @@ def numba_funcify_Alloc(op, node, **kwargs):
7474 f'if val.shape[{ - i - 1 } ] == 1 and scalar_shape[{ - i - 1 } ] != 1: raise ValueError("{ Alloc ._runtime_broadcast_error_msg } ")'
7575 )
7676 check_runtime_broadcast_src = indent ("\n " .join (check_runtime_broadcast ), " " * 4 )
77-
77+ dtype = node . inputs [ 0 ]. type . dtype
7878 alloc_def_src = f"""
7979def alloc(val, { ", " .join (shape_var_names )} ):
8080{ shapes_to_items_src }
8181 scalar_shape = { create_tuple_string (shape_var_item_names )}
8282{ check_runtime_broadcast_src }
83- res = np.empty(scalar_shape, dtype=val. dtype)
83+ res = np.empty(scalar_shape, dtype=np. { dtype } )
8484 res[...] = val
8585 return res
8686 """
8787 alloc_fn = compile_numba_function_src (
8888 alloc_def_src ,
8989 "alloc" ,
9090 globals () | {"np" : np },
91+ write_to_disk = True ,
9192 )
9293
94+ cache_version = - 1
9395 cache_key = sha256 (
94- str ((type (op ), node .inputs [0 ].type .broadcastable )).encode ()
96+ str ((type (op ), node .inputs [0 ].type .broadcastable , cache_version )).encode ()
9597 ).hexdigest ()
9698 return numba_basic .numba_njit (alloc_fn ), cache_key
9799
0 commit comments