tensor api + test

This commit is contained in:
tinebp 2025-05-18 16:15:46 -07:00
parent df68e4359a
commit dd500fb5af
9 changed files with 656 additions and 30 deletions

View file

@ -0,0 +1,46 @@
#include <vx_spawn.h>
#include <vx_tensor.h>
#include "common.h"
void kernel_body(kernel_arg_t* __UNIFORM__ arg) {
auto A = reinterpret_cast<I_TYPE*>(arg->A_addr);
auto B = reinterpret_cast<I_TYPE*>(arg->B_addr);
auto C = reinterpret_cast<O_TYPE*>(arg->C_addr);
tensor::fragment<tensor::half_t, tensor::row_major> fragA;
tensor::fragment<tensor::half_t, tensor::row_major> fragB;
tensor::fragment<float, tensor::row_major> fragC;
// calculate tile row & column based on block index
uint32_t tile_row = blockIdx.y * arg->tileM;
uint32_t tile_col = blockIdx.x * arg->tileN;
uint32_t N = arg->N;
uint32_t K = arg->K;
uint32_t tileK = arg->tileK;
// Initialize accumulator tile to zero
tensor::fill_fragment(fragC, 0.0f);
for (int i = 0; i < K; i += tileK) {
// Load A tile
auto tileA = A + (tile_row * K + i);
tensor::load_matrix_sync(fragA, tileA, K);
// Load B tile
auto tileB = B + (i * k + tile_col);
tensor::load_matrix_sync(fragB, tileB, K);
// Matrix multiply-accumulate: c += a * b
tensor::mma_sync(fragC, fragA, fragB, fragC);
}
// Store the computed C tile
auto tileC = C + (tile_row * N + tile_col);
tensor::store_matrix_sync(tileC, fragC, N, tensor::mem_row_major);
}
int main() {
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg);
}