Skip to content

Commit ca4d670

Browse files
gshtrasJArnoldAMD
andauthored
Library versions bump (ROCm#343)
* Updated library versions * Simple num_stages fix without re-tuning for performance * Tuning script adaptation for the new triton * navi lib versions * Update MI300X fused_moe configs for Triton 3.2 (ROCm#344) --------- Co-authored-by: Jeremy Arnold <103538711+JArnoldAMD@users.noreply.github.com>
1 parent 1dcd9fe commit ca4d670

File tree

58 files changed

+3576
-1320
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3576
-1320
lines changed

Dockerfile.base

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3-complete
2-
ARG HIPBLASLT_BRANCH="507a649"
2+
ARG HIPBLASLT_BRANCH="4d40e36"
33
ARG LEGACY_HIPBLASLT_OPTION=
4-
ARG RCCL_BRANCH="dfe4a3e"
4+
ARG RCCL_BRANCH="648a58d"
55
ARG RCCL_REPO="https://github.com/ROCm/rccl"
6-
ARG TRITON_BRANCH="e192dba"
6+
ARG TRITON_BRANCH="e5be006"
77
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
8-
ARG PYTORCH_BRANCH="8bc4033"
8+
ARG PYTORCH_BRANCH="8d4926e"
99
ARG PYTORCH_VISION_BRANCH="v0.19.1"
1010
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1111
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
12-
ARG FA_BRANCH="c555642"
12+
ARG FA_BRANCH="b7d29fb"
1313
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
1414

1515
FROM ${BASE_IMAGE} AS base

Dockerfile.base_navi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3-complete
2-
ARG HIPBLASLT_BRANCH="d43d84a"
2+
ARG HIPBLASLT_BRANCH="4d40e36"
33
ARG LEGACY_HIPBLASLT_OPTION=
4-
ARG RCCL_BRANCH="dfe4a3e"
4+
ARG RCCL_BRANCH="648a58d"
55
ARG RCCL_REPO="https://github.com/ROCm/rccl"
6-
ARG TRITON_BRANCH="e192dba"
6+
ARG TRITON_BRANCH="e5be006"
77
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
8-
ARG PYTORCH_BRANCH="8bc4033"
8+
ARG PYTORCH_BRANCH="8d4926e"
99
ARG PYTORCH_VISION_BRANCH="v0.19.1"
1010
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1111
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"

benchmarks/kernels/benchmark_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def get_rocm_tuning_space(use_fp16):
155155
# For now we see better perf with num_stages=0 for all gemm configs we care
156156
# But keep this explicit so that we do not forget we may need to set it to
157157
# other values in the future
158-
num_stage_range = [0]
158+
num_stage_range = [2]
159159
waves_per_eu_range = [0]
160160
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
161161
kpack_range = [1, 2] if use_fp16 else []
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 64,
5+
"BLOCK_SIZE_K": 256,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 4,
8+
"num_stages": 2,
9+
"waves_per_eu": 0
10+
},
11+
"2": {
12+
"BLOCK_SIZE_M": 16,
13+
"BLOCK_SIZE_N": 64,
14+
"BLOCK_SIZE_K": 256,
15+
"GROUP_SIZE_M": 1,
16+
"num_warps": 4,
17+
"num_stages": 2,
18+
"waves_per_eu": 0
19+
},
20+
"4": {
21+
"BLOCK_SIZE_M": 16,
22+
"BLOCK_SIZE_N": 32,
23+
"BLOCK_SIZE_K": 128,
24+
"GROUP_SIZE_M": 1,
25+
"num_warps": 2,
26+
"num_stages": 2,
27+
"waves_per_eu": 0
28+
},
29+
"8": {
30+
"BLOCK_SIZE_M": 16,
31+
"BLOCK_SIZE_N": 64,
32+
"BLOCK_SIZE_K": 256,
33+
"GROUP_SIZE_M": 1,
34+
"num_warps": 2,
35+
"num_stages": 2,
36+
"waves_per_eu": 0
37+
},
38+
"16": {
39+
"BLOCK_SIZE_M": 16,
40+
"BLOCK_SIZE_N": 64,
41+
"BLOCK_SIZE_K": 256,
42+
"GROUP_SIZE_M": 1,
43+
"num_warps": 2,
44+
"num_stages": 2,
45+
"waves_per_eu": 0
46+
},
47+
"24": {
48+
"BLOCK_SIZE_M": 16,
49+
"BLOCK_SIZE_N": 64,
50+
"BLOCK_SIZE_K": 256,
51+
"GROUP_SIZE_M": 1,
52+
"num_warps": 4,
53+
"num_stages": 2,
54+
"waves_per_eu": 0
55+
},
56+
"32": {
57+
"BLOCK_SIZE_M": 16,
58+
"BLOCK_SIZE_N": 64,
59+
"BLOCK_SIZE_K": 256,
60+
"GROUP_SIZE_M": 4,
61+
"num_warps": 4,
62+
"num_stages": 2,
63+
"waves_per_eu": 0
64+
},
65+
"48": {
66+
"BLOCK_SIZE_M": 16,
67+
"BLOCK_SIZE_N": 64,
68+
"BLOCK_SIZE_K": 256,
69+
"GROUP_SIZE_M": 1,
70+
"num_warps": 4,
71+
"num_stages": 2,
72+
"waves_per_eu": 0
73+
},
74+
"64": {
75+
"BLOCK_SIZE_M": 32,
76+
"BLOCK_SIZE_N": 64,
77+
"BLOCK_SIZE_K": 256,
78+
"GROUP_SIZE_M": 4,
79+
"num_warps": 2,
80+
"num_stages": 2,
81+
"waves_per_eu": 0
82+
},
83+
"96": {
84+
"BLOCK_SIZE_M": 32,
85+
"BLOCK_SIZE_N": 64,
86+
"BLOCK_SIZE_K": 256,
87+
"GROUP_SIZE_M": 1,
88+
"num_warps": 2,
89+
"num_stages": 2,
90+
"waves_per_eu": 0
91+
},
92+
"128": {
93+
"BLOCK_SIZE_M": 64,
94+
"BLOCK_SIZE_N": 64,
95+
"BLOCK_SIZE_K": 256,
96+
"GROUP_SIZE_M": 4,
97+
"num_warps": 4,
98+
"num_stages": 2,
99+
"waves_per_eu": 0
100+
},
101+
"256": {
102+
"BLOCK_SIZE_M": 128,
103+
"BLOCK_SIZE_N": 128,
104+
"BLOCK_SIZE_K": 256,
105+
"GROUP_SIZE_M": 4,
106+
"num_warps": 8,
107+
"num_stages": 2,
108+
"waves_per_eu": 0
109+
},
110+
"512": {
111+
"BLOCK_SIZE_M": 256,
112+
"BLOCK_SIZE_N": 128,
113+
"BLOCK_SIZE_K": 128,
114+
"GROUP_SIZE_M": 4,
115+
"num_warps": 8,
116+
"num_stages": 2,
117+
"waves_per_eu": 0
118+
},
119+
"1024": {
120+
"BLOCK_SIZE_M": 128,
121+
"BLOCK_SIZE_N": 128,
122+
"BLOCK_SIZE_K": 256,
123+
"GROUP_SIZE_M": 1,
124+
"num_warps": 8,
125+
"num_stages": 2,
126+
"waves_per_eu": 0
127+
},
128+
"1536": {
129+
"BLOCK_SIZE_M": 128,
130+
"BLOCK_SIZE_N": 256,
131+
"BLOCK_SIZE_K": 128,
132+
"GROUP_SIZE_M": 1,
133+
"num_warps": 8,
134+
"num_stages": 2,
135+
"waves_per_eu": 0
136+
},
137+
"2048": {
138+
"BLOCK_SIZE_M": 128,
139+
"BLOCK_SIZE_N": 256,
140+
"BLOCK_SIZE_K": 128,
141+
"GROUP_SIZE_M": 1,
142+
"num_warps": 8,
143+
"num_stages": 2,
144+
"waves_per_eu": 0
145+
},
146+
"3072": {
147+
"BLOCK_SIZE_M": 128,
148+
"BLOCK_SIZE_N": 256,
149+
"BLOCK_SIZE_K": 128,
150+
"GROUP_SIZE_M": 1,
151+
"num_warps": 8,
152+
"num_stages": 2,
153+
"waves_per_eu": 0
154+
},
155+
"4096": {
156+
"BLOCK_SIZE_M": 256,
157+
"BLOCK_SIZE_N": 256,
158+
"BLOCK_SIZE_K": 64,
159+
"GROUP_SIZE_M": 1,
160+
"num_warps": 8,
161+
"num_stages": 2,
162+
"waves_per_eu": 0
163+
}
164+
}

0 commit comments

Comments
 (0)