summaryrefslogtreecommitdiff
path: root/tools/mpy-tool.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/mpy-tool.py')
-rwxr-xr-xtools/mpy-tool.py296
1 files changed, 280 insertions, 16 deletions
diff --git a/tools/mpy-tool.py b/tools/mpy-tool.py
index 8fab1c969..67e2cbf15 100755
--- a/tools/mpy-tool.py
+++ b/tools/mpy-tool.py
@@ -24,6 +24,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+import io
import struct
import sys
from binascii import hexlify
@@ -302,6 +303,25 @@ class Opcode:
MP_BC_POP_JUMP_IF_TRUE,
MP_BC_POP_JUMP_IF_FALSE,
)
+ ALL_OFFSET = (
+ MP_BC_UNWIND_JUMP,
+ MP_BC_JUMP,
+ MP_BC_POP_JUMP_IF_TRUE,
+ MP_BC_POP_JUMP_IF_FALSE,
+ MP_BC_JUMP_IF_TRUE_OR_POP,
+ MP_BC_JUMP_IF_FALSE_OR_POP,
+ MP_BC_SETUP_WITH,
+ MP_BC_SETUP_EXCEPT,
+ MP_BC_SETUP_FINALLY,
+ MP_BC_POP_EXCEPT_JUMP,
+ MP_BC_FOR_ITER,
+ )
+ ALL_WITH_CHILD = (
+ MP_BC_MAKE_FUNCTION,
+ MP_BC_MAKE_FUNCTION_DEFARGS,
+ MP_BC_MAKE_CLOSURE,
+ MP_BC_MAKE_CLOSURE_DEFARGS,
+ )
# Create a dict mapping opcode value to opcode name.
mapping = ["unknown" for _ in range(256)]
@@ -896,7 +916,7 @@ class RawCode(object):
self.escaped_name = unique_escaped_name
def disassemble_children(self):
- print(" children:", [rc.simple_name.str for rc in self.children])
+ self.print_children_annotated()
for rc in self.children:
rc.disassemble()
@@ -985,6 +1005,75 @@ class RawCode(object):
raw_code_count += 1
raw_code_content += 4 * 4
+ @staticmethod
+ def decode_lineinfo(line_info: memoryview) -> "tuple[int, int, memoryview]":
+ c = line_info[0]
+ if (c & 0x80) == 0:
+ # 0b0LLBBBBB encoding
+ return (c & 0x1F), (c >> 5), line_info[1:]
+ else:
+ # 0b1LLLBBBB 0bLLLLLLLL encoding (l's LSB in second byte)
+ return (c & 0xF), (((c << 4) & 0x700) | line_info[1]), line_info[2:]
+
+ def get_source_annotation(self, ip: int, file=None) -> dict:
+ bc_offset = ip - self.offset_opcodes
+ try:
+ line_info = memoryview(self.fun_data)[self.offset_line_info : self.offset_opcodes]
+ except AttributeError:
+ return {"file": file, "line": None}
+
+ source_line = 1
+ while line_info:
+ bc_increment, line_increment, line_info = self.decode_lineinfo(line_info)
+ if bc_offset >= bc_increment:
+ bc_offset -= bc_increment
+ source_line += line_increment
+ else:
+ break
+
+ return {"file": file, "line": source_line}
+
+ def get_label(self, ip: "int | None" = None, child_num: "int | None" = None) -> str:
+ if ip is not None:
+ assert child_num is None
+ return "%s.%d" % (self.escaped_name, ip)
+ elif child_num is not None:
+ return "%s.child%d" % (self.escaped_name, child_num)
+ else:
+ return "%s" % self.escaped_name
+
+ def print_children_annotated(self) -> None:
+ """
+ Equivalent to `print(" children:", [child.simple_name.str for child in self.children])`,
+ but also includes json markers for the start and end of each one's name in that line.
+ """
+
+ labels = ["%s.children" % self.escaped_name]
+ annotation_labels = []
+ output = io.StringIO()
+ output.write(" children: [")
+ sep = ", "
+ for i, child in enumerate(self.children):
+ if i != 0:
+ output.write(sep)
+ start_col = output.tell() + 1
+ output.write(child.simple_name.str)
+ end_col = output.tell() + 1
+ labels.append(self.get_label(child_num=i))
+ annotation_labels.append(
+ {
+ "name": self.get_label(child_num=i),
+ "target": child.get_label(),
+ "range": {
+ "startCol": start_col,
+ "endCol": end_col,
+ },
+ },
+ )
+ output.write("]")
+
+ print(output.getvalue(), annotations={"labels": annotation_labels}, labels=labels)
+
class RawCodeBytecode(RawCode):
def __init__(self, parent_name, qstr_table, obj_table, fun_data):
@@ -993,9 +1082,58 @@ class RawCodeBytecode(RawCode):
parent_name, qstr_table, fun_data, 0, MP_CODE_BYTECODE
)
+ def get_opcode_annotations_labels(
+ self, opcode: int, ip: int, arg: int, sz: int, arg_pos: int, arg_len: int
+ ) -> "tuple[dict, list[str]]":
+ annotations = {
+ "source": self.get_source_annotation(ip),
+ "disassembly": Opcode.mapping[opcode],
+ }
+ labels = [self.get_label(ip)]
+
+ if opcode in Opcode.ALL_OFFSET:
+ annotations["link"] = {
+ "offset": arg_pos,
+ "length": arg_len,
+ "to": ip + arg + sz,
+ }
+ annotations["labels"] = [
+ {
+ "name": self.get_label(ip),
+ "target": self.get_label(ip + arg + sz),
+ "range": {
+ "startCol": arg_pos + 1,
+ "endCol": arg_pos + arg_len + 1,
+ },
+ },
+ ]
+
+ elif opcode in Opcode.ALL_WITH_CHILD:
+ try:
+ child = self.children[arg]
+ except IndexError:
+ # link out-of-range child to the child array itself
+ target = "%s.children" % self.escaped_name
+ else:
+ # link resolvable child to the actual child
+ target = child.get_label()
+
+ annotations["labels"] = [
+ {
+ "name": self.get_label(ip),
+ "target": target,
+ "range": {
+ "startCol": arg_pos + 1,
+ "endCol": arg_pos + arg_len + 1,
+ },
+ },
+ ]
+
+ return annotations, labels
+
def disassemble(self):
bc = self.fun_data
- print("simple_name:", self.simple_name.str)
+ print("simple_name:", self.simple_name.str, labels=[self.get_label()])
print(" raw bytecode:", len(bc), hexlify_to_str(bc))
print(" prelude:", self.prelude_signature)
print(" args:", [self.qstr_table[i].str for i in self.names[1:]])
@@ -1011,9 +1149,22 @@ class RawCodeBytecode(RawCode):
pass
else:
arg = ""
- print(
- " %-11s %s %s" % (hexlify_to_str(bc[ip : ip + sz]), Opcode.mapping[bc[ip]], arg)
+
+ pre_arg_part = " %-11s %s" % (
+ hexlify_to_str(bc[ip : ip + sz]),
+ Opcode.mapping[bc[ip]],
+ )
+ arg_part = "%s" % arg
+ annotations, labels = self.get_opcode_annotations_labels(
+ opcode=bc[ip],
+ ip=ip,
+ arg=arg,
+ sz=sz,
+ arg_pos=len(pre_arg_part) + 1,
+ arg_len=len(arg_part),
)
+
+ print(pre_arg_part, arg_part, annotations=annotations, labels=labels)
ip += sz
self.disassemble_children()
@@ -1114,7 +1265,7 @@ class RawCodeNative(RawCode):
def disassemble(self):
fun_data = self.fun_data
- print("simple_name:", self.simple_name.str)
+ print("simple_name:", self.simple_name.str, labels=[self.get_label()])
print(
" raw data:",
len(fun_data),
@@ -1833,6 +1984,100 @@ def extract_segments(compiled_modules, basename, kinds_arg):
output.write(source.read(segment.end - segment.start))
+class PrintShim:
+ """Base class for interposing extra functionality onto the global `print` method."""
+
+ def __init__(self):
+ self.wrapped_print = None
+
+ def __enter__(self):
+ global print
+
+ if self.wrapped_print is not None:
+ raise RecursionError
+
+ self.wrapped_print = print
+ print = self
+
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ global print
+
+ if self.wrapped_print is None:
+ return
+
+ print = self.wrapped_print
+ self.wrapped_print = None
+
+ self.on_exit()
+
+ def on_exit(self):
+ pass
+
+ def __call__(self, *a, **k):
+ return self.wrapped_print(*a, **k)
+
+
+class PrintIgnoreExtraArgs(PrintShim):
+ """Just strip the `annotations` and `labels` kwargs and pass down to the underlying print."""
+
+ def __call__(self, *a, annotations: dict = {}, labels: "list[str]" = (), **k):
+ return super().__call__(*a, **k)
+
+
+class PrintJson(PrintShim):
+ """Output lines as godbolt-compatible JSON with extra annotation info from `annotations` and `labels`, rather than plain text."""
+
+ def __init__(self, fp=sys.stdout, language_id: str = "mpy"):
+ super().__init__()
+ self.fp = fp
+ self.asm = {
+ "asm": [],
+ "labelDefinitions": {},
+ "languageId": language_id,
+ }
+ self.line_number: int = 0
+ self.buf: "io.StringIO | None" = None
+
+ def on_exit(self):
+ import json
+
+ if self.buf is not None:
+ # flush last partial line
+ self.__call__()
+
+ json.dump(self.asm, self.fp)
+
+ def __call__(self, *a, annotations: dict = {}, labels: "list[str]" = (), **k):
+ # ignore prints directed to an explicit output
+ if "file" in k:
+ return super().__call__(*a, **k)
+
+ if self.buf is None:
+ self.buf = io.StringIO()
+
+ super().__call__(*a, file=sys.stderr, **k)
+
+ if "end" in k:
+ # buffer partial-line prints to collect into a single AsmResultLine
+ return super().__call__(*a, file=self.buf, **k)
+ else:
+ retval = super().__call__(*a, file=self.buf, end="", **k)
+ output = self.buf.getvalue()
+ self.buf = None
+
+ asm_line = {"text": output}
+ asm_line.update(annotations)
+ self.asm["asm"].append(asm_line)
+
+ self.line_number += 1
+ for label in labels:
+ self.asm["labelDefinitions"][label] = self.line_number
+
+ return retval
+
+
def main(args=None):
global global_qstrs
@@ -1847,6 +2092,12 @@ def main(args=None):
)
cmd_parser.add_argument("-f", "--freeze", action="store_true", help="freeze files")
cmd_parser.add_argument(
+ "-j",
+ "--json",
+ action="store_true",
+ help="output hexdump, disassembly, and frozen code as JSON with extra metadata",
+ )
+ cmd_parser.add_argument(
"--merge", action="store_true", help="merge multiple .mpy files into one"
)
cmd_parser.add_argument(
@@ -1913,20 +2164,33 @@ def main(args=None):
print(er, file=sys.stderr)
sys.exit(1)
- if args.hexdump:
- hexdump_mpy(compiled_modules)
+ if args.json:
+ if args.freeze:
+ print_shim = PrintJson(sys.stdout, language_id="c")
+ elif args.hexdump:
+ print_shim = PrintJson(sys.stdout, language_id="stderr")
+ elif args.disassemble:
+ print_shim = PrintJson(sys.stdout, language_id="mpy")
+ else:
+ print_shim = PrintJson(sys.stdout)
+ else:
+ print_shim = PrintIgnoreExtraArgs()
- if args.disassemble:
+ with print_shim:
if args.hexdump:
- print()
- disassemble_mpy(compiled_modules)
+ hexdump_mpy(compiled_modules)
- if args.freeze:
- try:
- freeze_mpy(firmware_qstr_idents, compiled_modules)
- except FreezeError as er:
- print(er, file=sys.stderr)
- sys.exit(1)
+ if args.disassemble:
+ if args.hexdump:
+ print()
+ disassemble_mpy(compiled_modules)
+
+ if args.freeze:
+ try:
+ freeze_mpy(firmware_qstr_idents, compiled_modules)
+ except FreezeError as er:
+ print(er, file=sys.stderr)
+ sys.exit(1)
if args.merge:
merge_mpy(compiled_modules, args.output)