前言
最近准备写一套基于GLSL Compute
Shader的高性能算子,其中也包括矩阵乘。为了能够站在巨人的肩膀上,这一系列文章中我会深入解读一下ggml和ncnn这两套框架的GLSL矩阵乘写法,然后再尝试自己写一个性能比较接近cutlass
3.x的矩阵乘算子。
ggml
先介绍ggml。ggml是llama.cpp的底层张量库。其中实现了包括SIMT、coopmat1、coopmat2三种形式的GEMM。
SIMT
所谓SIMT是指不使用Tensor Core的传统实现方式。对应的源码位于mul_mm.comp。
算子的作者把SIMT和Tensor
Core的实现放在一块了,随后再通过宏来选择具体的实现,非常混乱。再加上各种诸如ne02之类的让人摸不着头脑的缩写命名,属实防御力拉满。我这里将SIMT的部分单独剥离出来放在下面,同时去掉MOE和batch相关的内容(MOE的部分由MUL_MAT_ID宏包裹),然后丢给Claude让他再对一些重点变量的含义做一些注释,最后做一些人工修订。
这里补充解释一下ggml中的命名规则,方便阅读。以ne13为例:
ne指num of
element。说明ne13指的是元素数量而不是字节数量。
1指1号矩阵也就是B矩阵。0号矩阵对应A矩阵。ggml中的计数都是从0开始的。
3指3号维度的大小。3号维度就是[N,C,H,W]中的最高维度N。
那么ne13的意思就是B矩阵(1号矩阵)的批次数量(3号维度,N维度的大小)。同理,对于行优先的A矩阵,ne00就是A矩阵(0号矩阵)的列数(0号维度,W维度的大小)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
| #version 450
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(local_size_x_id = 0, local_size_y = 1) in;
layout (binding = 0) readonly buffer A {f16vec4 data_a[];};
layout (binding = 1) readonly buffer B {f16vec4 data_b[];};
layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter { uint M; uint N; uint K; uint stride_a; uint stride_b; uint stride_d;
uint k_split; } p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 7) const uint TM = 4; layout (constant_id = 8) const uint TN = 2; layout (constant_id = 9) const uint TK = 1; layout (constant_id = 10) const uint WARP = 32;
#define SHMEM_STRIDE (BK / 2 + 1)
shared f16vec2 buf_a[BM * SHMEM_STRIDE]; shared f16vec2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { const uint idx = pos_a + col * p.stride_a / 4 + row; const uint buf_idx = col * SHMEM_STRIDE + row * 4 / 2; f16vec4 aa = data_a[idx]; buf_a[buf_idx] = aa.xy; buf_a[buf_idx + 1] = aa.zw; }
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { const uint idx = pos_b + col * p.stride_b / 4 + row; const uint buf_idx = col * SHMEM_STRIDE + row * 4 / 2; f16vec4 bb = data_b[idx]; buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; }
void main() {
const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m; const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER;
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM);
const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / 4); const uint loadc_a = gl_LocalInvocationID.x / (BK / 4); const uint loadr_b = gl_LocalInvocationID.x % (BK / 4); const uint loadc_b = gl_LocalInvocationID.x / (BK / 4);
const uint loadstride_a = gl_WorkGroupSize.x * 4 / BK; const uint loadstride_b = gl_WorkGroupSize.x * 4 / BK;
const uint start_k = ik * p.k_split; const uint end_k = min(p.K, (ik + 1) * p.k_split);
uint pos_a = (ir * BM * p.stride_a + start_k) / 4; uint pos_b = (ic * BN * p.stride_b + start_k) / 4;
float sums[WMITER * TM * WNITER * TN]; f16vec2 cache_a[WMITER * TM]; f16vec2 cache_b[TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = float(0.0f); }
for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k); } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k); }
barrier();
pos_a += BK / 4; pos_b += BK / 4;
[[unroll]] for (uint i = 0; i < BK / 2; i++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint j = 0; j < TN; j++) { cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; }
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; sums[sums_idx] = fma(float(cache_a[wsir * TM + cr].x), float(cache_b[cc].x), fma(float(cache_a[wsir * TM + cr].y), float(cache_b[cc].y), sums[sums_idx])); } } } } }
barrier(); }
const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN;
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = float(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); } } } } } }
|
大体上的思路和Simon的这篇博客中演示的wrap-tile优化很像,都是做了Block-Wrap-Thread的三级分块,如下图所示。

我也在我的demo仓库里复刻了ggml的这个版本
coopmat1
TODO
说到矩阵乘,自然离不开Tensor
Core。事实上,在NVIDIA所实现的各种加速特性中,Tensor
Core几乎是Vulkan唯一支持的。其他特性如TMA、Wrap Specialization、Thread
Group
Cluster、绕过L1/L2缓存的异步load/store等等,都不受当前的Vulkan支持。Vulkan通过VK_KHR_cooperative_matrix(下称coopmat1)和VK_NV_cooperative_matrix2(下称coopmat2)这两个扩展来调用Tensor
Core。
本小节我们先来阅读基于coopmat1扩展的GEMM写法。同上一小节,我们先对源码做一下宏替换,不然太难读了。