vx_spawn_threads implementation

This commit is contained in:
Blaise Tine 2024-06-07 07:52:15 -07:00
parent 8c5a783477
commit 96cb381885
5 changed files with 208 additions and 152 deletions

View file

@ -14,28 +14,54 @@
#ifndef __VX_SPAWN_H__
#define __VX_SPAWN_H__
#include <VX_types.h>
#include <vx_intrinsics.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef void (*vx_spawn_tasks_cb)(int task_id, const void *arg);
typedef void (*vx_spawn_task_groups_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, const void *arg);
typedef void (*vx_spawn_tasks_cb)(uint32_t task_id, const void *arg);
typedef void (*vx_serial_cb)(const void *arg);
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, const void * arg);
void vx_spawn_tasks(uint32_t num_tasks, vx_spawn_tasks_cb callback, const void * arg);
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, const void * arg);
void vx_serial(vx_serial_cb callback, void * arg);
inline void* vx_local_malloc(int local_group_id, int size) {
return (int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + local_group_id * size;
}
///////////////////////////////////////////////////////////////////////////////
void vx_serial(vx_serial_cb callback, const void * arg);
typedef union {
struct {
uint32_t x;
uint32_t y;
uint32_t z;
};
uint32_t m[3];
} dim3_t;
extern __thread dim3_t blockIdx;
extern __thread dim3_t threadIdx;
extern dim3_t gridDim;
extern dim3_t blockDim;
extern __thread uint32_t __local_group_id;
extern uint32_t __groups_per_core;
extern uint32_t __warps_per_group;
typedef void (*vx_kernel_func_cb)(const void *arg);
#define __local_mem(size) \
(void*)((int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + __local_group_id * size)
#define __syncthreads() \
vx_barrier(__COUNTER__ * __groups_per_core + __local_group_id, __warps_per_group)
int vx_spawn_threads(uint32_t dimension,
const uint32_t* grid_dim,
const uint32_t* block_dim,
vx_kernel_func_cb kernel_func,
const void* arg);
#ifdef __cplusplus
}

View file

@ -26,29 +26,29 @@ extern "C" {
typedef struct {
vx_spawn_tasks_cb callback;
void* arg;
int all_tasks_offset;
int remain_tasks_offset;
int warp_batches;
int remaining_warps;
const void* arg;
uint32_t all_tasks_offset;
uint32_t remain_tasks_offset;
uint32_t warp_batches;
uint32_t remaining_warps;
} wspawn_tasks_args_t;
static void __attribute__ ((noinline)) process_all_tasks() {
wspawn_tasks_args_t* targs = (wspawn_tasks_args_t*)csr_read(VX_CSR_MSCRATCH);
int threads_per_warp = vx_num_threads();
int warp_id = vx_warp_id();
int thread_id = vx_thread_id();
uint32_t threads_per_warp = vx_num_threads();
uint32_t warp_id = vx_warp_id();
uint32_t thread_id = vx_thread_id();
int start_warp = (warp_id * targs->warp_batches) + MIN(warp_id, targs->remaining_warps);
int iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
uint32_t start_warp = (warp_id * targs->warp_batches) + MIN(warp_id, targs->remaining_warps);
uint32_t iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
int start_task_id = targs->all_tasks_offset + (start_warp * threads_per_warp) + thread_id;
int end_task_id = start_task_id + iterations * threads_per_warp;
uint32_t start_task_id = targs->all_tasks_offset + (start_warp * threads_per_warp) + thread_id;
uint32_t end_task_id = start_task_id + iterations * threads_per_warp;
vx_spawn_tasks_cb callback = targs->callback;
void* arg = targs->arg;
for (int task_id = start_task_id; task_id < end_task_id; task_id += threads_per_warp) {
const void* arg = targs->arg;
for (uint32_t task_id = start_task_id; task_id < end_task_id; task_id += threads_per_warp) {
callback(task_id, arg);
}
}
@ -56,8 +56,8 @@ static void __attribute__ ((noinline)) process_all_tasks() {
static void __attribute__ ((noinline)) process_remaining_tasks() {
wspawn_tasks_args_t* targs = (wspawn_tasks_args_t*)csr_read(VX_CSR_MSCRATCH);
int thread_id = vx_thread_id();
int task_id = targs->remain_tasks_offset + thread_id;
uint32_t thread_id = vx_thread_id();
uint32_t task_id = targs->remain_tasks_offset + thread_id;
(targs->callback)(task_id, targs->arg);
}
@ -73,33 +73,33 @@ static void __attribute__ ((noinline)) process_all_tasks_stub() {
vx_tmc_zero();
}
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
void vx_spawn_tasks(uint32_t num_tasks, vx_spawn_tasks_cb callback , const void * arg) {
// device specifications
int num_cores = vx_num_cores();
int warps_per_core = vx_num_warps();
int threads_per_warp = vx_num_threads();
int core_id = vx_core_id();
uint32_t num_cores = vx_num_cores();
uint32_t warps_per_core = vx_num_warps();
uint32_t threads_per_warp = vx_num_threads();
uint32_t core_id = vx_core_id();
// calculate necessary active cores
int threads_per_core = warps_per_core * threads_per_warp;
int needed_cores = (num_tasks + threads_per_core - 1) / threads_per_core;
int active_cores = MIN(needed_cores, num_cores);
uint32_t threads_per_core = warps_per_core * threads_per_warp;
uint32_t needed_cores = (num_tasks + threads_per_core - 1) / threads_per_core;
uint32_t active_cores = MIN(needed_cores, num_cores);
// only active cores participate
if (core_id >= active_cores)
return;
// number of tasks per core
int tasks_per_core = num_tasks / active_cores;
int remaining_tasks_per_core = num_tasks - tasks_per_core * active_cores;
uint32_t tasks_per_core = num_tasks / active_cores;
uint32_t remaining_tasks_per_core = num_tasks - tasks_per_core * active_cores;
if (core_id < remaining_tasks_per_core)
tasks_per_core++;
// calculate number of warps to activate
int total_warps_per_core = tasks_per_core / threads_per_warp;
int remaining_tasks = tasks_per_core - total_warps_per_core * threads_per_warp;
int active_warps = total_warps_per_core;
int warp_batches = 1, remaining_warps = 0;
uint32_t total_warps_per_core = tasks_per_core / threads_per_warp;
uint32_t remaining_tasks = tasks_per_core - total_warps_per_core * threads_per_warp;
uint32_t active_warps = total_warps_per_core;
uint32_t warp_batches = 1, remaining_warps = 0;
if (active_warps > warps_per_core) {
active_warps = warps_per_core;
warp_batches = total_warps_per_core / active_warps;
@ -107,8 +107,8 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
}
// calculate offsets for task distribution
int all_tasks_offset = core_id * tasks_per_core + MIN(core_id, remaining_tasks_per_core);
int remain_tasks_offset = all_tasks_offset + (tasks_per_core - remaining_tasks);
uint32_t all_tasks_offset = core_id * tasks_per_core + MIN(core_id, remaining_tasks_per_core);
uint32_t remain_tasks_offset = all_tasks_offset + (tasks_per_core - remaining_tasks);
// prepare scheduler arguments
wspawn_tasks_args_t wspawn_args = {
@ -137,7 +137,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
if (remaining_tasks != 0) {
// activate remaining threads
int tmask = (1 << remaining_tasks) - 1;
uint32_t tmask = (1 << remaining_tasks) - 1;
vx_tmc(tmask);
// process remaining tasks
@ -153,51 +153,69 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
///////////////////////////////////////////////////////////////////////////////
__thread dim3_t blockIdx;
__thread dim3_t threadIdx;
dim3_t gridDim;
dim3_t blockDim;
__thread uint32_t __local_group_id;
uint32_t __groups_per_core;
uint32_t __warps_per_group;
typedef struct {
vx_spawn_task_groups_cb callback;
void* arg;
int group_offset;
int warp_batches;
int remaining_warps;
int warps_per_group;
int groups_per_core;
int remaining_mask;
vx_kernel_func_cb callback;
const void* arg;
uint32_t group_offset;
uint32_t warp_batches;
uint32_t remaining_warps;
uint32_t warps_per_group;
uint32_t groups_per_core;
uint32_t remaining_mask;
} wspawn_task_groups_args_t;
static void __attribute__ ((noinline)) process_all_task_groups() {
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
int warps_per_group = targs->warps_per_group;
int groups_per_core = targs->groups_per_core;
uint32_t threads_per_warp = vx_num_threads();
uint32_t warp_id = vx_warp_id();
uint32_t thread_id = vx_thread_id();
int threads_per_warp = vx_num_threads();
int warp_id = vx_warp_id();
int thread_id = vx_thread_id();
uint32_t warps_per_group = targs->warps_per_group;
uint32_t groups_per_core = targs->groups_per_core;
int iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
uint32_t iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
int local_group_id = warp_id / warps_per_group;
int group_warp_id = warp_id - local_group_id * warps_per_group;
int local_task_id = group_warp_id * threads_per_warp + thread_id;
uint32_t local_group_id = warp_id / warps_per_group;
uint32_t group_warp_id = warp_id - local_group_id * warps_per_group;
uint32_t local_task_id = group_warp_id * threads_per_warp + thread_id;
int start_group = targs->group_offset + local_group_id;
int end_group = start_group + iterations * groups_per_core;
uint32_t start_group = targs->group_offset + local_group_id;
uint32_t end_group = start_group + iterations * groups_per_core;
vx_spawn_task_groups_cb callback = targs->callback;
void* arg = targs->arg;
__local_group_id = local_group_id;
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
callback(local_task_id, group_id, start_group, warps_per_group, arg);
threadIdx.x = local_task_id % blockDim.x;
threadIdx.y = (local_task_id / blockDim.x) % blockDim.y;
threadIdx.z = local_task_id / (blockDim.x * blockDim.y);
vx_kernel_func_cb callback = targs->callback;
const void* arg = targs->arg;
for (uint32_t group_id = start_group; group_id < end_group; group_id += groups_per_core) {
blockIdx.x = group_id % gridDim.x;
blockIdx.y = (group_id / gridDim.x) % gridDim.y;
blockIdx.z = group_id / (gridDim.x * gridDim.y);
callback(arg);
}
}
static void __attribute__ ((noinline)) process_all_task_groups_stub() {
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
int warps_per_group = targs->warps_per_group;
int remaining_mask = targs->remaining_mask;
int warp_id = vx_warp_id();
int group_warp_id = warp_id % warps_per_group;
int threads_mask = (group_warp_id == warps_per_group-1) ? remaining_mask : -1;
uint32_t warps_per_group = targs->warps_per_group;
uint32_t remaining_mask = targs->remaining_mask;
uint32_t warp_id = vx_warp_id();
uint32_t group_warp_id = warp_id % warps_per_group;
uint32_t threads_mask = (group_warp_id == warps_per_group-1) ? remaining_mask : -1;
// activate threads
vx_tmc(threads_mask);
@ -209,46 +227,62 @@ static void __attribute__ ((noinline)) process_all_task_groups_stub() {
vx_tmc(0 == vx_warp_id());
}
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, void * arg) {
int vx_spawn_threads(uint32_t dimension,
const uint32_t* grid_dim,
const uint32_t * block_dim,
vx_kernel_func_cb kernel_func,
const void* arg) {
// calculate number of groups and group size
uint32_t num_groups = 1;
uint32_t group_size = 1;
for (uint32_t i = 0; i < 3; ++i) {
uint32_t gd = (i < dimension) ? grid_dim[i] : 1;
uint32_t bd = (i < dimension) ? block_dim[i] : 1;
num_groups *= gd;
group_size *= bd;
gridDim.m[i] = gd;
blockDim.m[i] = bd;
}
// device specifications
int num_cores = vx_num_cores();
int warps_per_core = vx_num_warps();
int threads_per_warp = vx_num_threads();
int core_id = vx_core_id();
uint32_t num_cores = vx_num_cores();
uint32_t warps_per_core = vx_num_warps();
uint32_t threads_per_warp = vx_num_threads();
uint32_t core_id = vx_core_id();
// check group size
int threads_per_core = warps_per_core * threads_per_warp;
uint32_t threads_per_core = warps_per_core * threads_per_warp;
if (threads_per_core < group_size) {
vx_printf("error: group_size > threads_per_core (%d)\n", threads_per_core);
return;
vx_printf("error: group_size > threads_per_core (%d, %d)\n", group_size, threads_per_core);
return -1;
}
int warps_per_group = group_size / threads_per_warp;
int remaining_threads = group_size - warps_per_group * threads_per_warp;
int remaining_mask = -1;
uint32_t warps_per_group = group_size / threads_per_warp;
uint32_t remaining_threads = group_size - warps_per_group * threads_per_warp;
uint32_t remaining_mask = -1;
if (remaining_threads != 0) {
remaining_mask = (1 << remaining_threads) - 1;
warps_per_group++;
++warps_per_group;
}
int needed_warps = num_groups * warps_per_group;
int needed_cores = (needed_warps + warps_per_core-1) / warps_per_core;
int active_cores = MIN(needed_cores, num_cores);
uint32_t needed_warps = num_groups * warps_per_group;
uint32_t needed_cores = (needed_warps + warps_per_core-1) / warps_per_core;
uint32_t active_cores = MIN(needed_cores, num_cores);
// only active cores participate
if (core_id >= active_cores)
return;
return 0;
int total_groups_per_core = num_groups / active_cores;
int remaining_groups_per_core = num_groups - active_cores * total_groups_per_core;
uint32_t total_groups_per_core = num_groups / active_cores;
uint32_t remaining_groups_per_core = num_groups - active_cores * total_groups_per_core;
if (core_id < remaining_groups_per_core)
total_groups_per_core++;
++total_groups_per_core;
// calculate number of warps to activate
int groups_per_core = warps_per_core / warps_per_group;
int total_warps_per_core = total_groups_per_core * warps_per_group;
int active_warps = total_warps_per_core;
int warp_batches = 1, remaining_warps = 0;
uint32_t groups_per_core = warps_per_core / warps_per_group;
uint32_t total_warps_per_core = total_groups_per_core * warps_per_group;
uint32_t active_warps = total_warps_per_core;
uint32_t warp_batches = 1, remaining_warps = 0;
if (active_warps > warps_per_core) {
active_warps = groups_per_core * warps_per_group;
warp_batches = total_warps_per_core / active_warps;
@ -256,11 +290,11 @@ void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_c
}
// calculate offsets for group distribution
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
uint32_t group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
// prepare scheduler arguments
// set scheduler arguments
wspawn_task_groups_args_t wspawn_args = {
callback,
kernel_func,
arg,
group_offset,
warp_batches,
@ -271,6 +305,10 @@ void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_c
};
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
// set global variables
__groups_per_core = groups_per_core;
__warps_per_group = warps_per_group;
// execute callback on other warps
vx_wspawn(active_warps, process_all_task_groups_stub);
@ -279,6 +317,8 @@ void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_c
// wait for spawned tasks to complete
vx_wspawn(1, 0);
return 0;
}
#ifdef __cplusplus

View file

@ -2,14 +2,14 @@
#define _COMMON_H_
#ifndef TYPE
#define TYPE int
#define TYPE float
#endif
typedef struct {
uint32_t num_groups;
uint32_t group_size;
uint32_t size;
uint32_t tile_size;
uint32_t grid_dim[2];
uint32_t block_dim[2];
uint64_t A_addr;
uint64_t B_addr;
uint64_t C_addr;

View file

@ -1,58 +1,53 @@
#include <stdint.h>
#include <vx_intrinsics.h>
#include <vx_spawn.h>
#include <vx_print.h>
#include "common.h"
void kernel_body(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
auto C_ptr = reinterpret_cast<TYPE*>(arg->C_addr);
auto size = arg->size;
void kernel_body(kernel_arg_t *arg) {
// Setup buffer arguments
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
auto C_ptr = reinterpret_cast<TYPE*>(arg->C_addr);
// Allocate local memory for the tile of matrix A & B
auto local_ptr = __local_mem(2 * blockDim.x * blockDim.y * sizeof(TYPE));
auto local_A = (TYPE*)local_ptr;
auto local_B = (TYPE*)local_ptr + blockDim.x * blockDim.y;
auto size = arg->size;
auto tile_size = arg->tile_size;
auto num_groups = arg->num_groups;
auto group_size = arg->group_size;
auto num_tiles = size / tile_size;
auto local_mem = vx_local_malloc(local_group_id, group_size * 2);
// Determine row and column indices of the current subtask
auto l_row = local_task_id / tile_size;
auto l_col = local_task_id % tile_size;
// Determine global row and column indices
auto g_row = blockIdx.x * blockDim.x + threadIdx.x;
auto g_col = blockIdx.y * blockDim.y + threadIdx.y;
// Determine row and column indices of the current task
auto g_row = (group_id / num_tiles) * tile_size + l_row;
auto g_col = (group_id % num_tiles) * tile_size + l_col;
// Determine local row and column indices
auto l_row = threadIdx.x;
auto l_col = threadIdx.y;
// Allocate local memory for the tile of matrix A & B
auto local_A = (TYPE*)local_mem;
auto local_B = local_A + group_size;
TYPE sum(0);
TYPE sum(0);
// Loop over tiles
for (uint32_t k = 0; k < size; k += tile_size) {
// Load tile of matrix A & B to local memory
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
// Loop over tiles
for (uint32_t k = 0; k < size; k += tile_size) {
// Load tile of matrix A & B to local memory
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
// Synchronize all warps in current group
__syncthreads();
// Synchronize all warps in current group
vx_barrier(local_group_id * 2 + 0, warps_per_group);
// Compute partial sum for the local tile
for (uint32_t j = 0; j < tile_size; ++j) {
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
}
// Compute partial sum for the local tile
for (uint32_t j = 0; j < tile_size; ++j) {
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
}
// Synchronize all warps in current group
__syncthreads();
}
// Synchronize all warps in current group
vx_barrier(local_group_id * 2 + 1, warps_per_group);
}
// Store the computed sum into the result matrix C
C_ptr[g_row * size + g_col] = sum;
// Store the computed sum into the result matrix C
C_ptr[g_row * size + g_col] = sum;
}
int main() {
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
vx_spawn_task_groups(arg->num_groups, arg->group_size, (vx_spawn_task_groups_cb)kernel_body, arg);
return 0;
auto arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg);
}

View file

@ -29,9 +29,7 @@ public:
return "integer";
}
static int generate() {
static int q(1);
return q++;
//return rand();
return rand();
}
static bool compare(int a, int b, int index, int errors) {
if (a != b) {
@ -80,7 +78,6 @@ static void matmul_cpu(TYPE* out, const TYPE* A, const TYPE* B, uint32_t width,
TYPE b = B[e * width + col];
TYPE c = a * b;
sum += c;
//printf("out[%d][%d]=%d; a=%d, b=%d, c=%d\n", row, col, sum, a, b, c);
}
out[row * width + col] = sum;
}
@ -157,20 +154,18 @@ int main(int argc, char *argv[]) {
uint32_t size_sq = size * size;
uint32_t buf_size = size_sq * sizeof(TYPE);
uint32_t group_size = tile_size * tile_size;
uint32_t num_groups = size_sq / group_size;
uint32_t local_mem = 2 * group_size * sizeof(TYPE);
std::cout << "data type: " << Comparator<TYPE>::type_str() << std::endl;
std::cout << "matrix size: " << size << "x" << size << std::endl;
std::cout << "tile size: " << tile_size << "x" << tile_size << std::endl;
std::cout << "group size: " << group_size << std::endl;
std::cout << "number of groups: " << num_groups << std::endl;
std::cout << "local memory: " << local_mem << " bytes" << std::endl;
kernel_arg.num_groups = num_groups;
kernel_arg.group_size = group_size;
kernel_arg.grid_dim[0] = size / tile_size;
kernel_arg.grid_dim[1] = size / tile_size;
kernel_arg.block_dim[0] = tile_size;
kernel_arg.block_dim[1] = tile_size;
kernel_arg.size = size;
kernel_arg.tile_size = tile_size;