mirror of
https://github.com/olofk/serv.git
synced 2025-06-28 17:22:58 -04:00
494 lines
18 KiB
Python
494 lines
18 KiB
Python
from sympy.logic import SOPform
|
|
from sympy import symbols
|
|
from functools import partial, reduce
|
|
from itertools import product, combinations
|
|
import networkx as nx
|
|
import z3
|
|
|
|
HEADER = """module serv_auto_decode
|
|
#(parameter [0:0] MDU = 1'b0,
|
|
parameter [0:0] CSR = 1'b0)
|
|
(
|
|
input wire i_clk,
|
|
//Input
|
|
input wire i_en,
|
|
input wire i_imm30,
|
|
input wire [2:0] i_funct3,
|
|
input wire [4:0] i_opcode,
|
|
//MDU/Ext/CSR
|
|
input wire i_imm25,
|
|
output reg [2:0] o_ext_funct3,
|
|
output reg o_mdu_op,
|
|
input wire i_op20,
|
|
input wire i_op21,
|
|
input wire i_op22,
|
|
input wire i_op26,
|
|
output reg o_e_op,
|
|
output reg o_ebreak,
|
|
output reg o_ctrl_mret,
|
|
output reg o_csr_en,
|
|
output reg o_csr_addr1,
|
|
output reg o_csr_addr0,
|
|
output reg o_csr_mstatus_en,
|
|
output reg o_csr_mie_en,
|
|
output reg o_csr_mcause_en,
|
|
output reg o_csr_source1,
|
|
output reg o_csr_source0,
|
|
output reg o_csr_d_sel,
|
|
output reg o_csr_imm_en,
|
|
output reg o_rd_csr_en,
|
|
//Output
|
|
{ports});
|
|
|
|
{body}
|
|
endmodule
|
|
"""
|
|
|
|
def printmap(ctrlmap):
|
|
s = ""
|
|
l = max([len(x) for x in ctrlmap])
|
|
s+=' '*(l+2)+"lajjbbbbbblllllsssassxoasssassssxssoa\n"
|
|
s+=' '*(l+2)+"uuaaenlglgbhwbhbhwdllornlrrdulllorrrn\n"
|
|
s+=' '*(l+2)+"iillqetete uu dttridlladblttrla d\n"
|
|
s+=' '*(l+2)+" p r uu iiii iiii u \n"
|
|
s+=' '*(l+2)+" c u \n"
|
|
|
|
for k,v in ctrlmap.items():
|
|
s += f"{k:<{l}} |{v}|\n"
|
|
s += f"{len(ctrlmap)} signals\n"
|
|
return s
|
|
|
|
def merge(d, dst, src):
|
|
l = list(d[dst])
|
|
for i in range(len(l)):
|
|
if l[i] == ' ':
|
|
l[i] = d[src][i]
|
|
elif l[i] == '1' and d[src][i] == '0':
|
|
raise Exception
|
|
elif l[i] == '0' and d[src][i] == '1':
|
|
raise Exception
|
|
d[dst] = ''.join(l)
|
|
d.pop(src)
|
|
|
|
def merge_signals(ctrlmap):
|
|
#Merge control signals and keep track of which signals that have been combined
|
|
|
|
# We build a graph of signals as nodes and merge conflicts as edges. We use z3 to find
|
|
# an optimal coloring of the graph. All nodes of the same color will have no conflicts
|
|
# and can be merged
|
|
|
|
solver = z3.Optimize()
|
|
|
|
node_colors = {}
|
|
|
|
g = nx.Graph()
|
|
|
|
color_count = z3.Int('color_count')
|
|
|
|
# Create a node for every signal
|
|
for sig in ctrlmap:
|
|
g.add_node(sig)
|
|
node_colors[sig] = node_color = z3.Int('color_' + sig)
|
|
|
|
solver.add(node_color >= 0, node_color < color_count)
|
|
|
|
# Conflicting signals may not get the same color
|
|
for sig_i, sig_j in combinations(ctrlmap, 2):
|
|
collide = any(
|
|
i != j and ' ' not in (i, j)
|
|
for i, j in zip(ctrlmap[sig_i], ctrlmap[sig_j])
|
|
)
|
|
if collide:
|
|
g.add_edge(sig_i, sig_j)
|
|
solver.add(node_colors[sig_i] != node_colors[sig_j])
|
|
|
|
# We use networkx to find the largest clique. All nodes in that clique will have to get
|
|
# distinct colors. Since the numbering of colors is arbitrary, we can without loss of
|
|
# generality decide a fixed numbering of the colors of that clique. This kind of
|
|
# symmetry breaking is essential for performance here.
|
|
for i, sig in enumerate(max(nx.find_cliques(g), key=len)):
|
|
solver.add(node_colors[sig] == i)
|
|
|
|
solver.minimize(color_count)
|
|
|
|
if solver.check() != z3.sat:
|
|
raise Exception("optmization failed") # Shouldn't happen
|
|
|
|
model = solver.model()
|
|
|
|
print(f"Found coloring using {model[color_count]} colors")
|
|
|
|
merged_signals = {}
|
|
merge_by_color = {}
|
|
|
|
for signal in list(ctrlmap):
|
|
color = model[node_colors[signal]]
|
|
if color in merge_by_color:
|
|
other_signal = merge_by_color[color]
|
|
merge(ctrlmap, other_signal, signal)
|
|
merged_signals.setdefault(other_signal, []).append(signal)
|
|
else:
|
|
merge_by_color[color] = signal
|
|
|
|
if merged_signals:
|
|
for k,v in merged_signals.items():
|
|
print(f"Merged {', '.join(v)} into {k}")
|
|
return (ctrlmap, merged_signals)
|
|
|
|
def map2signals(ctrlmap):
|
|
for k,v in ctrlmap.items():
|
|
ctrl_signals = {}
|
|
t = []
|
|
f = []
|
|
for i,op in enumerate(ops):
|
|
#Only rv32i for now
|
|
if i > 36:
|
|
continue
|
|
|
|
if v[i] == '1':
|
|
t.append(op)
|
|
elif v[i] == '0':
|
|
f.append(op)
|
|
ctrl_signals[k] = (t,f)
|
|
return ctrl_signals
|
|
|
|
def minterms(s):
|
|
return list(map(partial(reduce, lambda x, y: 2*x + y), product(*([0, 1] if z == 'x' else [int(z)] for z in s))))
|
|
|
|
def map2minterms(bitmap):
|
|
m = []
|
|
falsies = []
|
|
for i,op in enumerate(ops):
|
|
#Only rv32i for now
|
|
if i > 36:
|
|
continue
|
|
|
|
if bitmap[i] == '1':
|
|
m += minterms(ops[op])
|
|
elif bitmap[i] == '0':
|
|
falsies += minterms(ops[op])
|
|
return (m, falsies)
|
|
|
|
def write_post_reg_logic_decoder(ctrlmap, merged_signals):
|
|
signames = [
|
|
'i_imm30',
|
|
'i_funct3[2]',
|
|
'i_funct3[1]',
|
|
'i_funct3[0]',
|
|
'i_opcode[4]',
|
|
'i_opcode[3]',
|
|
'i_opcode[2]',
|
|
'i_opcode[1]',
|
|
'i_opcode[0]',
|
|
]
|
|
syms = [*symbols(' '.join(signames))]
|
|
ports = []
|
|
body = '\n'.join('//'+x for x in printmap(ctrlmap).split('\n'))
|
|
body += "\nalways @(posedge i_clk)\n"
|
|
body += " if (i_en) begin\n"
|
|
body2 =" end\n\n"
|
|
for sig, bitmap in ctrlmap.items():
|
|
#Find all conditions signals must be true and false
|
|
(t, f) = map2minterms(bitmap)
|
|
|
|
#Use Quine-McCluskey to minimize the logic expressions needed for each
|
|
#control signal. Don't cares are the ones that are neither listed as
|
|
#true or false
|
|
dc = set(range(2**9))-set(t)-set(f)
|
|
s = SOPform(syms, t, dc)
|
|
|
|
ports.append(f"output reg o_{sig}")
|
|
|
|
#Output final control signal expression
|
|
body += f" o_{sig} <= {s};\n"
|
|
if sig in merged_signals:
|
|
for alias in merged_signals[sig]:
|
|
ports.append(f"output wire o_{alias}")
|
|
body2 += f" assign o_{alias} = o_{sig};\n"
|
|
|
|
#Some extra signals
|
|
body += " //MDU/CSR/Ext\n"
|
|
body += " o_mdu_op <= MDU & (i_opcode == 5'b01100) & i_imm25;\n"
|
|
body += " o_ext_funct3 <= i_funct3;\n"
|
|
body += " o_ebreak <= i_op20;\n"
|
|
body += " o_rd_csr_en <= i_opcode[4] & i_opcode[2] & (|i_funct3);\n"
|
|
body += " o_ctrl_mret <= i_opcode[4] & i_opcode[2] & !(|i_funct3) & i_op21;\n"
|
|
body += " o_e_op <= i_opcode[4] & i_opcode[2] & !(|i_funct3) & !i_op21;\n"
|
|
body += " o_csr_en <= i_op20 | (i_op26 & !i_op21);\n"
|
|
body += " o_csr_mstatus_en <= !i_op26 & !i_op22;\n"
|
|
body += " o_csr_mie_en <= !i_op26 & i_op22 & !i_op20;\n"
|
|
body += " o_csr_mcause_en <= i_op21 & !i_op20;\n"
|
|
body += " o_csr_source1 <= i_funct3[1];\n"
|
|
body += " o_csr_source0 <= i_funct3[0];\n"
|
|
body += " o_csr_d_sel <= i_funct3[2];\n"
|
|
body += " o_csr_imm_en <= i_opcode[4] & i_opcode[2] & i_funct3[2];\n"
|
|
body += " o_csr_addr1 <= i_op26 & i_op20;\n"
|
|
body += " o_csr_addr0 <= !i_op26 | i_op21;\n"
|
|
|
|
with open('serv_post_reg_decode.v', 'w') as f:
|
|
f.write(HEADER.format(ports=',\n '.join(ports), body=body+body2+'\n'))
|
|
|
|
def write_pre_reg_logic_decoder(ctrlmap, merged_signals):
|
|
signames = [
|
|
'imm30',
|
|
'funct3[2]',
|
|
'funct3[1]',
|
|
'funct3[0]',
|
|
'opcode[4]',
|
|
'opcode[3]',
|
|
'opcode[2]',
|
|
'opcode[1]',
|
|
'opcode[0]',
|
|
]
|
|
syms = [*symbols(' '.join(signames))]
|
|
ports = []
|
|
body = """ reg imm30;
|
|
reg [2:0] funct3;
|
|
reg [4:0] opcode;
|
|
|
|
always @(posedge i_clk)
|
|
if (i_en) begin
|
|
imm30 <= i_imm30;
|
|
funct3 <= i_funct3;
|
|
opcode <= i_opcode;
|
|
end
|
|
"""
|
|
|
|
for sig, bitmap in ctrlmap.items():
|
|
#Find all conditions signals must be true and false
|
|
(t, f) = map2minterms(bitmap)
|
|
|
|
#Use Quine-McCluskey to minimize the logic expressions needed for each
|
|
#control signal. Don't cares are the ones that are neither listed as
|
|
#true or false
|
|
dc = set(range(2**9))-set(t)-set(f)
|
|
s = SOPform(syms, t, dc)
|
|
|
|
ports.append(f"output wire o_{sig}")
|
|
|
|
#Output final control signal expression
|
|
body += f" assign o_{sig} = {s};\n"
|
|
if sig in merged_signals:
|
|
for alias in merged_signals[sig]:
|
|
ports.append(f"output wire o_{alias}")
|
|
body += f" assign o_{alias} = o_{sig};"
|
|
|
|
#Some extra signals
|
|
body += "\n"
|
|
body += " //MDU/CSR/Ext\n"
|
|
body += "always @(posedge i_clk)\n"
|
|
body += " if (i_en) begin\n"
|
|
body += " imm25 <= i_imm25;\n"
|
|
body += " funct3 <= i_funct3;\n"
|
|
body += " opcode <= i_opcode;\n"
|
|
body += " op20 <= i_op20;\n"
|
|
body += " op21 <= i_op21;\n"
|
|
body += " op22 <= i_op22;\n"
|
|
body += " op26 <= i_op26;\n"
|
|
body += " end\n"
|
|
body += " assign o_mdu_op = MDU & (opcode == 5'b01100) & imm25;\n"
|
|
body += " assign o_ext_funct3 = funct3;\n"
|
|
body += " assign o_ebreak = op20;\n"
|
|
body += " assign o_rd_csr_en = opcode[4] & opcode[2] & (|funct3);\n"
|
|
body += " assign o_ctrl_mret = opcode[4] & opcode[2] & !(|funct3) & op21;\n"
|
|
body += " assign o_e_op = opcode[4] & opcode[2] & !(|funct3) & !op21;\n"
|
|
body += " assign o_csr_en = op20 | (op26 & !op21);\n"
|
|
body += " assign o_csr_mstatus_en = !op26 & !op22;\n"
|
|
body += " assign o_csr_mie_en = !op26 & op22 & !op20;\n"
|
|
body += " assign o_csr_mcause_en = op21 & !op20;\n"
|
|
body += " assign o_csr_source1 = funct3[1];\n"
|
|
body += " assign o_csr_source0 = funct3[0];\n"
|
|
body += " assign o_csr_d_sel = funct3[2];\n"
|
|
body += " assign o_csr_imm_en = opcode[4] & opcode[2] & funct3[2];\n"
|
|
body += " assign o_csr_addr1 = op26 & op20;\n"
|
|
body += " assign o_csr_addr0 = !op26 | op21;\n"
|
|
|
|
with open('serv_pre_reg_decode.v', 'w') as f:
|
|
f.write(HEADER.format(ports=',\n '.join(ports), body=body))
|
|
|
|
def write_mem_decoder(ctrlmap, merged_signals):
|
|
ports = []
|
|
mem = [0]*512
|
|
|
|
width = len(ctrlmap)
|
|
body = """ (* ram_style = "block" *) reg [{msb}:0] mem [0:511];
|
|
reg [{msb}:0] d;
|
|
|
|
initial begin
|
|
{mem} end
|
|
|
|
always @(posedge i_clk)
|
|
if (i_en)
|
|
d <= mem[{{i_imm30,i_funct3,i_opcode}}];
|
|
|
|
"""
|
|
|
|
s = ""
|
|
for i, (sig, bitmap) in enumerate(ctrlmap.items()):
|
|
#Find all conditions signals must be true
|
|
#Rest can be zero
|
|
(t, _) = map2minterms(bitmap)
|
|
for x in t:
|
|
mem[x] += 2**i
|
|
body += f" assign o_{sig} = d[{i}];\n"
|
|
|
|
ports.append(f"output wire o_{sig}")
|
|
|
|
if sig in merged_signals:
|
|
for alias in merged_signals[sig]:
|
|
ports.append(f"output wire o_{alias}")
|
|
body += f" assign o_{alias} = o_{sig};\n"
|
|
|
|
for i, m in enumerate(mem):
|
|
s += f" mem[{i}] = {width}'h{m:0{(width+3)//4}x};\n"
|
|
|
|
body += "\nalways @(posedge i_clk) begin\n"
|
|
body += "if (i_en) begin\n"
|
|
body += " //MDU/CSR/Ext\n"
|
|
body += " o_mdu_op <= MDU & (i_opcode == 5'b01100) & i_imm25;\n"
|
|
body += " o_ext_funct3 <= MDU ? i_funct3 : 3'b000;\n"
|
|
body += " o_ebreak <= CSR & (i_op20);\n"
|
|
body += " o_rd_csr_en <= CSR & (i_opcode[4] & i_opcode[2] & (|i_funct3));\n"
|
|
body += " o_ctrl_mret <= CSR & (i_opcode[4] & i_opcode[2] & !(|i_funct3) & i_op21);\n"
|
|
body += " o_e_op <= CSR & (i_opcode[4] & i_opcode[2] & !(|i_funct3) & !i_op21);\n"
|
|
body += " o_csr_en <= CSR & (i_op20 | (i_op26 & !i_op21));\n"
|
|
body += " o_csr_mstatus_en <= CSR & (!i_op26 & !i_op22);\n"
|
|
body += " o_csr_mie_en <= CSR & (!i_op26 & i_op22 & !i_op20);\n"
|
|
body += " o_csr_mcause_en <= CSR & ( i_op21 & !i_op20);\n"
|
|
body += " o_csr_source1 <= CSR & (i_funct3[1]);\n"
|
|
body += " o_csr_source0 <= CSR & (i_funct3[0]);\n"
|
|
body += " o_csr_d_sel <= CSR & (i_funct3[2]);\n"
|
|
body += " o_csr_imm_en <= CSR & (i_opcode[4] & i_opcode[2] & i_funct3[2]);\n"
|
|
body += " o_csr_addr1 <= CSR & (i_op26 & i_op20);\n"
|
|
body += " o_csr_addr0 <= CSR & (!i_op26 | i_op21);\n"
|
|
body += "end\n"
|
|
body += "end\n"
|
|
|
|
with open('serv_mem_decode.v', 'w') as f:
|
|
f.write(HEADER.format(ports=',\n '.join(ports), body=body.format(msb=width-1, mem=s)))
|
|
|
|
|
|
#imm30, funct3, opcode
|
|
ops = {
|
|
'lui' : 'x' + 'xxx' + '01101',
|
|
'auipc' : 'x' + 'xxx' + '00101',
|
|
'jal' : 'x' + 'xxx' + '11011',
|
|
'jalr' : 'x' + 'xxx' + '11001',#funct3 = 000?
|
|
'beq' : 'x' + '000' + '11000',
|
|
'bne' : 'x' + '001' + '11000',
|
|
'blt' : 'x' + '100' + '11000',
|
|
'bge' : 'x' + '101' + '11000',
|
|
'bltu' : 'x' + '110' + '11000',
|
|
'bgeu' : 'x' + '111' + '11000',
|
|
'lb' : 'x' + '000' + '00000',
|
|
'lh' : 'x' + '001' + '00000',
|
|
'lw' : 'x' + '010' + '00000',
|
|
'lbu' : 'x' + '100' + '00000',
|
|
'lhu' : 'x' + '101' + '00000',
|
|
'sb' : 'x' + '000' + '01000',
|
|
'sh' : 'x' + '001' + '01000',
|
|
'sw' : 'x' + '010' + '01000',
|
|
'addi' : 'x' + '000' + '00100',
|
|
'slti' : 'x' + '010' + '00100',
|
|
'sltiu' : 'x' + '011' + '00100',
|
|
'xori' : 'x' + '100' + '00100',
|
|
'ori' : 'x' + '110' + '00100',
|
|
'andi' : 'x' + '111' + '00100',
|
|
'slli' : '0' + '001' + '00100',
|
|
'srli' : '0' + '101' + '00100',
|
|
'srai' : '1' + '101' + '00100',
|
|
'add' : '0' + '000' + '01100',
|
|
'sub' : '1' + '000' + '01100',
|
|
'sll' : '0' + '001' + '01100',
|
|
'slt' : '0' + '010' + '01100',
|
|
'sltu' : '0' + '011' + '01100',
|
|
'xor' : '0' + '100' + '01100',
|
|
'srl' : '0' + '101' + '01100',
|
|
'sra' : '1' + '101' + '01100',
|
|
'or' : '0' + '110' + '01100',
|
|
'and' : '0' + '111' + '01100',
|
|
'fence' : 'x' + 'xxx' + '00011',#funct3=000?
|
|
'ecall' : 'x' + '000' + '11100',#ebreak same but op20=1
|
|
'csrrw' : 'x' + '001' + '11100',
|
|
'csrrs' : 'x' + '010' + '11100',
|
|
'csrrc' : 'x' + '011' + '11100',
|
|
'csrrwi': 'x' + '101' + '11100',
|
|
'csrrsi': 'x' + '110' + '11100',
|
|
'csrrci': 'x' + '111' + '11100',
|
|
}
|
|
|
|
###################################
|
|
###################################
|
|
###################################
|
|
|
|
|
|
#Map of all required true/false conditions for each op.
|
|
#This should ideally be created automatically from riscv-formal runs
|
|
#TODO: Extend with optional ISA extensions (M, Zicsr, Zifencei..)
|
|
|
|
#ebreak = ecall with op20=1
|
|
ctrlmap = \
|
|
{
|
|
#UUJRBBBBBBIIIIISSSIIIIIIIIIRRRRRRRRRR
|
|
#lajjbbbbbblllllsssassxoasssassssxssoa
|
|
#uuaaenlglgbhwbhbhwdllornlrrdulllorrrn
|
|
#iillqetete uu dttridlladblttrla d
|
|
# p r uu iiii iiii u
|
|
# c u
|
|
'branch_op' : ' 1111111100000000 00 000 000 00 ',
|
|
'slt_or_branch' : ' 1111111100000000 11 000 011 00 ',
|
|
'op_b_source' : ' 11111111000001110000000001111111111',
|
|
'immdec_ctrl0' : ' 0 11111100000111000000000 ',
|
|
'immdec_ctrl1' : '000 11111111111111111 ',
|
|
'immdec_ctrl2' : '000 11111100000000000000000 ',
|
|
'immdec_ctrl3' : '001 ',
|
|
'immdec_en0' : '0000111111000001110000000000000000000',
|
|
'immdec_en1' : '1110000000000000000000000000000000000',
|
|
'immdec_en2' : '1111000000111110001111111110000000000',
|
|
'immdec_en3' : '1110111111111111111111111110000000000',
|
|
'bne_or_bge' : ' 010101 ',
|
|
'sh_right' : ' 011 0 11 ',
|
|
'cond_branch' : ' 00111111 ',
|
|
'shift_op' : ' 0000000000000000 00 111 100 11 ',
|
|
'two_stage_op' : '0011111111111111110110001110011101100',
|
|
'rd_alu_en' : '0000 00000 1111111111111111111',
|
|
'rd_mem_en' : '0000 11111 0000000000000000000',
|
|
'dbus_en' : ' 0000000011111111 00 000 000 00 ',
|
|
'bufreg_rs1_en' : ' 0100000011111111 111 1 11 ',
|
|
'bufreg_imm_en' : ' 1111111111111111 000 0 00 ',
|
|
'bufreg_clr_lsb' : ' 1011111100000000 000 0 00 ',
|
|
'bufreg_sh_signed': ' 01 01 ',
|
|
'ctrl_jal_or_jalr': '0011 00000 0000000000000000000',
|
|
'ctrl_utype' : '110000000000000 0000000000000000000',
|
|
'ctrl_pc_rel' : '0110111111 ',
|
|
'rd_op' : '1111000000111110001111111111111111111',
|
|
'alu_sub' : ' 111111 011 01 11 ',
|
|
'alu_bool_op1' : ' 011000 0 00011',
|
|
'alu_bool_op0' : ' 001111 1 01101',
|
|
'alu_cmp_eq' : ' 110000 00 00 ',
|
|
'alu_cmp_sig' : ' 1100 10 10 ',
|
|
'alu_rd_sel0' : ' 1000000001100000000',
|
|
'alu_rd_sel1' : ' 0110000000001100000',
|
|
'alu_rd_sel2' : ' 0001110000000010011',
|
|
'mem_signed' : ' 11 00 ',
|
|
'mem_word' : ' 00100001 ',
|
|
'mem_half' : ' 01001010 ',
|
|
'mem_cmd' : ' 00000111 ',
|
|
# 'mtval_pc' : ' 1111111100000000 ',
|
|
}
|
|
|
|
print(printmap(ctrlmap))
|
|
|
|
print("\nMerging control signals")
|
|
(ctrlmap, merged_signals) = merge_signals(ctrlmap)
|
|
print(printmap(ctrlmap))
|
|
|
|
#Create the various decoders
|
|
print("Creating mem decoder")
|
|
write_mem_decoder(ctrlmap, merged_signals)
|
|
|
|
#print("Writing post-registered logic decoder")
|
|
#write_post_reg_logic_decoder(ctrlmap, merged_signals)
|
|
#
|
|
#print("Writing pre-registered logic decoder")
|
|
#write_pre_reg_logic_decoder(ctrlmap, merged_signals)
|