redesign of predicate extension to handle complex code optimizations

This commit is contained in:
Blaise Tine 2023-07-10 20:06:48 -04:00
parent 2b8ba7382a
commit e2461108d2
14 changed files with 110 additions and 37 deletions

View file

@ -27,7 +27,7 @@
`define NR_BITS `CLOG2(`NUM_REGS)
`define PD_STACK_SIZE `UP(`NT_BITS)
`define PD_STACK_SIZE `UP(`NUM_THREADS-1)
`define PD_STACK_SIZEW `UP(`CLOG2(`PD_STACK_SIZE))
`define PERF_CTR_BITS 44

View file

@ -216,7 +216,7 @@ module Vortex (
.MEM_OUT_REG (3),
.NC_ENABLE (1),
.PASSTHRU (!`L3_ENABLED)
) l3cache_wrap (
) l3cache (
.clk (clk),
.reset (l3_reset),

View file

@ -423,8 +423,8 @@ module VX_decode #(
7'h00: begin
ex_type = `EX_GPU;
case (func3)
3'h0: begin // TMC, PRED
op_type = rs2[0] ? `INST_OP_BITS'(`INST_GPU_PRED) : `INST_OP_BITS'(`INST_GPU_TMC);
3'h0: begin // TMC
op_type = `INST_OP_BITS'(`INST_GPU_TMC);
is_wstall = 1;
`USED_IREG (rs1);
end
@ -451,6 +451,12 @@ module VX_decode #(
`USED_IREG (rs1);
`USED_IREG (rs2);
end
3'h5: begin // PRED
op_type = `INST_OP_BITS'(`INST_GPU_PRED);
is_wstall = 1;
`USED_IREG (rs1);
`USED_IREG (rs2);
end
default:;
endcase
end
@ -458,9 +464,9 @@ module VX_decode #(
case (func3)
`ifdef EXT_RASTER_ENABLE
3'h0: begin // RASTER
ex_type = `EX_GPU;
op_type = `INST_OP_BITS'(`INST_GPU_RASTER);
use_rd = 1;
ex_type = `EX_GPU;
op_type = `INST_OP_BITS'(`INST_GPU_RASTER);
use_rd = 1;
`USED_IREG (rd);
end
`endif

View file

@ -53,10 +53,10 @@ module VX_wctl_unit #(
assign warp_ctl_if.sjoin = sjoin;
assign warp_ctl_if.barrier = barrier;
// tmc
// tmc / pred
wire [`NUM_THREADS-1:0] then_tmask = gpu_exe_if.tmask & taken;
wire [`NUM_THREADS-1:0] pred_mask = (then_tmask != 0) ? then_tmask : gpu_exe_if.tmask;
wire [`NUM_THREADS-1:0] pred_taken = taken & gpu_exe_if.tmask;
wire [`NUM_THREADS-1:0] pred_mask = (pred_taken != 0) ? pred_taken : rs2_data[`NUM_THREADS-1:0];
assign tmc.valid = gpu_exe_fire && (is_tmc || is_pred);
assign tmc.tmask = is_pred ? pred_mask : rs1_data[`NUM_THREADS-1:0];

View file

