diff --git a/src/wuttasync/importing/base.py b/src/wuttasync/importing/base.py index e8aa523..629ead6 100644 --- a/src/wuttasync/importing/base.py +++ b/src/wuttasync/importing/base.py @@ -29,6 +29,7 @@ import os import logging from collections import OrderedDict +import sqlalchemy as sa from sqlalchemy import orm from sqlalchemy_utils.functions import get_primary_keys, get_columns @@ -188,7 +189,10 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public- max_delete = None max_total = None - def __init__(self, config, handler=None, model_class=None, **kwargs): + handler = None + model_class = None + + def __init__(self, config, **kwargs): self.config = config self.app = self.config.get_app() @@ -202,8 +206,6 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public- "delete", kwargs.pop("allow_delete", self.allow_delete) ) - self.handler = handler - self.model_class = model_class self.__dict__.update(kwargs) self.fields = self.get_fields() @@ -272,7 +274,10 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public- if hasattr(self, "simple_fields"): return self.simple_fields - fields = get_columns(self.model_class) + try: + fields = get_columns(self.model_class) + except sa.exc.NoInspectionAvailable: + return [] return list(fields.keys()) def get_supported_fields(self): @@ -1008,9 +1013,7 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public- return None obj = self.make_empty_object(key) - if obj: - return self.update_target_object(obj, source_data) - return None + return self.update_target_object(obj, source_data) def make_empty_object(self, key): """ @@ -1039,7 +1042,9 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public- Default logic will make a new instance of :attr:`model_class`. """ - return self.model_class() + if callable(self.model_class): + return self.model_class() # pylint: disable=not-callable + raise AttributeError("model_class is not callable!") def update_target_object(self, obj, source_data, target_data=None): """ diff --git a/tests/importing/test_base.py b/tests/importing/test_base.py index 2ec9164..08c37a2 100644 --- a/tests/importing/test_base.py +++ b/tests/importing/test_base.py @@ -406,11 +406,18 @@ class TestImporter(DataTestCase): def test_normalize_source_object_all(self): model = self.app.model imp = self.make_importer(model_class=model.Setting) + + # normal setting = model.Setting() result = imp.normalize_source_object_all(setting) self.assertEqual(len(result), 1) self.assertIs(result[0], setting) + # unwanted (normalized is None) + with patch.object(imp, "normalize_source_object", return_value=None): + result = imp.normalize_source_object_all(setting) + self.assertIsNone(result) + def test_normalize_source_object(self): model = self.app.model imp = self.make_importer(model_class=model.Setting) @@ -532,10 +539,16 @@ class TestImporter(DataTestCase): def test_make_object(self): model = self.app.model + + # normal imp = self.make_importer(model_class=model.Setting) obj = imp.make_object() self.assertIsInstance(obj, model.Setting) + # no model_class + imp = self.make_importer() + self.assertRaises(AttributeError, imp.make_object) + def test_update_target_object(self): model = self.app.model imp = self.make_importer(model_class=model.Setting) @@ -707,13 +720,20 @@ class TestToSqlalchemy(DataTestCase): imp = self.make_importer(model_class=model.Setting, target_session=self.session) setting = model.Setting(name="foo", value="bar") - # new object is added to session + # normal; new object is added to session setting = imp.create_target_object(("foo",), {"name": "foo", "value": "bar"}) self.assertIsInstance(setting, model.Setting) self.assertEqual(setting.name, "foo") self.assertEqual(setting.value, "bar") self.assertIn(setting, self.session) + # unwanted; parent class does not create the object + with patch.object(mod.Importer, "create_target_object", return_value=None): + setting = imp.create_target_object( + ("foo",), {"name": "foo", "value": "bar"} + ) + self.assertIsNone(setting) + def test_delete_target_object(self): model = self.app.model diff --git a/tests/importing/test_csv.py b/tests/importing/test_csv.py index acd5f8e..8544d63 100644 --- a/tests/importing/test_csv.py +++ b/tests/importing/test_csv.py @@ -82,8 +82,8 @@ foo2,bar2 imp.input_file_path = self.data_path imp.open_input_file() imp.close_input_file() - self.assertFalse(hasattr(imp, "input_reader")) - self.assertFalse(hasattr(imp, "input_file")) + self.assertIsNone(imp.input_reader) + self.assertIsNone(imp.input_file) def test_get_source_objects(self): model = self.app.model