minor update

This commit is contained in:
Blaise Tine 2024-06-08 21:36:28 -07:00
parent e38187acb5
commit 8b63305201
3 changed files with 17 additions and 17 deletions

View file

@ -21,13 +21,13 @@
extern "C" {
#endif
typedef void (*vx_spawn_tasks_cb)(uint32_t task_id, const void *arg);
typedef void (*vx_spawn_tasks_cb)(uint32_t task_id, void *arg);
typedef void (*vx_serial_cb)(const void *arg);
typedef void (*vx_serial_cb)(void *arg);
void vx_spawn_tasks(uint32_t num_tasks, vx_spawn_tasks_cb callback, const void * arg);
void vx_serial(vx_serial_cb callback, void * arg);
void vx_serial(vx_serial_cb callback, const void * arg);
///////////////////////////////////////////////////////////////////////////////
@ -49,7 +49,7 @@ extern __thread uint32_t __local_group_id;
extern uint32_t __groups_per_core;
extern uint32_t __warps_per_group;
typedef void (*vx_kernel_func_cb)(const void *arg);
typedef void (*vx_kernel_func_cb)(void *arg);
#define __local_mem(size) \
(void*)((int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + __local_group_id * size)

View file

@ -58,13 +58,13 @@ static void __putfloat_cb(const putfloat_arg_t* arg) {
float value = arg->value;
int precision = arg->precision;
int ipart = (int)value;
vx_putint(ipart, 10);
if (precision != 0) {
vx_putchar('.');
vx_putint(ipart, 10);
if (precision != 0) {
vx_putchar('.');
float frac = value - (float)ipart;
float fscaled = frac * pow(10, precision);
vx_putint((int)fscaled, 10);
}
float fscaled = frac * pow(10, precision);
vx_putint((int)fscaled, 10);
}
}
static void __vprintf_cb(printf_arg_t* arg) {
@ -90,7 +90,7 @@ int vx_vprintf(const char* format, va_list va) {
arg.format = format;
arg.va = &va;
vx_serial((vx_serial_cb)__vprintf_cb, &arg);
return arg.ret;
return arg.ret;
}
int vx_printf(const char * format, ...) {
@ -99,7 +99,7 @@ int vx_printf(const char * format, ...) {
va_start(va, format);
ret = vx_vprintf(format, va);
va_end(va);
return ret;
return ret;
}
#ifdef __cplusplus

View file

@ -49,7 +49,7 @@ static void __attribute__ ((noinline)) process_all_tasks() {
vx_spawn_tasks_cb callback = targs->callback;
const void* arg = targs->arg;
for (uint32_t task_id = start_task_id; task_id < end_task_id; task_id += threads_per_warp) {
callback(task_id, arg);
callback(task_id, (void*)arg);
}
}
@ -59,7 +59,7 @@ static void __attribute__ ((noinline)) process_remaining_tasks() {
uint32_t thread_id = vx_thread_id();
uint32_t task_id = targs->remain_tasks_offset + thread_id;
(targs->callback)(task_id, targs->arg);
(targs->callback)(task_id, (void*)targs->arg);
}
static void __attribute__ ((noinline)) process_all_tasks_stub() {
@ -205,7 +205,7 @@ static void __attribute__ ((noinline)) process_all_task_groups() {
blockIdx.x = group_id % gridDim.x;
blockIdx.y = (group_id / gridDim.x) % gridDim.y;
blockIdx.z = group_id / (gridDim.x * gridDim.y);
callback(arg);
callback((void*)arg);
}
}
@ -236,8 +236,8 @@ int vx_spawn_threads(uint32_t dimension,
uint32_t num_groups = 1;
uint32_t group_size = 1;
for (uint32_t i = 0; i < 3; ++i) {
uint32_t gd = (i < dimension) ? grid_dim[i] : 1;
uint32_t bd = (i < dimension) ? block_dim[i] : 1;
uint32_t gd = (grid_dim && (i < dimension)) ? grid_dim[i] : 1;
uint32_t bd = (block_dim && (i < dimension)) ? block_dim[i] : 1;
num_groups *= gd;
group_size *= bd;
gridDim.m[i] = gd;