This commit is contained in:
Blaise Tine 2021-09-08 02:27:53 -07:00
commit 81bee3ac45
7 changed files with 56 additions and 32 deletions

View file

@ -349,7 +349,7 @@ module VX_decode #(
ex_type = `EX_GPU;
case (func3)
3'h0: begin
op_type = `INST_OP_BITS'(`INST_GPU_TMC);
op_type = rs2[0] ? `INST_OP_BITS'(`INST_GPU_PRED) : `INST_OP_BITS'(`INST_GPU_TMC);
is_wstall = 1;
`USED_IREG (rs1);
end

View file

@ -185,7 +185,7 @@
`define INST_GPU_SPLIT 3'h2
`define INST_GPU_JOIN 3'h3
`define INST_GPU_BAR 3'h4
`define INST_GPU_OTHER 3'h7
`define INST_GPU_PRED 3'h5
`define INST_GPU_BITS 3
///////////////////////////////////////////////////////////////////////////////

View file

@ -29,11 +29,18 @@ module VX_gpu_unit #(
wire is_tmc = (gpu_req_if.op_type == `INST_GPU_TMC);
wire is_split = (gpu_req_if.op_type == `INST_GPU_SPLIT);
wire is_bar = (gpu_req_if.op_type == `INST_GPU_BAR);
wire is_pred = (gpu_req_if.op_type == `INST_GPU_PRED);
// tmc
assign tmc.valid = is_tmc;
assign tmc.tmask = `NUM_THREADS'(gpu_req_if.rs1_data[gpu_req_if.tid]);
wire [`NUM_THREADS-1:0] pred_cond;
for (genvar i = 0; i < `NUM_THREADS; i++) begin
assign pred_cond[i] = gpu_req_if.tmask[i] && gpu_req_if.rs1_data[i][0];
end
wire [`NUM_THREADS-1:0] pred = (pred_cond != 0) ? pred_cond : gpu_req_if.tmask;
assign tmc.valid = is_tmc || is_pred;
assign tmc.tmask = is_pred ? pred : `NUM_THREADS'(gpu_req_if.rs1_data[gpu_req_if.tid]);
// wspawn

View file

@ -136,6 +136,7 @@ task print_ex_op (
`INST_GPU_SPLIT: dpi_trace("SPLIT");
`INST_GPU_JOIN: dpi_trace("JOIN");
`INST_GPU_BAR: dpi_trace("BAR");
`INST_GPU_PRED: dpi_trace("PRED");
default: dpi_trace("?");
endcase
end

View file

@ -74,33 +74,33 @@ module VX_warp_sched #(
active_warps[0] <= '1;
thread_masks[0] <= '1;
end else begin
if (warp_ctl_if.valid && warp_ctl_if.wspawn.valid) begin
use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1));
wspawn_pc <= warp_ctl_if.wspawn.pc;
end
if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (reached_barrier_limit) begin
barrier_masks[warp_ctl_if.barrier.id] <= 0;
if (warp_ctl_if.valid) begin
if (warp_ctl_if.wspawn.valid) begin
use_wspawn <= warp_ctl_if.wspawn.wmask & (~`NUM_WARPS'(1));
wspawn_pc <= warp_ctl_if.wspawn.pc;
end else begin
barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
stalled_warps[warp_ctl_if.wid] <= 0;
end
end
if (warp_ctl_if.valid && warp_ctl_if.tmc.valid) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
stalled_warps[warp_ctl_if.wid] <= 0;
end
if (warp_ctl_if.valid && warp_ctl_if.split.valid) begin
stalled_warps[warp_ctl_if.wid] <= 0;
if (warp_ctl_if.split.diverged) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask;
if (warp_ctl_if.barrier.valid) begin
if (reached_barrier_limit) begin
barrier_masks[warp_ctl_if.barrier.id] <= 0;
end else begin
barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
end
end
if (warp_ctl_if.tmc.valid) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.tmc.tmask;
end
if (warp_ctl_if.split.valid) begin
if (warp_ctl_if.split.diverged) begin
thread_masks[warp_ctl_if.wid] <= warp_ctl_if.split.then_tmask;
end
end
end
// Branch
if (branch_ctl_if.valid) begin
if (branch_ctl_if.taken) begin

View file

@ -53,8 +53,13 @@ extern "C" {
})
// Set thread mask
inline void vx_tmc(unsigned num_threads) {
asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(num_threads));
inline void vx_tmc(unsigned mask) {
asm volatile (".insn s 0x6b, 0, x0, 0(%0)" :: "r"(mask));
}
// Set thread predicate
inline void vx_pred(unsigned condition) {
asm volatile (".insn s 0x6b, 0, x1, 0(%0)" :: "r"(condition));
}
typedef void (*vx_wspawn_pfn)();

View file

@ -816,10 +816,21 @@ void Warp::execute(const Instr &instr, Pipeline *pipeline) {
case GPGPU:
switch (func3) {
case 0: {
// TMC
tmask_.reset();
for (int i = 0; i < num_threads; ++i) {
tmask_[i] = rsdata[0] & (1 << i);
// TMC
if (rsrc1) {
// predicate mode
ThreadMask pred;
for (int i = 0; i < num_threads; ++i) {
pred[i] = tmask_[i] ? (iRegFile_[i][rsrc0] != 0) : 0;
}
if (pred.any()) {
tmask_ &= pred;
}
} else {
tmask_.reset();
for (int i = 0; i < num_threads; ++i) {
tmask_[i] = rsdata[0] & (1 << i);
}
}
D(3, "*** TMC " << tmask_);
active_ = tmask_.any();