#!/usr/bin/python

# SPDX-License-Identifier: GPL-2.0
#
# Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
#    Author: Knut Omang <knut.omang@oracle.com>
#
# A script to generate code and header definitions for
# module and kernel symbols not directly exposed to KTF.
# (See the documentation for KTF for details)
#

import os, sys, re, shutil, string

def usage():
    print("Usage: resolve symbolfile outputfile")
    exit(0)

class ResolveError(Exception):
    def __init__(self, value):
        self.value = "resolve error: " + value
    def __str__(self):
        return repr(self.value)

class FuncInfo:
    def __init__(self, sym, re_sym, re_decl):
        self.symbol = sym
        self.re_sym = re_sym
        self.re_decl = re_decl

    def __repr__(self):
        return "FuncInfo: [%s: %s]" % (self.symbol, self.re_decl)

class Module:
    def __init__(self, name, header):
        self.cur_header = header
        self.prefix = "Z"
        self.name = name
        self.symbols = {}  # Input:  headerfile -> symbol list
        self.func_def = []  # FuncInfo list
        self.debug = False
        all_modules.append(self)

    def log(self, str):
        if self.debug:
            sys.stderr.write(str + "\n")

    def SetHeader(self, header):
        self.cur_header = header

    def AddSymbol(self, sym):
        try:
            h = self.symbols[self.cur_header]
            h.append(sym)
        except:
            self.symbols[self.cur_header] = [sym]

    # Open along include path:
    def Open(self, filename):
        for p in includepath:
            try:
                f = os.path.join(p, filename)
                self.log(" -- trying " + f)
                header = open(f,'r')
                return header
            except:
                continue
        sys.stderr.write(" ** unable to open \"%s\"\n" % filename)
        return None

    # Parse the relevant header files for function defs:
    def ParseDefs(self):
        for hf in self.symbols.keys():
            self.log(" ** Parsing %s:" % hf)
            header = self.Open(hf)
            if header == None:
                return
            content = header.read()
            symbols = self.symbols[hf]
            types = r"(extern|u8|u16|u32|u64|int|long|size_t|off_t|loff_t|void|struct|union\s)(.*[\*\s]"
            funor = ")(" + "|".join(symbols) + r")(\([^\)]+\);)$"
            tt = types + funor
            s = re.compile(tt, re.MULTILINE)
            miter = s.finditer(content)
            s_count = 0
            for m in miter:
                sym = m.group(3)
                re_name = "_".join([self.prefix, sym])
                re_decl = "%s%s(*%s)%s" % (m.group(1), m.group(2), re_name, m.group(4))
                self.func_def.append(FuncInfo(sym, re_name, re_decl))
                s_count = s_count + 1

            if s_count != len(symbols):
                print(" ** Warning: File %s: Found %d definitions from %d symbols!" % \
                    (hf, s_count, len(symbols)))
                print(" ** - please check/fix output manually!")

    # Output functions:
    def write_funcs(self, file):
        for fi in self.func_def:
            file.write("\t%s\n"% fi.re_decl)

    def write_defines(self, file):
        for fi in self.func_def:
            file.write("#define %s ktf_syms.%s\n" % (fi.symbol, fi.re_sym))

    def write_resolve_calls(self, file):
        for fi in self.func_def:
            file.write("\tktf_resolve_symbol(%s, %s);\n" % (self.name, fi.symbol))

usage_h = False
my_argv = []
includepath = [""]

#sys.stderr.write("%s\n" % (sys.argv))
for arg in sys.argv[1:]:
    if arg == "-h":
        usage_h = True
        continue
    incl = re.match(r"-I([^\s]+)", arg)
    if incl != None:
        includepath.append(incl.group(1))
        continue
    genopt = re.match(r"-([^\s]+)", arg)
    if genopt != None:
        # Silently ignore other cc options as we accept cc-flags-y:
        #
        continue
    my_argv.append(arg)

# Main program:

if len(my_argv) != 2 or usage_h:
    usage()

symfile = my_argv[0]
outputfile = my_argv[1]

all_modules = []
module = Module("kernel", None)  # Default, at the top of the file is the main kernel symbols
header = None  # A header directive is required before any symbols

try:
    file = open(symfile, 'r')
    for line in file:
        match = re.match(r"^#(\w+) ([\w\.]+)\s*$", line)
        if match != None:
            cmd = match.group(1)
            value = match.group(2)
            if cmd == "module":
                module = Module(value, header)
            elif cmd == "header":
                header = value
                module.SetHeader(header)
            else:
                raise ResolveError("While parsing %s: Unknown directive \"%s\"" % (symfile, cmd))
            #print("%s set to %s" % (cmd, value))
            continue
        match = re.match(r"\s*(\w+)\s*", line)
        if match != None:
            s = match.group(1)
            module.AddSymbol(s)
except ResolveError as e:
    print(e.value)
    exit(1)
except:
    print("Unable to open config file \"%s\"" % symfile)
    exit(1)

for m in all_modules:
    m.ParseDefs()

try:
    output = open(outputfile, "w")
except:
    print("Failed to open output header file \"%s\"" % outputfile)
output.write('#ifndef _KTF_RESOLVE_H\n#define _KTF_RESOLVE_H\n')
output.write('#include "ktf.h"\n\nstruct ktf_syms {\n')

for m in all_modules:
    m.write_funcs(output)

output.write('};\n\n')

for m in all_modules:
    m.write_defines(output)

output.write('\n\nextern struct ktf_syms ktf_syms;\n')
output.write('\nint ktf_resolve_symbols(void);\n#ifndef KTF_CLIENT\n')
output.write('struct ktf_syms ktf_syms;\n\n')

output.write('\nint ktf_resolve_symbols(void)\n{\n')

for m in all_modules:
    m.write_resolve_calls(output)

output.write('\treturn 0;\n}\n\n#endif /* !KTF_CLIENT */\n#endif /* _KTF_RESOLVE_H */ \n')
output.close()
