Skip to content

Commit b46233c

Browse files
committed
need to be able to invoke with eval no grad
1 parent 68e13a3 commit b46233c

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.11.1',
9+
version = '1.11.2',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/accept_video_wrapper.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextlib import nullcontext
2+
13
import torch
24
from torch import is_tensor, randn
35
from torch.nn import Module, Parameter
@@ -40,7 +42,8 @@ def __init__(
4042

4143
def forward(
4244
self,
43-
video # (b c t h w)
45+
video, # (b c t h w)
46+
eval_with_no_grad = False
4447
):
4548
add_time_pos_emb = self.add_time_pos_emb
4649
time = video.shape[2]
@@ -54,9 +57,17 @@ def forward(
5457

5558
video = rearrange(video, 'b t ... -> (b t) ...')
5659

60+
# forward through image net for outputs
61+
5762
func = getattr(self.image_net, self.forward_function)
5863

59-
outputs = func(video)
64+
if eval_with_no_grad:
65+
self.image_net.eval()
66+
67+
context = torch.no_grad if eval_with_no_grad else nullcontext
68+
69+
with context():
70+
outputs = func(video)
6071

6172
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
6273

@@ -111,7 +122,7 @@ def forward(
111122

112123
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
113124

114-
logits, embeddings = video_acceptor(videos) # always (batch, channels, time, height, width) - time is always dimension 2
125+
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
115126

116127
assert logits.shape == (1, 10, 1000)
117128
assert embeddings.shape == (1, 10, 65, 1024)

0 commit comments

Comments
 (0)