minor update

This commit is contained in:
Blaise Tine 2023-07-03 17:13:41 -04:00
parent f4adc4b4fe
commit 43f2daf7d3
6 changed files with 114 additions and 79 deletions

View file

@ -28,7 +28,7 @@
`define NR_BITS `CLOG2(`NUM_REGS)
`define PD_STACK_SIZE `UP(`NT_BITS)
`define PD_STACK_SIZEW `CLOG2(`PD_STACK_SIZE)
`define PD_STACK_SIZEW `UP(`CLOG2(`PD_STACK_SIZE))
`define PERF_CTR_BITS 44

View file

@ -3,7 +3,7 @@
module VX_ipdom_stack #(
parameter WIDTH = 1,
parameter DEPTH = 1,
parameter ADDRW = $clog2(DEPTH)
parameter ADDRW = `UP($clog2(DEPTH))
) (
input wire clk,
input wire reset,

View file

@ -60,9 +60,9 @@ module VX_schedule #(
wire schedule_ready;
// split/join
wire split_is_divergent;
wire [`NUM_THREADS-1:0] split_tmask0;
wire join_is_divergent;
wire split_is_dvg;
wire [`NUM_THREADS-1:0] split_tmask;
wire join_is_dvg;
wire join_is_else;
wire [`NUM_THREADS-1:0] join_tmask;
wire [`XLEN-1:0] join_pc;
@ -79,10 +79,10 @@ module VX_schedule #(
always @(*) begin
active_warps_n = active_warps;
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin
if (warp_ctl_if.wspawn.valid) begin
active_warps_n = warp_ctl_if.wspawn.wmask;
end
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
if (warp_ctl_if.tmc.valid) begin
active_warps_n[warp_ctl_if.wid] = tmc_active;
end
end
@ -107,9 +107,9 @@ module VX_schedule #(
thread_masks[0] <= 1;
end else begin
// join handling
if (warp_ctl_if.valid && warp_ctl_if.sjoin.valid) begin
if (warp_ctl_if.sjoin.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (join_is_divergent) begin
if (join_is_dvg) begin
if (join_is_else) begin
warp_pcs[warp_ctl_if.wid] <= `XLEN'(join_pc);
end
@ -117,13 +117,13 @@ module VX_schedule #(
end
end
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin
if (warp_ctl_if.wspawn.valid) begin
use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1));
wspawn_pc <= warp_ctl_if.wspawn.pc;
end
// barrier handling
if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin
if (warp_ctl_if.barrier.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (warp_ctl_if.barrier.is_global
&& (curr_barrier_mask_n == active_warps)) begin
@ -144,16 +144,16 @@ module VX_schedule #(
end
// TMC handling
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
if (warp_ctl_if.tmc.valid) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
stalled_warps[warp_ctl_if.wid] <= 0;
end
// split handling
if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
if (warp_ctl_if.split.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (split_is_divergent) begin
thread_masks[warp_ctl_if.wid] <= split_tmask0;
if (split_is_dvg) begin
thread_masks[warp_ctl_if.wid] <= split_tmask;
end
end
@ -225,59 +225,22 @@ module VX_schedule #(
// split/join handling
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_data [`NUM_WARPS-1:0];
wire [`PD_STACK_SIZEW-1:0] ipdom_q_ptr [`NUM_WARPS-1:0];
wire ipdom_index [`NUM_WARPS-1:0];
wire [`NUM_THREADS-1:0] then_tmask;
wire [`NUM_THREADS-1:0] else_tmask;
for (genvar i = 0; i < `NUM_THREADS; ++i) begin
assign then_tmask[i] = warp_ctl_if.split.tmask[i] && warp_ctl_if.split.taken[i];
assign else_tmask[i] = warp_ctl_if.split.tmask[i] && ~warp_ctl_if.split.taken[i];
end
wire [`CLOG2(`NUM_THREADS+1)-1:0] then_tmask_cnt, else_tmask_cnt;
`POP_COUNT(then_tmask_cnt, then_tmask);
`POP_COUNT(else_tmask_cnt, else_tmask);
wire then_first = (then_tmask_cnt >= else_tmask_cnt);
assign split_is_divergent = (then_tmask != 0) && (else_tmask != 0);
assign split_tmask0 = then_first ? then_tmask : else_tmask;
assign warp_ctl_if.split_ret = ipdom_q_ptr[warp_ctl_if.wid];
assign join_is_divergent = (warp_ctl_if.sjoin.stack_ptr != ipdom_q_ptr[warp_ctl_if.wid]);
assign {join_pc, join_tmask} = ipdom_data[warp_ctl_if.wid];
assign join_is_else = (ipdom_index[warp_ctl_if.wid] == 0);
wire [`NUM_THREADS-1:0] split_tmask1 = then_first ? else_tmask : then_tmask;
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_q0 = {warp_ctl_if.split.next_pc, split_tmask1};
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_q1 = {`XLEN'(0), warp_ctl_if.split.tmask};
wire ipdom_push = warp_ctl_if.valid && warp_ctl_if.split.valid && split_is_divergent;
wire ipdom_pop = warp_ctl_if.valid && warp_ctl_if.sjoin.valid && join_is_divergent;
`RESET_RELAY (ipdom_reset, reset);
for (genvar i = 0; i < `NUM_WARPS; ++i) begin
VX_ipdom_stack #(
.WIDTH (`XLEN+`NUM_THREADS),
.DEPTH (`PD_STACK_SIZE)
) ipdom_stack (
.clk (clk),
.reset (ipdom_reset),
.push (ipdom_push && (i == warp_ctl_if.wid)),
.pop (ipdom_pop && (i == warp_ctl_if.wid)),
.q0 (ipdom_q0),
.q1 (ipdom_q1),
.d (ipdom_data[i]),
.d_idx (ipdom_index[i]),
.q_ptr (ipdom_q_ptr[i]),
`UNUSED_PIN (d_ptr),
`UNUSED_PIN (empty),
`UNUSED_PIN (full)
);
end
VX_split_join #(
.CORE_ID (CORE_ID)
) split_join (
.clk (clk),
.reset (reset),
.wid (warp_ctl_if.wid),
.split (warp_ctl_if.split),
.sjoin (warp_ctl_if.sjoin),
.split_is_dvg (split_is_dvg),
.split_tmask (split_tmask),
.split_ret (warp_ctl_if.split_ret),
.join_is_dvg (join_is_dvg),
.join_is_else (join_is_else),
.join_tmask (join_tmask),
.join_pc (join_pc)
);
// schedule the next ready warp

View file

@ -0,0 +1,75 @@
`include "VX_platform.vh"
module VX_split_join #(
parameter CORE_ID = 0
) (
input wire clk,
input wire reset,
input wire [`UP(`NW_BITS)-1:0] wid,
input gpu_split_t split,
input gpu_join_t sjoin,
output wire split_is_dvg,
output wire [`NUM_THREADS-1:0] split_tmask,
output wire [`PD_STACK_SIZEW-1:0] split_ret,
output wire join_is_dvg,
output wire join_is_else,
output wire [`NUM_THREADS-1:0] join_tmask,
output wire [`XLEN-1:0] join_pc
);
`UNUSED_PARAM (CORE_ID)
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_data [`NUM_WARPS-1:0];
wire [`PD_STACK_SIZEW-1:0] ipdom_q_ptr [`NUM_WARPS-1:0];
wire ipdom_index [`NUM_WARPS-1:0];
wire [`NUM_THREADS-1:0] then_tmask;
wire [`NUM_THREADS-1:0] else_tmask;
for (genvar i = 0; i < `NUM_THREADS; ++i) begin
assign then_tmask[i] = split.tmask[i] && split.taken[i];
assign else_tmask[i] = split.tmask[i] && ~split.taken[i];
end
wire [`CLOG2(`NUM_THREADS+1)-1:0] then_tmask_cnt, else_tmask_cnt;
`POP_COUNT(then_tmask_cnt, then_tmask);
`POP_COUNT(else_tmask_cnt, else_tmask);
wire then_first = (then_tmask_cnt >= else_tmask_cnt);
assign split_is_dvg = (then_tmask != 0) && (else_tmask != 0);
assign split_tmask = then_first ? then_tmask : else_tmask;
assign split_ret = ipdom_q_ptr[wid];
assign join_is_dvg = (sjoin.stack_ptr != ipdom_q_ptr[wid]);
assign {join_pc, join_tmask} = ipdom_data[wid];
assign join_is_else = (ipdom_index[wid] == 0);
wire [`NUM_THREADS-1:0] split_tmask_n = then_first ? else_tmask : then_tmask;
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_q0 = {split.next_pc, split_tmask_n};
wire [(`XLEN+`NUM_THREADS)-1:0] ipdom_q1 = {`XLEN'(0), split.tmask};
wire ipdom_push = split.valid && split_is_dvg;
wire ipdom_pop = sjoin.valid && join_is_dvg;
`RESET_RELAY (ipdom_reset, reset);
for (genvar i = 0; i < `NUM_WARPS; ++i) begin
VX_ipdom_stack #(
.WIDTH (`XLEN+`NUM_THREADS),
.DEPTH (`PD_STACK_SIZE)
) ipdom_stack (
.clk (clk),
.reset (ipdom_reset),
.push (ipdom_push && (i == wid)),
.pop (ipdom_pop && (i == wid)),
.q0 (ipdom_q0),
.q1 (ipdom_q1),
.d (ipdom_data[i]),
.d_idx (ipdom_index[i]),
.q_ptr (ipdom_q_ptr[i]),
`UNUSED_PIN (d_ptr),
`UNUSED_PIN (empty),
`UNUSED_PIN (full)
);
end
endmodule

