minor update

This commit is contained in:
Blaise Tine 2024-05-09 22:59:47 -07:00
parent df95c7c4c6
commit 82a417f1f0
4 changed files with 21 additions and 19 deletions

View file

@ -11,6 +11,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// The intrinsics implemented use RISC-V assembler pseudo-directives defined here:
// https://sourceware.org/binutils/docs/as/RISC_002dV_002dFormats.html
#ifndef __VX_INTRINSICS_H__
#define __VX_INTRINSICS_H__
@ -126,9 +129,8 @@ inline void vx_pred_n(int condition, int thread_mask) {
asm volatile (".insn r %0, 5, 0, x1, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(condition), "r"(thread_mask));
}
typedef void (*vx_wspawn_pfn)();
// Spawn warps
typedef void (*vx_wspawn_pfn)();
inline void vx_wspawn(size_t num_warps, vx_wspawn_pfn func_ptr) {
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
}

View file

@ -23,13 +23,13 @@ extern "C" {
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, void *arg);
typedef void (*vx_spawn_task_groups_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, void *arg);
typedef void (*vx_serial_cb)(void *arg);
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg);
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, void * arg);
void vx_serial(vx_serial_cb callback, void * arg);

View file

@ -153,7 +153,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
///////////////////////////////////////////////////////////////////////////////
typedef struct {
vx_spawn_tasks_ex_cb callback;
vx_spawn_task_groups_cb callback;
void* arg;
int group_offset;
int warp_batches;
@ -161,10 +161,10 @@ typedef struct {
int warps_per_group;
int groups_per_core;
int remaining_mask;
} wspawn_tasks_ex_args_t;
} wspawn_task_groups_args_t;
static void __attribute__ ((noinline)) process_all_tasks_ex() {
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
static void __attribute__ ((noinline)) process_all_task_groups() {
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
int warps_per_group = targs->warps_per_group;
int groups_per_core = targs->groups_per_core;
@ -182,7 +182,7 @@ static void __attribute__ ((noinline)) process_all_tasks_ex() {
int start_group = targs->group_offset + local_group_id;
int end_group = start_group + iterations * groups_per_core;
vx_spawn_tasks_ex_cb callback = targs->callback;
vx_spawn_task_groups_cb callback = targs->callback;
void* arg = targs->arg;
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
@ -190,8 +190,8 @@ static void __attribute__ ((noinline)) process_all_tasks_ex() {
}
}
static void __attribute__ ((noinline)) process_all_tasks_ex_stub() {
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
static void __attribute__ ((noinline)) process_all_task_groups_stub() {
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
int warps_per_group = targs->warps_per_group;
int remaining_mask = targs->remaining_mask;
int warp_id = vx_warp_id();
@ -202,19 +202,19 @@ static void __attribute__ ((noinline)) process_all_tasks_ex_stub() {
vx_tmc(threads_mask);
// process all tasks
process_all_tasks_ex();
process_all_task_groups();
// disable all warps except warp0
vx_tmc(0 == warp_id);
}
void vx_syncthreads(int barrier_id) {
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
int warps_per_group = targs->warps_per_group;
vx_barrier(barrier_id, warps_per_group);
}
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg) {
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, void * arg) {
// device specifications
int num_cores = vx_num_cores();
int warps_per_core = vx_num_warps();
@ -264,7 +264,7 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
// prepare scheduler arguments
wspawn_tasks_ex_args_t wspawn_args = {
wspawn_task_groups_args_t wspawn_args = {
callback,
arg,
group_offset,
@ -277,10 +277,10 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
// execute callback on other warps
vx_wspawn(active_warps, process_all_tasks_ex_stub);
vx_wspawn(active_warps, process_all_task_groups_stub);
// execute callback on warp0
process_all_tasks_ex_stub();
process_all_task_groups_stub();
// wait for spawned tasks to complete
vx_wspawn(1, 0);

View file

@ -4,7 +4,7 @@
#include <vx_print.h>
#include "common.h"
void sgemm_kernel(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
void kernel_body(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
auto local_ptr = reinterpret_cast<TYPE*>(arg->local_addr);
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
@ -53,6 +53,6 @@ void sgemm_kernel(int local_task_id, int group_id, int local_group_id, int warps
int main() {
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
vx_spawn_tasks_ex(arg->num_groups, arg->group_size, (vx_spawn_tasks_ex_cb)sgemm_kernel, arg);
vx_spawn_task_groups(arg->num_groups, arg->group_size, (vx_spawn_task_groups_cb)kernel_body, arg);
return 0;
}