Skip to content

Commit e05cd6d

Browse files
committed
some models only return embeddings with some kwarg on forward
1 parent b46233c commit e05cd6d

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.11.2',
9+
version = '1.11.3',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/accept_video_wrapper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(
4343
def forward(
4444
self,
4545
video, # (b c t h w)
46-
eval_with_no_grad = False
46+
eval_with_no_grad = False,
47+
forward_kwargs = dict()
4748
):
4849
add_time_pos_emb = self.add_time_pos_emb
4950
time = video.shape[2]
@@ -67,7 +68,7 @@ def forward(
6768
context = torch.no_grad if eval_with_no_grad else nullcontext
6869

6970
with context():
70-
outputs = func(video)
71+
outputs = func(video, **forward_kwargs)
7172

7273
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
7374

0 commit comments

Comments
 (0)