@ -431,7 +431,7 @@ module VX_mem_unit # (
.MEM_OUT_REG (3),
.NC_ENABLE (1),
.PASSTHRU (!`L2_ENABLED)
) l2cache_wrap (
) l2cache (
.clk (clk),
.reset (l2_reset),
`ifdef PERF_ENABLE

View file

@ -118,8 +118,8 @@ inline void vx_tmc(unsigned thread_mask) {
}
// Set thread predicate
inline void vx_pred(unsigned condition) {
asm volatile (".insn r %0, 0, 0, x0, %1, x1" :: "i"(RISCV_CUSTOM0), "r"(condition));
inline void vx_pred(unsigned condition, unsigned thread_mask) {
asm volatile (".insn r %0, 5, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(condition), "r"(thread_mask));
}
typedef void (*vx_wspawn_pfn)();

View file

@ -31,7 +31,7 @@ public:
, num_regs_(32)
, num_csrs_(4096)
, num_barriers_(NUM_BARRIERS)
, ipdom_size_(log2ceil(num_threads) * 2)
, ipdom_size_((num_threads-1) * 2)
{}
uint16_t vsize() const {

View file

@ -385,11 +385,12 @@ static const char* op_string(const Instr &instr) {
switch (func7) {
case 0:
switch (func3) {
case 0: return rs2 ? "PRED" : "TMC";
case 0: return "PRED";
case 1: return "WSPAWN";
case 2: return "SPLIT";
case 3: return "JOIN";
case 4: return "BAR";
case 5: return "PRED";
default:
std::abort();
}

View file

@ -1299,25 +1299,14 @@ void Warp::execute(const Instr &instr, pipeline_trace_t *trace) {
case 0: {
switch (func3) {
case 0: {
// TMC
// TMC
trace->exe_type = ExeType::GPU;
trace->gpu_type = GpuType::TMC;
trace->used_iregs.set(rsrc0);
trace->fetch_stall = true;
if (rsrc1) {
// predicate mode
ThreadMask pred;
for (uint32_t t = 0; t < num_threads; ++t) {
pred[t] = tmask_.test(t) ? (ireg_file_.at(t).at(rsrc0) != 0) : 0;
}
if (pred.any()) {
next_tmask &= pred;
}
} else {
next_tmask.reset();
for (uint32_t t = 0; t < num_threads; ++t) {
next_tmask.set(t, rsdata.at(thread_start)[0].i & (1 << t));
}
next_tmask.reset();
for (uint32_t t = 0; t < num_threads; ++t) {
next_tmask.set(t, rsdata.at(thread_start)[0].i & (1 << t));
}
} break;
case 1: {
@ -1348,7 +1337,7 @@ void Warp::execute(const Instr &instr, pipeline_trace_t *trace) {
if (then_tmask.count() != tmask_.count()
&& else_tmask.count() != tmask_.count()) {
if (ipdom_stack_.size() == arch_.ipdom_size()) {
std::cout << "IPDOM stack is full! (size=" << std::dec << ipdom_stack_.size() << ")\n" << std::flush;
std::cout << "IPDOM stack is full! size=" << std::dec << ipdom_stack_.size() << ", PC=" << std::hex << PC_ << " (#" << std::dec << trace->uuid << ")\n" << std::dec << std::flush;
std::abort();
}
if (then_tmask.count() >= else_tmask.count()) {
@ -1401,6 +1390,23 @@ void Warp::execute(const Instr &instr, pipeline_trace_t *trace) {
trace->fetch_stall = true;
trace->data = std::make_shared<GPUTraceData>(rsdata[thread_start][0].i, rsdata[thread_start][1].i);
} break;
case 5: {
// PRED
trace->exe_type = ExeType::GPU;
trace->gpu_type = GpuType::TMC;
trace->used_iregs.set(rsrc0);
trace->used_iregs.set(rsrc1);
trace->fetch_stall = true;
ThreadMask pred;
for (uint32_t t = 0; t < num_threads; ++t) {
pred[t] = tmask_.test(t) && (ireg_file_.at(t).at(rsrc0) & 0x1);
}
if (pred.any()) {
next_tmask &= pred;
} else {
next_tmask = ireg_file_.at(thread_start).at(rsrc1);
}
} break;
default:
std::abort();
}

View file

@ -32,7 +32,7 @@ LLVM_POCL ?= /opt/llvm-pocl
K_CFLAGS += -v -O3 --sysroot=$(RISCV_SYSROOT) --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH) -Xclang -target-feature -Xclang +vortex
K_CFLAGS += -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections
K_CFLAGS += -I$(VORTEX_KN_PATH)/include
K_CFLAGS += -I$(VORTEX_KN_PATH)/include -DNDEBUG -DLLVM_VOTEX
K_LDFLAGS += -Wl,-Bstatic,--gc-sections,-T$(VORTEX_KN_PATH)/linker/vx_link$(XLEN).ld,--defsym=STARTUP_ADDR=$(STARTUP_ADDR) $(VORTEX_KN_PATH)/libvortexrt.a -lm
CXXFLAGS += -std=c++11 -Wall -Wextra -Wfatal-errors

View file

@ -27,7 +27,7 @@ LLVM_VORTEX ?= /opt/llvm-vortex
LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT)
LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH)
LLVM_CFLAGS += -Xclang -target-feature -Xclang +vortex
#LLVM_CFLAGS += -mllvm -vortex-branch-divergence=2
#LLVM_CFLAGS += -mllvm -vortex-branch-divergence=2
#LLVM_CFLAGS += -mllvm -print-after-all
#LLVM_CFLAGS += -I$(RISCV_SYSROOT)/include/c++/9.2.0/$(RISCV_PREFIX)
#LLVM_CFLAGS += -I$(RISCV_SYSROOT)/include/c++/9.2.0
@ -47,7 +47,7 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy
VX_CFLAGS += -v -O3 -std=c++17
VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections
VX_CFLAGS += -I$(VORTEX_KN_PATH)/include -I$(VORTEX_KN_PATH)/../hw
VX_CFLAGS += -DLLVM_VORTEX
VX_CFLAGS += -DNDEBUG -DLLVM_VORTEX
VX_LDFLAGS += -Wl,-Bstatic,--gc-sections,-T,$(VORTEX_KN_PATH)/linker/vx_link$(XLEN).ld,--defsym=STARTUP_ADDR=$(STARTUP_ADDR) $(VORTEX_KN_PATH)/libvortexrt.a
@ -76,7 +76,7 @@ endif
endif
all: $(PROJECT) kernel.bin kernel.dump
kernel.dump: kernel.elf
$(VX_DP) -D kernel.elf > kernel.dump

View file

@ -1,4 +1,6 @@
#include <stdint.h>
#include <assert.h>
#include <algorithm>
#include <vx_intrinsics.h>
#include <vx_spawn.h>
#include "common.h"
@ -43,7 +45,33 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
// loop
for (int i = 0, n = task_id; i < n; ++i) {
value += src_ptr[i];
}
}
// switch
switch (task_id) {
case 0:
value += 1;
break;
case 1:
value -= 1;
break;
case 2:
value *= 3;
break;
case 3:
value *= 5;
break;
default:
assert(task_id < arg->num_points);
break;
}
// select
value += (task_id >= 0) ? ((task_id > 5) ? src_ptr[0] : task_id) : ((task_id < 5) ? src_ptr[1] : -task_id);
// min/max
value += std::min(src_ptr[task_id], value);
value += std::max(src_ptr[task_id], value);
dst_ptr[task_id] = value;
}

View file

@ -3,6 +3,7 @@
#include <string.h>
#include <vortex.h>
#include <vector>
#include <assert.h>
#include "common.h"
#define RT_CHECK(_expr) \
@ -114,9 +115,34 @@ void gen_ref_data(uint32_t num_points) {
for (int j = 0, n = i; j < n; ++j) {
value += src_data.at(j);
}
// switch
switch (i) {
case 0:
value += 1;
break;
case 1:
value -= 1;
break;
case 2:
value *= 3;
break;
case 3:
value *= 5;
break;
default:
assert(i < (int)num_points);
break;
}
// select
value += (i >= 0) ? ((i > 5) ? src_data.at(0) : i) : ((i < 5) ? src_data.at(1) : -i);
// min/max
value += std::min(src_data.at(i), value);
value += std::max(src_data.at(i), value);
ref_data[i] = value;
//std::cout << std::dec << i << ": result=0x" << std::hex << value << std::endl;
}
}

View file

@ -13,3 +13,9 @@ SRCS = main.cpp $(VORTEX_KN_PATH)/../sim/common/gfxutil.cpp
VX_SRCS = kernel.cpp $(VORTEX_KN_PATH)/../sim/common/graphics.cpp
include ../common.mk
graphics.ll: $(VX_SRCS)
$(VX_CXX) $(VX_CFLAGS) -mllvm -debug-pass=Arguments $(VX_SRCS) $(VX_LDFLAGS) -S -emit-llvm
graphics.pass: graphics.ll
$(LLVM_VORTEX)/bin/llc -O3 -march=riscv32 -target-abi=ilp32f -mcpu=generic-rv32 -mattr=+m,+f,+vortex -float-abi=hard -code-model=small -print-after-all -debug-pass=Executions graphics.ll > graphics.pass 2>&1