11import transformers
2-
3- from retake .qwen2_vl import (
4- retake_Qwen2VLAttention_forward ,
5- retake_Qwen2VLSdpaAttention_forward ,
6- retake_Qwen2VLFlashAttention2_forward ,
7- retake_Qwen2VLForConditionalGeneration_compress_video_tokens ,
8- retake_Qwen2VLForConditionalGeneration_segment_input_ids ,
9- retake_Qwen2VLForConditionalGeneration_get_chunk_size ,
10- retake_Qwen2VLForConditionalGeneration_forge_input_chunks ,
11- retake_Qwen2VLForConditionalGeneration_forward ,
2+ from retake .llava_onevision import (
3+ retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens ,
4+ retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks ,
5+ retake_LlavaOnevisionForConditionalGeneration_forward ,
6+ retake_LlavaOnevisionForConditionalGeneration_get_chunk_size ,
7+ retake_LlavaOnevisionForConditionalGeneration_segment_input_ids ,
8+ retake_Qwen2Attention_forward ,
9+ retake_Qwen2Attention_init ,
1210)
1311from retake .qwen2_5_vl import (
1412 fixed_Qwen2_5_VLModel_prepare_4d_causal_attention_mask_with_cache_position ,
1513 retake_Qwen2_5_VLAttention_forward ,
16- retake_Qwen2_5_VLSdpaAttention_forward ,
1714 retake_Qwen2_5_VLFlashAttention2_forward ,
18- retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids ,
19- retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size ,
2015 retake_Qwen2_5_VLForConditionalGeneration_forge_input_chunks ,
2116 retake_Qwen2_5_VLForConditionalGeneration_forward ,
17+ retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size ,
18+ retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids ,
19+ retake_Qwen2_5_VLSdpaAttention_forward ,
2220)
23- from retake .llava_onevision import (
24- retake_Qwen2Attention_init ,
25- retake_Qwen2Attention_forward ,
26- retake_LlavaOnevisionForConditionalGeneration_get_chunk_size ,
27- retake_LlavaOnevisionForConditionalGeneration_segment_input_ids ,
28- retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens ,
29- retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks ,
30- retake_LlavaOnevisionForConditionalGeneration_forward ,
21+ from retake .qwen2_vl import (
22+ retake_Qwen2VLAttention_forward ,
23+ retake_Qwen2VLFlashAttention2_forward ,
24+ retake_Qwen2VLForConditionalGeneration_compress_video_tokens ,
25+ retake_Qwen2VLForConditionalGeneration_forge_input_chunks ,
26+ retake_Qwen2VLForConditionalGeneration_forward ,
27+ retake_Qwen2VLForConditionalGeneration_get_chunk_size ,
28+ retake_Qwen2VLForConditionalGeneration_segment_input_ids ,
29+ retake_Qwen2VLSdpaAttention_forward ,
3130)
3231
3332
3433def patch_qwen2vl_config (config , exp_configs ):
3534 # Rope Scaling
36- if ' scaling_factor' in exp_configs :
37- config .rope_scaling .pop (' type' )
38- config .rope_scaling [' rope_type' ] = ' yarn'
39- config .rope_scaling [' factor' ] = exp_configs [' scaling_factor' ]
40- config .rope_scaling [' beta_fast' ] = 32.0
41- config .rope_scaling [' beta_slow' ] = 1.0
35+ if " scaling_factor" in exp_configs :
36+ config .rope_scaling .pop (" type" )
37+ config .rope_scaling [" rope_type" ] = " yarn"
38+ config .rope_scaling [" factor" ] = exp_configs [" scaling_factor" ]
39+ config .rope_scaling [" beta_fast" ] = 32.0
40+ config .rope_scaling [" beta_slow" ] = 1.0
4241 # ReTaKe
43- config .longvideo_kwargs = exp_configs .get (' longvideo_kwargs' , {})
42+ config .longvideo_kwargs = exp_configs .get (" longvideo_kwargs" , {})
4443 return config
4544
4645
4746def patch_qwen2_5_vl_config (config , exp_configs ):
4847 # Rope Scaling
49- if ' scaling_factor' in exp_configs :
50- config .rope_scaling .pop (' type' )
51- config .rope_scaling [' rope_type' ] = ' yarn'
52- config .rope_scaling [' factor' ] = exp_configs [' scaling_factor' ]
53- config .rope_scaling [' beta_fast' ] = 32.0
54- config .rope_scaling [' beta_slow' ] = 1.0
48+ if " scaling_factor" in exp_configs :
49+ config .rope_scaling .pop (" type" )
50+ config .rope_scaling [" rope_type" ] = " yarn"
51+ config .rope_scaling [" factor" ] = exp_configs [" scaling_factor" ]
52+ config .rope_scaling [" beta_fast" ] = 32.0
53+ config .rope_scaling [" beta_slow" ] = 1.0
5554 # ReTaKe
56- config .longvideo_kwargs = exp_configs .get (' longvideo_kwargs' , {})
55+ config .longvideo_kwargs = exp_configs .get (" longvideo_kwargs" , {})
5756 return config
5857
5958
6059def patch_llava_onevision_config (config , exp_configs ):
6160 # Rope Scaling
62- if ' scaling_factor' in exp_configs :
61+ if " scaling_factor" in exp_configs :
6362 config .text_config .rope_scaling = {
64- ' rope_type' : ' yarn' ,
65- ' factor' : exp_configs [' scaling_factor' ],
66- ' beta_fast' : 32.0 ,
67- ' beta_slow' : 1.0 ,
63+ " rope_type" : " yarn" ,
64+ " factor" : exp_configs [" scaling_factor" ],
65+ " beta_fast" : 32.0 ,
66+ " beta_slow" : 1.0 ,
6867 }
6968 # ReTaKe
70- config .longvideo_kwargs = exp_configs .get (' longvideo_kwargs' , {})
69+ config .longvideo_kwargs = exp_configs .get (" longvideo_kwargs" , {})
7170 return config
7271
7372
7473def patch_qwen2vl (method ):
7574
7675 if method == "retake" :
7776 print ("Using ReTaKe for Qwen2VLForConditionalGeneration!" )
78- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLAttention .forward = retake_Qwen2VLAttention_forward
79- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLSdpaAttention .forward = retake_Qwen2VLSdpaAttention_forward
80- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLFlashAttention2 .forward = retake_Qwen2VLFlashAttention2_forward
81- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .compress_video_tokens = retake_Qwen2VLForConditionalGeneration_compress_video_tokens
82- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .segment_input_ids = retake_Qwen2VLForConditionalGeneration_segment_input_ids
83- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .get_chunk_size = retake_Qwen2VLForConditionalGeneration_get_chunk_size
84- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .forge_input_chunks = retake_Qwen2VLForConditionalGeneration_forge_input_chunks
85- transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .forward = retake_Qwen2VLForConditionalGeneration_forward
77+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLAttention .forward = (
78+ retake_Qwen2VLAttention_forward
79+ )
80+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLSdpaAttention .forward = (
81+ retake_Qwen2VLSdpaAttention_forward
82+ )
83+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLFlashAttention2 .forward = (
84+ retake_Qwen2VLFlashAttention2_forward
85+ )
86+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .compress_video_tokens = (
87+ retake_Qwen2VLForConditionalGeneration_compress_video_tokens
88+ )
89+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .segment_input_ids = (
90+ retake_Qwen2VLForConditionalGeneration_segment_input_ids
91+ )
92+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .get_chunk_size = (
93+ retake_Qwen2VLForConditionalGeneration_get_chunk_size
94+ )
95+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .forge_input_chunks = (
96+ retake_Qwen2VLForConditionalGeneration_forge_input_chunks
97+ )
98+ transformers .models .qwen2_vl .modeling_qwen2_vl .Qwen2VLForConditionalGeneration .forward = (
99+ retake_Qwen2VLForConditionalGeneration_forward
100+ )
86101 else :
87102 raise NotImplementedError
88103
@@ -91,14 +106,30 @@ def patch_qwen2_5_vl(method):
91106
92107 if method == "retake" :
93108 print ("Using ReTaKe for Qwen2_5_VLForConditionalGeneration!" )
94- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLModel ._prepare_4d_causal_attention_mask_with_cache_position = fixed_Qwen2_5_VLModel_prepare_4d_causal_attention_mask_with_cache_position
95- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLAttention .forward = retake_Qwen2_5_VLAttention_forward
96- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLSdpaAttention .forward = retake_Qwen2_5_VLSdpaAttention_forward
97- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLFlashAttention2 .forward = retake_Qwen2_5_VLFlashAttention2_forward
98- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .segment_input_ids = retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids
99- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .get_chunk_size = retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size
100- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .forge_input_chunks = retake_Qwen2_5_VLForConditionalGeneration_forge_input_chunks
101- transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .forward = retake_Qwen2_5_VLForConditionalGeneration_forward
109+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLModel ._prepare_4d_causal_attention_mask_with_cache_position = (
110+ fixed_Qwen2_5_VLModel_prepare_4d_causal_attention_mask_with_cache_position
111+ )
112+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLAttention .forward = (
113+ retake_Qwen2_5_VLAttention_forward
114+ )
115+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLSdpaAttention .forward = (
116+ retake_Qwen2_5_VLSdpaAttention_forward
117+ )
118+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLFlashAttention2 .forward = (
119+ retake_Qwen2_5_VLFlashAttention2_forward
120+ )
121+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .segment_input_ids = (
122+ retake_Qwen2_5_VLForConditionalGeneration_segment_input_ids
123+ )
124+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .get_chunk_size = (
125+ retake_Qwen2_5_VLForConditionalGeneration_get_chunk_size
126+ )
127+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .forge_input_chunks = (
128+ retake_Qwen2_5_VLForConditionalGeneration_forge_input_chunks
129+ )
130+ transformers .models .qwen2_5_vl .modeling_qwen2_5_vl .Qwen2_5_VLForConditionalGeneration .forward = (
131+ retake_Qwen2_5_VLForConditionalGeneration_forward
132+ )
102133 else :
103134 raise NotImplementedError
104135
@@ -107,12 +138,26 @@ def patch_llava_onevision(method):
107138
108139 if method == "retake" :
109140 print ("Using ReTaKe for LlavaOnevisionForConditionalGeneration!" )
110- transformers .models .qwen2 .modeling_qwen2 .Qwen2Attention .__init__ = retake_Qwen2Attention_init
111- transformers .models .qwen2 .modeling_qwen2 .Qwen2Attention .forward = retake_Qwen2Attention_forward
112- transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .get_chunk_size = retake_LlavaOnevisionForConditionalGeneration_get_chunk_size
113- transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .segment_input_ids = retake_LlavaOnevisionForConditionalGeneration_segment_input_ids
114- transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .compress_video_tokens = retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens
115- transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .forge_input_chunks = retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks
116- transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .forward = retake_LlavaOnevisionForConditionalGeneration_forward
141+ transformers .models .qwen2 .modeling_qwen2 .Qwen2Attention .__init__ = (
142+ retake_Qwen2Attention_init
143+ )
144+ transformers .models .qwen2 .modeling_qwen2 .Qwen2Attention .forward = (
145+ retake_Qwen2Attention_forward
146+ )
147+ transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .get_chunk_size = (
148+ retake_LlavaOnevisionForConditionalGeneration_get_chunk_size
149+ )
150+ transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .segment_input_ids = (
151+ retake_LlavaOnevisionForConditionalGeneration_segment_input_ids
152+ )
153+ transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .compress_video_tokens = (
154+ retake_LlavaOnevisionForConditionalGeneration_compress_video_tokens
155+ )
156+ transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .forge_input_chunks = (
157+ retake_LlavaOnevisionForConditionalGeneration_forge_input_chunks
158+ )
159+ transformers .models .llava_onevision .modeling_llava_onevision .LlavaOnevisionForConditionalGeneration .forward = (
160+ retake_LlavaOnevisionForConditionalGeneration_forward
161+ )
117162 else :
118163 raise NotImplementedError
0 commit comments