Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
e8ae3c9 to
f095801
Compare
93458cf to
21cec5f
Compare
bb190ed to
2dc37df
Compare
|
@gemini-cli /review |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This pull request introduces a JAX implementation of the DeepSeek Engram layer, along with comprehensive unit tests that validate its behavior against a PyTorch reference. The code is well-structured and the implementation appears to be correct and thorough. The core logic is sound, and the use of vectorization with nnx.vmap is a good practice for performance.
🔍 General Feedback
- Good Testing: The inclusion of unit tests comparing the JAX implementation to a PyTorch reference is excellent. This provides high confidence in the correctness of the implementation.
- Clear Implementation: The code in
engram.pyis well-commented and organized, making it easy to follow the logic from the original paper. - TODOs: I've commented on the
TODOs left in the code. Addressing them will improve the clarity and robustness of the implementation.
src/MaxText/layers/engram.py
Outdated
| # Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]} | ||
| self.vocab_size_across_layers = self._calculate_vocab_size_across_layers() | ||
|
|
||
| def _calculate_multipliers_across_layers(self, seed: int): |
There was a problem hiding this comment.
|
🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details. |
2dc37df to
5371cae
Compare
RissyRan
left a comment
There was a problem hiding this comment.
I reviewed the test and CompressedTokenizer. Will continue to review the rest part tomorrow.
|
|
||
| """ | ||
| DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models | ||
| <https://arxiv.org/pdf/2601.07372>`_, 2026 |
There was a problem hiding this comment.
is this extra "_" on purpose?
| engram_head_dim: int = 32 | ||
| engram_num_heads: int = 8 # num heads per n-gram | ||
| # Hashing | ||
| engram_pad_id: int = 2 # TODO(shuningjin): not the same as tokenizer.pad_id? |
There was a problem hiding this comment.
Does this need to be defined by users?
There was a problem hiding this comment.
In the reference code, it is a hyper-parameter. I am not sure why.
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change! I left some initial comments, and may need to go over multihead embedding and conv parts. It should be quick.
src/MaxText/layers/engram.py
Outdated
| n-grams into fixed integer IDs. To handle the large combinatorial space, it uses: | ||
| 1. Unique Prime Vocabularies: Per-head prime moduli to minimize collision overlap. | ||
| 2. Sliding Window: Efficient shifting to generate n-gram views. | ||
| 3. Lightweight Hashing: A multiplicative-XOR function (Rabin-Karp variant). |
There was a problem hiding this comment.
Nice! I may miss this part in reference implementation. Did you add this optimization?
| A dictionary mapping layer_id to a list of `max_ngram_size` multipliers. | ||
| """ | ||
| # Pre-calculate bounds for random generation | ||
| max_long = np.iinfo(np.int64).max |
There was a problem hiding this comment.
Could you cross check if we could update all np to jnp?
There was a problem hiding this comment.
If this is to be in data input pipeline, maybe should align with that. Does it use np or jnp?
| LAYER_PRIME_OFFSET = 10007 | ||
|
|
||
| layer_multipliers = {} | ||
| for layer_id in self.layer_ids: |
There was a problem hiding this comment.
Do you think we could update this block using vectorized operation? dim will depends on len(layer_ids). It's fixed at compile time.
src/MaxText/layers/engram.py
Outdated
| quant: Optional[Quant] = None, | ||
| kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), | ||
| *, | ||
| hc_mult: int = 4, |
There was a problem hiding this comment.
shall we put params with default value at the very end?
There was a problem hiding this comment.
Thanks for the suggestions. I have updated the argument order.
| axis=-1, | ||
| kernel_init=self.kernel_init, | ||
| # TODO(shuningjin): this needs to be actual logical axis? @reviewer | ||
| kernel_axes=("engram_dim", "embed"), |
There was a problem hiding this comment.
You could add the sharding constraint into base.yml.
maxtext/src/MaxText/configs/base.yml
Line 402 in 352dd58
I see it is smaller dim compared to emb, we could shard on tensor as a starting point. I see embed usually sharding on fsdp, sequence, context etc.
There was a problem hiding this comment.
Thanks for the suggestion! Added ['engram_dim', ['tensor']].
src/MaxText/layers/engram.py
Outdated
| Shape annotation: | ||
| B: Batch Size | ||
| S: Sequence Length | ||
| G: hc_mult, Number of Branches |
There was a problem hiding this comment.
Do you plan to separate this config or treat it same as mhc_expansion_rate?
There was a problem hiding this comment.
I will reuse mhc_expansion_rate. Have replaced all hc_mult.
| # Norms (vectorized) | ||
| # Independent weights per branch, Branched input | ||
| @nnx.split_rngs(splits=hc_mult) | ||
| @nnx.vmap(in_axes=0, out_axes=0) |
There was a problem hiding this comment.
Are you sharding on batch dimension? Why is that? Similar comment for other in_axes=0 vmap op.
| # Vectorized broadcast: apply each of the G key_projs to the SAME embeddings. | ||
| # in_axes: (0, None) -> 0 splits the Dense layers, None broadcasts embeddings | ||
| # out_axes: 2 -> Stack the results at axis 2 to get (B, S, G, D) | ||
| @nnx.vmap(in_axes=(0, None), out_axes=2) |
There was a problem hiding this comment.
Do you know if this in_axes working properly? I see your unit test has b=2 setup. When I integrated flash attn with sparse attn, I have to change the unit test to from b=2 to b=4 when sharding on fsdp, otherwise, it will fail on v5p-8 local machine.
| max_ngram_size: int, | ||
| engram_num_heads: int, | ||
| layer_ids: List[int], | ||
| tokenizer: HFTokenizer, |
There was a problem hiding this comment.
When you saying would like to put the look up table into data pipeline. Is this structure or performance beneficial? When we call the engram from decoder layer, we need to pass this tokenizer. So you are thinking, this engram module will call/depend on data pipeline to look up?
There was a problem hiding this comment.
I would like to put NgramHashMapping and hash_input_ids generation in data input pipeline, which is CPU intensive. Just like how we put tokenizer and input_ids generation in pipeline.
This is the overall structure: https://screenshot.googleplex.com/7YYxr4z7UqvBkpN
Also: b/478294696#comment5
| self.backbone_config = BackBoneConfig(self.config) | ||
| tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) | ||
| # input | ||
| batch, seq_len = 2, 3 |
There was a problem hiding this comment.
Could we set up a longer sequence, like 8, so test overlap of 2/3-grams?
There was a problem hiding this comment.
Thanks for noticing that. I change all tests to seq_len=32. Tests still pass.
5371cae to
2c4e71f
Compare
148db02 to
3fbfd19
Compare
Description
Background
What this PR does
Add Engram layer:
engram.pyNgramHashMapping(non-parametric):CompressedTokenizer+ hashing logic, convert "input_id" to "ngram hash_token_id"CompressedTokenizer(non-parametric): convert "input_id" to "compresed_input_id"Engram(multi-branch): inputs are "ngram hash_token_id" and "transformer state",MultiHeadEmbedding(lookup embedding using hash id as static memory) + context-aware gating (dot product static memory with contextual state) +ShortConv(temporal smoothing)MultiHeadEmbedding: convert ngram hash_token_id to ngram embedding vectorShortConv(multi-branch): depthwise (mix time steps, not mix channel), causal, short means kernel size is smallAdd unit test:
tests.unit.engram_vs_reference_testImplementation Notes
Placement of:
NgramHashMappingNgramHashMappingconverts vanilla token-ids to hashed ngram token-ids, whichEngramconsumes for embedding lookupNgramHashMappingand hash_input_ids generation be put in data input pipeline, which is CPU intensive. Just like how we put tokenizer and input_ids generation in pipeline.Multi-branch
EngramandShortConvhandles multi-branch input and multi-branch output (ifhc_mult > 1), optimized with nnx.vmapTests
unit test against reference
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.