per-workgroup local memory fix

This commit is contained in:
Blaise Tine 2024-04-06 02:05:51 -07:00
parent 3534175d43
commit 351aa48f6e
2 changed files with 65 additions and 65 deletions

View file

@ -25,18 +25,19 @@ typedef struct {
uint32_t num_groups[3];
uint32_t global_offset[3];
uint32_t local_size[3];
char * printf_buffer;
uint32_t *printf_buffer_position;
uint32_t printf_buffer;
uint32_t printf_buffer_position;
uint32_t printf_buffer_capacity;
uint32_t work_dim;
} context_t;
} pocl_kernel_context_t;
typedef void (*vx_spawn_kernel_cb) (
typedef void (*pocl_kernel_cb) (
const void * /* arg */,
const context_t * /* context */,
const pocl_kernel_context_t * /* context */,
uint32_t /* group_x */,
uint32_t /* group_y */,
uint32_t /* group_z */
uint32_t /* group_z */,
uint32_t /* local_offset */
);
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
@ -45,7 +46,7 @@ typedef void (*vx_serial_cb)(void *arg);
void vx_wspawn_wait();
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg);
void vx_spawn_pocl_kernel(pocl_kernel_context_t * ctx, pocl_kernel_cb callback, void * arg);
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);

View file

@ -14,6 +14,7 @@
#include <vx_spawn.h>
#include <vx_intrinsics.h>
#include <inttypes.h>
#include <vx_print.h>
#ifdef __cplusplus
extern "C" {
@ -29,21 +30,24 @@ typedef struct {
vx_spawn_tasks_cb callback;
void* arg;
int offset; // task offset
int remain; // remaining offset
int FWs; // number of NW batches where NW=<total warps per core>.
int RWs; // number of remaining warps in the core
int RWs; // number of remaining warps in the core
} wspawn_tasks_args_t;
typedef struct {
context_t * ctx;
vx_spawn_kernel_cb callback;
pocl_kernel_context_t * ctx;
pocl_kernel_cb callback;
void* arg;
int offset; // task offset
int local_size;
int offset; // task offset
int remain; // remaining offset
int FWs; // number of NW batches where NW=<total warps per core>.
int RWs; // number of remaining warps in the core
int RWs; // number of remaining warps in the core
char isXYpow2;
char log2XY;
char log2X;
} wspawn_kernel_args_t;
} wspawn_pocl_kernel_args_t;
void* g_wspawn_args[NUM_CORES_MAX];
@ -79,7 +83,7 @@ static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
int task_id = p_wspawn_args->offset + tid;
int task_id = p_wspawn_args->remain + tid;
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
@ -129,7 +133,10 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
rW = TW - fW * NW; // remaining warps
}
wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW };
int offset = core_id * tasks_per_core;
int remain = offset + (tasks_per_core_n1 - rT);
wspawn_tasks_args_t wspawn_args = { callback, arg, offset, remain, fW, rW};
g_wspawn_args[core_id] = &wspawn_args;
if (TW >= 1) {
@ -137,23 +144,11 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
int nw = MIN(TW, NW);
vx_wspawn(nw, spawn_tasks_all_cb);
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_all_stub();
// back to single-threaded
vx_tmc_one();
// wait for spawn warps to terminate
vx_wspawn_wait();
// execute callback on warp 0
spawn_tasks_all_cb();
}
if (rT != 0) {
// adjust offset
wspawn_args.offset += (tasks_per_core_n1 - rT);
// activate remaining threads
int tmask = (1 << rT) - 1;
vx_tmc(tmask);
@ -164,24 +159,29 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// back to single-threaded
vx_tmc_one();
}
// wait for spawn warps to terminate
vx_wspawn_wait();
}
///////////////////////////////////////////////////////////////////////////////
static void __attribute__ ((noinline)) spawn_kernel_all_stub() {
static void __attribute__ ((noinline)) spawn_pocl_kernel_all_stub() {
int NT = vx_num_threads();
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[cid];
wspawn_pocl_kernel_args_t* p_wspawn_args = (wspawn_pocl_kernel_args_t*)g_wspawn_args[cid];
pocl_kernel_context_t* ctx = p_wspawn_args->ctx;
void* arg = p_wspawn_args->arg;
int wK = (p_wspawn_args->FWs * wid) + MIN(p_wspawn_args->RWs, wid);
int tK = p_wspawn_args->FWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (wK * NT) + tid;
int X = p_wspawn_args->ctx->num_groups[0];
int Y = p_wspawn_args->ctx->num_groups[1];
int X = ctx->num_groups[0];
int Y = ctx->num_groups[1];
int XY = X * Y;
if (p_wspawn_args->isXYpow2) {
@ -190,7 +190,8 @@ static void __attribute__ ((noinline)) spawn_kernel_all_stub() {
int wg_2d = wg_id - k * XY;
int j = wg_2d >> p_wspawn_args->log2X;
int i = wg_2d - j * X;
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, i, j, k);
int local_offset = wg_id * p_wspawn_args->local_size;
(p_wspawn_args->callback)(arg, ctx, i, j, k, local_offset);
}
} else {
for (int wg_id = offset, N = wg_id + tK * NT; wg_id < N; wg_id += NT ) {
@ -198,50 +199,54 @@ static void __attribute__ ((noinline)) spawn_kernel_all_stub() {
int wg_2d = wg_id - k * XY;
int j = wg_2d / X;
int i = wg_2d - j * X;
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, i, j, k);
int local_offset = wg_id * p_wspawn_args->local_size;
(p_wspawn_args->callback)(arg, ctx, i, j, k, local_offset);
}
}
}
static void __attribute__ ((noinline)) spawn_kernel_rem_stub() {
static void __attribute__ ((noinline)) spawn_pocl_kernel_rem_stub() {
int cid = vx_core_id();
int tid = vx_thread_id();
wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[cid];
wspawn_pocl_kernel_args_t* p_wspawn_args = (wspawn_pocl_kernel_args_t*)g_wspawn_args[cid];
pocl_kernel_context_t* ctx = p_wspawn_args->ctx;
void* arg = p_wspawn_args->arg;
int wg_id = p_wspawn_args->offset + tid;
int X = p_wspawn_args->ctx->num_groups[0];
int Y = p_wspawn_args->ctx->num_groups[1];
int X = ctx->num_groups[0];
int Y = ctx->num_groups[1];
int XY = X * Y;
int wg_id = p_wspawn_args->remain + tid;
int local_offset = wg_id * p_wspawn_args->local_size;
if (p_wspawn_args->isXYpow2) {
int k = wg_id >> p_wspawn_args->log2XY;
int wg_2d = wg_id - k * XY;
int j = wg_2d >> p_wspawn_args->log2X;
int i = wg_2d - j * X;
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, i, j, k);
(p_wspawn_args->callback)(arg, ctx, i, j, k, local_offset);
} else {
int k = wg_id / XY;
int wg_2d = wg_id - k * XY;
int j = wg_2d / X;
int i = wg_2d - j * X;
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, i, j, k);
(p_wspawn_args->callback)(arg, ctx, i, j, k, local_offset);
}
}
static void __attribute__ ((noinline)) spawn_kernel_all_cb() {
static void __attribute__ ((noinline)) spawn_pocl_kernel_all_cb() {
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_kernel_all_stub();
spawn_pocl_kernel_all_stub();
// disable warp
vx_tmc_zero();
}
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
void vx_spawn_pocl_kernel(pocl_kernel_context_t * ctx, pocl_kernel_cb callback, void * arg) {
// total number of WGs
int X = ctx->num_groups[0];
int Y = ctx->num_groups[1];
@ -288,44 +293,38 @@ void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
char log2XY = log2_fast(XY);
char log2X = log2_fast(X);
wspawn_kernel_args_t wspawn_args = {
ctx, callback, arg, core_id * tasks_per_core, fW, rW, isXYpow2, log2XY, log2X
int local_size = ctx->local_size[0] * ctx->local_size[1] * ctx->local_size[2];
int offset = core_id * tasks_per_core;
int remain = offset + (tasks_per_core_n1 - rT);
wspawn_pocl_kernel_args_t wspawn_args = {
ctx, callback, arg, local_size, offset, remain, fW, rW, isXYpow2, log2XY, log2X
};
g_wspawn_args[core_id] = &wspawn_args;
if (TW >= 1) {
// execute callback on other warps
int nw = MIN(TW, NW);
vx_wspawn(nw, spawn_kernel_all_cb);
vx_wspawn(nw, spawn_pocl_kernel_all_cb);
// activate all threads
vx_tmc(-1);
// call stub routine
asm volatile("" ::: "memory");
spawn_kernel_all_stub();
// back to single-threaded
vx_tmc_one();
// wait for spawn warps to terminate
vx_wspawn_wait();
// execute callback on warp 0
spawn_pocl_kernel_all_cb();
}
if (rT != 0) {
// adjust offset
wspawn_args.offset += (tasks_per_core_n1 - rT);
// activate remaining threads
int tmask = (1 << rT) - 1;
vx_tmc(tmask);
// call stub routine
spawn_kernel_rem_stub();
spawn_pocl_kernel_rem_stub();
// back to single-threaded
vx_tmc_one();
}
// wait for spawn warps to terminate
vx_wspawn_wait();
}
#ifdef __cplusplus