diff --git a/fms_extras/utils/generation.py b/fms_extras/utils/generation.py index a31c125..6e6e218 100644 --- a/fms_extras/utils/generation.py +++ b/fms_extras/utils/generation.py @@ -357,24 +357,80 @@ def __extract_decode_output( Returns: Tuple[torch.Tensor, torch.Tensor] - the un-flattened next tokens per candidate per sequence, and the - un-flattened output embedding vector + the un-flattened logit scores per token per candidate per sequence, + and the un-flattened output embedding vectors """ logits, _, embeds = model_output # 1 n' v, 1 n' d OR bk 1+h v, bk 1+h d - next_vals = torch.argmax(logits, dim=-1) # 1 n' OR bk 1+h # If we used batch flattening / tree attention, unflatten the outputs if unflat_indices is not None: - next_vals = apply_index_map(next_vals[0], unflat_indices) # b k 1+h + logits = apply_index_map(logits[0], unflat_indices) # b k 1+h v embeds = apply_index_map(embeds[0], unflat_indices) # b k 1+h d else: - next_vals = next_vals.view( - batch_size, n_candidates, decode_seq_length - ) # b k 1+h + logits = logits.view( + batch_size, n_candidates, decode_seq_length, logits.size(2) + ) # b k 1+h v embeds = embeds.view( batch_size, n_candidates, decode_seq_length, embeds.size(2) ) # b k 1+h d - return next_vals, embeds + return logits, embeds + + +def __generate_targets( + logits: torch.Tensor, + do_sample: torch.Tensor, + temperature: float = 1.0, + top_k: int = 5, +) -> torch.Tensor: + """ + Extracts ground-truth tokens from a set of logits. If performing greedy decoding, + simply returns the most confident tokens. Otherwise, implements consistent multinomial + sampling - two identical distributions will always produce the same (randomized) sample. + Thus by induction, two candidates with identical prefixes will receive the same ground + truth sample up to the point their inputs diverge. This allows us to ensure that at least + one candidate will be accepted, so long as the candidate set covers the top_k options. + + For example, if the base model predicts tokens A and B with equal 50% probability, and the + speculator produces one candidate with A and another with B, with independent sampling there's + a 25% chance of rejecting both, even though one must be correct. Consistent sampling allows us + to avoid this. + + Args: + logits: torch.Tensor + Probability logits for a set of candidate sequences. Expects size + bsize x n_candidates x seq_len x vocab_size + do_sample: torch.Tensor + A tensor of booleans enabling/disabling non-greedy decoding with consistent + sampling, for each of bsize input sequences + temperature: float + Degree of smoothing on softmax sampling distribution + top_k: int + Sample only among the top_k most confident tokens + + Returns: + torch.Tensor + Tensor of chosen token values for each sequence + """ + + # Get sample distributions + logits = logits / temperature + v, _ = logits.topk(top_k) + logits[logits < v[:, :, :, [-1]]] = -float("inf") + probs = logits.softmax(-1) # b k 1+h v + + # Sample candidate-consistent ground truths: partition number line in [0,1] + # according to given multinomial distribution. Pick a random location + # on that line, return interval containing that location. + key = torch.rand(1, 1, logits.size(2), 1, device=probs.device) + a = ( + probs.cumsum(3).sub(key).sign() + ) # Sign flips on probability interval containing key + samples = a.sub(1).div(-2).sum(3) # Get index of sign-flip + + # Composite greedy and non greedy outputs + greedy = logits.argmax(-1) + samples = samples.to(dtype=greedy.dtype) + return torch.where(do_sample[:, None, None], samples, greedy) def speculative_generate( @@ -390,6 +446,9 @@ def speculative_generate( decode_model: Optional[Union[Callable, torch.nn.Module]] = None, # todo: This is a WIP to enable cudagraphs, currently its only for batch_size=1 cudagraphs: bool = False, + do_sample: bool = False, + temperature: float = 1.0, + top_k: int = 5, ): """ A reference implementation of speculative decoding generation. @@ -433,6 +492,15 @@ def speculative_generate( if True, cudagraphs is used and all metadata will be padded, otherwise metadata will not be padded unless required. Note: This is a WIP and only works for batch_size=1 + do_sample: bool + non-deterministic, multinomial output sampling. False for greedy. + Provides output diversity, but lowers speculative decoding speedup. + temperature: float + temperature of softmax when sampling. Lowering this should provide + better speculative decoding speedup when do_sample=True. + top_k: int + only search among top k tokens. Lowering this should provide + better speculative decoding speedup when do_sample=True. Returns: result: List of id tensors, possibly different lengths if batching. n_steps: Number of foward passes used to generate provided tokens. @@ -518,10 +586,20 @@ def speculative_generate( use_cache=True, ) # 1 n' v OR bk 1+h v - next_vals, embeds = __extract_decode_output( + logits, embeds = __extract_decode_output( output, unflat_indices, bsize, n_candidates, inp_len ) + if do_sample: + do_sample_vector = torch.ones(bsize, device=logits.device, dtype=torch.bool) + else: + do_sample_vector = torch.zeros( + bsize, device=logits.device, dtype=torch.bool + ) + next_vals = __generate_targets( + logits, do_sample_vector, temperature=temperature, top_k=top_k + ) + next_vals_list, embeds, parent_sequence_ids = __prune_candidates( input_ids, next_vals, embeds, kv_cache_manager, child_sequence_ids_list ) diff --git a/scripts/paged_speculative_inference.py b/scripts/paged_speculative_inference.py index 6d5ae0d..c529763 100644 --- a/scripts/paged_speculative_inference.py +++ b/scripts/paged_speculative_inference.py @@ -100,7 +100,21 @@ action="store_true", help="use a batch of prompts as input (note this is still wip for reduce-overhead=True)", ) -# top_k_tokens_per_head +parser.add_argument( + "--top_k", + type=int, + default=10, + help="sample only among top k most confident tokens (ignored if do_sample=False)", +) +parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="degree of smoothing for sampling distribution (ignored if do_sample=False)", +) +parser.add_argument( + "--do_sample", action="store_true", help="enable non-greedy generation" +) parser.add_argument( "--top_k_tokens_per_head", type=lambda s: list(map(int, s.split(","))), @@ -252,6 +266,9 @@ def infer(ids, warmup): # todo: we can only reduce-overhead for now when batch size is 1 flattening=not (args.compile and compile_mode == "reduce-overhead"), cudagraphs=cudagraphs, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k, threshes=args.top_k_tokens_per_head, ) else: @@ -261,9 +278,11 @@ def infer(ids, warmup): kv_cache_manager, max_new_tokens=100, max_seq_len=model.config.max_expected_seq_len, - do_sample=False, decode_model=decode_model, cudagraphs=cudagraphs, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k, ) if not warmup: total_tokens = 0