1+ from contextlib import nullcontext
2+
13import torch
24from torch import is_tensor , randn
35from torch .nn import Module , Parameter
@@ -40,7 +42,8 @@ def __init__(
4042
4143 def forward (
4244 self ,
43- video # (b c t h w)
45+ video , # (b c t h w)
46+ eval_with_no_grad = False
4447 ):
4548 add_time_pos_emb = self .add_time_pos_emb
4649 time = video .shape [2 ]
@@ -54,9 +57,17 @@ def forward(
5457
5558 video = rearrange (video , 'b t ... -> (b t) ...' )
5659
60+ # forward through image net for outputs
61+
5762 func = getattr (self .image_net , self .forward_function )
5863
59- outputs = func (video )
64+ if eval_with_no_grad :
65+ self .image_net .eval ()
66+
67+ context = torch .no_grad if eval_with_no_grad else nullcontext
68+
69+ with context ():
70+ outputs = func (video )
6071
6172 # handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
6273
@@ -111,7 +122,7 @@ def forward(
111122
112123 video_acceptor = AcceptVideoWrapper (v , add_time_pos_emb = True , output_pos_add_pos_emb = 1 , time_seq_len = 10 , dim_emb = 1024 )
113124
114- logits , embeddings = video_acceptor (videos ) # always (batch, channels, time, height, width) - time is always dimension 2
125+ logits , embeddings = video_acceptor (videos , eval_with_no_grad = True ) # always (batch, channels, time, height, width) - time is always dimension 2
115126
116127 assert logits .shape == (1 , 10 , 1000 )
117128 assert embeddings .shape == (1 , 10 , 65 , 1024 )
0 commit comments