import operator import os import csv import operator from enum import Enum, auto from typing import List, Set, ClassVar, Any from dataclasses import dataclass, field from ansi.color import fg from cxxheaderparser.parser import CxxParser # 'Fixing' complaints about typedefs CxxParser._fundamentals.discard("wchar_t") from cxxheaderparser.types import ( EnumDecl, Field, ForwardDecl, FriendDecl, Function, Method, Typedef, UsingAlias, UsingDecl, Variable, Pointer, Type, PQName, NameSpecifier, FundamentalSpecifier, Parameter, Array, Value, Token, FunctionType, ) from cxxheaderparser.parserstate import ( State, EmptyBlockState, ClassBlockState, ExternBlockState, NamespaceBlockState, ) @dataclass(frozen=True) class ApiEntryFunction: name: str returns: str params: str csv_type: ClassVar[str] = "Function" def dictify(self): return dict(name=self.name, type=self.returns, params=self.params) @dataclass(frozen=True) class ApiEntryVariable: name: str var_type: str csv_type: ClassVar[str] = "Variable" def dictify(self): return dict(name=self.name, type=self.var_type, params=None) @dataclass(frozen=True) class ApiHeader: name: str csv_type: ClassVar[str] = "Header" def dictify(self): return dict(name=self.name, type=None, params=None) @dataclass class ApiEntries: # These are sets, to avoid creating duplicates when we have multiple # declarations with same signature functions: Set[ApiEntryFunction] = field(default_factory=set) variables: Set[ApiEntryVariable] = field(default_factory=set) headers: Set[ApiHeader] = field(default_factory=set) class SymbolManager: def __init__(self): self.api = ApiEntries() self.name_hashes = set() # Calculate hash of name and raise exception if it already is in the set def _name_check(self, name: str): name_hash = gnu_sym_hash(name) if name_hash in self.name_hashes: raise Exception(f"Hash collision on {name}") self.name_hashes.add(name_hash) def add_function(self, function_def: ApiEntryFunction): if function_def in self.api.functions: return self._name_check(function_def.name) self.api.functions.add(function_def) def add_variable(self, variable_def: ApiEntryVariable): if variable_def in self.api.variables: return self._name_check(variable_def.name) self.api.variables.add(variable_def) def add_header(self, header: str): self.api.headers.add(ApiHeader(header)) def gnu_sym_hash(name: str): h = 0x1505 for c in name: h = (h << 5) + h + ord(c) return str(hex(h))[-8:] class SdkCollector: def __init__(self): self.symbol_manager = SymbolManager() def add_header_to_sdk(self, header: str): self.symbol_manager.add_header(header) def process_source_file_for_sdk(self, file_path: str): visitor = SdkCxxVisitor(self.symbol_manager) with open(file_path, "rt") as f: content = f.read() parser = CxxParser(file_path, content, visitor, None) parser.parse() def get_api(self): return self.symbol_manager.api def stringify_array_dimension(size_descr): if not size_descr: return "" return stringify_descr(size_descr) def stringify_array_descr(type_descr): assert isinstance(type_descr, Array) return ( stringify_descr(type_descr.array_of), stringify_array_dimension(type_descr.size), ) def stringify_descr(type_descr): if isinstance(type_descr, (NameSpecifier, FundamentalSpecifier)): return type_descr.name elif isinstance(type_descr, PQName): return "::".join(map(stringify_descr, type_descr.segments)) elif isinstance(type_descr, Pointer): # Hack if isinstance(type_descr.ptr_to, FunctionType): return stringify_descr(type_descr.ptr_to) return f"{stringify_descr(type_descr.ptr_to)}*" elif isinstance(type_descr, Type): return ( f"{'const ' if type_descr.const else ''}" f"{'volatile ' if type_descr.volatile else ''}" f"{stringify_descr(type_descr.typename)}" ) elif isinstance(type_descr, Parameter): return stringify_descr(type_descr.type) elif isinstance(type_descr, Array): # Hack for 2d arrays if isinstance(type_descr.array_of, Array): argtype, dimension = stringify_array_descr(type_descr.array_of) return ( f"{argtype}[{stringify_array_dimension(type_descr.size)}][{dimension}]" ) return f"{stringify_descr(type_descr.array_of)}[{stringify_array_dimension(type_descr.size)}]" elif isinstance(type_descr, Value): return " ".join(map(stringify_descr, type_descr.tokens)) elif isinstance(type_descr, FunctionType): return f"{stringify_descr(type_descr.return_type)} (*)({', '.join(map(stringify_descr, type_descr.parameters))})" elif isinstance(type_descr, Token): return type_descr.value elif type_descr is None: return "" else: raise Exception("unsupported type_descr: %s" % type_descr) class SdkCxxVisitor: def __init__(self, symbol_manager: SymbolManager): self.api = symbol_manager def on_variable(self, state: State, v: Variable) -> None: if not v.extern: return self.api.add_variable( ApiEntryVariable( stringify_descr(v.name), stringify_descr(v.type), ) ) def on_function(self, state: State, fn: Function) -> None: if fn.inline or fn.has_body: return self.api.add_function( ApiEntryFunction( stringify_descr(fn.name), stringify_descr(fn.return_type), ", ".join(map(stringify_descr, fn.parameters)) + (", ..." if fn.vararg else ""), ) ) def on_define(self, state: State, content: str) -> None: pass def on_pragma(self, state: State, content: str) -> None: pass def on_include(self, state: State, filename: str) -> None: pass def on_empty_block_start(self, state: EmptyBlockState) -> None: pass def on_empty_block_end(self, state: EmptyBlockState) -> None: pass def on_extern_block_start(self, state: ExternBlockState) -> None: pass def on_extern_block_end(self, state: ExternBlockState) -> None: pass def on_namespace_start(self, state: NamespaceBlockState) -> None: pass def on_namespace_end(self, state: NamespaceBlockState) -> None: pass def on_forward_decl(self, state: State, fdecl: ForwardDecl) -> None: pass def on_typedef(self, state: State, typedef: Typedef) -> None: pass def on_using_namespace(self, state: State, namespace: List[str]) -> None: pass def on_using_alias(self, state: State, using: UsingAlias) -> None: pass def on_using_declaration(self, state: State, using: UsingDecl) -> None: pass def on_enum(self, state: State, enum: EnumDecl) -> None: pass def on_class_start(self, state: ClassBlockState) -> None: pass def on_class_field(self, state: State, f: Field) -> None: pass def on_class_method(self, state: ClassBlockState, method: Method) -> None: pass def on_class_friend(self, state: ClassBlockState, friend: FriendDecl) -> None: pass def on_class_end(self, state: ClassBlockState) -> None: pass @dataclass(frozen=True) class SdkVersion: major: int = 0 minor: int = 0 csv_type: ClassVar[str] = "Version" def __str__(self) -> str: return f"{self.major}.{self.minor}" def as_int(self) -> int: return ((self.major & 0xFFFF) << 16) | (self.minor & 0xFFFF) @staticmethod def from_str(s: str) -> "SdkVersion": major, minor = s.split(".") return SdkVersion(int(major), int(minor)) def dictify(self) -> dict: return dict(name=str(self), type=None, params=None) class VersionBump(Enum): NONE = auto() MAJOR = auto() MINOR = auto() class ApiEntryState(Enum): PENDING = "?" APPROVED = "+" DISABLED = "-" # Special value for API version entry so users have less incentive to edit it VERSION_PENDING = "v" # Class that stores all known API entries, both enabled and disabled. # Also keeps track of API versioning # Allows comparison and update from newly-generated API class SdkCache: CSV_FIELD_NAMES = ("entry", "status", "name", "type", "params") def __init__(self, cache_file: str, load_version_only=False): self.cache_file_name = cache_file self.version = SdkVersion(0, 0) self.sdk = ApiEntries() self.disabled_entries = set() self.new_entries = set() self.loaded_dirty_version = False self.version_action = VersionBump.NONE self._load_version_only = load_version_only self.load_cache() def is_buildable(self) -> bool: return ( self.version != SdkVersion(0, 0) and self.version_action == VersionBump.NONE and not self._have_pending_entries() ) def _filter_enabled(self, sdk_entries): return sorted( filter(lambda e: e not in self.disabled_entries, sdk_entries), key=operator.attrgetter("name"), ) def get_valid_names(self): syms = set(map(lambda e: e.name, self.get_functions())) syms.update(map(lambda e: e.name, self.get_variables())) return syms def get_functions(self): return self._filter_enabled(self.sdk.functions) def get_variables(self): return self._filter_enabled(self.sdk.variables) def get_headers(self): return self._filter_enabled(self.sdk.headers) def _get_entry_status(self, entry) -> str: if entry in self.disabled_entries: return ApiEntryState.DISABLED elif entry in self.new_entries: if isinstance(entry, SdkVersion): return ApiEntryState.VERSION_PENDING return ApiEntryState.PENDING else: return ApiEntryState.APPROVED def _format_entry(self, obj): obj_dict = obj.dictify() obj_dict.update( dict( entry=obj.csv_type, status=self._get_entry_status(obj).value, ) ) return obj_dict def save(self) -> None: if self._load_version_only: raise Exception("Only SDK version was loaded, cannot save") if self.version_action == VersionBump.MINOR: self.version = SdkVersion(self.version.major, self.version.minor + 1) elif self.version_action == VersionBump.MAJOR: self.version = SdkVersion(self.version.major + 1, 0) if self._have_pending_entries(): self.new_entries.add(self.version) print( fg.red( f"API version is still WIP: {self.version}. Review the changes and re-run command." ) ) print(f"CSV file entries to mark up:") print( fg.yellow( "\n".join( map( str, filter( lambda e: not isinstance(e, SdkVersion), self.new_entries, ), ) ) ) ) else: print(fg.green(f"API version {self.version} is up to date")) regenerate_csv = ( self.loaded_dirty_version or self._have_pending_entries() or self.version_action != VersionBump.NONE ) if regenerate_csv: str_cache_entries = [self.version] name_getter = operator.attrgetter("name") str_cache_entries.extend(sorted(self.sdk.headers, key=name_getter)) str_cache_entries.extend(sorted(self.sdk.functions, key=name_getter)) str_cache_entries.extend(sorted(self.sdk.variables, key=name_getter)) with open(self.cache_file_name, "wt", newline="") as f: writer = csv.DictWriter(f, fieldnames=SdkCache.CSV_FIELD_NAMES) writer.writeheader() for entry in str_cache_entries: writer.writerow(self._format_entry(entry)) def _process_entry(self, entry_dict: dict) -> None: entry_class = entry_dict["entry"] entry_status = entry_dict["status"] entry_name = entry_dict["name"] entry = None if entry_class == SdkVersion.csv_type: self.version = SdkVersion.from_str(entry_name) if entry_status == ApiEntryState.VERSION_PENDING.value: self.loaded_dirty_version = True elif entry_class == ApiHeader.csv_type: self.sdk.headers.add(entry := ApiHeader(entry_name)) elif entry_class == ApiEntryFunction.csv_type: self.sdk.functions.add( entry := ApiEntryFunction( entry_name, entry_dict["type"], entry_dict["params"], ) ) elif entry_class == ApiEntryVariable.csv_type: self.sdk.variables.add( entry := ApiEntryVariable(entry_name, entry_dict["type"]) ) else: print(entry_dict) raise Exception("Unknown entry type: %s" % entry_class) if entry is None: return if entry_status == ApiEntryState.DISABLED.value: self.disabled_entries.add(entry) elif entry_status == ApiEntryState.PENDING.value: self.new_entries.add(entry) def load_cache(self) -> None: if not os.path.exists(self.cache_file_name): raise Exception( f"Cannot load symbol cache '{self.cache_file_name}'! File does not exist" ) with open(self.cache_file_name, "rt") as f: reader = csv.DictReader(f) for row in reader: self._process_entry(row) if self._load_version_only and row.get("entry") == SdkVersion.csv_type: break def _have_pending_entries(self) -> bool: return any( filter( lambda e: not isinstance(e, SdkVersion), self.new_entries, ) ) def sync_sets( self, known_set: Set[Any], new_set: Set[Any], update_version: bool = True ): new_entries = new_set - known_set if new_entries: print(f"New: {new_entries}") known_set |= new_entries self.new_entries |= new_entries if update_version and self.version_action == VersionBump.NONE: self.version_action = VersionBump.MINOR removed_entries = known_set - new_set if removed_entries: print(f"Removed: {removed_entries}") known_set -= removed_entries # If any of removed entries was a part of active API, that's a major bump if update_version and any( filter( lambda e: e not in self.disabled_entries and e not in self.new_entries, removed_entries, ) ): self.version_action = VersionBump.MAJOR self.disabled_entries -= removed_entries self.new_entries -= removed_entries def validate_api(self, api: ApiEntries) -> None: self.sync_sets(self.sdk.headers, api.headers, False) self.sync_sets(self.sdk.functions, api.functions) self.sync_sets(self.sdk.variables, api.variables)