spawn_tasks_ex optimization

This commit is contained in:
Blaise Tine 2024-05-07 23:40:38 -07:00
parent 0003926d01
commit b6aa44f39f
4 changed files with 18 additions and 48 deletions

View file

@ -23,7 +23,7 @@ extern "C" {
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, void *arg);
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, void *arg);
typedef void (*vx_serial_cb)(void *arg);
@ -31,8 +31,6 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg);
void vx_syncthreads(int barrier_id);
void vx_serial(vx_serial_cb callback, void * arg);
#ifdef __cplusplus

View file

@ -161,7 +161,6 @@ typedef struct {
int warps_per_group;
int groups_per_core;
int remaining_mask;
int barrier_enabled;
} wspawn_tasks_ex_args_t;
static void __attribute__ ((noinline)) process_all_tasks_ex() {
@ -187,7 +186,7 @@ static void __attribute__ ((noinline)) process_all_tasks_ex() {
void* arg = targs->arg;
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
callback(local_task_id, group_id, arg);
callback(local_task_id, group_id, start_group, warps_per_group, arg);
}
}
@ -211,21 +210,8 @@ static void __attribute__ ((noinline)) process_all_tasks_ex_stub() {
void vx_syncthreads(int barrier_id) {
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
int barrier_enabled = targs->barrier_enabled;
if (!barrier_enabled)
return; // no need to synchronize
int warps_per_group = targs->warps_per_group;
int groups_per_core = targs->groups_per_core;
int num_barriers = vx_num_barriers();
int warp_id = vx_warp_id();
int local_group_id = warp_id / warps_per_group;
int id = barrier_id * groups_per_core + local_group_id;
// check barrier resource
if (id >= num_barriers) {
vx_printf("error: out of barrier resource (%d:%d)\n", id+1, num_barriers);
return;
}
vx_barrier(id, warps_per_group);
vx_barrier(barrier_id, warps_per_group);
}
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg) {
@ -277,9 +263,6 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
// calculate offsets for group distribution
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
// check if warp barriers are needed
int barrier_enabled = (group_size > threads_per_warp);
// prepare scheduler arguments
wspawn_tasks_ex_args_t wspawn_args = {
callback,
@ -289,8 +272,7 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
remaining_warps,
warps_per_group,
groups_per_core,
remaining_mask,
barrier_enabled
remaining_mask
};
csr_write(VX_CSR_MSCRATCH, &wspawn_args);

View file

@ -4,7 +4,7 @@
#include <vx_print.h>
#include "common.h"
void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
void sgemm_kernel(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
auto local_ptr = reinterpret_cast<TYPE*>(arg->local_addr);
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
@ -24,8 +24,8 @@ void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
auto g_col = (group_id % num_tiles) * tile_size + l_col;
// Allocate local memory for the tile of matrix A & B
auto local_A = local_ptr + group_id * group_size;
auto local_B = local_A + num_groups * group_size;
auto local_A = local_ptr + local_group_id * group_size * 2;
auto local_B = local_A + group_size;
TYPE sum(0);
@ -35,16 +35,16 @@ void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
// Synchronize all threads in current group
vx_syncthreads(0);
// Synchronize all warps in current group
vx_barrier(local_group_id * 2 + 0, warps_per_group);
// Compute partial sum for the local tile
for (uint32_t j = 0; j < tile_size; ++j) {
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
}
// Synchronize all threads in current group
vx_syncthreads(1);
// Synchronize all warps in current group
vx_barrier(local_group_id * 2 + 1, warps_per_group);
}
// Store the computed sum into the result matrix C

View file

@ -160,7 +160,7 @@ int main(int argc, char *argv[]) {
uint32_t group_size = tile_size * tile_size;
uint32_t num_groups = (size * size) / group_size;
uint32_t local_mem = 2 * num_groups * group_size * sizeof(TYPE);
uint32_t local_mem = 2 * group_size * sizeof(TYPE);
std::cout << "data type: " << Comparator<TYPE>::type_str() << std::endl;
std::cout << "matrix size: " << size << "x" << size << std::endl;
@ -174,23 +174,15 @@ int main(int argc, char *argv[]) {
kernel_arg.size = size;
kernel_arg.tile_size = tile_size;
// check work group capacity
uint64_t num_warps, num_threads;
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_WARPS, &num_warps));
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &num_threads));
uint32_t threads_per_core = num_warps * num_threads;
RT_CHECK(threads_per_core < group_size);
// check local memory capacity
uint64_t max_local_mem;
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_SIZE, &max_local_mem));
RT_CHECK(max_local_mem < local_mem);
// acquire local memory address
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_ADDR, &kernel_arg.local_addr));
// check work group occupancy
uint32_t max_barriers, max_localmem;
RT_CHECK(vx_check_occupancy(device, group_size, &max_barriers, &max_localmem));
RT_CHECK(max_barriers < 2);
RT_CHECK(max_localmem < local_mem);
// allocate device memory
std::cout << "allocate device memory" << std::endl;
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_ADDR, &kernel_arg.local_addr));
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &A_buffer));
RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr));
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &B_buffer));
@ -212,8 +204,6 @@ int main(int argc, char *argv[]) {
// generate source data
for (uint32_t i = 0; i < size_sq; ++i) {
h_A[i] = Comparator<TYPE>::generate();
}
for (uint32_t i = 0; i < size_sq; ++i) {
h_B[i] = Comparator<TYPE>::generate();
}