TPU updates

This commit is contained in:
tinebp 2025-02-03 12:46:09 -08:00
parent a920025582
commit fbe8538573
13 changed files with 293 additions and 64 deletions

View file

@ -66,22 +66,22 @@ constexpr unsigned ceil2(T value) {
return (sizeof(T) * 8) - count_leading_zeros<T>(value);
}
inline uint64_t bit_clr(uint64_t bits, uint32_t index) {
constexpr uint64_t bit_clr(uint64_t bits, uint32_t index) {
assert(index <= 63);
return bits & ~(1ull << index);
}
inline uint64_t bit_set(uint64_t bits, uint32_t index) {
constexpr uint64_t bit_set(uint64_t bits, uint32_t index) {
assert(index <= 63);
return bits | (1ull << index);
}
inline bool bit_get(uint64_t bits, uint32_t index) {
constexpr bool bit_get(uint64_t bits, uint32_t index) {
assert(index <= 63);
return (bits >> index) & 0x1;
}
inline uint64_t bit_clrw(uint64_t bits, uint32_t start, uint32_t end) {
constexpr uint64_t bit_clrw(uint64_t bits, uint32_t start, uint32_t end) {
assert(end >= start);
assert(end <= 63);
uint32_t shift = 63 - end;
@ -89,7 +89,7 @@ inline uint64_t bit_clrw(uint64_t bits, uint32_t start, uint32_t end) {
return bits & ~mask;
}
inline uint64_t bit_setw(uint64_t bits, uint32_t start, uint32_t end, uint64_t value) {
constexpr uint64_t bit_setw(uint64_t bits, uint32_t start, uint32_t end, uint64_t value) {
assert(end >= start);
assert(end <= 63);
uint32_t shift = 63 - end;
@ -97,14 +97,14 @@ inline uint64_t bit_setw(uint64_t bits, uint32_t start, uint32_t end, uint64_t v
return bit_clrw(bits, start, end) | dirty;
}
inline uint64_t bit_getw(uint64_t bits, uint32_t start, uint32_t end) {
constexpr uint64_t bit_getw(uint64_t bits, uint32_t start, uint32_t end) {
assert(end >= start);
assert(end <= 63);
uint32_t shift = 63 - end;
return (bits << shift) >> (shift + start);
}
inline uint64_t bit_reverse(uint64_t bits) {
constexpr uint64_t bit_reverse(uint64_t bits) {
bits = ((bits & 0xAAAAAAAAAAAAAAAA) >> 1) | ((bits & 0x5555555555555555) << 1);
bits = ((bits & 0xCCCCCCCCCCCCCCCC) >> 2) | ((bits & 0x3333333333333333) << 2);
bits = ((bits & 0xF0F0F0F0F0F0F0F0) >> 4) | ((bits & 0x0F0F0F0F0F0F0F0F) << 4);
@ -114,7 +114,7 @@ inline uint64_t bit_reverse(uint64_t bits) {
return bits;
}
inline uint64_t bit_reverse(uint64_t bits, uint32_t width) {
constexpr uint64_t bit_reverse(uint64_t bits, uint32_t width) {
assert(width <= 64);
uint64_t reversed(0);
for (uint32_t i = 0; i < width; ++i) {
@ -126,7 +126,7 @@ inline uint64_t bit_reverse(uint64_t bits, uint32_t width) {
}
template <typename T = uint32_t>
T sext(const T& word, uint32_t width) {
constexpr T sext(const T& word, uint32_t width) {
assert(width > 1);
assert(width <= (sizeof(T) * 8));
if (width == (sizeof(T) * 8))
@ -136,7 +136,7 @@ T sext(const T& word, uint32_t width) {
}
template <typename T = uint32_t>
T zext(const T& word, uint32_t width) {
constexpr T zext(const T& word, uint32_t width) {
assert(width > 1);
assert(width <= (sizeof(T) * 8));
if (width == (sizeof(T) * 8))
@ -144,3 +144,8 @@ T zext(const T& word, uint32_t width) {
T mask((static_cast<T>(1) << width) - 1);
return word & mask;
}
constexpr int pow2_sqrt(int x) {
assert(ispow2(x));
return 1 << (count_trailing_zeros(x) / 2);
}

View file

@ -24,9 +24,13 @@ extern "C" {
#define F32_SIGN 0x80000000
#define F64_SIGN 0x8000000000000000
inline float16_t to_float16_t(uint16_t x) { return float16_t{x}; }
inline bfloat16_t to_bfloat16_t(uint16_t x) { return bfloat16_t{x}; }
inline float32_t to_float32_t(uint32_t x) { return float32_t{x}; }
inline float64_t to_float64_t(uint64_t x) { return float64_t{x}; }
inline uint16_t from_float16_t(float16_t x) { return uint16_t(x.v); }
inline uint16_t from_bfloat16_t(bfloat16_t x) { return uint16_t(x.v); }
inline uint32_t from_float32_t(float32_t x) { return uint32_t(x.v); }
inline uint64_t from_float64_t(float64_t x) { return uint64_t(x.v); }
@ -530,6 +534,34 @@ uint64_t rv_ftod(uint32_t a) {
return from_float64_t(r);
}
uint32_t rv_htof_s(uint16_t a, uint32_t frm, uint32_t* fflags) {
rv_init(frm);
auto r = f16_to_f32(to_float16_t(a));
if (fflags) { *fflags = softfloat_exceptionFlags; }
return from_float32_t(r);
}
uint16_t rv_ftoh_s(uint32_t a, uint32_t frm, uint32_t* fflags) {
rv_init(frm);
auto r = f32_to_f16(to_float32_t(a));
if (fflags) { *fflags = softfloat_exceptionFlags; }
return from_float16_t(r);
}
uint32_t rv_btof_s(uint16_t a, uint32_t frm, uint32_t* fflags) {
rv_init(frm);
auto r = bf16_to_f32(to_bfloat16_t(a));
if (fflags) { *fflags = softfloat_exceptionFlags; }
return from_float32_t(r);
}
uint16_t rv_ftob_s(uint32_t a, uint32_t frm, uint32_t* fflags) {
rv_init(frm);
auto r = f32_to_bf16(to_float32_t(a));
if (fflags) { *fflags = softfloat_exceptionFlags; }
return from_bfloat16_t(r);
}
#ifdef __cplusplus
}
#endif

View file

@ -1,10 +1,10 @@
// Copyright © 2019-2023
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@ -92,6 +92,12 @@ uint32_t rv_dtof(uint64_t a);
uint32_t rv_dtof_r(uint64_t a, uint32_t frm);
uint64_t rv_ftod(uint32_t a);
uint32_t rv_htof_s(uint16_t a, uint32_t frm, uint32_t* fflags);
uint16_t rv_ftoh_s(uint32_t a, uint32_t frm, uint32_t* fflags);
uint32_t rv_btof_s(uint16_t a, uint32_t frm, uint32_t* fflags);
uint16_t rv_ftob_s(uint32_t a, uint32_t frm, uint32_t* fflags);
#ifdef __cplusplus
}
#endif

View file

@ -14,6 +14,7 @@
#pragma once
#include <VX_config.h>
#include <bitmanip.h>
#ifndef RAM_PAGE_SIZE
#define RAM_PAGE_SIZE 4096
@ -40,4 +41,6 @@ inline constexpr int L2_NUM_REQS = NUM_SOCKETS * L1_MEM_PORTS;
inline constexpr int L3_NUM_REQS = NUM_CLUSTERS * L2_MEM_PORTS;
inline constexpr int PER_ISSUE_WARPS = NUM_WARPS / ISSUE_WIDTH;
inline constexpr int PER_ISSUE_WARPS = NUM_WARPS / ISSUE_WIDTH;
inline constexpr int TENSOR_TILE_SIZE = pow2_sqrt(NUM_THREADS);

View file

@ -54,7 +54,7 @@ Core::Core(const SimContext& ctx,
#ifdef EXT_TPU_ENABLE
{
snprintf(sname, 100, "%s-tpu", this->name().c_str());
tensor_unit_ = TensorUnit::Create(sname);
tensor_unit_ = TensorUnit::Create(sname, TENSOR_TILE_SIZE);
}
#endif

View file

@ -131,6 +131,12 @@ public:
return mem_coalescers_.at(idx);
}
#ifdef EXT_TPU_ENABLE
TensorUnit::Ptr& tensor_unit() {
return tensor_unit_;
}
#endif
const PerfStats& perf_stats() const {
return perf_stats_;
}

View file

@ -104,6 +104,9 @@ Emulator::Emulator(const Arch &arch, const DCRS &dcrs, Core* core)
: arch_(arch)
, dcrs_(dcrs)
, core_(core)
#ifdef EXT_TPU_ENABLE
, tensor_unit_(core->tensor_unit())
#endif
, warps_(arch.num_warps(), arch.num_threads())
, barriers_(arch.num_barriers(), 0)
, ipdom_size_(arch.num_threads()-1)

View file

@ -167,6 +167,11 @@ private:
const Arch& arch_;
const DCRS& dcrs_;
Core* core_;
#ifdef EXT_TPU_ENABLE
TensorUnit::Ptr tensor_unit_;
#endif
std::vector<warp_t> warps_;
WarpMask active_warps_;
WarpMask stalled_warps_;

View file

@ -1485,8 +1485,39 @@ void Emulator::execute(const Instr &instr, uint32_t wid, instr_trace_t *trace) {
case 0: // reserved
case 1: // reserved
std::abort();
case 2:
case 2: {
trace->fu_type = FUType::SFU;
trace->sfu_type = SfuType::MMADD;
trace->src_regs[0] = {RegType::Integer, rsrc0};
trace->src_regs[1] = {RegType::Integer, rsrc1};
trace->src_regs[2] = {RegType::Integer, rsrc2};
auto trace_data = std::make_shared<TensorUnit::TraceData>();
trace->data = trace_data;
TensorFormat from, to;
switch (func2) {
case 0: // INT8
from = TensorFormat::Int4;
to = TensorFormat::Int32;
break;
case 1: // INT16
from = TensorFormat::Int8;
to = TensorFormat::Int32;
break;
case 2: // FP16
from = TensorFormat::FP16;
to = TensorFormat::FP32;
break;
case 3: // BF16
from = TensorFormat::BF16;
to = TensorFormat::FP32;
break;
default:
std::abort();
}
tensor_unit_->mmadd(from, to, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
rd_write = true;
} break;
default:
std::abort();
}

View file

@ -265,10 +265,7 @@ void SfuUnit::tick() {
}
} break;
#ifdef EXT_TPU_ENABLE
case SfuType::MMADD_U4:
case SfuType::MMADD_U8:
case SfuType::MMADD_F16:
case SfuType::MMADD_BF16: {
case SfuType::MMADD: {
if (trace->eop) {
auto trace_data = std::dynamic_pointer_cast<TensorUnit::TraceData>(trace->data);
output.push(trace, trace_data->latency + delay);

View file

@ -14,30 +14,100 @@
#include "tensor_unit.h"
#include "mem.h"
#include <VX_config.h>
#include <rvfloats.h>
#include <algorithm>
using namespace vortex;
union flaot_uint32_t {
float f;
uint32_t u;
};
inline uint32_t read_element(const std::vector<reg_data_t>& reg_data, int index, TensorFormat format) {
switch (format) {
case TensorFormat::Int4: {
return reg_data.at(index / 8).u >> (index % 8);
}
case TensorFormat::Int8: {
return reg_data.at(index / 4).u >> (index % 4);
}
case TensorFormat::FP16: {
return reg_data.at(index / 2).u >> (index % 2);
}
case TensorFormat::BF16: {
return reg_data.at(index / 2).u >> (index % 2);
}
default: assert(false);
}
}
inline void write_element(std::vector<reg_data_t>& reg_data, int index, uint32_t value, TensorFormat format) {
switch (format) {
case TensorFormat::Int32:
case TensorFormat::FP32: {
reg_data.at(index).i = value;
break;
}
default: assert(false);
}
}
inline float type_to_float(uint32_t value, TensorFormat format) {
switch (format) {
case TensorFormat::Int4: {
flaot_uint32_t u2f;
u2f.u = rv_itof_s(value, 0, nullptr);
return u2f.f;
}
case TensorFormat::Int8: {
flaot_uint32_t u2f;
u2f.u = rv_itof_s(value, 0, nullptr);
return u2f.f;
}
case TensorFormat::FP16: {
flaot_uint32_t u2f;
u2f.u = rv_htof_s(value, 0, nullptr);
return u2f.f;
}
case TensorFormat::BF16: {
flaot_uint32_t u2f;
u2f.u = rv_btof_s(value, 0, nullptr);
return u2f.f;
}
default: assert(false);
}
return 0;
}
inline uint32_t float_to_type(float value, TensorFormat format) {
switch (format) {
case TensorFormat::Int32: {
flaot_uint32_t f2u;
f2u.f = value;
return rv_ftoi_s(f2u.u, 0, nullptr);
}
case TensorFormat::FP32: {
flaot_uint32_t f2u;
f2u.f = value;
return f2u.u;
}
default: assert(false);
}
return 0;
}
class TensorCore : public SimObject<TensorCore> {
public:
struct PerfStats {
uint64_t reads;
uint64_t writes;
uint64_t latency;
uint64_t stalls;
PerfStats()
: reads(0)
, writes(0)
, latency(0)
, stalls(0)
: latency(0)
{}
PerfStats& operator+=(const PerfStats& rhs) {
this->reads += rhs.reads;
this->writes += rhs.writes;
this->latency += rhs.latency;
this->stalls += rhs.stalls;
return *this;
}
};
@ -45,23 +115,56 @@ public:
SimPort<instr_trace_t*> Input;
SimPort<instr_trace_t*> Output;
TensorCore(const SimContext& ctx, const char* name);
~TensorCore();
TensorCore(const SimContext& ctx, const char* name, uint32_t tile_size)
: SimObject<TensorCore>(ctx, name)
, Input(this)
, Output(this)
, tile_size_(tile_size)
{}
void reset();
~TensorCore() {
this->reset();
}
void tick();
void reset() {
//--
}
void attach_ram(RAM* mem);
void tick() {
//--
}
void mmadd(TensorUnit::TraceData::Ptr trace_data);
void mmadd(TensorFormat from_format,
TensorFormat to_format,
const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
TensorUnit::TraceData::Ptr trace_data) {
assert(rd_data.size() <= tile_size_);
trace_data->latency = 2 + tile_size_;
// matrix multiplication and accumulation
for (uint32_t i = 0; i < tile_size_; i++) {
for (uint32_t j = 0; j < tile_size_; j++) {
float sum = type_to_float(read_element(rs3_data, i * tile_size_ + j, to_format), to_format);
for (uint32_t k = 0; k < tile_size_; k++) {
auto a = type_to_float(read_element(rs1_data, i * tile_size_ + k, from_format), from_format);
auto b = type_to_float(read_element(rs2_data, k * tile_size_ + j, from_format), from_format);
sum += a * b;
}
write_element(rd_data, i * tile_size_ + j, float_to_type(sum, to_format), to_format);
}
}
}
const PerfStats& perf_stats() const;
const PerfStats& perf_stats() const {
return perf_stats_;
}
private:
class Impl;
Impl* impl_;
PerfStats perf_stats_;
uint32_t tile_size_;
};
///////////////////////////////////////////////////////////////////////////////
@ -69,9 +172,17 @@ private:
class TensorUnit::Impl {
public:
Impl(TensorUnit* simobject)
Impl(const SimContext& ctx, TensorUnit* simobject, uint32_t tile_size)
: simobject_(simobject)
, tensor_cores_(NUM_TENSOR_CORES)
, tc_sel_(0)
{
char sname[100];
for (uint32_t i = 0; i < NUM_TENSOR_CORES; i++) {
snprintf(sname, 100, "%s-core%d", simobject->name().c_str(), i);
tensor_cores_[i] = TensorCore::Create(ctx, sname, tile_size);
}
this->reset();
}
@ -82,15 +193,26 @@ public:
}
void tick() {
//--
// forward input to tensor cores
auto& input = simobject_->Input;
if (input.empty())
return;
auto trace = input.front();
auto trace_data = std::dynamic_pointer_cast<TraceData>(trace->data);
tensor_cores_.at(trace_data->tc_idx)->Input.push(trace, 1);
input.pop();
}
void mmadd(const std::vector<reg_data_t>& rs1_data,
void mmadd(TensorFormat from_format,
TensorFormat to_format,
const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
TensorUnit::TraceData::Ptr& trace_data) {
//--
TensorUnit::TraceData::Ptr trace_data) {
tensor_cores_.at(tc_sel_)->mmadd(from_format, to_format, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
trace_data->tc_idx = tc_sel_;
tc_sel_ = (tc_sel_ + 1) % NUM_TENSOR_CORES;
}
const PerfStats& perf_stats() const {
@ -100,16 +222,18 @@ public:
private:
TensorUnit* simobject_;
std::vector<TensorCore::Ptr> tensor_cores_;
uint32_t tc_sel_;
PerfStats perf_stats_;
};
///////////////////////////////////////////////////////////////////////////////
TensorUnit::TensorUnit(const SimContext& ctx, const char* name)
TensorUnit::TensorUnit(const SimContext& ctx, const char* name, uint32_t tile_size)
: SimObject<TensorUnit>(ctx, name)
, Input(this)
, Output(this)
, impl_(new Impl(this))
, impl_(new Impl(ctx, this, tile_size))
{}
TensorUnit::~TensorUnit() {
@ -124,12 +248,14 @@ void TensorUnit::tick() {
impl_->tick();
}
void TensorUnit::mmadd(const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
TensorUnit::TraceData::Ptr& trace_data) {
impl_->mmadd(rs1_data, rs2_data, rs3_data, rd_data, trace_data);
void TensorUnit::mmadd(TensorFormat from_format,
TensorFormat to_format,
const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
TensorUnit::TraceData::Ptr trace_data) {
impl_->mmadd(from_format, to_format, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
}
const TensorUnit::PerfStats& TensorUnit::perf_stats() const {

View file

@ -18,38 +18,57 @@
namespace vortex {
enum class TensorFormat : int {
Int4 = 0,
Int8 = 1,
Int16 = 2,
Int32 = 3,
Int64 = 4,
FP16 = 5,
FP32 = 6,
FP64 = 7,
BF16 = 8,
_MAX = 9
};
class TensorUnit : public SimObject<TensorUnit> {
public:
struct TraceData : public ITraceData {
using Ptr = std::shared_ptr<TraceData>;
uint32_t tc_idx;
uint32_t latency;
};
struct PerfStats {
uint64_t latency;
uint64_t stalls;
PerfStats()
: latency(0)
, stalls(0)
{}
PerfStats& operator+=(const PerfStats& rhs) {
this->latency += rhs.latency;
return *this;
}
};
SimPort<instr_trace_t*> Input;
SimPort<instr_trace_t*> Output;
TensorUnit(const SimContext& ctx, const char* name);
TensorUnit(const SimContext& ctx, const char* name, uint32_t tile_size);
~TensorUnit();
void reset();
void tick();
void mmadd(const std::vector<reg_data_t>& rs1_data,
void mmadd(TensorFormat from_format,
TensorFormat to_format,
const std::vector<reg_data_t>& rs1_data,
const std::vector<reg_data_t>& rs2_data,
const std::vector<reg_data_t>& rs3_data,
std::vector<reg_data_t>& rd_data,
TensorUnit::TraceData::Ptr& trace_data);
TensorUnit::TraceData::Ptr trace_data);
const PerfStats& perf_stats() const;

View file

@ -245,10 +245,9 @@ enum class SfuType {
CSRRW,
CSRRS,
CSRRC,
MMADD_U4,
MMADD_U8,
MMADD_F16,
MMADD_BF16
#ifdef EXT_TPU_ENABLE
MMADD,
#endif
};
inline std::ostream &operator<<(std::ostream &os, const SfuType& type) {
@ -263,10 +262,7 @@ inline std::ostream &operator<<(std::ostream &os, const SfuType& type) {
case SfuType::CSRRS: os << "CSRRS"; break;
case SfuType::CSRRC: os << "CSRRC"; break;
#ifdef EXT_TPU_ENABLE
case SfuType::MMADD_U4: os << "MMADD_U4"; break;
case SfuType::MMADD_U8: os << "MMADD_U8"; break;
case SfuType::MMADD_F16: os << "MMADD_F16"; break;
case SfuType::MMADD_BF16: os << "MMADD_BF16"; break;
case SfuType::MMADD: os << "MMADD"; break;
#endif
default: assert(false);
}