simx tensor 844

This commit is contained in:
tinebp 2025-05-22 05:00:39 -07:00
parent 84a4ede9c9
commit 06ad9197c3
22 changed files with 802 additions and 135 deletions

View file

@ -36,6 +36,25 @@ extern "C" {
#define RISCV_CUSTOM2 0x5B
#define RISCV_CUSTOM3 0x7B
#define RISCV_INSN_R(opcode7, funct3, funct7, rd, rs1, rs2) ( \
((funct7 & 0x7F) << 25) | \
((rs2 & 0x1F) << 20) | \
((rs1 & 0x1F) << 15) | \
((funct3 & 0x7) << 12) | \
((rd & 0x1F) << 7) | \
(opcode7 & 0x7F) \
)
#define RISCV_INSN_R4(opcode7, funct3, funct2, rd, rs1, rs2, rs3) ( \
((rs3 & 0x1F) << 27) | \
((funct2 & 0x3) << 25) | \
((rs2 & 0x1F) << 20) | \
((rs1 & 0x1F) << 15) | \
((funct3 & 0x7) << 12) | \
((rd & 0x1F) << 7) | \
(opcode7 & 0x7F) \
)
#define csr_read(csr) ({ \
size_t __r; \
__asm__ __volatile__ ("csrr %0, %1" : "=r" (__r) : "i" (csr) : "memory"); \
@ -221,6 +240,233 @@ inline void vx_fence() {
__asm__ volatile ("fence iorw, iorw");
}
typedef float mf32x8_t __attribute__((vector_size(8*4))); // 8 x f32 registers
#define MAKE_VX_WSETM_F32(f0, f1, f2, f3, f4, f5, f6, f7) \
mf32x8_t ret; \
register float fd0 __asm__(f0); \
register float fd1 __asm__(f1); \
register float fd2 __asm__(f2); \
register float fd3 __asm__(f3); \
register float fd4 __asm__(f4); \
register float fd5 __asm__(f5); \
register float fd6 __asm__(f6); \
register float fd7 __asm__(f7); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd0): "r"(value)); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd1): "r"(value)); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd2): "r"(value)); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd3): "r"(value)); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd4): "r"(value)); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd5): "r"(value)); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd6): "r"(value)); \
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd7): "r"(value)); \
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
return ret
__attribute__((always_inline)) mf32x8_t vx_wsetm_a_f32(size_t value) {
MAKE_VX_WSETM_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15");
}
__attribute__((always_inline)) mf32x8_t vx_wsetm_b_f32(size_t value) {
MAKE_VX_WSETM_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31");
}
__attribute__((always_inline)) mf32x8_t vx_wsetm_c_f32(size_t value) {
MAKE_VX_WSETM_F32("f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7");
}
__attribute__((always_inline)) mf32x8_t vx_wsetm_d_f32(size_t value) {
MAKE_VX_WSETM_F32("f16", "f17", "f18", "f19", "f20", "f21", "f22", "f23");
}
#define MAKE_VX_WLDM_D_F32(f0, f1, f2, f3, f4, f5, f6, f7) \
mf32x8_t ret; \
auto base = (const float*)src + row * ldm; \
register float fd0 __asm__(f0); \
register float fd1 __asm__(f1); \
register float fd2 __asm__(f2); \
register float fd3 __asm__(f3); \
register float fd4 __asm__(f4); \
register float fd5 __asm__(f5); \
register float fd6 __asm__(f6); \
register float fd7 __asm__(f7); \
__asm__ volatile ("flw %0, %1" : "=f"(fd0) : "m"(base[0])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd1) : "m"(base[1])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd2) : "m"(base[2])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd3) : "m"(base[3])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd4) : "m"(base[4])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd5) : "m"(base[5])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd6) : "m"(base[6])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd7) : "m"(base[7])); \
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
return ret
#define MAKE_VX_WLDM_T_F32(f0, f1, f2, f3, f4, f5, f6, f7) \
mf32x8_t ret; \
auto base = (const float*)src + col; \
register float fd0 __asm__(f0); \
register float fd1 __asm__(f1); \
register float fd2 __asm__(f2); \
register float fd3 __asm__(f3); \
register float fd4 __asm__(f4); \
register float fd5 __asm__(f5); \
register float fd6 __asm__(f6); \
register float fd7 __asm__(f7); \
__asm__ volatile ("flw %0, %1" : "=f"(fd0) : "m"(base[0 * ldm])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd1) : "m"(base[1 * ldm])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd2) : "m"(base[2 * ldm])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd3) : "m"(base[3 * ldm])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd4) : "m"(base[4 * ldm])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd5) : "m"(base[5 * ldm])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd6) : "m"(base[6 * ldm])); \
__asm__ volatile ("flw %0, %1" : "=f"(fd7) : "m"(base[7 * ldm])); \
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
return ret
__attribute__((always_inline)) mf32x8_t vx_wldm_ad_f32(const void* src, int row, size_t ldm) {
MAKE_VX_WLDM_D_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15");
}
__attribute__((always_inline)) mf32x8_t vx_wldm_at_f32(const void* src, int col, size_t ldm) {
MAKE_VX_WLDM_T_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15");
}
__attribute__((always_inline)) mf32x8_t vx_wldm_bd_f32(const void* src, int row, size_t ldm) {
MAKE_VX_WLDM_D_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31");
}
__attribute__((always_inline)) mf32x8_t vx_wldm_bt_f32(const void* src, int col, size_t ldm) {
MAKE_VX_WLDM_T_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31");
}
__attribute__((always_inline)) void vx_wstm_f32(void* dst, const mf32x8_t& src, int row, int col, size_t ldm) {
mf32x8_t ret;
auto base = (float*)dst + row * ldm + col;
auto base_2row = base + 2 * ldm;
__asm__ volatile("fsw %0, %1" ::"f"(src[0]), "m"(base[0]));
__asm__ volatile("fsw %0, %1" ::"f"(src[1]), "m"(base[1]));
__asm__ volatile("fsw %0, %1" ::"f"(src[2]), "m"(base_2row[0]));
__asm__ volatile("fsw %0, %1" ::"f"(src[3]), "m"(base_2row[1]));
__asm__ volatile("fsw %0, %1" ::"f"(src[4]), "m"(base[4]));
__asm__ volatile("fsw %0, %1" ::"f"(src[5]), "m"(base[5]));
__asm__ volatile("fsw %0, %1" ::"f"(src[6]), "m"(base_2row[4]));
__asm__ volatile("fsw %0, %1" ::"f"(src[7]), "m"(base_2row[5]));
}
#define MAKE_VX_HMMA_844_D_F32_STEP(fmt, step, rd_lo, rd_hi, rs1, rs2, rs3_lo, rs3_hi) \
__asm__ volatile (".word %1" : "=r"(rd_lo) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 0, 1)), "r"(rs1), "r"(rs2), "r"(rs3_lo)); \
__asm__ volatile (".word %1" : "=r"(rd_hi) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 1, 1)), "r"(rs1), "r"(rs2), "r"(rs3_hi))
#define MAKE_VX_HMMA_844_D_F32(fmt) \
mf32x8_t ret; \
register float fd0 __asm__("f16"); \
register float fd1 __asm__("f17"); \
register float fd2 __asm__("f18"); \
register float fd3 __asm__("f19"); \
register float fd4 __asm__("f20"); \
register float fd5 __asm__("f21"); \
register float fd6 __asm__("f22"); \
register float fd7 __asm__("f23"); \
register float fa0 __asm__("f8") = a[0]; \
register float fa1 __asm__("f9" ) = a[1]; \
register float fa2 __asm__("f10") = a[2]; \
register float fa3 __asm__("f11") = a[3]; \
register float fa4 __asm__("f12") = a[4]; \
register float fa5 __asm__("f13") = a[5]; \
register float fa6 __asm__("f14") = a[6]; \
register float fa7 __asm__("f15") = a[7]; \
register float fb0 __asm__("f24") = b[0]; \
register float fb1 __asm__("f25") = b[1]; \
register float fb2 __asm__("f26") = b[2]; \
register float fb3 __asm__("f27") = b[3]; \
register float fb4 __asm__("f28") = b[4]; \
register float fb5 __asm__("f29") = b[5]; \
register float fb6 __asm__("f30") = b[6]; \
register float fb7 __asm__("f31") = b[7]; \
register float fc0 __asm__("f0") = c[0]; \
register float fc1 __asm__("f1") = c[1]; \
register float fc2 __asm__("f2") = c[2]; \
register float fc3 __asm__("f3") = c[3]; \
register float fc4 __asm__("f4") = c[4]; \
register float fc5 __asm__("f5") = c[5]; \
register float fc6 __asm__("f6") = c[6]; \
register float fc7 __asm__("f7") = c[7]; \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 0, fd0, fd1, fa0, fb0, fc0, fc1); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 1, fd2, fd3, fa0, fb0, fc2, fc3); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 2, fd4, fd5, fa0, fb0, fc4, fc5); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 3, fd6, fd7, fa0, fb0, fc6, fc7); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 4, fd0, fd1, fa1, fb1, fc0, fc1); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 5, fd2, fd3, fa1, fb1, fc2, fc3); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 6, fd4, fd5, fa1, fb1, fc4, fc5); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 7, fd6, fd7, fa1, fb1, fc6, fc7); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 8, fd0, fd1, fa2, fb2, fc0, fc1); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 9, fd2, fd3, fa2, fb2, fc2, fc3); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 10, fd4, fd5, fa2, fb2, fc4, fc5); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 11, fd6, fd7, fa2, fb2, fc6, fc7); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 12, fd0, fd1, fa3, fb3, fc0, fc1); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 13, fd2, fd3, fa3, fb3, fc2, fc3); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 14, fd4, fd5, fa3, fb3, fc4, fc5); \
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 15, fd6, fd7, fa3, fb3, fc6, fc7); \
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
return ret
#define MAKE_VX_HMMA_844_C_F32_STEP(fmt, step, rd_lo, rd_hi, rs1, rs2) \
__asm__ volatile (".word %1" : "=r"(rd_lo) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 0, 0)), "r"(rs1), "r"(rs2), "r"(rd_lo)); \
__asm__ volatile (".word %1" : "=r"(rd_hi) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 1, 0)), "r"(rs1), "r"(rs2), "r"(rd_hi))
#define MAKE_VX_HMMA_844_C_F32(fmt) \
mf32x8_t ret; \
register float fa0 __asm__("f8") = a[0]; \
register float fa1 __asm__("f9" ) = a[1]; \
register float fa2 __asm__("f10") = a[2]; \
register float fa3 __asm__("f11") = a[3]; \
register float fa4 __asm__("f12") = a[4]; \
register float fa5 __asm__("f13") = a[5]; \
register float fa6 __asm__("f14") = a[6]; \
register float fa7 __asm__("f15") = a[7]; \
register float fb0 __asm__("f24") = b[0]; \
register float fb1 __asm__("f25") = b[1]; \
register float fb2 __asm__("f26") = b[2]; \
register float fb3 __asm__("f27") = b[3]; \
register float fb4 __asm__("f28") = b[4]; \
register float fb5 __asm__("f29") = b[5]; \
register float fb6 __asm__("f30") = b[6]; \
register float fb7 __asm__("f31") = b[7]; \
register float fc0 __asm__("f0") = c[0]; \
register float fc1 __asm__("f1") = c[1]; \
register float fc2 __asm__("f2") = c[2]; \
register float fc3 __asm__("f3") = c[3]; \
register float fc4 __asm__("f4") = c[4]; \
register float fc5 __asm__("f5") = c[5]; \
register float fc6 __asm__("f6") = c[6]; \
register float fc7 __asm__("f7") = c[7]; \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 0, fc0, fc1, fa0, fb0); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 1, fc2, fc3, fa0, fb0); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 2, fc4, fc5, fa0, fb0); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 3, fc6, fc7, fa0, fb0); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 4, fc0, fc1, fa1, fb1); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 5, fc2, fc3, fa1, fb1); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 6, fc4, fc5, fa1, fb1); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 7, fc6, fc7, fa1, fb1); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 8, fc0, fc1, fa2, fb2); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 9, fc2, fc3, fa2, fb2); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 10, fc4, fc5, fa2, fb2); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 11, fc6, fc7, fa2, fb2); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 12, fc0, fc1, fa3, fb3); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 13, fc2, fc3, fa3, fb3); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 14, fc4, fc5, fa3, fb3); \
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 15, fc6, fc7, fa3, fb3); \
ret = {fc0, fc1, fc2, fc3, fc4, fc5, fc6, fc7}; \
return ret
__attribute__((always_inline)) mf32x8_t vx_hmma_844_c_f16_f32(const mf32x8_t& a, const mf32x8_t& b, const mf32x8_t& c) {
MAKE_VX_HMMA_844_C_F32(0);
}
__attribute__((always_inline)) mf32x8_t vx_hmma_844_d_f16_f32(const mf32x8_t& a, const mf32x8_t& b, const mf32x8_t& c) {
MAKE_VX_HMMA_844_D_F32(0);
}
#ifdef __cplusplus
}
#endif

