fix: refactor some more for tests + pylint
This commit is contained in:
parent
e494bdd2b9
commit
8c3948ff33
3 changed files with 36 additions and 11 deletions
|
@ -29,6 +29,7 @@ import os
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy_utils.functions import get_primary_keys, get_columns
|
||||
|
||||
|
@ -188,7 +189,10 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public-
|
|||
max_delete = None
|
||||
max_total = None
|
||||
|
||||
def __init__(self, config, handler=None, model_class=None, **kwargs):
|
||||
handler = None
|
||||
model_class = None
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
self.config = config
|
||||
self.app = self.config.get_app()
|
||||
|
||||
|
@ -202,8 +206,6 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public-
|
|||
"delete", kwargs.pop("allow_delete", self.allow_delete)
|
||||
)
|
||||
|
||||
self.handler = handler
|
||||
self.model_class = model_class
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
self.fields = self.get_fields()
|
||||
|
@ -272,7 +274,10 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public-
|
|||
if hasattr(self, "simple_fields"):
|
||||
return self.simple_fields
|
||||
|
||||
try:
|
||||
fields = get_columns(self.model_class)
|
||||
except sa.exc.NoInspectionAvailable:
|
||||
return []
|
||||
return list(fields.keys())
|
||||
|
||||
def get_supported_fields(self):
|
||||
|
@ -1008,9 +1013,7 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public-
|
|||
return None
|
||||
|
||||
obj = self.make_empty_object(key)
|
||||
if obj:
|
||||
return self.update_target_object(obj, source_data)
|
||||
return None
|
||||
|
||||
def make_empty_object(self, key):
|
||||
"""
|
||||
|
@ -1039,7 +1042,9 @@ class Importer: # pylint: disable=too-many-instance-attributes,too-many-public-
|
|||
|
||||
Default logic will make a new instance of :attr:`model_class`.
|
||||
"""
|
||||
return self.model_class()
|
||||
if callable(self.model_class):
|
||||
return self.model_class() # pylint: disable=not-callable
|
||||
raise AttributeError("model_class is not callable!")
|
||||
|
||||
def update_target_object(self, obj, source_data, target_data=None):
|
||||
"""
|
||||
|
|
|
@ -406,11 +406,18 @@ class TestImporter(DataTestCase):
|
|||
def test_normalize_source_object_all(self):
|
||||
model = self.app.model
|
||||
imp = self.make_importer(model_class=model.Setting)
|
||||
|
||||
# normal
|
||||
setting = model.Setting()
|
||||
result = imp.normalize_source_object_all(setting)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIs(result[0], setting)
|
||||
|
||||
# unwanted (normalized is None)
|
||||
with patch.object(imp, "normalize_source_object", return_value=None):
|
||||
result = imp.normalize_source_object_all(setting)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_normalize_source_object(self):
|
||||
model = self.app.model
|
||||
imp = self.make_importer(model_class=model.Setting)
|
||||
|
@ -532,10 +539,16 @@ class TestImporter(DataTestCase):
|
|||
|
||||
def test_make_object(self):
|
||||
model = self.app.model
|
||||
|
||||
# normal
|
||||
imp = self.make_importer(model_class=model.Setting)
|
||||
obj = imp.make_object()
|
||||
self.assertIsInstance(obj, model.Setting)
|
||||
|
||||
# no model_class
|
||||
imp = self.make_importer()
|
||||
self.assertRaises(AttributeError, imp.make_object)
|
||||
|
||||
def test_update_target_object(self):
|
||||
model = self.app.model
|
||||
imp = self.make_importer(model_class=model.Setting)
|
||||
|
@ -707,13 +720,20 @@ class TestToSqlalchemy(DataTestCase):
|
|||
imp = self.make_importer(model_class=model.Setting, target_session=self.session)
|
||||
setting = model.Setting(name="foo", value="bar")
|
||||
|
||||
# new object is added to session
|
||||
# normal; new object is added to session
|
||||
setting = imp.create_target_object(("foo",), {"name": "foo", "value": "bar"})
|
||||
self.assertIsInstance(setting, model.Setting)
|
||||
self.assertEqual(setting.name, "foo")
|
||||
self.assertEqual(setting.value, "bar")
|
||||
self.assertIn(setting, self.session)
|
||||
|
||||
# unwanted; parent class does not create the object
|
||||
with patch.object(mod.Importer, "create_target_object", return_value=None):
|
||||
setting = imp.create_target_object(
|
||||
("foo",), {"name": "foo", "value": "bar"}
|
||||
)
|
||||
self.assertIsNone(setting)
|
||||
|
||||
def test_delete_target_object(self):
|
||||
model = self.app.model
|
||||
|
||||
|
|
|
@ -82,8 +82,8 @@ foo2,bar2
|
|||
imp.input_file_path = self.data_path
|
||||
imp.open_input_file()
|
||||
imp.close_input_file()
|
||||
self.assertFalse(hasattr(imp, "input_reader"))
|
||||
self.assertFalse(hasattr(imp, "input_file"))
|
||||
self.assertIsNone(imp.input_reader)
|
||||
self.assertIsNone(imp.input_file)
|
||||
|
||||
def test_get_source_objects(self):
|
||||
model = self.app.model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue