fix: implement deletion logic; add cli params for max changes

also add special UUID field handling for CSV -> SQLAlchemy ORM, to
normalize string from CSV to proper UUID so key matching works
This commit is contained in:
Lance Edgar 2024-12-06 15:18:23 -06:00
parent a73896b75d
commit 328f8d9952
6 changed files with 735 additions and 115 deletions

View file

@ -147,15 +147,18 @@ def import_command_template(
create: Annotated[ create: Annotated[
bool, bool,
typer.Option(help="Allow new target records to be created.")] = True, typer.Option(help="Allow new target records to be created. "
"See aso --max-create.")] = True,
update: Annotated[ update: Annotated[
bool, bool,
typer.Option(help="Allow existing target records to be updated.")] = True, typer.Option(help="Allow existing target records to be updated. "
"See also --max-update.")] = True,
delete: Annotated[ delete: Annotated[
bool, bool,
typer.Option(help="Allow existing target records to be deleted.")] = False, typer.Option(help="Allow existing target records to be deleted. "
"See also --max-delete.")] = False,
fields: Annotated[ fields: Annotated[
str, str,
@ -170,7 +173,27 @@ def import_command_template(
keys: Annotated[ keys: Annotated[
str, str,
typer.Option('--key', '--keys', typer.Option('--key', '--keys',
help="List of fields to use as record key/identifier. See also --fields.")] = None, help="List of fields to use as record key/identifier. "
"See also --fields.")] = None,
max_create: Annotated[
int,
typer.Option(help="Max number of target records to create (per model). "
"See also --create.")] = None,
max_update: Annotated[
int,
typer.Option(help="Max number of target records to update (per model). "
"See also --update.")] = None,
max_delete: Annotated[
int,
typer.Option(help="Max number of target records to delete (per model). "
"See also --delete.")] = None,
max_total: Annotated[
int,
typer.Option(help="Max number of *any* target record changes which may occur (per model).")] = None,
dry_run: Annotated[ dry_run: Annotated[
bool, bool,

View file

@ -22,6 +22,22 @@
################################################################################ ################################################################################
""" """
Data Import / Export Framework Data Import / Export Framework
This namespace exposes the following:
* :enum:`~wuttasync.importing.handlers.Orientation`
And for the :term:`import handlers <import handler>`:
* :class:`~wuttasync.importing.handlers.ImportHandler`
* :class:`~wuttasync.importing.handlers.FromFileHandler`
* :class:`~wuttasync.importing.handlers.ToSqlalchemyHandler`
And for the :term:`importers <importer>`:
* :class:`~wuttasync.importing.base.Importer`
* :class:`~wuttasync.importing.base.FromFile`
* :class:`~wuttasync.importing.base.ToSqlalchemy`
""" """
from .handlers import Orientation, ImportHandler, FromFileHandler, ToSqlalchemyHandler from .handlers import Orientation, ImportHandler, FromFileHandler, ToSqlalchemyHandler

View file

@ -26,6 +26,7 @@ Data Importer base class
import os import os
import logging import logging
from collections import OrderedDict
from sqlalchemy import orm from sqlalchemy import orm
from sqlalchemy_utils.functions import get_primary_keys, get_columns from sqlalchemy_utils.functions import get_primary_keys, get_columns
@ -36,6 +37,13 @@ from wuttasync.util import data_diffs
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ImportLimitReached(Exception):
"""
Exception raised when an import/export job reaches the max number
of changes allowed.
"""
class Importer: class Importer:
""" """
Base class for all data importers / exporters. Base class for all data importers / exporters.
@ -174,6 +182,11 @@ class Importer:
:meth:`get_target_cache()`. :meth:`get_target_cache()`.
""" """
max_create = None
max_update = None
max_delete = None
max_total = None
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
self.config = config self.config = config
self.app = self.config.get_app() self.app = self.config.get_app()
@ -354,9 +367,26 @@ class Importer:
Note that subclass generally should not override this method, Note that subclass generally should not override this method,
but instead some of the others. but instead some of the others.
:param source_data: Optional sequence of normalized source This first calls :meth:`setup()` to prepare things as needed.
data. If not specified, it is obtained from
:meth:`normalize_source_data()`. If no source data is specified, it calls
:meth:`normalize_source_data()` to get that. Regardless, it
also calls :meth:`get_unique_data()` to discard any
duplicates.
If :attr:`caches_target` is set, it calls
:meth:`get_target_cache()` and assigns result to
:attr:`cached_target`.
Then depending on values for :attr:`create`, :attr:`update`
and :attr:`delete` it may call:
* :meth:`do_create_update()`
* :meth:`do_delete()`
And finally it calls :meth:`teardown()` for cleanup.
:param source_data: Sequence of normalized source data, if known.
:param progress: Optional progress indicator factory. :param progress: Optional progress indicator factory.
@ -366,13 +396,6 @@ class Importer:
* ``created`` - list of records created on the target * ``created`` - list of records created on the target
* ``updated`` - list of records updated on the target * ``updated`` - list of records updated on the target
* ``deleted`` - list of records deleted on the target * ``deleted`` - list of records deleted on the target
See also these methods which this one calls:
* :meth:`setup()`
* :meth:`do_create_update()`
* :meth:`do_delete()`
* :meth:`teardown()`
""" """
# TODO: should add try/catch around this all? and teardown() in finally: clause? # TODO: should add try/catch around this all? and teardown() in finally: clause?
self.setup() self.setup()
@ -386,8 +409,9 @@ class Importer:
if source_data is None: if source_data is None:
source_data = self.normalize_source_data(progress=progress) source_data = self.normalize_source_data(progress=progress)
# TODO: should exclude duplicate source records # nb. prune duplicate records from source data
# source_data, unique = self.get_unique_data(source_data) source_data, source_keys = self.get_unique_data(source_data)
model_title = self.get_model_title() model_title = self.get_model_title()
log.debug(f"got %s {model_title} records from source", log.debug(f"got %s {model_title} records from source",
len(source_data)) len(source_data))
@ -402,7 +426,12 @@ class Importer:
# delete target data # delete target data
if self.delete: if self.delete:
deleted = self.do_delete(source_data) changes = len(created) + len(updated)
if self.max_total and changes >= self.max_total:
log.debug("max of %s total changes already reached; skipping deletions",
self.max_total)
else:
deleted = self.do_delete(source_keys, changes, progress=progress)
self.teardown() self.teardown()
return created, updated, deleted return created, updated, deleted
@ -460,6 +489,16 @@ class Importer:
target_data=target_data) target_data=target_data)
updated.append((target_object, target_data, source_data)) updated.append((target_object, target_data, source_data))
# stop if we reach max allowed
if self.max_update and len(updated) >= self.max_update:
log.warning("max of %s *updated* records has been reached; stopping now",
self.max_update)
raise ImportLimitReached()
elif self.max_total and (len(created) + len(updated)) >= self.max_total:
log.warning("max of %s *total changes* has been reached; stopping now",
self.max_total)
raise ImportLimitReached()
elif not target_object and self.create: elif not target_object and self.create:
# target object not yet present, so create it # target object not yet present, so create it
@ -473,23 +512,94 @@ class Importer:
# 'object': target_object, # 'object': target_object,
# 'data': self.normalize_target_object(target_object), # 'data': self.normalize_target_object(target_object),
# } # }
# stop if we reach max allowed
if self.max_create and len(created) >= self.max_create:
log.warning("max of %s *created* records has been reached; stopping now",
self.max_create)
raise ImportLimitReached()
elif self.max_total and (len(created) + len(updated)) >= self.max_total:
log.warning("max of %s *total changes* has been reached; stopping now",
self.max_total)
raise ImportLimitReached()
else: else:
log.debug("did NOT create new %s for key: %s", model_title, key) log.debug("did NOT create new %s for key: %s", model_title, key)
actioning = self.actioning.capitalize() actioning = self.actioning.capitalize()
target_title = self.handler.get_target_title() target_title = self.handler.get_target_title()
try:
self.app.progress_loop(create_update, all_source_data, progress, self.app.progress_loop(create_update, all_source_data, progress,
message=f"{actioning} {model_title} data to {target_title}") message=f"{actioning} {model_title} data to {target_title}")
except ImportLimitReached:
pass
return created, updated return created, updated
def do_delete(self, source_data, progress=None): def do_delete(self, source_keys, changes=None, progress=None):
""" """
TODO: not yet implemented Delete records from the target side as needed, per the given
source data.
:returns: List of records deleted on the target. This will call :meth:`get_deletable_keys()` to discover which
keys existing on the target side could theoretically allow
being deleted.
From that set it will remove all the given source keys - since
such keys still exist on the source, they should not be
deleted from target.
If any "deletable" keys remain, their corresponding objects
are removed from target via :meth:`delete_target_object()`.
:param source_keys: A ``set`` of keys for all source records.
Essentially this is just the list of keys for which target
records should *not* be deleted - since they still exist in
the data source.
:param changes: Number of changes which have already been made
on the target side. Used to enforce max allowed changes,
if applicable.
:param progress: Optional progress indicator factory.
:returns: List of target records which were deleted.
""" """
return [] deleted = []
changes = changes or 0
# which target records are deletable? potentially all target
# records may be eligible, but anything also found in source
# is *not* eligible.
deletable = self.get_deletable_keys() - source_keys
log.debug("found %s records to delete", len(deletable))
def delete(key, i):
cached = self.cached_target.pop(key)
obj = cached['object']
# delete target object
if self.delete_target_object(obj):
deleted.append((obj, cached['data']))
# stop if we reach max allowed
if self.max_delete and len(deleted) >= self.max_delete:
log.warning("max of %s *deleted* records has been reached; stopping now",
self.max_delete)
raise ImportLimitReached()
elif self.max_total and (changes + len(deleted)) >= self.max_total:
log.warning("max of %s *total changes* has been reached; stopping now",
self.max_total)
raise ImportLimitReached()
try:
model_title = self.get_model_title()
self.app.progress_loop(delete, sorted(deletable), progress,
message=f"Deleting {model_title} records")
except ImportLimitReached:
pass
return deleted
def get_record_key(self, data): def get_record_key(self, data):
""" """
@ -579,6 +689,49 @@ class Importer:
message=f"Reading {model_title} data from {source_title}") message=f"Reading {model_title} data from {source_title}")
return normalized return normalized
def get_unique_data(self, source_data):
"""
Return a copy of the given source data, with any duplicate
records removed.
This looks for duplicates based on the effective key fields,
cf. :meth:`get_keys()`. The first record found with a given
key is kept; subsequent records with that key are discarded.
This is called from :meth:`process_data()` and is done largely
for sanity's sake, to avoid indeterminate behavior when source
data contains duplicates. For instance:
Problem #1: If source contains 2 records with key 'X' it makes
no sense to create both records on the target side.
Problem #2: if the 2 source records have different data (apart
from their key) then which should target reflect?
So the main point of this method is to discard the duplicates
to avoid problem #1, but do it in a deterministic way so at
least the "choice" of which record is kept will not vary
across runs; hence "pseudo-resolve" problem #2.
:param source_data: Sequence of normalized source data.
:returns: A 2-tuple of ``(source_data, unique_keys)`` where:
* ``source_data`` is the final list of source data
* ``unique_keys`` is a :class:`python:set` of the source record keys
"""
unique = OrderedDict()
for data in source_data:
key = self.get_record_key(data)
if key in unique:
log.warning("duplicate %s records detected from %s for key: %s",
self.get_model_title(),
self.handler.get_source_title(),
key)
else:
unique[key] = data
return list(unique.values()), set(unique)
def get_source_objects(self): def get_source_objects(self):
""" """
This method (if applicable) should return a sequence of "raw" This method (if applicable) should return a sequence of "raw"
@ -754,6 +907,38 @@ class Importer:
for field in fields]) for field in fields])
return data return data
def get_deletable_keys(self, progress=None):
"""
Return a set of record keys from the target side, which are
*potentially* eligible for deletion.
Inclusion in this set does not imply a given record/key
*should* be deleted, only that app logic (e.g. business rules)
does not prevent it.
Default logic here will look in the :attr:`cached_target` and
then call :meth:`can_delete_object()` for each record in the
cache. If that call returns true for a given key, it is
included in the result.
:returns: The ``set`` of target record keys eligible for
deletion.
"""
if not self.caches_target:
return set()
keys = set()
def check(key, i):
data = self.cached_target[key]['data']
obj = self.cached_target[key]['object']
if self.can_delete_object(obj, data):
keys.add(key)
self.app.progress_loop(check, set(self.cached_target), progress,
message="Determining which objects can be deleted")
return keys
############################## ##############################
# CRUD methods # CRUD methods
############################## ##############################
@ -859,6 +1044,40 @@ class Importer:
return obj return obj
def can_delete_object(self, obj, data=None):
"""
Should return true or false indicating whether the given
object "can" be deleted. Default is to return true in all
cases.
If you return false then the importer will know not to call
:meth:`delete_target_object()` even if the data sets imply
that it should.
:param obj: Raw object on the target side.
:param data: Normalized data dict for the target record, if
known.
:returns: ``True`` if object can be deleted, else ``False``.
"""
return True
def delete_target_object(self, obj):
"""
Delete the given raw object from the target side, and return
true if successful.
This is called from :meth:`do_delete()`.
Default logic for this method just returns false; subclass
should override if needed.
:returns: Should return ``True`` if deletion succeeds, or
``False`` if deletion failed or was skipped.
"""
return False
class FromFile(Importer): class FromFile(Importer):
""" """
@ -1005,10 +1224,9 @@ class ToSqlalchemy(Importer):
""" """
Tries to fetch the object from target DB using ORM query. Tries to fetch the object from target DB using ORM query.
""" """
# first the default logic in case target object is cached # use default logic to fetch from cache, if applicable
obj = super().get_target_object(key) if self.caches_target:
if obj: return super().get_target_object(key)
return obj
# okay now we must fetch via query # okay now we must fetch via query
query = self.target_session.query(self.model_class) query = self.target_session.query(self.model_class)
@ -1019,15 +1237,6 @@ class ToSqlalchemy(Importer):
except orm.exc.NoResultFound: except orm.exc.NoResultFound:
pass pass
def create_target_object(self, key, source_data):
""" """
with self.target_session.no_autoflush:
obj = super().create_target_object(key, source_data)
if obj:
# nb. add new object to target db session
self.target_session.add(obj)
return obj
def get_target_objects(self, source_data=None, progress=None): def get_target_objects(self, source_data=None, progress=None):
""" """
Fetches target objects via the ORM query from Fetches target objects via the ORM query from
@ -1043,3 +1252,17 @@ class ToSqlalchemy(Importer):
:meth:`get_target_objects()`. :meth:`get_target_objects()`.
""" """
return self.target_session.query(self.model_class) return self.target_session.query(self.model_class)
def create_target_object(self, key, source_data):
""" """
with self.target_session.no_autoflush:
obj = super().create_target_object(key, source_data)
if obj:
# nb. add new object to target db session
self.target_session.add(obj)
return obj
def delete_target_object(self, obj):
""" """
self.target_session.delete(obj)
return True

