Skip to content

Commit e563a54

Browse files
committed
correct ci
1 parent 9a255ac commit e563a54

19 files changed

+2487
-1377
lines changed

ucm/sandbox/sparse/retake/demo.py

Lines changed: 123 additions & 73 deletions
Large diffs are not rendered by default.

ucm/sandbox/sparse/retake/retake/dataset_utils.py

Lines changed: 218 additions & 155 deletions
Large diffs are not rendered by default.

ucm/sandbox/sparse/retake/retake/infer_eval.py

Lines changed: 194 additions & 122 deletions
Large diffs are not rendered by default.

ucm/sandbox/sparse/retake/retake/llava_onevision.py

Lines changed: 265 additions & 131 deletions
Large diffs are not rendered by default.

ucm/sandbox/sparse/retake/retake/longvideo_cache.py

Lines changed: 471 additions & 230 deletions
Large diffs are not rendered by default.
Lines changed: 109 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,103 @@
11
import transformers
2-
3-
from retake.qwen2_vl import (
4-
retake_Qwen2VLAttention_forward,
5-
retake_Qwen2VLSdpaAttention_forward,
6-
retake_Qwen2VLFlashAttention2_forward,
7-
retake_Qwen2VLForConditionalGeneration_compress_video_tokens,
8-
retake_Qwen2VLForConditionalGeneration_segment_input_ids,
9-
retake_Qwen2VLForConditionalGeneration_get_chunk_size,
10-
retake_Qwen2VLForConditionalGeneration_forge_input_chunks,
11-
retake_Qwen2VLForConditionalGeneration_forward,
2+
from retake.llava_onevision import (
3+
retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens,
4+
retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks,
5+
retake_LlavaOnevisionForConditionalGeneration_forward,
6+
retake_LlavaOnevisionForConditionalGeneration_get_chunk_size,
7+
retake_LlavaOnevisionForConditionalGeneration_segment_input_ids,
8+
retake_Qwen2Attention_forward,
9+
retake_Qwen2Attention_init,
1210
)
1311
from retake.qwen2_5_vl import (
1412
fixed_Qwen2_5_VLModel_prepare_4d_causal_attention_mask_with_cache_position,
1513
retake_Qwen2_5_VLAttention_forward,
16-
retake_Qwen2_5_VLSdpaAttention_forward,
1714
retake_Qwen2_5_VLFlashAttention2_forward,
18-
retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids,
19-
retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size,
2015
retake_Qwen2_5_VLForConditionalGeneration_forge_input_chunks,
2116
retake_Qwen2_5_VLForConditionalGeneration_forward,
17+
retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size,
18+
retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids,
19+
retake_Qwen2_5_VLSdpaAttention_forward,
2220
)
23-
from retake.llava_onevision import (
24-
retake_Qwen2Attention_init,
25-
retake_Qwen2Attention_forward,
26-
retake_LlavaOnevisionForConditionalGeneration_get_chunk_size,
27-
retake_LlavaOnevisionForConditionalGeneration_segment_input_ids,
28-
retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens,
29-
retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks,
30-
retake_LlavaOnevisionForConditionalGeneration_forward,
21+
from retake.qwen2_vl import (
22+
retake_Qwen2VLAttention_forward,
23+
retake_Qwen2VLFlashAttention2_forward,
24+
retake_Qwen2VLForConditionalGeneration_compress_video_tokens,
25+
retake_Qwen2VLForConditionalGeneration_forge_input_chunks,
26+
retake_Qwen2VLForConditionalGeneration_forward,
27+
retake_Qwen2VLForConditionalGeneration_get_chunk_size,
28+
retake_Qwen2VLForConditionalGeneration_segment_input_ids,
29+
retake_Qwen2VLSdpaAttention_forward,
3130
)
3231

3332

3433
def patch_qwen2vl_config(config, exp_configs):
3534
# Rope Scaling
36-
if 'scaling_factor' in exp_configs:
37-
config.rope_scaling.pop('type')
38-
config.rope_scaling['rope_type'] = 'yarn'
39-
config.rope_scaling['factor'] = exp_configs['scaling_factor']
40-
config.rope_scaling['beta_fast'] = 32.0
41-
config.rope_scaling['beta_slow'] = 1.0
35+
if "scaling_factor" in exp_configs:
36+
config.rope_scaling.pop("type")
37+
config.rope_scaling["rope_type"] = "yarn"
38+
config.rope_scaling["factor"] = exp_configs["scaling_factor"]
39+
config.rope_scaling["beta_fast"] = 32.0
40+
config.rope_scaling["beta_slow"] = 1.0
4241
# ReTaKe
43-
config.longvideo_kwargs = exp_configs.get('longvideo_kwargs', {})
42+
config.longvideo_kwargs = exp_configs.get("longvideo_kwargs", {})
4443
return config
4544

4645

4746
def patch_qwen2_5_vl_config(config, exp_configs):
4847
# Rope Scaling
49-
if 'scaling_factor' in exp_configs:
50-
config.rope_scaling.pop('type')
51-
config.rope_scaling['rope_type'] = 'yarn'
52-
config.rope_scaling['factor'] = exp_configs['scaling_factor']
53-
config.rope_scaling['beta_fast'] = 32.0
54-
config.rope_scaling['beta_slow'] = 1.0
48+
if "scaling_factor" in exp_configs:
49+
config.rope_scaling.pop("type")
50+
config.rope_scaling["rope_type"] = "yarn"
51+
config.rope_scaling["factor"] = exp_configs["scaling_factor"]
52+
config.rope_scaling["beta_fast"] = 32.0
53+
config.rope_scaling["beta_slow"] = 1.0
5554
# ReTaKe
56-
config.longvideo_kwargs = exp_configs.get('longvideo_kwargs', {})
55+
config.longvideo_kwargs = exp_configs.get("longvideo_kwargs", {})
5756
return config
5857

5958

6059
def patch_llava_onevision_config(config, exp_configs):
6160
# Rope Scaling
62-
if 'scaling_factor' in exp_configs:
61+
if "scaling_factor" in exp_configs:
6362
config.text_config.rope_scaling = {
64-
'rope_type': 'yarn',
65-
'factor': exp_configs['scaling_factor'],
66-
'beta_fast': 32.0,
67-
'beta_slow': 1.0,
63+
"rope_type": "yarn",
64+
"factor": exp_configs["scaling_factor"],
65+
"beta_fast": 32.0,
66+
"beta_slow": 1.0,
6867
}
6968
# ReTaKe
70-
config.longvideo_kwargs = exp_configs.get('longvideo_kwargs', {})
69+
config.longvideo_kwargs = exp_configs.get("longvideo_kwargs", {})
7170
return config
7271

7372

7473
def patch_qwen2vl(method):
7574

7675
if method == "retake":
7776
print("Using ReTaKe for Qwen2VLForConditionalGeneration!")
78-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLAttention.forward = retake_Qwen2VLAttention_forward
79-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention.forward = retake_Qwen2VLSdpaAttention_forward
80-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLFlashAttention2.forward = retake_Qwen2VLFlashAttention2_forward
81-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.compress_video_tokens = retake_Qwen2VLForConditionalGeneration_compress_video_tokens
82-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.segment_input_ids = retake_Qwen2VLForConditionalGeneration_segment_input_ids
83-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.get_chunk_size = retake_Qwen2VLForConditionalGeneration_get_chunk_size
84-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forge_input_chunks = retake_Qwen2VLForConditionalGeneration_forge_input_chunks
85-
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = retake_Qwen2VLForConditionalGeneration_forward
77+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLAttention.forward = (
78+
retake_Qwen2VLAttention_forward
79+
)
80+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention.forward = (
81+
retake_Qwen2VLSdpaAttention_forward
82+
)
83+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLFlashAttention2.forward = (
84+
retake_Qwen2VLFlashAttention2_forward
85+
)
86+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.compress_video_tokens = (
87+
retake_Qwen2VLForConditionalGeneration_compress_video_tokens
88+
)
89+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.segment_input_ids = (
90+
retake_Qwen2VLForConditionalGeneration_segment_input_ids
91+
)
92+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.get_chunk_size = (
93+
retake_Qwen2VLForConditionalGeneration_get_chunk_size
94+
)
95+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forge_input_chunks = (
96+
retake_Qwen2VLForConditionalGeneration_forge_input_chunks
97+
)
98+
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = (
99+
retake_Qwen2VLForConditionalGeneration_forward
100+
)
86101
else:
87102
raise NotImplementedError
88103

@@ -91,14 +106,30 @@ def patch_qwen2_5_vl(method):
91106

92107
if method == "retake":
93108
print("Using ReTaKe for Qwen2_5_VLForConditionalGeneration!")
94-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLModel._prepare_4d_causal_attention_mask_with_cache_position = fixed_Qwen2_5_VLModel_prepare_4d_causal_attention_mask_with_cache_position
95-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLAttention.forward = retake_Qwen2_5_VLAttention_forward
96-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLSdpaAttention.forward = retake_Qwen2_5_VLSdpaAttention_forward
97-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLFlashAttention2.forward = retake_Qwen2_5_VLFlashAttention2_forward
98-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.segment_input_ids = retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids
99-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.get_chunk_size = retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size
100-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forge_input_chunks = retake_Qwen2_5_VLForConditionalGeneration_forge_input_chunks
101-
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = retake_Qwen2_5_VLForConditionalGeneration_forward
109+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLModel._prepare_4d_causal_attention_mask_with_cache_position = (
110+
fixed_Qwen2_5_VLModel_prepare_4d_causal_attention_mask_with_cache_position
111+
)
112+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLAttention.forward = (
113+
retake_Qwen2_5_VLAttention_forward
114+
)
115+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLSdpaAttention.forward = (
116+
retake_Qwen2_5_VLSdpaAttention_forward
117+
)
118+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLFlashAttention2.forward = (
119+
retake_Qwen2_5_VLFlashAttention2_forward
120+
)
121+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.segment_input_ids = (
122+
retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids
123+
)
124+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.get_chunk_size = (
125+
retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size
126+
)
127+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forge_input_chunks = (
128+
retake_Qwen2_5_VLForConditionalGeneration_forge_input_chunks
129+
)
130+
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
131+
retake_Qwen2_5_VLForConditionalGeneration_forward
132+
)
102133
else:
103134
raise NotImplementedError
104135

@@ -107,12 +138,26 @@ def patch_llava_onevision(method):
107138

108139
if method == "retake":
109140
print("Using ReTaKe for LlavaOnevisionForConditionalGeneration!")
110-
transformers.models.qwen2.modeling_qwen2.Qwen2Attention.__init__ = retake_Qwen2Attention_init
111-
transformers.models.qwen2.modeling_qwen2.Qwen2Attention.forward = retake_Qwen2Attention_forward
112-
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.get_chunk_size = retake_LlavaOnevisionForConditionalGeneration_get_chunk_size
113-
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.segment_input_ids = retake_LlavaOnevisionForConditionalGeneration_segment_input_ids
114-
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.compress_video_tokens = retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens
115-
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.forge_input_chunks = retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks
116-
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.forward = retake_LlavaOnevisionForConditionalGeneration_forward
141+
transformers.models.qwen2.modeling_qwen2.Qwen2Attention.__init__ = (
142+
retake_Qwen2Attention_init
143+
)
144+
transformers.models.qwen2.modeling_qwen2.Qwen2Attention.forward = (
145+
retake_Qwen2Attention_forward
146+
)
147+
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.get_chunk_size = (
148+
retake_LlavaOnevisionForConditionalGeneration_get_chunk_size
149+
)
150+
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.segment_input_ids = (
151+
retake_LlavaOnevisionForConditionalGeneration_segment_input_ids
152+
)
153+
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.compress_video_tokens = (
154+
retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens
155+
)
156+
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.forge_input_chunks = (
157+
retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks
158+
)
159+
transformers.models.llava_onevision.modeling_llava_onevision.LlavaOnevisionForConditionalGeneration.forward = (
160+
retake_LlavaOnevisionForConditionalGeneration_forward
161+
)
117162
else:
118163
raise NotImplementedError

0 commit comments

Comments
 (0)