Add Mistral3 multimodal support with Pixtral vision encoder#431
Add Mistral3 multimodal support with Pixtral vision encoder#431dai-yamashita wants to merge 1 commit intoelixir-nx:mainfrom
Conversation
This adds support for Mistral3 multimodal models (vision + text): - `Bumblebee.Vision.Pixtral`: Pixtral vision encoder with RoPE support - `Bumblebee.Text.Mistral3`: Mistral3 text decoder with interleaved attention - `Bumblebee.Multimodal.Mistral3`: Vision-language model combining Pixtral and Mistral3 with multimodal projector for image-conditioned generation - Ministral/Ministral3 variant support with interleaved attention - Devstral 2 (Ministral3) model support Supported architectures: - PixtralVisionModel - Mistral3Model, Mistral3ForCausalLM, Mistral3ForSequenceClassification - Mistral3ForConditionalGeneration (multimodal) - Ministral3ForCausalLM
dd11e25 to
911cc2d
Compare
jonatanklosko
left a comment
There was a problem hiding this comment.
Hey @dai-yamashita, this has multiple parts, so I would split it into multiple PR, for example:
- Changes to the existing Mistral.
- Pixtral model.
- Mistral3 model.
I also dropped two high level comments inline.
Also note that for this to work end-to-end, and important piece is all the pixtral image processing. It may be worth first doing a proof of concept where an actual generation with images works, and then submit the pieces as PRs.
| "Mistral3ForCausalLM" => {Bumblebee.Text.Mistral3, :for_causal_language_modeling}, | ||
| "Mistral3ForSequenceClassification" => | ||
| {Bumblebee.Text.Mistral3, :for_sequence_classification}, |
There was a problem hiding this comment.
As far as I can tell Mistral3ForCausalLM and Mistral3ForSequenceClassification don't exist in hf/transformers. Looking at modeling_mistral3.py there's only Mistral3ForConditionalGeneration and the corresponding Mistral3Model, which should be a single module in Bumblebee.
| # Expected values from Bumblebee inference with model generated by: | ||
| # test/fixtures/scripts/generate_expected_values.py (torch.manual_seed(42)) | ||
| assert_all_close( | ||
| outputs.logits[[.., 1..3, 1..3]], | ||
| Nx.tensor([ | ||
| [ | ||
| [3.5014779567718506, -3.962040662765503, -4.744167327880859], |
There was a problem hiding this comment.
We should not assert against values from Bumblebee, because it doesn't verify the current implementation is correct. We want those reference values to be generated using Python transformers, this way we know the implementation behaves the same. See #422 for a example complete PR with tests.
There was a problem hiding this comment.
Thank you very much for your review.
We truly appreciate your thoughtful advice on various pending issues and how to proceed.
We will carefully address the points you raised and resubmit the request accordingly.
Summary
This PR adds support for Mistral3 multimodal models (Ministral series):
attention_head_size,use_interleaved_attention,tie_word_embeddingsBumblebee.Vision.PixtralBumblebee.Text.Mistral3Bumblebee.Multimodal.Mistral3Supported architectures
PixtralVisionModelMistral3Model,Mistral3ForCausalLM,Mistral3ForSequenceClassificationMistral3ForConditionalGeneration(multimodal)Key features
transformer.exfor per-layer attention configurationTest plan