6161 SequenceClassifierOutputWithPast ,
6262)
6363from ..model_utils import PretrainedModel , register_base_model
64- from ..modeling_rope_utils import dynamic_rope_update
64+ from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
6565from ..moe_gate import PretrainedMoEGate
6666from ..moe_layer import MoEFlexTokenLayer
6767from .configuration import DeepseekV3Config
@@ -137,81 +137,6 @@ def yarn_get_mscale(scale, mscale=1):
137137 return 0.1 * mscale * math .log (scale ) + 1.0
138138
139139
140- def _compute_yarn_parameters (
141- config ,
142- seq_len = None ,
143- ):
144- base = config ["rope_theta" ]
145- rope_parameters_dict = config ["rope_parameters" ]
146- partial_rotary_factor = config .partial_rotary_factor if hasattr (config , "partial_rotary_factor" ) else 1.0
147- head_dim = getattr (config , "qk_rope_head_dim" , config .hidden_size // config .num_attention_heads )
148- dim = int (head_dim * partial_rotary_factor )
149-
150- factor = rope_parameters_dict ["factor" ]
151- attention_factor = rope_parameters_dict .get ("attention_factor" , None )
152- mscale = rope_parameters_dict .get ("mscale" )
153- mscale_all_dim = rope_parameters_dict .get ("mscale_all_dim" )
154-
155- # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
156- # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
157- # values to compute the default attention scaling factor, instead of using `factor`.
158- if "original_max_position_embeddings" in rope_parameters_dict :
159- original_max_position_embeddings = rope_parameters_dict ["original_max_position_embeddings" ]
160- factor = config .max_position_embeddings / original_max_position_embeddings
161- else :
162- original_max_position_embeddings = config .max_position_embeddings
163-
164- # Sets the attention factor as suggested in the paper
165- if attention_factor is None :
166- if mscale and mscale_all_dim :
167- attention_factor = float (yarn_get_mscale (factor , mscale ) / yarn_get_mscale (factor , mscale_all_dim ))
168- else :
169- attention_factor = yarn_get_mscale (factor )
170-
171- # Optional config options
172- # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
173- beta_fast = rope_parameters_dict .get ("beta_fast" ) or 32
174- beta_slow = rope_parameters_dict .get ("beta_slow" ) or 1
175-
176- # Compute the inverse frequencies
177- def find_correction_dim (num_rotations , dim , base , max_position_embeddings ):
178- """Inverse dimension formula to find the dimension based on the number of rotations"""
179- return (dim * math .log (max_position_embeddings / (num_rotations * 2 * math .pi ))) / (2 * math .log (base ))
180-
181- def find_correction_range (low_rot , high_rot , dim , base , max_position_embeddings , truncate ):
182- """Find dimension range bounds based on rotations"""
183- low = find_correction_dim (low_rot , dim , base , max_position_embeddings )
184- high = find_correction_dim (high_rot , dim , base , max_position_embeddings )
185- if truncate :
186- low = math .floor (low )
187- high = math .ceil (high )
188- return max (low , 0 ), min (high , dim - 1 )
189-
190- def linear_ramp_factor (min , max , dim ):
191- if min == max :
192- max += 0.001 # Prevent singularity
193-
194- linear_func = (paddle .arange (dim , dtype = paddle .float32 ) - min ) / (max - min )
195- ramp_func = paddle .clamp (linear_func , 0 , 1 )
196- return ramp_func
197-
198- pos_freqs = base ** (paddle .arange (0 , dim , 2 ).astype (paddle .float32 ) / dim )
199- inv_freq_extrapolation = 1.0 / pos_freqs
200- inv_freq_interpolation = 1.0 / (factor * pos_freqs )
201-
202- # truncate = config.rope_parameters.get("truncate", True)
203- low , high = find_correction_range (beta_fast , beta_slow , dim , base , original_max_position_embeddings , True )
204-
205- # Get n-dimensional rotational scaling corrected for extrapolation
206- inv_freq_extrapolation_factor = 1 - linear_ramp_factor (low , high , dim // 2 ).astype (paddle .float32 )
207-
208- inv_freq = (
209- inv_freq_interpolation * (1 - inv_freq_extrapolation_factor )
210- + inv_freq_extrapolation * inv_freq_extrapolation_factor
211- )
212- return inv_freq , attention_factor
213-
214-
215140class DeepseekV3YarnRotaryEmbedding (nn .Layer ):
216141 def __init__ (self , config : DeepseekV3Config , device = None ):
217142 super ().__init__ ()
@@ -221,11 +146,38 @@ def __init__(self, config: DeepseekV3Config, device=None):
221146
222147 rope_parameters = self .config .rope_parameters
223148 self .rope_type = rope_parameters .get ("rope_type" , rope_parameters .get ("type" , "default" ))
224- assert self .rope_type == "yarn"
149+ rope_init_fn = self .compute_default_rope_parameters
150+ if self .rope_type != "default" :
151+ rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
152+ inv_freq , self .attention_scaling = rope_init_fn (self .config )
153+
154+ self .register_buffer ("inv_freq" , inv_freq , persistable = False )
155+ self .original_inv_freq = inv_freq
156+
157+ @staticmethod
158+ def compute_default_rope_parameters (
159+ config : Optional [DeepseekV3Config ] = None ,
160+ seq_len : Optional [int ] = None ,
161+ ) -> tuple ["paddle.Tensor" , float ]:
162+ """
163+ Computes the inverse frequencies according to the original RoPE implementation
164+ Args:
165+ config ([`PreTrainedConfig`]):
166+ The model configuration.
167+ seq_len (`int`, *optional*):
168+ The current sequence length. Unused for this type of RoPE.
169+ Returns:
170+ Tuple of (`paddle.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
171+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
172+ """
173+ base = config .rope_parameters ["rope_theta" ]
174+ dim = getattr (config , "head_dim" , None ) or config .hidden_size // config .num_attention_heads
175+
176+ attention_factor = 1.0 # Unused in this type of RoPE
225177
226- self . inv_freq , self . attention_scaling = _compute_yarn_parameters ( config )
227- self . register_buffer ( "inv_freq" , self . inv_freq , persistable = False )
228- # self.original_inv_freq = self. inv_freq
178+ # Compute the inverse frequencies
179+ inv_freq = 1.0 / ( base ** ( paddle . arange ( 0 , dim , 2 , dtype = paddle . int64 ). astype ( dtype = paddle . float32 ) / dim ) )
180+ return inv_freq , attention_factor
229181
230182 @dynamic_rope_update
231183 def forward (self , x , position_ids ):
0 commit comments