serv/sw/decodegen.py
2022-01-14 23:36:20 +01:00

494 lines
18 KiB
Python

from sympy.logic import SOPform
from sympy import symbols
from functools import partial, reduce
from itertools import product, combinations
import networkx as nx
import z3
HEADER = """module serv_auto_decode
#(parameter [0:0] MDU = 1'b0,
parameter [0:0] CSR = 1'b0)
(
input wire i_clk,
//Input
input wire i_en,
input wire i_imm30,
input wire [2:0] i_funct3,
input wire [4:0] i_opcode,
//MDU/Ext/CSR
input wire i_imm25,
output reg [2:0] o_ext_funct3,
output reg o_mdu_op,
input wire i_op20,
input wire i_op21,
input wire i_op22,
input wire i_op26,
output reg o_e_op,
output reg o_ebreak,
output reg o_ctrl_mret,
output reg o_csr_en,
output reg o_csr_addr1,
output reg o_csr_addr0,
output reg o_csr_mstatus_en,
output reg o_csr_mie_en,
output reg o_csr_mcause_en,
output reg o_csr_source1,
output reg o_csr_source0,
output reg o_csr_d_sel,
output reg o_csr_imm_en,
output reg o_rd_csr_en,
//Output
{ports});
{body}
endmodule
"""
def printmap(ctrlmap):
s = ""
l = max([len(x) for x in ctrlmap])
s+=' '*(l+2)+"lajjbbbbbblllllsssassxoasssassssxssoa\n"
s+=' '*(l+2)+"uuaaenlglgbhwbhbhwdllornlrrdulllorrrn\n"
s+=' '*(l+2)+"iillqetete uu dttridlladblttrla d\n"
s+=' '*(l+2)+" p r uu iiii iiii u \n"
s+=' '*(l+2)+" c u \n"
for k,v in ctrlmap.items():
s += f"{k:<{l}} |{v}|\n"
s += f"{len(ctrlmap)} signals\n"
return s
def merge(d, dst, src):
l = list(d[dst])
for i in range(len(l)):
if l[i] == ' ':
l[i] = d[src][i]
elif l[i] == '1' and d[src][i] == '0':
raise Exception
elif l[i] == '0' and d[src][i] == '1':
raise Exception
d[dst] = ''.join(l)
d.pop(src)
def merge_signals(ctrlmap):
#Merge control signals and keep track of which signals that have been combined
# We build a graph of signals as nodes and merge conflicts as edges. We use z3 to find
# an optimal coloring of the graph. All nodes of the same color will have no conflicts
# and can be merged
solver = z3.Optimize()
node_colors = {}
g = nx.Graph()
color_count = z3.Int('color_count')
# Create a node for every signal
for sig in ctrlmap:
g.add_node(sig)
node_colors[sig] = node_color = z3.Int('color_' + sig)
solver.add(node_color >= 0, node_color < color_count)
# Conflicting signals may not get the same color
for sig_i, sig_j in combinations(ctrlmap, 2):
collide = any(
i != j and ' ' not in (i, j)
for i, j in zip(ctrlmap[sig_i], ctrlmap[sig_j])
)
if collide:
g.add_edge(sig_i, sig_j)
solver.add(node_colors[sig_i] != node_colors[sig_j])
# We use networkx to find the largest clique. All nodes in that clique will have to get
# distinct colors. Since the numbering of colors is arbitrary, we can without loss of
# generality decide a fixed numbering of the colors of that clique. This kind of
# symmetry breaking is essential for performance here.
for i, sig in enumerate(max(nx.find_cliques(g), key=len)):
solver.add(node_colors[sig] == i)
solver.minimize(color_count)
if solver.check() != z3.sat:
raise Exception("optmization failed") # Shouldn't happen
model = solver.model()
print(f"Found coloring using {model[color_count]} colors")
merged_signals = {}
merge_by_color = {}
for signal in list(ctrlmap):
color = model[node_colors[signal]]
if color in merge_by_color:
other_signal = merge_by_color[color]
merge(ctrlmap, other_signal, signal)
merged_signals.setdefault(other_signal, []).append(signal)
else:
merge_by_color[color] = signal
if merged_signals:
for k,v in merged_signals.items():
print(f"Merged {', '.join(v)} into {k}")
return (ctrlmap, merged_signals)
def map2signals(ctrlmap):
for k,v in ctrlmap.items():
ctrl_signals = {}
t = []
f = []
for i,op in enumerate(ops):
#Only rv32i for now
if i > 36:
continue
if v[i] == '1':
t.append(op)
elif v[i] == '0':
f.append(op)
ctrl_signals[k] = (t,f)
return ctrl_signals
def minterms(s):
return list(map(partial(reduce, lambda x, y: 2*x + y), product(*([0, 1] if z == 'x' else [int(z)] for z in s))))
def map2minterms(bitmap):
m = []
falsies = []
for i,op in enumerate(ops):
#Only rv32i for now
if i > 36:
continue
if bitmap[i] == '1':
m += minterms(ops[op])
elif bitmap[i] == '0':
falsies += minterms(ops[op])
return (m, falsies)
def write_post_reg_logic_decoder(ctrlmap, merged_signals):
signames = [
'i_imm30',
'i_funct3[2]',
'i_funct3[1]',
'i_funct3[0]',
'i_opcode[4]',
'i_opcode[3]',
'i_opcode[2]',
'i_opcode[1]',
'i_opcode[0]',
]
syms = [*symbols(' '.join(signames))]
ports = []
body = '\n'.join('//'+x for x in printmap(ctrlmap).split('\n'))
body += "\nalways @(posedge i_clk)\n"
body += " if (i_en) begin\n"
body2 =" end\n\n"
for sig, bitmap in ctrlmap.items():
#Find all conditions signals must be true and false
(t, f) = map2minterms(bitmap)
#Use Quine-McCluskey to minimize the logic expressions needed for each
#control signal. Don't cares are the ones that are neither listed as
#true or false
dc = set(range(2**9))-set(t)-set(f)
s = SOPform(syms, t, dc)
ports.append(f"output reg o_{sig}")
#Output final control signal expression
body += f" o_{sig} <= {s};\n"
if sig in merged_signals:
for alias in merged_signals[sig]:
ports.append(f"output wire o_{alias}")
body2 += f" assign o_{alias} = o_{sig};\n"
#Some extra signals
body += " //MDU/CSR/Ext\n"
body += " o_mdu_op <= MDU & (i_opcode == 5'b01100) & i_imm25;\n"
body += " o_ext_funct3 <= i_funct3;\n"
body += " o_ebreak <= i_op20;\n"
body += " o_rd_csr_en <= i_opcode[4] & i_opcode[2] & (|i_funct3);\n"
body += " o_ctrl_mret <= i_opcode[4] & i_opcode[2] & !(|i_funct3) & i_op21;\n"
body += " o_e_op <= i_opcode[4] & i_opcode[2] & !(|i_funct3) & !i_op21;\n"
body += " o_csr_en <= i_op20 | (i_op26 & !i_op21);\n"
body += " o_csr_mstatus_en <= !i_op26 & !i_op22;\n"
body += " o_csr_mie_en <= !i_op26 & i_op22 & !i_op20;\n"
body += " o_csr_mcause_en <= i_op21 & !i_op20;\n"
body += " o_csr_source1 <= i_funct3[1];\n"
body += " o_csr_source0 <= i_funct3[0];\n"
body += " o_csr_d_sel <= i_funct3[2];\n"
body += " o_csr_imm_en <= i_opcode[4] & i_opcode[2] & i_funct3[2];\n"
body += " o_csr_addr1 <= i_op26 & i_op20;\n"
body += " o_csr_addr0 <= !i_op26 | i_op21;\n"
with open('serv_post_reg_decode.v', 'w') as f:
f.write(HEADER.format(ports=',\n '.join(ports), body=body+body2+'\n'))
def write_pre_reg_logic_decoder(ctrlmap, merged_signals):
signames = [
'imm30',
'funct3[2]',
'funct3[1]',
'funct3[0]',
'opcode[4]',
'opcode[3]',
'opcode[2]',
'opcode[1]',
'opcode[0]',
]
syms = [*symbols(' '.join(signames))]
ports = []
body = """ reg imm30;
reg [2:0] funct3;
reg [4:0] opcode;
always @(posedge i_clk)
if (i_en) begin
imm30 <= i_imm30;
funct3 <= i_funct3;
opcode <= i_opcode;
end
"""
for sig, bitmap in ctrlmap.items():
#Find all conditions signals must be true and false
(t, f) = map2minterms(bitmap)
#Use Quine-McCluskey to minimize the logic expressions needed for each
#control signal. Don't cares are the ones that are neither listed as
#true or false
dc = set(range(2**9))-set(t)-set(f)
s = SOPform(syms, t, dc)
ports.append(f"output wire o_{sig}")
#Output final control signal expression
body += f" assign o_{sig} = {s};\n"
if sig in merged_signals:
for alias in merged_signals[sig]:
ports.append(f"output wire o_{alias}")
body += f" assign o_{alias} = o_{sig};"
#Some extra signals
body += "\n"
body += " //MDU/CSR/Ext\n"
body += "always @(posedge i_clk)\n"
body += " if (i_en) begin\n"
body += " imm25 <= i_imm25;\n"
body += " funct3 <= i_funct3;\n"
body += " opcode <= i_opcode;\n"
body += " op20 <= i_op20;\n"
body += " op21 <= i_op21;\n"
body += " op22 <= i_op22;\n"
body += " op26 <= i_op26;\n"
body += " end\n"
body += " assign o_mdu_op = MDU & (opcode == 5'b01100) & imm25;\n"
body += " assign o_ext_funct3 = funct3;\n"
body += " assign o_ebreak = op20;\n"
body += " assign o_rd_csr_en = opcode[4] & opcode[2] & (|funct3);\n"
body += " assign o_ctrl_mret = opcode[4] & opcode[2] & !(|funct3) & op21;\n"
body += " assign o_e_op = opcode[4] & opcode[2] & !(|funct3) & !op21;\n"
body += " assign o_csr_en = op20 | (op26 & !op21);\n"
body += " assign o_csr_mstatus_en = !op26 & !op22;\n"
body += " assign o_csr_mie_en = !op26 & op22 & !op20;\n"
body += " assign o_csr_mcause_en = op21 & !op20;\n"
body += " assign o_csr_source1 = funct3[1];\n"
body += " assign o_csr_source0 = funct3[0];\n"
body += " assign o_csr_d_sel = funct3[2];\n"
body += " assign o_csr_imm_en = opcode[4] & opcode[2] & funct3[2];\n"
body += " assign o_csr_addr1 = op26 & op20;\n"
body += " assign o_csr_addr0 = !op26 | op21;\n"
with open('serv_pre_reg_decode.v', 'w') as f:
f.write(HEADER.format(ports=',\n '.join(ports), body=body))
def write_mem_decoder(ctrlmap, merged_signals):
ports = []
mem = [0]*512
width = len(ctrlmap)
body = """ (* ram_style = "block" *) reg [{msb}:0] mem [0:511];
reg [{msb}:0] d;
initial begin
{mem} end
always @(posedge i_clk)
if (i_en)
d <= mem[{{i_imm30,i_funct3,i_opcode}}];
"""
s = ""
for i, (sig, bitmap) in enumerate(ctrlmap.items()):
#Find all conditions signals must be true
#Rest can be zero
(t, _) = map2minterms(bitmap)
for x in t:
mem[x] += 2**i
body += f" assign o_{sig} = d[{i}];\n"
ports.append(f"output wire o_{sig}")
if sig in merged_signals:
for alias in merged_signals[sig]:
ports.append(f"output wire o_{alias}")
body += f" assign o_{alias} = o_{sig};\n"
for i, m in enumerate(mem):
s += f" mem[{i}] = {width}'h{m:0{(width+3)//4}x};\n"
body += "\nalways @(posedge i_clk) begin\n"
body += "if (i_en) begin\n"
body += " //MDU/CSR/Ext\n"
body += " o_mdu_op <= MDU & (i_opcode == 5'b01100) & i_imm25;\n"
body += " o_ext_funct3 <= MDU ? i_funct3 : 3'b000;\n"
body += " o_ebreak <= CSR & (i_op20);\n"
body += " o_rd_csr_en <= CSR & (i_opcode[4] & i_opcode[2] & (|i_funct3));\n"
body += " o_ctrl_mret <= CSR & (i_opcode[4] & i_opcode[2] & !(|i_funct3) & i_op21);\n"
body += " o_e_op <= CSR & (i_opcode[4] & i_opcode[2] & !(|i_funct3) & !i_op21);\n"
body += " o_csr_en <= CSR & (i_op20 | (i_op26 & !i_op21));\n"
body += " o_csr_mstatus_en <= CSR & (!i_op26 & !i_op22);\n"
body += " o_csr_mie_en <= CSR & (!i_op26 & i_op22 & !i_op20);\n"
body += " o_csr_mcause_en <= CSR & ( i_op21 & !i_op20);\n"
body += " o_csr_source1 <= CSR & (i_funct3[1]);\n"
body += " o_csr_source0 <= CSR & (i_funct3[0]);\n"
body += " o_csr_d_sel <= CSR & (i_funct3[2]);\n"
body += " o_csr_imm_en <= CSR & (i_opcode[4] & i_opcode[2] & i_funct3[2]);\n"
body += " o_csr_addr1 <= CSR & (i_op26 & i_op20);\n"
body += " o_csr_addr0 <= CSR & (!i_op26 | i_op21);\n"
body += "end\n"
body += "end\n"
with open('serv_mem_decode.v', 'w') as f:
f.write(HEADER.format(ports=',\n '.join(ports), body=body.format(msb=width-1, mem=s)))
#imm30, funct3, opcode
ops = {
'lui' : 'x' + 'xxx' + '01101',
'auipc' : 'x' + 'xxx' + '00101',
'jal' : 'x' + 'xxx' + '11011',
'jalr' : 'x' + 'xxx' + '11001',#funct3 = 000?
'beq' : 'x' + '000' + '11000',
'bne' : 'x' + '001' + '11000',
'blt' : 'x' + '100' + '11000',
'bge' : 'x' + '101' + '11000',
'bltu' : 'x' + '110' + '11000',
'bgeu' : 'x' + '111' + '11000',
'lb' : 'x' + '000' + '00000',
'lh' : 'x' + '001' + '00000',
'lw' : 'x' + '010' + '00000',
'lbu' : 'x' + '100' + '00000',
'lhu' : 'x' + '101' + '00000',
'sb' : 'x' + '000' + '01000',
'sh' : 'x' + '001' + '01000',
'sw' : 'x' + '010' + '01000',
'addi' : 'x' + '000' + '00100',
'slti' : 'x' + '010' + '00100',
'sltiu' : 'x' + '011' + '00100',
'xori' : 'x' + '100' + '00100',
'ori' : 'x' + '110' + '00100',
'andi' : 'x' + '111' + '00100',
'slli' : '0' + '001' + '00100',
'srli' : '0' + '101' + '00100',
'srai' : '1' + '101' + '00100',
'add' : '0' + '000' + '01100',
'sub' : '1' + '000' + '01100',
'sll' : '0' + '001' + '01100',
'slt' : '0' + '010' + '01100',
'sltu' : '0' + '011' + '01100',
'xor' : '0' + '100' + '01100',
'srl' : '0' + '101' + '01100',
'sra' : '1' + '101' + '01100',
'or' : '0' + '110' + '01100',
'and' : '0' + '111' + '01100',
'fence' : 'x' + 'xxx' + '00011',#funct3=000?
'ecall' : 'x' + '000' + '11100',#ebreak same but op20=1
'csrrw' : 'x' + '001' + '11100',
'csrrs' : 'x' + '010' + '11100',
'csrrc' : 'x' + '011' + '11100',
'csrrwi': 'x' + '101' + '11100',
'csrrsi': 'x' + '110' + '11100',
'csrrci': 'x' + '111' + '11100',
}
###################################
###################################
###################################
#Map of all required true/false conditions for each op.
#This should ideally be created automatically from riscv-formal runs
#TODO: Extend with optional ISA extensions (M, Zicsr, Zifencei..)
#ebreak = ecall with op20=1
ctrlmap = \
{
#UUJRBBBBBBIIIIISSSIIIIIIIIIRRRRRRRRRR
#lajjbbbbbblllllsssassxoasssassssxssoa
#uuaaenlglgbhwbhbhwdllornlrrdulllorrrn
#iillqetete uu dttridlladblttrla d
# p r uu iiii iiii u
# c u
'branch_op' : ' 1111111100000000 00 000 000 00 ',
'slt_or_branch' : ' 1111111100000000 11 000 011 00 ',
'op_b_source' : ' 11111111000001110000000001111111111',
'immdec_ctrl0' : ' 0 11111100000111000000000 ',
'immdec_ctrl1' : '000 11111111111111111 ',
'immdec_ctrl2' : '000 11111100000000000000000 ',
'immdec_ctrl3' : '001 ',
'immdec_en0' : '0000111111000001110000000000000000000',
'immdec_en1' : '1110000000000000000000000000000000000',
'immdec_en2' : '1111000000111110001111111110000000000',
'immdec_en3' : '1110111111111111111111111110000000000',
'bne_or_bge' : ' 010101 ',
'sh_right' : ' 011 0 11 ',
'cond_branch' : ' 00111111 ',
'shift_op' : ' 0000000000000000 00 111 100 11 ',
'two_stage_op' : '0011111111111111110110001110011101100',
'rd_alu_en' : '0000 00000 1111111111111111111',
'rd_mem_en' : '0000 11111 0000000000000000000',
'dbus_en' : ' 0000000011111111 00 000 000 00 ',
'bufreg_rs1_en' : ' 0100000011111111 111 1 11 ',
'bufreg_imm_en' : ' 1111111111111111 000 0 00 ',
'bufreg_clr_lsb' : ' 1011111100000000 000 0 00 ',
'bufreg_sh_signed': ' 01 01 ',
'ctrl_jal_or_jalr': '0011 00000 0000000000000000000',
'ctrl_utype' : '110000000000000 0000000000000000000',
'ctrl_pc_rel' : '0110111111 ',
'rd_op' : '1111000000111110001111111111111111111',
'alu_sub' : ' 111111 011 01 11 ',
'alu_bool_op1' : ' 011000 0 00011',
'alu_bool_op0' : ' 001111 1 01101',
'alu_cmp_eq' : ' 110000 00 00 ',
'alu_cmp_sig' : ' 1100 10 10 ',
'alu_rd_sel0' : ' 1000000001100000000',
'alu_rd_sel1' : ' 0110000000001100000',
'alu_rd_sel2' : ' 0001110000000010011',
'mem_signed' : ' 11 00 ',
'mem_word' : ' 00100001 ',
'mem_half' : ' 01001010 ',
'mem_cmd' : ' 00000111 ',
# 'mtval_pc' : ' 1111111100000000 ',
}
print(printmap(ctrlmap))
print("\nMerging control signals")
(ctrlmap, merged_signals) = merge_signals(ctrlmap)
print(printmap(ctrlmap))
#Create the various decoders
print("Creating mem decoder")
write_mem_decoder(ctrlmap, merged_signals)
#print("Writing post-registered logic decoder")
#write_post_reg_logic_decoder(ctrlmap, merged_signals)
#
#print("Writing pre-registered logic decoder")
#write_pre_reg_logic_decoder(ctrlmap, merged_signals)