View file

@ -26,11 +26,12 @@ Importing from CSV
import csv import csv
import logging import logging
import uuid as _uuid
from collections import OrderedDict from collections import OrderedDict
from sqlalchemy_utils.functions import get_primary_keys from sqlalchemy_utils.functions import get_primary_keys
from wuttjamaican.db.util import make_topo_sortkey from wuttjamaican.db.util import make_topo_sortkey, UUID
from .base import FromFile from .base import FromFile
from .handlers import FromFileHandler from .handlers import FromFileHandler
@ -138,7 +139,54 @@ class FromCsv(FromFile):
class FromCsvToSqlalchemyMixin: class FromCsvToSqlalchemyMixin:
""" """
Mixin handler class for CSV SQLAlchemy ORM import/export. Mixin class for CSV SQLAlchemy ORM :term:`importers <importer>`.
Meant to be used by :class:`FromCsvToSqlalchemyHandlerMixin`.
This mixin adds some logic to better handle ``uuid`` key fields
which are of :class:`~wuttjamaican:wuttjamaican.db.util.UUID` data
type (i.e. on the target side). Namely, when reading ``uuid``
values as string from CSV, convert them to proper UUID instances,
so the key matching between source and target will behave as
expected.
"""
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
# nb. keep track of any key fields which use proper UUID type
self.uuid_keys = []
for field in self.get_keys():
attr = getattr(self.model_class, field)
if len(attr.prop.columns) == 1:
if isinstance(attr.prop.columns[0].type, UUID):
self.uuid_keys.append(field)
def normalize_source_object(self, obj):
""" """
data = dict(obj)
# nb. convert to proper UUID values so key matching will work
# properly, where applicable
for key in self.uuid_keys:
uuid = data[key]
if uuid and not isinstance(uuid, _uuid.UUID):
data[key] = _uuid.UUID(uuid)
return data
class FromCsvToSqlalchemyHandlerMixin:
"""
Mixin class for CSV SQLAlchemy ORM :term:`import handlers
<import handler>`.
This knows how to dynamically generate :term:`importer` classes to
target the particular ORM involved. Such classes will inherit
from :class:`FromCsvToSqlalchemyMixin`, in addition to whatever
:attr:`FromImporterBase` and :attr:`ToImporterBase` reference.
This all happens within :meth:`define_importers()`.
""" """
source_key = 'csv' source_key = 'csv'
generic_source_title = "CSV" generic_source_title = "CSV"
@ -201,30 +249,39 @@ class FromCsvToSqlalchemyMixin:
return importers return importers
def make_importer_factory(self, cls, name): def make_importer_factory(self, model_class, name):
""" """
Generate and return a new importer/exporter class, targeting Generate and return a new :term:`importer` class, targeting
the given data model class. the given :term:`data model` class.
:param cls: A data model class. The newly-created class will inherit from:
:param name: Optional "model name" override for the * :class:`FromCsvToSqlalchemyMixin`
importer/exporter. * :attr:`FromImporterBase`
* :attr:`ToImporterBase`
:returns: A new class, meant to process import/export :param model_class: A data model class.
operations which target the given data model. The new
class will inherit from both :attr:`FromImporterBase` and :param name: The "model name" for the importer/exporter. New
:attr:`ToImporterBase`. class name will be based on this, so e.g. ``Widget`` model
name becomes ``WidgetImporter`` class name.
:returns: The new class, meant to process import/export
targeting the given data model.
""" """
return type(f'{name}Importer', (FromCsv, self.ToImporterBase), { return type(f'{name}Importer',
'model_class': cls, (FromCsvToSqlalchemyMixin, self.FromImporterBase, self.ToImporterBase), {
'key': list(get_primary_keys(cls)), 'model_class': model_class,
'key': list(get_primary_keys(model_class)),
}) })
class FromCsvToWutta(FromCsvToSqlalchemyMixin, FromFileHandler, ToWuttaHandler): class FromCsvToWutta(FromCsvToSqlalchemyHandlerMixin, FromFileHandler, ToWuttaHandler):
""" """
Handler for CSV Wutta :term:`app database` import. Handler for CSV Wutta :term:`app database` import.
This uses :class:`FromCsvToSqlalchemyHandlerMixin` for most of the
heavy lifting.
""" """
ToImporterBase = ToWutta ToImporterBase = ToWutta

