wspawn fix

This commit is contained in:
Blaise Tine 2024-05-10 21:42:20 -07:00
parent d1ba02681e
commit 4e7bc9654b
2 changed files with 9 additions and 8 deletions

View file

@ -96,12 +96,14 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
tasks_per_core++;
// calculate number of warps to activate
int active_warps = tasks_per_core / threads_per_warp;
int remaining_tasks = tasks_per_core - active_warps * threads_per_warp;
int total_warps_per_core = tasks_per_core / threads_per_warp;
int remaining_tasks = tasks_per_core - total_warps_per_core * threads_per_warp;
int active_warps = total_warps_per_core;
int warp_batches = 1, remaining_warps = 0;
if (active_warps > warps_per_core) {
warp_batches = active_warps / warps_per_core;
remaining_warps = active_warps - warp_batches * warps_per_core;
active_warps = warps_per_core;
warp_batches = total_warps_per_core / active_warps;
remaining_warps = total_warps_per_core - warp_batches * active_warps;
}
// calculate offsets for task distribution
@ -121,8 +123,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
if (active_warps >= 1) {
// execute callback on other warps
int num_total_warps = MIN(active_warps, warps_per_core);
vx_wspawn(num_total_warps, process_all_tasks_stub);
vx_wspawn(active_warps, process_all_tasks_stub);
// activate all threads
vx_tmc(-1);
@ -252,8 +253,8 @@ void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_c
// calculate number of warps to activate
int groups_per_core = warps_per_core / warps_per_group;
int total_warps_per_core = total_groups_per_core * warps_per_group;
int warp_batches = 1, remaining_warps = 0;
int active_warps = total_warps_per_core;
int warp_batches = 1, remaining_warps = 0;
if (active_warps > warps_per_core) {
active_warps = groups_per_core * warps_per_group;
warp_batches = total_warps_per_core / active_warps;

View file

@ -226,7 +226,7 @@ void Emulator::resume(uint32_t wid) {
bool Emulator::wspawn(uint32_t num_warps, Word nextPC) {
num_warps = std::min<uint32_t>(num_warps, arch_.num_warps());
if (num_warps < 2)
if (num_warps < 2 && active_warps_.count() == 1)
return true;
wspawn_.valid = true;
wspawn_.num_warps = num_warps;