From 328f8d99529c2f9fbdb1bf4119dcf896721f45b2 Mon Sep 17 00:00:00 2001 From: Lance Edgar Date: Fri, 6 Dec 2024 15:18:23 -0600 Subject: [PATCH] fix: implement deletion logic; add cli params for max changes also add special UUID field handling for CSV -> SQLAlchemy ORM, to normalize string from CSV to proper UUID so key matching works --- src/wuttasync/cli/base.py | 31 ++- src/wuttasync/importing/__init__.py | 16 ++ src/wuttasync/importing/base.py | 287 +++++++++++++++++++--- src/wuttasync/importing/csv.py | 89 +++++-- tests/importing/test_base.py | 365 +++++++++++++++++++++++----- tests/importing/test_csv.py | 62 ++++- 6 files changed, 735 insertions(+), 115 deletions(-) diff --git a/src/wuttasync/cli/base.py b/src/wuttasync/cli/base.py index 0a0ed2c..6368c6a 100644 --- a/src/wuttasync/cli/base.py +++ b/src/wuttasync/cli/base.py @@ -147,15 +147,18 @@ def import_command_template( create: Annotated[ bool, - typer.Option(help="Allow new target records to be created.")] = True, + typer.Option(help="Allow new target records to be created. " + "See aso --max-create.")] = True, update: Annotated[ bool, - typer.Option(help="Allow existing target records to be updated.")] = True, + typer.Option(help="Allow existing target records to be updated. " + "See also --max-update.")] = True, delete: Annotated[ bool, - typer.Option(help="Allow existing target records to be deleted.")] = False, + typer.Option(help="Allow existing target records to be deleted. " + "See also --max-delete.")] = False, fields: Annotated[ str, @@ -170,7 +173,27 @@ def import_command_template( keys: Annotated[ str, typer.Option('--key', '--keys', - help="List of fields to use as record key/identifier. See also --fields.")] = None, + help="List of fields to use as record key/identifier. " + "See also --fields.")] = None, + + max_create: Annotated[ + int, + typer.Option(help="Max number of target records to create (per model). " + "See also --create.")] = None, + + max_update: Annotated[ + int, + typer.Option(help="Max number of target records to update (per model). " + "See also --update.")] = None, + + max_delete: Annotated[ + int, + typer.Option(help="Max number of target records to delete (per model). " + "See also --delete.")] = None, + + max_total: Annotated[ + int, + typer.Option(help="Max number of *any* target record changes which may occur (per model).")] = None, dry_run: Annotated[ bool, diff --git a/src/wuttasync/importing/__init__.py b/src/wuttasync/importing/__init__.py index b43d5d3..eadd3b6 100644 --- a/src/wuttasync/importing/__init__.py +++ b/src/wuttasync/importing/__init__.py @@ -22,6 +22,22 @@ ################################################################################ """ Data Import / Export Framework + +This namespace exposes the following: + +* :enum:`~wuttasync.importing.handlers.Orientation` + +And for the :term:`import handlers `: + +* :class:`~wuttasync.importing.handlers.ImportHandler` +* :class:`~wuttasync.importing.handlers.FromFileHandler` +* :class:`~wuttasync.importing.handlers.ToSqlalchemyHandler` + +And for the :term:`importers `: + +* :class:`~wuttasync.importing.base.Importer` +* :class:`~wuttasync.importing.base.FromFile` +* :class:`~wuttasync.importing.base.ToSqlalchemy` """ from .handlers import Orientation, ImportHandler, FromFileHandler, ToSqlalchemyHandler diff --git a/src/wuttasync/importing/base.py b/src/wuttasync/importing/base.py index f9c4bb2..59017d7 100644 --- a/src/wuttasync/importing/base.py +++ b/src/wuttasync/importing/base.py @@ -26,6 +26,7 @@ Data Importer base class import os import logging +from collections import OrderedDict from sqlalchemy import orm from sqlalchemy_utils.functions import get_primary_keys, get_columns @@ -36,6 +37,13 @@ from wuttasync.util import data_diffs log = logging.getLogger(__name__) +class ImportLimitReached(Exception): + """ + Exception raised when an import/export job reaches the max number + of changes allowed. + """ + + class Importer: """ Base class for all data importers / exporters. @@ -174,6 +182,11 @@ class Importer: :meth:`get_target_cache()`. """ + max_create = None + max_update = None + max_delete = None + max_total = None + def __init__(self, config, **kwargs): self.config = config self.app = self.config.get_app() @@ -354,9 +367,26 @@ class Importer: Note that subclass generally should not override this method, but instead some of the others. - :param source_data: Optional sequence of normalized source - data. If not specified, it is obtained from - :meth:`normalize_source_data()`. + This first calls :meth:`setup()` to prepare things as needed. + + If no source data is specified, it calls + :meth:`normalize_source_data()` to get that. Regardless, it + also calls :meth:`get_unique_data()` to discard any + duplicates. + + If :attr:`caches_target` is set, it calls + :meth:`get_target_cache()` and assigns result to + :attr:`cached_target`. + + Then depending on values for :attr:`create`, :attr:`update` + and :attr:`delete` it may call: + + * :meth:`do_create_update()` + * :meth:`do_delete()` + + And finally it calls :meth:`teardown()` for cleanup. + + :param source_data: Sequence of normalized source data, if known. :param progress: Optional progress indicator factory. @@ -366,13 +396,6 @@ class Importer: * ``created`` - list of records created on the target * ``updated`` - list of records updated on the target * ``deleted`` - list of records deleted on the target - - See also these methods which this one calls: - - * :meth:`setup()` - * :meth:`do_create_update()` - * :meth:`do_delete()` - * :meth:`teardown()` """ # TODO: should add try/catch around this all? and teardown() in finally: clause? self.setup() @@ -386,8 +409,9 @@ class Importer: if source_data is None: source_data = self.normalize_source_data(progress=progress) - # TODO: should exclude duplicate source records - # source_data, unique = self.get_unique_data(source_data) + # nb. prune duplicate records from source data + source_data, source_keys = self.get_unique_data(source_data) + model_title = self.get_model_title() log.debug(f"got %s {model_title} records from source", len(source_data)) @@ -402,7 +426,12 @@ class Importer: # delete target data if self.delete: - deleted = self.do_delete(source_data) + changes = len(created) + len(updated) + if self.max_total and changes >= self.max_total: + log.debug("max of %s total changes already reached; skipping deletions", + self.max_total) + else: + deleted = self.do_delete(source_keys, changes, progress=progress) self.teardown() return created, updated, deleted @@ -460,6 +489,16 @@ class Importer: target_data=target_data) updated.append((target_object, target_data, source_data)) + # stop if we reach max allowed + if self.max_update and len(updated) >= self.max_update: + log.warning("max of %s *updated* records has been reached; stopping now", + self.max_update) + raise ImportLimitReached() + elif self.max_total and (len(created) + len(updated)) >= self.max_total: + log.warning("max of %s *total changes* has been reached; stopping now", + self.max_total) + raise ImportLimitReached() + elif not target_object and self.create: # target object not yet present, so create it @@ -473,23 +512,94 @@ class Importer: # 'object': target_object, # 'data': self.normalize_target_object(target_object), # } + + # stop if we reach max allowed + if self.max_create and len(created) >= self.max_create: + log.warning("max of %s *created* records has been reached; stopping now", + self.max_create) + raise ImportLimitReached() + elif self.max_total and (len(created) + len(updated)) >= self.max_total: + log.warning("max of %s *total changes* has been reached; stopping now", + self.max_total) + raise ImportLimitReached() + else: log.debug("did NOT create new %s for key: %s", model_title, key) actioning = self.actioning.capitalize() target_title = self.handler.get_target_title() - self.app.progress_loop(create_update, all_source_data, progress, - message=f"{actioning} {model_title} data to {target_title}") + try: + self.app.progress_loop(create_update, all_source_data, progress, + message=f"{actioning} {model_title} data to {target_title}") + except ImportLimitReached: + pass return created, updated - def do_delete(self, source_data, progress=None): + def do_delete(self, source_keys, changes=None, progress=None): """ - TODO: not yet implemented + Delete records from the target side as needed, per the given + source data. - :returns: List of records deleted on the target. + This will call :meth:`get_deletable_keys()` to discover which + keys existing on the target side could theoretically allow + being deleted. + + From that set it will remove all the given source keys - since + such keys still exist on the source, they should not be + deleted from target. + + If any "deletable" keys remain, their corresponding objects + are removed from target via :meth:`delete_target_object()`. + + :param source_keys: A ``set`` of keys for all source records. + Essentially this is just the list of keys for which target + records should *not* be deleted - since they still exist in + the data source. + + :param changes: Number of changes which have already been made + on the target side. Used to enforce max allowed changes, + if applicable. + + :param progress: Optional progress indicator factory. + + :returns: List of target records which were deleted. """ - return [] + deleted = [] + changes = changes or 0 + + # which target records are deletable? potentially all target + # records may be eligible, but anything also found in source + # is *not* eligible. + deletable = self.get_deletable_keys() - source_keys + log.debug("found %s records to delete", len(deletable)) + + def delete(key, i): + cached = self.cached_target.pop(key) + obj = cached['object'] + + # delete target object + if self.delete_target_object(obj): + deleted.append((obj, cached['data'])) + + # stop if we reach max allowed + if self.max_delete and len(deleted) >= self.max_delete: + log.warning("max of %s *deleted* records has been reached; stopping now", + self.max_delete) + raise ImportLimitReached() + elif self.max_total and (changes + len(deleted)) >= self.max_total: + log.warning("max of %s *total changes* has been reached; stopping now", + self.max_total) + raise ImportLimitReached() + + try: + model_title = self.get_model_title() + self.app.progress_loop(delete, sorted(deletable), progress, + message=f"Deleting {model_title} records") + except ImportLimitReached: + pass + + return deleted def get_record_key(self, data): """ @@ -579,6 +689,49 @@ class Importer: message=f"Reading {model_title} data from {source_title}") return normalized + def get_unique_data(self, source_data): + """ + Return a copy of the given source data, with any duplicate + records removed. + + This looks for duplicates based on the effective key fields, + cf. :meth:`get_keys()`. The first record found with a given + key is kept; subsequent records with that key are discarded. + + This is called from :meth:`process_data()` and is done largely + for sanity's sake, to avoid indeterminate behavior when source + data contains duplicates. For instance: + + Problem #1: If source contains 2 records with key 'X' it makes + no sense to create both records on the target side. + + Problem #2: if the 2 source records have different data (apart + from their key) then which should target reflect? + + So the main point of this method is to discard the duplicates + to avoid problem #1, but do it in a deterministic way so at + least the "choice" of which record is kept will not vary + across runs; hence "pseudo-resolve" problem #2. + + :param source_data: Sequence of normalized source data. + + :returns: A 2-tuple of ``(source_data, unique_keys)`` where: + + * ``source_data`` is the final list of source data + * ``unique_keys`` is a :class:`python:set` of the source record keys + """ + unique = OrderedDict() + for data in source_data: + key = self.get_record_key(data) + if key in unique: + log.warning("duplicate %s records detected from %s for key: %s", + self.get_model_title(), + self.handler.get_source_title(), + key) + else: + unique[key] = data + return list(unique.values()), set(unique) + def get_source_objects(self): """ This method (if applicable) should return a sequence of "raw" @@ -754,6 +907,38 @@ class Importer: for field in fields]) return data + def get_deletable_keys(self, progress=None): + """ + Return a set of record keys from the target side, which are + *potentially* eligible for deletion. + + Inclusion in this set does not imply a given record/key + *should* be deleted, only that app logic (e.g. business rules) + does not prevent it. + + Default logic here will look in the :attr:`cached_target` and + then call :meth:`can_delete_object()` for each record in the + cache. If that call returns true for a given key, it is + included in the result. + + :returns: The ``set`` of target record keys eligible for + deletion. + """ + if not self.caches_target: + return set() + + keys = set() + + def check(key, i): + data = self.cached_target[key]['data'] + obj = self.cached_target[key]['object'] + if self.can_delete_object(obj, data): + keys.add(key) + + self.app.progress_loop(check, set(self.cached_target), progress, + message="Determining which objects can be deleted") + return keys + ############################## # CRUD methods ############################## @@ -859,6 +1044,40 @@ class Importer: return obj + def can_delete_object(self, obj, data=None): + """ + Should return true or false indicating whether the given + object "can" be deleted. Default is to return true in all + cases. + + If you return false then the importer will know not to call + :meth:`delete_target_object()` even if the data sets imply + that it should. + + :param obj: Raw object on the target side. + + :param data: Normalized data dict for the target record, if + known. + + :returns: ``True`` if object can be deleted, else ``False``. + """ + return True + + def delete_target_object(self, obj): + """ + Delete the given raw object from the target side, and return + true if successful. + + This is called from :meth:`do_delete()`. + + Default logic for this method just returns false; subclass + should override if needed. + + :returns: Should return ``True`` if deletion succeeds, or + ``False`` if deletion failed or was skipped. + """ + return False + class FromFile(Importer): """ @@ -1005,10 +1224,9 @@ class ToSqlalchemy(Importer): """ Tries to fetch the object from target DB using ORM query. """ - # first the default logic in case target object is cached - obj = super().get_target_object(key) - if obj: - return obj + # use default logic to fetch from cache, if applicable + if self.caches_target: + return super().get_target_object(key) # okay now we must fetch via query query = self.target_session.query(self.model_class) @@ -1019,15 +1237,6 @@ class ToSqlalchemy(Importer): except orm.exc.NoResultFound: pass - def create_target_object(self, key, source_data): - """ """ - with self.target_session.no_autoflush: - obj = super().create_target_object(key, source_data) - if obj: - # nb. add new object to target db session - self.target_session.add(obj) - return obj - def get_target_objects(self, source_data=None, progress=None): """ Fetches target objects via the ORM query from @@ -1043,3 +1252,17 @@ class ToSqlalchemy(Importer): :meth:`get_target_objects()`. """ return self.target_session.query(self.model_class) + + def create_target_object(self, key, source_data): + """ """ + with self.target_session.no_autoflush: + obj = super().create_target_object(key, source_data) + if obj: + # nb. add new object to target db session + self.target_session.add(obj) + return obj + + def delete_target_object(self, obj): + """ """ + self.target_session.delete(obj) + return True diff --git a/src/wuttasync/importing/csv.py b/src/wuttasync/importing/csv.py index 1c62818..e1937b5 100644 --- a/src/wuttasync/importing/csv.py +++ b/src/wuttasync/importing/csv.py @@ -26,11 +26,12 @@ Importing from CSV import csv import logging +import uuid as _uuid from collections import OrderedDict from sqlalchemy_utils.functions import get_primary_keys -from wuttjamaican.db.util import make_topo_sortkey +from wuttjamaican.db.util import make_topo_sortkey, UUID from .base import FromFile from .handlers import FromFileHandler @@ -138,7 +139,54 @@ class FromCsv(FromFile): class FromCsvToSqlalchemyMixin: """ - Mixin handler class for CSV → SQLAlchemy ORM import/export. + Mixin class for CSV → SQLAlchemy ORM :term:`importers `. + + Meant to be used by :class:`FromCsvToSqlalchemyHandlerMixin`. + + This mixin adds some logic to better handle ``uuid`` key fields + which are of :class:`~wuttjamaican:wuttjamaican.db.util.UUID` data + type (i.e. on the target side). Namely, when reading ``uuid`` + values as string from CSV, convert them to proper UUID instances, + so the key matching between source and target will behave as + expected. + """ + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + # nb. keep track of any key fields which use proper UUID type + self.uuid_keys = [] + for field in self.get_keys(): + attr = getattr(self.model_class, field) + if len(attr.prop.columns) == 1: + if isinstance(attr.prop.columns[0].type, UUID): + self.uuid_keys.append(field) + + def normalize_source_object(self, obj): + """ """ + data = dict(obj) + + # nb. convert to proper UUID values so key matching will work + # properly, where applicable + for key in self.uuid_keys: + uuid = data[key] + if uuid and not isinstance(uuid, _uuid.UUID): + data[key] = _uuid.UUID(uuid) + + return data + + +class FromCsvToSqlalchemyHandlerMixin: + """ + Mixin class for CSV → SQLAlchemy ORM :term:`import handlers + `. + + This knows how to dynamically generate :term:`importer` classes to + target the particular ORM involved. Such classes will inherit + from :class:`FromCsvToSqlalchemyMixin`, in addition to whatever + :attr:`FromImporterBase` and :attr:`ToImporterBase` reference. + + This all happens within :meth:`define_importers()`. """ source_key = 'csv' generic_source_title = "CSV" @@ -201,30 +249,39 @@ class FromCsvToSqlalchemyMixin: return importers - def make_importer_factory(self, cls, name): + def make_importer_factory(self, model_class, name): """ - Generate and return a new importer/exporter class, targeting - the given data model class. + Generate and return a new :term:`importer` class, targeting + the given :term:`data model` class. - :param cls: A data model class. + The newly-created class will inherit from: - :param name: Optional "model name" override for the - importer/exporter. + * :class:`FromCsvToSqlalchemyMixin` + * :attr:`FromImporterBase` + * :attr:`ToImporterBase` - :returns: A new class, meant to process import/export - operations which target the given data model. The new - class will inherit from both :attr:`FromImporterBase` and - :attr:`ToImporterBase`. + :param model_class: A data model class. + + :param name: The "model name" for the importer/exporter. New + class name will be based on this, so e.g. ``Widget`` model + name becomes ``WidgetImporter`` class name. + + :returns: The new class, meant to process import/export + targeting the given data model. """ - return type(f'{name}Importer', (FromCsv, self.ToImporterBase), { - 'model_class': cls, - 'key': list(get_primary_keys(cls)), + return type(f'{name}Importer', + (FromCsvToSqlalchemyMixin, self.FromImporterBase, self.ToImporterBase), { + 'model_class': model_class, + 'key': list(get_primary_keys(model_class)), }) -class FromCsvToWutta(FromCsvToSqlalchemyMixin, FromFileHandler, ToWuttaHandler): +class FromCsvToWutta(FromCsvToSqlalchemyHandlerMixin, FromFileHandler, ToWuttaHandler): """ Handler for CSV → Wutta :term:`app database` import. + + This uses :class:`FromCsvToSqlalchemyHandlerMixin` for most of the + heavy lifting. """ ToImporterBase = ToWutta diff --git a/tests/importing/test_base.py b/tests/importing/test_base.py index ee8cb48..ff2ca6e 100644 --- a/tests/importing/test_base.py +++ b/tests/importing/test_base.py @@ -89,60 +89,237 @@ class TestImporter(DataTestCase): def test_process_data(self): model = self.app.model - imp = self.make_importer(model_class=model.Setting, caches_target=True) + imp = self.make_importer(model_class=model.Setting, caches_target=True, + delete=True) - # empty data set / just for coverage - with patch.object(imp, 'normalize_source_data') as normalize_source_data: - normalize_source_data.return_value = [] + def make_cache(): + setting1 = model.Setting(name='foo1', value='bar1') + setting2 = model.Setting(name='foo2', value='bar2') + setting3 = model.Setting(name='foo3', value='bar3') + cache = { + ('foo1',): { + 'object': setting1, + 'data': {'name': 'foo1', 'value': 'bar1'}, + }, + ('foo2',): { + 'object': setting2, + 'data': {'name': 'foo2', 'value': 'bar2'}, + }, + ('foo3',): { + 'object': setting3, + 'data': {'name': 'foo3', 'value': 'bar3'}, + }, + } + return cache - with patch.object(imp, 'get_target_cache') as get_target_cache: - get_target_cache.return_value = {} + # nb. delete always succeeds + with patch.object(imp, 'delete_target_object', return_value=True): - result = imp.process_data() - self.assertEqual(result, ([], [], [])) + # create + update + delete all as needed + with patch.object(imp, 'get_target_cache', return_value=make_cache()): + created, updated, deleted = imp.process_data([ + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + {'name': 'foo5', 'value': 'BAR5'}, + ]) + self.assertEqual(len(created), 2) + self.assertEqual(len(updated), 1) + self.assertEqual(len(deleted), 2) + + # same but with --max-total so delete gets skipped + with patch.object(imp, 'get_target_cache', return_value=make_cache()): + with patch.object(imp, 'max_total', new=3): + created, updated, deleted = imp.process_data([ + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + {'name': 'foo5', 'value': 'BAR5'}, + ]) + self.assertEqual(len(created), 2) + self.assertEqual(len(updated), 1) + self.assertEqual(len(deleted), 0) + + # delete all if source data empty + with patch.object(imp, 'get_target_cache', return_value=make_cache()): + created, updated, deleted = imp.process_data() + self.assertEqual(len(created), 0) + self.assertEqual(len(updated), 0) + self.assertEqual(len(deleted), 3) def test_do_create_update(self): model = self.app.model + imp = self.make_importer(model_class=model.Setting, caches_target=True) + + def make_cache(): + setting1 = model.Setting(name='foo1', value='bar1') + setting2 = model.Setting(name='foo2', value='bar2') + cache = { + ('foo1',): { + 'object': setting1, + 'data': {'name': 'foo1', 'value': 'bar1'}, + }, + ('foo2',): { + 'object': setting2, + 'data': {'name': 'foo2', 'value': 'bar2'}, + }, + } + return cache + + # change nothing if data matches + with patch.multiple(imp, create=True, cached_target=make_cache()): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'bar1'}, + {'name': 'foo2', 'value': 'bar2'}, + ]) + self.assertEqual(len(created), 0) + self.assertEqual(len(updated), 0) + + # update all as needed + with patch.multiple(imp, create=True, cached_target=make_cache()): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'BAR1'}, + {'name': 'foo2', 'value': 'BAR2'}, + ]) + self.assertEqual(len(created), 0) + self.assertEqual(len(updated), 2) + + # update all, with --max-update + with patch.multiple(imp, create=True, cached_target=make_cache(), max_update=1): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'BAR1'}, + {'name': 'foo2', 'value': 'BAR2'}, + ]) + self.assertEqual(len(created), 0) + self.assertEqual(len(updated), 1) + + # update all, with --max-total + with patch.multiple(imp, create=True, cached_target=make_cache(), max_total=1): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'BAR1'}, + {'name': 'foo2', 'value': 'BAR2'}, + ]) + self.assertEqual(len(created), 0) + self.assertEqual(len(updated), 1) + + # create all as needed + with patch.multiple(imp, create=True, cached_target=make_cache()): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'bar1'}, + {'name': 'foo2', 'value': 'bar2'}, + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + ]) + self.assertEqual(len(created), 2) + self.assertEqual(len(updated), 0) + + # what happens when create gets skipped + with patch.multiple(imp, create=True, cached_target=make_cache()): + with patch.object(imp, 'create_target_object', return_value=None): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'bar1'}, + {'name': 'foo2', 'value': 'bar2'}, + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + ]) + self.assertEqual(len(created), 0) + self.assertEqual(len(updated), 0) + + # create all, with --max-create + with patch.multiple(imp, create=True, cached_target=make_cache(), max_create=1): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'bar1'}, + {'name': 'foo2', 'value': 'bar2'}, + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + ]) + self.assertEqual(len(created), 1) + self.assertEqual(len(updated), 0) + + # create all, with --max-total + with patch.multiple(imp, create=True, cached_target=make_cache(), max_total=1): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'bar1'}, + {'name': 'foo2', 'value': 'bar2'}, + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + ]) + self.assertEqual(len(created), 1) + self.assertEqual(len(updated), 0) + + # create + update all as needed + with patch.multiple(imp, create=True, cached_target=make_cache()): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'BAR1'}, + {'name': 'foo2', 'value': 'BAR2'}, + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + ]) + self.assertEqual(len(created), 2) + self.assertEqual(len(updated), 2) + + # create + update all, with --max-total + with patch.multiple(imp, create=True, cached_target=make_cache(), max_total=1): + created, updated = imp.do_create_update([ + {'name': 'foo1', 'value': 'BAR1'}, + {'name': 'foo2', 'value': 'BAR2'}, + {'name': 'foo3', 'value': 'BAR3'}, + {'name': 'foo4', 'value': 'BAR4'}, + ]) + # nb. foo1 is updated first + self.assertEqual(len(created), 0) + self.assertEqual(len(updated), 1) + + def test_do_delete(self): + model = self.app.model # this requires a mock target cache + setting1 = model.Setting(name='foo1', value='bar1') + setting2 = model.Setting(name='foo2', value='bar2') imp = self.make_importer(model_class=model.Setting, caches_target=True) - setting = model.Setting(name='foo', value='bar') - imp.cached_target = { - ('foo',): { - 'object': setting, - 'data': {'name': 'foo', 'value': 'bar'}, + cache = { + ('foo1',): { + 'object': setting1, + 'data': {'name': 'foo1', 'value': 'bar1'}, + }, + ('foo2',): { + 'object': setting2, + 'data': {'name': 'foo2', 'value': 'bar2'}, }, } - # will update the one record - result = imp.do_create_update([{'name': 'foo', 'value': 'baz'}]) - self.assertIs(result[1][0][0], setting) - self.assertEqual(result, ([], [(setting, - # nb. target - {'name': 'foo', 'value': 'bar'}, - # nb. source - {'name': 'foo', 'value': 'baz'})])) - self.assertEqual(setting.value, 'baz') + with patch.object(imp, 'delete_target_object') as delete_target_object: - # will create a new record - result = imp.do_create_update([{'name': 'blah', 'value': 'zay'}]) - self.assertIsNot(result[0][0][0], setting) - setting_new = result[0][0][0] - self.assertEqual(result, ([(setting_new, - # nb. source - {'name': 'blah', 'value': 'zay'})], - [])) - self.assertEqual(setting_new.name, 'blah') - self.assertEqual(setting_new.value, 'zay') + # delete nothing if source has same keys + with patch.multiple(imp, create=True, cached_target=dict(cache)): + source_keys = set(imp.cached_target) + result = imp.do_delete(source_keys) + self.assertFalse(delete_target_object.called) + self.assertEqual(result, []) - # but what if new record is *not* created - with patch.object(imp, 'create_target_object', return_value=None): - result = imp.do_create_update([{'name': 'another', 'value': 'one'}]) - self.assertEqual(result, ([], [])) + # delete both if source has no keys + delete_target_object.reset_mock() + with patch.multiple(imp, create=True, cached_target=dict(cache)): + source_keys = set() + result = imp.do_delete(source_keys) + self.assertEqual(delete_target_object.call_count, 2) + self.assertEqual(len(result), 2) - # def test_do_delete(self): - # model = self.app.model - # imp = self.make_importer(model_class=model.Setting) + # delete just one if --max-delete was set + delete_target_object.reset_mock() + with patch.multiple(imp, create=True, cached_target=dict(cache)): + source_keys = set() + with patch.object(imp, 'max_delete', new=1): + result = imp.do_delete(source_keys) + self.assertEqual(delete_target_object.call_count, 1) + self.assertEqual(len(result), 1) + + # delete just one if --max-total was set + delete_target_object.reset_mock() + with patch.multiple(imp, create=True, cached_target=dict(cache)): + source_keys = set() + with patch.object(imp, 'max_total', new=1): + result = imp.do_delete(source_keys) + self.assertEqual(delete_target_object.call_count, 1) + self.assertEqual(len(result), 1) def test_get_record_key(self): model = self.app.model @@ -182,6 +359,22 @@ class TestImporter(DataTestCase): # nb. default normalizer returns object as-is self.assertIs(data[0], setting) + def test_get_unique_data(self): + model = self.app.model + imp = self.make_importer(model_class=model.Setting) + + setting1 = model.Setting(name='foo', value='bar1') + setting2 = model.Setting(name='foo', value='bar2') + + result = imp.get_unique_data([setting2, setting1]) + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + self.assertIsInstance(result[0], list) + self.assertEqual(len(result[0]), 1) + self.assertIs(result[0][0], setting2) # nb. not setting1 + self.assertIsInstance(result[1], set) + self.assertEqual(result[1], {('foo',)}) + def test_get_source_objects(self): model = self.app.model imp = self.make_importer(model_class=model.Setting) @@ -263,6 +456,34 @@ class TestImporter(DataTestCase): data = imp.normalize_target_object(setting) self.assertEqual(data, {'name': 'foo', 'value': 'bar'}) + def test_get_deletable_keys(self): + model = self.app.model + imp = self.make_importer(model_class=model.Setting) + + # empty set by default (nb. no target cache) + result = imp.get_deletable_keys() + self.assertIsInstance(result, set) + self.assertEqual(result, set()) + + setting = model.Setting(name='foo', value='bar') + cache = { + ('foo',): { + 'object': setting, + 'data': {'name': 'foo', 'value': 'bar'}, + }, + } + + with patch.multiple(imp, create=True, caches_target=True, cached_target=cache): + + # all are deletable by default + result = imp.get_deletable_keys() + self.assertEqual(result, {('foo',)}) + + # but some maybe can't be deleted + with patch.object(imp, 'can_delete_object', return_value=False): + result = imp.get_deletable_keys() + self.assertEqual(result, set()) + def test_create_target_object(self): model = self.app.model imp = self.make_importer(model_class=model.Setting) @@ -301,6 +522,19 @@ class TestImporter(DataTestCase): self.assertIs(obj, setting) self.assertEqual(setting.value, 'bar') + def test_can_delete_object(self): + model = self.app.model + imp = self.make_importer(model_class=model.Setting) + setting = model.Setting(name='foo') + self.assertTrue(imp.can_delete_object(setting)) + + def test_delete_target_object(self): + model = self.app.model + imp = self.make_importer(model_class=model.Setting) + setting = model.Setting(name='foo') + # nb. default implementation always returns false + self.assertFalse(imp.delete_target_object(setting)) + class TestFromFile(DataTestCase): @@ -390,6 +624,20 @@ class TestToSqlalchemy(DataTestCase): kwargs.setdefault('handler', self.handler) return mod.ToSqlalchemy(self.config, **kwargs) + def test_get_target_objects(self): + model = self.app.model + imp = self.make_importer(model_class=model.Setting, target_session=self.session) + + setting1 = model.Setting(name='foo', value='bar') + self.session.add(setting1) + setting2 = model.Setting(name='foo2', value='bar2') + self.session.add(setting2) + self.session.commit() + + result = imp.get_target_objects() + self.assertEqual(len(result), 2) + self.assertEqual(set(result), {setting1, setting2}) + def test_get_target_object(self): model = self.app.model setting = model.Setting(name='foo', value='bar') @@ -416,15 +664,19 @@ class TestToSqlalchemy(DataTestCase): self.session.add(setting2) self.session.commit() - # then we should be able to fetch that via query - imp.target_session = self.session - result = imp.get_target_object(('foo2',)) - self.assertIsInstance(result, model.Setting) - self.assertIs(result, setting2) + # nb. disable target cache + with patch.multiple(imp, create=True, + target_session=self.session, + caches_target=False): - # but sometimes it will not be found - result = imp.get_target_object(('foo3',)) - self.assertIsNone(result) + # now we should be able to fetch that via query + result = imp.get_target_object(('foo2',)) + self.assertIsInstance(result, model.Setting) + self.assertIs(result, setting2) + + # but sometimes it will not be found + result = imp.get_target_object(('foo3',)) + self.assertIsNone(result) def test_create_target_object(self): model = self.app.model @@ -438,16 +690,13 @@ class TestToSqlalchemy(DataTestCase): self.assertEqual(setting.value, 'bar') self.assertIn(setting, self.session) - def test_get_target_objects(self): + def test_delete_target_object(self): model = self.app.model + + setting = model.Setting(name='foo', value='bar') + self.session.add(setting) + + self.assertEqual(self.session.query(model.Setting).count(), 1) imp = self.make_importer(model_class=model.Setting, target_session=self.session) - - setting1 = model.Setting(name='foo', value='bar') - self.session.add(setting1) - setting2 = model.Setting(name='foo2', value='bar2') - self.session.add(setting2) - self.session.commit() - - result = imp.get_target_objects() - self.assertEqual(len(result), 2) - self.assertEqual(set(result), {setting1, setting2}) + imp.delete_target_object(setting) + self.assertEqual(self.session.query(model.Setting).count(), 0) diff --git a/tests/importing/test_csv.py b/tests/importing/test_csv.py index 683215e..dc65e54 100644 --- a/tests/importing/test_csv.py +++ b/tests/importing/test_csv.py @@ -1,6 +1,7 @@ #-*- coding: utf-8; -*- import csv +import uuid as _uuid from unittest.mock import patch from wuttjamaican.testing import DataTestCase @@ -87,23 +88,74 @@ foo2,bar2 self.assertEqual(objects[1], {'name': 'foo2', 'value': 'bar2'}) -class MockMixinHandler(mod.FromCsvToSqlalchemyMixin, ToSqlalchemyHandler): - ToImporterBase = ToSqlalchemy +class MockMixinImporter(mod.FromCsvToSqlalchemyMixin, mod.FromCsv, ToSqlalchemy): + pass class TestFromCsvToSqlalchemyMixin(DataTestCase): + def setUp(self): + self.setup_db() + self.handler = ImportHandler(self.config) + + def make_importer(self, **kwargs): + kwargs.setdefault('handler', self.handler) + return MockMixinImporter(self.config, **kwargs) + + def test_constructor(self): + model = self.app.model + + # no uuid keys + imp = self.make_importer(model_class=model.Setting) + self.assertEqual(imp.uuid_keys, []) + + # typical + # nb. as of now Upgrade is the only table using proper UUID + imp = self.make_importer(model_class=model.Upgrade) + self.assertEqual(imp.uuid_keys, ['uuid']) + + def test_normalize_source_object(self): + model = self.app.model + + # no uuid keys + imp = self.make_importer(model_class=model.Setting) + result = imp.normalize_source_object({'name': 'foo', 'value': 'bar'}) + self.assertEqual(result, {'name': 'foo', 'value': 'bar'}) + + # source has proper UUID + # nb. as of now Upgrade is the only table using proper UUID + imp = self.make_importer(model_class=model.Upgrade, fields=['uuid', 'description']) + result = imp.normalize_source_object({'uuid': _uuid.UUID('06753693-d892-77f0-8000-ce71bf7ebbba'), + 'description': 'testing'}) + self.assertEqual(result, {'uuid': _uuid.UUID('06753693-d892-77f0-8000-ce71bf7ebbba'), + 'description': 'testing'}) + + # source has string uuid + # nb. as of now Upgrade is the only table using proper UUID + imp = self.make_importer(model_class=model.Upgrade, fields=['uuid', 'description']) + result = imp.normalize_source_object({'uuid': '06753693d89277f08000ce71bf7ebbba', + 'description': 'testing'}) + self.assertEqual(result, {'uuid': _uuid.UUID('06753693-d892-77f0-8000-ce71bf7ebbba'), + 'description': 'testing'}) + + +class MockMixinHandler(mod.FromCsvToSqlalchemyHandlerMixin, ToSqlalchemyHandler): + ToImporterBase = ToSqlalchemy + + +class TestFromCsvToSqlalchemyHandlerMixin(DataTestCase): + def make_handler(self, **kwargs): return MockMixinHandler(self.config, **kwargs) def test_get_target_model(self): - with patch.object(mod.FromCsvToSqlalchemyMixin, 'define_importers', return_value={}): + with patch.object(mod.FromCsvToSqlalchemyHandlerMixin, 'define_importers', return_value={}): handler = self.make_handler() self.assertRaises(NotImplementedError, handler.get_target_model) def test_define_importers(self): model = self.app.model - with patch.object(mod.FromCsvToSqlalchemyMixin, 'get_target_model', return_value=model): + with patch.object(mod.FromCsvToSqlalchemyHandlerMixin, 'get_target_model', return_value=model): handler = self.make_handler() importers = handler.define_importers() self.assertIn('Setting', importers) @@ -115,7 +167,7 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase): def test_make_importer_factory(self): model = self.app.model - with patch.object(mod.FromCsvToSqlalchemyMixin, 'define_importers', return_value={}): + with patch.object(mod.FromCsvToSqlalchemyHandlerMixin, 'define_importers', return_value={}): handler = self.make_handler() factory = handler.make_importer_factory(model.Setting, 'Setting') self.assertTrue(issubclass(factory, mod.FromCsv))