enabling MSCRATCH CSR

This commit is contained in:
Blaise Tine 2024-04-09 02:01:17 -07:00
parent 7784dfe9b7
commit dd461468d3
3 changed files with 42 additions and 54 deletions

View file

@ -51,6 +51,7 @@
`define VX_CSR_MIDELEG 12'h303
`define VX_CSR_MIE 12'h304
`define VX_CSR_MTVEC 12'h305
`define VX_CSR_MSCRATCH 12'h340
`define VX_CSR_MEPC 12'h341

View file

@ -14,6 +14,7 @@
#ifndef __VX_INTRINSICS_H__
#define __VX_INTRINSICS_H__
#include <stddef.h>
#include <VX_config.h>
#include <VX_types.h>
@ -39,74 +40,67 @@ extern "C" {
#define RISCV_CUSTOM3 0x7B
#define csr_read(csr) ({ \
unsigned __r; \
__asm__ __volatile__ ("csrr %0, %1" : "=r" (__r) : "i" (csr)); \
size_t __r; \
__asm__ __volatile__ ("csrr %0, %1" : "=r" (__r) : "i" (csr) : "memory"); \
__r; \
})
#define csr_write(csr, val) ({ \
unsigned __v = (unsigned)(val); \
size_t __v = (size_t)(val); \
if (__builtin_constant_p(val) && __v < 32) \
__asm__ __volatile__ ("csrw %0, %1" :: "i" (csr), "i" (__v)); \
__asm__ __volatile__ ("csrw %0, %1" :: "i" (csr), "i" (__v) : "memory"); \
else \
__asm__ __volatile__ ("csrw %0, %1" :: "i" (csr), "r" (__v)); \
__asm__ __volatile__ ("csrw %0, %1" :: "i" (csr), "r" (__v) : "memory"); \
})
#define csr_swap(csr, val) ({ \
unsigned __r; \
unsigned __v = (unsigned)(val); \
size_t __r; \
size_t __v = (size_t)(val); \
if (__builtin_constant_p(val) && __v < 32) \
__asm__ __volatile__ ("csrrw %0, %1, %2" : "=r" (__r) : "i" (csr), "i" (__v)); \
__asm__ __volatile__ ("csrrw %0, %1, %2" : "=r" (__r) : "i" (csr), "i" (__v) : "memory"); \
else \
__asm__ __volatile__ ("csrrw %0, %1, %2" : "=r" (__r) : "i" (csr), "r" (__v)); \
__asm__ __volatile__ ("csrrw %0, %1, %2" : "=r" (__r) : "i" (csr), "r" (__v) : "memory"); \
__r; \
})
#define csr_read_set(csr, val) ({ \
unsigned __r; \
unsigned __v = (unsigned)(val); \
size_t __r; \
size_t __v = (size_t)(val); \
if (__builtin_constant_p(val) && __v < 32) \
__asm__ __volatile__ ("csrrs %0, %1, %2" : "=r" (__r) : "i" (csr), "i" (__v)); \
__asm__ __volatile__ ("csrrs %0, %1, %2" : "=r" (__r) : "i" (csr), "i" (__v) : "memory"); \
else \
__asm__ __volatile__ ("csrrs %0, %1, %2" : "=r" (__r) : "i" (csr), "r" (__v)); \
__asm__ __volatile__ ("csrrs %0, %1, %2" : "=r" (__r) : "i" (csr), "r" (__v) : "memory"); \
__r; \
})
#define csr_set(csr, val) ({ \
unsigned __v = (unsigned)(val); \
size_t __v = (size_t)(val); \
if (__builtin_constant_p(val) && __v < 32) \
__asm__ __volatile__ ("csrs %0, %1" :: "i" (csr), "i" (__v)); \
__asm__ __volatile__ ("csrs %0, %1" :: "i" (csr), "i" (__v) : "memory"); \
else \
__asm__ __volatile__ ("csrs %0, %1" :: "i" (csr), "r" (__v)); \
__asm__ __volatile__ ("csrs %0, %1" :: "i" (csr), "r" (__v) : "memory"); \
})
#define csr_read_clear(csr, val) ({ \
unsigned __r; \
unsigned __v = (unsigned)(val); \
size_t __r; \
size_t __v = (size_t)(val); \
if (__builtin_constant_p(val) && __v < 32) \
__asm__ __volatile__ ("csrrc %0, %1, %2" : "=r" (__r) : "i" (csr), "i" (__v)); \
__asm__ __volatile__ ("csrrc %0, %1, %2" : "=r" (__r) : "i" (csr), "i" (__v) : "memory"); \
else \
__asm__ __volatile__ ("csrrc %0, %1, %2" : "=r" (__r) : "i" (csr), "r" (__v)); \
__asm__ __volatile__ ("csrrc %0, %1, %2" : "=r" (__r) : "i" (csr), "r" (__v) : "memory"); \
__r; \
})
#define csr_clear(csr, val) ({ \
unsigned __v = (unsigned)(val); \
size_t __v = (size_t)(val); \
if (__builtin_constant_p(val) && __v < 32) \
__asm__ __volatile__ ("csrc %0, %1" :: "i" (csr), "i" (__v)); \
__asm__ __volatile__ ("csrc %0, %1" :: "i" (csr), "i" (__v) : "memory"); \
else \
__asm__ __volatile__ ("csrc %0, %1" :: "i" (csr), "r" (__v)); \
__asm__ __volatile__ ("csrc %0, %1" :: "i" (csr), "r" (__v) : "memory"); \
})
// Conditional move
inline unsigned vx_cmov(unsigned c, unsigned t, unsigned f) {
unsigned ret;
asm volatile (".insn r4 %1, 1, 0, %0, %2, %3, %4" : "=r"(ret) : "i"(RISCV_CUSTOM1), "r"(c), "r"(t), "r"(f));
return ret;
}
// Set thread mask
inline void vx_tmc(unsigned thread_mask) {
inline void vx_tmc(size_t thread_mask) {
asm volatile (".insn r %0, 0, 0, x0, %1, x0" :: "i"(RISCV_CUSTOM0), "r"(thread_mask));
}
@ -119,37 +113,36 @@ inline void vx_tmc_zero() {
inline void vx_tmc_one() {
asm volatile (
"li a0, 1\n\t" // Load immediate value 1 into a0 (x10) register
".insn r %0, 0, 0, x0, a0, x0" :: "i"(RISCV_CUSTOM0)
: "a0" // Indicate that a0 (x10) is clobbered
".insn r %0, 0, 0, x0, a0, x0" :: "i"(RISCV_CUSTOM0) : "a0"
);
}
// Set thread predicate
inline void vx_pred(unsigned condition, unsigned thread_mask) {
inline void vx_pred(int condition, int thread_mask) {
asm volatile (".insn r %0, 5, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(condition), "r"(thread_mask));
}
typedef void (*vx_wspawn_pfn)();
// Spawn warps
inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) {
inline void vx_wspawn(size_t num_warps, vx_wspawn_pfn func_ptr) {
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
}
// Split on a predicate
inline unsigned vx_split(unsigned predicate) {
unsigned ret;
inline int vx_split(int predicate) {
size_t ret;
asm volatile (".insn r %1, 2, 0, %0, %2, x0" : "=r"(ret) : "i"(RISCV_CUSTOM0), "r"(predicate));
return ret;
}
// Join
inline void vx_join(unsigned stack_ptr) {
inline void vx_join(int stack_ptr) {
asm volatile (".insn r %0, 3, 0, x0, %1, x0" :: "i"(RISCV_CUSTOM0), "r"(stack_ptr));
}
// Warp Barrier
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
inline void vx_barrier(int barried_id, int num_warps) {
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
}
@ -181,8 +174,8 @@ inline int vx_thread_mask() {
return ret;
}
// Return number of active warps
inline int vx_active_warps() {
// Return active warps mask
inline int vx_warp_mask() {
int ret;
asm volatile ("csrr %0, %1" : "=r"(ret) : "i"(VX_CSR_WARP_MASK));
return ret;

View file

@ -49,8 +49,6 @@ typedef struct {
char log2X;
} wspawn_pocl_kernel_args_t;
void* g_wspawn_args[NUM_CORES_MAX];
inline char is_log2(int x) {
return ((x & (x-1)) == 0);
}
@ -61,11 +59,10 @@ inline int log2_fast(int x) {
static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
int NT = vx_num_threads();
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)csr_read(VX_CSR_MSCRATCH);
int wK = (p_wspawn_args->FWs * wid) + MIN(p_wspawn_args->RWs, wid);
int tK = p_wspawn_args->FWs + (wid < p_wspawn_args->RWs);
@ -79,10 +76,9 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
}
static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
int cid = vx_core_id();
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)csr_read(VX_CSR_MSCRATCH);
int task_id = p_wspawn_args->remain + tid;
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
@ -136,8 +132,8 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
int offset = core_id * tasks_per_core;
int remain = offset + (tasks_per_core_n1 - rT);
wspawn_tasks_args_t wspawn_args = { callback, arg, offset, remain, fW, rW};
g_wspawn_args[core_id] = &wspawn_args;
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, remain, fW, rW};
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
if (TW >= 1) {
// execute callback on other warps
@ -174,11 +170,10 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
static void __attribute__ ((noinline)) spawn_pocl_kernel_all_stub() {
int NT = vx_num_threads();
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
wspawn_pocl_kernel_args_t* p_wspawn_args = (wspawn_pocl_kernel_args_t*)g_wspawn_args[cid];
wspawn_pocl_kernel_args_t* p_wspawn_args = (wspawn_pocl_kernel_args_t*)csr_read(VX_CSR_MSCRATCH);
pocl_kernel_context_t* ctx = p_wspawn_args->ctx;
void* arg = p_wspawn_args->arg;
@ -212,10 +207,9 @@ static void __attribute__ ((noinline)) spawn_pocl_kernel_all_stub() {
}
static void __attribute__ ((noinline)) spawn_pocl_kernel_rem_stub() {
int cid = vx_core_id();
int tid = vx_thread_id();
wspawn_pocl_kernel_args_t* p_wspawn_args = (wspawn_pocl_kernel_args_t*)g_wspawn_args[cid];
wspawn_pocl_kernel_args_t* p_wspawn_args = (wspawn_pocl_kernel_args_t*)csr_read(VX_CSR_MSCRATCH);
pocl_kernel_context_t* ctx = p_wspawn_args->ctx;
void* arg = p_wspawn_args->arg;
@ -306,7 +300,7 @@ void vx_spawn_pocl_kernel(pocl_kernel_context_t * ctx, pocl_kernel_cb callback,
wspawn_pocl_kernel_args_t wspawn_args = {
ctx, callback, arg, local_size, offset, remain, fW, rW, isXYpow2, log2XY, log2X
};
g_wspawn_args[core_id] = &wspawn_args;
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
if (TW >= 1) {
// execute callback on other warps