Assign extra import handler kwargs before loading the importers

e.g. for CORE importers which can target different DB and therefore
the set of importers must vary by DB
This commit is contained in:
Lance Edgar 2024-05-08 13:35:06 -05:00
parent 033a0c27ec
commit 329d051d76
2 changed files with 24 additions and 18 deletions

View file

@ -2,7 +2,7 @@
################################################################################ ################################################################################
# #
# Rattail -- Retail Software Framework # Rattail -- Retail Software Framework
# Copyright © 2010-2023 Lance Edgar # Copyright © 2010-2024 Lance Edgar
# #
# This file is part of Rattail. # This file is part of Rattail.
# #
@ -35,9 +35,7 @@ import humanize
import markupsafe import markupsafe
import sqlalchemy as sa import sqlalchemy as sa
from rattail.core import get_uuid from rattail.util import get_object_spec
from rattail.time import make_utc
from rattail.util import get_object_spec, progress_loop
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -96,11 +94,15 @@ class ImportHandler(object):
self.app = self.config.get_app() self.app = self.config.get_app()
self.enum = self.config.get_enum() self.enum = self.config.get_enum()
self.model = self.config.get_model() self.model = self.config.get_model()
self.importers = self.get_importers()
self.extra_importer_kwargs = kwargs.pop('extra_importer_kwargs', {}) # nb. must assign extra attrs before we get_importers() since
# attrs may need to affect that behavior
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self, key, value) setattr(self, key, value)
self.importers = self.get_importers()
self.extra_importer_kwargs = kwargs.pop('extra_importer_kwargs', {})
@classmethod @classmethod
def get_key(cls): def get_key(cls):
return 'to_{}.from_{}.{}'.format(cls.local_key, cls.host_key, cls.direction) return 'to_{}.from_{}.{}'.format(cls.local_key, cls.host_key, cls.direction)
@ -213,7 +215,7 @@ class ImportHandler(object):
importer._handler_key = key importer._handler_key = key
batch = model.ImporterBatch() batch = model.ImporterBatch()
batch.uuid = get_uuid() batch.uuid = self.app.make_uuid()
batch.created_by = user batch.created_by = user
batch.batch_handler_spec = handler_spec batch.batch_handler_spec = handler_spec
batch.import_handler_spec = get_object_spec(self) batch.import_handler_spec = get_object_spec(self)
@ -290,7 +292,7 @@ class ImportHandler(object):
return row_table return row_table
def populate_row_table(self, session, importer, batch, row_table): def populate_row_table(self, session, importer, batch, row_table):
importer.now = make_utc(tzinfo=True) importer.now = self.app.make_utc(tzinfo=True)
importer.setup() importer.setup()
# obtain host data # obtain host data
@ -323,8 +325,8 @@ class ImportHandler(object):
status_code=self.enum.IMPORTER_BATCH_ROW_STATUS_DELETE) status_code=self.enum.IMPORTER_BATCH_ROW_STATUS_DELETE)
batch.rowcount += 1 batch.rowcount += 1
progress_loop(delete, sorted(deleting), self.progress, self.app.progress_loop(delete, sorted(deleting), self.progress,
message="Deleting {} data".format(importer.model_name)) message=f"Deleting {importer.model_name} data")
def _populate_create_update(self, session, importer, batch, row_table, data): def _populate_create_update(self, session, importer, batch, row_table, data):
@ -358,12 +360,12 @@ class ImportHandler(object):
status_code=status_code, status_text=status_text) status_code=status_code, status_text=status_text)
batch.rowcount += 1 batch.rowcount += 1
progress_loop(record, data, self.progress, self.app.progress_loop(record, data, self.progress,
message="Populating batch for {}".format(importer._handler_key)) message=f"Populating batch for {importer._handler_key}")
def make_batch_row(self, session, importer, row_table, sequence, host_data, local_data, status_code=None, status_text=None): def make_batch_row(self, session, importer, row_table, sequence, host_data, local_data, status_code=None, status_text=None):
values = { values = {
'uuid': get_uuid(), 'uuid': self.app.make_uuid(),
'sequence': sequence, 'sequence': sequence,
'object_str': '', 'object_str': '',
'status_code': status_code, 'status_code': status_code,
@ -405,7 +407,7 @@ class ImportHandler(object):
attribute after this method completes. This would be a dictionary attribute after this method completes. This would be a dictionary
whose keys are model names and values are the importer instances. whose keys are model names and values are the importer instances.
""" """
self.import_began = make_utc(tzinfo=True) self.import_began = self.app.make_utc(tzinfo=True)
retain_used_importers = kwargs.pop('retain_used_importers', False) retain_used_importers = kwargs.pop('retain_used_importers', False)
if 'dry_run' in kwargs: if 'dry_run' in kwargs:
self.dry_run = kwargs['dry_run'] self.dry_run = kwargs['dry_run']
@ -532,7 +534,7 @@ class ImportHandler(object):
if not self.warnings: if not self.warnings:
return return
now = make_utc(tzinfo=True) now = self.app.make_utc(tzinfo=True)
data = { data = {
'local_title': self.local_title, 'local_title': self.local_title,
'host_title': self.host_title, 'host_title': self.host_title,
@ -571,7 +573,7 @@ class BulkImportHandler(ImportHandler):
Import all data for the given importer/model keys. Import all data for the given importer/model keys.
""" """
# TODO: still need to refactor much of this so can share with parent class # TODO: still need to refactor much of this so can share with parent class
self.import_began = make_utc(tzinfo=True) self.import_began = self.app.make_utc(tzinfo=True)
if 'dry_run' in kwargs: if 'dry_run' in kwargs:
self.dry_run = kwargs['dry_run'] self.dry_run = kwargs['dry_run']
self.progress = kwargs.pop('progress', getattr(self, 'progress', None)) self.progress = kwargs.pop('progress', getattr(self, 'progress', None))

View file

@ -8,7 +8,6 @@ import pytz
from sqlalchemy import orm from sqlalchemy import orm
from mock import patch, Mock from mock import patch, Mock
from rattail.db import Session
from rattail.importing import handlers, Importer from rattail.importing import handlers, Importer
from rattail.config import make_config from rattail.config import make_config
from .. import RattailTestCase from .. import RattailTestCase
@ -165,11 +164,13 @@ class ImportHandlerBattery(ImporterTester):
handler.rollback_local_transaction() handler.rollback_local_transaction()
def test_import_data(self): def test_import_data(self):
self.config = self.make_config()
handler = self.make_handler() handler = self.make_handler()
result = handler.import_data() result = handler.import_data()
self.assertEqual(result, {}) self.assertEqual(result, {})
def test_import_data_dry_run(self): def test_import_data_dry_run(self):
self.config = self.make_config()
# as init kwarg # as init kwarg
handler = self.make_handler(dry_run=True) handler = self.make_handler(dry_run=True)
@ -224,6 +225,7 @@ class ImportHandlerBattery(ImporterTester):
process.assert_called_once_with({'Foo': ([1], [2], [3])}) process.assert_called_once_with({'Foo': ([1], [2], [3])})
def test_import_data_commit_host_partial(self): def test_import_data_commit_host_partial(self):
self.config = self.make_config()
handler = self.make_handler() handler = self.make_handler()
importer = Mock() importer = Mock()
importer.import_data.side_effect = ValueError importer.import_data.side_effect = ValueError
@ -244,6 +246,7 @@ class ImportHandlerBattery(ImporterTester):
class BulkImportHandlerBattery(ImportHandlerBattery): class BulkImportHandlerBattery(ImportHandlerBattery):
def test_import_data_invalid_model(self): def test_import_data_invalid_model(self):
self.config = self.make_config()
handler = self.make_handler() handler = self.make_handler()
importer = Mock() importer = Mock()
importer.import_data.return_value = 0 importer.import_data.return_value = 0
@ -262,6 +265,7 @@ class BulkImportHandlerBattery(ImportHandlerBattery):
self.assertFalse(importer.called) self.assertFalse(importer.called)
def test_import_data_with_changes(self): def test_import_data_with_changes(self):
self.config = self.make_config()
handler = self.make_handler() handler = self.make_handler()
importer = Mock() importer = Mock()
FooImporter = Mock(return_value=importer) FooImporter = Mock(return_value=importer)
@ -387,7 +391,7 @@ class TestImportHandlerImportData(ImporterTester, unittest.TestCase):
self.result = [], [], [] self.result = [], [], []
def test_invalid_importer_key_is_ignored(self): def test_invalid_importer_key_is_ignored(self):
handler = handlers.ImportHandler() handler = handlers.ImportHandler(self.config)
self.assertNotIn('InvalidKey', handler.importers) self.assertNotIn('InvalidKey', handler.importers)
self.assertEqual(handler.import_data('InvalidKey'), {}) self.assertEqual(handler.import_data('InvalidKey'), {})