mirror of
https://github.com/vortexgpgpu/vortex.git
synced 2025-04-24 05:47:35 -04:00
TPU updates
This commit is contained in:
parent
a920025582
commit
fbe8538573
13 changed files with 293 additions and 64 deletions
|
@ -66,22 +66,22 @@ constexpr unsigned ceil2(T value) {
|
|||
return (sizeof(T) * 8) - count_leading_zeros<T>(value);
|
||||
}
|
||||
|
||||
inline uint64_t bit_clr(uint64_t bits, uint32_t index) {
|
||||
constexpr uint64_t bit_clr(uint64_t bits, uint32_t index) {
|
||||
assert(index <= 63);
|
||||
return bits & ~(1ull << index);
|
||||
}
|
||||
|
||||
inline uint64_t bit_set(uint64_t bits, uint32_t index) {
|
||||
constexpr uint64_t bit_set(uint64_t bits, uint32_t index) {
|
||||
assert(index <= 63);
|
||||
return bits | (1ull << index);
|
||||
}
|
||||
|
||||
inline bool bit_get(uint64_t bits, uint32_t index) {
|
||||
constexpr bool bit_get(uint64_t bits, uint32_t index) {
|
||||
assert(index <= 63);
|
||||
return (bits >> index) & 0x1;
|
||||
}
|
||||
|
||||
inline uint64_t bit_clrw(uint64_t bits, uint32_t start, uint32_t end) {
|
||||
constexpr uint64_t bit_clrw(uint64_t bits, uint32_t start, uint32_t end) {
|
||||
assert(end >= start);
|
||||
assert(end <= 63);
|
||||
uint32_t shift = 63 - end;
|
||||
|
@ -89,7 +89,7 @@ inline uint64_t bit_clrw(uint64_t bits, uint32_t start, uint32_t end) {
|
|||
return bits & ~mask;
|
||||
}
|
||||
|
||||
inline uint64_t bit_setw(uint64_t bits, uint32_t start, uint32_t end, uint64_t value) {
|
||||
constexpr uint64_t bit_setw(uint64_t bits, uint32_t start, uint32_t end, uint64_t value) {
|
||||
assert(end >= start);
|
||||
assert(end <= 63);
|
||||
uint32_t shift = 63 - end;
|
||||
|
@ -97,14 +97,14 @@ inline uint64_t bit_setw(uint64_t bits, uint32_t start, uint32_t end, uint64_t v
|
|||
return bit_clrw(bits, start, end) | dirty;
|
||||
}
|
||||
|
||||
inline uint64_t bit_getw(uint64_t bits, uint32_t start, uint32_t end) {
|
||||
constexpr uint64_t bit_getw(uint64_t bits, uint32_t start, uint32_t end) {
|
||||
assert(end >= start);
|
||||
assert(end <= 63);
|
||||
uint32_t shift = 63 - end;
|
||||
return (bits << shift) >> (shift + start);
|
||||
}
|
||||
|
||||
inline uint64_t bit_reverse(uint64_t bits) {
|
||||
constexpr uint64_t bit_reverse(uint64_t bits) {
|
||||
bits = ((bits & 0xAAAAAAAAAAAAAAAA) >> 1) | ((bits & 0x5555555555555555) << 1);
|
||||
bits = ((bits & 0xCCCCCCCCCCCCCCCC) >> 2) | ((bits & 0x3333333333333333) << 2);
|
||||
bits = ((bits & 0xF0F0F0F0F0F0F0F0) >> 4) | ((bits & 0x0F0F0F0F0F0F0F0F) << 4);
|
||||
|
@ -114,7 +114,7 @@ inline uint64_t bit_reverse(uint64_t bits) {
|
|||
return bits;
|
||||
}
|
||||
|
||||
inline uint64_t bit_reverse(uint64_t bits, uint32_t width) {
|
||||
constexpr uint64_t bit_reverse(uint64_t bits, uint32_t width) {
|
||||
assert(width <= 64);
|
||||
uint64_t reversed(0);
|
||||
for (uint32_t i = 0; i < width; ++i) {
|
||||
|
@ -126,7 +126,7 @@ inline uint64_t bit_reverse(uint64_t bits, uint32_t width) {
|
|||
}
|
||||
|
||||
template <typename T = uint32_t>
|
||||
T sext(const T& word, uint32_t width) {
|
||||
constexpr T sext(const T& word, uint32_t width) {
|
||||
assert(width > 1);
|
||||
assert(width <= (sizeof(T) * 8));
|
||||
if (width == (sizeof(T) * 8))
|
||||
|
@ -136,7 +136,7 @@ T sext(const T& word, uint32_t width) {
|
|||
}
|
||||
|
||||
template <typename T = uint32_t>
|
||||
T zext(const T& word, uint32_t width) {
|
||||
constexpr T zext(const T& word, uint32_t width) {
|
||||
assert(width > 1);
|
||||
assert(width <= (sizeof(T) * 8));
|
||||
if (width == (sizeof(T) * 8))
|
||||
|
@ -144,3 +144,8 @@ T zext(const T& word, uint32_t width) {
|
|||
T mask((static_cast<T>(1) << width) - 1);
|
||||
return word & mask;
|
||||
}
|
||||
|
||||
constexpr int pow2_sqrt(int x) {
|
||||
assert(ispow2(x));
|
||||
return 1 << (count_trailing_zeros(x) / 2);
|
||||
}
|
|
@ -24,9 +24,13 @@ extern "C" {
|
|||
#define F32_SIGN 0x80000000
|
||||
#define F64_SIGN 0x8000000000000000
|
||||
|
||||
inline float16_t to_float16_t(uint16_t x) { return float16_t{x}; }
|
||||
inline bfloat16_t to_bfloat16_t(uint16_t x) { return bfloat16_t{x}; }
|
||||
inline float32_t to_float32_t(uint32_t x) { return float32_t{x}; }
|
||||
inline float64_t to_float64_t(uint64_t x) { return float64_t{x}; }
|
||||
|
||||
inline uint16_t from_float16_t(float16_t x) { return uint16_t(x.v); }
|
||||
inline uint16_t from_bfloat16_t(bfloat16_t x) { return uint16_t(x.v); }
|
||||
inline uint32_t from_float32_t(float32_t x) { return uint32_t(x.v); }
|
||||
inline uint64_t from_float64_t(float64_t x) { return uint64_t(x.v); }
|
||||
|
||||
|
@ -530,6 +534,34 @@ uint64_t rv_ftod(uint32_t a) {
|
|||
return from_float64_t(r);
|
||||
}
|
||||
|
||||
uint32_t rv_htof_s(uint16_t a, uint32_t frm, uint32_t* fflags) {
|
||||
rv_init(frm);
|
||||
auto r = f16_to_f32(to_float16_t(a));
|
||||
if (fflags) { *fflags = softfloat_exceptionFlags; }
|
||||
return from_float32_t(r);
|
||||
}
|
||||
|
||||
uint16_t rv_ftoh_s(uint32_t a, uint32_t frm, uint32_t* fflags) {
|
||||
rv_init(frm);
|
||||
auto r = f32_to_f16(to_float32_t(a));
|
||||
if (fflags) { *fflags = softfloat_exceptionFlags; }
|
||||
return from_float16_t(r);
|
||||
}
|
||||
|
||||
uint32_t rv_btof_s(uint16_t a, uint32_t frm, uint32_t* fflags) {
|
||||
rv_init(frm);
|
||||
auto r = bf16_to_f32(to_bfloat16_t(a));
|
||||
if (fflags) { *fflags = softfloat_exceptionFlags; }
|
||||
return from_float32_t(r);
|
||||
}
|
||||
|
||||
uint16_t rv_ftob_s(uint32_t a, uint32_t frm, uint32_t* fflags) {
|
||||
rv_init(frm);
|
||||
auto r = f32_to_bf16(to_float32_t(a));
|
||||
if (fflags) { *fflags = softfloat_exceptionFlags; }
|
||||
return from_bfloat16_t(r);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
// Copyright © 2019-2023
|
||||
//
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
@ -92,6 +92,12 @@ uint32_t rv_dtof(uint64_t a);
|
|||
uint32_t rv_dtof_r(uint64_t a, uint32_t frm);
|
||||
uint64_t rv_ftod(uint32_t a);
|
||||
|
||||
uint32_t rv_htof_s(uint16_t a, uint32_t frm, uint32_t* fflags);
|
||||
uint16_t rv_ftoh_s(uint32_t a, uint32_t frm, uint32_t* fflags);
|
||||
|
||||
uint32_t rv_btof_s(uint16_t a, uint32_t frm, uint32_t* fflags);
|
||||
uint16_t rv_ftob_s(uint32_t a, uint32_t frm, uint32_t* fflags);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <VX_config.h>
|
||||
#include <bitmanip.h>
|
||||
|
||||
#ifndef RAM_PAGE_SIZE
|
||||
#define RAM_PAGE_SIZE 4096
|
||||
|
@ -40,4 +41,6 @@ inline constexpr int L2_NUM_REQS = NUM_SOCKETS * L1_MEM_PORTS;
|
|||
|
||||
inline constexpr int L3_NUM_REQS = NUM_CLUSTERS * L2_MEM_PORTS;
|
||||
|
||||
inline constexpr int PER_ISSUE_WARPS = NUM_WARPS / ISSUE_WIDTH;
|
||||
inline constexpr int PER_ISSUE_WARPS = NUM_WARPS / ISSUE_WIDTH;
|
||||
|
||||
inline constexpr int TENSOR_TILE_SIZE = pow2_sqrt(NUM_THREADS);
|
|
@ -54,7 +54,7 @@ Core::Core(const SimContext& ctx,
|
|||
#ifdef EXT_TPU_ENABLE
|
||||
{
|
||||
snprintf(sname, 100, "%s-tpu", this->name().c_str());
|
||||
tensor_unit_ = TensorUnit::Create(sname);
|
||||
tensor_unit_ = TensorUnit::Create(sname, TENSOR_TILE_SIZE);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -131,6 +131,12 @@ public:
|
|||
return mem_coalescers_.at(idx);
|
||||
}
|
||||
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
TensorUnit::Ptr& tensor_unit() {
|
||||
return tensor_unit_;
|
||||
}
|
||||
#endif
|
||||
|
||||
const PerfStats& perf_stats() const {
|
||||
return perf_stats_;
|
||||
}
|
||||
|
|
|
@ -104,6 +104,9 @@ Emulator::Emulator(const Arch &arch, const DCRS &dcrs, Core* core)
|
|||
: arch_(arch)
|
||||
, dcrs_(dcrs)
|
||||
, core_(core)
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
, tensor_unit_(core->tensor_unit())
|
||||
#endif
|
||||
, warps_(arch.num_warps(), arch.num_threads())
|
||||
, barriers_(arch.num_barriers(), 0)
|
||||
, ipdom_size_(arch.num_threads()-1)
|
||||
|
|
|
@ -167,6 +167,11 @@ private:
|
|||
const Arch& arch_;
|
||||
const DCRS& dcrs_;
|
||||
Core* core_;
|
||||
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
TensorUnit::Ptr tensor_unit_;
|
||||
#endif
|
||||
|
||||
std::vector<warp_t> warps_;
|
||||
WarpMask active_warps_;
|
||||
WarpMask stalled_warps_;
|
||||
|
|
|
@ -1485,8 +1485,39 @@ void Emulator::execute(const Instr &instr, uint32_t wid, instr_trace_t *trace) {
|
|||
case 0: // reserved
|
||||
case 1: // reserved
|
||||
std::abort();
|
||||
case 2:
|
||||
case 2: {
|
||||
trace->fu_type = FUType::SFU;
|
||||
trace->sfu_type = SfuType::MMADD;
|
||||
trace->src_regs[0] = {RegType::Integer, rsrc0};
|
||||
trace->src_regs[1] = {RegType::Integer, rsrc1};
|
||||
trace->src_regs[2] = {RegType::Integer, rsrc2};
|
||||
auto trace_data = std::make_shared<TensorUnit::TraceData>();
|
||||
trace->data = trace_data;
|
||||
|
||||
TensorFormat from, to;
|
||||
switch (func2) {
|
||||
case 0: // INT8
|
||||
from = TensorFormat::Int4;
|
||||
to = TensorFormat::Int32;
|
||||
break;
|
||||
case 1: // INT16
|
||||
from = TensorFormat::Int8;
|
||||
to = TensorFormat::Int32;
|
||||
break;
|
||||
case 2: // FP16
|
||||
from = TensorFormat::FP16;
|
||||
to = TensorFormat::FP32;
|
||||
break;
|
||||
case 3: // BF16
|
||||
from = TensorFormat::BF16;
|
||||
to = TensorFormat::FP32;
|
||||
break;
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
tensor_unit_->mmadd(from, to, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
|
||||
rd_write = true;
|
||||
} break;
|
||||
default:
|
||||
std::abort();
|
||||
}
|
||||
|
|
|
@ -265,10 +265,7 @@ void SfuUnit::tick() {
|
|||
}
|
||||
} break;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
case SfuType::MMADD_U4:
|
||||
case SfuType::MMADD_U8:
|
||||
case SfuType::MMADD_F16:
|
||||
case SfuType::MMADD_BF16: {
|
||||
case SfuType::MMADD: {
|
||||
if (trace->eop) {
|
||||
auto trace_data = std::dynamic_pointer_cast<TensorUnit::TraceData>(trace->data);
|
||||
output.push(trace, trace_data->latency + delay);
|
||||
|
|
|
@ -14,30 +14,100 @@
|
|||
#include "tensor_unit.h"
|
||||
#include "mem.h"
|
||||
#include <VX_config.h>
|
||||
#include <rvfloats.h>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace vortex;
|
||||
|
||||
union flaot_uint32_t {
|
||||
float f;
|
||||
uint32_t u;
|
||||
};
|
||||
|
||||
inline uint32_t read_element(const std::vector<reg_data_t>& reg_data, int index, TensorFormat format) {
|
||||
switch (format) {
|
||||
case TensorFormat::Int4: {
|
||||
return reg_data.at(index / 8).u >> (index % 8);
|
||||
}
|
||||
case TensorFormat::Int8: {
|
||||
return reg_data.at(index / 4).u >> (index % 4);
|
||||
}
|
||||
case TensorFormat::FP16: {
|
||||
return reg_data.at(index / 2).u >> (index % 2);
|
||||
}
|
||||
case TensorFormat::BF16: {
|
||||
return reg_data.at(index / 2).u >> (index % 2);
|
||||
}
|
||||
default: assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
inline void write_element(std::vector<reg_data_t>& reg_data, int index, uint32_t value, TensorFormat format) {
|
||||
switch (format) {
|
||||
case TensorFormat::Int32:
|
||||
case TensorFormat::FP32: {
|
||||
reg_data.at(index).i = value;
|
||||
break;
|
||||
}
|
||||
default: assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
inline float type_to_float(uint32_t value, TensorFormat format) {
|
||||
switch (format) {
|
||||
case TensorFormat::Int4: {
|
||||
flaot_uint32_t u2f;
|
||||
u2f.u = rv_itof_s(value, 0, nullptr);
|
||||
return u2f.f;
|
||||
}
|
||||
case TensorFormat::Int8: {
|
||||
flaot_uint32_t u2f;
|
||||
u2f.u = rv_itof_s(value, 0, nullptr);
|
||||
return u2f.f;
|
||||
}
|
||||
case TensorFormat::FP16: {
|
||||
flaot_uint32_t u2f;
|
||||
u2f.u = rv_htof_s(value, 0, nullptr);
|
||||
return u2f.f;
|
||||
}
|
||||
case TensorFormat::BF16: {
|
||||
flaot_uint32_t u2f;
|
||||
u2f.u = rv_btof_s(value, 0, nullptr);
|
||||
return u2f.f;
|
||||
}
|
||||
default: assert(false);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline uint32_t float_to_type(float value, TensorFormat format) {
|
||||
switch (format) {
|
||||
case TensorFormat::Int32: {
|
||||
flaot_uint32_t f2u;
|
||||
f2u.f = value;
|
||||
return rv_ftoi_s(f2u.u, 0, nullptr);
|
||||
}
|
||||
case TensorFormat::FP32: {
|
||||
flaot_uint32_t f2u;
|
||||
f2u.f = value;
|
||||
return f2u.u;
|
||||
}
|
||||
default: assert(false);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
class TensorCore : public SimObject<TensorCore> {
|
||||
public:
|
||||
struct PerfStats {
|
||||
uint64_t reads;
|
||||
uint64_t writes;
|
||||
uint64_t latency;
|
||||
uint64_t stalls;
|
||||
|
||||
PerfStats()
|
||||
: reads(0)
|
||||
, writes(0)
|
||||
, latency(0)
|
||||
, stalls(0)
|
||||
: latency(0)
|
||||
{}
|
||||
|
||||
PerfStats& operator+=(const PerfStats& rhs) {
|
||||
this->reads += rhs.reads;
|
||||
this->writes += rhs.writes;
|
||||
this->latency += rhs.latency;
|
||||
this->stalls += rhs.stalls;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
@ -45,23 +115,56 @@ public:
|
|||
SimPort<instr_trace_t*> Input;
|
||||
SimPort<instr_trace_t*> Output;
|
||||
|
||||
TensorCore(const SimContext& ctx, const char* name);
|
||||
~TensorCore();
|
||||
TensorCore(const SimContext& ctx, const char* name, uint32_t tile_size)
|
||||
: SimObject<TensorCore>(ctx, name)
|
||||
, Input(this)
|
||||
, Output(this)
|
||||
, tile_size_(tile_size)
|
||||
{}
|
||||
|
||||
void reset();
|
||||
~TensorCore() {
|
||||
this->reset();
|
||||
}
|
||||
|
||||
void tick();
|
||||
void reset() {
|
||||
//--
|
||||
}
|
||||
|
||||
void attach_ram(RAM* mem);
|
||||
void tick() {
|
||||
//--
|
||||
}
|
||||
|
||||
void mmadd(TensorUnit::TraceData::Ptr trace_data);
|
||||
void mmadd(TensorFormat from_format,
|
||||
TensorFormat to_format,
|
||||
const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
TensorUnit::TraceData::Ptr trace_data) {
|
||||
assert(rd_data.size() <= tile_size_);
|
||||
trace_data->latency = 2 + tile_size_;
|
||||
// matrix multiplication and accumulation
|
||||
for (uint32_t i = 0; i < tile_size_; i++) {
|
||||
for (uint32_t j = 0; j < tile_size_; j++) {
|
||||
float sum = type_to_float(read_element(rs3_data, i * tile_size_ + j, to_format), to_format);
|
||||
for (uint32_t k = 0; k < tile_size_; k++) {
|
||||
auto a = type_to_float(read_element(rs1_data, i * tile_size_ + k, from_format), from_format);
|
||||
auto b = type_to_float(read_element(rs2_data, k * tile_size_ + j, from_format), from_format);
|
||||
sum += a * b;
|
||||
}
|
||||
write_element(rd_data, i * tile_size_ + j, float_to_type(sum, to_format), to_format);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const PerfStats& perf_stats() const;
|
||||
const PerfStats& perf_stats() const {
|
||||
return perf_stats_;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
class Impl;
|
||||
Impl* impl_;
|
||||
PerfStats perf_stats_;
|
||||
uint32_t tile_size_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -69,9 +172,17 @@ private:
|
|||
class TensorUnit::Impl {
|
||||
public:
|
||||
|
||||
Impl(TensorUnit* simobject)
|
||||
Impl(const SimContext& ctx, TensorUnit* simobject, uint32_t tile_size)
|
||||
: simobject_(simobject)
|
||||
, tensor_cores_(NUM_TENSOR_CORES)
|
||||
, tc_sel_(0)
|
||||
{
|
||||
char sname[100];
|
||||
for (uint32_t i = 0; i < NUM_TENSOR_CORES; i++) {
|
||||
snprintf(sname, 100, "%s-core%d", simobject->name().c_str(), i);
|
||||
tensor_cores_[i] = TensorCore::Create(ctx, sname, tile_size);
|
||||
}
|
||||
|
||||
this->reset();
|
||||
}
|
||||
|
||||
|
@ -82,15 +193,26 @@ public:
|
|||
}
|
||||
|
||||
void tick() {
|
||||
//--
|
||||
// forward input to tensor cores
|
||||
auto& input = simobject_->Input;
|
||||
if (input.empty())
|
||||
return;
|
||||
auto trace = input.front();
|
||||
auto trace_data = std::dynamic_pointer_cast<TraceData>(trace->data);
|
||||
tensor_cores_.at(trace_data->tc_idx)->Input.push(trace, 1);
|
||||
input.pop();
|
||||
}
|
||||
|
||||
void mmadd(const std::vector<reg_data_t>& rs1_data,
|
||||
void mmadd(TensorFormat from_format,
|
||||
TensorFormat to_format,
|
||||
const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
TensorUnit::TraceData::Ptr& trace_data) {
|
||||
//--
|
||||
TensorUnit::TraceData::Ptr trace_data) {
|
||||
tensor_cores_.at(tc_sel_)->mmadd(from_format, to_format, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
|
||||
trace_data->tc_idx = tc_sel_;
|
||||
tc_sel_ = (tc_sel_ + 1) % NUM_TENSOR_CORES;
|
||||
}
|
||||
|
||||
const PerfStats& perf_stats() const {
|
||||
|
@ -100,16 +222,18 @@ public:
|
|||
private:
|
||||
|
||||
TensorUnit* simobject_;
|
||||
std::vector<TensorCore::Ptr> tensor_cores_;
|
||||
uint32_t tc_sel_;
|
||||
PerfStats perf_stats_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TensorUnit::TensorUnit(const SimContext& ctx, const char* name)
|
||||
TensorUnit::TensorUnit(const SimContext& ctx, const char* name, uint32_t tile_size)
|
||||
: SimObject<TensorUnit>(ctx, name)
|
||||
, Input(this)
|
||||
, Output(this)
|
||||
, impl_(new Impl(this))
|
||||
, impl_(new Impl(ctx, this, tile_size))
|
||||
{}
|
||||
|
||||
TensorUnit::~TensorUnit() {
|
||||
|
@ -124,12 +248,14 @@ void TensorUnit::tick() {
|
|||
impl_->tick();
|
||||
}
|
||||
|
||||
void TensorUnit::mmadd(const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
TensorUnit::TraceData::Ptr& trace_data) {
|
||||
impl_->mmadd(rs1_data, rs2_data, rs3_data, rd_data, trace_data);
|
||||
void TensorUnit::mmadd(TensorFormat from_format,
|
||||
TensorFormat to_format,
|
||||
const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
TensorUnit::TraceData::Ptr trace_data) {
|
||||
impl_->mmadd(from_format, to_format, rs1_data, rs2_data, rs3_data, rd_data, trace_data);
|
||||
}
|
||||
|
||||
const TensorUnit::PerfStats& TensorUnit::perf_stats() const {
|
||||
|
|
|
@ -18,38 +18,57 @@
|
|||
|
||||
namespace vortex {
|
||||
|
||||
enum class TensorFormat : int {
|
||||
Int4 = 0,
|
||||
Int8 = 1,
|
||||
Int16 = 2,
|
||||
Int32 = 3,
|
||||
Int64 = 4,
|
||||
FP16 = 5,
|
||||
FP32 = 6,
|
||||
FP64 = 7,
|
||||
BF16 = 8,
|
||||
_MAX = 9
|
||||
};
|
||||
|
||||
class TensorUnit : public SimObject<TensorUnit> {
|
||||
public:
|
||||
struct TraceData : public ITraceData {
|
||||
using Ptr = std::shared_ptr<TraceData>;
|
||||
uint32_t tc_idx;
|
||||
uint32_t latency;
|
||||
};
|
||||
|
||||
struct PerfStats {
|
||||
uint64_t latency;
|
||||
uint64_t stalls;
|
||||
|
||||
PerfStats()
|
||||
: latency(0)
|
||||
, stalls(0)
|
||||
{}
|
||||
|
||||
PerfStats& operator+=(const PerfStats& rhs) {
|
||||
this->latency += rhs.latency;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
SimPort<instr_trace_t*> Input;
|
||||
SimPort<instr_trace_t*> Output;
|
||||
|
||||
TensorUnit(const SimContext& ctx, const char* name);
|
||||
TensorUnit(const SimContext& ctx, const char* name, uint32_t tile_size);
|
||||
~TensorUnit();
|
||||
|
||||
void reset();
|
||||
|
||||
void tick();
|
||||
|
||||
void mmadd(const std::vector<reg_data_t>& rs1_data,
|
||||
void mmadd(TensorFormat from_format,
|
||||
TensorFormat to_format,
|
||||
const std::vector<reg_data_t>& rs1_data,
|
||||
const std::vector<reg_data_t>& rs2_data,
|
||||
const std::vector<reg_data_t>& rs3_data,
|
||||
std::vector<reg_data_t>& rd_data,
|
||||
TensorUnit::TraceData::Ptr& trace_data);
|
||||
TensorUnit::TraceData::Ptr trace_data);
|
||||
|
||||
const PerfStats& perf_stats() const;
|
||||
|
||||
|
|
|
@ -245,10 +245,9 @@ enum class SfuType {
|
|||
CSRRW,
|
||||
CSRRS,
|
||||
CSRRC,
|
||||
MMADD_U4,
|
||||
MMADD_U8,
|
||||
MMADD_F16,
|
||||
MMADD_BF16
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
MMADD,
|
||||
#endif
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const SfuType& type) {
|
||||
|
@ -263,10 +262,7 @@ inline std::ostream &operator<<(std::ostream &os, const SfuType& type) {
|
|||
case SfuType::CSRRS: os << "CSRRS"; break;
|
||||
case SfuType::CSRRC: os << "CSRRC"; break;
|
||||
#ifdef EXT_TPU_ENABLE
|
||||
case SfuType::MMADD_U4: os << "MMADD_U4"; break;
|
||||
case SfuType::MMADD_U8: os << "MMADD_U8"; break;
|
||||
case SfuType::MMADD_F16: os << "MMADD_F16"; break;
|
||||
case SfuType::MMADD_BF16: os << "MMADD_BF16"; break;
|
||||
case SfuType::MMADD: os << "MMADD"; break;
|
||||
#endif
|
||||
default: assert(false);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue