#include #include #include "common.h" void kernel_body(kernel_arg_t* __UNIFORM__ arg) { auto A = reinterpret_cast(arg->A_addr); auto B = reinterpret_cast(arg->B_addr); auto C = reinterpret_cast(arg->C_addr); tensor::fragment fragA; tensor::fragment fragB; tensor::fragment fragC; // calculate tile row & column based on block index uint32_t tile_row = blockIdx.y * arg->tileM; uint32_t tile_col = blockIdx.x * arg->tileN; uint32_t N = arg->N; uint32_t K = arg->K; uint32_t tileK = arg->tileK; // Initialize accumulator tile to zero tensor::fill_fragment(fragC, 0.0f); for (int i = 0; i < K; i += tileK) { // Load A tile auto tileA = A + (tile_row * K + i); tensor::load_matrix_sync(fragA, tileA, K); // Load B tile auto tileB = B + (i * k + tile_col); tensor::load_matrix_sync(fragB, tileB, K); // Matrix multiply-accumulate: c += a * b tensor::mma_sync(fragC, fragA, fragB, fragC); } // Store the computed C tile auto tileC = C + (tile_row * N + tile_col); tensor::store_matrix_sync(tileC, fragC, N, tensor::mem_row_major); } int main() { kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH); return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg); }