View file

@ -16,48 +16,181 @@
#include <stdint.h>
#include <vx_intrinsics.h>
#include <type_traits>
#include <hfloats.h>
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __cplusplus
}
#ifndef NUM_LANES
#define NUM_LANES 8
#endif
namespace tensor {
enum frag_layout_t { row_major, col_major };
enum mem_layout_t { mem_row_major, mem_col_major };
enum frag_use_t { matrix_d, matrix_a, matrix_b, matrix_c };
enum layout_t { row_major, col_major };
template <typename T, frag_layout_t L>
template <frag_use_t U, typename T, layout_t L>
struct fragment {
typedef T DType;
static const frag_layout_t Layout = L;
typedef T VType __attribute__((vector_size(8 * sizeof(void*))));
VType data;
typedef T Type;
static const frag_use_t Use = U;
static const layout_t Layout = L;
mf32x8_t data;
};
template <typename Frag>
void fill_fragment(Frag &frag, size_t value) {
// empty skeleton
__attribute__((always_inline)) void map_operand_ab_32lanes(int tid, int &row, int &col) {
int tg = tid / 4;
// A (row major)
// Figure 7(a) in paper
// row 0~ 3: threadgroups 0 and 2
// row 4~ 7: threadgroups 4 and 6
// row 8~11: threadgroups 1 and 3
// row 12~15: threadgroups 5 and 7
row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
// B (column major)
// NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the
// corrected mapping:
// col 0~ 3: threadgroups 0 and 1
// col 4~ 7: threadgroups 4 and 5
// col 8~11: threadgroups 2 and 3
// col 12~15: threadgroups 6 and 7
col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
}
__attribute__((always_inline)) void map_operand_ab_8lanes(int tid, int &row, int &col) {
int tg = tid / 4;
// A (row major)
// row 0~ 3: threadgroup 0
// row 4~ 7: threadgroup 1
row = tid % 4;
row += tg * 4;
// B (column major)
// col 0~ 3: threadgroup 0
// col 4~ 7: threadgroup 1
col = tid % 4;
col += tg * 4;
}
__attribute__((always_inline)) void map_operand_c_32lanes(int tid, int &row, int &col) {
int tg = tid / 4;
// Figure 7(b), left
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
// Figure 7(b), right
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
}
__attribute__((always_inline)) void map_operand_c_8lanes(int tid, int &row, int &col) {
int tg = tid / 4;
// Figure 7(b), left
col = 0;
row = tg * 4;
// Figure 7(b), right
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
}
__attribute__((always_inline)) void map_operand_ab(int tid, int &row, int &col) {
if constexpr (NUM_LANES == 32) {
map_operand_ab_32lanes(tid, row, col);
} else if constexpr (NUM_LANES == 8) {
map_operand_ab_8lanes(tid, row, col);
} else {
static_assert(NUM_LANES == 32 || NUM_LANES == 8, "NUM_LANES must be 8 or 32");
}
}
__attribute__((always_inline)) void map_operand_c(int tid, int &row, int &col) {
if constexpr (NUM_LANES == 32) {
map_operand_c_32lanes(tid, row, col);
} else if constexpr (NUM_LANES == 8) {
map_operand_c_8lanes(tid, row, col);
} else {
static_assert(NUM_LANES == 32 || NUM_LANES == 8, "NUM_LANES must be 8 or 32");
}
}
template <typename Frag>
void load_matrix_sync(Frag &frag, const void *ptr, size_t ld) {
// empty skeleton
__attribute__((always_inline)) void fill_fragment(Frag &dst, size_t value) {
if constexpr (Frag::Use == matrix_d) {
dst.data = vx_wsetm_d_f32(value);
} else if constexpr (Frag::Use == matrix_a) {
dst.data = vx_wsetm_a_f32(value);
} else if constexpr (Frag::Use == matrix_b) {
dst.data = vx_wsetm_b_f32(value);
} else if constexpr (Frag::Use == matrix_c) {
dst.data = vx_wsetm_c_f32(value);
}
}
template <layout_t mem_layout, typename Frag>
__attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) {
int row, col;
int tid = vx_thread_id();
map_operand_ab(tid, row, col);
if constexpr (Frag::Use == matrix_a) {
if constexpr (Frag::Layout == mem_layout) {
dst.data = vx_wldm_ad_f32(src, row, ldm);
} else {
dst.data = vx_wldm_at_f32(src, col, ldm);
}
} else if constexpr (Frag::Use == matrix_b) {
if constexpr (Frag::Layout == mem_layout) {
dst.data = vx_wldm_bd_f32(src, row, ldm);
} else {
dst.data = vx_wldm_bt_f32(src, col, ldm);
}
} else {
static_assert(false, "Only matrix_a and matrix_b are supported!");
}
}
template <layout_t mem_layout, typename Frag>
__attribute__((always_inline)) void store_matrix_sync(void *dst, const Frag &src, size_t ldm) {
static_assert(Frag::Layout == mem_layout, "fragment layout should match memory!");
int row, col;
int tid = vx_thread_id();
map_operand_c(tid, row, col);
if constexpr (Frag::Use == matrix_c) {
vx_wstm_f32(dst, src.data, row, col, ldm);
} else if constexpr (Frag::Use == matrix_d) {
vx_wstm_f32(dst, src.data, row, col, ldm);
} else {
static_assert(false, "Only matrix_c or matrix_c are supported!");
}
}
// Perform the matrix multiply-accumulate: D = A * B + C
template <typename FragD, typename FragA, typename FragB, typename FragC>
void mma_sync(FragD &D, const FragA &A, const FragB &B, const FragC &C) {
// empty skeleton
}
__attribute__((always_inline)) void mma_sync(FragD &D, const FragA &A, const FragB &B, const FragC &C) {
static_assert(FragA::Use == matrix_a, "A must be matrix_a");
static_assert(FragB::Use == matrix_b, "B must be matrix_b");
static_assert(FragC::Use == matrix_c, "C must be matrix_c");
static_assert(FragD::Use == matrix_d || FragD::Use == matrix_c, "D must be matrix_d or matrix_c");
static_assert(std::is_same_v<typename FragA::Type, typename FragB::Type>, "A and B must have the same type");
static_assert(std::is_same_v<typename FragC::Type, typename FragD::Type>, "C and D must have the same type");
// Store a fragment result back to global memory
template <typename Type, typename Frag>
void store_matrix_sync(void *ptr, const Frag &frag, size_t ld, mem_layout_t layout) {
// empty skeleton
if constexpr (std::is_same_v<typename FragC::Type, float>
&& std::is_same_v<typename FragA::Type, vortex::half_t>) {
if constexpr (FragD::Use == matrix_d) {
D.data = vx_hmma_844_d_f16_f32(A.data, B.data, C.data);
} else {
D.data = vx_hmma_844_c_f16_f32(A.data, B.data, C.data);
}
} else {
static_assert(false, "Unsupported type!");
}
}
} // namespace wmma

109
sim/common/hfloats.h Normal file
View file

@ -0,0 +1,109 @@
// Copyright © 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <cmath>
#include <cstring>
// A minimal IEEE 754 half-precision (16-bit) float implementation
// Provides conversion to/from 32-bit float and basic arithmetic operators.
namespace vortex {
struct half_t {
uint16_t bits;
half_t() = default;
// Construct from float
half_t(float f) { bits = float_to_half(f); }
// Convert to float
operator float() const { return half_to_float(bits); }
// Arithmetic operators
friend half_t operator+(half_t a, half_t b) { return half_t((float)a + (float)b); }
friend half_t operator-(half_t a, half_t b) { return half_t((float)a - (float)b); }
friend half_t operator*(half_t a, half_t b) { return half_t((float)a * (float)b); }
friend half_t operator/(half_t a, half_t b) { return half_t((float)a / (float)b); }
private:
static uint16_t float_to_half(float f) {
uint32_t x;
std::memcpy(&x, &f, sizeof(x));
uint32_t sign = (x >> 16) & 0x8000;
uint32_t mant = x & 0x007FFFFF;
uint32_t exp = x & 0x7F800000;
if (exp >= 0x47800000) {
// Inf or NaN
if (mant && exp == 0x7F800000) {
// NaN: preserve some payload
return static_cast<uint16_t>(sign | 0x0200);
}
// Infinity
return static_cast<uint16_t>(sign | 0x7C00);
}
if (exp <= 0x38000000) {
// Subnormal or zero
if (exp < 0x33000000) {
// Too small: underflows to zero
return static_cast<uint16_t>(sign);
}
// Subnormal
mant |= 0x00800000;
int shift = 113 - (exp >> 23);
mant = (mant >> shift) + ((mant >> (shift - 1)) & 1);
return static_cast<uint16_t>(sign | (mant & 0x03FF));
}
// Normalized number
uint16_t h_exp = static_cast<uint16_t>(((exp - 0x38000000) >> 13) & 0x7C00);
uint16_t h_mant = static_cast<uint16_t>(mant >> 13);
return static_cast<uint16_t>(sign | h_exp | h_mant);
}
static float half_to_float(uint16_t h) {
uint32_t sign = (h & 0x8000) << 16;
uint32_t exp = h & 0x7C00;
uint32_t mant = h & 0x03FF;
uint32_t f;
if (exp == 0x7C00) {
// Inf or NaN
f = sign | 0x7F800000 | (mant << 13);
} else if (exp != 0) {
// Normalized
uint32_t e = ((exp >> 10) + 112) << 23;
f = sign | e | (mant << 13);
} else if (mant != 0) {
// Subnormal
mant <<= 1;
int e = -1;
while (!(mant & 0x0400)) {
mant <<= 1;
e--;
}
mant &= 0x03FF;
uint32_t f_e = static_cast<uint32_t>(e + 1 + 127) << 23;
f = sign | f_e | (mant << 13);
} else {
// Zero
f = sign;
}
float result;
std::memcpy(&result, &f, sizeof(result));
return result;
}
};
} // namespace vortex

View file

@ -24,6 +24,11 @@ SRCS += $(SRC_DIR)/execute.cpp $(SRC_DIR)/func_unit.cpp
SRCS += $(SRC_DIR)/cache_sim.cpp $(SRC_DIR)/mem_sim.cpp $(SRC_DIR)/local_mem.cpp $(SRC_DIR)/mem_coalescer.cpp
SRCS += $(SRC_DIR)/dcrs.cpp $(SRC_DIR)/types.cpp
# Add TPU extension sources
ifneq ($(findstring -DEXT_TPU_ENABLE, $(CONFIGS)),)
SRCS += $(SRC_DIR)/tensor_unit.cpp
endif
# Add V extension sources
ifneq ($(findstring -DEXT_V_ENABLE, $(CONFIGS)),)
SRCS += $(SRC_DIR)/voperands.cpp

View file

@ -39,6 +39,9 @@ Core::Core(const SimContext& ctx,
, core_id_(core_id)
, socket_(socket)
, arch_(arch)
#ifdef EXT_TPU_ENABLE
, tensor_unit_(TensorUnit::Create("tpu", arch, this))
#endif
#ifdef EXT_V_ENABLE
, vec_unit_(VecUnit::Create("vpu", arch, this))
#endif
@ -136,6 +139,9 @@ Core::Core(const SimContext& ctx,
dispatchers_.at((int)FUType::FPU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_FPU_BLOCKS, NUM_FPU_LANES);
dispatchers_.at((int)FUType::LSU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_LSU_BLOCKS, NUM_LSU_LANES);
dispatchers_.at((int)FUType::SFU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_SFU_BLOCKS, NUM_SFU_LANES);
#ifdef EXT_TPU_ENABLE
dispatchers_.at((int)FUType::TPU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_VPU_BLOCKS, NUM_VPU_LANES);
#endif
#ifdef EXT_V_ENABLE
dispatchers_.at((int)FUType::VPU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_VPU_BLOCKS, NUM_VPU_LANES);
#endif
@ -145,6 +151,9 @@ Core::Core(const SimContext& ctx,
func_units_.at((int)FUType::FPU) = SimPlatform::instance().create_object<FpuUnit>(this);
func_units_.at((int)FUType::LSU) = SimPlatform::instance().create_object<LsuUnit>(this);
func_units_.at((int)FUType::SFU) = SimPlatform::instance().create_object<SfuUnit>(this);
#ifdef EXT_TPU_ENABLE
func_units_.at((int)FUType::TPU) = SimPlatform::instance().create_object<TpuUnit>(this);
#endif
#ifdef EXT_V_ENABLE
func_units_.at((int)FUType::VPU) = SimPlatform::instance().create_object<VpuUnit>(this);
#endif
@ -341,6 +350,9 @@ void Core::issue() {
default: assert(false);
}
} break;
#ifdef EXT_TPU_ENABLE
case FUType::TPU: ++perf_stats_.scrb_tpu; break;
#endif
#ifdef EXT_V_ENABLE
case FUType::VPU: ++perf_stats_.scrb_vpu; break;
#endif

View file

@ -59,6 +59,9 @@ public:
uint64_t scrb_sfu;
uint64_t scrb_csrs;
uint64_t scrb_wctl;
#ifdef EXT_TPU_ENABLE
uint64_t scrb_tpu;
#endif
#ifdef EXT_V_ENABLE
uint64_t vinstrs;
uint64_t scrb_vpu;
@ -83,6 +86,9 @@ public:
, scrb_sfu(0)
, scrb_csrs(0)
, scrb_wctl(0)
#ifdef EXT_TPU_ENABLE
, scrb_tpu(0)
#endif
#ifdef EXT_V_ENABLE
, vinstrs(0)
, scrb_vpu(0)
@ -155,6 +161,12 @@ public:
return emulator_.dcache_write(data, addr, size);
}
#ifdef EXT_TPU_ENABLE
TensorUnit::Ptr& tensor_unit() {
return tensor_unit_;
}
#endif
#ifdef EXT_V_ENABLE
VecUnit::Ptr& vec_unit() {
return vec_unit_;
@ -182,6 +194,10 @@ private:
Socket* socket_;
const Arch& arch_;
#ifdef EXT_TPU_ENABLE
TensorUnit::Ptr tensor_unit_;
#endif
#ifdef EXT_V_ENABLE
VecUnit::Ptr vec_unit_;
#endif

View file

@ -55,12 +55,12 @@ static const std::unordered_map<Opcode, InstType> sc_instTable = {
};
static const char* op_string(const Instr &instr) {
auto opcode = instr.getOpcode();
auto opcode = instr.getOpcode();
auto funct2 = instr.getFunct2();
auto funct3 = instr.getFunct3();
auto funct7 = instr.getFunct7();
auto rd = instr.getDestReg();
auto rs1 = instr.getSrcReg(1);
auto rd = instr.getDestReg();
auto rs1 = instr.getSrcReg(1);
auto imm = instr.getImm();
switch (opcode) {
@ -386,23 +386,29 @@ static const char* op_string(const Instr &instr) {
default:
std::abort();
}
case 1:
switch (funct3) {
case 0: // gfx reserved
std::abort();
default:
std::abort();
}
#ifdef EXT_TPU_ENABLE
case 2:
switch (funct3) {
case 0: return "HMMA844";
default:
std::abort();
}
#endif
default:
std::abort();
}
case Opcode::EXT2:
switch(funct3) {
case 0: // reserved
case 1: // reserved
case 0: // gfx reserved
case 1: // gfx reserved
std::abort();
case 2:
switch (funct2) {
case 0: return "MMADD.u4_i32";
case 1: return "MMADD.u8_i32";
case 2: return "MMADD.f16_f32";
case 3: return "MMADD.bf16_f32";
default:
std::abort();
}
default:
std::abort();
}
@ -468,7 +474,7 @@ std::ostream &operator<<(std::ostream &os, const Instr &instr) {
}
}
std::shared_ptr<Instr> Emulator::decode(uint32_t code) const {
Instr::Ptr Emulator::decode(uint32_t code) const {
auto instr = std::allocate_shared<Instr>(instr_pool_);
auto op = Opcode((code >> shift_opcode) & mask_opcode);
instr->setOpcode(op);
@ -574,6 +580,31 @@ std::shared_ptr<Instr> Emulator::decode(uint32_t code) const {
std::abort();
}
break;
#ifdef EXT_TPU_ENABLE
case 2: {
switch (funct3) {
case 0: { // HMMA844
uint32_t fmt = rd;
uint32_t steps = rs1 >> 1;
uint32_t step = steps % 4;
uint32_t set = steps / 4;
uint32_t rd_pair = rs1 & 0x1;
uint32_t use_d = rs2;
uint32_t base_rd = (use_d ? 16 : 0) + (step * 2 + rd_pair); // C/D
uint32_t base_rs1 = 8 + set; // A
uint32_t base_rs2 = 24 + set; // B
uint32_t base_rs3 = 0 + step; // C
instr->setImm((fmt << 2) + step); // fmt + step
instr->setDestReg(base_rd, RegType::Float);
instr->addSrcReg(base_rs1, RegType::Float);
instr->addSrcReg(base_rs2, RegType::Float);
instr->addSrcReg(base_rs3, RegType::Float);
} break;
default:
std::abort();
}
} break;
#endif
default:
std::abort();
}

View file

@ -79,6 +79,9 @@ Emulator::Emulator(const Arch &arch, const DCRS &dcrs, Core* core)
, warps_(arch.num_warps(), arch.num_threads())
, barriers_(arch.num_barriers(), 0)
, ipdom_size_(arch.num_threads()-1)
#ifdef EXT_TPU_ENABLE
, tensor_unit_(core->tensor_unit())
#endif
#ifdef EXT_V_ENABLE
, vec_unit_(core->vec_unit())
#endif
@ -162,7 +165,7 @@ instr_trace_t* Emulator::step() {
if (scheduled_warp == -1)
return nullptr;
// suspend warp until decode
// get scheduled warp
auto& warp = warps_.at(scheduled_warp);
assert(warp.tmask.any());

View file

@ -19,6 +19,10 @@
#include <stack>
#include <mem.h>
#include "types.h"
#include "instr.h"
#ifdef EXT_TPU_ENABLE
#include "tensor_unit.h"
#endif
#ifdef EXT_V_ENABLE
#include "vec_unit.h"
#endif
@ -105,7 +109,7 @@ public:
private:
std::shared_ptr<Instr> decode(uint32_t code) const;
Instr::Ptr decode(uint32_t code) const;
void execute(const Instr &instr, uint32_t wid, instr_trace_t *trace);
@ -146,6 +150,10 @@ private:
Word csr_mscratch_;
wspawn_t wspawn_;
#ifdef EXT_TPU_ENABLE
TensorUnit::Ptr tensor_unit_;
#endif
#ifdef EXT_V_ENABLE
VecUnit::Ptr vec_unit_;
#endif

View file

@ -1412,6 +1412,24 @@ void Emulator::execute(const Instr &instr, uint32_t wid, instr_trace_t *trace) {
std::abort();
}
} break;
#ifdef EXT_TPU_ENABLE
case 2: {
switch (funct3) {
case 0: { // HMMA844
trace->fu_type = FUType::TPU;
trace->tpu_type = TpuType::HMMA844;
auto trace_data = std::make_shared<TensorUnit::ExeTraceData>();
trace->data = trace_data;
uint32_t fmt = immsrc >> 2;
uint32_t step = immsrc & 0x3;
tensor_unit_->hmma844(wid, fmt, step, rs1_data, rs2_data, rs3_data, rd_data, trace_data.get());
rd_write = true;
} break;
default:
std::abort();
}
} break;
#endif
default:
std::abort();
}

View file

@ -317,6 +317,25 @@ void SfuUnit::tick() {
///////////////////////////////////////////////////////////////////////////////
#ifdef EXT_TPU_ENABLE
TpuUnit::TpuUnit(const SimContext& ctx, Core* core)
: FuncUnit(ctx, core, "tpu-unit")
{
// bind tensor unit
for (uint32_t iw = 0; iw < ISSUE_WIDTH; ++iw) {
this->Inputs.at(iw).bind(&core_->tensor_unit()->Inputs.at(iw));
core_->tensor_unit()->Outputs.at(iw).bind(&this->Outputs.at(iw));
}
}
void TpuUnit::tick() {
// use tensor_unit
}
#endif
///////////////////////////////////////////////////////////////////////////////
#ifdef EXT_V_ENABLE
VpuUnit::VpuUnit(const SimContext& ctx, Core* core)

View file

@ -110,6 +110,19 @@ public:
///////////////////////////////////////////////////////////////////////////////
#ifdef EXT_TPU_ENABLE
class TpuUnit : public FuncUnit {
public:
TpuUnit(const SimContext& ctx, Core*);
void tick() override;
};
#endif
///////////////////////////////////////////////////////////////////////////////
#ifdef EXT_V_ENABLE
class VpuUnit : public FuncUnit {

View file

@ -125,6 +125,8 @@ enum VectorAttrMask {
class Instr {
public:
using Ptr = std::shared_ptr<Instr>;
Instr()
: opcode_(Opcode::NONE)
, num_rsrcs_(0)

View file

@ -72,6 +72,9 @@ public:
AluType alu_type;
FpuType fpu_type;
SfuType sfu_type;
#ifdef EXT_TPU_ENABLE
TpuType tpu_type;
#endif
#ifdef EXT_V_ENABLE
VpuType vpu_type;
#endif

View file

@ -13,40 +13,20 @@
// limitations under the License.
#include "tensor_unit.h"
#include "core.h"
using namespace vortex;
template <typename T>
class FMAD : public SimObject<FMAD<T>> {
public:
SimPort<T> Input;
SimPort<T> Output;
FMAD(const SimContext &ctx, const char* name)
: SimObject<FMAD<T>>(ctx, name)
, Input(this)
, Output(this)
{}
virtual ~FMAD() {}
void reset() {
//--
}
void tick() {
//--
}
};
class TensorUnit::Impl {
public:
Impl(TensorUnit* simobject, const Config& config, Core* core)
Impl(TensorUnit* simobject, const Arch& arch, Core* core)
: simobject_(simobject)
, config_(config)
, core_(core)
, arch_(arch)
, perf_stats_()
{}
{
//--
}
~Impl() {
// Destructor logic if needed
@ -57,7 +37,48 @@ public:
}
void tick() {
// Implement the tick logic here
for (uint32_t iw = 0; iw < ISSUE_WIDTH; ++iw) {
auto& input = simobject_->Inputs.at(iw);
if (input.empty())
return;
auto trace = input.front();
int delay = 0;
switch (trace->tpu_type) {
case TpuType::HMMA844:
delay = 4;
break;
default:
std::abort();
}
simobject_->Outputs.at(iw).push(trace, 2 + delay);
DT(3, simobject_->name() << ": op=" << trace->tpu_type << ", " << *trace);
input.pop();
}
}
void hmma844(uint32_t wid,
uint32_t fmt, uint32_t step,
const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
ExeTraceData* trace_data) {
uint32_t num_octects = arch_.num_threads() / 8;
uint32_t threadgroup_lane_offset = 4 * num_octects;
for (uint32_t i = 0; i < num_octects; ++i) {
std::vector<reg_data_t> octet_A(8);
std::vector<reg_data_t> octet_B(8);
std::vector<reg_data_t> octet_C(8);
std::vector<reg_data_t> octet_D(8);
for (uint32_t j = 0; j < 8; ++j) {
octet_A[j] = rs1_data[i * 8 + j];
octet_B[j] = rs2_data[i * 8 + j];
octet_C[j] = rs3_data[i * 8 + j];
octet_D[j] = rd_data[i * 8 + j];
}
}
}
const PerfStats& perf_stats() const {
@ -66,18 +87,18 @@ public:
private:
TensorUnit* simobject_;
Config config_;
Core* core_;
Arch arch_;
PerfStats perf_stats_;
};
///////////////////////////////////////////////////////////////////////////////
TensorUnit::TensorUnit(const SimContext &ctx, const char* name, const Config& config, Core* core)
TensorUnit::TensorUnit(const SimContext &ctx, const char* name, const Arch& arch, Core* core)
: SimObject<TensorUnit>(ctx, name)
, Inputs(config.num_ports, this)
, Outputs(config.num_ports, this)
, impl_(new Impl(this, config, core))
, Inputs(ISSUE_WIDTH, this)
, Outputs(ISSUE_WIDTH, this)
, impl_(new Impl(this, arch, core))
{}
TensorUnit::~TensorUnit() {
@ -94,4 +115,14 @@ void TensorUnit::tick() {
const TensorUnit::PerfStats &TensorUnit::perf_stats() const {
return impl_->perf_stats();
}
void TensorUnit::hmma844(uint32_t wid,
uint32_t fmt, uint32_t step,
const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
ExeTraceData* trace_data) {
impl_->hmma844(wid, fmt, step, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
}

View file

@ -22,15 +22,10 @@ class Core;
class TensorUnit : public SimObject<TensorUnit> {
public:
struct Config {
uint8_t num_ports;
uint8_t mac_latency;
Config()
: num_ports(0)
, mac_latency(0)
{}
};
struct ExeTraceData : public ITraceData {
using Ptr = std::shared_ptr<ExeTraceData>;
};
struct PerfStats {
uint64_t latency;
@ -48,14 +43,21 @@ public:
std::vector<SimPort<instr_trace_t*>> Inputs;
std::vector<SimPort<instr_trace_t*>> Outputs;
TensorUnit(const SimContext &ctx, const char* name, const Config& config, Core* core);
TensorUnit(const SimContext &ctx, const char* name, const Arch& arch, Core* core);
virtual ~TensorUnit();
virtual void reset();
virtual void tick();
void hmma844(uint32_t wid,
uint32_t fmt, uint32_t step,
const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
ExeTraceData* trace_data);
const PerfStats& perf_stats() const;
private:

View file

@ -127,6 +127,9 @@ enum class FUType {
LSU,
FPU,
SFU,
#ifdef EXT_TPU_ENABLE
TPU,
#endif
#ifdef EXT_V_ENABLE
VPU,
#endif
@ -139,6 +142,9 @@ inline std::ostream &operator<<(std::ostream &os, const FUType& type) {
case FUType::LSU: os << "LSU"; break;
case FUType::FPU: os << "FPU"; break;
case FUType::SFU: os << "SFU"; break;
#ifdef EXT_TPU_ENABLE
case FUType::TPU: os << "TPU"; break;
#endif
#ifdef EXT_V_ENABLE
case FUType::VPU: os << "VPU"; break;
#endif
@ -286,6 +292,20 @@ inline std::ostream &operator<<(std::ostream &os, const SfuType& type) {
///////////////////////////////////////////////////////////////////////////////
enum class TpuType {
HMMA844 = 0,
};
inline std::ostream &operator<<(std::ostream &os, const TpuType& type) {
switch (type) {
case TpuType::HMMA844: os << "HMMA844"; break;
default: assert(false);
}
return os;
}
///////////////////////////////////////////////////////////////////////////////
enum class VpuType {
VSET = 0,

View file

@ -96,10 +96,6 @@ public:
DT(3, simobject_->name() << ": op=" << trace->vpu_type << ", " << *trace);
if (trace->eop && trace->fetch_stall) {
core_->resume(trace->wid);
}
input.pop();
}
}

View file

@ -56,11 +56,7 @@ public:
std::vector<SimPort<instr_trace_t*>> Inputs;
std::vector<SimPort<instr_trace_t*>> Outputs;
VecUnit(const SimContext& ctx,
const char* name,
const Arch& arch,
Core* core);
VecUnit(const SimContext& ctx, const char* name, const Arch& arch, Core* core);
~VecUnit();
void reset();

View file

@ -5,7 +5,7 @@
#include <hfloats.h>
#ifndef I_TYPE
#define I_TYPE vortex::half_t
#define I_TYPE float
#endif
#ifndef O_TYPE

View file

@ -1,46 +1,46 @@
#include "common.h"
#include <vx_spawn.h>
#include <vx_tensor.h>
#include "common.h"
void kernel_body(kernel_arg_t* __UNIFORM__ arg) {
auto A = reinterpret_cast<I_TYPE*>(arg->A_addr);
auto B = reinterpret_cast<I_TYPE*>(arg->B_addr);
auto C = reinterpret_cast<O_TYPE*>(arg->C_addr);
void kernel_body(kernel_arg_t *__UNIFORM__ arg) {
auto A = reinterpret_cast<I_TYPE *>(arg->A_addr);
auto B = reinterpret_cast<I_TYPE *>(arg->B_addr);
auto C = reinterpret_cast<O_TYPE *>(arg->C_addr);
tensor::fragment<tensor::half_t, tensor::row_major> fragA;
tensor::fragment<tensor::half_t, tensor::row_major> fragB;
tensor::fragment<float, tensor::row_major> fragC;
tensor::fragment<tensor::matrix_a, I_TYPE, tensor::row_major> fragA;
tensor::fragment<tensor::matrix_b, I_TYPE, tensor::col_major> fragB;
tensor::fragment<tensor::matrix_c, O_TYPE, tensor::row_major> fragC;
// calculate tile row & column based on block index
uint32_t tile_row = blockIdx.y * arg->tileM;
uint32_t tile_col = blockIdx.x * arg->tileN;
// calculate tile row & column based on block index
uint32_t tile_row = blockIdx.y * arg->tileM;
uint32_t tile_col = blockIdx.x * arg->tileN;
uint32_t N = arg->N;
uint32_t K = arg->K;
uint32_t tileK = arg->tileK;
uint32_t N = arg->N;
uint32_t K = arg->K;
uint32_t tileK = arg->tileK;
// Initialize accumulator tile to zero
tensor::fill_fragment(fragC, 0.0f);
// Initialize accumulator tile to zero
tensor::fill_fragment(fragC, 0);
for (int i = 0; i < K; i += tileK) {
// Load A tile
auto tileA = A + (tile_row * K + i);
tensor::load_matrix_sync(fragA, tileA, K);
for (int i = 0; i < K; i += tileK) {
// Load A tile
auto tileA = A + (tile_row * K + i);
tensor::load_matrix_sync<tensor::row_major>(fragA, tileA, K);
// Load B tile
auto tileB = B + (i * k + tile_col);
tensor::load_matrix_sync(fragB, tileB, K);
// Load B tile
auto tileB = B + (i * K + tile_col);
tensor::load_matrix_sync<tensor::row_major>(fragB, tileB, K);
// Matrix multiply-accumulate: c += a * b
tensor::mma_sync(fragC, fragA, fragB, fragC);
}
// Matrix multiply-accumulate: c += a * b
tensor::mma_sync(fragC, fragA, fragB, fragC);
}
// Store the computed C tile
auto tileC = C + (tile_row * N + tile_col);
tensor::store_matrix_sync(tileC, fragC, N, tensor::mem_row_major);
// Store the computed C tile
auto tileC = C + (tile_row * N + tile_col);
tensor::store_matrix_sync<tensor::row_major>(tileC, fragC, N);
}
int main() {
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg);
kernel_arg_t *arg = (kernel_arg_t *)csr_read(VX_CSR_MSCRATCH);
return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg);
}

View file

@ -192,11 +192,15 @@ int main(int argc, char *argv[]) {
uint64_t NT;
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &NT));
std::cout << "GPU warp size: " << NT << std::endl;
if (NT < 4) {
std::cout << "Error: warp size must be at least 4 threads!" << std::endl;
return -1;
}
std::cout << "GPU warp size: " << NT << " threads" << std::endl;
uint64_t isa_flags;
RT_CHECK(vx_dev_caps(device, VX_CAPS_ISA_FLAGS, &isa_flags));
uint32_t XlenB = 4 * VX_ISA_ARCH(isa_flags);
uint32_t XlenB = VX_ISA_ARCH(isa_flags) / 8;
std::cout << "GPU XLEN: " << 8 * XlenB << std::endl;
// tile format ratio
@ -206,22 +210,22 @@ int main(int argc, char *argv[]) {
// determine tensor tile size
uint32_t logNT = log2(NT);
uint32_t tileM = 4 * (1 << (logNT / 2)) * o_ratio;
uint32_t tileN = (logNT % 2 == 0) ? tileM / 2 : tileN;
uint32_t tileN = (logNT % 2 == 0) ? (tileM / 2) : tileM;
uint32_t tileK = std::min(tileM, tileN) * i_ratio;
std::cout << "GPU tensor tileM=" << tileM << ", tileN=" << tileM << ", tileK=" << tileM << std::endl;
std::cout << "GPU tensor tileM=" << tileM << ", tileN=" << tileM << ", tileK=" << tileK << std::endl;
if ((M & (tileM - 1)) != 0) {
if ((M % tileM) != 0) {
std::cout << "Error: M must be a multiple of tensor tileM!" << std::endl;
return -1;
}
if ((N & (tileN - 1)) != 0) {
if ((N % tileN) != 0) {
std::cout << "Error: M must be a multiple of tensor tileN!" << std::endl;
return -1;
}
if ((K & (tileK - 1)) != 0) {
if ((K % tileK) != 0) {
std::cout << "Error: M must be a multiple of tensor tileK!" << std::endl;
return -1;
}
@ -234,8 +238,8 @@ int main(int argc, char *argv[]) {
size_t sizeB = K * N;
size_t sizeC = M * N;
std::cout << "input data type: " << Comparator<I_TYPE>::type_str() << std::endl;
std::cout << "output data type: " << Comparator<O_TYPE>::type_str() << std::endl;
std::cout << "input data type: " << Comparator<I_TYPE>::type_str() << " (" << sizeof(I_TYPE) << " bytes)" << std::endl;
std::cout << "output data type: " << Comparator<O_TYPE>::type_str() << " (" << sizeof(O_TYPE) << " bytes)" << std::endl;
std::cout << "matrix A: " << M << "x" << K << std::endl;
std::cout << "matrix B: " << K << "x" << N << std::endl;
std::cout << "matrix C: " << M << "x" << N << std::endl;