4949 ModelOutput ,
5050)
5151from ..model_utils import PretrainedModel , register_base_model
52+ from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
5253from ..tensor_parallel_utils import model_parallel_dropout
5354from .configuration import PaddleOCRVisionConfig , PaddleOCRVLConfig
5455
@@ -100,6 +101,9 @@ def apply_rotary_pos_emb_vision(q, k, cos, sin):
100101
101102
102103def apply_fused_rope (query_states , key_states , rope_theta ):
104+ # b h l d -> b l h d
105+ query_states = query_states .transpose (1 , 2 )
106+ key_states = key_states .transpose (1 , 2 )
103107 _ , _ , num_heads , _ = query_states .shape
104108 _ , kv_seq_len , num_key_value_heads , _ = key_states .shape
105109 if num_heads != num_key_value_heads :
@@ -112,7 +116,7 @@ def apply_fused_rope(query_states, key_states, rope_theta):
112116 None ,
113117 rotary_emb_base = rope_theta ,
114118 )
115- return query_states , key_states
119+ return query_states . transpose ( 1 , 2 ), key_states . transpose ( 1 , 2 )
116120
117121
118122def inbatch_pack_offset_to_attn_mask_start_row_indices (inbatch_pack_offset ):
@@ -1056,28 +1060,46 @@ def forward(self, image_features, image_grid_thw):
10561060class KeyeRotaryEmbedding (nn .Layer ):
10571061 def __init__ (self , config : PaddleOCRVLConfig ):
10581062 super ().__init__ ()
1059- self .rope_kwargs = {}
1060-
1061- # # BC: "rope_type" was originally "type"
1062- # if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1063- # self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1064- # else:
1065- # self.rope_type = "default"
1066- rope_parameters = config .rope_parameters
1067- self .rope_type = rope_parameters .get ("rope_type" , rope_parameters .get ("type" , "default" ))
1063+ self .config = config
10681064 self .max_seq_len_cached = config .max_position_embeddings
10691065 self .original_max_seq_len = config .max_position_embeddings
10701066
1071- if self . rope_type == "default" :
1072- dim = config . head_dim
1073- inv_freq = 1.0 / ( config . rope_theta ** ( paddle . arange ( 0 , dim , 2 , dtype = "int64" ). astype ( "float32" ) / dim ))
1074- self .attention_scaling = 1.0
1075- else :
1076- raise ValueError ( f"Unsupported rope type: { self .rope_type } " )
1067+ rope_parameters = self . config . rope_parameters
1068+ self . rope_type = rope_parameters . get ( "rope_type" , rope_parameters . get ( "type" , "default" ))
1069+ rope_init_fn = self . compute_default_rope_parameters
1070+ if self .rope_type != "default" :
1071+ rope_init_fn = ROPE_INIT_FUNCTIONS [ self . rope_type ]
1072+ inv_freq , self . attention_scaling = rope_init_fn ( self .config )
10771073
10781074 self .register_buffer ("inv_freq" , inv_freq , persistable = False )
1079- self .original_inv_freq = self .inv_freq
1075+ self .original_inv_freq = inv_freq
1076+
1077+ @staticmethod
1078+ def compute_default_rope_parameters (
1079+ config : Optional [PaddleOCRVLConfig ] = None ,
1080+ seq_len : Optional [int ] = None ,
1081+ ) -> tuple ["paddle.Tensor" , float ]:
1082+ """
1083+ Computes the inverse frequencies according to the original RoPE implementation
1084+ Args:
1085+ config ([`PreTrainedConfig`]):
1086+ The model configuration.
1087+ seq_len (`int`, *optional*):
1088+ The current sequence length. Unused for this type of RoPE.
1089+ Returns:
1090+ Tuple of (`paddle.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
1091+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
1092+ """
1093+ base = config .rope_parameters ["rope_theta" ]
1094+ dim = getattr (config , "head_dim" , None ) or config .hidden_size // config .num_attention_heads
1095+
1096+ attention_factor = 1.0 # Unused in this type of RoPE
1097+
1098+ # Compute the inverse frequencies
1099+ inv_freq = 1.0 / (base ** (paddle .arange (0 , dim , 2 , dtype = paddle .int64 ).astype (dtype = paddle .float32 ) / dim ))
1100+ return inv_freq , attention_factor
10801101
1102+ @dynamic_rope_update
10811103 @paddle .no_grad ()
10821104 def forward (self , x , position_ids ):
10831105 # Core RoPE block. In contrast to other models, Keye has different position ids for the grids
0 commit comments