Add BulkImporter
and BulkImportHandler
base classes
This commit is contained in:
parent
10040a8c3b
commit
7a2ef35518
|
@ -26,9 +26,9 @@ Data Importing Framework
|
||||||
|
|
||||||
from __future__ import unicode_literals, absolute_import
|
from __future__ import unicode_literals, absolute_import
|
||||||
|
|
||||||
from .importers import Importer, FromQuery
|
from .importers import Importer, FromQuery, BulkImporter
|
||||||
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
|
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
|
||||||
from .postgresql import BulkToPostgreSQL
|
from .postgresql import BulkToPostgreSQL
|
||||||
from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler, BulkToPostgreSQLHandler
|
from .handlers import ImportHandler, BulkImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler
|
||||||
from .rattail import FromRattailHandler, ToRattailHandler
|
from .rattail import FromRattailHandler, ToRattailHandler
|
||||||
from . import model
|
from . import model
|
||||||
|
|
|
@ -229,6 +229,55 @@ class ImportHandler(object):
|
||||||
log.info("warning email was sent for {} -> {} import".format(self.host_title, self.local_title))
|
log.info("warning email was sent for {} -> {} import".format(self.host_title, self.local_title))
|
||||||
|
|
||||||
|
|
||||||
|
class BulkImportHandler(ImportHandler):
|
||||||
|
"""
|
||||||
|
Base class for bulk import handlers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def import_data(self, *keys, **kwargs):
|
||||||
|
"""
|
||||||
|
Import all data for the given importer/model keys.
|
||||||
|
"""
|
||||||
|
# TODO: still need to refactor much of this so can share with parent class
|
||||||
|
self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
||||||
|
if 'dry_run' in kwargs:
|
||||||
|
self.dry_run = kwargs['dry_run']
|
||||||
|
self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
|
||||||
|
self.warnings = kwargs.pop('warnings', False)
|
||||||
|
kwargs.update({'dry_run': self.dry_run,
|
||||||
|
'progress': self.progress})
|
||||||
|
self.setup()
|
||||||
|
self.begin_transaction()
|
||||||
|
changes = OrderedDict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for key in keys:
|
||||||
|
importer = self.get_importer(key, **kwargs)
|
||||||
|
if not importer:
|
||||||
|
log.warning("skipping unknown importer: {}".format(key))
|
||||||
|
continue
|
||||||
|
|
||||||
|
created = importer.import_data()
|
||||||
|
log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
|
||||||
|
self.host_title, self.local_title, created, key))
|
||||||
|
if created:
|
||||||
|
changes[key] = created
|
||||||
|
except:
|
||||||
|
if self.commit_host_partial and not self.dry_run:
|
||||||
|
log.warning("{host} -> {local}: committing partial transaction on host {host} (despite error)".format(
|
||||||
|
host=self.host_title, local=self.local_title))
|
||||||
|
self.commit_host_transaction()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
if self.dry_run:
|
||||||
|
self.rollback_transaction()
|
||||||
|
else:
|
||||||
|
self.commit_transaction()
|
||||||
|
|
||||||
|
self.teardown()
|
||||||
|
return changes
|
||||||
|
|
||||||
|
|
||||||
class FromSQLAlchemyHandler(ImportHandler):
|
class FromSQLAlchemyHandler(ImportHandler):
|
||||||
"""
|
"""
|
||||||
Handler for imports for which the host data source is represented by a
|
Handler for imports for which the host data source is represented by a
|
||||||
|
@ -292,42 +341,7 @@ class ToSQLAlchemyHandler(ImportHandler):
|
||||||
self.session = None
|
self.session = None
|
||||||
|
|
||||||
|
|
||||||
class BulkToPostgreSQLHandler(ToSQLAlchemyHandler):
|
class BulkToPostgreSQLHandler(BulkImportHandler):
|
||||||
"""
|
"""
|
||||||
Handler for bulk imports which target PostgreSQL on the local side.
|
Handler for bulk imports which target PostgreSQL on the local side.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def import_data(self, *keys, **kwargs):
|
|
||||||
"""
|
|
||||||
Import all data for the given importer/model keys.
|
|
||||||
"""
|
|
||||||
# TODO: still need to refactor much of this so can share with parent class
|
|
||||||
self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
|
||||||
if 'dry_run' in kwargs:
|
|
||||||
self.dry_run = kwargs['dry_run']
|
|
||||||
self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
|
|
||||||
self.warnings = kwargs.pop('warnings', False)
|
|
||||||
kwargs.update({'dry_run': self.dry_run,
|
|
||||||
'progress': self.progress})
|
|
||||||
self.setup()
|
|
||||||
self.begin_transaction()
|
|
||||||
changes = OrderedDict()
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
importer = self.get_importer(key, **kwargs)
|
|
||||||
if not importer:
|
|
||||||
log.warning("skipping unknown importer: {}".format(key))
|
|
||||||
continue
|
|
||||||
|
|
||||||
created = importer.import_data()
|
|
||||||
log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
|
|
||||||
self.host_title, self.local_title, created, key))
|
|
||||||
if created:
|
|
||||||
changes[key] = created
|
|
||||||
|
|
||||||
if self.dry_run:
|
|
||||||
self.rollback_transaction()
|
|
||||||
else:
|
|
||||||
self.commit_transaction()
|
|
||||||
self.teardown()
|
|
||||||
return changes
|
|
||||||
|
|
|
@ -456,3 +456,54 @@ class FromQuery(Importer):
|
||||||
Returns (raw) query results as a sequence.
|
Returns (raw) query results as a sequence.
|
||||||
"""
|
"""
|
||||||
return QuerySequence(self.query())
|
return QuerySequence(self.query())
|
||||||
|
|
||||||
|
|
||||||
|
class BulkImporter(Importer):
|
||||||
|
"""
|
||||||
|
Base class for bulk data importers which target PostgreSQL on the local side.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def import_data(self, host_data=None, now=None, **kwargs):
|
||||||
|
self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
||||||
|
if kwargs:
|
||||||
|
self._setup(**kwargs)
|
||||||
|
self.setup()
|
||||||
|
if host_data is None:
|
||||||
|
host_data = self.normalize_host_data()
|
||||||
|
created = self._import_create(host_data)
|
||||||
|
self.teardown()
|
||||||
|
return created
|
||||||
|
|
||||||
|
def _import_create(self, data):
|
||||||
|
count = len(data)
|
||||||
|
if not count:
|
||||||
|
return 0
|
||||||
|
created = count
|
||||||
|
|
||||||
|
prog = None
|
||||||
|
if self.progress:
|
||||||
|
prog = self.progress("Importing {} data".format(self.model_name), count)
|
||||||
|
|
||||||
|
for i, host_data in enumerate(data, 1):
|
||||||
|
|
||||||
|
key = self.get_key(host_data)
|
||||||
|
self.create_object(key, host_data)
|
||||||
|
if self.max_create and i >= self.max_create:
|
||||||
|
log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
|
||||||
|
created = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if prog:
|
||||||
|
prog.update(i)
|
||||||
|
if prog:
|
||||||
|
prog.destroy()
|
||||||
|
|
||||||
|
self.flush_create()
|
||||||
|
return created
|
||||||
|
|
||||||
|
def flush_create(self):
|
||||||
|
"""
|
||||||
|
Perform any final steps to "flush" the created data here. Note that
|
||||||
|
the importer's handler is still responsible for actually committing
|
||||||
|
changes to the local system, if applicable.
|
||||||
|
"""
|
||||||
|
|
|
@ -30,14 +30,14 @@ import os
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from rattail.importing.sqlalchemy import ToSQLAlchemy
|
from rattail.importing import BulkImporter, ToSQLAlchemy
|
||||||
from rattail.time import make_utc
|
from rattail.time import make_utc
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BulkToPostgreSQL(ToSQLAlchemy):
|
class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy):
|
||||||
"""
|
"""
|
||||||
Base class for bulk data importers which target PostgreSQL on the local side.
|
Base class for bulk data importers which target PostgreSQL on the local side.
|
||||||
"""
|
"""
|
||||||
|
@ -55,44 +55,6 @@ class BulkToPostgreSQL(ToSQLAlchemy):
|
||||||
os.remove(self.data_path)
|
os.remove(self.data_path)
|
||||||
self.data_buffer = None
|
self.data_buffer = None
|
||||||
|
|
||||||
def import_data(self, host_data=None, now=None, **kwargs):
|
|
||||||
self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
|
|
||||||
if kwargs:
|
|
||||||
self._setup(**kwargs)
|
|
||||||
self.setup()
|
|
||||||
if host_data is None:
|
|
||||||
host_data = self.normalize_host_data()
|
|
||||||
created = self._import_create(host_data)
|
|
||||||
self.teardown()
|
|
||||||
return created
|
|
||||||
|
|
||||||
def _import_create(self, data):
|
|
||||||
count = len(data)
|
|
||||||
if not count:
|
|
||||||
return 0
|
|
||||||
created = count
|
|
||||||
|
|
||||||
prog = None
|
|
||||||
if self.progress:
|
|
||||||
prog = self.progress("Importing {} data".format(self.model_name), count)
|
|
||||||
|
|
||||||
for i, host_data in enumerate(data, 1):
|
|
||||||
|
|
||||||
key = self.get_key(host_data)
|
|
||||||
self.create_object(key, host_data)
|
|
||||||
if self.max_create and i >= self.max_create:
|
|
||||||
log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
|
|
||||||
created = i
|
|
||||||
break
|
|
||||||
|
|
||||||
if prog:
|
|
||||||
prog.update(i)
|
|
||||||
if prog:
|
|
||||||
prog.destroy()
|
|
||||||
|
|
||||||
self.commit_create()
|
|
||||||
return created
|
|
||||||
|
|
||||||
def create_object(self, key, data):
|
def create_object(self, key, data):
|
||||||
data = self.prep_data_for_postgres(data)
|
data = self.prep_data_for_postgres(data)
|
||||||
self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])).encode('utf-8'))
|
self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])).encode('utf-8'))
|
||||||
|
@ -121,7 +83,7 @@ class BulkToPostgreSQL(ToSQLAlchemy):
|
||||||
|
|
||||||
return unicode(value)
|
return unicode(value)
|
||||||
|
|
||||||
def commit_create(self):
|
def flush_create(self):
|
||||||
log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
|
log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
|
||||||
self.data_buffer.close()
|
self.data_buffer.close()
|
||||||
self.data_buffer = open(self.data_path, 'rb')
|
self.data_buffer = open(self.data_path, 'rb')
|
||||||
|
|
|
@ -31,7 +31,7 @@ from rattail.util import OrderedDict
|
||||||
from rattail.importing.rattail import FromRattailToRattail, FromRattail
|
from rattail.importing.rattail import FromRattailToRattail, FromRattail
|
||||||
|
|
||||||
|
|
||||||
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkToPostgreSQLHandler):
|
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkImportHandler):
|
||||||
"""
|
"""
|
||||||
Handler for Rattail -> Rattail bulk data import.
|
Handler for Rattail -> Rattail bulk data import.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -19,12 +19,12 @@ from rattail.tests.importing.test_importers import MockImporter
|
||||||
from rattail.tests.importing.test_postgresql import MockBulkImporter
|
from rattail.tests.importing.test_postgresql import MockBulkImporter
|
||||||
|
|
||||||
|
|
||||||
class TestImportHandler(unittest.TestCase):
|
class ImportHandlerBattery(ImporterTester):
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
|
|
||||||
# vanilla
|
# vanilla
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
self.assertEqual(handler.importers, {})
|
self.assertEqual(handler.importers, {})
|
||||||
self.assertEqual(handler.get_importers(), {})
|
self.assertEqual(handler.get_importers(), {})
|
||||||
self.assertEqual(handler.get_importer_keys(), [])
|
self.assertEqual(handler.get_importer_keys(), [])
|
||||||
|
@ -32,34 +32,34 @@ class TestImportHandler(unittest.TestCase):
|
||||||
self.assertFalse(handler.commit_host_partial)
|
self.assertFalse(handler.commit_host_partial)
|
||||||
|
|
||||||
# with config
|
# with config
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
self.assertIsNone(handler.config)
|
self.assertIsNone(handler.config)
|
||||||
config = RattailConfig()
|
config = RattailConfig()
|
||||||
handler = handlers.ImportHandler(config=config)
|
handler = self.handler_class(config=config)
|
||||||
self.assertIs(handler.config, config)
|
self.assertIs(handler.config, config)
|
||||||
|
|
||||||
# dry run
|
# dry run
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
self.assertFalse(handler.dry_run)
|
self.assertFalse(handler.dry_run)
|
||||||
handler = handlers.ImportHandler(dry_run=True)
|
handler = self.handler_class(dry_run=True)
|
||||||
self.assertTrue(handler.dry_run)
|
self.assertTrue(handler.dry_run)
|
||||||
|
|
||||||
# extra kwarg
|
# extra kwarg
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
self.assertRaises(AttributeError, getattr, handler, 'foo')
|
self.assertRaises(AttributeError, getattr, handler, 'foo')
|
||||||
handler = handlers.ImportHandler(foo='bar')
|
handler = self.handler_class(foo='bar')
|
||||||
self.assertEqual(handler.foo, 'bar')
|
self.assertEqual(handler.foo, 'bar')
|
||||||
|
|
||||||
def test_get_importer(self):
|
def test_get_importer(self):
|
||||||
get_importers = Mock(return_value={'foo': Importer})
|
get_importers = Mock(return_value={'foo': Importer})
|
||||||
|
|
||||||
# no importers
|
# no importers
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
self.assertIsNone(handler.get_importer('foo'))
|
self.assertIsNone(handler.get_importer('foo'))
|
||||||
|
|
||||||
# no config
|
# no config
|
||||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
importer = handler.get_importer('foo')
|
importer = handler.get_importer('foo')
|
||||||
self.assertIs(type(importer), Importer)
|
self.assertIs(type(importer), Importer)
|
||||||
self.assertIsNone(importer.config)
|
self.assertIsNone(importer.config)
|
||||||
|
@ -67,26 +67,26 @@ class TestImportHandler(unittest.TestCase):
|
||||||
|
|
||||||
# with config
|
# with config
|
||||||
config = RattailConfig()
|
config = RattailConfig()
|
||||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||||
handler = handlers.ImportHandler(config=config)
|
handler = self.handler_class(config=config)
|
||||||
importer = handler.get_importer('foo')
|
importer = handler.get_importer('foo')
|
||||||
self.assertIs(type(importer), Importer)
|
self.assertIs(type(importer), Importer)
|
||||||
self.assertIs(importer.config, config)
|
self.assertIs(importer.config, config)
|
||||||
self.assertIs(importer.handler, handler)
|
self.assertIs(importer.handler, handler)
|
||||||
|
|
||||||
# dry run
|
# dry run
|
||||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
importer = handler.get_importer('foo')
|
importer = handler.get_importer('foo')
|
||||||
self.assertFalse(importer.dry_run)
|
self.assertFalse(importer.dry_run)
|
||||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||||
handler = handlers.ImportHandler(dry_run=True)
|
handler = self.handler_class(dry_run=True)
|
||||||
importer = handler.get_importer('foo')
|
importer = handler.get_importer('foo')
|
||||||
self.assertTrue(handler.dry_run)
|
self.assertTrue(handler.dry_run)
|
||||||
|
|
||||||
# host title
|
# host title
|
||||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
importer = handler.get_importer('foo')
|
importer = handler.get_importer('foo')
|
||||||
self.assertIsNone(importer.host_system_title)
|
self.assertIsNone(importer.host_system_title)
|
||||||
handler.host_title = "Foo"
|
handler.host_title = "Foo"
|
||||||
|
@ -94,8 +94,8 @@ class TestImportHandler(unittest.TestCase):
|
||||||
self.assertEqual(importer.host_system_title, "Foo")
|
self.assertEqual(importer.host_system_title, "Foo")
|
||||||
|
|
||||||
# extra kwarg
|
# extra kwarg
|
||||||
with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
|
with patch.object(self.handler_class, 'get_importers', get_importers):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.handler_class()
|
||||||
importer = handler.get_importer('foo')
|
importer = handler.get_importer('foo')
|
||||||
self.assertRaises(AttributeError, getattr, importer, 'bar')
|
self.assertRaises(AttributeError, getattr, importer, 'bar')
|
||||||
importer = handler.get_importer('foo', bar='baz')
|
importer = handler.get_importer('foo', bar='baz')
|
||||||
|
@ -104,15 +104,15 @@ class TestImportHandler(unittest.TestCase):
|
||||||
def test_get_importer_kwargs(self):
|
def test_get_importer_kwargs(self):
|
||||||
|
|
||||||
# empty by default
|
# empty by default
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
self.assertEqual(handler.get_importer_kwargs('foo'), {})
|
self.assertEqual(handler.get_importer_kwargs('foo'), {})
|
||||||
|
|
||||||
# extra kwargs are preserved
|
# extra kwargs are preserved
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'})
|
self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'})
|
||||||
|
|
||||||
def test_begin_transaction(self):
|
def test_begin_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
with patch.object(handler, 'begin_host_transaction') as begin_host:
|
with patch.object(handler, 'begin_host_transaction') as begin_host:
|
||||||
with patch.object(handler, 'begin_local_transaction') as begin_local:
|
with patch.object(handler, 'begin_local_transaction') as begin_local:
|
||||||
handler.begin_transaction()
|
handler.begin_transaction()
|
||||||
|
@ -120,15 +120,15 @@ class TestImportHandler(unittest.TestCase):
|
||||||
begin_local.assert_called_once_with()
|
begin_local.assert_called_once_with()
|
||||||
|
|
||||||
def test_begin_host_transaction(self):
|
def test_begin_host_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
handler.begin_host_transaction()
|
handler.begin_host_transaction()
|
||||||
|
|
||||||
def test_begin_local_transaction(self):
|
def test_begin_local_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
handler.begin_local_transaction()
|
handler.begin_local_transaction()
|
||||||
|
|
||||||
def test_commit_transaction(self):
|
def test_commit_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
with patch.object(handler, 'commit_host_transaction') as commit_host:
|
with patch.object(handler, 'commit_host_transaction') as commit_host:
|
||||||
with patch.object(handler, 'commit_local_transaction') as commit_local:
|
with patch.object(handler, 'commit_local_transaction') as commit_local:
|
||||||
handler.commit_transaction()
|
handler.commit_transaction()
|
||||||
|
@ -136,15 +136,15 @@ class TestImportHandler(unittest.TestCase):
|
||||||
commit_local.assert_called_once_with()
|
commit_local.assert_called_once_with()
|
||||||
|
|
||||||
def test_commit_host_transaction(self):
|
def test_commit_host_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
handler.commit_host_transaction()
|
handler.commit_host_transaction()
|
||||||
|
|
||||||
def test_commit_local_transaction(self):
|
def test_commit_local_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
handler.commit_local_transaction()
|
handler.commit_local_transaction()
|
||||||
|
|
||||||
def test_rollback_transaction(self):
|
def test_rollback_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
with patch.object(handler, 'rollback_host_transaction') as rollback_host:
|
with patch.object(handler, 'rollback_host_transaction') as rollback_host:
|
||||||
with patch.object(handler, 'rollback_local_transaction') as rollback_local:
|
with patch.object(handler, 'rollback_local_transaction') as rollback_local:
|
||||||
handler.rollback_transaction()
|
handler.rollback_transaction()
|
||||||
|
@ -152,24 +152,22 @@ class TestImportHandler(unittest.TestCase):
|
||||||
rollback_local.assert_called_once_with()
|
rollback_local.assert_called_once_with()
|
||||||
|
|
||||||
def test_rollback_host_transaction(self):
|
def test_rollback_host_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
handler.rollback_host_transaction()
|
handler.rollback_host_transaction()
|
||||||
|
|
||||||
def test_rollback_local_transaction(self):
|
def test_rollback_local_transaction(self):
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
handler.rollback_local_transaction()
|
handler.rollback_local_transaction()
|
||||||
|
|
||||||
def test_import_data(self):
|
def test_import_data(self):
|
||||||
|
handler = self.make_handler()
|
||||||
# normal
|
|
||||||
handler = handlers.ImportHandler()
|
|
||||||
result = handler.import_data()
|
result = handler.import_data()
|
||||||
self.assertEqual(result, {})
|
self.assertEqual(result, {})
|
||||||
|
|
||||||
def test_import_data_dry_run(self):
|
def test_import_data_dry_run(self):
|
||||||
|
|
||||||
# as init kwarg
|
# as init kwarg
|
||||||
handler = handlers.ImportHandler(dry_run=True)
|
handler = self.make_handler(dry_run=True)
|
||||||
with patch.object(handler, 'commit_transaction') as commit:
|
with patch.object(handler, 'commit_transaction') as commit:
|
||||||
with patch.object(handler, 'rollback_transaction') as rollback:
|
with patch.object(handler, 'rollback_transaction') as rollback:
|
||||||
handler.import_data()
|
handler.import_data()
|
||||||
|
@ -178,7 +176,7 @@ class TestImportHandler(unittest.TestCase):
|
||||||
self.assertTrue(handler.dry_run)
|
self.assertTrue(handler.dry_run)
|
||||||
|
|
||||||
# as import kwarg
|
# as import kwarg
|
||||||
handler = handlers.ImportHandler()
|
handler = self.make_handler()
|
||||||
with patch.object(handler, 'commit_transaction') as commit:
|
with patch.object(handler, 'commit_transaction') as commit:
|
||||||
with patch.object(handler, 'rollback_transaction') as rollback:
|
with patch.object(handler, 'rollback_transaction') as rollback:
|
||||||
handler.import_data(dry_run=True)
|
handler.import_data(dry_run=True)
|
||||||
|
@ -187,11 +185,10 @@ class TestImportHandler(unittest.TestCase):
|
||||||
self.assertTrue(handler.dry_run)
|
self.assertTrue(handler.dry_run)
|
||||||
|
|
||||||
def test_import_data_invalid_model(self):
|
def test_import_data_invalid_model(self):
|
||||||
|
handler = self.make_handler()
|
||||||
importer = Mock()
|
importer = Mock()
|
||||||
importer.import_data.return_value = [], [], []
|
importer.import_data.return_value = [], [], []
|
||||||
FooImporter = Mock(return_value=importer)
|
FooImporter = Mock(return_value=importer)
|
||||||
|
|
||||||
handler = handlers.ImportHandler()
|
|
||||||
handler.importers = {'Foo': FooImporter}
|
handler.importers = {'Foo': FooImporter}
|
||||||
|
|
||||||
handler.import_data('Foo')
|
handler.import_data('Foo')
|
||||||
|
@ -206,10 +203,9 @@ class TestImportHandler(unittest.TestCase):
|
||||||
self.assertFalse(importer.called)
|
self.assertFalse(importer.called)
|
||||||
|
|
||||||
def test_import_data_with_changes(self):
|
def test_import_data_with_changes(self):
|
||||||
|
handler = self.make_handler()
|
||||||
importer = Mock()
|
importer = Mock()
|
||||||
FooImporter = Mock(return_value=importer)
|
FooImporter = Mock(return_value=importer)
|
||||||
|
|
||||||
handler = handlers.ImportHandler()
|
|
||||||
handler.importers = {'Foo': FooImporter}
|
handler.importers = {'Foo': FooImporter}
|
||||||
|
|
||||||
importer.import_data.return_value = [], [], []
|
importer.import_data.return_value = [], [], []
|
||||||
|
@ -223,11 +219,10 @@ class TestImportHandler(unittest.TestCase):
|
||||||
process.assert_called_once_with({'Foo': ([1], [2], [3])})
|
process.assert_called_once_with({'Foo': ([1], [2], [3])})
|
||||||
|
|
||||||
def test_import_data_commit_host_partial(self):
|
def test_import_data_commit_host_partial(self):
|
||||||
|
handler = self.make_handler()
|
||||||
importer = Mock()
|
importer = Mock()
|
||||||
importer.import_data.side_effect = ValueError
|
importer.import_data.side_effect = ValueError
|
||||||
FooImporter = Mock(return_value=importer)
|
FooImporter = Mock(return_value=importer)
|
||||||
|
|
||||||
handler = handlers.ImportHandler()
|
|
||||||
handler.importers = {'Foo': FooImporter}
|
handler.importers = {'Foo': FooImporter}
|
||||||
|
|
||||||
handler.commit_host_partial = False
|
handler.commit_host_partial = False
|
||||||
|
@ -240,6 +235,47 @@ class TestImportHandler(unittest.TestCase):
|
||||||
self.assertRaises(ValueError, handler.import_data, 'Foo')
|
self.assertRaises(ValueError, handler.import_data, 'Foo')
|
||||||
commit.assert_called_once_with()
|
commit.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
|
class BulkImportHandlerBattery(ImportHandlerBattery):
|
||||||
|
|
||||||
|
def test_import_data_invalid_model(self):
|
||||||
|
handler = self.make_handler()
|
||||||
|
importer = Mock()
|
||||||
|
importer.import_data.return_value = 0
|
||||||
|
FooImporter = Mock(return_value=importer)
|
||||||
|
handler.importers = {'Foo': FooImporter}
|
||||||
|
|
||||||
|
handler.import_data('Foo')
|
||||||
|
self.assertEqual(FooImporter.call_count, 1)
|
||||||
|
importer.import_data.assert_called_once_with()
|
||||||
|
|
||||||
|
FooImporter.reset_mock()
|
||||||
|
importer.reset_mock()
|
||||||
|
|
||||||
|
handler.import_data('Missing')
|
||||||
|
self.assertFalse(FooImporter.called)
|
||||||
|
self.assertFalse(importer.called)
|
||||||
|
|
||||||
|
def test_import_data_with_changes(self):
|
||||||
|
handler = self.make_handler()
|
||||||
|
importer = Mock()
|
||||||
|
FooImporter = Mock(return_value=importer)
|
||||||
|
handler.importers = {'Foo': FooImporter}
|
||||||
|
|
||||||
|
importer.import_data.return_value = 0
|
||||||
|
with patch.object(handler, 'process_changes') as process:
|
||||||
|
handler.import_data('Foo')
|
||||||
|
self.assertFalse(process.called)
|
||||||
|
|
||||||
|
importer.import_data.return_value = 3
|
||||||
|
with patch.object(handler, 'process_changes') as process:
|
||||||
|
handler.import_data('Foo')
|
||||||
|
self.assertFalse(process.called)
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
|
||||||
|
handler_class = handlers.ImportHandler
|
||||||
|
|
||||||
@patch('rattail.importing.handlers.send_email')
|
@patch('rattail.importing.handlers.send_email')
|
||||||
def test_process_changes_sends_email(self, send_email):
|
def test_process_changes_sends_email(self, send_email):
|
||||||
handler = handlers.ImportHandler()
|
handler = handlers.ImportHandler()
|
||||||
|
@ -265,6 +301,10 @@ class TestImportHandler(unittest.TestCase):
|
||||||
self.assertEqual(send_email.call_count, 1)
|
self.assertEqual(send_email.call_count, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkImportHandler(unittest.TestCase, BulkImportHandlerBattery):
|
||||||
|
handler_class = handlers.BulkImportHandler
|
||||||
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
# fake import handler, tested mostly for basic coverage
|
# fake import handler, tested mostly for basic coverage
|
||||||
######################################################################
|
######################################################################
|
||||||
|
@ -566,7 +606,7 @@ class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler):
|
||||||
return Session()
|
return Session()
|
||||||
|
|
||||||
|
|
||||||
class TestBulkImportHandler(RattailTestCase, ImporterTester):
|
class TestBulkImportHandlerOld(RattailTestCase, ImporterTester):
|
||||||
|
|
||||||
importer_class = MockBulkImporter
|
importer_class = MockBulkImporter
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import unicode_literals, absolute_import
|
||||||
|
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from mock import Mock, patch
|
from mock import Mock, patch, call
|
||||||
|
|
||||||
from rattail.db import model
|
from rattail.db import model
|
||||||
from rattail.db.util import QuerySequence
|
from rattail.db.util import QuerySequence
|
||||||
|
@ -13,6 +13,49 @@ from rattail.tests import NullProgress, RattailTestCase
|
||||||
from rattail.tests.importing import ImporterTester
|
from rattail.tests.importing import ImporterTester
|
||||||
|
|
||||||
|
|
||||||
|
class ImporterBattery(ImporterTester):
|
||||||
|
"""
|
||||||
|
Battery of tests which can hopefully be ran for any non-bulk importer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_import_data_empty(self):
|
||||||
|
importer = self.make_importer()
|
||||||
|
result = importer.import_data()
|
||||||
|
self.assertEqual(result, {})
|
||||||
|
|
||||||
|
def test_import_data_dry_run(self):
|
||||||
|
importer = self.make_importer()
|
||||||
|
self.assertFalse(importer.dry_run)
|
||||||
|
importer.import_data(dry_run=True)
|
||||||
|
self.assertTrue(importer.dry_run)
|
||||||
|
|
||||||
|
def test_import_data_create(self):
|
||||||
|
importer = self.make_importer()
|
||||||
|
with patch.object(importer, 'get_key', lambda k: k):
|
||||||
|
with patch.object(importer, 'create_object') as create:
|
||||||
|
importer.import_data(host_data=[1, 2, 3])
|
||||||
|
self.assertEqual(create.call_args_list, [
|
||||||
|
call(1, 1), call(2, 2), call(3, 3)])
|
||||||
|
|
||||||
|
def test_import_data_max_create(self):
|
||||||
|
importer = self.make_importer()
|
||||||
|
with patch.object(importer, 'get_key', lambda k: k):
|
||||||
|
with patch.object(importer, 'create_object') as create:
|
||||||
|
importer.import_data(host_data=[1, 2, 3], max_create=1)
|
||||||
|
self.assertEqual(create.call_args_list, [call(1, 1)])
|
||||||
|
|
||||||
|
|
||||||
|
class BulkImporterBattery(ImporterBattery):
|
||||||
|
"""
|
||||||
|
Battery of tests which can hopefully be ran for any bulk importer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_import_data_empty(self):
|
||||||
|
importer = self.make_importer()
|
||||||
|
result = importer.import_data()
|
||||||
|
self.assertEqual(result, 0)
|
||||||
|
|
||||||
|
|
||||||
class TestImporter(TestCase):
|
class TestImporter(TestCase):
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
|
@ -164,6 +207,11 @@ class TestFromQuery(RattailTestCase):
|
||||||
self.assertIsInstance(objects, QuerySequence)
|
self.assertIsInstance(objects, QuerySequence)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkImporter(TestCase, BulkImporterBattery):
|
||||||
|
importer_class = importers.BulkImporter
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
# fake importer class, tested mostly for basic coverage
|
# fake importer class, tested mostly for basic coverage
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
Loading…
Reference in a new issue