summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--py/builtinimport.c20
-rw-r--r--tests/cpydiff/core_import_prereg.py18
-rw-r--r--tests/import/broken/pkg2_and_zerodiv.py2
-rw-r--r--tests/import/broken/zerodiv.py1
-rw-r--r--tests/import/circular/main.py4
-rw-r--r--tests/import/circular/sub.py3
-rw-r--r--tests/import/import_broken.py32
-rw-r--r--tests/import/import_circular.py1
8 files changed, 62 insertions, 19 deletions
diff --git a/py/builtinimport.c b/py/builtinimport.c
index de4ea17f3..8827be612 100644
--- a/py/builtinimport.c
+++ b/py/builtinimport.c
@@ -346,6 +346,17 @@ STATIC void evaluate_relative_import(mp_int_t level, const char **module_name, s
*module_name_len = new_module_name_len;
}
+typedef struct _nlr_jump_callback_node_unregister_module_t {
+ nlr_jump_callback_node_t callback;
+ qstr name;
+} nlr_jump_callback_node_unregister_module_t;
+
+STATIC void unregister_module_from_nlr_jump_callback(void *ctx_in) {
+ nlr_jump_callback_node_unregister_module_t *ctx = ctx_in;
+ mp_map_t *mp_loaded_modules_map = &MP_STATE_VM(mp_loaded_modules_dict).map;
+ mp_map_lookup(mp_loaded_modules_map, MP_OBJ_NEW_QSTR(ctx->name), MP_MAP_LOOKUP_REMOVE_IF_FOUND);
+}
+
// Load a module at the specified absolute path, possibly as a submodule of the given outer module.
// full_mod_name: The full absolute path up to this level (e.g. "foo.bar.baz").
// level_mod_name: The final component of the path (e.g. "baz").
@@ -467,8 +478,13 @@ STATIC mp_obj_t process_import_at_level(qstr full_mod_name, qstr level_mod_name,
// Module was found on the filesystem/frozen, try and load it.
DEBUG_printf("Found path to load: %.*s\n", (int)vstr_len(&path), vstr_str(&path));
- // Prepare for loading from the filesystem. Create a new shell module.
+ // Prepare for loading from the filesystem. Create a new shell module
+ // and register it in sys.modules. Also make sure we remove it if
+ // there is any problem below.
module_obj = mp_obj_new_module(full_mod_name);
+ nlr_jump_callback_node_unregister_module_t ctx;
+ ctx.name = full_mod_name;
+ nlr_push_jump_callback(&ctx.callback, unregister_module_from_nlr_jump_callback);
#if MICROPY_MODULE_OVERRIDE_MAIN_IMPORT
// If this module is being loaded via -m on unix, then
@@ -526,6 +542,8 @@ STATIC mp_obj_t process_import_at_level(qstr full_mod_name, qstr level_mod_name,
mp_store_attr(outer_module_obj, level_mod_name, module_obj);
}
+ nlr_pop_jump_callback(false);
+
return module_obj;
}
diff --git a/tests/cpydiff/core_import_prereg.py b/tests/cpydiff/core_import_prereg.py
deleted file mode 100644
index 3ce2340c6..000000000
--- a/tests/cpydiff/core_import_prereg.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""
-categories: Core,import
-description: Failed to load modules are still registered as loaded
-cause: To make module handling more efficient, it's not wrapped with exception handling.
-workaround: Test modules before production use; during development, use ``del sys.modules["name"]``, or just soft or hard reset the board.
-"""
-import sys
-
-try:
- from modules import foo
-except NameError as e:
- print(e)
-try:
- from modules import foo
-
- print("Should not get here")
-except NameError as e:
- print(e)
diff --git a/tests/import/broken/pkg2_and_zerodiv.py b/tests/import/broken/pkg2_and_zerodiv.py
new file mode 100644
index 000000000..3580628ff
--- /dev/null
+++ b/tests/import/broken/pkg2_and_zerodiv.py
@@ -0,0 +1,2 @@
+import pkg2
+import broken.zerodiv
diff --git a/tests/import/broken/zerodiv.py b/tests/import/broken/zerodiv.py
new file mode 100644
index 000000000..72dca4d5e
--- /dev/null
+++ b/tests/import/broken/zerodiv.py
@@ -0,0 +1 @@
+1 / 0
diff --git a/tests/import/circular/main.py b/tests/import/circular/main.py
new file mode 100644
index 000000000..5d63d507c
--- /dev/null
+++ b/tests/import/circular/main.py
@@ -0,0 +1,4 @@
+x = 1
+import circular.sub
+
+print(circular.sub.y)
diff --git a/tests/import/circular/sub.py b/tests/import/circular/sub.py
new file mode 100644
index 000000000..50d7afe07
--- /dev/null
+++ b/tests/import/circular/sub.py
@@ -0,0 +1,3 @@
+from circular.main import x
+
+y = x + 20
diff --git a/tests/import/import_broken.py b/tests/import/import_broken.py
new file mode 100644
index 000000000..3c7cf4a49
--- /dev/null
+++ b/tests/import/import_broken.py
@@ -0,0 +1,32 @@
+import sys, pkg
+
+# Modules we import are usually added to sys.modules.
+print("pkg" in sys.modules)
+
+try:
+ from broken.zerodiv import x
+except Exception as e:
+ print(e.__class__.__name__)
+
+# The broken module we tried to import should not be in sys.modules.
+print("broken.zerodiv" in sys.modules)
+
+# If we try to import the module again, the code should
+# run again and we should get the same error.
+try:
+ from broken.zerodiv import x
+except Exception as e:
+ print(e.__class__.__name__)
+
+# Import a module that successfully imports some other modules
+# before importing the problematic module.
+try:
+ import broken.pkg2_and_zerodiv
+except ZeroDivisionError:
+ pass
+
+print("pkg2" in sys.modules)
+print("pkg2.mod1" in sys.modules)
+print("pkg2.mod2" in sys.modules)
+print("broken.zerodiv" in sys.modules)
+print("broken.pkg2_and_zerodiv" in sys.modules)
diff --git a/tests/import/import_circular.py b/tests/import/import_circular.py
new file mode 100644
index 000000000..388efdd13
--- /dev/null
+++ b/tests/import/import_circular.py
@@ -0,0 +1 @@
+import circular.main