mirror of
https://github.com/vortexgpgpu/vortex.git
synced 2025-06-28 09:37:38 -04:00
spawn_tasks_ex optimization
This commit is contained in:
parent
0003926d01
commit
b6aa44f39f
4 changed files with 18 additions and 48 deletions
|
@ -23,7 +23,7 @@ extern "C" {
|
||||||
|
|
||||||
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
|
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
|
||||||
|
|
||||||
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, void *arg);
|
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, void *arg);
|
||||||
|
|
||||||
typedef void (*vx_serial_cb)(void *arg);
|
typedef void (*vx_serial_cb)(void *arg);
|
||||||
|
|
||||||
|
@ -31,8 +31,6 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
|
||||||
|
|
||||||
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg);
|
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg);
|
||||||
|
|
||||||
void vx_syncthreads(int barrier_id);
|
|
||||||
|
|
||||||
void vx_serial(vx_serial_cb callback, void * arg);
|
void vx_serial(vx_serial_cb callback, void * arg);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
@ -161,7 +161,6 @@ typedef struct {
|
||||||
int warps_per_group;
|
int warps_per_group;
|
||||||
int groups_per_core;
|
int groups_per_core;
|
||||||
int remaining_mask;
|
int remaining_mask;
|
||||||
int barrier_enabled;
|
|
||||||
} wspawn_tasks_ex_args_t;
|
} wspawn_tasks_ex_args_t;
|
||||||
|
|
||||||
static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
||||||
|
@ -187,7 +186,7 @@ static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
||||||
void* arg = targs->arg;
|
void* arg = targs->arg;
|
||||||
|
|
||||||
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
|
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
|
||||||
callback(local_task_id, group_id, arg);
|
callback(local_task_id, group_id, start_group, warps_per_group, arg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,21 +210,8 @@ static void __attribute__ ((noinline)) process_all_tasks_ex_stub() {
|
||||||
|
|
||||||
void vx_syncthreads(int barrier_id) {
|
void vx_syncthreads(int barrier_id) {
|
||||||
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
|
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||||
int barrier_enabled = targs->barrier_enabled;
|
|
||||||
if (!barrier_enabled)
|
|
||||||
return; // no need to synchronize
|
|
||||||
int warps_per_group = targs->warps_per_group;
|
int warps_per_group = targs->warps_per_group;
|
||||||
int groups_per_core = targs->groups_per_core;
|
vx_barrier(barrier_id, warps_per_group);
|
||||||
int num_barriers = vx_num_barriers();
|
|
||||||
int warp_id = vx_warp_id();
|
|
||||||
int local_group_id = warp_id / warps_per_group;
|
|
||||||
int id = barrier_id * groups_per_core + local_group_id;
|
|
||||||
// check barrier resource
|
|
||||||
if (id >= num_barriers) {
|
|
||||||
vx_printf("error: out of barrier resource (%d:%d)\n", id+1, num_barriers);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
vx_barrier(id, warps_per_group);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg) {
|
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg) {
|
||||||
|
@ -277,9 +263,6 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
|
||||||
// calculate offsets for group distribution
|
// calculate offsets for group distribution
|
||||||
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
|
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
|
||||||
|
|
||||||
// check if warp barriers are needed
|
|
||||||
int barrier_enabled = (group_size > threads_per_warp);
|
|
||||||
|
|
||||||
// prepare scheduler arguments
|
// prepare scheduler arguments
|
||||||
wspawn_tasks_ex_args_t wspawn_args = {
|
wspawn_tasks_ex_args_t wspawn_args = {
|
||||||
callback,
|
callback,
|
||||||
|
@ -289,8 +272,7 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
|
||||||
remaining_warps,
|
remaining_warps,
|
||||||
warps_per_group,
|
warps_per_group,
|
||||||
groups_per_core,
|
groups_per_core,
|
||||||
remaining_mask,
|
remaining_mask
|
||||||
barrier_enabled
|
|
||||||
};
|
};
|
||||||
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
|
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#include <vx_print.h>
|
#include <vx_print.h>
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
|
void sgemm_kernel(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
|
||||||
auto local_ptr = reinterpret_cast<TYPE*>(arg->local_addr);
|
auto local_ptr = reinterpret_cast<TYPE*>(arg->local_addr);
|
||||||
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
|
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
|
||||||
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
|
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
|
||||||
|
@ -24,8 +24,8 @@ void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
|
||||||
auto g_col = (group_id % num_tiles) * tile_size + l_col;
|
auto g_col = (group_id % num_tiles) * tile_size + l_col;
|
||||||
|
|
||||||
// Allocate local memory for the tile of matrix A & B
|
// Allocate local memory for the tile of matrix A & B
|
||||||
auto local_A = local_ptr + group_id * group_size;
|
auto local_A = local_ptr + local_group_id * group_size * 2;
|
||||||
auto local_B = local_A + num_groups * group_size;
|
auto local_B = local_A + group_size;
|
||||||
|
|
||||||
TYPE sum(0);
|
TYPE sum(0);
|
||||||
|
|
||||||
|
@ -35,16 +35,16 @@ void sgemm_kernel(int local_task_id, int group_id, kernel_arg_t *arg) {
|
||||||
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
|
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
|
||||||
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
|
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
|
||||||
|
|
||||||
// Synchronize all threads in current group
|
// Synchronize all warps in current group
|
||||||
vx_syncthreads(0);
|
vx_barrier(local_group_id * 2 + 0, warps_per_group);
|
||||||
|
|
||||||
// Compute partial sum for the local tile
|
// Compute partial sum for the local tile
|
||||||
for (uint32_t j = 0; j < tile_size; ++j) {
|
for (uint32_t j = 0; j < tile_size; ++j) {
|
||||||
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
|
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Synchronize all threads in current group
|
// Synchronize all warps in current group
|
||||||
vx_syncthreads(1);
|
vx_barrier(local_group_id * 2 + 1, warps_per_group);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the computed sum into the result matrix C
|
// Store the computed sum into the result matrix C
|
||||||
|
|
|
@ -160,7 +160,7 @@ int main(int argc, char *argv[]) {
|
||||||
|
|
||||||
uint32_t group_size = tile_size * tile_size;
|
uint32_t group_size = tile_size * tile_size;
|
||||||
uint32_t num_groups = (size * size) / group_size;
|
uint32_t num_groups = (size * size) / group_size;
|
||||||
uint32_t local_mem = 2 * num_groups * group_size * sizeof(TYPE);
|
uint32_t local_mem = 2 * group_size * sizeof(TYPE);
|
||||||
|
|
||||||
std::cout << "data type: " << Comparator<TYPE>::type_str() << std::endl;
|
std::cout << "data type: " << Comparator<TYPE>::type_str() << std::endl;
|
||||||
std::cout << "matrix size: " << size << "x" << size << std::endl;
|
std::cout << "matrix size: " << size << "x" << size << std::endl;
|
||||||
|
@ -174,23 +174,15 @@ int main(int argc, char *argv[]) {
|
||||||
kernel_arg.size = size;
|
kernel_arg.size = size;
|
||||||
kernel_arg.tile_size = tile_size;
|
kernel_arg.tile_size = tile_size;
|
||||||
|
|
||||||
// check work group capacity
|
// check work group occupancy
|
||||||
uint64_t num_warps, num_threads;
|
uint32_t max_barriers, max_localmem;
|
||||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_WARPS, &num_warps));
|
RT_CHECK(vx_check_occupancy(device, group_size, &max_barriers, &max_localmem));
|
||||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &num_threads));
|
RT_CHECK(max_barriers < 2);
|
||||||
uint32_t threads_per_core = num_warps * num_threads;
|
RT_CHECK(max_localmem < local_mem);
|
||||||
RT_CHECK(threads_per_core < group_size);
|
|
||||||
|
|
||||||
// check local memory capacity
|
|
||||||
uint64_t max_local_mem;
|
|
||||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_SIZE, &max_local_mem));
|
|
||||||
RT_CHECK(max_local_mem < local_mem);
|
|
||||||
|
|
||||||
// acquire local memory address
|
|
||||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_ADDR, &kernel_arg.local_addr));
|
|
||||||
|
|
||||||
// allocate device memory
|
// allocate device memory
|
||||||
std::cout << "allocate device memory" << std::endl;
|
std::cout << "allocate device memory" << std::endl;
|
||||||
|
RT_CHECK(vx_dev_caps(device, VX_CAPS_LOCAL_MEM_ADDR, &kernel_arg.local_addr));
|
||||||
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &A_buffer));
|
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &A_buffer));
|
||||||
RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr));
|
RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr));
|
||||||
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &B_buffer));
|
RT_CHECK(vx_mem_alloc(device, buf_size, VX_MEM_READ, &B_buffer));
|
||||||
|
@ -212,8 +204,6 @@ int main(int argc, char *argv[]) {
|
||||||
// generate source data
|
// generate source data
|
||||||
for (uint32_t i = 0; i < size_sq; ++i) {
|
for (uint32_t i = 0; i < size_sq; ++i) {
|
||||||
h_A[i] = Comparator<TYPE>::generate();
|
h_A[i] = Comparator<TYPE>::generate();
|
||||||
}
|
|
||||||
for (uint32_t i = 0; i < size_sq; ++i) {
|
|
||||||
h_B[i] = Comparator<TYPE>::generate();
|
h_B[i] = Comparator<TYPE>::generate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue