diff options
| -rwxr-xr-x | tools/mpy-tool.py | 296 |
1 files changed, 280 insertions, 16 deletions
diff --git a/tools/mpy-tool.py b/tools/mpy-tool.py index 8fab1c969..67e2cbf15 100755 --- a/tools/mpy-tool.py +++ b/tools/mpy-tool.py @@ -24,6 +24,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import io import struct import sys from binascii import hexlify @@ -302,6 +303,25 @@ class Opcode: MP_BC_POP_JUMP_IF_TRUE, MP_BC_POP_JUMP_IF_FALSE, ) + ALL_OFFSET = ( + MP_BC_UNWIND_JUMP, + MP_BC_JUMP, + MP_BC_POP_JUMP_IF_TRUE, + MP_BC_POP_JUMP_IF_FALSE, + MP_BC_JUMP_IF_TRUE_OR_POP, + MP_BC_JUMP_IF_FALSE_OR_POP, + MP_BC_SETUP_WITH, + MP_BC_SETUP_EXCEPT, + MP_BC_SETUP_FINALLY, + MP_BC_POP_EXCEPT_JUMP, + MP_BC_FOR_ITER, + ) + ALL_WITH_CHILD = ( + MP_BC_MAKE_FUNCTION, + MP_BC_MAKE_FUNCTION_DEFARGS, + MP_BC_MAKE_CLOSURE, + MP_BC_MAKE_CLOSURE_DEFARGS, + ) # Create a dict mapping opcode value to opcode name. mapping = ["unknown" for _ in range(256)] @@ -896,7 +916,7 @@ class RawCode(object): self.escaped_name = unique_escaped_name def disassemble_children(self): - print(" children:", [rc.simple_name.str for rc in self.children]) + self.print_children_annotated() for rc in self.children: rc.disassemble() @@ -985,6 +1005,75 @@ class RawCode(object): raw_code_count += 1 raw_code_content += 4 * 4 + @staticmethod + def decode_lineinfo(line_info: memoryview) -> "tuple[int, int, memoryview]": + c = line_info[0] + if (c & 0x80) == 0: + # 0b0LLBBBBB encoding + return (c & 0x1F), (c >> 5), line_info[1:] + else: + # 0b1LLLBBBB 0bLLLLLLLL encoding (l's LSB in second byte) + return (c & 0xF), (((c << 4) & 0x700) | line_info[1]), line_info[2:] + + def get_source_annotation(self, ip: int, file=None) -> dict: + bc_offset = ip - self.offset_opcodes + try: + line_info = memoryview(self.fun_data)[self.offset_line_info : self.offset_opcodes] + except AttributeError: + return {"file": file, "line": None} + + source_line = 1 + while line_info: + bc_increment, line_increment, line_info = self.decode_lineinfo(line_info) + if bc_offset >= bc_increment: + bc_offset -= bc_increment + source_line += line_increment + else: + break + + return {"file": file, "line": source_line} + + def get_label(self, ip: "int | None" = None, child_num: "int | None" = None) -> str: + if ip is not None: + assert child_num is None + return "%s.%d" % (self.escaped_name, ip) + elif child_num is not None: + return "%s.child%d" % (self.escaped_name, child_num) + else: + return "%s" % self.escaped_name + + def print_children_annotated(self) -> None: + """ + Equivalent to `print(" children:", [child.simple_name.str for child in self.children])`, + but also includes json markers for the start and end of each one's name in that line. + """ + + labels = ["%s.children" % self.escaped_name] + annotation_labels = [] + output = io.StringIO() + output.write(" children: [") + sep = ", " + for i, child in enumerate(self.children): + if i != 0: + output.write(sep) + start_col = output.tell() + 1 + output.write(child.simple_name.str) + end_col = output.tell() + 1 + labels.append(self.get_label(child_num=i)) + annotation_labels.append( + { + "name": self.get_label(child_num=i), + "target": child.get_label(), + "range": { + "startCol": start_col, + "endCol": end_col, + }, + }, + ) + output.write("]") + + print(output.getvalue(), annotations={"labels": annotation_labels}, labels=labels) + class RawCodeBytecode(RawCode): def __init__(self, parent_name, qstr_table, obj_table, fun_data): @@ -993,9 +1082,58 @@ class RawCodeBytecode(RawCode): parent_name, qstr_table, fun_data, 0, MP_CODE_BYTECODE ) + def get_opcode_annotations_labels( + self, opcode: int, ip: int, arg: int, sz: int, arg_pos: int, arg_len: int + ) -> "tuple[dict, list[str]]": + annotations = { + "source": self.get_source_annotation(ip), + "disassembly": Opcode.mapping[opcode], + } + labels = [self.get_label(ip)] + + if opcode in Opcode.ALL_OFFSET: + annotations["link"] = { + "offset": arg_pos, + "length": arg_len, + "to": ip + arg + sz, + } + annotations["labels"] = [ + { + "name": self.get_label(ip), + "target": self.get_label(ip + arg + sz), + "range": { + "startCol": arg_pos + 1, + "endCol": arg_pos + arg_len + 1, + }, + }, + ] + + elif opcode in Opcode.ALL_WITH_CHILD: + try: + child = self.children[arg] + except IndexError: + # link out-of-range child to the child array itself + target = "%s.children" % self.escaped_name + else: + # link resolvable child to the actual child + target = child.get_label() + + annotations["labels"] = [ + { + "name": self.get_label(ip), + "target": target, + "range": { + "startCol": arg_pos + 1, + "endCol": arg_pos + arg_len + 1, + }, + }, + ] + + return annotations, labels + def disassemble(self): bc = self.fun_data - print("simple_name:", self.simple_name.str) + print("simple_name:", self.simple_name.str, labels=[self.get_label()]) print(" raw bytecode:", len(bc), hexlify_to_str(bc)) print(" prelude:", self.prelude_signature) print(" args:", [self.qstr_table[i].str for i in self.names[1:]]) @@ -1011,9 +1149,22 @@ class RawCodeBytecode(RawCode): pass else: arg = "" - print( - " %-11s %s %s" % (hexlify_to_str(bc[ip : ip + sz]), Opcode.mapping[bc[ip]], arg) + + pre_arg_part = " %-11s %s" % ( + hexlify_to_str(bc[ip : ip + sz]), + Opcode.mapping[bc[ip]], + ) + arg_part = "%s" % arg + annotations, labels = self.get_opcode_annotations_labels( + opcode=bc[ip], + ip=ip, + arg=arg, + sz=sz, + arg_pos=len(pre_arg_part) + 1, + arg_len=len(arg_part), ) + + print(pre_arg_part, arg_part, annotations=annotations, labels=labels) ip += sz self.disassemble_children() @@ -1114,7 +1265,7 @@ class RawCodeNative(RawCode): def disassemble(self): fun_data = self.fun_data - print("simple_name:", self.simple_name.str) + print("simple_name:", self.simple_name.str, labels=[self.get_label()]) print( " raw data:", len(fun_data), @@ -1833,6 +1984,100 @@ def extract_segments(compiled_modules, basename, kinds_arg): output.write(source.read(segment.end - segment.start)) +class PrintShim: + """Base class for interposing extra functionality onto the global `print` method.""" + + def __init__(self): + self.wrapped_print = None + + def __enter__(self): + global print + + if self.wrapped_print is not None: + raise RecursionError + + self.wrapped_print = print + print = self + + return self + + def __exit__(self, exc_type, exc_value, traceback): + global print + + if self.wrapped_print is None: + return + + print = self.wrapped_print + self.wrapped_print = None + + self.on_exit() + + def on_exit(self): + pass + + def __call__(self, *a, **k): + return self.wrapped_print(*a, **k) + + +class PrintIgnoreExtraArgs(PrintShim): + """Just strip the `annotations` and `labels` kwargs and pass down to the underlying print.""" + + def __call__(self, *a, annotations: dict = {}, labels: "list[str]" = (), **k): + return super().__call__(*a, **k) + + +class PrintJson(PrintShim): + """Output lines as godbolt-compatible JSON with extra annotation info from `annotations` and `labels`, rather than plain text.""" + + def __init__(self, fp=sys.stdout, language_id: str = "mpy"): + super().__init__() + self.fp = fp + self.asm = { + "asm": [], + "labelDefinitions": {}, + "languageId": language_id, + } + self.line_number: int = 0 + self.buf: "io.StringIO | None" = None + + def on_exit(self): + import json + + if self.buf is not None: + # flush last partial line + self.__call__() + + json.dump(self.asm, self.fp) + + def __call__(self, *a, annotations: dict = {}, labels: "list[str]" = (), **k): + # ignore prints directed to an explicit output + if "file" in k: + return super().__call__(*a, **k) + + if self.buf is None: + self.buf = io.StringIO() + + super().__call__(*a, file=sys.stderr, **k) + + if "end" in k: + # buffer partial-line prints to collect into a single AsmResultLine + return super().__call__(*a, file=self.buf, **k) + else: + retval = super().__call__(*a, file=self.buf, end="", **k) + output = self.buf.getvalue() + self.buf = None + + asm_line = {"text": output} + asm_line.update(annotations) + self.asm["asm"].append(asm_line) + + self.line_number += 1 + for label in labels: + self.asm["labelDefinitions"][label] = self.line_number + + return retval + + def main(args=None): global global_qstrs @@ -1847,6 +2092,12 @@ def main(args=None): ) cmd_parser.add_argument("-f", "--freeze", action="store_true", help="freeze files") cmd_parser.add_argument( + "-j", + "--json", + action="store_true", + help="output hexdump, disassembly, and frozen code as JSON with extra metadata", + ) + cmd_parser.add_argument( "--merge", action="store_true", help="merge multiple .mpy files into one" ) cmd_parser.add_argument( @@ -1913,20 +2164,33 @@ def main(args=None): print(er, file=sys.stderr) sys.exit(1) - if args.hexdump: - hexdump_mpy(compiled_modules) + if args.json: + if args.freeze: + print_shim = PrintJson(sys.stdout, language_id="c") + elif args.hexdump: + print_shim = PrintJson(sys.stdout, language_id="stderr") + elif args.disassemble: + print_shim = PrintJson(sys.stdout, language_id="mpy") + else: + print_shim = PrintJson(sys.stdout) + else: + print_shim = PrintIgnoreExtraArgs() - if args.disassemble: + with print_shim: if args.hexdump: - print() - disassemble_mpy(compiled_modules) + hexdump_mpy(compiled_modules) - if args.freeze: - try: - freeze_mpy(firmware_qstr_idents, compiled_modules) - except FreezeError as er: - print(er, file=sys.stderr) - sys.exit(1) + if args.disassemble: + if args.hexdump: + print() + disassemble_mpy(compiled_modules) + + if args.freeze: + try: + freeze_mpy(firmware_qstr_idents, compiled_modules) + except FreezeError as er: + print(er, file=sys.stderr) + sys.exit(1) if args.merge: merge_mpy(compiled_modules, args.output) |