View file

@ -37,6 +37,8 @@ module VX_wctl_unit #(
assign taken[i] = gpu_exe_if.rs1_data[i][0];
end
wire gpu_exe_fire = gpu_exe_if.valid && gpu_exe_if.ready;
wire is_wspawn = (gpu_exe_if.op_type == `INST_GPU_WSPAWN);
wire is_tmc = (gpu_exe_if.op_type == `INST_GPU_TMC);
wire is_pred = (gpu_exe_if.op_type == `INST_GPU_PRED);
@ -44,7 +46,6 @@ module VX_wctl_unit #(
wire is_join = (gpu_exe_if.op_type == `INST_GPU_JOIN);
wire is_bar = (gpu_exe_if.op_type == `INST_GPU_BAR);
assign warp_ctl_if.valid = gpu_exe_if.valid && gpu_exe_if.ready;
assign warp_ctl_if.wid = gpu_exe_if.wid;
assign warp_ctl_if.tmc = tmc;
assign warp_ctl_if.wspawn = wspawn;
@ -57,34 +58,33 @@ module VX_wctl_unit #(
wire [`NUM_THREADS-1:0] then_tmask = gpu_exe_if.tmask & taken;
wire [`NUM_THREADS-1:0] pred_mask = (then_tmask != 0) ? then_tmask : gpu_exe_if.tmask;
assign tmc.valid = is_tmc || is_pred;
assign tmc.valid = gpu_exe_fire && (is_tmc || is_pred);
assign tmc.tmask = is_pred ? pred_mask : rs1_data[`NUM_THREADS-1:0];
// wspawn
wire [`XLEN-1:0] wspawn_pc = rs2_data;
wire [`NUM_WARPS-1:0] wspawn_wmask;
for (genvar i = 0; i < `NUM_WARPS; ++i) begin
assign wspawn_wmask[i] = (i < rs1_data[31:0]);
end
assign wspawn.valid = is_wspawn;
assign wspawn.valid = gpu_exe_fire && is_wspawn;
assign wspawn.wmask = wspawn_wmask;
assign wspawn.pc = wspawn_pc;
assign wspawn.pc = rs2_data;
// split
assign split.valid = is_split;
assign split.valid = gpu_exe_fire && is_split;
assign split.taken = taken;
assign split.tmask = gpu_exe_if.tmask;
assign split.next_pc = gpu_exe_if.next_PC;
// join
assign sjoin.valid = is_join;
assign sjoin.valid = gpu_exe_fire && is_join;
assign sjoin.stack_ptr = `PD_STACK_SIZEW'(rs1_data);
// barrier
assign barrier.valid = is_bar;
assign barrier.valid = gpu_exe_fire && is_bar;
assign barrier.id = rs1_data[`NB_BITS-1:0];
assign barrier.is_global = rs1_data[31];
assign barrier.size_m1 = $bits(barrier.size_m1)'(rs2_data[31:0] - 1);

View file

@ -7,7 +7,6 @@ import VX_gpu_types::*;
interface VX_warp_ctl_if ();
wire valid;
wire [`UP(`NW_BITS)-1:0] wid;
gpu_tmc_t tmc;
gpu_wspawn_t wspawn;
@ -17,7 +16,6 @@ interface VX_warp_ctl_if ();
wire [`PD_STACK_SIZEW-1:0] split_ret;
modport master (
output valid,
output wid,
output wspawn,
output tmc,
@ -28,7 +26,6 @@ interface VX_warp_ctl_if ();
);
modport slave (
input valid,
input wid,
input wspawn,
input tmc,