diff --git a/kernel/include/vx_intrinsics.h b/kernel/include/vx_intrinsics.h index 6000065e9..72f0e6254 100644 --- a/kernel/include/vx_intrinsics.h +++ b/kernel/include/vx_intrinsics.h @@ -36,6 +36,25 @@ extern "C" { #define RISCV_CUSTOM2 0x5B #define RISCV_CUSTOM3 0x7B +#define RISCV_INSN_R(opcode7, funct3, funct7, rd, rs1, rs2) ( \ + ((funct7 & 0x7F) << 25) | \ + ((rs2 & 0x1F) << 20) | \ + ((rs1 & 0x1F) << 15) | \ + ((funct3 & 0x7) << 12) | \ + ((rd & 0x1F) << 7) | \ + (opcode7 & 0x7F) \ +) + +#define RISCV_INSN_R4(opcode7, funct3, funct2, rd, rs1, rs2, rs3) ( \ + ((rs3 & 0x1F) << 27) | \ + ((funct2 & 0x3) << 25) | \ + ((rs2 & 0x1F) << 20) | \ + ((rs1 & 0x1F) << 15) | \ + ((funct3 & 0x7) << 12) | \ + ((rd & 0x1F) << 7) | \ + (opcode7 & 0x7F) \ +) + #define csr_read(csr) ({ \ size_t __r; \ __asm__ __volatile__ ("csrr %0, %1" : "=r" (__r) : "i" (csr) : "memory"); \ @@ -221,6 +240,233 @@ inline void vx_fence() { __asm__ volatile ("fence iorw, iorw"); } +typedef float mf32x8_t __attribute__((vector_size(8*4))); // 8 x f32 registers + +#define MAKE_VX_WSETM_F32(f0, f1, f2, f3, f4, f5, f6, f7) \ + mf32x8_t ret; \ + register float fd0 __asm__(f0); \ + register float fd1 __asm__(f1); \ + register float fd2 __asm__(f2); \ + register float fd3 __asm__(f3); \ + register float fd4 __asm__(f4); \ + register float fd5 __asm__(f5); \ + register float fd6 __asm__(f6); \ + register float fd7 __asm__(f7); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd0): "r"(value)); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd1): "r"(value)); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd2): "r"(value)); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd3): "r"(value)); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd4): "r"(value)); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd5): "r"(value)); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd6): "r"(value)); \ + __asm__ volatile("fmv.w.x %0, %1" : "=f"(fd7): "r"(value)); \ + ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \ + return ret + +__attribute__((always_inline)) mf32x8_t vx_wsetm_a_f32(size_t value) { + MAKE_VX_WSETM_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15"); +} + +__attribute__((always_inline)) mf32x8_t vx_wsetm_b_f32(size_t value) { + MAKE_VX_WSETM_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31"); +} + +__attribute__((always_inline)) mf32x8_t vx_wsetm_c_f32(size_t value) { + MAKE_VX_WSETM_F32("f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7"); +} + +__attribute__((always_inline)) mf32x8_t vx_wsetm_d_f32(size_t value) { + MAKE_VX_WSETM_F32("f16", "f17", "f18", "f19", "f20", "f21", "f22", "f23"); +} + +#define MAKE_VX_WLDM_D_F32(f0, f1, f2, f3, f4, f5, f6, f7) \ + mf32x8_t ret; \ + auto base = (const float*)src + row * ldm; \ + register float fd0 __asm__(f0); \ + register float fd1 __asm__(f1); \ + register float fd2 __asm__(f2); \ + register float fd3 __asm__(f3); \ + register float fd4 __asm__(f4); \ + register float fd5 __asm__(f5); \ + register float fd6 __asm__(f6); \ + register float fd7 __asm__(f7); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd0) : "m"(base[0])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd1) : "m"(base[1])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd2) : "m"(base[2])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd3) : "m"(base[3])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd4) : "m"(base[4])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd5) : "m"(base[5])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd6) : "m"(base[6])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd7) : "m"(base[7])); \ + ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \ + return ret + +#define MAKE_VX_WLDM_T_F32(f0, f1, f2, f3, f4, f5, f6, f7) \ + mf32x8_t ret; \ + auto base = (const float*)src + col; \ + register float fd0 __asm__(f0); \ + register float fd1 __asm__(f1); \ + register float fd2 __asm__(f2); \ + register float fd3 __asm__(f3); \ + register float fd4 __asm__(f4); \ + register float fd5 __asm__(f5); \ + register float fd6 __asm__(f6); \ + register float fd7 __asm__(f7); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd0) : "m"(base[0 * ldm])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd1) : "m"(base[1 * ldm])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd2) : "m"(base[2 * ldm])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd3) : "m"(base[3 * ldm])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd4) : "m"(base[4 * ldm])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd5) : "m"(base[5 * ldm])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd6) : "m"(base[6 * ldm])); \ + __asm__ volatile ("flw %0, %1" : "=f"(fd7) : "m"(base[7 * ldm])); \ + ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \ + return ret + +__attribute__((always_inline)) mf32x8_t vx_wldm_ad_f32(const void* src, int row, size_t ldm) { + MAKE_VX_WLDM_D_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15"); +} + +__attribute__((always_inline)) mf32x8_t vx_wldm_at_f32(const void* src, int col, size_t ldm) { + MAKE_VX_WLDM_T_F32("f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15"); +} + +__attribute__((always_inline)) mf32x8_t vx_wldm_bd_f32(const void* src, int row, size_t ldm) { + MAKE_VX_WLDM_D_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31"); +} + +__attribute__((always_inline)) mf32x8_t vx_wldm_bt_f32(const void* src, int col, size_t ldm) { + MAKE_VX_WLDM_T_F32("f24", "f25", "f26", "f27", "f28", "f29", "f30", "f31"); +} + +__attribute__((always_inline)) void vx_wstm_f32(void* dst, const mf32x8_t& src, int row, int col, size_t ldm) { + mf32x8_t ret; + auto base = (float*)dst + row * ldm + col; + auto base_2row = base + 2 * ldm; + __asm__ volatile("fsw %0, %1" ::"f"(src[0]), "m"(base[0])); + __asm__ volatile("fsw %0, %1" ::"f"(src[1]), "m"(base[1])); + __asm__ volatile("fsw %0, %1" ::"f"(src[2]), "m"(base_2row[0])); + __asm__ volatile("fsw %0, %1" ::"f"(src[3]), "m"(base_2row[1])); + __asm__ volatile("fsw %0, %1" ::"f"(src[4]), "m"(base[4])); + __asm__ volatile("fsw %0, %1" ::"f"(src[5]), "m"(base[5])); + __asm__ volatile("fsw %0, %1" ::"f"(src[6]), "m"(base_2row[4])); + __asm__ volatile("fsw %0, %1" ::"f"(src[7]), "m"(base_2row[5])); +} + +#define MAKE_VX_HMMA_844_D_F32_STEP(fmt, step, rd_lo, rd_hi, rs1, rs2, rs3_lo, rs3_hi) \ + __asm__ volatile (".word %1" : "=r"(rd_lo) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 0, 1)), "r"(rs1), "r"(rs2), "r"(rs3_lo)); \ + __asm__ volatile (".word %1" : "=r"(rd_hi) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 1, 1)), "r"(rs1), "r"(rs2), "r"(rs3_hi)) + +#define MAKE_VX_HMMA_844_D_F32(fmt) \ + mf32x8_t ret; \ + register float fd0 __asm__("f16"); \ + register float fd1 __asm__("f17"); \ + register float fd2 __asm__("f18"); \ + register float fd3 __asm__("f19"); \ + register float fd4 __asm__("f20"); \ + register float fd5 __asm__("f21"); \ + register float fd6 __asm__("f22"); \ + register float fd7 __asm__("f23"); \ + register float fa0 __asm__("f8") = a[0]; \ + register float fa1 __asm__("f9" ) = a[1]; \ + register float fa2 __asm__("f10") = a[2]; \ + register float fa3 __asm__("f11") = a[3]; \ + register float fa4 __asm__("f12") = a[4]; \ + register float fa5 __asm__("f13") = a[5]; \ + register float fa6 __asm__("f14") = a[6]; \ + register float fa7 __asm__("f15") = a[7]; \ + register float fb0 __asm__("f24") = b[0]; \ + register float fb1 __asm__("f25") = b[1]; \ + register float fb2 __asm__("f26") = b[2]; \ + register float fb3 __asm__("f27") = b[3]; \ + register float fb4 __asm__("f28") = b[4]; \ + register float fb5 __asm__("f29") = b[5]; \ + register float fb6 __asm__("f30") = b[6]; \ + register float fb7 __asm__("f31") = b[7]; \ + register float fc0 __asm__("f0") = c[0]; \ + register float fc1 __asm__("f1") = c[1]; \ + register float fc2 __asm__("f2") = c[2]; \ + register float fc3 __asm__("f3") = c[3]; \ + register float fc4 __asm__("f4") = c[4]; \ + register float fc5 __asm__("f5") = c[5]; \ + register float fc6 __asm__("f6") = c[6]; \ + register float fc7 __asm__("f7") = c[7]; \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 0, fd0, fd1, fa0, fb0, fc0, fc1); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 1, fd2, fd3, fa0, fb0, fc2, fc3); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 2, fd4, fd5, fa0, fb0, fc4, fc5); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 3, fd6, fd7, fa0, fb0, fc6, fc7); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 4, fd0, fd1, fa1, fb1, fc0, fc1); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 5, fd2, fd3, fa1, fb1, fc2, fc3); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 6, fd4, fd5, fa1, fb1, fc4, fc5); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 7, fd6, fd7, fa1, fb1, fc6, fc7); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 8, fd0, fd1, fa2, fb2, fc0, fc1); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 9, fd2, fd3, fa2, fb2, fc2, fc3); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 10, fd4, fd5, fa2, fb2, fc4, fc5); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 11, fd6, fd7, fa2, fb2, fc6, fc7); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 12, fd0, fd1, fa3, fb3, fc0, fc1); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 13, fd2, fd3, fa3, fb3, fc2, fc3); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 14, fd4, fd5, fa3, fb3, fc4, fc5); \ + MAKE_VX_HMMA_844_D_F32_STEP(fmt, 15, fd6, fd7, fa3, fb3, fc6, fc7); \ + ret = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; \ + return ret + +#define MAKE_VX_HMMA_844_C_F32_STEP(fmt, step, rd_lo, rd_hi, rs1, rs2) \ + __asm__ volatile (".word %1" : "=r"(rd_lo) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 0, 0)), "r"(rs1), "r"(rs2), "r"(rd_lo)); \ + __asm__ volatile (".word %1" : "=r"(rd_hi) : "i"(RISCV_INSN_R(RISCV_CUSTOM0, 0, 2, fmt, step * 2 + 1, 0)), "r"(rs1), "r"(rs2), "r"(rd_hi)) + +#define MAKE_VX_HMMA_844_C_F32(fmt) \ + mf32x8_t ret; \ + register float fa0 __asm__("f8") = a[0]; \ + register float fa1 __asm__("f9" ) = a[1]; \ + register float fa2 __asm__("f10") = a[2]; \ + register float fa3 __asm__("f11") = a[3]; \ + register float fa4 __asm__("f12") = a[4]; \ + register float fa5 __asm__("f13") = a[5]; \ + register float fa6 __asm__("f14") = a[6]; \ + register float fa7 __asm__("f15") = a[7]; \ + register float fb0 __asm__("f24") = b[0]; \ + register float fb1 __asm__("f25") = b[1]; \ + register float fb2 __asm__("f26") = b[2]; \ + register float fb3 __asm__("f27") = b[3]; \ + register float fb4 __asm__("f28") = b[4]; \ + register float fb5 __asm__("f29") = b[5]; \ + register float fb6 __asm__("f30") = b[6]; \ + register float fb7 __asm__("f31") = b[7]; \ + register float fc0 __asm__("f0") = c[0]; \ + register float fc1 __asm__("f1") = c[1]; \ + register float fc2 __asm__("f2") = c[2]; \ + register float fc3 __asm__("f3") = c[3]; \ + register float fc4 __asm__("f4") = c[4]; \ + register float fc5 __asm__("f5") = c[5]; \ + register float fc6 __asm__("f6") = c[6]; \ + register float fc7 __asm__("f7") = c[7]; \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 0, fc0, fc1, fa0, fb0); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 1, fc2, fc3, fa0, fb0); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 2, fc4, fc5, fa0, fb0); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 3, fc6, fc7, fa0, fb0); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 4, fc0, fc1, fa1, fb1); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 5, fc2, fc3, fa1, fb1); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 6, fc4, fc5, fa1, fb1); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 7, fc6, fc7, fa1, fb1); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 8, fc0, fc1, fa2, fb2); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 9, fc2, fc3, fa2, fb2); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 10, fc4, fc5, fa2, fb2); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 11, fc6, fc7, fa2, fb2); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 12, fc0, fc1, fa3, fb3); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 13, fc2, fc3, fa3, fb3); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 14, fc4, fc5, fa3, fb3); \ + MAKE_VX_HMMA_844_C_F32_STEP(fmt, 15, fc6, fc7, fa3, fb3); \ + ret = {fc0, fc1, fc2, fc3, fc4, fc5, fc6, fc7}; \ + return ret + +__attribute__((always_inline)) mf32x8_t vx_hmma_844_c_f16_f32(const mf32x8_t& a, const mf32x8_t& b, const mf32x8_t& c) { + MAKE_VX_HMMA_844_C_F32(0); +} + +__attribute__((always_inline)) mf32x8_t vx_hmma_844_d_f16_f32(const mf32x8_t& a, const mf32x8_t& b, const mf32x8_t& c) { + MAKE_VX_HMMA_844_D_F32(0); +} + #ifdef __cplusplus } #endif diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index f6edcbfba..e68206c63 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -16,48 +16,181 @@ #include #include +#include +#include -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __cplusplus -} +#ifndef NUM_LANES +#define NUM_LANES 8 #endif namespace tensor { -enum frag_layout_t { row_major, col_major }; -enum mem_layout_t { mem_row_major, mem_col_major }; +enum frag_use_t { matrix_d, matrix_a, matrix_b, matrix_c }; +enum layout_t { row_major, col_major }; -template +template struct fragment { - typedef T DType; - static const frag_layout_t Layout = L; - typedef T VType __attribute__((vector_size(8 * sizeof(void*)))); - VType data; + typedef T Type; + static const frag_use_t Use = U; + static const layout_t Layout = L; + mf32x8_t data; }; -template -void fill_fragment(Frag &frag, size_t value) { - // empty skeleton +__attribute__((always_inline)) void map_operand_ab_32lanes(int tid, int &row, int &col) { + int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; +} + +__attribute__((always_inline)) void map_operand_ab_8lanes(int tid, int &row, int &col) { + int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +__attribute__((always_inline)) void map_operand_c_32lanes(int tid, int &row, int &col) { + int tg = tid / 4; + + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +__attribute__((always_inline)) void map_operand_c_8lanes(int tid, int &row, int &col) { + int tg = tid / 4; + + // Figure 7(b), left + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +__attribute__((always_inline)) void map_operand_ab(int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_operand_ab_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_operand_ab_8lanes(tid, row, col); + } else { + static_assert(NUM_LANES == 32 || NUM_LANES == 8, "NUM_LANES must be 8 or 32"); + } +} + +__attribute__((always_inline)) void map_operand_c(int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_operand_c_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_operand_c_8lanes(tid, row, col); + } else { + static_assert(NUM_LANES == 32 || NUM_LANES == 8, "NUM_LANES must be 8 or 32"); + } } template -void load_matrix_sync(Frag &frag, const void *ptr, size_t ld) { - // empty skeleton +__attribute__((always_inline)) void fill_fragment(Frag &dst, size_t value) { + if constexpr (Frag::Use == matrix_d) { + dst.data = vx_wsetm_d_f32(value); + } else if constexpr (Frag::Use == matrix_a) { + dst.data = vx_wsetm_a_f32(value); + } else if constexpr (Frag::Use == matrix_b) { + dst.data = vx_wsetm_b_f32(value); + } else if constexpr (Frag::Use == matrix_c) { + dst.data = vx_wsetm_c_f32(value); + } +} + +template +__attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) { + int row, col; + int tid = vx_thread_id(); + map_operand_ab(tid, row, col); + if constexpr (Frag::Use == matrix_a) { + if constexpr (Frag::Layout == mem_layout) { + dst.data = vx_wldm_ad_f32(src, row, ldm); + } else { + dst.data = vx_wldm_at_f32(src, col, ldm); + } + } else if constexpr (Frag::Use == matrix_b) { + if constexpr (Frag::Layout == mem_layout) { + dst.data = vx_wldm_bd_f32(src, row, ldm); + } else { + dst.data = vx_wldm_bt_f32(src, col, ldm); + } + } else { + static_assert(false, "Only matrix_a and matrix_b are supported!"); + } +} + +template +__attribute__((always_inline)) void store_matrix_sync(void *dst, const Frag &src, size_t ldm) { + static_assert(Frag::Layout == mem_layout, "fragment layout should match memory!"); + int row, col; + int tid = vx_thread_id(); + map_operand_c(tid, row, col); + if constexpr (Frag::Use == matrix_c) { + vx_wstm_f32(dst, src.data, row, col, ldm); + } else if constexpr (Frag::Use == matrix_d) { + vx_wstm_f32(dst, src.data, row, col, ldm); + } else { + static_assert(false, "Only matrix_c or matrix_c are supported!"); + } } -// Perform the matrix multiply-accumulate: D = A * B + C template -void mma_sync(FragD &D, const FragA &A, const FragB &B, const FragC &C) { - // empty skeleton -} +__attribute__((always_inline)) void mma_sync(FragD &D, const FragA &A, const FragB &B, const FragC &C) { + static_assert(FragA::Use == matrix_a, "A must be matrix_a"); + static_assert(FragB::Use == matrix_b, "B must be matrix_b"); + static_assert(FragC::Use == matrix_c, "C must be matrix_c"); + static_assert(FragD::Use == matrix_d || FragD::Use == matrix_c, "D must be matrix_d or matrix_c"); + static_assert(std::is_same_v, "A and B must have the same type"); + static_assert(std::is_same_v, "C and D must have the same type"); -// Store a fragment result back to global memory -template -void store_matrix_sync(void *ptr, const Frag &frag, size_t ld, mem_layout_t layout) { - // empty skeleton + if constexpr (std::is_same_v + && std::is_same_v) { + if constexpr (FragD::Use == matrix_d) { + D.data = vx_hmma_844_d_f16_f32(A.data, B.data, C.data); + } else { + D.data = vx_hmma_844_c_f16_f32(A.data, B.data, C.data); + } + } else { + static_assert(false, "Unsupported type!"); + } } } // namespace wmma diff --git a/sim/common/hfloats.h b/sim/common/hfloats.h new file mode 100644 index 000000000..8f18b5e4c --- /dev/null +++ b/sim/common/hfloats.h @@ -0,0 +1,109 @@ +// Copyright © 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +// A minimal IEEE 754 half-precision (16-bit) float implementation +// Provides conversion to/from 32-bit float and basic arithmetic operators. + +namespace vortex { + +struct half_t { + uint16_t bits; + + half_t() = default; + // Construct from float + half_t(float f) { bits = float_to_half(f); } + // Convert to float + operator float() const { return half_to_float(bits); } + + // Arithmetic operators + friend half_t operator+(half_t a, half_t b) { return half_t((float)a + (float)b); } + friend half_t operator-(half_t a, half_t b) { return half_t((float)a - (float)b); } + friend half_t operator*(half_t a, half_t b) { return half_t((float)a * (float)b); } + friend half_t operator/(half_t a, half_t b) { return half_t((float)a / (float)b); } + +private: + static uint16_t float_to_half(float f) { + uint32_t x; + std::memcpy(&x, &f, sizeof(x)); + uint32_t sign = (x >> 16) & 0x8000; + uint32_t mant = x & 0x007FFFFF; + uint32_t exp = x & 0x7F800000; + + if (exp >= 0x47800000) { + // Inf or NaN + if (mant && exp == 0x7F800000) { + // NaN: preserve some payload + return static_cast(sign | 0x0200); + } + // Infinity + return static_cast(sign | 0x7C00); + } + if (exp <= 0x38000000) { + // Subnormal or zero + if (exp < 0x33000000) { + // Too small: underflows to zero + return static_cast(sign); + } + // Subnormal + mant |= 0x00800000; + int shift = 113 - (exp >> 23); + mant = (mant >> shift) + ((mant >> (shift - 1)) & 1); + return static_cast(sign | (mant & 0x03FF)); + } + // Normalized number + uint16_t h_exp = static_cast(((exp - 0x38000000) >> 13) & 0x7C00); + uint16_t h_mant = static_cast(mant >> 13); + return static_cast(sign | h_exp | h_mant); + } + + static float half_to_float(uint16_t h) { + uint32_t sign = (h & 0x8000) << 16; + uint32_t exp = h & 0x7C00; + uint32_t mant = h & 0x03FF; + uint32_t f; + + if (exp == 0x7C00) { + // Inf or NaN + f = sign | 0x7F800000 | (mant << 13); + } else if (exp != 0) { + // Normalized + uint32_t e = ((exp >> 10) + 112) << 23; + f = sign | e | (mant << 13); + } else if (mant != 0) { + // Subnormal + mant <<= 1; + int e = -1; + while (!(mant & 0x0400)) { + mant <<= 1; + e--; + } + mant &= 0x03FF; + uint32_t f_e = static_cast(e + 1 + 127) << 23; + f = sign | f_e | (mant << 13); + } else { + // Zero + f = sign; + } + float result; + std::memcpy(&result, &f, sizeof(result)); + return result; + } +}; + +} // namespace vortex \ No newline at end of file diff --git a/sim/simx/Makefile b/sim/simx/Makefile index d18310832..f94fd13f7 100644 --- a/sim/simx/Makefile +++ b/sim/simx/Makefile @@ -24,6 +24,11 @@ SRCS += $(SRC_DIR)/execute.cpp $(SRC_DIR)/func_unit.cpp SRCS += $(SRC_DIR)/cache_sim.cpp $(SRC_DIR)/mem_sim.cpp $(SRC_DIR)/local_mem.cpp $(SRC_DIR)/mem_coalescer.cpp SRCS += $(SRC_DIR)/dcrs.cpp $(SRC_DIR)/types.cpp +# Add TPU extension sources +ifneq ($(findstring -DEXT_TPU_ENABLE, $(CONFIGS)),) + SRCS += $(SRC_DIR)/tensor_unit.cpp +endif + # Add V extension sources ifneq ($(findstring -DEXT_V_ENABLE, $(CONFIGS)),) SRCS += $(SRC_DIR)/voperands.cpp diff --git a/sim/simx/core.cpp b/sim/simx/core.cpp index d8b03408c..0d06dcf40 100644 --- a/sim/simx/core.cpp +++ b/sim/simx/core.cpp @@ -39,6 +39,9 @@ Core::Core(const SimContext& ctx, , core_id_(core_id) , socket_(socket) , arch_(arch) +#ifdef EXT_TPU_ENABLE + , tensor_unit_(TensorUnit::Create("tpu", arch, this)) +#endif #ifdef EXT_V_ENABLE , vec_unit_(VecUnit::Create("vpu", arch, this)) #endif @@ -136,6 +139,9 @@ Core::Core(const SimContext& ctx, dispatchers_.at((int)FUType::FPU) = SimPlatform::instance().create_object(this, 2, NUM_FPU_BLOCKS, NUM_FPU_LANES); dispatchers_.at((int)FUType::LSU) = SimPlatform::instance().create_object(this, 2, NUM_LSU_BLOCKS, NUM_LSU_LANES); dispatchers_.at((int)FUType::SFU) = SimPlatform::instance().create_object(this, 2, NUM_SFU_BLOCKS, NUM_SFU_LANES); +#ifdef EXT_TPU_ENABLE + dispatchers_.at((int)FUType::TPU) = SimPlatform::instance().create_object(this, 2, NUM_VPU_BLOCKS, NUM_VPU_LANES); +#endif #ifdef EXT_V_ENABLE dispatchers_.at((int)FUType::VPU) = SimPlatform::instance().create_object(this, 2, NUM_VPU_BLOCKS, NUM_VPU_LANES); #endif @@ -145,6 +151,9 @@ Core::Core(const SimContext& ctx, func_units_.at((int)FUType::FPU) = SimPlatform::instance().create_object(this); func_units_.at((int)FUType::LSU) = SimPlatform::instance().create_object(this); func_units_.at((int)FUType::SFU) = SimPlatform::instance().create_object(this); +#ifdef EXT_TPU_ENABLE + func_units_.at((int)FUType::TPU) = SimPlatform::instance().create_object(this); +#endif #ifdef EXT_V_ENABLE func_units_.at((int)FUType::VPU) = SimPlatform::instance().create_object(this); #endif @@ -341,6 +350,9 @@ void Core::issue() { default: assert(false); } } break; + #ifdef EXT_TPU_ENABLE + case FUType::TPU: ++perf_stats_.scrb_tpu; break; + #endif #ifdef EXT_V_ENABLE case FUType::VPU: ++perf_stats_.scrb_vpu; break; #endif diff --git a/sim/simx/core.h b/sim/simx/core.h index 430026569..322f90885 100644 --- a/sim/simx/core.h +++ b/sim/simx/core.h @@ -59,6 +59,9 @@ public: uint64_t scrb_sfu; uint64_t scrb_csrs; uint64_t scrb_wctl; + #ifdef EXT_TPU_ENABLE + uint64_t scrb_tpu; + #endif #ifdef EXT_V_ENABLE uint64_t vinstrs; uint64_t scrb_vpu; @@ -83,6 +86,9 @@ public: , scrb_sfu(0) , scrb_csrs(0) , scrb_wctl(0) + #ifdef EXT_TPU_ENABLE + , scrb_tpu(0) + #endif #ifdef EXT_V_ENABLE , vinstrs(0) , scrb_vpu(0) @@ -155,6 +161,12 @@ public: return emulator_.dcache_write(data, addr, size); } +#ifdef EXT_TPU_ENABLE + TensorUnit::Ptr& tensor_unit() { + return tensor_unit_; + } +#endif + #ifdef EXT_V_ENABLE VecUnit::Ptr& vec_unit() { return vec_unit_; @@ -182,6 +194,10 @@ private: Socket* socket_; const Arch& arch_; +#ifdef EXT_TPU_ENABLE + TensorUnit::Ptr tensor_unit_; +#endif + #ifdef EXT_V_ENABLE VecUnit::Ptr vec_unit_; #endif diff --git a/sim/simx/decode.cpp b/sim/simx/decode.cpp index 371bde6e0..b2840e16f 100644 --- a/sim/simx/decode.cpp +++ b/sim/simx/decode.cpp @@ -55,12 +55,12 @@ static const std::unordered_map sc_instTable = { }; static const char* op_string(const Instr &instr) { - auto opcode = instr.getOpcode(); + auto opcode = instr.getOpcode(); auto funct2 = instr.getFunct2(); auto funct3 = instr.getFunct3(); auto funct7 = instr.getFunct7(); - auto rd = instr.getDestReg(); - auto rs1 = instr.getSrcReg(1); + auto rd = instr.getDestReg(); + auto rs1 = instr.getSrcReg(1); auto imm = instr.getImm(); switch (opcode) { @@ -386,23 +386,29 @@ static const char* op_string(const Instr &instr) { default: std::abort(); } + case 1: + switch (funct3) { + case 0: // gfx reserved + std::abort(); + default: + std::abort(); + } + #ifdef EXT_TPU_ENABLE + case 2: + switch (funct3) { + case 0: return "HMMA844"; + default: + std::abort(); + } + #endif default: std::abort(); } case Opcode::EXT2: switch(funct3) { - case 0: // reserved - case 1: // reserved + case 0: // gfx reserved + case 1: // gfx reserved std::abort(); - case 2: - switch (funct2) { - case 0: return "MMADD.u4_i32"; - case 1: return "MMADD.u8_i32"; - case 2: return "MMADD.f16_f32"; - case 3: return "MMADD.bf16_f32"; - default: - std::abort(); - } default: std::abort(); } @@ -468,7 +474,7 @@ std::ostream &operator<<(std::ostream &os, const Instr &instr) { } } -std::shared_ptr Emulator::decode(uint32_t code) const { +Instr::Ptr Emulator::decode(uint32_t code) const { auto instr = std::allocate_shared(instr_pool_); auto op = Opcode((code >> shift_opcode) & mask_opcode); instr->setOpcode(op); @@ -574,6 +580,31 @@ std::shared_ptr Emulator::decode(uint32_t code) const { std::abort(); } break; + #ifdef EXT_TPU_ENABLE + case 2: { + switch (funct3) { + case 0: { // HMMA844 + uint32_t fmt = rd; + uint32_t steps = rs1 >> 1; + uint32_t step = steps % 4; + uint32_t set = steps / 4; + uint32_t rd_pair = rs1 & 0x1; + uint32_t use_d = rs2; + uint32_t base_rd = (use_d ? 16 : 0) + (step * 2 + rd_pair); // C/D + uint32_t base_rs1 = 8 + set; // A + uint32_t base_rs2 = 24 + set; // B + uint32_t base_rs3 = 0 + step; // C + instr->setImm((fmt << 2) + step); // fmt + step + instr->setDestReg(base_rd, RegType::Float); + instr->addSrcReg(base_rs1, RegType::Float); + instr->addSrcReg(base_rs2, RegType::Float); + instr->addSrcReg(base_rs3, RegType::Float); + } break; + default: + std::abort(); + } + } break; + #endif default: std::abort(); } diff --git a/sim/simx/emulator.cpp b/sim/simx/emulator.cpp index 4b78832d2..0ec1f67af 100644 --- a/sim/simx/emulator.cpp +++ b/sim/simx/emulator.cpp @@ -79,6 +79,9 @@ Emulator::Emulator(const Arch &arch, const DCRS &dcrs, Core* core) , warps_(arch.num_warps(), arch.num_threads()) , barriers_(arch.num_barriers(), 0) , ipdom_size_(arch.num_threads()-1) + #ifdef EXT_TPU_ENABLE + , tensor_unit_(core->tensor_unit()) + #endif #ifdef EXT_V_ENABLE , vec_unit_(core->vec_unit()) #endif @@ -162,7 +165,7 @@ instr_trace_t* Emulator::step() { if (scheduled_warp == -1) return nullptr; - // suspend warp until decode + // get scheduled warp auto& warp = warps_.at(scheduled_warp); assert(warp.tmask.any()); diff --git a/sim/simx/emulator.h b/sim/simx/emulator.h index b4ead5bf6..4d7914578 100644 --- a/sim/simx/emulator.h +++ b/sim/simx/emulator.h @@ -19,6 +19,10 @@ #include #include #include "types.h" +#include "instr.h" +#ifdef EXT_TPU_ENABLE +#include "tensor_unit.h" +#endif #ifdef EXT_V_ENABLE #include "vec_unit.h" #endif @@ -105,7 +109,7 @@ public: private: - std::shared_ptr decode(uint32_t code) const; + Instr::Ptr decode(uint32_t code) const; void execute(const Instr &instr, uint32_t wid, instr_trace_t *trace); @@ -146,6 +150,10 @@ private: Word csr_mscratch_; wspawn_t wspawn_; +#ifdef EXT_TPU_ENABLE + TensorUnit::Ptr tensor_unit_; +#endif + #ifdef EXT_V_ENABLE VecUnit::Ptr vec_unit_; #endif diff --git a/sim/simx/execute.cpp b/sim/simx/execute.cpp index 7eb52e187..353fbbc7e 100644 --- a/sim/simx/execute.cpp +++ b/sim/simx/execute.cpp @@ -1412,6 +1412,24 @@ void Emulator::execute(const Instr &instr, uint32_t wid, instr_trace_t *trace) { std::abort(); } } break; + #ifdef EXT_TPU_ENABLE + case 2: { + switch (funct3) { + case 0: { // HMMA844 + trace->fu_type = FUType::TPU; + trace->tpu_type = TpuType::HMMA844; + auto trace_data = std::make_shared(); + trace->data = trace_data; + uint32_t fmt = immsrc >> 2; + uint32_t step = immsrc & 0x3; + tensor_unit_->hmma844(wid, fmt, step, rs1_data, rs2_data, rs3_data, rd_data, trace_data.get()); + rd_write = true; + } break; + default: + std::abort(); + } + } break; + #endif default: std::abort(); } diff --git a/sim/simx/func_unit.cpp b/sim/simx/func_unit.cpp index ec9bf29c4..f66b10ba3 100644 --- a/sim/simx/func_unit.cpp +++ b/sim/simx/func_unit.cpp @@ -317,6 +317,25 @@ void SfuUnit::tick() { /////////////////////////////////////////////////////////////////////////////// +#ifdef EXT_TPU_ENABLE + +TpuUnit::TpuUnit(const SimContext& ctx, Core* core) + : FuncUnit(ctx, core, "tpu-unit") +{ + // bind tensor unit + for (uint32_t iw = 0; iw < ISSUE_WIDTH; ++iw) { + this->Inputs.at(iw).bind(&core_->tensor_unit()->Inputs.at(iw)); + core_->tensor_unit()->Outputs.at(iw).bind(&this->Outputs.at(iw)); + } +} + +void TpuUnit::tick() { + // use tensor_unit +} +#endif + +/////////////////////////////////////////////////////////////////////////////// + #ifdef EXT_V_ENABLE VpuUnit::VpuUnit(const SimContext& ctx, Core* core) diff --git a/sim/simx/func_unit.h b/sim/simx/func_unit.h index b380545cc..99d47038e 100644 --- a/sim/simx/func_unit.h +++ b/sim/simx/func_unit.h @@ -110,6 +110,19 @@ public: /////////////////////////////////////////////////////////////////////////////// +#ifdef EXT_TPU_ENABLE + +class TpuUnit : public FuncUnit { +public: + TpuUnit(const SimContext& ctx, Core*); + + void tick() override; +}; + +#endif + +/////////////////////////////////////////////////////////////////////////////// + #ifdef EXT_V_ENABLE class VpuUnit : public FuncUnit { diff --git a/sim/simx/instr.h b/sim/simx/instr.h index 43e024126..6c1cc98d3 100644 --- a/sim/simx/instr.h +++ b/sim/simx/instr.h @@ -125,6 +125,8 @@ enum VectorAttrMask { class Instr { public: + using Ptr = std::shared_ptr; + Instr() : opcode_(Opcode::NONE) , num_rsrcs_(0) diff --git a/sim/simx/instr_trace.h b/sim/simx/instr_trace.h index f40800933..727faa5bf 100644 --- a/sim/simx/instr_trace.h +++ b/sim/simx/instr_trace.h @@ -72,6 +72,9 @@ public: AluType alu_type; FpuType fpu_type; SfuType sfu_type; + #ifdef EXT_TPU_ENABLE + TpuType tpu_type; + #endif #ifdef EXT_V_ENABLE VpuType vpu_type; #endif diff --git a/sim/simx/tensor_unit.cpp b/sim/simx/tensor_unit.cpp index b72016c9d..b72d16ad1 100644 --- a/sim/simx/tensor_unit.cpp +++ b/sim/simx/tensor_unit.cpp @@ -13,40 +13,20 @@ // limitations under the License. #include "tensor_unit.h" +#include "core.h" using namespace vortex; -template -class FMAD : public SimObject> { -public: - SimPort Input; - SimPort Output; - - FMAD(const SimContext &ctx, const char* name) - : SimObject>(ctx, name) - , Input(this) - , Output(this) - {} - - virtual ~FMAD() {} - - void reset() { - //-- - } - - void tick() { - //-- - } -}; - class TensorUnit::Impl { public: - Impl(TensorUnit* simobject, const Config& config, Core* core) + Impl(TensorUnit* simobject, const Arch& arch, Core* core) : simobject_(simobject) - , config_(config) , core_(core) + , arch_(arch) , perf_stats_() - {} + { + //-- + } ~Impl() { // Destructor logic if needed @@ -57,7 +37,48 @@ public: } void tick() { - // Implement the tick logic here + for (uint32_t iw = 0; iw < ISSUE_WIDTH; ++iw) { + auto& input = simobject_->Inputs.at(iw); + if (input.empty()) + return; + auto trace = input.front(); + int delay = 0; + switch (trace->tpu_type) { + case TpuType::HMMA844: + delay = 4; + break; + default: + std::abort(); + } + simobject_->Outputs.at(iw).push(trace, 2 + delay); + DT(3, simobject_->name() << ": op=" << trace->tpu_type << ", " << *trace); + input.pop(); + } + } + + void hmma844(uint32_t wid, + uint32_t fmt, uint32_t step, + const std::vector& rs1_data, + const std::vector& rs2_data, + const std::vector& rs3_data, + std::vector& rd_data, + ExeTraceData* trace_data) { + uint32_t num_octects = arch_.num_threads() / 8; + uint32_t threadgroup_lane_offset = 4 * num_octects; + for (uint32_t i = 0; i < num_octects; ++i) { + std::vector octet_A(8); + std::vector octet_B(8); + std::vector octet_C(8); + std::vector octet_D(8); + + for (uint32_t j = 0; j < 8; ++j) { + octet_A[j] = rs1_data[i * 8 + j]; + octet_B[j] = rs2_data[i * 8 + j]; + octet_C[j] = rs3_data[i * 8 + j]; + octet_D[j] = rd_data[i * 8 + j]; + } + } + } const PerfStats& perf_stats() const { @@ -66,18 +87,18 @@ public: private: TensorUnit* simobject_; - Config config_; Core* core_; + Arch arch_; PerfStats perf_stats_; }; /////////////////////////////////////////////////////////////////////////////// -TensorUnit::TensorUnit(const SimContext &ctx, const char* name, const Config& config, Core* core) +TensorUnit::TensorUnit(const SimContext &ctx, const char* name, const Arch& arch, Core* core) : SimObject(ctx, name) - , Inputs(config.num_ports, this) - , Outputs(config.num_ports, this) - , impl_(new Impl(this, config, core)) + , Inputs(ISSUE_WIDTH, this) + , Outputs(ISSUE_WIDTH, this) + , impl_(new Impl(this, arch, core)) {} TensorUnit::~TensorUnit() { @@ -94,4 +115,14 @@ void TensorUnit::tick() { const TensorUnit::PerfStats &TensorUnit::perf_stats() const { return impl_->perf_stats(); +} + +void TensorUnit::hmma844(uint32_t wid, + uint32_t fmt, uint32_t step, + const std::vector& rs1_data, + const std::vector& rs2_data, + const std::vector& rs3_data, + std::vector& rd_data, + ExeTraceData* trace_data) { + impl_->hmma844(wid, fmt, step, rs1_data, rs2_data, rs3_data, rd_data, trace_data); } \ No newline at end of file diff --git a/sim/simx/tensor_unit.h b/sim/simx/tensor_unit.h index eaa84615f..befe662fd 100644 --- a/sim/simx/tensor_unit.h +++ b/sim/simx/tensor_unit.h @@ -22,15 +22,10 @@ class Core; class TensorUnit : public SimObject { public: - struct Config { - uint8_t num_ports; - uint8_t mac_latency; - Config() - : num_ports(0) - , mac_latency(0) - {} - }; + struct ExeTraceData : public ITraceData { + using Ptr = std::shared_ptr; + }; struct PerfStats { uint64_t latency; @@ -48,14 +43,21 @@ public: std::vector> Inputs; std::vector> Outputs; - TensorUnit(const SimContext &ctx, const char* name, const Config& config, Core* core); - + TensorUnit(const SimContext &ctx, const char* name, const Arch& arch, Core* core); virtual ~TensorUnit(); virtual void reset(); virtual void tick(); + void hmma844(uint32_t wid, + uint32_t fmt, uint32_t step, + const std::vector& rs1_data, + const std::vector& rs2_data, + const std::vector& rs3_data, + std::vector& rd_data, + ExeTraceData* trace_data); + const PerfStats& perf_stats() const; private: diff --git a/sim/simx/types.h b/sim/simx/types.h index a791c4293..4cc1208f0 100644 --- a/sim/simx/types.h +++ b/sim/simx/types.h @@ -127,6 +127,9 @@ enum class FUType { LSU, FPU, SFU, +#ifdef EXT_TPU_ENABLE + TPU, +#endif #ifdef EXT_V_ENABLE VPU, #endif @@ -139,6 +142,9 @@ inline std::ostream &operator<<(std::ostream &os, const FUType& type) { case FUType::LSU: os << "LSU"; break; case FUType::FPU: os << "FPU"; break; case FUType::SFU: os << "SFU"; break; +#ifdef EXT_TPU_ENABLE + case FUType::TPU: os << "TPU"; break; +#endif #ifdef EXT_V_ENABLE case FUType::VPU: os << "VPU"; break; #endif @@ -286,6 +292,20 @@ inline std::ostream &operator<<(std::ostream &os, const SfuType& type) { /////////////////////////////////////////////////////////////////////////////// +enum class TpuType { + HMMA844 = 0, +}; + +inline std::ostream &operator<<(std::ostream &os, const TpuType& type) { + switch (type) { + case TpuType::HMMA844: os << "HMMA844"; break; + default: assert(false); + } + return os; +} + +/////////////////////////////////////////////////////////////////////////////// + enum class VpuType { VSET = 0, diff --git a/sim/simx/vec_unit.cpp b/sim/simx/vec_unit.cpp index 9727fbbe2..e8f8fa0a4 100644 --- a/sim/simx/vec_unit.cpp +++ b/sim/simx/vec_unit.cpp @@ -96,10 +96,6 @@ public: DT(3, simobject_->name() << ": op=" << trace->vpu_type << ", " << *trace); - if (trace->eop && trace->fetch_stall) { - core_->resume(trace->wid); - } - input.pop(); } } diff --git a/sim/simx/vec_unit.h b/sim/simx/vec_unit.h index 7d719dfe9..dffa0c348 100644 --- a/sim/simx/vec_unit.h +++ b/sim/simx/vec_unit.h @@ -56,11 +56,7 @@ public: std::vector> Inputs; std::vector> Outputs; - VecUnit(const SimContext& ctx, - const char* name, - const Arch& arch, - Core* core); - + VecUnit(const SimContext& ctx, const char* name, const Arch& arch, Core* core); ~VecUnit(); void reset(); diff --git a/tests/regression/sgemm_tpu/common.h b/tests/regression/sgemm_tpu/common.h index b755970fc..720c0ac32 100644 --- a/tests/regression/sgemm_tpu/common.h +++ b/tests/regression/sgemm_tpu/common.h @@ -5,7 +5,7 @@ #include #ifndef I_TYPE -#define I_TYPE vortex::half_t +#define I_TYPE float #endif #ifndef O_TYPE diff --git a/tests/regression/sgemm_tpu/kernel.cpp b/tests/regression/sgemm_tpu/kernel.cpp index fde56e0da..719943ff5 100644 --- a/tests/regression/sgemm_tpu/kernel.cpp +++ b/tests/regression/sgemm_tpu/kernel.cpp @@ -1,46 +1,46 @@ +#include "common.h" #include #include -#include "common.h" -void kernel_body(kernel_arg_t* __UNIFORM__ arg) { - auto A = reinterpret_cast(arg->A_addr); - auto B = reinterpret_cast(arg->B_addr); - auto C = reinterpret_cast(arg->C_addr); +void kernel_body(kernel_arg_t *__UNIFORM__ arg) { + auto A = reinterpret_cast(arg->A_addr); + auto B = reinterpret_cast(arg->B_addr); + auto C = reinterpret_cast(arg->C_addr); - tensor::fragment fragA; - tensor::fragment fragB; - tensor::fragment fragC; + tensor::fragment fragA; + tensor::fragment fragB; + tensor::fragment fragC; - // calculate tile row & column based on block index - uint32_t tile_row = blockIdx.y * arg->tileM; - uint32_t tile_col = blockIdx.x * arg->tileN; + // calculate tile row & column based on block index + uint32_t tile_row = blockIdx.y * arg->tileM; + uint32_t tile_col = blockIdx.x * arg->tileN; - uint32_t N = arg->N; - uint32_t K = arg->K; - uint32_t tileK = arg->tileK; + uint32_t N = arg->N; + uint32_t K = arg->K; + uint32_t tileK = arg->tileK; - // Initialize accumulator tile to zero - tensor::fill_fragment(fragC, 0.0f); + // Initialize accumulator tile to zero + tensor::fill_fragment(fragC, 0); - for (int i = 0; i < K; i += tileK) { - // Load A tile - auto tileA = A + (tile_row * K + i); - tensor::load_matrix_sync(fragA, tileA, K); + for (int i = 0; i < K; i += tileK) { + // Load A tile + auto tileA = A + (tile_row * K + i); + tensor::load_matrix_sync(fragA, tileA, K); - // Load B tile - auto tileB = B + (i * k + tile_col); - tensor::load_matrix_sync(fragB, tileB, K); + // Load B tile + auto tileB = B + (i * K + tile_col); + tensor::load_matrix_sync(fragB, tileB, K); - // Matrix multiply-accumulate: c += a * b - tensor::mma_sync(fragC, fragA, fragB, fragC); - } + // Matrix multiply-accumulate: c += a * b + tensor::mma_sync(fragC, fragA, fragB, fragC); + } - // Store the computed C tile - auto tileC = C + (tile_row * N + tile_col); - tensor::store_matrix_sync(tileC, fragC, N, tensor::mem_row_major); + // Store the computed C tile + auto tileC = C + (tile_row * N + tile_col); + tensor::store_matrix_sync(tileC, fragC, N); } int main() { - kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH); - return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg); + kernel_arg_t *arg = (kernel_arg_t *)csr_read(VX_CSR_MSCRATCH); + return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg); } diff --git a/tests/regression/sgemm_tpu/main.cpp b/tests/regression/sgemm_tpu/main.cpp index a9356be62..5af8fe436 100644 --- a/tests/regression/sgemm_tpu/main.cpp +++ b/tests/regression/sgemm_tpu/main.cpp @@ -192,11 +192,15 @@ int main(int argc, char *argv[]) { uint64_t NT; RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &NT)); - std::cout << "GPU warp size: " << NT << std::endl; + if (NT < 4) { + std::cout << "Error: warp size must be at least 4 threads!" << std::endl; + return -1; + } + std::cout << "GPU warp size: " << NT << " threads" << std::endl; uint64_t isa_flags; RT_CHECK(vx_dev_caps(device, VX_CAPS_ISA_FLAGS, &isa_flags)); - uint32_t XlenB = 4 * VX_ISA_ARCH(isa_flags); + uint32_t XlenB = VX_ISA_ARCH(isa_flags) / 8; std::cout << "GPU XLEN: " << 8 * XlenB << std::endl; // tile format ratio @@ -206,22 +210,22 @@ int main(int argc, char *argv[]) { // determine tensor tile size uint32_t logNT = log2(NT); uint32_t tileM = 4 * (1 << (logNT / 2)) * o_ratio; - uint32_t tileN = (logNT % 2 == 0) ? tileM / 2 : tileN; + uint32_t tileN = (logNT % 2 == 0) ? (tileM / 2) : tileM; uint32_t tileK = std::min(tileM, tileN) * i_ratio; - std::cout << "GPU tensor tileM=" << tileM << ", tileN=" << tileM << ", tileK=" << tileM << std::endl; + std::cout << "GPU tensor tileM=" << tileM << ", tileN=" << tileM << ", tileK=" << tileK << std::endl; - if ((M & (tileM - 1)) != 0) { + if ((M % tileM) != 0) { std::cout << "Error: M must be a multiple of tensor tileM!" << std::endl; return -1; } - if ((N & (tileN - 1)) != 0) { + if ((N % tileN) != 0) { std::cout << "Error: M must be a multiple of tensor tileN!" << std::endl; return -1; } - if ((K & (tileK - 1)) != 0) { + if ((K % tileK) != 0) { std::cout << "Error: M must be a multiple of tensor tileK!" << std::endl; return -1; } @@ -234,8 +238,8 @@ int main(int argc, char *argv[]) { size_t sizeB = K * N; size_t sizeC = M * N; - std::cout << "input data type: " << Comparator::type_str() << std::endl; - std::cout << "output data type: " << Comparator::type_str() << std::endl; + std::cout << "input data type: " << Comparator::type_str() << " (" << sizeof(I_TYPE) << " bytes)" << std::endl; + std::cout << "output data type: " << Comparator::type_str() << " (" << sizeof(O_TYPE) << " bytes)" << std::endl; std::cout << "matrix A: " << M << "x" << K << std::endl; std::cout << "matrix B: " << K << "x" << N << std::endl; std::cout << "matrix C: " << M << "x" << N << std::endl;