mirror of
https://github.com/vortexgpgpu/vortex.git
synced 2025-04-23 21:39:10 -04:00
minor update
This commit is contained in:
parent
df95c7c4c6
commit
82a417f1f0
4 changed files with 21 additions and 19 deletions
|
@ -11,6 +11,9 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The intrinsics implemented use RISC-V assembler pseudo-directives defined here:
|
||||
// https://sourceware.org/binutils/docs/as/RISC_002dV_002dFormats.html
|
||||
|
||||
#ifndef __VX_INTRINSICS_H__
|
||||
#define __VX_INTRINSICS_H__
|
||||
|
||||
|
@ -126,9 +129,8 @@ inline void vx_pred_n(int condition, int thread_mask) {
|
|||
asm volatile (".insn r %0, 5, 0, x1, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(condition), "r"(thread_mask));
|
||||
}
|
||||
|
||||
typedef void (*vx_wspawn_pfn)();
|
||||
|
||||
// Spawn warps
|
||||
typedef void (*vx_wspawn_pfn)();
|
||||
inline void vx_wspawn(size_t num_warps, vx_wspawn_pfn func_ptr) {
|
||||
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
|
||||
}
|
||||
|
|
|
@ -23,13 +23,13 @@ extern "C" {
|
|||
|
||||
typedef void (*vx_spawn_tasks_cb)(int task_id, void *arg);
|
||||
|
||||
typedef void (*vx_spawn_tasks_ex_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, void *arg);
|
||||
typedef void (*vx_spawn_task_groups_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, void *arg);
|
||||
|
||||
typedef void (*vx_serial_cb)(void *arg);
|
||||
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
|
||||
|
||||
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg);
|
||||
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, void * arg);
|
||||
|
||||
void vx_serial(vx_serial_cb callback, void * arg);
|
||||
|
||||
|
|
|
@ -153,7 +153,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
|||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
typedef struct {
|
||||
vx_spawn_tasks_ex_cb callback;
|
||||
vx_spawn_task_groups_cb callback;
|
||||
void* arg;
|
||||
int group_offset;
|
||||
int warp_batches;
|
||||
|
@ -161,10 +161,10 @@ typedef struct {
|
|||
int warps_per_group;
|
||||
int groups_per_core;
|
||||
int remaining_mask;
|
||||
} wspawn_tasks_ex_args_t;
|
||||
} wspawn_task_groups_args_t;
|
||||
|
||||
static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
||||
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
static void __attribute__ ((noinline)) process_all_task_groups() {
|
||||
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
|
||||
int warps_per_group = targs->warps_per_group;
|
||||
int groups_per_core = targs->groups_per_core;
|
||||
|
@ -182,7 +182,7 @@ static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
|||
int start_group = targs->group_offset + local_group_id;
|
||||
int end_group = start_group + iterations * groups_per_core;
|
||||
|
||||
vx_spawn_tasks_ex_cb callback = targs->callback;
|
||||
vx_spawn_task_groups_cb callback = targs->callback;
|
||||
void* arg = targs->arg;
|
||||
|
||||
for (int group_id = start_group; group_id < end_group; group_id += groups_per_core) {
|
||||
|
@ -190,8 +190,8 @@ static void __attribute__ ((noinline)) process_all_tasks_ex() {
|
|||
}
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) process_all_tasks_ex_stub() {
|
||||
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
static void __attribute__ ((noinline)) process_all_task_groups_stub() {
|
||||
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
int warps_per_group = targs->warps_per_group;
|
||||
int remaining_mask = targs->remaining_mask;
|
||||
int warp_id = vx_warp_id();
|
||||
|
@ -202,19 +202,19 @@ static void __attribute__ ((noinline)) process_all_tasks_ex_stub() {
|
|||
vx_tmc(threads_mask);
|
||||
|
||||
// process all tasks
|
||||
process_all_tasks_ex();
|
||||
process_all_task_groups();
|
||||
|
||||
// disable all warps except warp0
|
||||
vx_tmc(0 == warp_id);
|
||||
}
|
||||
|
||||
void vx_syncthreads(int barrier_id) {
|
||||
wspawn_tasks_ex_args_t* targs = (wspawn_tasks_ex_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
wspawn_task_groups_args_t* targs = (wspawn_task_groups_args_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
int warps_per_group = targs->warps_per_group;
|
||||
vx_barrier(barrier_id, warps_per_group);
|
||||
}
|
||||
|
||||
void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb callback, void * arg) {
|
||||
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, void * arg) {
|
||||
// device specifications
|
||||
int num_cores = vx_num_cores();
|
||||
int warps_per_core = vx_num_warps();
|
||||
|
@ -264,7 +264,7 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
|
|||
int group_offset = core_id * total_groups_per_core + MIN(core_id, remaining_groups_per_core);
|
||||
|
||||
// prepare scheduler arguments
|
||||
wspawn_tasks_ex_args_t wspawn_args = {
|
||||
wspawn_task_groups_args_t wspawn_args = {
|
||||
callback,
|
||||
arg,
|
||||
group_offset,
|
||||
|
@ -277,10 +277,10 @@ void vx_spawn_tasks_ex(int num_groups, int group_size, vx_spawn_tasks_ex_cb call
|
|||
csr_write(VX_CSR_MSCRATCH, &wspawn_args);
|
||||
|
||||
// execute callback on other warps
|
||||
vx_wspawn(active_warps, process_all_tasks_ex_stub);
|
||||
vx_wspawn(active_warps, process_all_task_groups_stub);
|
||||
|
||||
// execute callback on warp0
|
||||
process_all_tasks_ex_stub();
|
||||
process_all_task_groups_stub();
|
||||
|
||||
// wait for spawned tasks to complete
|
||||
vx_wspawn(1, 0);
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include <vx_print.h>
|
||||
#include "common.h"
|
||||
|
||||
void sgemm_kernel(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
|
||||
void kernel_body(int local_task_id, int group_id, int local_group_id, int warps_per_group, kernel_arg_t *arg) {
|
||||
auto local_ptr = reinterpret_cast<TYPE*>(arg->local_addr);
|
||||
auto A_ptr = reinterpret_cast<TYPE*>(arg->A_addr);
|
||||
auto B_ptr = reinterpret_cast<TYPE*>(arg->B_addr);
|
||||
|
@ -53,6 +53,6 @@ void sgemm_kernel(int local_task_id, int group_id, int local_group_id, int warps
|
|||
|
||||
int main() {
|
||||
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
vx_spawn_tasks_ex(arg->num_groups, arg->group_size, (vx_spawn_tasks_ex_cb)sgemm_kernel, arg);
|
||||
vx_spawn_task_groups(arg->num_groups, arg->group_size, (vx_spawn_task_groups_cb)kernel_body, arg);
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue