前言
最近准备写一套基于GLSL Compute
Shader的高性能算子,其中也包括矩阵乘。为了能够站在巨人的肩膀上,这一系列文章中我会深入解读一下ggml和ncnn这两套框架的GLSL矩阵乘写法,然后再尝试自己写一个性能比较接近cutlass
3.x的矩阵乘算子。
ggml
先介绍ggml。ggml是llama.cpp的底层张量库。其中实现了包括SIMT、coopmat1、coopmat2三种形式的GEMM。
SIMT
所谓SIMT是指不使用Tensor Core的传统实现方式。对应的源码位于mul_mm.comp。
算子的作者把SIMT和Tensor
Core的实现放在一块了,随后再通过宏来选择具体的实现,非常混乱。再加上各种诸如ne02
之类的让人摸不着头脑的缩写命名,属实防御力拉满。我这里将SIMT的部分单独剥离出来放在下面,同时去掉MOE相关的内容(由MUL_MAT_ID
宏包裹),然后丢给Claude让他再对一些重点变量的含义做一些注释,方便阅读理解。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
| #include "types.glsl"
#ifndef LOAD_VEC_A #define LOAD_VEC_A 1 #endif #ifndef LOAD_VEC_B #define LOAD_VEC_B 1 #endif
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) #define LOAD_VEC_BATCH_A 2 #else #define LOAD_VEC_BATCH_A 1 #endif #if !defined(ALIGNED) #define LOAD_VEC_BATCH_B 2 #else #define LOAD_VEC_BATCH_B 1 #endif
#if !defined(TO_FLOAT_TYPE) #define TO_FLOAT_TYPE FLOAT_TYPE #endif
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; #endif #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter { uint M; uint N; uint K; uint stride_a; uint stride_b; uint stride_d;
uint batch_stride_a; uint batch_stride_b; uint batch_stride_d;
uint k_split; uint ne02; uint ne12; uint broadcast2; uint broadcast3; } p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 7) const uint TM = 4; layout (constant_id = 8) const uint TN = 2; layout (constant_id = 9) const uint TK = 1; layout (constant_id = 10) const uint WARP = 32;
#define SHMEM_STRIDE (BK / 2 + 1)
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#include "mul_mm_funcs.glsl"
void main() { const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12;
const uint i03 = i13 / p.broadcast3; const uint i02 = i12 / p.broadcast2;
const uint batch_idx_a = i03 * p.ne02 + i02;
const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m; const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER;
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM);
const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
const uint start_k = ik * p.k_split; const uint end_k = min(p.K, (ik + 1) * p.k_split);
uint pos_a = ( batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
ACC_TYPE sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; FLOAT_TYPE_VEC2 cache_b[TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); }
for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k); } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k); }
barrier();
pos_a += BK / LOAD_VEC_A; pos_b += BK / LOAD_VEC_B;
[[unroll]] for (uint i = 0; i < BK / 2; i++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint j = 0; j < TN; j++) { cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; }
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); } } } } }
barrier(); }
const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); } } } } } }
|
大体上的思路和Simon的这篇博客中演示的wrap-tile优化很像,都是做了Block-Wrap-Thread的三级分块,如下图所示。

coopmat1
TODO
说到矩阵乘,自然离不开Tensor
Core。事实上,在NVIDIA所实现的各种加速特性中,Tensor
Core几乎是Vulkan唯一支持的。其他特性如TMA、Wrap Specialization、Thread
Group
Cluster、绕过L1/L2缓存的异步load/store等等,都不受当前的Vulkan支持。Vulkan通过VK_KHR_cooperative_matrix
(下称coopmat1)和VK_NV_cooperative_matrix2
(下称coopmat2)这两个扩展来调用Tensor
Core。
本小节我们先来阅读基于coopmat1扩展的GEMM写法。同上一小节,我们先对源码做一下宏替换,不然太难读了。