SIMT stack fix

This commit is contained in:
Blaise Tine 2024-05-01 20:50:21 -07:00
parent 896aca0c62
commit b5ca7a999c
3 changed files with 21 additions and 6 deletions

View file

@ -57,6 +57,7 @@ module VX_schedule import VX_gpu_pkg::*; #(
wire schedule_ready;
// split/join
wire [`NUM_THREADS-1:0] split_tmask;
wire join_valid;
wire join_is_dvg;
wire join_is_else;
@ -137,7 +138,7 @@ module VX_schedule import VX_gpu_pkg::*; #(
// split handling
if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
if (warp_ctl_if.split.is_dvg) begin
thread_masks_n[warp_ctl_if.wid] = warp_ctl_if.split.then_tmask;
thread_masks_n[warp_ctl_if.wid] = split_tmask;
end
stalled_warps_n[warp_ctl_if.wid] = 0; // unlock warp
end
@ -295,6 +296,7 @@ module VX_schedule import VX_gpu_pkg::*; #(
.wid (warp_ctl_if.wid),
.split (warp_ctl_if.split),
.sjoin (warp_ctl_if.sjoin),
.split_tmask(split_tmask),
.join_valid (join_valid),
.join_is_dvg(join_is_dvg),
.join_is_else(join_is_else),

View file

@ -22,6 +22,7 @@ module VX_split_join import VX_gpu_pkg::*; #(
input wire [`NW_WIDTH-1:0] wid,
input split_t split,
input join_t sjoin,
output wire [`NUM_THREADS-1:0] split_tmask,
output wire join_valid,
output wire join_is_dvg,
output wire join_is_else,
@ -37,8 +38,15 @@ module VX_split_join import VX_gpu_pkg::*; #(
wire [`DV_STACK_SIZEW-1:0] ipdom_q_ptr [`NUM_WARPS-1:0];
wire ipdom_set [`NUM_WARPS-1:0];
wire [`CLOG2(`NUM_THREADS+1)-1:0] then_tmask_cnt, else_tmask_cnt;
`POP_COUNT(then_tmask_cnt, split.then_tmask);
`POP_COUNT(else_tmask_cnt, split.else_tmask);
wire then_first = (then_tmask_cnt >= else_tmask_cnt);
assign split_tmask = then_first ? split.then_tmask : split.else_tmask;
wire [`NUM_THREADS-1:0] ntaken_tmask = then_first ? split.else_tmask : split.then_tmask;
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_q0 = {split.then_tmask | split.else_tmask, `XLEN'(0)};
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_q1 = {split.else_tmask, split.next_pc};
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_q1 = {ntaken_tmask, split.next_pc};
wire sjoin_is_dvg = (sjoin.stack_ptr != ipdom_q_ptr[wid]);

View file

@ -1325,12 +1325,17 @@ void Emulator::execute(const Instr &instr, uint32_t wid, instr_trace_t *trace) {
std::cout << "IPDOM stack is full! size=" << std::dec << stack_size << ", PC=0x" << std::hex << warp.PC << " (#" << std::dec << trace->uuid << ")\n" << std::flush;
std::abort();
}
// set new thread mask
next_tmask = then_tmask;
// set new thread mask to the larger set
if (then_tmask.count() >= else_tmask.count()) {
next_tmask = then_tmask;
} else {
next_tmask = else_tmask;
}
// push reconvergence thread mask onto the stack
warp.ipdom_stack.emplace(warp.tmask);
// push else's thread mask onto the stack
warp.ipdom_stack.emplace(else_tmask, next_pc);
// push not taken thread mask onto the stack
auto ntaken_tmask = ~next_tmask & warp.tmask;
warp.ipdom_stack.emplace(ntaken_tmask, next_pc);
}
// return divergent state
for (uint32_t t = thread_start; t < num_threads; ++t) {