fix split/join hardware

This commit is contained in:
Blaise Tine 2021-06-15 16:06:54 -04:00
parent f7bc14a2ec
commit 288f0c976b

View file

@ -21,7 +21,7 @@ module VX_warp_sched #(
`UNUSED_PARAM (CORE_ID)
wire join_fall;
wire join_else;
wire [31:0] join_pc;
wire [`NUM_THREADS-1:0] join_tm;
@ -45,8 +45,6 @@ module VX_warp_sched #(
reg [`NW_BITS-1:0] scheduled_warp;
wire warp_scheduled;
reg didnt_split;
wire ifetch_rsp_fire = ifetch_rsp_if.valid && ifetch_rsp_if.ready;
always @(*) begin
@ -82,7 +80,6 @@ module VX_warp_sched #(
schedule_table[0] <= 1; // set first warp as ready
thread_masks[0] <= 1; // Activating first thread in first warp
stalled_warps <= 0;
didnt_split <= 0;
fetch_lock <= 0;
for (integer i = 1; i < `NUM_WARPS; i++) begin
@ -107,19 +104,10 @@ module VX_warp_sched #(
end else if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
stalled_warps[warp_ctl_if.wid] <= 0;
end else if (join_if.valid && !didnt_split) begin
if (!join_fall) begin
warp_pcs[join_if.wid] <= join_pc;
end
thread_masks[join_if.wid] <= join_tm;
didnt_split <= 0;
end else if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (warp_ctl_if.split.diverged) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_mask;
didnt_split <= 0;
end else begin
didnt_split <= 1;
end
end
@ -150,6 +138,14 @@ module VX_warp_sched #(
warp_pcs[ifetch_rsp_if.wid] <= ifetch_rsp_if.PC + 4;
end
// join handling
if (join_if.valid) begin
if (join_else) begin
warp_pcs[join_if.wid] <= join_pc;
end
thread_masks[join_if.wid] <= join_tm;
end
active_warps <= active_warps_n;
// reset 'schedule_table' when it goes to zero
@ -174,22 +170,21 @@ module VX_warp_sched #(
end
end
// split/join stack management
// split/join stack management
wire [(1+32+`NUM_THREADS-1):0] ipdom [`NUM_WARPS-1:0];
wire [(1+32+`NUM_THREADS-1):0] q1 = {1'b1, 32'b0, thread_masks[warp_ctl_if.wid]};
wire [(1+32+`NUM_THREADS-1):0] q2 = {1'b0, warp_ctl_if.split.pc, warp_ctl_if.split.else_mask};
assign {join_fall, join_pc, join_tm} = ipdom [join_if.wid];
for (genvar i = 0; i < `NUM_WARPS; i++) begin
wire push = warp_ctl_if.valid
&& warp_ctl_if.split.valid
&& warp_ctl_if.split.diverged
&& warp_ctl_if.split.valid
&& (i == warp_ctl_if.wid);
wire pop = join_if.valid && (i == join_if.wid);
wire [`NUM_THREADS-1:0] else_mask = warp_ctl_if.split.diverged ? warp_ctl_if.split.else_mask : thread_masks[warp_ctl_if.wid];
wire [(1+32+`NUM_THREADS-1):0] q_end = {1'b0, 32'b0, thread_masks[warp_ctl_if.wid]};
wire [(1+32+`NUM_THREADS-1):0] q_else = {1'b1, warp_ctl_if.split.pc, else_mask};
VX_ipdom_stack #(
.WIDTH (1+32+`NUM_THREADS),
.DEPTH (2 ** (`NT_BITS+1))
@ -198,14 +193,16 @@ module VX_warp_sched #(
.reset (reset),
.push (push),
.pop (pop),
.q1 (q1),
.q2 (q2),
.q1 (q_end),
.q2 (q_else),
.d (ipdom[i]),
`UNUSED_PIN (empty),
`UNUSED_PIN (full)
);
end
assign {join_else, join_pc, join_tm} = ipdom [join_if.wid];
// calculate next warp schedule
reg [`NUM_THREADS-1:0] thread_mask;