runtime update

This commit is contained in:
Blaise Tine 2021-10-07 13:58:21 -04:00
parent efb70b21df
commit cfed87f416
5 changed files with 132 additions and 58 deletions

View file

@ -9,7 +9,10 @@ extern "C" {
int vx_vprintf(const char* format, va_list va);
int vx_printf(const char * format, ...);
int vx_putchar(int c);
void vx_putchar(int c);
void vx_putint(int value, int base);
void vx_putfloat(float value, int precision);
#ifdef __cplusplus
}

View file

@ -8,7 +8,7 @@
extern "C" {
#endif
struct context_t {
typedef struct {
uint32_t num_groups[3];
uint32_t global_offset[3];
uint32_t local_size[3];
@ -16,11 +16,11 @@ struct context_t {
uint32_t *printf_buffer_position;
uint32_t printf_buffer_capacity;
uint32_t work_dim;
};
} context_t;
typedef void (*vx_spawn_kernel_cb) (
const void * /* arg */,
const struct context_t * /* context */,
const context_t * /* context */,
uint32_t /* group_x */,
uint32_t /* group_y */,
uint32_t /* group_z */
@ -28,9 +28,9 @@ typedef void (*vx_spawn_kernel_cb) (
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
typedef void (*vx_serial_cb)(int task_id, void *arg);
typedef void (*vx_serial_cb)(void *arg);
void vx_spawn_kernel(struct context_t * ctx, vx_spawn_kernel_cb callback, void * arg);
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg);
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);

View file

@ -4,28 +4,37 @@
#include <stdlib.h>
#include <stdbool.h>
#include <stdio.h>
#include <math.h>
#ifdef __cplusplus
extern "C" {
#endif
struct printf_arg_t {
typedef struct {
const char* format;
va_list va;
va_list* va;
int ret;
};
} printf_arg_t;
static void __printf_callback(int task_id, void* arg) {
struct printf_arg_t* p_arg = (struct printf_arg_t*)(arg);
p_arg->ret = vprintf(p_arg->format, p_arg->va);
typedef struct {
int value;
int base;
} putint_arg_t;
typedef struct {
float value;
int precision;
} putfloat_arg_t;
static void __printf_cb(printf_arg_t* arg) {
arg->ret = vprintf(arg->format, *arg->va);
}
int vx_vprintf(const char* format, va_list va) {
// need to execute 'vprintf' single-threaded due to potential thread-data dependency
struct printf_arg_t arg;
printf_arg_t arg;
arg.format = format;
arg.va = va;
vx_serial(__printf_callback, &arg);
arg.va = &va;
vx_serial(__printf_cb, &arg);
return arg.ret;
}
@ -38,6 +47,45 @@ int vx_printf(const char * format, ...) {
return ret;
}
static void __putint_cb(const putint_arg_t* arg) {
char tmp[33];
float value = arg->value;
int base = arg->base;
itoa(value, tmp, base);
for (int i = 0; i < 33; ++i) {
int c = tmp[i];
if (!c) break;
vx_putchar(c);
}
}
void vx_putint(int value, int base) {
putint_arg_t arg;
arg.value = value;
arg.base = base;
vx_serial(__putint_cb, &arg);
}
static void __putfloat_cb(const putfloat_arg_t* arg) {
float value = arg->value;
int precision = arg->precision;
int ipart = (int)value;
vx_putint(ipart, 10);
if (precision != 0) {
vx_putchar('.');
float frac = value - (float)ipart;
float fscaled = frac * pow(10, precision);
vx_putint((int)fscaled, 10);
}
}
void vx_putfloat(float value, int precision) {
putfloat_arg_t arg;
arg.value = value;
arg.precision = precision;
vx_serial(__putfloat_cb, &arg);
}
#ifdef __cplusplus
}
#endif

View file

@ -1,3 +1,5 @@
#include <VX_config.h>
.type vx_serial, @function
.global vx_serial
vx_serial:
@ -8,23 +10,22 @@ vx_serial:
sw s2, 8(sp)
sw s1, 4(sp)
sw s0, 0(sp)
mv s4, a0 # callback
mv s3, a1 # arg
csrr s2, 0xfc0 # NT
csrr s1, 0xcc0 # tid
li s0, 0 # index
mv s4, a0 # s4 <- callback
mv s3, a1 # s3 <- arg
csrr s2, CSR_NT # s2 <- NT
csrr s1, CSR_WTID # s1 <- tid
li s0, 0 # s0 <- index
label_loop:
sub t0, s0, s1
snez t0, t0
.insn s 0x6b, 2, x0, 0(t0) # split t0
seqz t1, t0 # (index != tid)
.insn s 0x6b, 2, x0, 0(t1) # split t0
bnez t0, label_join
mv a0, s0 # a0 <- index
mv a1, s3 # a1 <- arg
jalr s4 # callback(index, arg)
mv a0, s3 # a0 <- arg
jalr s4 # callback(arg)
label_join:
.insn s 0x6b, 3, x0, 0(x0) # join
addi s0, s0, 1
blt s0, s2, label_loop
addi s0, s0, 1 # index++
blt s0, s2, label_loop # loop back
lw ra, 20(sp)
lw s4, 16(sp)
lw s3, 12(sp)

View file

@ -20,7 +20,7 @@ typedef struct {
} wspawn_tasks_args_t;
typedef struct {
struct context_t * ctx;
context_t * ctx;
vx_spawn_kernel_cb callback;
void * arg;
int offset;
@ -44,10 +44,7 @@ inline int fast_log2(int x) {
return (*(int*)(&f)>>23) - 127;
}
static void spawn_tasks_callback() {
// activate all threads
vx_tmc(-1);
static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
int core_id = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
@ -65,15 +62,9 @@ static void spawn_tasks_callback() {
// wait for all warps to complete
vx_barrier(0, p_wspawn_args->NW);
// set warp0 to single-threaded and stop other warps
vx_tmc(0 == wid);
}
void spawn_remaining_tasks_callback(int thread_mask) {
// activate threads
vx_tmc(thread_mask);
static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
int core_id = vx_core_id();
int tid = vx_thread_gid();
@ -81,6 +72,26 @@ void spawn_remaining_tasks_callback(int thread_mask) {
int task_id = p_wspawn_args->offset + tid;
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
static void spawn_tasks_all_cb() {
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_all_stub();
// set warp0 to single-threaded and stop other warps
int wid = vx_warp_id();
vx_tmc(0 == wid);
}
static void spawn_tasks_rem_cb(int thread_mask) {
// activate threads
vx_tmc(thread_mask);
// call stub routine
spawn_tasks_rem_stub();
// back to single-threaded
vx_tmc(1);
@ -128,24 +139,21 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
if (nW >= 1) {
int nw = MIN(nW, NW);
wspawn_args.NW = nw;
vx_wspawn(nw, spawn_tasks_callback);
spawn_tasks_callback();
vx_wspawn(nw, spawn_tasks_all_cb);
spawn_tasks_all_cb();
}
//--
if (rT != 0) {
wspawn_args.offset = tasks_per_core0 - rT;
int tmask = (1 << rT) - 1;
spawn_remaining_tasks_callback(tmask);
spawn_tasks_rem_cb(tmask);
}
}
///////////////////////////////////////////////////////////////////////////////
static void spawn_kernel_callback() {
// activate all threads
vx_tmc(-1);
static void __attribute__ ((noinline)) spawn_kernel_all_stub() {
int core_id = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
@ -176,15 +184,9 @@ static void spawn_kernel_callback() {
// wait for all warps to complete
vx_barrier(0, p_wspawn_args->NW);
// set warp0 to single-threaded and stop other warps
vx_tmc(0 == wid);
}
static void spawn_kernel_remaining_callback(int thread_mask) {
// activate threads
vx_tmc(thread_mask);
static void __attribute__ ((noinline)) spawn_kernel_rem_stub() {
int core_id = vx_core_id();
int tid = vx_thread_gid();
@ -206,12 +208,32 @@ static void spawn_kernel_remaining_callback(int thread_mask) {
int gid2 = p_wspawn_args->ctx->global_offset[2] + k;
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
}
static void spawn_kernel_all_cb() {
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_kernel_all_stub();
// set warp0 to single-threaded and stop other warps
int wid = vx_warp_id();
vx_tmc(0 == wid);
}
static void spawn_kernel_rem_cb(int thread_mask) {
// activate threads
vx_tmc(thread_mask);
// call stub routine
spawn_kernel_rem_stub();
// back to single-threaded
vx_tmc(1);
}
void vx_spawn_kernel(struct context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
// total number of WGs
int X = ctx->num_groups[0];
int Y = ctx->num_groups[1];
@ -268,15 +290,15 @@ void vx_spawn_kernel(struct context_t * ctx, vx_spawn_kernel_cb callback, void *
if (nW >= 1) {
int nw = MIN(nW, NW);
wspawn_args.NW = nw;
vx_wspawn(nw, spawn_kernel_callback);
spawn_kernel_callback();
vx_wspawn(nw, spawn_kernel_all_cb);
spawn_kernel_all_cb();
}
//--
if (rT != 0) {
wspawn_args.offset = wgs_per_core0 - rT;
int tmask = (1 << rT) - 1;
spawn_kernel_remaining_callback(tmask);
spawn_kernel_rem_cb(tmask);
}
}