View file

@ -89,60 +89,237 @@ class TestImporter(DataTestCase):
def test_process_data(self): def test_process_data(self):
model = self.app.model model = self.app.model
imp = self.make_importer(model_class=model.Setting, caches_target=True) imp = self.make_importer(model_class=model.Setting, caches_target=True,
delete=True)
# empty data set / just for coverage def make_cache():
with patch.object(imp, 'normalize_source_data') as normalize_source_data: setting1 = model.Setting(name='foo1', value='bar1')
normalize_source_data.return_value = [] setting2 = model.Setting(name='foo2', value='bar2')
setting3 = model.Setting(name='foo3', value='bar3')
cache = {
('foo1',): {
'object': setting1,
'data': {'name': 'foo1', 'value': 'bar1'},
},
('foo2',): {
'object': setting2,
'data': {'name': 'foo2', 'value': 'bar2'},
},
('foo3',): {
'object': setting3,
'data': {'name': 'foo3', 'value': 'bar3'},
},
}
return cache
with patch.object(imp, 'get_target_cache') as get_target_cache: # nb. delete always succeeds
get_target_cache.return_value = {} with patch.object(imp, 'delete_target_object', return_value=True):
result = imp.process_data() # create + update + delete all as needed
self.assertEqual(result, ([], [], [])) with patch.object(imp, 'get_target_cache', return_value=make_cache()):
created, updated, deleted = imp.process_data([
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
{'name': 'foo5', 'value': 'BAR5'},
])
self.assertEqual(len(created), 2)
self.assertEqual(len(updated), 1)
self.assertEqual(len(deleted), 2)
# same but with --max-total so delete gets skipped
with patch.object(imp, 'get_target_cache', return_value=make_cache()):
with patch.object(imp, 'max_total', new=3):
created, updated, deleted = imp.process_data([
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
{'name': 'foo5', 'value': 'BAR5'},
])
self.assertEqual(len(created), 2)
self.assertEqual(len(updated), 1)
self.assertEqual(len(deleted), 0)
# delete all if source data empty
with patch.object(imp, 'get_target_cache', return_value=make_cache()):
created, updated, deleted = imp.process_data()
self.assertEqual(len(created), 0)
self.assertEqual(len(updated), 0)
self.assertEqual(len(deleted), 3)
def test_do_create_update(self): def test_do_create_update(self):
model = self.app.model model = self.app.model
imp = self.make_importer(model_class=model.Setting, caches_target=True)
def make_cache():
setting1 = model.Setting(name='foo1', value='bar1')
setting2 = model.Setting(name='foo2', value='bar2')
cache = {
('foo1',): {
'object': setting1,
'data': {'name': 'foo1', 'value': 'bar1'},
},
('foo2',): {
'object': setting2,
'data': {'name': 'foo2', 'value': 'bar2'},
},
}
return cache
# change nothing if data matches
with patch.multiple(imp, create=True, cached_target=make_cache()):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'bar1'},
{'name': 'foo2', 'value': 'bar2'},
])
self.assertEqual(len(created), 0)
self.assertEqual(len(updated), 0)
# update all as needed
with patch.multiple(imp, create=True, cached_target=make_cache()):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'BAR1'},
{'name': 'foo2', 'value': 'BAR2'},
])
self.assertEqual(len(created), 0)
self.assertEqual(len(updated), 2)
# update all, with --max-update
with patch.multiple(imp, create=True, cached_target=make_cache(), max_update=1):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'BAR1'},
{'name': 'foo2', 'value': 'BAR2'},
])
self.assertEqual(len(created), 0)
self.assertEqual(len(updated), 1)
# update all, with --max-total
with patch.multiple(imp, create=True, cached_target=make_cache(), max_total=1):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'BAR1'},
{'name': 'foo2', 'value': 'BAR2'},
])
self.assertEqual(len(created), 0)
self.assertEqual(len(updated), 1)
# create all as needed
with patch.multiple(imp, create=True, cached_target=make_cache()):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'bar1'},
{'name': 'foo2', 'value': 'bar2'},
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
])
self.assertEqual(len(created), 2)
self.assertEqual(len(updated), 0)
# what happens when create gets skipped
with patch.multiple(imp, create=True, cached_target=make_cache()):
with patch.object(imp, 'create_target_object', return_value=None):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'bar1'},
{'name': 'foo2', 'value': 'bar2'},
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
])
self.assertEqual(len(created), 0)
self.assertEqual(len(updated), 0)
# create all, with --max-create
with patch.multiple(imp, create=True, cached_target=make_cache(), max_create=1):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'bar1'},
{'name': 'foo2', 'value': 'bar2'},
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
])
self.assertEqual(len(created), 1)
self.assertEqual(len(updated), 0)
# create all, with --max-total
with patch.multiple(imp, create=True, cached_target=make_cache(), max_total=1):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'bar1'},
{'name': 'foo2', 'value': 'bar2'},
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
])
self.assertEqual(len(created), 1)
self.assertEqual(len(updated), 0)
# create + update all as needed
with patch.multiple(imp, create=True, cached_target=make_cache()):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'BAR1'},
{'name': 'foo2', 'value': 'BAR2'},
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
])
self.assertEqual(len(created), 2)
self.assertEqual(len(updated), 2)
# create + update all, with --max-total
with patch.multiple(imp, create=True, cached_target=make_cache(), max_total=1):
created, updated = imp.do_create_update([
{'name': 'foo1', 'value': 'BAR1'},
{'name': 'foo2', 'value': 'BAR2'},
{'name': 'foo3', 'value': 'BAR3'},
{'name': 'foo4', 'value': 'BAR4'},
])
# nb. foo1 is updated first
self.assertEqual(len(created), 0)
self.assertEqual(len(updated), 1)
def test_do_delete(self):
model = self.app.model
# this requires a mock target cache # this requires a mock target cache
setting1 = model.Setting(name='foo1', value='bar1')
setting2 = model.Setting(name='foo2', value='bar2')
imp = self.make_importer(model_class=model.Setting, caches_target=True) imp = self.make_importer(model_class=model.Setting, caches_target=True)
setting = model.Setting(name='foo', value='bar') cache = {
imp.cached_target = { ('foo1',): {
('foo',): { 'object': setting1,
'object': setting, 'data': {'name': 'foo1', 'value': 'bar1'},
'data': {'name': 'foo', 'value': 'bar'}, },
('foo2',): {
'object': setting2,
'data': {'name': 'foo2', 'value': 'bar2'},
}, },
} }
# will update the one record with patch.object(imp, 'delete_target_object') as delete_target_object:
result = imp.do_create_update([{'name': 'foo', 'value': 'baz'}])
self.assertIs(result[1][0][0], setting)
self.assertEqual(result, ([], [(setting,
# nb. target
{'name': 'foo', 'value': 'bar'},
# nb. source
{'name': 'foo', 'value': 'baz'})]))
self.assertEqual(setting.value, 'baz')
# will create a new record # delete nothing if source has same keys
result = imp.do_create_update([{'name': 'blah', 'value': 'zay'}]) with patch.multiple(imp, create=True, cached_target=dict(cache)):
self.assertIsNot(result[0][0][0], setting) source_keys = set(imp.cached_target)
setting_new = result[0][0][0] result = imp.do_delete(source_keys)
self.assertEqual(result, ([(setting_new, self.assertFalse(delete_target_object.called)
# nb. source self.assertEqual(result, [])
{'name': 'blah', 'value': 'zay'})],
[]))
self.assertEqual(setting_new.name, 'blah')
self.assertEqual(setting_new.value, 'zay')
# but what if new record is *not* created # delete both if source has no keys
with patch.object(imp, 'create_target_object', return_value=None): delete_target_object.reset_mock()
result = imp.do_create_update([{'name': 'another', 'value': 'one'}]) with patch.multiple(imp, create=True, cached_target=dict(cache)):
self.assertEqual(result, ([], [])) source_keys = set()
result = imp.do_delete(source_keys)
self.assertEqual(delete_target_object.call_count, 2)
self.assertEqual(len(result), 2)
# def test_do_delete(self): # delete just one if --max-delete was set
# model = self.app.model delete_target_object.reset_mock()
# imp = self.make_importer(model_class=model.Setting) with patch.multiple(imp, create=True, cached_target=dict(cache)):
source_keys = set()
with patch.object(imp, 'max_delete', new=1):
result = imp.do_delete(source_keys)
self.assertEqual(delete_target_object.call_count, 1)
self.assertEqual(len(result), 1)
# delete just one if --max-total was set
delete_target_object.reset_mock()
with patch.multiple(imp, create=True, cached_target=dict(cache)):
source_keys = set()
with patch.object(imp, 'max_total', new=1):
result = imp.do_delete(source_keys)
self.assertEqual(delete_target_object.call_count, 1)
self.assertEqual(len(result), 1)
def test_get_record_key(self): def test_get_record_key(self):
model = self.app.model model = self.app.model
@ -182,6 +359,22 @@ class TestImporter(DataTestCase):
# nb. default normalizer returns object as-is # nb. default normalizer returns object as-is
self.assertIs(data[0], setting) self.assertIs(data[0], setting)
def test_get_unique_data(self):
model = self.app.model
imp = self.make_importer(model_class=model.Setting)
setting1 = model.Setting(name='foo', value='bar1')
setting2 = model.Setting(name='foo', value='bar2')
result = imp.get_unique_data([setting2, setting1])
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], list)
self.assertEqual(len(result[0]), 1)
self.assertIs(result[0][0], setting2) # nb. not setting1
self.assertIsInstance(result[1], set)
self.assertEqual(result[1], {('foo',)})
def test_get_source_objects(self): def test_get_source_objects(self):
model = self.app.model model = self.app.model
imp = self.make_importer(model_class=model.Setting) imp = self.make_importer(model_class=model.Setting)
@ -263,6 +456,34 @@ class TestImporter(DataTestCase):
data = imp.normalize_target_object(setting) data = imp.normalize_target_object(setting)
self.assertEqual(data, {'name': 'foo', 'value': 'bar'}) self.assertEqual(data, {'name': 'foo', 'value': 'bar'})
def test_get_deletable_keys(self):
model = self.app.model
imp = self.make_importer(model_class=model.Setting)
# empty set by default (nb. no target cache)
result = imp.get_deletable_keys()
self.assertIsInstance(result, set)
self.assertEqual(result, set())
setting = model.Setting(name='foo', value='bar')
cache = {
('foo',): {
'object': setting,
'data': {'name': 'foo', 'value': 'bar'},
},
}
with patch.multiple(imp, create=True, caches_target=True, cached_target=cache):
# all are deletable by default
result = imp.get_deletable_keys()
self.assertEqual(result, {('foo',)})
# but some maybe can't be deleted
with patch.object(imp, 'can_delete_object', return_value=False):
result = imp.get_deletable_keys()
self.assertEqual(result, set())
def test_create_target_object(self): def test_create_target_object(self):
model = self.app.model model = self.app.model
imp = self.make_importer(model_class=model.Setting) imp = self.make_importer(model_class=model.Setting)
@ -301,6 +522,19 @@ class TestImporter(DataTestCase):
self.assertIs(obj, setting) self.assertIs(obj, setting)
self.assertEqual(setting.value, 'bar') self.assertEqual(setting.value, 'bar')
def test_can_delete_object(self):
model = self.app.model
imp = self.make_importer(model_class=model.Setting)
setting = model.Setting(name='foo')
self.assertTrue(imp.can_delete_object(setting))
def test_delete_target_object(self):
model = self.app.model
imp = self.make_importer(model_class=model.Setting)
setting = model.Setting(name='foo')
# nb. default implementation always returns false
self.assertFalse(imp.delete_target_object(setting))
class TestFromFile(DataTestCase): class TestFromFile(DataTestCase):
@ -390,6 +624,20 @@ class TestToSqlalchemy(DataTestCase):
kwargs.setdefault('handler', self.handler) kwargs.setdefault('handler', self.handler)
return mod.ToSqlalchemy(self.config, **kwargs) return mod.ToSqlalchemy(self.config, **kwargs)
def test_get_target_objects(self):
model = self.app.model
imp = self.make_importer(model_class=model.Setting, target_session=self.session)
setting1 = model.Setting(name='foo', value='bar')
self.session.add(setting1)
setting2 = model.Setting(name='foo2', value='bar2')
self.session.add(setting2)
self.session.commit()
result = imp.get_target_objects()
self.assertEqual(len(result), 2)
self.assertEqual(set(result), {setting1, setting2})
def test_get_target_object(self): def test_get_target_object(self):
model = self.app.model model = self.app.model
setting = model.Setting(name='foo', value='bar') setting = model.Setting(name='foo', value='bar')
@ -416,8 +664,12 @@ class TestToSqlalchemy(DataTestCase):
self.session.add(setting2) self.session.add(setting2)
self.session.commit() self.session.commit()
# then we should be able to fetch that via query # nb. disable target cache
imp.target_session = self.session with patch.multiple(imp, create=True,
target_session=self.session,
caches_target=False):
# now we should be able to fetch that via query
result = imp.get_target_object(('foo2',)) result = imp.get_target_object(('foo2',))
self.assertIsInstance(result, model.Setting) self.assertIsInstance(result, model.Setting)
self.assertIs(result, setting2) self.assertIs(result, setting2)
@ -438,16 +690,13 @@ class TestToSqlalchemy(DataTestCase):
self.assertEqual(setting.value, 'bar') self.assertEqual(setting.value, 'bar')
self.assertIn(setting, self.session) self.assertIn(setting, self.session)
def test_get_target_objects(self): def test_delete_target_object(self):
model = self.app.model model = self.app.model
setting = model.Setting(name='foo', value='bar')
self.session.add(setting)
self.assertEqual(self.session.query(model.Setting).count(), 1)
imp = self.make_importer(model_class=model.Setting, target_session=self.session) imp = self.make_importer(model_class=model.Setting, target_session=self.session)
imp.delete_target_object(setting)
setting1 = model.Setting(name='foo', value='bar') self.assertEqual(self.session.query(model.Setting).count(), 0)
self.session.add(setting1)
setting2 = model.Setting(name='foo2', value='bar2')
self.session.add(setting2)
self.session.commit()
result = imp.get_target_objects()
self.assertEqual(len(result), 2)
self.assertEqual(set(result), {setting1, setting2})

