mirror of
https://github.com/vortexgpgpu/vortex.git
synced 2025-04-23 21:39:10 -04:00
spawn_tasks_ex optimization
This commit is contained in:
parent
0003926d01
commit
b6aa44f39f
4 changed files with 18 additions and 48 deletions
|
@ -23,7 +23,7 @@ extern "C" {
|
|||
|
||||
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
|
||||
|
||||
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, void *arg);
|
||||
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, void *arg);
|
||||
|
||||
typedef void (*vx_serial_cb)(void *arg);
|
||||
|
||||
|
@ -31,8 +31,6 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
|
|||
|
||||
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg);
|
||||
|
||||
void vx_syncthreads(int barrier_id);
|
||||
|
||||
void vx_serial(vx_serial_cb callback, void * arg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -161,7 +161,6 @@ typedef struct {
|
|||
int warps_per_group;
|
||||
int groups_per_core;
|
||||
int remaining_mask;
|
||||
int barrier_enabled;
|
||||
} wspawn_tasks_ex_args_t;
|
||||
|
||||
static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
||||
|
@ -187,7 +186,7 @@ static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
|||
void* arg = targs->arg;
|
||||
|
||||
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
|
||||
callback(local_task_id, group_id, arg);
|
||||
callback(local_task_id, group_id, start_group, warps_per_group, arg);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,21 +210,8 @@ static void __attribute__ ((noinline)) process_all_tasks_ex_stub() {
|
|||
|
||||
void vx_syncthreads(int barrier_id) {
|
||||
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
int barrier_enabled = targs->barrier_enabled;
|
||||
if (!barrier_enabled)
|
||||
return; // no need to synchronize
|
||||
int warps_per_group = targs->warps_per_group;
|
||||
int groups_per_core = targs->groups_per_core;
|
||||
int num_barriers = vx_num_barriers();
|
||||
int warp_id = vx_warp_id();
|
||||
int local_group_id = warp_id / warps_per_group;
|
||||
int id = barrier_id * groups_per_core + local_group_id;
|
||||
// check barrier resource
|
||||
if (id >= num_barriers) {
|
||||
vx_printf("error: out of barrier resource (%d:%d)\n", id+1, num_barriers);
|
||||
return;
|
||||
}
|
||||
vx_barrier(id, warps_per_group);
|
||||
vx_barrier(barrier_id, warps_per_group);
|
||||
}
|
||||
|
||||
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg) {
|
||||
|
@ -277,9 +263,6 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
|
|||
// calculate offsets for group distribution
|
||||
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
|
||||
|
||||
// check if warp barriers are needed
|
||||
int barrier_enabled = (group_size > threads_per_warp);
|
||||
|
||||
// prepare scheduler arguments
|
||||
wspawn_tasks_ex_args_t wspawn_args = {
|
||||
callback,
|
||||
|
@ -289,8 +272,7 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
|
|||
remaining_warps,
|
||||
warps_per_group,
|
||||
groups_per_core,
|
||||
remaining_mask,
|
||||
barrier_enabled
|
||||
remaining_mask
|
||||
};
|
||||
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include <vx_print.h>
|
||||
#include "common.h"
|
||||
|
||||
void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
|
||||
void sgemm_kernel(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
|
||||
auto local_ptr = reinterpret_cast<TYPE*>(arg->local_addr);
|
||||
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
|
||||
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
|
||||
|
@ -24,8 +24,8 @@ void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
|
|||
auto g_col = (group_id % num_tiles) * tile_size + l_col;
|
||||
|
||||
// Allocate local memory for the tile of matrix A & B
|
||||
auto local_A = local_ptr + group_id * group_size;
|
||||
auto local_B = local_A + num_groups * group_size;
|
||||
auto local_A = local_ptr + local_group_id * group_size * 2;
|
||||
auto local_B = local_A + group_size;
|
||||
|
||||
TYPE sum(0);
|
||||
|
||||
|
@ -35,16 +35,16 @@ void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
|
|||
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
|
||||
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
|
||||
|
||||
// Synchronize all threads in current group
|
||||
vx_syncthreads(0);
|
||||
// Synchronize all warps in current group
|
||||
vx_barrier(local_group_id * 2 + 0, warps_per_group);
|
||||
|
||||
// Compute partial sum for the local tile
|
||||
for (uint32_t j = 0; j < tile_size; ++j) {
|
||||
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
|
||||
}
|
||||
|
||||
// Synchronize all threads in current group
|
||||
vx_syncthreads(1);
|
||||
// Synchronize all warps in current group
|
||||
vx_barrier(local_group_id * 2 + 1, warps_per_group);
|
||||
}
|
||||
|
||||
// Store the computed sum into the result matrix C
|
||||
|
|
|
@ -160,7 +160,7 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
uint32_t group_size = tile_size * tile_size;
|
||||
uint32_t num_groups = (size * size) / group_size;
|
||||
uint32_t local_mem = 2 * num_groups * group_size * sizeof(TYPE);
|
||||
uint32_t local_mem = 2 * group_size * sizeof(TYPE);
|
||||
|
||||
std::cout << "data type: " << Comparator<TYPE>::type_str() << std::endl;
|
||||
std::cout << "matrix size: " << size << "x" << size << std::endl;
|
||||
|
@ -174,23 +174,15 @@ int main(int argc, char *argv[]) {
|
|||
kernel_arg.size = size;
|
||||
kernel_arg.tile_size = tile_size;
|
||||
|
||||
// check work group capacity
|
||||
uint64_t num_warps, num_threads;
|
||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_WARPS, &num_warps));
|
||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &num_threads));
|
||||
uint32_t threads_per_core = num_warps * num_threads;
|
||||
RT_CHECK(threads_per_core < group_size);
|
||||
|
||||
// check local memory capacity
|
||||
uint64_t max_local_mem;
|
||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_SIZE, &max_local_mem));
|
||||
RT_CHECK(max_local_mem < local_mem);
|
||||
|
||||
// acquire local memory address
|
||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_ADDR, &kernel_arg.local_addr));
|
||||
// check work group occupancy
|
||||
uint32_t max_barriers, max_localmem;
|
||||
RT_CHECK(vx_check_occupancy(device, group_size, &max_barriers, &max_localmem));
|
||||
RT_CHECK(max_barriers < 2);
|
||||
RT_CHECK(max_localmem < local_mem);
|
||||
|
||||
// allocate device memory
|
||||
std::cout << "allocate device memory" << std::endl;
|
||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_ADDR, &kernel_arg.local_addr));
|
||||
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &A_buffer));
|
||||
RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr));
|
||||
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &B_buffer));
|
||||
|
@ -212,8 +204,6 @@ int main(int argc, char *argv[]) {
|
|||
// generate source data
|
||||
for (uint32_t i = 0; i < size_sq; ++i) {
|
||||
h_A[i] = Comparator<TYPE>::generate();
|
||||
}
|
||||
for (uint32_t i = 0; i < size_sq; ++i) {
|
||||
h_B[i] = Comparator<TYPE>::generate();
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue