warp scheduler optimization

This commit is contained in:
Blaise Tine 2021-08-07 23:45:01 -07:00
parent 5b8e58e15e
commit b1eef0fb7c
6 changed files with 66 additions and 66 deletions

View file

@ -221,6 +221,7 @@ module VX_decode #(
use_rd = 1;
use_imm = 1;
use_PC = 1;
is_wstall = 1;
imm = 32'd4;
`USED_IREG (rd);
end
@ -406,8 +407,9 @@ module VX_decode #(
assign join_if.valid = ifetch_rsp_fire && is_join;
assign join_if.wid = ifetch_rsp_if.wid;
assign wstall_if.valid = ifetch_rsp_fire && is_wstall;
assign wstall_if.valid = ifetch_rsp_fire;
assign wstall_if.wid = ifetch_rsp_if.wid;
assign wstall_if.stalled = is_wstall;
assign ifetch_rsp_if.ready = decode_if.ready;

View file

@ -44,7 +44,6 @@ module VX_fetch #(
.branch_ctl_if (branch_ctl_if),
.ifetch_req_if (ifetch_req_if),
.ifetch_rsp_if (ifetch_rsp_if),
.fetch_to_csr_if (fetch_to_csr_if),

View file

@ -38,7 +38,7 @@ task print_ex_op (
`BR_MRET: $write("MRET");
`BR_SRET: $write("SRET");
`BR_DRET: $write("DRET");
default: $write("?");
default: $write("?");
endcase
end else if (`ALU_IS_MUL(op_mod)) begin
case (`MUL_BITS'(op_type))

View file

@ -13,7 +13,6 @@ module VX_warp_sched #(
VX_join_if join_if,
VX_branch_ctl_if branch_ctl_if,
VX_ifetch_rsp_if ifetch_rsp_if,
VX_ifetch_req_if ifetch_req_if,
VX_fetch_to_csr_if fetch_to_csr_if,
@ -30,26 +29,25 @@ module VX_warp_sched #(
reg [`NUM_WARPS-1:0] active_warps, active_warps_n; // real active warps (updated when a warp is activated or disabled)
reg [`NUM_WARPS-1:0] stalled_warps; // asserted when a branch/gpgpu instructions are issued
// Lock warp until instruction decode to resolve branches
reg [`NUM_WARPS-1:0] fetch_lock;
reg [`NUM_WARPS-1:0][`NUM_THREADS-1:0] thread_masks;
reg [`NUM_WARPS-1:0][31:0] warp_pcs, warp_next_pcs;
// barriers
reg [`NUM_BARRIERS-1:0][`NUM_WARPS-1:0] barrier_stall_mask; // warps waiting on barrier
reg [`NUM_BARRIERS-1:0][`NUM_WARPS-1:0] barrier_masks; // warps waiting on barrier
wire reached_barrier_limit; // the expected number of warps reached the barrier
// wspawn
reg [31:0] use_wspawn_pc;
reg [31:0] wspawn_pc;
reg [`NUM_WARPS-1:0] use_wspawn;
wire [`NW_BITS-1:0] schedule_warp;
wire [`NW_BITS-1:0] schedule_wid;
wire [`NUM_THREADS-1:0] schedule_tmask;
wire [31:0] schedule_pc;
wire schedule_valid;
wire warp_scheduled;
wire ifetch_req_fire = ifetch_req_if.valid && ifetch_req_if.ready;
wire ifetch_rsp_fire = ifetch_rsp_if.valid && ifetch_rsp_if.ready;
wire tmc_active = (warp_ctl_if.tmc.tmask != 0);
always @(*) begin
@ -64,55 +62,44 @@ module VX_warp_sched #(
always @(posedge clk) begin
if (reset) begin
for (integer i = 0; i < `NUM_BARRIERS; i++) begin
barrier_stall_mask[i] <= 0;
end
barrier_masks <= 0;
use_wspawn <= 0;
stalled_warps <= 0;
warp_pcs <= '0;
active_warps <= '0;
thread_masks <= '0;
use_wspawn_pc <= 0;
use_wspawn <= 0;
warp_pcs[0] <= `STARTUP_ADDR;
active_warps[0] <= 1; // Activating first warp
thread_masks[0] <= 1; // Activating first thread in first warp
stalled_warps <= 0;
fetch_lock <= 0;
for (integer i = 1; i < `NUM_WARPS; i++) begin
warp_pcs[i] <= 0;
active_warps[i] <= 0;
thread_masks[i] <= 0;
end
// activate first warp
warp_pcs[0] <= `STARTUP_ADDR;
active_warps[0] <= '1;
thread_masks[0] <= '1;
end else begin
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin
use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1));
use_wspawn_pc <= warp_ctl_if.wspawn.pc;
use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1));
wspawn_pc <= warp_ctl_if.wspawn.pc;
end
if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (reached_barrier_limit) begin
barrier_stall_mask[warp_ctl_if.barrier.id] <= 0;
barrier_masks[warp_ctl_if.barrier.id] <= 0;
end else begin
barrier_stall_mask[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
end
end else if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
end
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
stalled_warps[warp_ctl_if.wid] <= 0;
end else if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
end
if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (warp_ctl_if.split.diverged) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask;
end
end
if (use_wspawn[schedule_warp] && warp_scheduled) begin
use_wspawn[schedule_warp] <= 0;
thread_masks[schedule_warp] <= 1;
end
// Stalling the scheduling of warps
if (wstall_if.valid) begin
stalled_warps[wstall_if.wid] <= 1;
end
// Branch
if (branch_ctl_if.valid) begin
@ -122,18 +109,24 @@ module VX_warp_sched #(
stalled_warps[branch_ctl_if.wid] <= 0;
end
// Lock warp until instruction decode to resolve branches
if (warp_scheduled) begin
fetch_lock[schedule_warp] <= 1;
// stall the warp until decode stage
stalled_warps[schedule_wid] <= 1;
// release wspawn
use_wspawn[schedule_wid] <= 0;
if (use_wspawn[schedule_wid]) begin
thread_masks[schedule_wid] <= 1;
end
end
if (ifetch_req_fire) begin
warp_next_pcs[ifetch_req_if.wid] <= ifetch_req_if.PC + 4;
end
if (ifetch_rsp_fire) begin
fetch_lock[ifetch_rsp_if.wid] <= 0;
warp_pcs[ifetch_rsp_if.wid] <= warp_next_pcs[ifetch_rsp_if.wid];
if (wstall_if.valid) begin
stalled_warps[wstall_if.wid] <= wstall_if.stalled;
warp_pcs[wstall_if.wid] <= warp_next_pcs[wstall_if.wid];
end
// join handling
@ -156,15 +149,15 @@ module VX_warp_sched #(
`IGNORE_UNUSED_BEGIN
wire [`NW_BITS:0] active_barrier_count;
`IGNORE_UNUSED_END
assign active_barrier_count = $countones(barrier_stall_mask[warp_ctl_if.barrier.id]);
assign active_barrier_count = $countones(barrier_masks[warp_ctl_if.barrier.id]);
assign reached_barrier_limit = (active_barrier_count[`NW_BITS-1:0] == warp_ctl_if.barrier.size_m1);
reg [`NUM_WARPS-1:0] total_barrier_stall;
reg [`NUM_WARPS-1:0] barrier_stalls;
always @(*) begin
total_barrier_stall = barrier_stall_mask[0];
barrier_stalls = barrier_masks[0];
for (integer i = 1; i < `NUM_BARRIERS; ++i) begin
total_barrier_stall |= barrier_stall_mask[i];
barrier_stalls |= barrier_masks[i];
end
end
@ -205,22 +198,27 @@ module VX_warp_sched #(
// round-robin warp scheduling
wire schedule_valid;
wire [`NUM_WARPS-1:0] ready_warps = active_warps & ~(stalled_warps | barrier_stalls);
VX_rr_arbiter #(
.NUM_REQS (`NUM_WARPS)
) rr_arbiter (
.clk (clk),
.reset (reset),
.requests (active_warps & ~(stalled_warps | total_barrier_stall | fetch_lock)),
.grant_index (schedule_warp),
.requests (ready_warps),
.grant_index (schedule_wid),
.grant_valid (schedule_valid),
`UNUSED_PIN (grant_onehot),
`UNUSED_PIN (enable)
);
wire [`NUM_THREADS-1:0] thread_mask = use_wspawn[schedule_warp] ? `NUM_THREADS'(1) : thread_masks[schedule_warp];
wire [31:0] warp_pc = use_wspawn[schedule_warp] ? use_wspawn_pc : warp_pcs[schedule_warp];
wire [`NUM_WARPS-1:0][(`NUM_THREADS + 32)-1:0] schedule_data;
for (genvar i = 0; i < `NUM_WARPS; ++i) begin
assign schedule_data[i] = {(use_wspawn[i] ? `NUM_THREADS'(1) : thread_masks[i]),
(use_wspawn[i] ? wspawn_pc : warp_pcs[i])};
end
assign {schedule_tmask, schedule_pc} = schedule_data[schedule_wid];
wire stall_out = ~ifetch_req_if.ready && ifetch_req_if.valid;
@ -233,17 +231,17 @@ module VX_warp_sched #(
.clk (clk),
.reset (reset),
.enable (!stall_out),
.data_in ({schedule_valid, thread_mask, warp_pc, schedule_warp}),
.data_in ({schedule_valid, schedule_tmask, schedule_pc, schedule_wid}),
.data_out ({ifetch_req_if.valid, ifetch_req_if.tmask, ifetch_req_if.PC, ifetch_req_if.wid})
);
assign busy = (active_warps != 0);
`SCOPE_ASSIGN (wsched_scheduled_warp, warp_scheduled);
`SCOPE_ASSIGN (wsched_scheduled, warp_scheduled);
`SCOPE_ASSIGN (wsched_active_warps, active_warps);
`SCOPE_ASSIGN (wsched_schedule_table, schedule_table);
`SCOPE_ASSIGN (wsched_schedule_ready, schedule_ready);
`SCOPE_ASSIGN (wsched_warp_to_schedule, schedule_warp);
`SCOPE_ASSIGN (wsched_warp_pc, warp_pc);
`SCOPE_ASSIGN (wsched_stalled_warps, stalled_warps);
`SCOPE_ASSIGN (wsched_schedule_wid, schedule_wid);
`SCOPE_ASSIGN (wsched_schedule_tmask, schedule_tmask);
`SCOPE_ASSIGN (wsched_schedule_pc, schedule_pc);
endmodule

View file

@ -7,6 +7,7 @@ interface VX_wstall_if();
wire valid;
wire [`NW_BITS-1:0] wid;
wire stalled;
endinterface

View file

@ -140,9 +140,9 @@
"afu/vortex/cluster/core/pipeline/fetch/warp_sched": {
"?wsched_scheduled_warp": 1,
"wsched_active_warps": "`NUM_WARPS",
"wsched_schedule_table": "`NUM_WARPS",
"wsched_schedule_ready": "`NUM_WARPS",
"wsched_warp_to_schedule": "`NW_BITS",
"wsched_stalled_warps": "`NUM_WARPS",
"wsched_schedule_tmask": "`NUM_THREADS",
"wsched_schedule_wid": "`NW_BITS",
"wsched_warp_pc": "32"
},
"afu/vortex/cluster/core/pipeline/execute/gpu_unit": {