minor update

This commit is contained in:
Blaise Tine 2024-03-26 16:31:24 -07:00
parent 8ab4c53e27
commit 70f2f58ac9
3 changed files with 24 additions and 26 deletions

View file

@ -10,6 +10,7 @@
typedef struct {
uint32_t num_tasks;
uint32_t size;
uint32_t log2_size;
uint64_t A_addr;
uint64_t B_addr;
uint64_t C_addr;

View file

@ -7,20 +7,15 @@ inline char is_log2(uint32_t x) {
return ((x & (x-1)) == 0);
}
inline uint32_t log2_fast(uint32_t x) {
return 31 - __builtin_clz (x);
}
void kernel_body(uint32_t task_id, kernel_arg_t* __UNIFORM__ arg) {
auto A = reinterpret_cast<TYPE*>(arg->A_addr);
auto B = reinterpret_cast<TYPE*>(arg->B_addr);
auto C = reinterpret_cast<TYPE*>(arg->C_addr);
auto size = arg->size;
auto size = arg->size;
uint32_t row, col;
if (is_log2(size)) {
uint32_t log_size = log2_fast(size);
row = task_id >> log_size;
row = task_id >> arg->log2_size;
col = task_id & (size-1);
} else {
row = task_id / size;
@ -31,6 +26,7 @@ void kernel_body(uint32_t task_id, kernel_arg_t* __UNIFORM__ arg) {
for (int e = 0; e < size; ++e) {
sum += A[row * size + e] * B[e * size + col];
}
C[row * size + col] = sum;
}

View file

@ -4,6 +4,7 @@
#include <vector>
#include <chrono>
#include <vortex.h>
#include <cmath>
#include "common.h"
#define FLOAT_ULP 6
@ -138,7 +139,6 @@ int main(int argc, char *argv[]) {
std::cout << "data type: " << Comparator<TYPE>::type_str() << std::endl;
std::cout << "matrix size: " << size << "x" << size << std::endl;
std::cout << "buffer size: " << buf_size << " bytes" << std::endl;
// upload program
std::cout << "upload program" << std::endl;
@ -152,10 +152,11 @@ int main(int argc, char *argv[]) {
kernel_arg.num_tasks = num_points;
kernel_arg.size = size;
kernel_arg.log2_size = log2(size);
std::cout << "dev_src0=0x" << std::hex << kernel_arg.A_addr << std::endl;
std::cout << "dev_src1=0x" << std::hex << kernel_arg.B_addr << std::endl;
std::cout << "dev_dst=0x" << std::hex << kernel_arg.C_addr << std::endl;
std::cout << "dev_argA=0x" << std::hex << kernel_arg.A_addr << std::endl;
std::cout << "dev_argB=0x" << std::hex << kernel_arg.B_addr << std::endl;
std::cout << "dev_argC=0x" << std::hex << kernel_arg.C_addr << std::endl;
// allocate staging buffer
std::cout << "allocate staging buffer" << std::endl;
@ -168,40 +169,40 @@ int main(int argc, char *argv[]) {
RT_CHECK(vx_copy_to_dev(device, KERNEL_ARG_DEV_MEM_ADDR, staging_buf.data(), sizeof(kernel_arg_t)));
// generate source data
std::vector<TYPE> src_A(num_points);
std::vector<TYPE> src_B(num_points);
std::vector<TYPE> refs(num_points);
std::vector<TYPE> h_A(num_points);
std::vector<TYPE> h_B(num_points);
std::vector<TYPE> h_C(num_points);
for (uint32_t i = 0; i < num_points; ++i) {
auto a = static_cast<float>(std::rand()) / RAND_MAX;
auto b = static_cast<float>(std::rand()) / RAND_MAX;
src_A[i] = static_cast<TYPE>(a * size);
src_B[i] = static_cast<TYPE>(b * size);
h_A[i] = static_cast<TYPE>(a * size);
h_B[i] = static_cast<TYPE>(b * size);
}
matmul_cpu(refs.data(), src_A.data(), src_B.data(), size, size);
matmul_cpu(h_C.data(), h_A.data(), h_B.data(), size, size);
// upload source buffer0
// upload matrix A buffer
{
std::cout << "upload source buffer0" << std::endl;
std::cout << "upload matrix A buffer" << std::endl;
auto buf_ptr = (TYPE*)staging_buf.data();
for (uint32_t i = 0; i < num_points; ++i) {
buf_ptr[i] = src_A[i];
buf_ptr[i] = h_A[i];
}
RT_CHECK(vx_copy_to_dev(device, kernel_arg.A_addr, staging_buf.data(), buf_size));
}
// upload source buffer1
// upload matrix B buffer
{
std::cout << "upload source buffer1" << std::endl;
std::cout << "upload matrix B buffer" << std::endl;
auto buf_ptr = (TYPE*)staging_buf.data();
for (uint32_t i = 0; i < num_points; ++i) {
buf_ptr[i] = src_B[i];
buf_ptr[i] = h_B[i];
}
RT_CHECK(vx_copy_to_dev(device, kernel_arg.B_addr, staging_buf.data(), buf_size));
}
// clear destination buffer
std::cout << "clear destination buffer" << std::endl;
memset(staging_buf.data(), 0, num_points * sizeof(TYPE));
memset(staging_buf.data(), 0, buf_size);
RT_CHECK(vx_copy_to_dev(device, kernel_arg.C_addr, staging_buf.data(), buf_size));
auto time_start = std::chrono::high_resolution_clock::now();
@ -227,8 +228,8 @@ int main(int argc, char *argv[]) {
{
int errors = 0;
auto buf_ptr = (TYPE*)staging_buf.data();
for (uint32_t i = 0; i < refs.size(); ++i) {
auto ref = refs[i];
for (uint32_t i = 0; i < h_C.size(); ++i) {
auto ref = h_C[i];
auto cur = buf_ptr[i];
if (!Comparator<TYPE>::compare(cur, ref, i, errors)) {
++errors;