Commit 26c82db
[JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API (#2523)
* Fix incorrect calculation of segment pos from segment ids for thd cases and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Correct the assert condition
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Modify fused attn tests to pass new args to from_segment_ids_and_pos()
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Calculate seg ids before pos
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* 1. Change the signature for from_segment_ids_and_pos()
2. Add support for THD in from_segment_ids_and_pos()
3. Assert if load balanced segment_ids is passed to generate a segment_pos
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Pass keyword-only args by name
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
* nit: Fix typo to use seg_ids instead of segment_ids
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
* nit: Fix comments
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
* Modify the function call to differentiate between load balancing and actually reordered segment_ids and segment_pos
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix the is_segment_ids_reordered to be set only when CP and load balancing
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Fix comments for from_segment_ids_and_pos()
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* Code clean up
for more information, see https://pre-commit.ci
Fix lint errors
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>1 parent 5ba01fa commit 26c82db
File tree
2 files changed
+90
-11
lines changed- tests/jax
- transformer_engine/jax
2 files changed
+90
-11
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
668 | 668 | | |
669 | 669 | | |
670 | 670 | | |
| 671 | + | |
| 672 | + | |
671 | 673 | | |
672 | 674 | | |
673 | 675 | | |
674 | 676 | | |
675 | 677 | | |
676 | 678 | | |
677 | | - | |
678 | | - | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
679 | 689 | | |
680 | 690 | | |
681 | 691 | | |
| |||
704 | 714 | | |
705 | 715 | | |
706 | 716 | | |
| 717 | + | |
| 718 | + | |
707 | 719 | | |
708 | 720 | | |
709 | 721 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
658 | 658 | | |
659 | 659 | | |
660 | 660 | | |
661 | | - | |
| 661 | + | |
662 | 662 | | |
663 | 663 | | |
664 | 664 | | |
| |||
796 | 796 | | |
797 | 797 | | |
798 | 798 | | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
799 | 802 | | |
800 | 803 | | |
801 | | - | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
802 | 807 | | |
803 | 808 | | |
804 | 809 | | |
| |||
812 | 817 | | |
813 | 818 | | |
814 | 819 | | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
815 | 823 | | |
816 | 824 | | |
817 | 825 | | |
818 | 826 | | |
819 | 827 | | |
820 | | - | |
821 | | - | |
822 | | - | |
823 | | - | |
824 | | - | |
825 | | - | |
826 | | - | |
| 828 | + | |
| 829 | + | |
| 830 | + | |
| 831 | + | |
| 832 | + | |
| 833 | + | |
| 834 | + | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
| 838 | + | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
| 845 | + | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
| 862 | + | |
| 863 | + | |
| 864 | + | |
| 865 | + | |
| 866 | + | |
| 867 | + | |
| 868 | + | |
| 869 | + | |
| 870 | + | |
| 871 | + | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
827 | 891 | | |
828 | 892 | | |
829 | 893 | | |
830 | 894 | | |
| 895 | + | |
| 896 | + | |
| 897 | + | |
831 | 898 | | |
832 | 899 | | |
833 | 900 | | |
| |||
0 commit comments