wspawn thread index reordering

This commit is contained in:
Blaise Tine 2024-03-04 20:28:45 -08:00
parent 589e351832
commit de8453d0be

View file

@ -29,7 +29,7 @@ typedef struct {
vx_spawn_tasks_cb callback;
void* arg;
int offset; // task offset
int NWs; // number of NW batches where NW=<total warps per core>.
int FWs; // number of NW batches where NW=<total warps per core>.
int RWs; // number of remaining warps in the core
} wspawn_tasks_args_t;
@ -38,7 +38,7 @@ typedef struct {
vx_spawn_kernel_cb callback;
void* arg;
int offset; // task offset
int NWs; // number of NW batches where NW=<total warps per core>.
int FWs; // number of NW batches where NW=<total warps per core>.
int RWs; // number of remaining warps in the core
char isXYpow2;
char log2XY;
@ -63,13 +63,13 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
int wK = (p_wspawn_args->NWs * wid) + MIN(p_wspawn_args->RWs, wid);
int tK = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (wK * NT) + (tid * tK);
int wK = (p_wspawn_args->FWs * wid) + MIN(p_wspawn_args->RWs, wid);
int tK = p_wspawn_args->FWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (wK * NT) + tid;
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
void* arg = p_wspawn_args->arg;
for (int task_id = offset, N = task_id + tK; task_id < N; ++task_id) {
for (int task_id = offset, N = offset + tK * NT; task_id < N; task_id += NT) {
callback(task_id, arg);
}
}
@ -176,16 +176,16 @@ static void __attribute__ ((noinline)) spawn_kernel_all_stub() {
wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[cid];
int wK = (p_wspawn_args->NWs * wid) + MIN(p_wspawn_args->RWs, wid);
int tK = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (wK * NT) + (tid * tK);
int wK = (p_wspawn_args->FWs * wid) + MIN(p_wspawn_args->RWs, wid);
int tK = p_wspawn_args->FWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (wK * NT) + tid;
int X = p_wspawn_args->ctx->num_groups[0];
int Y = p_wspawn_args->ctx->num_groups[1];
int XY = X * Y;
if (p_wspawn_args->isXYpow2) {
for (int wg_id = offset, N = wg_id + tK; wg_id < N; ++wg_id) {
for (int wg_id = offset, N = wg_id + tK * NT; wg_id < N; wg_id += NT ) {
int k = wg_id >> p_wspawn_args->log2XY;
int wg_2d = wg_id - k * XY;
int j = wg_2d >> p_wspawn_args->log2X;
@ -193,7 +193,7 @@ static void __attribute__ ((noinline)) spawn_kernel_all_stub() {
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, i, j, k);
}
} else {
for (int wg_id = offset, N = wg_id + tK; wg_id < N; ++wg_id) {
for (int wg_id = offset, N = wg_id + tK * NT; wg_id < N; wg_id += NT ) {
int k = wg_id / XY;
int wg_2d = wg_id - k * XY;
int j = wg_2d / X;