Skip to content

Add DeepSeek Engram layer#3010

Open
shuningjin wants to merge 1 commit intomainfrom
shuningjin-engram
Open

Add DeepSeek Engram layer#3010
shuningjin wants to merge 1 commit intomainfrom
shuningjin-engram

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Jan 26, 2026

Description

Background

What this PR does

Add Engram layer: engram.py

  • NgramHashMapping (non-parametric): CompressedTokenizer + hashing logic, convert "input_id" to "ngram hash_token_id"
    • CompressedTokenizer (non-parametric): convert "input_id" to "compresed_input_id"
  • Engram (multi-branch): inputs are "ngram hash_token_id" and "transformer state", MultiHeadEmbedding (lookup embedding using hash id as static memory) + context-aware gating (dot product static memory with contextual state) + ShortConv (temporal smoothing)
    • MultiHeadEmbedding: convert ngram hash_token_id to ngram embedding vector
    • ShortConv (multi-branch): depthwise (mix time steps, not mix channel), causal, short means kernel size is small

Add unit test: tests.unit.engram_vs_reference_test

  • for each component, verify the output matches that from reference code

Implementation Notes

Placement of: NgramHashMapping

  • NgramHashMapping converts vanilla token-ids to hashed ngram token-ids, which Engram consumes for embedding lookup
  • Future: I would like to NgramHashMapping and hash_input_ids generation be put in data input pipeline, which is CPU intensive. Just like how we put tokenizer and input_ids generation in pipeline.

Multi-branch

  • Engram and ShortConv handles multi-branch input and multi-branch output (if hc_mult > 1), optimized with nnx.vmap
  • Future: to be integrated into multi-branch backbone like mHC.

Tests

unit test against reference

python3 -m pytest -v --pyargs tests.unit.engram_vs_reference_test -rP -s

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 26, 2026

Codecov Report

❌ Patch coverage is 0% with 205 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/engram.py 0.00% 205 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shuningjin shuningjin changed the title [DRAFT] do no merge [DRAFT] engram Jan 29, 2026
@shuningjin shuningjin force-pushed the shuningjin-engram branch 2 times, most recently from 93458cf to 21cec5f Compare January 30, 2026 17:52
@shuningjin shuningjin changed the title [DRAFT] engram Add DeepSeek Engram layer Feb 4, 2026
@shuningjin shuningjin marked this pull request as ready for review February 4, 2026 21:48
@shuningjin
Copy link
Collaborator Author

@gemini-cli /review

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gemini-cli /review

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This pull request introduces a JAX implementation of the DeepSeek Engram layer, along with comprehensive unit tests that validate its behavior against a PyTorch reference. The code is well-structured and the implementation appears to be correct and thorough. The core logic is sound, and the use of vectorization with nnx.vmap is a good practice for performance.

🔍 General Feedback

  • Good Testing: The inclusion of unit tests comparing the JAX implementation to a PyTorch reference is excellent. This provides high confidence in the correctness of the implementation.
  • Clear Implementation: The code in engram.py is well-commented and organized, making it easy to follow the logic from the original paper.
  • TODOs: I've commented on the TODOs left in the code. Addressing them will improve the clarity and robustness of the implementation.

# Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]}
self.vocab_size_across_layers = self._calculate_vocab_size_across_layers()

def _calculate_multipliers_across_layers(self, seed: int):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Your comment here is valid and raises a good question. Using `tokenizer.pad_id` directly would be more robust and less prone to configuration errors. If there's a specific reason to pass `pad_id` separately and then look it up in the `lookup_table`, it would be beneficial to document that reasoning here. Otherwise, I'd recommend simplifying this to use the tokenizer's padding ID directly.

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the test and CompressedTokenizer. Will continue to review the rest part tomorrow.


"""
DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models
<https://arxiv.org/pdf/2601.07372>`_, 2026
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this extra "_" on purpose?

engram_head_dim: int = 32
engram_num_heads: int = 8 # num heads per n-gram
# Hashing
engram_pad_id: int = 2 # TODO(shuningjin): not the same as tokenizer.pad_id?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be defined by users?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! I left some initial comments, and may need to go over multihead embedding and conv parts. It should be quick.

n-grams into fixed integer IDs. To handle the large combinatorial space, it uses:
1. Unique Prime Vocabularies: Per-head prime moduli to minimize collision overlap.
2. Sliding Window: Efficient shifting to generate n-gram views.
3. Lightweight Hashing: A multiplicative-XOR function (Rabin-Karp variant).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I may miss this part in reference implementation. Did you add this optimization?

A dictionary mapping layer_id to a list of `max_ngram_size` multipliers.
"""
# Pre-calculate bounds for random generation
max_long = np.iinfo(np.int64).max
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you cross check if we could update all np to jnp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is to be in data input pipeline, maybe should align with that. Does it use np or jnp?

LAYER_PRIME_OFFSET = 10007

layer_multipliers = {}
for layer_id in self.layer_ids:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could update this block using vectorized operation? dim will depends on len(layer_ids). It's fixed at compile time.

quant: Optional[Quant] = None,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
*,
hc_mult: int = 4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we put params with default value at the very end?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions. I have updated the argument order.

axis=-1,
kernel_init=self.kernel_init,
# TODO(shuningjin): this needs to be actual logical axis? @reviewer
kernel_axes=("engram_dim", "embed"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add the sharding constraint into base.yml.

logical_axis_rules: [

I see it is smaller dim compared to emb, we could shard on tensor as a starting point. I see embed usually sharding on fsdp, sequence, context etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Added ['engram_dim', ['tensor']].

Shape annotation:
B: Batch Size
S: Sequence Length
G: hc_mult, Number of Branches
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to separate this config or treat it same as mhc_expansion_rate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will reuse mhc_expansion_rate. Have replaced all hc_mult.

# Norms (vectorized)
# Independent weights per branch, Branched input
@nnx.split_rngs(splits=hc_mult)
@nnx.vmap(in_axes=0, out_axes=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sharding on batch dimension? Why is that? Similar comment for other in_axes=0 vmap op.

# Vectorized broadcast: apply each of the G key_projs to the SAME embeddings.
# in_axes: (0, None) -> 0 splits the Dense layers, None broadcasts embeddings
# out_axes: 2 -> Stack the results at axis 2 to get (B, S, G, D)
@nnx.vmap(in_axes=(0, None), out_axes=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this in_axes working properly? I see your unit test has b=2 setup. When I integrated flash attn with sparse attn, I have to change the unit test to from b=2 to b=4 when sharding on fsdp, otherwise, it will fail on v5p-8 local machine.

max_ngram_size: int,
engram_num_heads: int,
layer_ids: List[int],
tokenizer: HFTokenizer,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you saying would like to put the look up table into data pipeline. Is this structure or performance beneficial? When we call the engram from decoder layer, we need to pass this tokenizer. So you are thinking, this engram module will call/depend on data pipeline to look up?

Copy link
Collaborator Author

@shuningjin shuningjin Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to put NgramHashMapping and hash_input_ids generation in data input pipeline, which is CPU intensive. Just like how we put tokenizer and input_ids generation in pipeline.

This is the overall structure: https://screenshot.googleplex.com/7YYxr4z7UqvBkpN

Also: b/478294696#comment5

self.backbone_config = BackBoneConfig(self.config)
tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True)
# input
batch, seq_len = 2, 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we set up a longer sequence, like 8, so test overlap of 2/3-grams?

Copy link
Collaborator Author

@shuningjin shuningjin Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing that. I change all tests to seq_len=32. Tests still pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants