mirror of
https://github.com/vortexgpgpu/vortex.git
synced 2025-06-28 09:37:38 -04:00
simx tensor 844
This commit is contained in:
parent
84a4ede9c9
commit
06ad9197c3
22 changed files with 802 additions and 135 deletions
|
@ -36,6 +36,25 @@ extern "C" {
|
|||
#define RISCV_CUSTOM2 0x5B
|
||||
#define RISCV_CUSTOM3 0x7B
|
||||
|
||||
#define RISCV_INSN_R(opcode7, funct3, funct7, rd, rs1, rs2) ( \
|
||||
((funct7 & 0x7F) << 25) | \
|
||||
((rs2 & 0x1F) << 20) | \
|
||||
((rs1 & 0x1F) << 15) | \
|
||||
((funct3 & 0x7) << 12) | \
|
||||
((rd & 0x1F) << 7) | \
|
||||
(opcode7 & 0x7F) \
|
||||
)
|
||||
|
||||
#define RISCV_INSN_R4(opcode7, funct3, funct2, rd, rs1, rs2, rs3) ( \
|
||||
((rs3 & 0x1F) << 27) | \
|
||||
((funct2 & 0x3) << 25) | \
|
||||
((rs2 & 0x1F) << 20) | \
|
||||
((rs1 & 0x1F) << 15) | \
|
||||
((funct3 & 0x7) << 12) | \
|
||||
((rd & 0x1F) << 7) | \
|
||||
(opcode7 & 0x7F) \
|
||||
)
|
||||
|
||||
#define csr_read(csr) ({ \
|
||||
size_t __r; \
|
||||
__asm__ __volatile__ ("csrr %0, %1" : "=r" (__r) : "i" (csr) : "memory"); \
|
||||
|
@ -221,6 +240,233 @@ inline void vx_fence() {
|
|||
__asm__ volatile ("fence iorw, iorw");
|
||||
}
|
||||
|
||||
typedef float mf32x8_t __attribute__((vector_size(8*4))); // 8 x f32 registers
|
||||
|
||||
#define MAKE_VX_WSETM_F32(f0, f1, f2, f3, f4, f5, f6, f7) \
|
||||
mf32x8_t ret; \
|
||||
register float fd0 __asm__(f0); \
|
||||
register float fd1 __asm__(f1); \
|
||||
register float fd2 __asm__(f2); \
|
||||
register float fd3 __asm__(f3); \
|
||||
register float fd4 __asm__(f4); \
|
||||
register float fd5 __asm__(f5); \
|
||||
register float fd6 __asm__(f6); \
|
||||
register float fd7 __asm__(f7); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd0): "r"(value)); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd1): "r"(value)); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd2): "r"(value)); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd3): "r"(value)); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd4): "r"(value)); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd5): "r"(value)); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd6): "r"(value)); \
|
||||
__asm__ volatile("fmv.w.x %0, %1" : "=f"(fd7): "r"(value)); \
|
||||
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
|
||||
return ret
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wsetm_a_f32(size_t value) {
|
||||
MAKE_VX_WSETM_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15");
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wsetm_b_f32(size_t value) {
|
||||
MAKE_VX_WSETM_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31");
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wsetm_c_f32(size_t value) {
|
||||
MAKE_VX_WSETM_F32("f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7");
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wsetm_d_f32(size_t value) {
|
||||
MAKE_VX_WSETM_F32("f16", "f17", "f18", "f19", "f20", "f21", "f22", "f23");
|
||||
}
|
||||
|
||||
#define MAKE_VX_WLDM_D_F32(f0, f1, f2, f3, f4, f5, f6, f7) \
|
||||
mf32x8_t ret; \
|
||||
auto base = (const float*)src + row * ldm; \
|
||||
register float fd0 __asm__(f0); \
|
||||
register float fd1 __asm__(f1); \
|
||||
register float fd2 __asm__(f2); \
|
||||
register float fd3 __asm__(f3); \
|
||||
register float fd4 __asm__(f4); \
|
||||
register float fd5 __asm__(f5); \
|
||||
register float fd6 __asm__(f6); \
|
||||
register float fd7 __asm__(f7); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd0) : "m"(base[0])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd1) : "m"(base[1])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd2) : "m"(base[2])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd3) : "m"(base[3])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd4) : "m"(base[4])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd5) : "m"(base[5])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd6) : "m"(base[6])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd7) : "m"(base[7])); \
|
||||
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
|
||||
return ret
|
||||
|
||||
#define MAKE_VX_WLDM_T_F32(f0, f1, f2, f3, f4, f5, f6, f7) \
|
||||
mf32x8_t ret; \
|
||||
auto base = (const float*)src + col; \
|
||||
register float fd0 __asm__(f0); \
|
||||
register float fd1 __asm__(f1); \
|
||||
register float fd2 __asm__(f2); \
|
||||
register float fd3 __asm__(f3); \
|
||||
register float fd4 __asm__(f4); \
|
||||
register float fd5 __asm__(f5); \
|
||||
register float fd6 __asm__(f6); \
|
||||
register float fd7 __asm__(f7); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd0) : "m"(base[0 * ldm])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd1) : "m"(base[1 * ldm])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd2) : "m"(base[2 * ldm])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd3) : "m"(base[3 * ldm])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd4) : "m"(base[4 * ldm])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd5) : "m"(base[5 * ldm])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd6) : "m"(base[6 * ldm])); \
|
||||
__asm__ volatile ("flw %0, %1" : "=f"(fd7) : "m"(base[7 * ldm])); \
|
||||
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
|
||||
return ret
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wldm_ad_f32(const void* src, int row, size_t ldm) {
|
||||
MAKE_VX_WLDM_D_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15");
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wldm_at_f32(const void* src, int col, size_t ldm) {
|
||||
MAKE_VX_WLDM_T_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15");
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wldm_bd_f32(const void* src, int row, size_t ldm) {
|
||||
MAKE_VX_WLDM_D_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31");
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_wldm_bt_f32(const void* src, int col, size_t ldm) {
|
||||
MAKE_VX_WLDM_T_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31");
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) void vx_wstm_f32(void* dst, const mf32x8_t& src, int row, int col, size_t ldm) {
|
||||
mf32x8_t ret;
|
||||
auto base = (float*)dst + row * ldm + col;
|
||||
auto base_2row = base + 2 * ldm;
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[0]), "m"(base[0]));
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[1]), "m"(base[1]));
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[2]), "m"(base_2row[0]));
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[3]), "m"(base_2row[1]));
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[4]), "m"(base[4]));
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[5]), "m"(base[5]));
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[6]), "m"(base_2row[4]));
|
||||
__asm__ volatile("fsw %0, %1" ::"f"(src[7]), "m"(base_2row[5]));
|
||||
}
|
||||
|
||||
#define MAKE_VX_HMMA_844_D_F32_STEP(fmt, step, rd_lo, rd_hi, rs1, rs2, rs3_lo, rs3_hi) \
|
||||
__asm__ volatile (".word %1" : "=r"(rd_lo) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 0, 1)), "r"(rs1), "r"(rs2), "r"(rs3_lo)); \
|
||||
__asm__ volatile (".word %1" : "=r"(rd_hi) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 1, 1)), "r"(rs1), "r"(rs2), "r"(rs3_hi))
|
||||
|
||||
#define MAKE_VX_HMMA_844_D_F32(fmt) \
|
||||
mf32x8_t ret; \
|
||||
register float fd0 __asm__("f16"); \
|
||||
register float fd1 __asm__("f17"); \
|
||||
register float fd2 __asm__("f18"); \
|
||||
register float fd3 __asm__("f19"); \
|
||||
register float fd4 __asm__("f20"); \
|
||||
register float fd5 __asm__("f21"); \
|
||||
register float fd6 __asm__("f22"); \
|
||||
register float fd7 __asm__("f23"); \
|
||||
register float fa0 __asm__("f8") = a[0]; \
|
||||
register float fa1 __asm__("f9" ) = a[1]; \
|
||||
register float fa2 __asm__("f10") = a[2]; \
|
||||
register float fa3 __asm__("f11") = a[3]; \
|
||||
register float fa4 __asm__("f12") = a[4]; \
|
||||
register float fa5 __asm__("f13") = a[5]; \
|
||||
register float fa6 __asm__("f14") = a[6]; \
|
||||
register float fa7 __asm__("f15") = a[7]; \
|
||||
register float fb0 __asm__("f24") = b[0]; \
|
||||
register float fb1 __asm__("f25") = b[1]; \
|
||||
register float fb2 __asm__("f26") = b[2]; \
|
||||
register float fb3 __asm__("f27") = b[3]; \
|
||||
register float fb4 __asm__("f28") = b[4]; \
|
||||
register float fb5 __asm__("f29") = b[5]; \
|
||||
register float fb6 __asm__("f30") = b[6]; \
|
||||
register float fb7 __asm__("f31") = b[7]; \
|
||||
register float fc0 __asm__("f0") = c[0]; \
|
||||
register float fc1 __asm__("f1") = c[1]; \
|
||||
register float fc2 __asm__("f2") = c[2]; \
|
||||
register float fc3 __asm__("f3") = c[3]; \
|
||||
register float fc4 __asm__("f4") = c[4]; \
|
||||
register float fc5 __asm__("f5") = c[5]; \
|
||||
register float fc6 __asm__("f6") = c[6]; \
|
||||
register float fc7 __asm__("f7") = c[7]; \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 0, fd0, fd1, fa0, fb0, fc0, fc1); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 1, fd2, fd3, fa0, fb0, fc2, fc3); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 2, fd4, fd5, fa0, fb0, fc4, fc5); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 3, fd6, fd7, fa0, fb0, fc6, fc7); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 4, fd0, fd1, fa1, fb1, fc0, fc1); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 5, fd2, fd3, fa1, fb1, fc2, fc3); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 6, fd4, fd5, fa1, fb1, fc4, fc5); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 7, fd6, fd7, fa1, fb1, fc6, fc7); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 8, fd0, fd1, fa2, fb2, fc0, fc1); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 9, fd2, fd3, fa2, fb2, fc2, fc3); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 10, fd4, fd5, fa2, fb2, fc4, fc5); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 11, fd6, fd7, fa2, fb2, fc6, fc7); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 12, fd0, fd1, fa3, fb3, fc0, fc1); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 13, fd2, fd3, fa3, fb3, fc2, fc3); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 14, fd4, fd5, fa3, fb3, fc4, fc5); \
|
||||
MAKE_VX_HMMA_844_D_F32_STEP(fmt, 15, fd6, fd7, fa3, fb3, fc6, fc7); \
|
||||
ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \
|
||||
return ret
|
||||
|
||||
#define MAKE_VX_HMMA_844_C_F32_STEP(fmt, step, rd_lo, rd_hi, rs1, rs2) \
|
||||
__asm__ volatile (".word %1" : "=r"(rd_lo) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 0, 0)), "r"(rs1), "r"(rs2), "r"(rd_lo)); \
|
||||
__asm__ volatile (".word %1" : "=r"(rd_hi) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 1, 0)), "r"(rs1), "r"(rs2), "r"(rd_hi))
|
||||
|
||||
#define MAKE_VX_HMMA_844_C_F32(fmt) \
|
||||
mf32x8_t ret; \
|
||||
register float fa0 __asm__("f8") = a[0]; \
|
||||
register float fa1 __asm__("f9" ) = a[1]; \
|
||||
register float fa2 __asm__("f10") = a[2]; \
|
||||
register float fa3 __asm__("f11") = a[3]; \
|
||||
register float fa4 __asm__("f12") = a[4]; \
|
||||
register float fa5 __asm__("f13") = a[5]; \
|
||||
register float fa6 __asm__("f14") = a[6]; \
|
||||
register float fa7 __asm__("f15") = a[7]; \
|
||||
register float fb0 __asm__("f24") = b[0]; \
|
||||
register float fb1 __asm__("f25") = b[1]; \
|
||||
register float fb2 __asm__("f26") = b[2]; \
|
||||
register float fb3 __asm__("f27") = b[3]; \
|
||||
register float fb4 __asm__("f28") = b[4]; \
|
||||
register float fb5 __asm__("f29") = b[5]; \
|
||||
register float fb6 __asm__("f30") = b[6]; \
|
||||
register float fb7 __asm__("f31") = b[7]; \
|
||||
register float fc0 __asm__("f0") = c[0]; \
|
||||
register float fc1 __asm__("f1") = c[1]; \
|
||||
register float fc2 __asm__("f2") = c[2]; \
|
||||
register float fc3 __asm__("f3") = c[3]; \
|
||||
register float fc4 __asm__("f4") = c[4]; \
|
||||
register float fc5 __asm__("f5") = c[5]; \
|
||||
register float fc6 __asm__("f6") = c[6]; \
|
||||
register float fc7 __asm__("f7") = c[7]; \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 0, fc0, fc1, fa0, fb0); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 1, fc2, fc3, fa0, fb0); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 2, fc4, fc5, fa0, fb0); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 3, fc6, fc7, fa0, fb0); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 4, fc0, fc1, fa1, fb1); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 5, fc2, fc3, fa1, fb1); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 6, fc4, fc5, fa1, fb1); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 7, fc6, fc7, fa1, fb1); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 8, fc0, fc1, fa2, fb2); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 9, fc2, fc3, fa2, fb2); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 10, fc4, fc5, fa2, fb2); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 11, fc6, fc7, fa2, fb2); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 12, fc0, fc1, fa3, fb3); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 13, fc2, fc3, fa3, fb3); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 14, fc4, fc5, fa3, fb3); \
|
||||
MAKE_VX_HMMA_844_C_F32_STEP(fmt, 15, fc6, fc7, fa3, fb3); \
|
||||
ret = {fc0, fc1, fc2, fc3, fc4, fc5, fc6, fc7}; \
|
||||
return ret
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_hmma_844_c_f16_f32(const mf32x8_t& a, const mf32x8_t& b, const mf32x8_t& c) {
|
||||
MAKE_VX_HMMA_844_C_F32(0);
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) mf32x8_t vx_hmma_844_d_f16_f32(const mf32x8_t& a, const mf32x8_t& b, const mf32x8_t& c) {
|
||||
MAKE_VX_HMMA_844_D_F32(0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,48 +16,181 @@
|
|||
|
||||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
#include <type_traits>
|
||||
#include <hfloats.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#ifndef NUM_LANES
|
||||
#define NUM_LANES 8
|
||||
#endif
|
||||
|
||||
namespace tensor {
|
||||
|
||||
enum frag_layout_t { row_major, col_major };
|
||||
enum mem_layout_t { mem_row_major, mem_col_major };
|
||||
enum frag_use_t { matrix_d, matrix_a, matrix_b, matrix_c };
|
||||
enum layout_t { row_major, col_major };
|
||||
|
||||
template <typename T, frag_layout_t L>
|
||||
template <frag_use_t U, typename T, layout_t L>
|
||||
struct fragment {
|
||||
typedef T DType;
|
||||
static const frag_layout_t Layout = L;
|
||||
typedef T VType __attribute__((vector_size(8 * sizeof(void*))));
|
||||
VType data;
|
||||
typedef T Type;
|
||||
static const frag_use_t Use = U;
|
||||
static const layout_t Layout = L;
|
||||
mf32x8_t data;
|
||||
};
|
||||
|
||||
template <typename Frag>
|
||||
void fill_fragment(Frag &frag, size_t value) {
|
||||
// empty skeleton
|
||||
__attribute__((always_inline)) void map_operand_ab_32lanes(int tid, int &row, int &col) {
|
||||
int tg = tid / 4;
|
||||
|
||||
// A (row major)
|
||||
// Figure 7(a) in paper
|
||||
// row 0~ 3: threadgroups 0 and 2
|
||||
// row 4~ 7: threadgroups 4 and 6
|
||||
// row 8~11: threadgroups 1 and 3
|
||||
// row 12~15: threadgroups 5 and 7
|
||||
row = tid % 4;
|
||||
row += (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
// B (column major)
|
||||
// NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the
|
||||
// corrected mapping:
|
||||
// col 0~ 3: threadgroups 0 and 1
|
||||
// col 4~ 7: threadgroups 4 and 5
|
||||
// col 8~11: threadgroups 2 and 3
|
||||
// col 12~15: threadgroups 6 and 7
|
||||
col = tid % 4;
|
||||
col += ((tg % 4) / 2) * 8;
|
||||
col += (tg / 4) * 4;
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) void map_operand_ab_8lanes(int tid, int &row, int &col) {
|
||||
int tg = tid / 4;
|
||||
|
||||
// A (row major)
|
||||
// row 0~ 3: threadgroup 0
|
||||
// row 4~ 7: threadgroup 1
|
||||
row = tid % 4;
|
||||
row += tg * 4;
|
||||
|
||||
// B (column major)
|
||||
// col 0~ 3: threadgroup 0
|
||||
// col 4~ 7: threadgroup 1
|
||||
col = tid % 4;
|
||||
col += tg * 4;
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) void map_operand_c_32lanes(int tid, int &row, int &col) {
|
||||
int tg = tid / 4;
|
||||
|
||||
// Figure 7(b), left
|
||||
col = ((tg % 4) / 2) * 8;
|
||||
row = (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
// Figure 7(b), right
|
||||
row += (tid % 4) % 2;
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) void map_operand_c_8lanes(int tid, int &row, int &col) {
|
||||
int tg = tid / 4;
|
||||
|
||||
// Figure 7(b), left
|
||||
col = 0;
|
||||
row = tg * 4;
|
||||
|
||||
// Figure 7(b), right
|
||||
row += (tid % 4) % 2;
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) void map_operand_ab(int tid, int &row, int &col) {
|
||||
if constexpr (NUM_LANES == 32) {
|
||||
map_operand_ab_32lanes(tid, row, col);
|
||||
} else if constexpr (NUM_LANES == 8) {
|
||||
map_operand_ab_8lanes(tid, row, col);
|
||||
} else {
|
||||
static_assert(NUM_LANES == 32 || NUM_LANES == 8, "NUM_LANES must be 8 or 32");
|
||||
}
|
||||
}
|
||||
|
||||
__attribute__((always_inline)) void map_operand_c(int tid, int &row, int &col) {
|
||||
if constexpr (NUM_LANES == 32) {
|
||||
map_operand_c_32lanes(tid, row, col);
|
||||
} else if constexpr (NUM_LANES == 8) {
|
||||
map_operand_c_8lanes(tid, row, col);
|
||||
} else {
|
||||
static_assert(NUM_LANES == 32 || NUM_LANES == 8, "NUM_LANES must be 8 or 32");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Frag>
|
||||
void load_matrix_sync(Frag &frag, const void *ptr, size_t ld) {
|
||||
// empty skeleton
|
||||
__attribute__((always_inline)) void fill_fragment(Frag &dst, size_t value) {
|
||||
if constexpr (Frag::Use == matrix_d) {
|
||||
dst.data = vx_wsetm_d_f32(value);
|
||||
} else if constexpr (Frag::Use == matrix_a) {
|
||||
dst.data = vx_wsetm_a_f32(value);
|
||||
} else if constexpr (Frag::Use == matrix_b) {
|
||||
dst.data = vx_wsetm_b_f32(value);
|
||||
} else if constexpr (Frag::Use == matrix_c) {
|
||||
dst.data = vx_wsetm_c_f32(value);
|
||||
}
|
||||
}
|
||||
|
||||
template <layout_t mem_layout, typename Frag>
|
||||
__attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) {
|
||||
int row, col;
|
||||
int tid = vx_thread_id();
|
||||
map_operand_ab(tid, row, col);
|
||||
if constexpr (Frag::Use == matrix_a) {
|
||||
if constexpr (Frag::Layout == mem_layout) {
|
||||
dst.data = vx_wldm_ad_f32(src, row, ldm);
|
||||
} else {
|
||||
dst.data = vx_wldm_at_f32(src, col, ldm);
|
||||
}
|
||||
} else if constexpr (Frag::Use == matrix_b) {
|
||||
if constexpr (Frag::Layout == mem_layout) {
|
||||
dst.data = vx_wldm_bd_f32(src, row, ldm);
|
||||
} else {
|
||||
dst.data = vx_wldm_bt_f32(src, col, ldm);
|
||||
}
|
||||
} else {
|
||||
static_assert(false, "Only matrix_a and matrix_b are supported!");
|
||||
}
|
||||
}
|
||||
|
||||
template <layout_t mem_layout, typename Frag>
|
||||
__attribute__((always_inline)) void store_matrix_sync(void *dst, const Frag &src, size_t ldm) {
|
||||
static_assert(Frag::Layout == mem_layout, "fragment layout should match memory!");
|
||||
int row, col;
|
||||
int tid = vx_thread_id();
|
||||
map_operand_c(tid, row, col);
|
||||
if constexpr (Frag::Use == matrix_c) {
|
||||
vx_wstm_f32(dst, src.data, row, col, ldm);
|
||||
} else if constexpr (Frag::Use == matrix_d) {
|
||||
vx_wstm_f32(dst, src.data, row, col, ldm);
|
||||
} else {
|
||||
static_assert(false, "Only matrix_c or matrix_c are supported!");
|
||||
}
|
||||
}
|
||||
|
||||
// Perform the matrix multiply-accumulate: D = A * B + C
|
||||
template <typename FragD, typename FragA, typename FragB, typename FragC>
|
||||
void mma_sync(FragD &D, const FragA &A, const FragB &B, const FragC &C) {
|
||||
// empty skeleton
|
||||
}
|
||||
__attribute__((always_inline)) void mma_sync(FragD &D, const FragA &A, const FragB &B, const FragC &C) {
|
||||
static_assert(FragA::Use == matrix_a, "A must be matrix_a");
|
||||
static_assert(FragB::Use == matrix_b, "B must be matrix_b");
|
||||
static_assert(FragC::Use == matrix_c, "C must be matrix_c");
|
||||
static_assert(FragD::Use == matrix_d || FragD::Use == matrix_c, "D must be matrix_d or matrix_c");
|
||||
static_assert(std::is_same_v<typename FragA::Type, typename FragB::Type>, "A and B must have the same type");
|
||||
static_assert(std::is_same_v<typename FragC::Type, typename FragD::Type>, "C and D must have the same type");
|
||||
|
||||
// Store a fragment result back to global memory
|
||||
template <typename Type, typename Frag>
|
||||
void store_matrix_sync(void *ptr, const Frag &frag, size_t ld, mem_layout_t layout) {
|
||||
// empty skeleton
|
||||
if constexpr (std::is_same_v<typename FragC::Type, float>
|
||||
&& std::is_same_v<typename FragA::Type, vortex::half_t>) {
|
||||
if constexpr (FragD::Use == matrix_d) {
|
||||
D.data = vx_hmma_844_d_f16_f32(A.data, B.data, C.data);
|
||||
} else {
|
||||
D.data = vx_hmma_844_c_f16_f32(A.data, B.data, C.data);
|
||||
}
|
||||
} else {
|
||||
static_assert(false, "Unsupported type!");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace wmma
|
||||
|
|
109
sim/common/hfloats.h
Normal file
109
sim/common/hfloats.h
Normal file
|
@ -0,0 +1,109 @@
|
|||
// Copyright © 2019-2023
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
// A minimal IEEE 754 half-precision (16-bit) float implementation
|
||||
// Provides conversion to/from 32-bit float and basic arithmetic operators.
|
||||
|
||||
namespace vortex {
|
||||
|
||||
struct half_t {
|
||||
uint16_t bits;
|
||||
|
||||
half_t() = default;
|
||||
// Construct from float
|
||||
half_t(float f) { bits = float_to_half(f); }
|
||||
// Convert to float
|
||||
operator float() const { return half_to_float(bits); }
|
||||
|
||||
// Arithmetic operators
|
||||
friend half_t operator+(half_t a, half_t b) { return half_t((float)a + (float)b); }
|
||||
friend half_t operator-(half_t a, half_t b) { return half_t((float)a - (float)b); }
|
||||
friend half_t operator*(half_t a, half_t b) { return half_t((float)a * (float)b); }
|
||||
friend half_t operator/(half_t a, half_t b) { return half_t((float)a / (float)b); }
|
||||
|
||||
private:
|
||||
static uint16_t float_to_half(float f) {
|
||||
uint32_t x;
|
||||
std::memcpy(&x, &f, sizeof(x));
|
||||
uint32_t sign = (x >> 16) & 0x8000;
|
||||
uint32_t mant = x & 0x007FFFFF;
|
||||
uint32_t exp = x & 0x7F800000;
|
||||
|
||||
if (exp >= 0x47800000) {
|
||||
// Inf or NaN
|
||||
if (mant && exp == 0x7F800000) {
|
||||
// NaN: preserve some payload
|
||||
return static_cast<uint16_t>(sign | 0x0200);
|
||||
}
|
||||
// Infinity
|
||||
return static_cast<uint16_t>(sign | 0x7C00);
|
||||
}
|
||||
if (exp <= 0x38000000) {
|
||||
// Subnormal or zero
|
||||
if (exp < 0x33000000) {
|
||||
// Too small: underflows to zero
|
||||
return static_cast<uint16_t>(sign);
|
||||
}
|
||||
// Subnormal
|
||||
mant |= 0x00800000;
|
||||
int shift = 113 - (exp >> 23);
|
||||
mant = (mant >> shift) + ((mant >> (shift - 1)) & 1);
|
||||
return static_cast<uint16_t>(sign | (mant & 0x03FF));
|
||||
}
|
||||
// Normalized number
|
||||
uint16_t h_exp = static_cast<uint16_t>(((exp - 0x38000000) >> 13) & 0x7C00);
|
||||
uint16_t h_mant = static_cast<uint16_t>(mant >> 13);
|
||||
return static_cast<uint16_t>(sign | h_exp | h_mant);
|
||||
}
|
||||
|
||||
static float half_to_float(uint16_t h) {
|
||||
uint32_t sign = (h & 0x8000) << 16;
|
||||
uint32_t exp = h & 0x7C00;
|
||||
uint32_t mant = h & 0x03FF;
|
||||
uint32_t f;
|
||||
|
||||
if (exp == 0x7C00) {
|
||||
// Inf or NaN
|
||||
f = sign | 0x7F800000 | (mant << 13);
|
||||
} else if (exp != 0) {
|
||||
// Normalized
|
||||
uint32_t e = ((exp >> 10) + 112) << 23;
|
||||
f = sign | e | (mant << 13);
|
||||
} else if (mant != 0) {
|
||||
// Subnormal
|
||||
mant <<= 1;
|
||||
int e = -1;
|
||||
while (!(mant & 0x0400)) {
|
||||
mant <<= 1;
|
||||
e--;
|
||||
}
|
||||
mant &= 0x03FF;
|
||||
uint32_t f_e = static_cast<uint32_t>(e + 1 + 127) << 23;
|
||||
f = sign | f_e | (mant << 13);
|
||||
} else {
|
||||
// Zero
|
||||
f = sign;
|
||||
}
|
||||
float result;
|
||||
std::memcpy(&result, &f, sizeof(result));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace vortex
|
|
@ -24,6 +24,11 @@ SRCS += $(SRC_DIR)/execute.cpp $(SRC_DIR)/func_unit.cpp
|
|||
SRCS += $(SRC_DIR)/cache_sim.cpp $(SRC_DIR)/mem_sim.cpp $(SRC_DIR)/local_mem.cpp $(SRC_DIR)/mem_coalescer.cpp
|
||||
SRCS += $(SRC_DIR)/dcrs.cpp $(SRC_DIR)/types.cpp
|
||||
|
||||
# Add TPU extension sources
|
||||
ifneq ($(findstring -DEXT_TPU_ENABLE, $(CONFIGS)),)
|
||||
SRCS += $(SRC_DIR)/tensor_unit.cpp
|
||||
endif
|
||||
|
||||
# Add V extension sources
|
||||
ifneq ($(findstring -DEXT_V_ENABLE, $(CONFIGS)),)
|
||||
SRCS += $(SRC_DIR)/voperands.cpp
|
||||
|
|
|
@ -39,6 +39,9 @@ Core::Core(const SimContext& ctx,
|
|||
, core_id_(core_id)
|
||||
, socket_(socket)
|
||||
, arch_(arch)
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
, tensor_unit_(TensorUnit::Create("tpu", arch, this))
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
, vec_unit_(VecUnit::Create("vpu", arch, this))
|
||||
#endif
|
||||
|
@ -136,6 +139,9 @@ Core::Core(const SimContext& ctx,
|
|||
dispatchers_.at((int)FUType::FPU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_FPU_BLOCKS, NUM_FPU_LANES);
|
||||
dispatchers_.at((int)FUType::LSU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_LSU_BLOCKS, NUM_LSU_LANES);
|
||||
dispatchers_.at((int)FUType::SFU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_SFU_BLOCKS, NUM_SFU_LANES);
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
dispatchers_.at((int)FUType::TPU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_VPU_BLOCKS, NUM_VPU_LANES);
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
dispatchers_.at((int)FUType::VPU) = SimPlatform::instance().create_object<Dispatcher>(this, 2, NUM_VPU_BLOCKS, NUM_VPU_LANES);
|
||||
#endif
|
||||
|
@ -145,6 +151,9 @@ Core::Core(const SimContext& ctx,
|
|||
func_units_.at((int)FUType::FPU) = SimPlatform::instance().create_object<FpuUnit>(this);
|
||||
func_units_.at((int)FUType::LSU) = SimPlatform::instance().create_object<LsuUnit>(this);
|
||||
func_units_.at((int)FUType::SFU) = SimPlatform::instance().create_object<SfuUnit>(this);
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
func_units_.at((int)FUType::TPU) = SimPlatform::instance().create_object<TpuUnit>(this);
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
func_units_.at((int)FUType::VPU) = SimPlatform::instance().create_object<VpuUnit>(this);
|
||||
#endif
|
||||
|
@ -341,6 +350,9 @@ void Core::issue() {
|
|||
default: assert(false);
|
||||
}
|
||||
} break;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
case FUType::TPU: ++perf_stats_.scrb_tpu; break;
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
case FUType::VPU: ++perf_stats_.scrb_vpu; break;
|
||||
#endif
|
||||
|
|
|
@ -59,6 +59,9 @@ public:
|
|||
uint64_t scrb_sfu;
|
||||
uint64_t scrb_csrs;
|
||||
uint64_t scrb_wctl;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
uint64_t scrb_tpu;
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
uint64_t vinstrs;
|
||||
uint64_t scrb_vpu;
|
||||
|
@ -83,6 +86,9 @@ public:
|
|||
, scrb_sfu(0)
|
||||
, scrb_csrs(0)
|
||||
, scrb_wctl(0)
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
, scrb_tpu(0)
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
, vinstrs(0)
|
||||
, scrb_vpu(0)
|
||||
|
@ -155,6 +161,12 @@ public:
|
|||
return emulator_.dcache_write(data, addr, size);
|
||||
}
|
||||
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
TensorUnit::Ptr& tensor_unit() {
|
||||
return tensor_unit_;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef EXT_V_ENABLE
|
||||
VecUnit::Ptr& vec_unit() {
|
||||
return vec_unit_;
|
||||
|
@ -182,6 +194,10 @@ private:
|
|||
Socket* socket_;
|
||||
const Arch& arch_;
|
||||
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
TensorUnit::Ptr tensor_unit_;
|
||||
#endif
|
||||
|
||||
#ifdef EXT_V_ENABLE
|
||||
VecUnit::Ptr vec_unit_;
|
||||
#endif
|
||||
|
|
|
@ -55,12 +55,12 @@ static const std::unordered_map<Opcode, InstType> sc_instTable = {
|
|||
};
|
||||
|
||||
static const char* op_string(const Instr &instr) {
|
||||
auto opcode = instr.getOpcode();
|
||||
auto opcode = instr.getOpcode();
|
||||
auto funct2 = instr.getFunct2();
|
||||
auto funct3 = instr.getFunct3();
|
||||
auto funct7 = instr.getFunct7();
|
||||
auto rd = instr.getDestReg();
|
||||
auto rs1 = instr.getSrcReg(1);
|
||||
auto rd = instr.getDestReg();
|
||||
auto rs1 = instr.getSrcReg(1);
|
||||
auto imm = instr.getImm();
|
||||
|
||||
switch (opcode) {
|
||||
|
@ -386,23 +386,29 @@ static const char* op_string(const Instr &instr) {
|
|||
default:
|
||||
std::abort();
|
||||
}
|
||||
case 1:
|
||||
switch (funct3) {
|
||||
case 0: // gfx reserved
|
||||
std::abort();
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
case 2:
|
||||
switch (funct3) {
|
||||
case 0: return "HMMA844";
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
case Opcode::EXT2:
|
||||
switch(funct3) {
|
||||
case 0: // reserved
|
||||
case 1: // reserved
|
||||
case 0: // gfx reserved
|
||||
case 1: // gfx reserved
|
||||
std::abort();
|
||||
case 2:
|
||||
switch (funct2) {
|
||||
case 0: return "MMADD.u4_i32";
|
||||
case 1: return "MMADD.u8_i32";
|
||||
case 2: return "MMADD.f16_f32";
|
||||
case 3: return "MMADD.bf16_f32";
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
|
@ -468,7 +474,7 @@ std::ostream &operator<<(std::ostream &os, const Instr &instr) {
|
|||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Instr> Emulator::decode(uint32_t code) const {
|
||||
Instr::Ptr Emulator::decode(uint32_t code) const {
|
||||
auto instr = std::allocate_shared<Instr>(instr_pool_);
|
||||
auto op = Opcode((code >> shift_opcode) & mask_opcode);
|
||||
instr->setOpcode(op);
|
||||
|
@ -574,6 +580,31 @@ std::shared_ptr<Instr> Emulator::decode(uint32_t code) const {
|
|||
std::abort();
|
||||
}
|
||||
break;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
case 2: {
|
||||
switch (funct3) {
|
||||
case 0: { // HMMA844
|
||||
uint32_t fmt = rd;
|
||||
uint32_t steps = rs1 >> 1;
|
||||
uint32_t step = steps % 4;
|
||||
uint32_t set = steps / 4;
|
||||
uint32_t rd_pair = rs1 & 0x1;
|
||||
uint32_t use_d = rs2;
|
||||
uint32_t base_rd = (use_d ? 16 : 0) + (step * 2 + rd_pair); // C/D
|
||||
uint32_t base_rs1 = 8 + set; // A
|
||||
uint32_t base_rs2 = 24 + set; // B
|
||||
uint32_t base_rs3 = 0 + step; // C
|
||||
instr->setImm((fmt << 2) + step); // fmt + step
|
||||
instr->setDestReg(base_rd, RegType::Float);
|
||||
instr->addSrcReg(base_rs1, RegType::Float);
|
||||
instr->addSrcReg(base_rs2, RegType::Float);
|
||||
instr->addSrcReg(base_rs3, RegType::Float);
|
||||
} break;
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
} break;
|
||||
#endif
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
|
|
|
@ -79,6 +79,9 @@ Emulator::Emulator(const Arch &arch, const DCRS &dcrs, Core* core)
|
|||
, warps_(arch.num_warps(), arch.num_threads())
|
||||
, barriers_(arch.num_barriers(), 0)
|
||||
, ipdom_size_(arch.num_threads()-1)
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
, tensor_unit_(core->tensor_unit())
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
, vec_unit_(core->vec_unit())
|
||||
#endif
|
||||
|
@ -162,7 +165,7 @@ instr_trace_t* Emulator::step() {
|
|||
if (scheduled_warp == -1)
|
||||
return nullptr;
|
||||
|
||||
// suspend warp until decode
|
||||
// get scheduled warp
|
||||
auto& warp = warps_.at(scheduled_warp);
|
||||
assert(warp.tmask.any());
|
||||
|
||||
|
|
|
@ -19,6 +19,10 @@
|
|||
#include <stack>
|
||||
#include <mem.h>
|
||||
#include "types.h"
|
||||
#include "instr.h"
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
#include "tensor_unit.h"
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
#include "vec_unit.h"
|
||||
#endif
|
||||
|
@ -105,7 +109,7 @@ public:
|
|||
|
||||
private:
|
||||
|
||||
std::shared_ptr<Instr> decode(uint32_t code) const;
|
||||
Instr::Ptr decode(uint32_t code) const;
|
||||
|
||||
void execute(const Instr &instr, uint32_t wid, instr_trace_t *trace);
|
||||
|
||||
|
@ -146,6 +150,10 @@ private:
|
|||
Word csr_mscratch_;
|
||||
wspawn_t wspawn_;
|
||||
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
TensorUnit::Ptr tensor_unit_;
|
||||
#endif
|
||||
|
||||
#ifdef EXT_V_ENABLE
|
||||
VecUnit::Ptr vec_unit_;
|
||||
#endif
|
||||
|
|
|
@ -1412,6 +1412,24 @@ void Emulator::execute(const Instr &instr, uint32_t wid, instr_trace_t *trace) {
|
|||
std::abort();
|
||||
}
|
||||
} break;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
case 2: {
|
||||
switch (funct3) {
|
||||
case 0: { // HMMA844
|
||||
trace->fu_type = FUType::TPU;
|
||||
trace->tpu_type = TpuType::HMMA844;
|
||||
auto trace_data = std::make_shared<TensorUnit::ExeTraceData>();
|
||||
trace->data = trace_data;
|
||||
uint32_t fmt = immsrc >> 2;
|
||||
uint32_t step = immsrc & 0x3;
|
||||
tensor_unit_->hmma844(wid, fmt, step, rs1_data, rs2_data, rs3_data, rd_data, trace_data.get());
|
||||
rd_write = true;
|
||||
} break;
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
} break;
|
||||
#endif
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
|
|
|
@ -317,6 +317,25 @@ void SfuUnit::tick() {
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
|
||||
TpuUnit::TpuUnit(const SimContext& ctx, Core* core)
|
||||
: FuncUnit(ctx, core, "tpu-unit")
|
||||
{
|
||||
// bind tensor unit
|
||||
for (uint32_t iw = 0; iw < ISSUE_WIDTH; ++iw) {
|
||||
this->Inputs.at(iw).bind(&core_->tensor_unit()->Inputs.at(iw));
|
||||
core_->tensor_unit()->Outputs.at(iw).bind(&this->Outputs.at(iw));
|
||||
}
|
||||
}
|
||||
|
||||
void TpuUnit::tick() {
|
||||
// use tensor_unit
|
||||
}
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef EXT_V_ENABLE
|
||||
|
||||
VpuUnit::VpuUnit(const SimContext& ctx, Core* core)
|
||||
|
|
|
@ -110,6 +110,19 @@ public:
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
|
||||
class TpuUnit : public FuncUnit {
|
||||
public:
|
||||
TpuUnit(const SimContext& ctx, Core*);
|
||||
|
||||
void tick() override;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef EXT_V_ENABLE
|
||||
|
||||
class VpuUnit : public FuncUnit {
|
||||
|
|
|
@ -125,6 +125,8 @@ enum VectorAttrMask {
|
|||
|
||||
class Instr {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<Instr>;
|
||||
|
||||
Instr()
|
||||
: opcode_(Opcode::NONE)
|
||||
, num_rsrcs_(0)
|
||||
|
|
|
@ -72,6 +72,9 @@ public:
|
|||
AluType alu_type;
|
||||
FpuType fpu_type;
|
||||
SfuType sfu_type;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
TpuType tpu_type;
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
VpuType vpu_type;
|
||||
#endif
|
||||
|
|
|
@ -13,40 +13,20 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "tensor_unit.h"
|
||||
#include "core.h"
|
||||
|
||||
using namespace vortex;
|
||||
|
||||
template <typename T>
|
||||
class FMAD : public SimObject<FMAD<T>> {
|
||||
public:
|
||||
SimPort<T> Input;
|
||||
SimPort<T> Output;
|
||||
|
||||
FMAD(const SimContext &ctx, const char* name)
|
||||
: SimObject<FMAD<T>>(ctx, name)
|
||||
, Input(this)
|
||||
, Output(this)
|
||||
{}
|
||||
|
||||
virtual ~FMAD() {}
|
||||
|
||||
void reset() {
|
||||
//--
|
||||
}
|
||||
|
||||
void tick() {
|
||||
//--
|
||||
}
|
||||
};
|
||||
|
||||
class TensorUnit::Impl {
|
||||
public:
|
||||
Impl(TensorUnit* simobject, const Config& config, Core* core)
|
||||
Impl(TensorUnit* simobject, const Arch& arch, Core* core)
|
||||
: simobject_(simobject)
|
||||
, config_(config)
|
||||
, core_(core)
|
||||
, arch_(arch)
|
||||
, perf_stats_()
|
||||
{}
|
||||
{
|
||||
//--
|
||||
}
|
||||
|
||||
~Impl() {
|
||||
// Destructor logic if needed
|
||||
|
@ -57,7 +37,48 @@ public:
|
|||
}
|
||||
|
||||
void tick() {
|
||||
// Implement the tick logic here
|
||||
for (uint32_t iw = 0; iw < ISSUE_WIDTH; ++iw) {
|
||||
auto& input = simobject_->Inputs.at(iw);
|
||||
if (input.empty())
|
||||
return;
|
||||
auto trace = input.front();
|
||||
int delay = 0;
|
||||
switch (trace->tpu_type) {
|
||||
case TpuType::HMMA844:
|
||||
delay = 4;
|
||||
break;
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
simobject_->Outputs.at(iw).push(trace, 2 + delay);
|
||||
DT(3, simobject_->name() << ": op=" << trace->tpu_type << ", " << *trace);
|
||||
input.pop();
|
||||
}
|
||||
}
|
||||
|
||||
void hmma844(uint32_t wid,
|
||||
uint32_t fmt, uint32_t step,
|
||||
const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
ExeTraceData* trace_data) {
|
||||
uint32_t num_octects = arch_.num_threads() / 8;
|
||||
uint32_t threadgroup_lane_offset = 4 * num_octects;
|
||||
for (uint32_t i = 0; i < num_octects; ++i) {
|
||||
std::vector<reg_data_t> octet_A(8);
|
||||
std::vector<reg_data_t> octet_B(8);
|
||||
std::vector<reg_data_t> octet_C(8);
|
||||
std::vector<reg_data_t> octet_D(8);
|
||||
|
||||
for (uint32_t j = 0; j < 8; ++j) {
|
||||
octet_A[j] = rs1_data[i * 8 + j];
|
||||
octet_B[j] = rs2_data[i * 8 + j];
|
||||
octet_C[j] = rs3_data[i * 8 + j];
|
||||
octet_D[j] = rd_data[i * 8 + j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const PerfStats& perf_stats() const {
|
||||
|
@ -66,18 +87,18 @@ public:
|
|||
|
||||
private:
|
||||
TensorUnit* simobject_;
|
||||
Config config_;
|
||||
Core* core_;
|
||||
Arch arch_;
|
||||
PerfStats perf_stats_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TensorUnit::TensorUnit(const SimContext &ctx, const char* name, const Config& config, Core* core)
|
||||
TensorUnit::TensorUnit(const SimContext &ctx, const char* name, const Arch& arch, Core* core)
|
||||
: SimObject<TensorUnit>(ctx, name)
|
||||
, Inputs(config.num_ports, this)
|
||||
, Outputs(config.num_ports, this)
|
||||
, impl_(new Impl(this, config, core))
|
||||
, Inputs(ISSUE_WIDTH, this)
|
||||
, Outputs(ISSUE_WIDTH, this)
|
||||
, impl_(new Impl(this, arch, core))
|
||||
{}
|
||||
|
||||
TensorUnit::~TensorUnit() {
|
||||
|
@ -94,4 +115,14 @@ void TensorUnit::tick() {
|
|||
|
||||
const TensorUnit::PerfStats &TensorUnit::perf_stats() const {
|
||||
return impl_->perf_stats();
|
||||
}
|
||||
|
||||
void TensorUnit::hmma844(uint32_t wid,
|
||||
uint32_t fmt, uint32_t step,
|
||||
const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
ExeTraceData* trace_data) {
|
||||
impl_->hmma844(wid, fmt, step, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
|
||||
}
|
|
@ -22,15 +22,10 @@ class Core;
|
|||
|
||||
class TensorUnit : public SimObject<TensorUnit> {
|
||||
public:
|
||||
struct Config {
|
||||
uint8_t num_ports;
|
||||
uint8_t mac_latency;
|
||||
|
||||
Config()
|
||||
: num_ports(0)
|
||||
, mac_latency(0)
|
||||
{}
|
||||
};
|
||||
struct ExeTraceData : public ITraceData {
|
||||
using Ptr = std::shared_ptr<ExeTraceData>;
|
||||
};
|
||||
|
||||
struct PerfStats {
|
||||
uint64_t latency;
|
||||
|
@ -48,14 +43,21 @@ public:
|
|||
std::vector<SimPort<instr_trace_t*>> Inputs;
|
||||
std::vector<SimPort<instr_trace_t*>> Outputs;
|
||||
|
||||
TensorUnit(const SimContext &ctx, const char* name, const Config& config, Core* core);
|
||||
|
||||
TensorUnit(const SimContext &ctx, const char* name, const Arch& arch, Core* core);
|
||||
virtual ~TensorUnit();
|
||||
|
||||
virtual void reset();
|
||||
|
||||
virtual void tick();
|
||||
|
||||
void hmma844(uint32_t wid,
|
||||
uint32_t fmt, uint32_t step,
|
||||
const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
ExeTraceData* trace_data);
|
||||
|
||||
const PerfStats& perf_stats() const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -127,6 +127,9 @@ enum class FUType {
|
|||
LSU,
|
||||
FPU,
|
||||
SFU,
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
TPU,
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
VPU,
|
||||
#endif
|
||||
|
@ -139,6 +142,9 @@ inline std::ostream &operator<<(std::ostream &os, const FUType& type) {
|
|||
case FUType::LSU: os << "LSU"; break;
|
||||
case FUType::FPU: os << "FPU"; break;
|
||||
case FUType::SFU: os << "SFU"; break;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
case FUType::TPU: os << "TPU"; break;
|
||||
#endif
|
||||
#ifdef EXT_V_ENABLE
|
||||
case FUType::VPU: os << "VPU"; break;
|
||||
#endif
|
||||
|
@ -286,6 +292,20 @@ inline std::ostream &operator<<(std::ostream &os, const SfuType& type) {
|
|||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class TpuType {
|
||||
HMMA844 = 0,
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const TpuType& type) {
|
||||
switch (type) {
|
||||
case TpuType::HMMA844: os << "HMMA844"; break;
|
||||
default: assert(false);
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum class VpuType {
|
||||
VSET = 0,
|
||||
|
||||
|
|
|
@ -96,10 +96,6 @@ public:
|
|||
|
||||
DT(3, simobject_->name() << ": op=" << trace->vpu_type << ", " << *trace);
|
||||
|
||||
if (trace->eop && trace->fetch_stall) {
|
||||
core_->resume(trace->wid);
|
||||
}
|
||||
|
||||
input.pop();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,11 +56,7 @@ public:
|
|||
std::vector<SimPort<instr_trace_t*>> Inputs;
|
||||
std::vector<SimPort<instr_trace_t*>> Outputs;
|
||||
|
||||
VecUnit(const SimContext& ctx,
|
||||
const char* name,
|
||||
const Arch& arch,
|
||||
Core* core);
|
||||
|
||||
VecUnit(const SimContext& ctx, const char* name, const Arch& arch, Core* core);
|
||||
~VecUnit();
|
||||
|
||||
void reset();
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
#include <hfloats.h>
|
||||
|
||||
#ifndef I_TYPE
|
||||
#define I_TYPE vortex::half_t
|
||||
#define I_TYPE float
|
||||
#endif
|
||||
|
||||
#ifndef O_TYPE
|
||||
|
|
|
@ -1,46 +1,46 @@
|
|||
#include "common.h"
|
||||
#include <vx_spawn.h>
|
||||
#include <vx_tensor.h>
|
||||
#include "common.h"
|
||||
|
||||
void kernel_body(kernel_arg_t* __UNIFORM__ arg) {
|
||||
auto A = reinterpret_cast<I_TYPE*>(arg->A_addr);
|
||||
auto B = reinterpret_cast<I_TYPE*>(arg->B_addr);
|
||||
auto C = reinterpret_cast<O_TYPE*>(arg->C_addr);
|
||||
void kernel_body(kernel_arg_t *__UNIFORM__ arg) {
|
||||
auto A = reinterpret_cast<I_TYPE *>(arg->A_addr);
|
||||
auto B = reinterpret_cast<I_TYPE *>(arg->B_addr);
|
||||
auto C = reinterpret_cast<O_TYPE *>(arg->C_addr);
|
||||
|
||||
tensor::fragment<tensor::half_t, tensor::row_major> fragA;
|
||||
tensor::fragment<tensor::half_t, tensor::row_major> fragB;
|
||||
tensor::fragment<float, tensor::row_major> fragC;
|
||||
tensor::fragment<tensor::matrix_a, I_TYPE, tensor::row_major> fragA;
|
||||
tensor::fragment<tensor::matrix_b, I_TYPE, tensor::col_major> fragB;
|
||||
tensor::fragment<tensor::matrix_c, O_TYPE, tensor::row_major> fragC;
|
||||
|
||||
// calculate tile row & column based on block index
|
||||
uint32_t tile_row = blockIdx.y * arg->tileM;
|
||||
uint32_t tile_col = blockIdx.x * arg->tileN;
|
||||
// calculate tile row & column based on block index
|
||||
uint32_t tile_row = blockIdx.y * arg->tileM;
|
||||
uint32_t tile_col = blockIdx.x * arg->tileN;
|
||||
|
||||
uint32_t N = arg->N;
|
||||
uint32_t K = arg->K;
|
||||
uint32_t tileK = arg->tileK;
|
||||
uint32_t N = arg->N;
|
||||
uint32_t K = arg->K;
|
||||
uint32_t tileK = arg->tileK;
|
||||
|
||||
// Initialize accumulator tile to zero
|
||||
tensor::fill_fragment(fragC, 0.0f);
|
||||
// Initialize accumulator tile to zero
|
||||
tensor::fill_fragment(fragC, 0);
|
||||
|
||||
for (int i = 0; i < K; i += tileK) {
|
||||
// Load A tile
|
||||
auto tileA = A + (tile_row * K + i);
|
||||
tensor::load_matrix_sync(fragA, tileA, K);
|
||||
for (int i = 0; i < K; i += tileK) {
|
||||
// Load A tile
|
||||
auto tileA = A + (tile_row * K + i);
|
||||
tensor::load_matrix_sync<tensor::row_major>(fragA, tileA, K);
|
||||
|
||||
// Load B tile
|
||||
auto tileB = B + (i * k + tile_col);
|
||||
tensor::load_matrix_sync(fragB, tileB, K);
|
||||
// Load B tile
|
||||
auto tileB = B + (i * K + tile_col);
|
||||
tensor::load_matrix_sync<tensor::row_major>(fragB, tileB, K);
|
||||
|
||||
// Matrix multiply-accumulate: c += a * b
|
||||
tensor::mma_sync(fragC, fragA, fragB, fragC);
|
||||
}
|
||||
// Matrix multiply-accumulate: c += a * b
|
||||
tensor::mma_sync(fragC, fragA, fragB, fragC);
|
||||
}
|
||||
|
||||
// Store the computed C tile
|
||||
auto tileC = C + (tile_row * N + tile_col);
|
||||
tensor::store_matrix_sync(tileC, fragC, N, tensor::mem_row_major);
|
||||
// Store the computed C tile
|
||||
auto tileC = C + (tile_row * N + tile_col);
|
||||
tensor::store_matrix_sync<tensor::row_major>(tileC, fragC, N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH);
|
||||
return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg);
|
||||
kernel_arg_t *arg = (kernel_arg_t *)csr_read(VX_CSR_MSCRATCH);
|
||||
return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg);
|
||||
}
|
||||
|
|
|
@ -192,11 +192,15 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
uint64_t NT;
|
||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &NT));
|
||||
std::cout << "GPU warp size: " << NT << std::endl;
|
||||
if (NT < 4) {
|
||||
std::cout << "Error: warp size must be at least 4 threads!" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "GPU warp size: " << NT << " threads" << std::endl;
|
||||
|
||||
uint64_t isa_flags;
|
||||
RT_CHECK(vx_dev_caps(device, VX_CAPS_ISA_FLAGS, &isa_flags));
|
||||
uint32_t XlenB = 4 * VX_ISA_ARCH(isa_flags);
|
||||
uint32_t XlenB = VX_ISA_ARCH(isa_flags) / 8;
|
||||
std::cout << "GPU XLEN: " << 8 * XlenB << std::endl;
|
||||
|
||||
// tile format ratio
|
||||
|
@ -206,22 +210,22 @@ int main(int argc, char *argv[]) {
|
|||
// determine tensor tile size
|
||||
uint32_t logNT = log2(NT);
|
||||
uint32_t tileM = 4 * (1 << (logNT / 2)) * o_ratio;
|
||||
uint32_t tileN = (logNT % 2 == 0) ? tileM / 2 : tileN;
|
||||
uint32_t tileN = (logNT % 2 == 0) ? (tileM / 2) : tileM;
|
||||
uint32_t tileK = std::min(tileM, tileN) * i_ratio;
|
||||
|
||||
std::cout << "GPU tensor tileM=" << tileM << ", tileN=" << tileM << ", tileK=" << tileM << std::endl;
|
||||
std::cout << "GPU tensor tileM=" << tileM << ", tileN=" << tileM << ", tileK=" << tileK << std::endl;
|
||||
|
||||
if ((M & (tileM - 1)) != 0) {
|
||||
if ((M % tileM) != 0) {
|
||||
std::cout << "Error: M must be a multiple of tensor tileM!" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((N & (tileN - 1)) != 0) {
|
||||
if ((N % tileN) != 0) {
|
||||
std::cout << "Error: M must be a multiple of tensor tileN!" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((K & (tileK - 1)) != 0) {
|
||||
if ((K % tileK) != 0) {
|
||||
std::cout << "Error: M must be a multiple of tensor tileK!" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
@ -234,8 +238,8 @@ int main(int argc, char *argv[]) {
|
|||
size_t sizeB = K * N;
|
||||
size_t sizeC = M * N;
|
||||
|
||||
std::cout << "input data type: " << Comparator<I_TYPE>::type_str() << std::endl;
|
||||
std::cout << "output data type: " << Comparator<O_TYPE>::type_str() << std::endl;
|
||||
std::cout << "input data type: " << Comparator<I_TYPE>::type_str() << " (" << sizeof(I_TYPE) << " bytes)" << std::endl;
|
||||
std::cout << "output data type: " << Comparator<O_TYPE>::type_str() << " (" << sizeof(O_TYPE) << " bytes)" << std::endl;
|
||||
std::cout << "matrix A: " << M << "x" << K << std::endl;
|
||||
std::cout << "matrix B: " << K << "x" << N << std::endl;
|
||||
std::cout << "matrix C: " << M << "x" << N << std::endl;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue