mirror of
https://github.com/vortexgpgpu/vortex.git
synced 2025-04-23 21:39:10 -04:00
vx_spawn_threads implementation
This commit is contained in:
parent
8c5a783477
commit
96cb381885
5 changed files with 208 additions and 152 deletions
|
@ -14,28 +14,54 @@
|
|||
#ifndef __VX_SPAWN_H__
|
||||
#define __VX_SPAWN_H__
|
||||
|
||||
#include <VX_types.h>
|
||||
#include <vx_intrinsics.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef void (*vx_spawn_tasks_cb)(int task_id, const void *arg);
|
||||
|
||||
typedef void (*vx_spawn_task_groups_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, const void *arg);
|
||||
typedef void (*vx_spawn_tasks_cb)(uint32_t task_id, const void *arg);
|
||||
|
||||
typedef void (*vx_serial_cb)(const void *arg);
|
||||
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, const void * arg);
|
||||
void vx_spawn_tasks(uint32_t num_tasks, vx_spawn_tasks_cb callback, const void * arg);
|
||||
|
||||
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, const void * arg);
|
||||
void vx_serial(vx_serial_cb callback, void * arg);
|
||||
|
||||
inline void* vx_local_malloc(int local_group_id, int size) {
|
||||
return (int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + local_group_id * size;
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void vx_serial(vx_serial_cb callback, const void * arg);
|
||||
typedef union {
|
||||
struct {
|
||||
uint32_t x;
|
||||
uint32_t y;
|
||||
uint32_t z;
|
||||
};
|
||||
uint32_t m[3];
|
||||
} dim3_t;
|
||||
|
||||
extern __thread dim3_t blockIdx;
|
||||
extern __thread dim3_t threadIdx;
|
||||
extern dim3_t gridDim;
|
||||
extern dim3_t blockDim;
|
||||
|
||||
extern __thread uint32_t __local_group_id;
|
||||
extern uint32_t __groups_per_core;
|
||||
extern uint32_t __warps_per_group;
|
||||
|
||||
typedef void (*vx_kernel_func_cb)(const void *arg);
|
||||
|
||||
#define __local_mem(size) \
|
||||
(void*)((int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + __local_group_id * size)
|
||||
|
||||
#define __syncthreads() \
|
||||
vx_barrier(__COUNTER__ * __groups_per_core + __local_group_id, __warps_per_group)
|
||||
|
||||
int vx_spawn_threads(uint32_t dimension,
|
||||
const uint32_t* grid_dim,
|
||||
const uint32_t* block_dim,
|
||||
vx_kernel_func_cb kernel_func,
|
||||
const void* arg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -26,29 +26,29 @@ extern "C" {
|
|||
|
||||
typedef struct {
|
||||
vx_spawn_tasks_cb callback;
|
||||
void* arg;
|
||||
int all_tasks_offset;
|
||||
int remain_tasks_offset;
|
||||
int warp_batches;
|
||||
int remaining_warps;
|
||||
const void* arg;
|
||||
uint32_t all_tasks_offset;
|
||||
uint32_t remain_tasks_offset;
|
||||
uint32_t warp_batches;
|
||||
uint32_t remaining_warps;
|
||||
} wspawn_tasks_args_t;
|
||||
|
||||
static void __attribute__ ((noinline)) process_all_tasks() {
|
||||
wspawn_tasks_args_t* targs = (wspawn_tasks_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
|
||||
int threads_per_warp = vx_num_threads();
|
||||
int warp_id = vx_warp_id();
|
||||
int thread_id = vx_thread_id();
|
||||
uint32_t threads_per_warp = vx_num_threads();
|
||||
uint32_t warp_id = vx_warp_id();
|
||||
uint32_t thread_id = vx_thread_id();
|
||||
|
||||
int start_warp = (warp_id * targs->warp_batches) + MIN(warp_id, targs->remaining_warps);
|
||||
int iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
|
||||
uint32_t start_warp = (warp_id * targs->warp_batches) + MIN(warp_id, targs->remaining_warps);
|
||||
uint32_t iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
|
||||
|
||||
int start_task_id = targs->all_tasks_offset + (start_warp * threads_per_warp) + thread_id;
|
||||
int end_task_id = start_task_id + iterations * threads_per_warp;
|
||||
uint32_t start_task_id = targs->all_tasks_offset + (start_warp * threads_per_warp) + thread_id;
|
||||
uint32_t end_task_id = start_task_id + iterations * threads_per_warp;
|
||||
|
||||
vx_spawn_tasks_cb callback = targs->callback;
|
||||
void* arg = targs->arg;
|
||||
for (int task_id = start_task_id; task_id < end_task_id; task_id += threads_per_warp) {
|
||||
const void* arg = targs->arg;
|
||||
for (uint32_t task_id = start_task_id; task_id < end_task_id; task_id += threads_per_warp) {
|
||||
callback(task_id, arg);
|
||||
}
|
||||
}
|
||||
|
@ -56,8 +56,8 @@ static void __attribute__ ((noinline)) process_all_tasks() {
|
|||
static void __attribute__ ((noinline)) process_remaining_tasks() {
|
||||
wspawn_tasks_args_t* targs = (wspawn_tasks_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
|
||||
int thread_id = vx_thread_id();
|
||||
int task_id = targs->remain_tasks_offset + thread_id;
|
||||
uint32_t thread_id = vx_thread_id();
|
||||
uint32_t task_id = targs->remain_tasks_offset + thread_id;
|
||||
|
||||
(targs->callback)(task_id, targs->arg);
|
||||
}
|
||||
|
@ -73,33 +73,33 @@ static void __attribute__ ((noinline)) process_all_tasks_stub() {
|
|||
vx_tmc_zero();
|
||||
}
|
||||
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
void vx_spawn_tasks(uint32_t num_tasks, vx_spawn_tasks_cb callback , const void * arg) {
|
||||
// device specifications
|
||||
int num_cores = vx_num_cores();
|
||||
int warps_per_core = vx_num_warps();
|
||||
int threads_per_warp = vx_num_threads();
|
||||
int core_id = vx_core_id();
|
||||
uint32_t num_cores = vx_num_cores();
|
||||
uint32_t warps_per_core = vx_num_warps();
|
||||
uint32_t threads_per_warp = vx_num_threads();
|
||||
uint32_t core_id = vx_core_id();
|
||||
|
||||
// calculate necessary active cores
|
||||
int threads_per_core = warps_per_core * threads_per_warp;
|
||||
int needed_cores = (num_tasks + threads_per_core - 1) / threads_per_core;
|
||||
int active_cores = MIN(needed_cores, num_cores);
|
||||
uint32_t threads_per_core = warps_per_core * threads_per_warp;
|
||||
uint32_t needed_cores = (num_tasks + threads_per_core - 1) / threads_per_core;
|
||||
uint32_t active_cores = MIN(needed_cores, num_cores);
|
||||
|
||||
// only active cores participate
|
||||
if (core_id >= active_cores)
|
||||
return;
|
||||
|
||||
// number of tasks per core
|
||||
int tasks_per_core = num_tasks / active_cores;
|
||||
int remaining_tasks_per_core = num_tasks - tasks_per_core * active_cores;
|
||||
uint32_t tasks_per_core = num_tasks / active_cores;
|
||||
uint32_t remaining_tasks_per_core = num_tasks - tasks_per_core * active_cores;
|
||||
if (core_id < remaining_tasks_per_core)
|
||||
tasks_per_core++;
|
||||
|
||||
// calculate number of warps to activate
|
||||
int total_warps_per_core = tasks_per_core / threads_per_warp;
|
||||
int remaining_tasks = tasks_per_core - total_warps_per_core * threads_per_warp;
|
||||
int active_warps = total_warps_per_core;
|
||||
int warp_batches = 1, remaining_warps = 0;
|
||||
uint32_t total_warps_per_core = tasks_per_core / threads_per_warp;
|
||||
uint32_t remaining_tasks = tasks_per_core - total_warps_per_core * threads_per_warp;
|
||||
uint32_t active_warps = total_warps_per_core;
|
||||
uint32_t warp_batches = 1, remaining_warps = 0;
|
||||
if (active_warps > warps_per_core) {
|
||||
active_warps = warps_per_core;
|
||||
warp_batches = total_warps_per_core / active_warps;
|
||||
|
@ -107,8 +107,8 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
|||
}
|
||||
|
||||
// calculate offsets for task distribution
|
||||
int all_tasks_offset = core_id * tasks_per_core + MIN(core_id, remaining_tasks_per_core);
|
||||
int remain_tasks_offset = all_tasks_offset + (tasks_per_core - remaining_tasks);
|
||||
uint32_t all_tasks_offset = core_id * tasks_per_core + MIN(core_id, remaining_tasks_per_core);
|
||||
uint32_t remain_tasks_offset = all_tasks_offset + (tasks_per_core - remaining_tasks);
|
||||
|
||||
// prepare scheduler arguments
|
||||
wspawn_tasks_args_t wspawn_args = {
|
||||
|
@ -137,7 +137,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
|||
|
||||
if (remaining_tasks != 0) {
|
||||
// activate remaining threads
|
||||
int tmask = (1 << remaining_tasks) - 1;
|
||||
uint32_t tmask = (1 << remaining_tasks) - 1;
|
||||
vx_tmc(tmask);
|
||||
|
||||
// process remaining tasks
|
||||
|
@ -153,51 +153,69 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__thread dim3_t blockIdx;
|
||||
__thread dim3_t threadIdx;
|
||||
dim3_t gridDim;
|
||||
dim3_t blockDim;
|
||||
|
||||
__thread uint32_t __local_group_id;
|
||||
uint32_t __groups_per_core;
|
||||
uint32_t __warps_per_group;
|
||||
|
||||
typedef struct {
|
||||
vx_spawn_task_groups_cb callback;
|
||||
void* arg;
|
||||
int group_offset;
|
||||
int warp_batches;
|
||||
int remaining_warps;
|
||||
int warps_per_group;
|
||||
int groups_per_core;
|
||||
int remaining_mask;
|
||||
vx_kernel_func_cb callback;
|
||||
const void* arg;
|
||||
uint32_t group_offset;
|
||||
uint32_t warp_batches;
|
||||
uint32_t remaining_warps;
|
||||
uint32_t warps_per_group;
|
||||
uint32_t groups_per_core;
|
||||
uint32_t remaining_mask;
|
||||
} wspawn_task_groups_args_t;
|
||||
|
||||
static void __attribute__ ((noinline)) process_all_task_groups() {
|
||||
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
|
||||
int warps_per_group = targs->warps_per_group;
|
||||
int groups_per_core = targs->groups_per_core;
|
||||
uint32_t threads_per_warp = vx_num_threads();
|
||||
uint32_t warp_id = vx_warp_id();
|
||||
uint32_t thread_id = vx_thread_id();
|
||||
|
||||
int threads_per_warp = vx_num_threads();
|
||||
int warp_id = vx_warp_id();
|
||||
int thread_id = vx_thread_id();
|
||||
uint32_t warps_per_group = targs->warps_per_group;
|
||||
uint32_t groups_per_core = targs->groups_per_core;
|
||||
|
||||
int iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
|
||||
uint32_t iterations = targs->warp_batches + (warp_id < targs->remaining_warps);
|
||||
|
||||
int local_group_id = warp_id / warps_per_group;
|
||||
int group_warp_id = warp_id - local_group_id * warps_per_group;
|
||||
int local_task_id = group_warp_id * threads_per_warp + thread_id;
|
||||
uint32_t local_group_id = warp_id / warps_per_group;
|
||||
uint32_t group_warp_id = warp_id - local_group_id * warps_per_group;
|
||||
uint32_t local_task_id = group_warp_id * threads_per_warp + thread_id;
|
||||
|
||||
int start_group = targs->group_offset + local_group_id;
|
||||
int end_group = start_group + iterations * groups_per_core;
|
||||
uint32_t start_group = targs->group_offset + local_group_id;
|
||||
uint32_t end_group = start_group + iterations * groups_per_core;
|
||||
|
||||
vx_spawn_task_groups_cb callback = targs->callback;
|
||||
void* arg = targs->arg;
|
||||
__local_group_id = local_group_id;
|
||||
|
||||
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
|
||||
callback(local_task_id, group_id, start_group, warps_per_group, arg);
|
||||
threadIdx.x = local_task_id % blockDim.x;
|
||||
threadIdx.y = (local_task_id / blockDim.x) % blockDim.y;
|
||||
threadIdx.z = local_task_id / (blockDim.x * blockDim.y);
|
||||
|
||||
vx_kernel_func_cb callback = targs->callback;
|
||||
const void* arg = targs->arg;
|
||||
|
||||
for (uint32_t group_id = start_group; group_id < end_group; group_id += groups_per_core) {
|
||||
blockIdx.x = group_id % gridDim.x;
|
||||
blockIdx.y = (group_id / gridDim.x) % gridDim.y;
|
||||
blockIdx.z = group_id / (gridDim.x * gridDim.y);
|
||||
callback(arg);
|
||||
}
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) process_all_task_groups_stub() {
|
||||
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
int warps_per_group = targs->warps_per_group;
|
||||
int remaining_mask = targs->remaining_mask;
|
||||
int warp_id = vx_warp_id();
|
||||
int group_warp_id = warp_id % warps_per_group;
|
||||
int threads_mask = (group_warp_id == warps_per_group-1) ? remaining_mask : -1;
|
||||
uint32_t warps_per_group = targs->warps_per_group;
|
||||
uint32_t remaining_mask = targs->remaining_mask;
|
||||
uint32_t warp_id = vx_warp_id();
|
||||
uint32_t group_warp_id = warp_id % warps_per_group;
|
||||
uint32_t threads_mask = (group_warp_id == warps_per_group-1) ? remaining_mask : -1;
|
||||
|
||||
// activate threads
|
||||
vx_tmc(threads_mask);
|
||||
|
@ -209,46 +227,62 @@ static void __attribute__ ((noinline)) process_all_task_groups_stub() {
|
|||
vx_tmc(0 == vx_warp_id());
|
||||
}
|
||||
|
||||
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, void * arg) {
|
||||
int vx_spawn_threads(uint32_t dimension,
|
||||
const uint32_t* grid_dim,
|
||||
const uint32_t * block_dim,
|
||||
vx_kernel_func_cb kernel_func,
|
||||
const void* arg) {
|
||||
// calculate number of groups and group size
|
||||
uint32_t num_groups = 1;
|
||||
uint32_t group_size = 1;
|
||||
for (uint32_t i = 0; i < 3; ++i) {
|
||||
uint32_t gd = (i < dimension) ? grid_dim[i] : 1;
|
||||
uint32_t bd = (i < dimension) ? block_dim[i] : 1;
|
||||
num_groups *= gd;
|
||||
group_size *= bd;
|
||||
gridDim.m[i] = gd;
|
||||
blockDim.m[i] = bd;
|
||||
}
|
||||
|
||||
// device specifications
|
||||
int num_cores = vx_num_cores();
|
||||
int warps_per_core = vx_num_warps();
|
||||
int threads_per_warp = vx_num_threads();
|
||||
int core_id = vx_core_id();
|
||||
uint32_t num_cores = vx_num_cores();
|
||||
uint32_t warps_per_core = vx_num_warps();
|
||||
uint32_t threads_per_warp = vx_num_threads();
|
||||
uint32_t core_id = vx_core_id();
|
||||
|
||||
// check group size
|
||||
int threads_per_core = warps_per_core * threads_per_warp;
|
||||
uint32_t threads_per_core = warps_per_core * threads_per_warp;
|
||||
if (threads_per_core < group_size) {
|
||||
vx_printf("error: group_size > threads_per_core (%d)\n", threads_per_core);
|
||||
return;
|
||||
vx_printf("error: group_size > threads_per_core (%d, %d)\n", group_size, threads_per_core);
|
||||
return -1;
|
||||
}
|
||||
|
||||
int warps_per_group = group_size / threads_per_warp;
|
||||
int remaining_threads = group_size - warps_per_group * threads_per_warp;
|
||||
int remaining_mask = -1;
|
||||
uint32_t warps_per_group = group_size / threads_per_warp;
|
||||
uint32_t remaining_threads = group_size - warps_per_group * threads_per_warp;
|
||||
uint32_t remaining_mask = -1;
|
||||
if (remaining_threads != 0) {
|
||||
remaining_mask = (1 << remaining_threads) - 1;
|
||||
warps_per_group++;
|
||||
++warps_per_group;
|
||||
}
|
||||
|
||||
int needed_warps = num_groups * warps_per_group;
|
||||
int needed_cores = (needed_warps + warps_per_core-1) / warps_per_core;
|
||||
int active_cores = MIN(needed_cores, num_cores);
|
||||
uint32_t needed_warps = num_groups * warps_per_group;
|
||||
uint32_t needed_cores = (needed_warps + warps_per_core-1) / warps_per_core;
|
||||
uint32_t active_cores = MIN(needed_cores, num_cores);
|
||||
|
||||
// only active cores participate
|
||||
if (core_id >= active_cores)
|
||||
return;
|
||||
return 0;
|
||||
|
||||
int total_groups_per_core = num_groups / active_cores;
|
||||
int remaining_groups_per_core = num_groups - active_cores * total_groups_per_core;
|
||||
uint32_t total_groups_per_core = num_groups / active_cores;
|
||||
uint32_t remaining_groups_per_core = num_groups - active_cores * total_groups_per_core;
|
||||
if (core_id < remaining_groups_per_core)
|
||||
total_groups_per_core++;
|
||||
++total_groups_per_core;
|
||||
|
||||
// calculate number of warps to activate
|
||||
int groups_per_core = warps_per_core / warps_per_group;
|
||||
int total_warps_per_core = total_groups_per_core * warps_per_group;
|
||||
int active_warps = total_warps_per_core;
|
||||
int warp_batches = 1, remaining_warps = 0;
|
||||
uint32_t groups_per_core = warps_per_core / warps_per_group;
|
||||
uint32_t total_warps_per_core = total_groups_per_core * warps_per_group;
|
||||
uint32_t active_warps = total_warps_per_core;
|
||||
uint32_t warp_batches = 1, remaining_warps = 0;
|
||||
if (active_warps > warps_per_core) {
|
||||
active_warps = groups_per_core * warps_per_group;
|
||||
warp_batches = total_warps_per_core / active_warps;
|
||||
|
@ -256,11 +290,11 @@ void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_c
|
|||
}
|
||||
|
||||
// calculate offsets for group distribution
|
||||
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
|
||||
uint32_t group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
|
||||
|
||||
// prepare scheduler arguments
|
||||
// set scheduler arguments
|
||||
wspawn_task_groups_args_t wspawn_args = {
|
||||
callback,
|
||||
kernel_func,
|
||||
arg,
|
||||
group_offset,
|
||||
warp_batches,
|
||||
|
@ -271,6 +305,10 @@ void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_c
|
|||
};
|
||||
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
|
||||
|
||||
// set global variables
|
||||
__groups_per_core = groups_per_core;
|
||||
__warps_per_group = warps_per_group;
|
||||
|
||||
// execute callback on other warps
|
||||
vx_wspawn(active_warps, process_all_task_groups_stub);
|
||||
|
||||
|
@ -279,6 +317,8 @@ void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_c
|
|||
|
||||
// wait for spawned tasks to complete
|
||||
vx_wspawn(1, 0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
#define _COMMON_H_
|
||||
|
||||
#ifndef TYPE
|
||||
#define TYPE int
|
||||
#define TYPE float
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
uint32_t num_groups;
|
||||
uint32_t group_size;
|
||||
uint32_t size;
|
||||
uint32_t tile_size;
|
||||
uint32_t grid_dim[2];
|
||||
uint32_t block_dim[2];
|
||||
uint64_t A_addr;
|
||||
uint64_t B_addr;
|
||||
uint64_t C_addr;
|
||||
|
|
|
@ -1,58 +1,53 @@
|
|||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
#include <vx_spawn.h>
|
||||
#include <vx_print.h>
|
||||
#include "common.h"
|
||||
|
||||
void kernel_body(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
|
||||
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
|
||||
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
|
||||
auto C_ptr = reinterpret_cast<TYPE*>(arg->C_addr);
|
||||
auto size = arg->size;
|
||||
void kernel_body(kernel_arg_t *arg) {
|
||||
// Setup buffer arguments
|
||||
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
|
||||
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
|
||||
auto C_ptr = reinterpret_cast<TYPE*>(arg->C_addr);
|
||||
|
||||
// Allocate local memory for the tile of matrix A & B
|
||||
auto local_ptr = __local_mem(2 * blockDim.x * blockDim.y * sizeof(TYPE));
|
||||
auto local_A = (TYPE*)local_ptr;
|
||||
auto local_B = (TYPE*)local_ptr + blockDim.x * blockDim.y;
|
||||
|
||||
auto size = arg->size;
|
||||
auto tile_size = arg->tile_size;
|
||||
auto num_groups = arg->num_groups;
|
||||
auto group_size = arg->group_size;
|
||||
auto num_tiles = size / tile_size;
|
||||
auto local_mem = vx_local_malloc(local_group_id, group_size * 2);
|
||||
|
||||
// Determine row and column indices of the current subtask
|
||||
auto l_row = local_task_id / tile_size;
|
||||
auto l_col = local_task_id % tile_size;
|
||||
// Determine global row and column indices
|
||||
auto g_row = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto g_col = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
// Determine row and column indices of the current task
|
||||
auto g_row = (group_id / num_tiles) * tile_size + l_row;
|
||||
auto g_col = (group_id % num_tiles) * tile_size + l_col;
|
||||
// Determine local row and column indices
|
||||
auto l_row = threadIdx.x;
|
||||
auto l_col = threadIdx.y;
|
||||
|
||||
// Allocate local memory for the tile of matrix A & B
|
||||
auto local_A = (TYPE*)local_mem;
|
||||
auto local_B = local_A + group_size;
|
||||
TYPE sum(0);
|
||||
|
||||
TYPE sum(0);
|
||||
// Loop over tiles
|
||||
for (uint32_t k = 0; k < size; k += tile_size) {
|
||||
// Load tile of matrix A & B to local memory
|
||||
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
|
||||
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
|
||||
|
||||
// Loop over tiles
|
||||
for (uint32_t k = 0; k < size; k += tile_size) {
|
||||
// Load tile of matrix A & B to local memory
|
||||
local_A[l_row * tile_size + l_col] = A_ptr[g_row * size + (k + l_col)];
|
||||
local_B[l_row * tile_size + l_col] = B_ptr[(k + l_row) * size + g_col];
|
||||
// Synchronize all warps in current group
|
||||
__syncthreads();
|
||||
|
||||
// Synchronize all warps in current group
|
||||
vx_barrier(local_group_id * 2 + 0, warps_per_group);
|
||||
// Compute partial sum for the local tile
|
||||
for (uint32_t j = 0; j < tile_size; ++j) {
|
||||
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
|
||||
}
|
||||
|
||||
// Compute partial sum for the local tile
|
||||
for (uint32_t j = 0; j < tile_size; ++j) {
|
||||
sum += local_A[l_row * tile_size + j] * local_B[j * tile_size + l_col];
|
||||
}
|
||||
// Synchronize all warps in current group
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Synchronize all warps in current group
|
||||
vx_barrier(local_group_id * 2 + 1, warps_per_group);
|
||||
}
|
||||
|
||||
// Store the computed sum into the result matrix C
|
||||
C_ptr[g_row * size + g_col] = sum;
|
||||
// Store the computed sum into the result matrix C
|
||||
C_ptr[g_row * size + g_col] = sum;
|
||||
}
|
||||
|
||||
int main() {
|
||||
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
vx_spawn_task_groups(arg->num_groups, arg->group_size, (vx_spawn_task_groups_cb)kernel_body, arg);
|
||||
return 0;
|
||||
auto arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg);
|
||||
}
|
||||
|
|
|
@ -29,9 +29,7 @@ public:
|
|||
return "integer";
|
||||
}
|
||||
static int generate() {
|
||||
static int q(1);
|
||||
return q++;
|
||||
//return rand();
|
||||
return rand();
|
||||
}
|
||||
static bool compare(int a, int b, int index, int errors) {
|
||||
if (a != b) {
|
||||
|
@ -80,7 +78,6 @@ static void matmul_cpu(TYPE* out, const TYPE* A, const TYPE* B, uint32_t width,
|
|||
TYPE b = B[e * width + col];
|
||||
TYPE c = a * b;
|
||||
sum += c;
|
||||
//printf("out[%d][%d]=%d; a=%d, b=%d, c=%d\n", row, col, sum, a, b, c);
|
||||
}
|
||||
out[row * width + col] = sum;
|
||||
}
|
||||
|
@ -157,20 +154,18 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
uint32_t size_sq = size * size;
|
||||
uint32_t buf_size = size_sq * sizeof(TYPE);
|
||||
|
||||
uint32_t group_size = tile_size * tile_size;
|
||||
uint32_t num_groups = size_sq / group_size;
|
||||
uint32_t local_mem = 2 * group_size * sizeof(TYPE);
|
||||
|
||||
std::cout << "data type: " << Comparator<TYPE>::type_str() << std::endl;
|
||||
std::cout << "matrix size: " << size << "x" << size << std::endl;
|
||||
std::cout << "tile size: " << tile_size << "x" << tile_size << std::endl;
|
||||
std::cout << "group size: " << group_size << std::endl;
|
||||
std::cout << "number of groups: " << num_groups << std::endl;
|
||||
std::cout << "local memory: " << local_mem << " bytes" << std::endl;
|
||||
|
||||
kernel_arg.num_groups = num_groups;
|
||||
kernel_arg.group_size = group_size;
|
||||
kernel_arg.grid_dim[0] = size / tile_size;
|
||||
kernel_arg.grid_dim[1] = size / tile_size;
|
||||
kernel_arg.block_dim[0] = tile_size;
|
||||
kernel_arg.block_dim[1] = tile_size;
|
||||
kernel_arg.size = size;
|
||||
kernel_arg.tile_size = tile_size;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue