44def exists (val ):
55 return val is not None
66
7+ def identity (t ):
8+ return t
9+
10+ def clone_and_detach (t ):
11+ return t .clone ().detach ()
12+
713def apply_tuple_or_single (fn , val ):
814 if isinstance (val , tuple ):
915 return tuple (map (fn , val ))
@@ -17,7 +23,8 @@ def __init__(
1723 layer = None ,
1824 layer_name = 'transformer' ,
1925 layer_save_input = False ,
20- return_embeddings_only = False
26+ return_embeddings_only = False ,
27+ detach = True
2128 ):
2229 super ().__init__ ()
2330 self .vit = vit
@@ -34,9 +41,11 @@ def __init__(
3441 self .layer_save_input = layer_save_input # whether to save input or output of layer
3542 self .return_embeddings_only = return_embeddings_only
3643
44+ self .detach_fn = clone_and_detach if detach else identity
45+
3746 def _hook (self , _ , inputs , output ):
3847 layer_output = inputs if self .layer_save_input else output
39- self .latents = apply_tuple_or_single (lambda t : t . clone (). detach () , layer_output )
48+ self .latents = apply_tuple_or_single (self . detach_fn , layer_output )
4049
4150 def _register_hook (self ):
4251 if not exists (self .layer ):
0 commit comments