View file

@ -1,6 +1,7 @@
#-*- coding: utf-8; -*- #-*- coding: utf-8; -*-
import csv import csv
import uuid as _uuid
from unittest.mock import patch from unittest.mock import patch
from wuttjamaican.testing import DataTestCase from wuttjamaican.testing import DataTestCase
@ -87,23 +88,74 @@ foo2,bar2
self.assertEqual(objects[1], {'name': 'foo2', 'value': 'bar2'}) self.assertEqual(objects[1], {'name': 'foo2', 'value': 'bar2'})
class MockMixinHandler(mod.FromCsvToSqlalchemyMixin, ToSqlalchemyHandler): class MockMixinImporter(mod.FromCsvToSqlalchemyMixin, mod.FromCsv, ToSqlalchemy):
ToImporterBase = ToSqlalchemy pass
class TestFromCsvToSqlalchemyMixin(DataTestCase): class TestFromCsvToSqlalchemyMixin(DataTestCase):
def setUp(self):
self.setup_db()
self.handler = ImportHandler(self.config)
def make_importer(self, **kwargs):
kwargs.setdefault('handler', self.handler)
return MockMixinImporter(self.config, **kwargs)
def test_constructor(self):
model = self.app.model
# no uuid keys
imp = self.make_importer(model_class=model.Setting)
self.assertEqual(imp.uuid_keys, [])
# typical
# nb. as of now Upgrade is the only table using proper UUID
imp = self.make_importer(model_class=model.Upgrade)
self.assertEqual(imp.uuid_keys, ['uuid'])
def test_normalize_source_object(self):
model = self.app.model
# no uuid keys
imp = self.make_importer(model_class=model.Setting)
result = imp.normalize_source_object({'name': 'foo', 'value': 'bar'})
self.assertEqual(result, {'name': 'foo', 'value': 'bar'})
# source has proper UUID
# nb. as of now Upgrade is the only table using proper UUID
imp = self.make_importer(model_class=model.Upgrade, fields=['uuid', 'description'])
result = imp.normalize_source_object({'uuid': _uuid.UUID('06753693-d892-77f0-8000-ce71bf7ebbba'),
'description': 'testing'})
self.assertEqual(result, {'uuid': _uuid.UUID('06753693-d892-77f0-8000-ce71bf7ebbba'),
'description': 'testing'})
# source has string uuid
# nb. as of now Upgrade is the only table using proper UUID
imp = self.make_importer(model_class=model.Upgrade, fields=['uuid', 'description'])
result = imp.normalize_source_object({'uuid': '06753693d89277f08000ce71bf7ebbba',
'description': 'testing'})
self.assertEqual(result, {'uuid': _uuid.UUID('06753693-d892-77f0-8000-ce71bf7ebbba'),
'description': 'testing'})
class MockMixinHandler(mod.FromCsvToSqlalchemyHandlerMixin, ToSqlalchemyHandler):
ToImporterBase = ToSqlalchemy
class TestFromCsvToSqlalchemyHandlerMixin(DataTestCase):
def make_handler(self, **kwargs): def make_handler(self, **kwargs):
return MockMixinHandler(self.config, **kwargs) return MockMixinHandler(self.config, **kwargs)
def test_get_target_model(self): def test_get_target_model(self):
with patch.object(mod.FromCsvToSqlalchemyMixin, 'define_importers', return_value={}): with patch.object(mod.FromCsvToSqlalchemyHandlerMixin, 'define_importers', return_value={}):
handler = self.make_handler() handler = self.make_handler()
self.assertRaises(NotImplementedError, handler.get_target_model) self.assertRaises(NotImplementedError, handler.get_target_model)
def test_define_importers(self): def test_define_importers(self):
model = self.app.model model = self.app.model
with patch.object(mod.FromCsvToSqlalchemyMixin, 'get_target_model', return_value=model): with patch.object(mod.FromCsvToSqlalchemyHandlerMixin, 'get_target_model', return_value=model):
handler = self.make_handler() handler = self.make_handler()
importers = handler.define_importers() importers = handler.define_importers()
self.assertIn('Setting', importers) self.assertIn('Setting', importers)
@ -115,7 +167,7 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase):
def test_make_importer_factory(self): def test_make_importer_factory(self):
model = self.app.model model = self.app.model
with patch.object(mod.FromCsvToSqlalchemyMixin, 'define_importers', return_value={}): with patch.object(mod.FromCsvToSqlalchemyHandlerMixin, 'define_importers', return_value={}):
handler = self.make_handler() handler = self.make_handler()
factory = handler.make_importer_factory(model.Setting, 'Setting') factory = handler.make_importer_factory(model.Setting, 'Setting')
self.assertTrue(issubclass(factory, mod.FromCsv)) self.assertTrue(issubclass(factory, mod.FromCsv))