Add BulkImporter
and BulkImportHandler
base classes
This commit is contained in:
parent
10040a8c3b
commit
7a2ef35518
|
@ -26,9 +26,9 @@ Data Importing Framework
|
|||
|
||||
from __future__ import unicode_literals, absolute_import
|
||||
|
||||
from .importers import Importer, FromQuery
|
||||
from .importers import Importer, FromQuery, BulkImporter
|
||||
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
|
||||
from .postgresql import BulkToPostgreSQL
|
||||
from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler, BulkToPostgreSQLHandler
|
||||
from .handlers import ImportHandler, BulkImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler
|
||||
from .rattail import FromRattailHandler, ToRattailHandler
|
||||
from . import model
|
||||
|
|
|
@ -229,6 +229,55 @@ class ImportHandler(object):
|
|||
log.info("warning email was sent for {} -> {} import".format(self.host_title, self.local_title))
|
||||
|
||||
|
||||
class BulkImportHandler(ImportHandler):
|
||||
"""
|
||||
Base class for bulk import handlers.
|
||||
"""
|
||||
|
||||
def import_data(self, *keys, **kwargs):
|
||||
"""
|
||||
Import all data for the given importer/model keys.
|
||||
"""
|
||||
# TODO: still need to refactor much of this so can share with parent class
|
||||
self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
||||
if 'dry_run' in kwargs:
|
||||
self.dry_run = kwargs['dry_run']
|
||||
self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
|
||||
self.warnings = kwargs.pop('warnings', False)
|
||||
kwargs.update({'dry_run': self.dry_run,
|
||||
'progress': self.progress})
|
||||
self.setup()
|
||||
self.begin_transaction()
|
||||
changes = OrderedDict()
|
||||
|
||||
try:
|
||||
for key in keys:
|
||||
importer = self.get_importer(key, **kwargs)
|
||||
if not importer:
|
||||
log.warning("skipping unknown importer: {}".format(key))
|
||||
continue
|
||||
|
||||
created = importer.import_data()
|
||||
log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
|
||||
self.host_title, self.local_title, created, key))
|
||||
if created:
|
||||
changes[key] = created
|
||||
except:
|
||||
if self.commit_host_partial and not self.dry_run:
|
||||
log.warning("{host} -> {local}: committing partial transaction on host {host} (despite error)".format(
|
||||
host=self.host_title, local=self.local_title))
|
||||
self.commit_host_transaction()
|
||||
raise
|
||||
else:
|
||||
if self.dry_run:
|
||||
self.rollback_transaction()
|
||||
else:
|
||||
self.commit_transaction()
|
||||
|
||||
self.teardown()
|
||||
return changes
|
||||
|
||||
|
||||
class FromSQLAlchemyHandler(ImportHandler):
|
||||
"""
|
||||
Handler for imports for which the host data source is represented by a
|
||||
|
@ -292,42 +341,7 @@ class ToSQLAlchemyHandler(ImportHandler):
|
|||
self.session = None
|
||||
|
||||
|
||||
class BulkToPostgreSQLHandler(ToSQLAlchemyHandler):
|
||||
class BulkToPostgreSQLHandler(BulkImportHandler):
|
||||
"""
|
||||
Handler for bulk imports which target PostgreSQL on the local side.
|
||||
"""
|
||||
|
||||
def import_data(self, *keys, **kwargs):
|
||||
"""
|
||||
Import all data for the given importer/model keys.
|
||||
"""
|
||||
# TODO: still need to refactor much of this so can share with parent class
|
||||
self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
||||
if 'dry_run' in kwargs:
|
||||
self.dry_run = kwargs['dry_run']
|
||||
self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
|
||||
self.warnings = kwargs.pop('warnings', False)
|
||||
kwargs.update({'dry_run': self.dry_run,
|
||||
'progress': self.progress})
|
||||
self.setup()
|
||||
self.begin_transaction()
|
||||
changes = OrderedDict()
|
||||
|
||||
for key in keys:
|
||||
importer = self.get_importer(key, **kwargs)
|
||||
if not importer:
|
||||
log.warning("skipping unknown importer: {}".format(key))
|
||||
continue
|
||||
|
||||
created = importer.import_data()
|
||||
log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
|
||||
self.host_title, self.local_title, created, key))
|
||||
if created:
|
||||
changes[key] = created
|
||||
|
||||
if self.dry_run:
|
||||
self.rollback_transaction()
|
||||
else:
|
||||
self.commit_transaction()
|
||||
self.teardown()
|
||||
return changes
|
||||
|
|
|
@ -456,3 +456,54 @@ class FromQuery(Importer):
|
|||
Returns (raw) query results as a sequence.
|
||||
"""
|
||||
return QuerySequence(self.query())
|
||||
|
||||
|
||||
class BulkImporter(Importer):
|
||||
"""
|
||||
Base class for bulk data importers which target PostgreSQL on the local side.
|
||||
"""
|
||||
|
||||
def import_data(self, host_data=None, now=None, **kwargs):
|
||||
self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
||||
if kwargs:
|
||||
self._setup(**kwargs)
|
||||
self.setup()
|
||||
if host_data is None:
|
||||
host_data = self.normalize_host_data()
|
||||
created = self._import_create(host_data)
|
||||
self.teardown()
|
||||
return created
|
||||
|
||||
def _import_create(self, data):
|
||||
count = len(data)
|
||||
if not count:
|
||||
return 0
|
||||
created = count
|
||||
|
||||
prog = None
|
||||
if self.progress:
|
||||
prog = self.progress("Importing {} data".format(self.model_name), count)
|
||||
|
||||
for i, host_data in enumerate(data, 1):
|
||||
|
||||
key = self.get_key(host_data)
|
||||
self.create_object(key, host_data)
|
||||
if self.max_create and i >= self.max_create:
|
||||
log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
|
||||
created = i
|
||||
break
|
||||
|
||||
if prog:
|
||||
prog.update(i)
|
||||
if prog:
|
||||
prog.destroy()
|
||||
|
||||
self.flush_create()
|
||||
return created
|
||||
|
||||
def flush_create(self):
|
||||
"""
|
||||
Perform any final steps to "flush" the created data here. Note that
|
||||
the importer's handler is still responsible for actually committing
|
||||
changes to the local system, if applicable.
|
||||
"""
|
||||
|
|
|
@ -30,14 +30,14 @@ import os
|
|||
import datetime
|
||||
import logging
|
||||
|
||||
from rattail.importing.sqlalchemy import ToSQLAlchemy
|
||||
from rattail.importing import BulkImporter, ToSQLAlchemy
|
||||
from rattail.time import make_utc
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BulkToPostgreSQL(ToSQLAlchemy):
|
||||
class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy):
|
||||
"""
|
||||
Base class for bulk data importers which target PostgreSQL on the local side.
|
||||
"""
|
||||
|
@ -55,44 +55,6 @@ class BulkToPostgreSQL(ToSQLAlchemy):
|
|||
os.remove(self.data_path)
|
||||
self.data_buffer = None
|
||||
|
||||
def import_data(self, host_data=None, now=None, **kwargs):
|
||||
self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
||||
if kwargs:
|
||||
self._setup(**kwargs)
|
||||
self.setup()
|
||||
if host_data is None:
|
||||
host_data = self.normalize_host_data()
|
||||
created = self._import_create(host_data)
|
||||
self.teardown()
|
||||
return created
|
||||
|
||||
def _import_create(self, data):
|
||||
count = len(data)
|
||||
if not count:
|
||||
return 0
|
||||
created = count
|
||||
|
||||
prog = None
|
||||
if self.progress:
|
||||
prog = self.progress("Importing {} data".format(self.model_name), count)
|
||||
|
||||
for i, host_data in enumerate(data, 1):
|
||||
|
||||
key = self.get_key(host_data)
|
||||
self.create_object(key, host_data)
|
||||
if self.max_create and i >= self.max_create:
|
||||
log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
|
||||
created = i
|
||||
break
|
||||
|
||||
if prog:
|
||||
prog.update(i)
|
||||
if prog:
|
||||
prog.destroy()
|
||||
|
||||
self.commit_create()
|
||||
return created
|
||||
|
||||
def create_object(self, key, data):
|
||||
data = self.prep_data_for_postgres(data)
|
||||
self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])).encode('utf-8'))
|
||||
|
@ -121,7 +83,7 @@ class BulkToPostgreSQL(ToSQLAlchemy):
|
|||
|
||||
return unicode(value)
|
||||
|
||||
def commit_create(self):
|
||||
def flush_create(self):
|
||||
log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
|
||||
self.data_buffer.close()
|
||||
self.data_buffer = open(self.data_path, 'rb')
|
||||
|
|
|
@ -31,7 +31,7 @@ from rattail.util import OrderedDict
|
|||
from rattail.importing.rattail import FromRattailToRattail, FromRattail
|
||||
|
||||
|
||||
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkToPostgreSQLHandler):
|
||||
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkImportHandler):
|
||||
"""
|
||||
Handler for Rattail -> Rattail bulk data import.
|
||||
"""
|
||||
|
|
|
@ -19,12 +19,12 @@ from rattail.tests.importing.test_importers import MockImporter
|
|||
from rattail.tests.importing.test_postgresql import MockBulkImporter
|
||||
|
||||
|
||||
class TestImportHandler(unittest.TestCase):
|
||||
class ImportHandlerBattery(ImporterTester):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
# vanilla
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.handler_class()
|
||||
self.assertEqual(handler.importers, {})
|
||||
self.assertEqual(handler.get_importers(), {})
|
||||
self.assertEqual(handler.get_importer_keys(), [])
|
||||
|
@ -32,34 +32,34 @@ class TestImportHandler(unittest.TestCase):
|
|||
self.assertFalse(handler.commit_host_partial)
|
||||
|
||||
# with config
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.handler_class()
|
||||
self.assertIsNone(handler.config)
|
||||
config = RattailConfig()
|
||||
handler = handlers.ImportHandler(config=config)
|
||||
handler = self.handler_class(config=config)
|
||||
self.assertIs(handler.config, config)
|
||||
|
||||
# dry run
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.handler_class()
|
||||
self.assertFalse(handler.dry_run)
|
||||
handler = handlers.ImportHandler(dry_run=True)
|
||||
handler = self.handler_class(dry_run=True)
|
||||
self.assertTrue(handler.dry_run)
|
||||
|
||||
# extra kwarg
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.handler_class()
|
||||
self.assertRaises(AttributeError, getattr, handler, 'foo')
|
||||
handler = handlers.ImportHandler(foo='bar')
|
||||
handler = self.handler_class(foo='bar')
|
||||
self.assertEqual(handler.foo, 'bar')
|
||||
|
||||
def test_get_importer(self):
|
||||
get_importers = Mock(return_value={'foo': Importer})
|
||||
|
||||
# no importers
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
self.assertIsNone(handler.get_importer('foo'))
|
||||
|
||||
# no config
|
||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
||||
handler = handlers.ImportHandler()
|
||||
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||
handler = self.handler_class()
|
||||
importer = handler.get_importer('foo')
|
||||
self.assertIs(type(importer), Importer)
|
||||
self.assertIsNone(importer.config)
|
||||
|
@ -67,26 +67,26 @@ class TestImportHandler(unittest.TestCase):
|
|||
|
||||
# with config
|
||||
config = RattailConfig()
|
||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
||||
handler = handlers.ImportHandler(config=config)
|
||||
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||
handler = self.handler_class(config=config)
|
||||
importer = handler.get_importer('foo')
|
||||
self.assertIs(type(importer), Importer)
|
||||
self.assertIs(importer.config, config)
|
||||
self.assertIs(importer.handler, handler)
|
||||
|
||||
# dry run
|
||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
||||
handler = handlers.ImportHandler()
|
||||
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||
handler = self.handler_class()
|
||||
importer = handler.get_importer('foo')
|
||||
self.assertFalse(importer.dry_run)
|
||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
||||
handler = handlers.ImportHandler(dry_run=True)
|
||||
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||
handler = self.handler_class(dry_run=True)
|
||||
importer = handler.get_importer('foo')
|
||||
self.assertTrue(handler.dry_run)
|
||||
|
||||
# host title
|
||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
||||
handler = handlers.ImportHandler()
|
||||
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||
handler = self.handler_class()
|
||||
importer = handler.get_importer('foo')
|
||||
self.assertIsNone(importer.host_system_title)
|
||||
handler.host_title = "Foo"
|
||||
|
@ -94,8 +94,8 @@ class TestImportHandler(unittest.TestCase):
|
|||
self.assertEqual(importer.host_system_title, "Foo")
|
||||
|
||||
# extra kwarg
|
||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
||||
handler = handlers.ImportHandler()
|
||||
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||
handler = self.handler_class()
|
||||
importer = handler.get_importer('foo')
|
||||
self.assertRaises(AttributeError, getattr, importer, 'bar')
|
||||
importer = handler.get_importer('foo', bar='baz')
|
||||
|
@ -104,15 +104,15 @@ class TestImportHandler(unittest.TestCase):
|
|||
def test_get_importer_kwargs(self):
|
||||
|
||||
# empty by default
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
self.assertEqual(handler.get_importer_kwargs('foo'), {})
|
||||
|
||||
# extra kwargs are preserved
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'})
|
||||
|
||||
def test_begin_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
with patch.object(handler, 'begin_host_transaction') as begin_host:
|
||||
with patch.object(handler, 'begin_local_transaction') as begin_local:
|
||||
handler.begin_transaction()
|
||||
|
@ -120,15 +120,15 @@ class TestImportHandler(unittest.TestCase):
|
|||
begin_local.assert_called_once_with()
|
||||
|
||||
def test_begin_host_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
handler.begin_host_transaction()
|
||||
|
||||
def test_begin_local_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
handler.begin_local_transaction()
|
||||
|
||||
def test_commit_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
with patch.object(handler, 'commit_host_transaction') as commit_host:
|
||||
with patch.object(handler, 'commit_local_transaction') as commit_local:
|
||||
handler.commit_transaction()
|
||||
|
@ -136,15 +136,15 @@ class TestImportHandler(unittest.TestCase):
|
|||
commit_local.assert_called_once_with()
|
||||
|
||||
def test_commit_host_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
handler.commit_host_transaction()
|
||||
|
||||
def test_commit_local_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
handler.commit_local_transaction()
|
||||
|
||||
def test_rollback_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
with patch.object(handler, 'rollback_host_transaction') as rollback_host:
|
||||
with patch.object(handler, 'rollback_local_transaction') as rollback_local:
|
||||
handler.rollback_transaction()
|
||||
|
@ -152,24 +152,22 @@ class TestImportHandler(unittest.TestCase):
|
|||
rollback_local.assert_called_once_with()
|
||||
|
||||
def test_rollback_host_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
handler.rollback_host_transaction()
|
||||
|
||||
def test_rollback_local_transaction(self):
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
handler.rollback_local_transaction()
|
||||
|
||||
def test_import_data(self):
|
||||
|
||||
# normal
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
result = handler.import_data()
|
||||
self.assertEqual(result, {})
|
||||
|
||||
def test_import_data_dry_run(self):
|
||||
|
||||
# as init kwarg
|
||||
handler = handlers.ImportHandler(dry_run=True)
|
||||
handler = self.make_handler(dry_run=True)
|
||||
with patch.object(handler, 'commit_transaction') as commit:
|
||||
with patch.object(handler, 'rollback_transaction') as rollback:
|
||||
handler.import_data()
|
||||
|
@ -178,7 +176,7 @@ class TestImportHandler(unittest.TestCase):
|
|||
self.assertTrue(handler.dry_run)
|
||||
|
||||
# as import kwarg
|
||||
handler = handlers.ImportHandler()
|
||||
handler = self.make_handler()
|
||||
with patch.object(handler, 'commit_transaction') as commit:
|
||||
with patch.object(handler, 'rollback_transaction') as rollback:
|
||||
handler.import_data(dry_run=True)
|
||||
|
@ -187,11 +185,10 @@ class TestImportHandler(unittest.TestCase):
|
|||
self.assertTrue(handler.dry_run)
|
||||
|
||||
def test_import_data_invalid_model(self):
|
||||
handler = self.make_handler()
|
||||
importer = Mock()
|
||||
importer.import_data.return_value = [], [], []
|
||||
FooImporter = Mock(return_value=importer)
|
||||
|
||||
handler = handlers.ImportHandler()
|
||||
handler.importers = {'Foo': FooImporter}
|
||||
|
||||
handler.import_data('Foo')
|
||||
|
@ -206,10 +203,9 @@ class TestImportHandler(unittest.TestCase):
|
|||
self.assertFalse(importer.called)
|
||||
|
||||
def test_import_data_with_changes(self):
|
||||
handler = self.make_handler()
|
||||
importer = Mock()
|
||||
FooImporter = Mock(return_value=importer)
|
||||
|
||||
handler = handlers.ImportHandler()
|
||||
handler.importers = {'Foo': FooImporter}
|
||||
|
||||
importer.import_data.return_value = [], [], []
|
||||
|
@ -223,11 +219,10 @@ class TestImportHandler(unittest.TestCase):
|
|||
process.assert_called_once_with({'Foo': ([1], [2], [3])})
|
||||
|
||||
def test_import_data_commit_host_partial(self):
|
||||
handler = self.make_handler()
|
||||
importer = Mock()
|
||||
importer.import_data.side_effect = ValueError
|
||||
FooImporter = Mock(return_value=importer)
|
||||
|
||||
handler = handlers.ImportHandler()
|
||||
handler.importers = {'Foo': FooImporter}
|
||||
|
||||
handler.commit_host_partial = False
|
||||
|
@ -240,6 +235,47 @@ class TestImportHandler(unittest.TestCase):
|
|||
self.assertRaises(ValueError, handler.import_data, 'Foo')
|
||||
commit.assert_called_once_with()
|
||||
|
||||
|
||||
class BulkImportHandlerBattery(ImportHandlerBattery):
|
||||
|
||||
def test_import_data_invalid_model(self):
|
||||
handler = self.make_handler()
|
||||
importer = Mock()
|
||||
importer.import_data.return_value = 0
|
||||
FooImporter = Mock(return_value=importer)
|
||||
handler.importers = {'Foo': FooImporter}
|
||||
|
||||
handler.import_data('Foo')
|
||||
self.assertEqual(FooImporter.call_count, 1)
|
||||
importer.import_data.assert_called_once_with()
|
||||
|
||||
FooImporter.reset_mock()
|
||||
importer.reset_mock()
|
||||
|
||||
handler.import_data('Missing')
|
||||
self.assertFalse(FooImporter.called)
|
||||
self.assertFalse(importer.called)
|
||||
|
||||
def test_import_data_with_changes(self):
|
||||
handler = self.make_handler()
|
||||
importer = Mock()
|
||||
FooImporter = Mock(return_value=importer)
|
||||
handler.importers = {'Foo': FooImporter}
|
||||
|
||||
importer.import_data.return_value = 0
|
||||
with patch.object(handler, 'process_changes') as process:
|
||||
handler.import_data('Foo')
|
||||
self.assertFalse(process.called)
|
||||
|
||||
importer.import_data.return_value = 3
|
||||
with patch.object(handler, 'process_changes') as process:
|
||||
handler.import_data('Foo')
|
||||
self.assertFalse(process.called)
|
||||
|
||||
|
||||
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
|
||||
handler_class = handlers.ImportHandler
|
||||
|
||||
@patch('rattail.importing.handlers.send_email')
|
||||
def test_process_changes_sends_email(self, send_email):
|
||||
handler = handlers.ImportHandler()
|
||||
|
@ -265,6 +301,10 @@ class TestImportHandler(unittest.TestCase):
|
|||
self.assertEqual(send_email.call_count, 1)
|
||||
|
||||
|
||||
class TestBulkImportHandler(unittest.TestCase, BulkImportHandlerBattery):
|
||||
handler_class = handlers.BulkImportHandler
|
||||
|
||||
|
||||
######################################################################
|
||||
# fake import handler, tested mostly for basic coverage
|
||||
######################################################################
|
||||
|
@ -566,7 +606,7 @@ class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler):
|
|||
return Session()
|
||||
|
||||
|
||||
class TestBulkImportHandler(RattailTestCase, ImporterTester):
|
||||
class TestBulkImportHandlerOld(RattailTestCase, ImporterTester):
|
||||
|
||||
importer_class = MockBulkImporter
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import unicode_literals, absolute_import
|
|||
|
||||
from unittest import TestCase
|
||||
|
||||
from mock import Mock, patch
|
||||
from mock import Mock, patch, call
|
||||
|
||||
from rattail.db import model
|
||||
from rattail.db.util import QuerySequence
|
||||
|
@ -13,6 +13,49 @@ from rattail.tests import NullProgress, RattailTestCase
|
|||
from rattail.tests.importing import ImporterTester
|
||||
|
||||
|
||||
class ImporterBattery(ImporterTester):
|
||||
"""
|
||||
Battery of tests which can hopefully be ran for any non-bulk importer.
|
||||
"""
|
||||
|
||||
def test_import_data_empty(self):
|
||||
importer = self.make_importer()
|
||||
result = importer.import_data()
|
||||
self.assertEqual(result, {})
|
||||
|
||||
def test_import_data_dry_run(self):
|
||||
importer = self.make_importer()
|
||||
self.assertFalse(importer.dry_run)
|
||||
importer.import_data(dry_run=True)
|
||||
self.assertTrue(importer.dry_run)
|
||||
|
||||
def test_import_data_create(self):
|
||||
importer = self.make_importer()
|
||||
with patch.object(importer, 'get_key', lambda k: k):
|
||||
with patch.object(importer, 'create_object') as create:
|
||||
importer.import_data(host_data=[1, 2, 3])
|
||||
self.assertEqual(create.call_args_list, [
|
||||
call(1, 1), call(2, 2), call(3, 3)])
|
||||
|
||||
def test_import_data_max_create(self):
|
||||
importer = self.make_importer()
|
||||
with patch.object(importer, 'get_key', lambda k: k):
|
||||
with patch.object(importer, 'create_object') as create:
|
||||
importer.import_data(host_data=[1, 2, 3], max_create=1)
|
||||
self.assertEqual(create.call_args_list, [call(1, 1)])
|
||||
|
||||
|
||||
class BulkImporterBattery(ImporterBattery):
|
||||
"""
|
||||
Battery of tests which can hopefully be ran for any bulk importer.
|
||||
"""
|
||||
|
||||
def test_import_data_empty(self):
|
||||
importer = self.make_importer()
|
||||
result = importer.import_data()
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
|
||||
class TestImporter(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
@ -164,6 +207,11 @@ class TestFromQuery(RattailTestCase):
|
|||
self.assertIsInstance(objects, QuerySequence)
|
||||
|
||||
|
||||
class TestBulkImporter(TestCase, BulkImporterBattery):
|
||||
importer_class = importers.BulkImporter
|
||||
|
||||
|
||||
|
||||
######################################################################
|
||||
# fake importer class, tested mostly for basic coverage
|
||||
######################################################################
|
||||
|
|
Loading…
Reference in a new issue