Skip to content

Commit f86e052

Browse files
committed
offer way for extractor to return latents without detaching them
1 parent 2fa2b62 commit f86e052

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.35.7',
6+
version = '0.35.8',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/extractor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
def exists(val):
55
return val is not None
66

7+
def identity(t):
8+
return t
9+
10+
def clone_and_detach(t):
11+
return t.clone().detach()
12+
713
def apply_tuple_or_single(fn, val):
814
if isinstance(val, tuple):
915
return tuple(map(fn, val))
@@ -17,7 +23,8 @@ def __init__(
1723
layer = None,
1824
layer_name = 'transformer',
1925
layer_save_input = False,
20-
return_embeddings_only = False
26+
return_embeddings_only = False,
27+
detach = True
2128
):
2229
super().__init__()
2330
self.vit = vit
@@ -34,9 +41,11 @@ def __init__(
3441
self.layer_save_input = layer_save_input # whether to save input or output of layer
3542
self.return_embeddings_only = return_embeddings_only
3643

44+
self.detach_fn = clone_and_detach if detach else identity
45+
3746
def _hook(self, _, inputs, output):
3847
layer_output = inputs if self.layer_save_input else output
39-
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
48+
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
4049

4150
def _register_hook(self):
4251
if not exists(self.layer):

0 commit comments

Comments
 (0)