Remove all references to old importer frameworks

we now have a winner, these served their purpose but are no longer wanted
This commit is contained in:
Lance Edgar 2017-05-25 14:31:00 -05:00
parent bb8193ad48
commit 3b9752a1b2
16 changed files with 6 additions and 5758 deletions

View file

@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8; -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2016 Lance Edgar
# Copyright © 2010-2017 Lance Edgar
#
# This file is part of Rattail.
#
@ -26,5 +26,5 @@ Console Commands
from __future__ import unicode_literals, absolute_import
from .core import main, Command, Subcommand, OldImportSubcommand, NewImportSubcommand, Dump, date_argument
from .core import main, Command, Subcommand, Dump, date_argument
from .importing import ImportSubcommand

View file

@ -749,251 +749,6 @@ class FileMonitorCommand(Subcommand):
service.delayed_auto_start_service(name)
class OldImportSubcommand(Subcommand):
"""
Base class for subcommands which use the data importing system.
"""
supports_versioning = True
def add_parser_args(self, parser):
handler = self.get_handler(quiet=True)
if self.supports_versioning:
parser.add_argument('--no-versioning', action='store_true',
help="Disables versioning during the import. This is "
"intended to be useful e.g. during initial import, where "
"the process can be quite slow even without the overhead "
"of versioning.")
parser.add_argument('--warnings', '-W', action='store_true',
help="Whether to log warnings if any data model "
"writes occur. Intended to help stay in sync "
"with an external data source.")
parser.add_argument('--max-updates', type=int,
help="Maximum number of record updates (or additions) which, if "
"reached, should cause the importer to stop early. Note that the "
"updates which have completed will be committed unless a dry run "
"is in effect.")
parser.add_argument('--dry-run', action='store_true',
help="Go through the motions and allow logging to occur, "
"but do not actually commit the transaction at the end.")
parser.add_argument('models', nargs='*', metavar='MODEL',
help="Which models to import. If none are specified, all models will "
"be imported. Or, specify only those you wish to import. Supported "
"models are: {0}".format(', '.join(handler.get_importer_keys())))
def run(self, args):
log.info("begin {0} for data model(s): {1}".format(
self.name, ', '.join(args.models or ["ALL"])))
Session = self.parent.db_session_factory
if self.supports_versioning:
if args.no_versioning:
from rattail.db.continuum import disable_versioning
disable_versioning()
session = Session(continuum_user=self.continuum_user)
else:
session = Session()
self.import_data(args, session)
if args.dry_run:
session.rollback()
log.info("dry run, so transaction was rolled back")
else:
session.commit()
log.info("transaction was committed")
session.close()
def get_handler_factory(self, quiet=False):
"""
This method must return a factory, which will in turn generate a
handler instance to be used by the command. Note that you *must*
override this method.
"""
raise NotImplementedError
def get_handler(self, **kwargs):
"""
Returns a handler instance to be used by the command.
"""
factory = self.get_handler_factory(quiet=kwargs.pop('quiet', False))
return factory(getattr(self, 'config', None), **kwargs)
@property
def continuum_user(self):
"""
Info needed to assign the Continuum user for the database session.
"""
def import_data(self, args, session):
"""
Perform a data import, with the given arguments and database session.
"""
handler = self.get_handler(session=session)
models = args.models or handler.get_importer_keys()
updates = handler.import_data(models, max_updates=args.max_updates,
progress=self.progress)
if args.warnings and updates:
handler.process_warnings(updates, command=self, models=models, dry_run=args.dry_run,
render_record=self.get_record_renderer(),
progress=self.progress)
def get_record_renderer(self):
"""
Get the record renderer for email notifications. Note that config may
override the default.
"""
spec = self.config.get('{0}.{1}'.format(self.parent.name, self.name), 'record_renderer',
default='rattail.db.importing:RecordRenderer')
return load_object(spec)(self.config)
class NewImportSubcommand(Subcommand):
"""
Base class for subcommands which use the (new) data importing system.
"""
def get_handler_factory(self, args=None):
"""
This method must return a factory, which will in turn generate a
handler instance to be used by the command. Note that you *must*
override this method.
"""
raise NotImplementedError
def get_handler(self, args=None, **kwargs):
"""
Returns a handler instance to be used by the command.
"""
factory = self.get_handler_factory(args)
kwargs = self.get_handler_kwargs(args, **kwargs)
kwargs['command'] = self
return factory(getattr(self, 'config', None), **kwargs)
def get_handler_kwargs(self, args, **kwargs):
"""
Return a dict of kwargs to be passed to the handler factory.
"""
return kwargs
def add_parser_args(self, parser):
handler = self.get_handler()
# model names (aka importer keys)
parser.add_argument('models', nargs='*', metavar='MODEL',
help="Which data models to import. If you specify any, then only data "
"for those models will be imported. If you do not specify any, then all "
"*default* models will be imported. Supported models are: ({})".format(
', '.join(handler.get_importer_keys())))
# start/end date
parser.add_argument('--start-date', type=date_argument,
help="Optional (inclusive) starting point for date range, by which host "
"data should be filtered. Only used by certain importers.")
parser.add_argument('--end-date', type=date_argument,
help="Optional (inclusive) ending point for date range, by which host "
"data should be filtered. Only used by certain importers.")
# allow create?
parser.add_argument('--create', action='store_true', default=True,
help="Allow new records to be created during the import.")
parser.add_argument('--no-create', action='store_false', dest='create',
help="Do not allow new records to be created during the import.")
parser.add_argument('--max-create', type=int, metavar='COUNT',
help="Maximum number of records which may be created, after which a "
"given import task should stop. Note that this applies on a per-model "
"basis and not overall.")
# allow update?
parser.add_argument('--update', action='store_true', default=True,
help="Allow existing records to be updated during the import.")
parser.add_argument('--no-update', action='store_false', dest='update',
help="Do not allow existing records to be updated during the import.")
parser.add_argument('--max-update', type=int, metavar='COUNT',
help="Maximum number of records which may be updated, after which a "
"given import task should stop. Note that this applies on a per-model "
"basis and not overall.")
# allow delete?
parser.add_argument('--delete', action='store_true', default=False,
help="Allow records to be deleted during the import.")
parser.add_argument('--no-delete', action='store_false', dest='delete',
help="Do not allow records to be deleted during the import.")
parser.add_argument('--max-delete', type=int, metavar='COUNT',
help="Maximum number of records which may be deleted, after which a "
"given import task should stop. Note that this applies on a per-model "
"basis and not overall.")
# max total changes, per model
parser.add_argument('--max-total', type=int, metavar='COUNT',
help="Maximum number of *any* record changes which may occur, after which "
"a given import task should stop. Note that this applies on a per-model "
"basis and not overall.")
# treat changes as warnings?
parser.add_argument('--warnings', '-W', action='store_true',
help="Set this flag if you expect a \"clean\" import, and wish for any "
"changes which do occur to be processed further and/or specially. The "
"behavior of this flag is ultimately up to the import handler, but the "
"default is to send an email notification.")
# dry run?
parser.add_argument('--dry-run', action='store_true',
help="Go through the full motions and allow logging etc. to "
"occur, but rollback (abort) the transaction at the end.")
def run(self, args):
log.info("begin `{} {}` for data models: {}".format(
self.parent.name, self.name, ', '.join(args.models or ["(ALL)"])))
Session = self.parent.db_session_factory
session = Session()
self.import_data(args, session)
if args.dry_run:
session.rollback()
log.info("dry run, so transaction was rolled back")
else:
session.commit()
log.info("transaction was committed")
session.close()
def import_data(self, args, session):
"""
Perform a data import, with the given arguments and database session.
"""
handler = self.get_handler(args=args, session=session, progress=self.progress)
models = args.models or handler.get_default_keys()
log.debug("using handler: {}".format(handler))
log.debug("importing models: {}".format(models))
log.debug("args are: {}".format(args))
handler.import_data(models, args)
class ImportCSV(OldImportSubcommand):
"""
Import data from a CSV file
"""
name = 'import-csv'
description = __doc__.strip()
def add_parser_args(self, parser):
super(ImportCSV, self).add_parser_args(parser)
parser.add_argument('importer',
help="Spec string for importer class which should handle the import.")
parser.add_argument('csv_path',
help="Path to the data file which will be imported.")
def import_data(self, args, session):
from rattail.db.importing.providers.csv import make_provider
provider = make_provider(self.config, session, args.importer, data_path=args.csv_path)
data = provider.get_data(progress=self.progress)
affected = provider.importer.import_data(data, provider.supported_fields, 'uuid',
progress=self.progress)
log.info("added or updated {0} {1} records".format(affected, provider.model_name))
class LoadHostDataCommand(Subcommand):
"""
Loads data from the Rattail host database, if one is configured.

View file

@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8; -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2016 Lance Edgar
# Copyright © 2010-2017 Lance Edgar
#
# This file is part of Rattail.
#
@ -27,4 +27,4 @@ DataSync Daemon
from __future__ import unicode_literals, absolute_import
from .watchers import DataSyncWatcher
from .consumers import DataSyncConsumer, DataSyncImportConsumer, NewDataSyncImportConsumer
from .consumers import DataSyncConsumer, NewDataSyncImportConsumer

View file

@ -28,152 +28,9 @@ from __future__ import unicode_literals, absolute_import
from rattail import importing
from rattail.config import parse_list
from rattail.db.newimporting import ImportHandler
from rattail.util import load_object
class DataSyncConsumer(object):
"""
Base class for all DataSync consumers.
"""
def __init__(self, config, key, dbkey=None):
self.config = config
self.key = key
self.dbkey = dbkey
def setup(self):
"""
This method is called when the consumer thread is first started.
"""
def begin_transaction(self):
"""
Called just before the consumer is asked to process changes, possibly
via multiple batches.
"""
def process_changes(self, session, changes):
"""
Process (consume) a batch of changes.
"""
def rollback_transaction(self):
"""
Called when any batch of changes failed to process.
"""
def commit_transaction(self):
"""
Called just after the consumer has successfully finished processing
changes, possibly via multiple batches.
"""
class DataSyncImportConsumer(DataSyncConsumer):
"""
Base class for DataSync consumer which is able to leverage a (set of)
importer(s) to do the heavy lifting.
.. note::
This assumes "old-style" importers based on
``rattail.db.newimporting.Importer``.
"""
def __init__(self, *args, **kwargs):
super(DataSyncImportConsumer, self).__init__(*args, **kwargs)
self.importers = self.get_importers()
def get_importers(self):
"""
You must override this to return a dict of importer *instances*, keyed
by what you expect the corresponding ``DataSyncChange.payload_type`` to
be, coming from the "host" system, whatever that is.
"""
raise NotImplementedError
def get_importers_from_handler(self, handler, default_only=True):
if not isinstance(handler, ImportHandler):
handler = handler(config=self.config)
factories = handler.get_importers()
if default_only:
keys = handler.get_default_keys()
else:
keys = factories.keys()
importers = {}
for key in keys:
importers[key] = factories[key](config=self.config)
return importers
def process_changes(self, session, changes):
"""
Process all changes, leveraging importer(s) as much as possible.
"""
# Update all importers with current Rattail session.
for importer in self.importers.itervalues():
importer.session = session
for change in changes:
self.invoke_importer(session, change)
def invoke_importer(self, session, change):
"""
For the given change, invoke the default importer behavior, if one is
available.
"""
importer = self.importers.get(change.payload_type)
if importer:
if change.deletion:
self.process_deletion(session, importer, change)
else:
return self.process_change(session, importer, change)
def process_change(self, session, importer, change=None, host_object=None, host_data=None):
"""
Invoke the importer to process the given change / host record.
"""
if host_data is None:
if host_object is None:
host_object = self.get_host_record(session, change)
if host_object is None:
return
host_data = importer.normalize_source_record(host_object)
if host_data is None:
return
key = importer.get_key(host_data)
local_object = importer.get_instance(key)
if local_object:
local_data = importer.normalize_instance(local_object)
if importer.data_diffs(local_data, host_data):
local_object = importer.update_instance(local_object, host_data, local_data)
return local_object
else:
return importer.create_instance(key, host_data)
def process_deletion(self, session, importer, change):
"""
Attempt to invoke the importer, to delete a local record according to
the change involved.
"""
key = self.get_deletion_key(session, change)
local_object = importer.get_instance(key)
if local_object:
return importer.delete_instance(local_object)
return False
def get_deletion_key(self, session, change):
return (change.payload_key,)
def get_host_record(self, session, change):
"""
You must override this, to return a host record from the given
``DataSyncChange`` instance. Note that the host record need *not* be
normalized, as that will be done by the importer. (This is effectively
the only part of the processing which is not handled by the importer.)
"""
raise NotImplementedError
class NewDataSyncImportConsumer(DataSyncConsumer):
"""
Base class for DataSync consumer which is able to leverage a (set of)

View file

@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2015 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Data Importing
"""
from .core import Importer, make_importer, RecordRenderer
from . import models
from .providers import DataProvider, QueryProvider
from .handlers import ImportHandler

View file

@ -1,406 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2015 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Core Importer Stuff
"""
from __future__ import unicode_literals, absolute_import
import logging
from sqlalchemy.orm.exc import NoResultFound
from rattail.db import model
from rattail.core import Object
from rattail.util import load_object
from rattail.db.cache import cache_model
log = logging.getLogger(__name__)
def make_importer(config, session, spec):
"""
Create an importer instance according to the given spec. For now..see the
source code for more details.
"""
importer = None
if '.' not in spec and ':' not in spec:
from rattail.db.importing import models
if hasattr(models, spec):
importer = getattr(models, spec)
elif hasattr(models, '{0}Importer'.format(spec)):
importer = getattr(models, '{0}Importer'.format(spec))
else:
importer = load_object(spec)
if importer:
return importer(config, session)
class Importer(Object):
"""
Base class for model importers.
"""
supported_fields = []
normalizer_class = None
cached_data = None
complex_fields = []
"""
Sequence of field names which are considered complex and therefore require
custom logic provided by the derived class, etc.
"""
def __init__(self, config, session, **kwargs):
self.config = config
self.session = session
super(Importer, self).__init__(**kwargs)
@property
def model_module(self):
"""
Reference to a module which contains all available / necessary data
models. By default this is ``rattail.db.model``.
"""
return model
@property
def model_class(self):
return getattr(model, self.__class__.__name__[:-8])
@property
def model_name(self):
return self.model_class.__name__
@property
def simple_fields(self):
return self.supported_fields
def import_data(self, records, fields, key, count=None, max_updates=None, progress=None):
"""
Import some data.
"""
if count is None:
count = len(records)
if count == 0:
return [], []
self.fields = fields
self.key = key
if isinstance(key, basestring):
self.key = (key,)
self.progress = progress
self.normalizer = self.normalizer_class() if self.normalizer_class else None
self.setup()
self.cache_data(progress)
# Normalize to remove duplicate source records. This is more for the
# sake of sanity since duplicates typically lead to a ping-pong effect
# where an update-less import is impossible.
normalized = {}
for src_data in records:
key = self.get_key(src_data)
if key in normalized:
log.warning("duplicate records from {0}:{1} for key: {2}".format(
self.provider.__class__.__module__, self.provider.__class__.__name__, repr(key)))
normalized[key] = src_data
records = []
for key in sorted(normalized):
records.append(normalized[key])
prog = None
if progress:
prog = progress("Importing {0} data".format(self.model_name), count)
created = []
updated = []
affected = 0
keys_seen = set()
for i, src_data in enumerate(records, 1):
key = self.get_key(src_data)
if key in keys_seen:
log.warning("duplicate records from {0}:{1} for key: {2}".format(
self.provider.__class__.__module__, self.provider.__class__.__name__, repr(key)))
else:
keys_seen.add(key)
self.normalize_record(src_data)
dirty = False
inst_data = self.get_instance_data(src_data)
if inst_data:
if self.data_differs(inst_data, src_data):
instance = self.get_instance(src_data)
self.update_instance(instance, src_data, inst_data)
updated.append(instance)
dirty = True
else:
instance = self.new_instance(src_data)
assert instance, "Failed to create new model instance for data: {0}".format(repr(src_data))
self.update_instance(instance, src_data)
self.session.add(instance)
self.session.flush()
log.debug("created new {} {}: {}".format(self.model_name, key, instance))
created.append(instance)
dirty = True
if dirty:
self.session.flush()
affected += 1
if max_updates and affected >= max_updates:
log.warning("max of {0} updates has been reached; bailing early".format(max_updates))
break
if prog:
prog.update(i)
if prog:
prog.destroy()
return created, updated
def setup(self):
"""
Perform any setup necessary, e.g. cache lookups for existing data.
"""
def cache_query_options(self):
"""
Return a list of options to apply to the cache query, if needed.
"""
def cache_model(self, model_class, key, **kwargs):
"""
Convenience method for caching a model.
"""
kwargs.setdefault('progress', self.progress)
return cache_model(self.session, model_class, key=key, **kwargs)
def get_cache_key(self, instance, normalized):
"""
Get the primary model cache key for a given instance/data object.
"""
return tuple(normalized['data'].get(k) for k in self.key)
def cache_data(self, progress):
"""
Cache all existing model instances as normalized data.
"""
self.cached_data = self.cache_model(self.model_class, self.get_cache_key,
query_options=self.cache_query_options(),
normalizer=self.normalize_cache)
def normalize_cache(self, instance):
"""
Normalizer for cache data. This adds the instance to the cache in
addition to its normalized data. This is so that if lots of updates
are required, we don't we have to constantly fetch them.
"""
return {'instance': instance, 'data': self.normalize_instance(instance)}
def data_differs(self, inst_data, src_data):
"""
Compare source record data to instance data to determine if there is a
net change.
"""
for field in self.fields:
if src_data[field] != inst_data[field]:
log.debug("field {0} differed for instance data: {1}, source data: {2}".format(
field, repr(inst_data), repr(src_data)))
return True
return False
def string_or_null(self, data, *fields):
"""
For each field specified, ensure the data value is a non-empty string,
or ``None``.
"""
for field in fields:
if field in data:
value = data[field]
value = value.strip() if value else None
data[field] = value or None
def int_or_null(self, data, *fields):
"""
For each field specified, ensure the data value is a non-zero integer,
or ``None``.
"""
for field in fields:
if field in data:
value = data[field]
value = int(value) if value else None
data[field] = value or None
def prioritize_2(self, data, field):
"""
Prioritize the data values for the pair of fields implied by the given
fieldname. I.e., if only one non-empty value is present, make sure
it's in the first slot.
"""
field2 = '{0}_2'.format(field)
if field in data and field2 in data:
if data[field2] and not data[field]:
data[field], data[field2] = data[field2], None
def normalize_record(self, data):
"""
Normalize the source data record, if necessary.
"""
def get_key(self, data):
"""
Return the key value for the given source data record.
"""
return tuple(data.get(k) for k in self.key)
def get_instance(self, data):
"""
Fetch an instance from our database which corresponds to the source
data, if possible; otherwise return ``None``.
"""
key = self.get_key(data)
if not key:
log.warning("source {0} has no {1}: {2}".format(
self.model_name, self.key, repr(data)))
return None
if self.cached_data is not None:
data = self.cached_data.get(key)
return data['instance'] if data else None
q = self.session.query(self.model_class)
for i, k in enumerate(self.key):
q = q.filter(getattr(self.model_class, k) == key[i])
try:
instance = q.one()
except NoResultFound:
return None
else:
return instance
def get_instance_data(self, data):
"""
Return a normalized data record for the model instance corresponding to
the source data record, or ``None``.
"""
key = self.get_key(data)
if not key:
log.warning("source {0} has no {1}: {2}".format(
self.model_name, self.key, repr(data)))
return None
if self.cached_data is not None:
data = self.cached_data.get(key)
return data['data'] if data else None
instance = self.get_instance(data)
if instance:
return self.normalize_instance(instance)
def normalize_instance(self, instance):
"""
Normalize a model instance.
"""
if self.normalizer:
return self.normalizer.normalize(instance)
data = {}
for field in self.simple_fields:
if field in self.fields:
data[field] = getattr(instance, field)
return data
def new_instance(self, data):
"""
Return a new model instance to correspond to the source data record.
"""
kwargs = {}
key = self.get_key(data)
for i, k in enumerate(self.key):
if k in self.simple_fields:
kwargs[k] = key[i]
return self.model_class(**kwargs)
def update_instance(self, instance, data, inst_data=None):
"""
Update the given model instance with the given data.
"""
for field in self.simple_fields:
if field in data:
if not inst_data or inst_data[field] != data[field]:
setattr(instance, field, data[field])
class RecordRenderer(object):
"""
Record renderer for email notifications sent from data import jobs.
"""
def __init__(self, config):
self.config = config
def __call__(self, record):
return self.render(record)
def render(self, record):
"""
Render the given record. Default is to attempt.
"""
key = record.__class__.__name__.lower()
renderer = getattr(self, 'render_{0}'.format(key), None)
if renderer:
return renderer(record)
label = self.get_label(record)
url = self.get_url(record)
if url:
return '<a href="{0}">{1}</a>'.format(url, label)
return label
def get_label(self, record):
key = record.__class__.__name__.lower()
label = getattr(self, 'label_{0}'.format(key), self.label)
return label(record)
def label(self, record):
return unicode(record)
def get_url(self, record):
"""
Fetch / generate a URL for the given data record. You should *not*
override this method, but do :meth:`url()` instead.
"""
key = record.__class__.__name__.lower()
url = getattr(self, 'url_{0}'.format(key), self.url)
return url(record)
def url(self, record):
"""
Fetch / generate a URL for the given data record.
"""
url = self.config.get('tailbone', 'url')
if url:
url = url.rstrip('/')
name = '{0}s'.format(record.__class__.__name__.lower())
if name == 'persons': # FIXME, obviously this is a hack
name = 'people'
url = '{0}/{1}/{{uuid}}'.format(url, name)
return url.format(uuid=record.uuid)

View file

@ -1,135 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2016 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Import Handlers
"""
from __future__ import unicode_literals, absolute_import
import logging
from rattail.util import OrderedDict
from rattail.mail import send_email
log = logging.getLogger(__name__)
class ImportHandler(object):
"""
Base class for all import handlers.
"""
def __init__(self, config=None, session=None):
self.config = config
self.session = session
self.importers = self.get_importers()
def get_importers(self):
"""
Returns a dict of all available importers, where the values are
importer factories. All subclasses will want to override this. Note
that if you return an ``OrderedDict`` instance, you can affect the
ordering of keys in the command line help system, etc.
"""
return {}
def get_importer_keys(self):
"""
Returns a list of keys corresponding to the available importers.
"""
return list(self.importers.iterkeys())
def get_importer(self, key):
"""
Returns an importer instance corresponding to the given key.
"""
return self.importers[key](self.config, self.session,
**self.get_importer_kwargs(key))
def get_importer_kwargs(self, key):
"""
Return a dict of kwargs to be used when construcing an importer with
the given key.
"""
return {}
def import_data(self, keys, max_updates=None, progress=None):
"""
Import all data for the given importer keys.
"""
self.before_import()
updates = OrderedDict()
for key in keys:
provider = self.get_importer(key)
if not provider:
log.warning("unknown importer; skipping: {0}".format(repr(key)))
continue
data = provider.get_data(progress=progress)
created, updated = provider.importer.import_data(
data, provider.supported_fields, provider.key,
max_updates=max_updates, progress=progress)
if hasattr(provider, 'process_deletions'):
deleted = provider.process_deletions(data, progress=progress)
else:
deleted = 0
log.info("added {0}, updated {1}, deleted {2} {3} records".format(
len(created), len(updated), deleted, key))
if created or updated or deleted:
updates[key] = created, updated, deleted
self.after_import()
return updates
def before_import(self):
return
def after_import(self):
return
def process_warnings(self, updates, command=None, **kwargs):
"""
If an import was run with "warnings" enabled, and work was effectively
done then this method is called to process the updates. The assumption
is that a warning email will be sent with the details, but you can do
anything you like if you override this.
"""
data = kwargs
data['updates'] = updates
if command:
data['command'] = '{} {}'.format(command.parent.name, command.name)
else:
data['command'] = None
if command:
key = '{}_{}_updates'.format(command.parent.name, command.name)
key = key.replace('-', '_')
else:
key = 'rattail_import_updates'
send_email(self.config, key, fallback_key='rattail_import_updates', data=data)

File diff suppressed because it is too large Load diff

View file

@ -1,164 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2016 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Rattail Data Normalization
"""
from __future__ import unicode_literals, absolute_import
class Normalizer(object):
"""
Base class for data normalizers.
"""
def normalize(self, instance):
raise NotImplementedError
class UserNormalizer(Normalizer):
"""
Normalizer for user data.
"""
# Must set this to the administrator Role instance.
admin = None
def normalize(self, user):
return {
'uuid': user.uuid,
'username': user.username,
'password': user.password,
'salt': user.salt,
'person_uuid': user.person_uuid,
'active': user.active,
'admin': self.admin in user.roles,
}
class DepartmentNormalizer(Normalizer):
"""
Normalizer for department data.
"""
def normalize(self, department):
return {
'uuid': department.uuid,
'number': department.number,
'name': department.name,
}
class EmployeeNormalizer(Normalizer):
"""
Normalizer for employee data.
"""
def normalize(self, employee):
person = employee.person
customer = person.customers[0] if person.customers else None
data = {
'uuid': employee.uuid,
'id': employee.id,
'person_uuid': person.uuid,
'customer_id': customer.id if customer else None,
'status': employee.status,
'first_name': person.first_name,
'last_name': person.last_name,
'display_name': employee.display_name,
'person_display_name': person.display_name,
}
data['phone_number'] = None
for phone in employee.phones:
if phone.type == 'Home':
data['phone_number'] = phone.number
break
data['phone_number_2'] = None
first = False
for phone in employee.phones:
if phone.type == 'Home':
if first:
data['phone_number_2'] = phone.number
break
first = True
email = employee.email
data['email_address'] = email.address if email else None
return data
class EmployeeStoreNormalizer(Normalizer):
"""
Normalizer for employee_x_store data.
"""
def normalize(self, emp_store):
return {
'uuid': emp_store.uuid,
'employee_uuid': emp_store.employee_uuid,
'store_uuid': emp_store.store_uuid,
}
class EmployeeDepartmentNormalizer(Normalizer):
"""
Normalizer for employee_x_department data.
"""
def normalize(self, emp_dept):
return {
'uuid': emp_dept.uuid,
'employee_uuid': emp_dept.employee_uuid,
'department_uuid': emp_dept.department_uuid,
}
class MessageNormalizer(Normalizer):
"""
Normalizer for message data.
"""
def normalize(self, message):
return {
'uuid': message.uuid,
'sender_uuid': message.sender_uuid,
'subject': message.subject,
'body': message.body,
'sent': message.sent,
}
class MessageRecipientNormalizer(Normalizer):
"""
Normalizer for message recipient data.
"""
def normalize(self, recip):
return {
'uuid': recip.uuid,
'message_uuid': recip.message_uuid,
'recipient_uuid': recip.recipient_uuid,
'status': recip.status,
}

View file

@ -1,27 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2015 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Import Data Providers
"""
from .core import DataProvider, QueryProvider

View file

@ -1,182 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2015 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Import Data Providers
"""
from __future__ import unicode_literals
import datetime
from rattail.core import Object
from rattail.db.cache import cache_model
class DataProvider(Object):
"""
Base class for import data providers.
"""
importer_class = None
normalize_progress_message = "Normalizing source data"
progress = None
def __init__(self, config, session, importer=None, importer_kwargs={}, **kwargs):
self.config = config
self.session = session
if importer is None:
self.importer = self.importer_class(config, session, **importer_kwargs)
else:
self.importer = importer
super(DataProvider, self).__init__(**kwargs)
self.importer.provider = self
@property
def model_class(self):
return self.importer.model_class
@property
def key(self):
"""
Key by which records should be matched between the source data and
Rattail.
"""
raise NotImplementedError("Please define the `key` for your data provider.")
@property
def model_name(self):
return self.model_class.__name__
def cache_model(self, model_class, key, **kwargs):
"""
Convenience method for caching a model.
"""
kwargs.setdefault('progress', self.progress)
return cache_model(self.session, model_class, key=key, **kwargs)
def setup(self):
"""
Perform any setup necessary, e.g. cache lookups for existing data.
"""
def get_data(self, progress=None, normalize_progress_message=None):
"""
Return the full set of normalized data which is to be imported.
"""
self.now = datetime.datetime.utcnow()
self.progress = progress
self.setup()
source_data = self.get_source_data(progress=progress)
data = self.normalize_source_data(source_data, progress=progress)
self.teardown()
return data
def teardown(self):
"""
Perform any cleanup necessary, after the main data run.
"""
def get_source_data(self, progress=None):
"""
Return the data which is to be imported.
"""
return []
def normalize_source_data(self, source_data, progress=None):
"""
Return a normalized version of the full set of source data.
"""
data = []
count = len(source_data)
if count == 0:
return data
prog = None
if progress:
prog = progress(self.normalize_progress_message, count)
for i, record in enumerate(source_data, 1):
record = self.normalize(record)
if record:
data.append(record)
if prog:
prog.update(i)
if prog:
prog.destroy()
return data
def normalize(self, data):
"""
Normalize a source data record. Generally this is where the provider
may massage the record in any way necessary, so that its values are
more "native" and can be used for direct comparison with, and
assignment to, the target model instance.
Note that if you override this, your method must return the data to be
imported. If your method returns ``None`` then that particular record
would be skipped and not imported.
"""
return data
def int_(self, value):
"""
Coerce ``value`` to an integer, or return ``None`` if that can't be
done cleanly.
"""
try:
return int(value)
except (TypeError, ValueError):
return None
class QueryDataProxy(object):
"""
Simple proxy to wrap a SQLAlchemy query and make it sort of behave like a
normal sequence, as much as needed to make a ``DataProvider`` happy.
"""
def __init__(self, query):
self.query = query
def __len__(self):
return self.query.count()
def __iter__(self):
return iter(self.query)
class QueryProvider(DataProvider):
"""
Data provider whose data source is a SQLAlchemy query. Note that this
needn't be a Rattail database query; any database will work as long as a
SQLAlchemy query is behind it.
"""
def query(self):
"""
Return the query which will define the data set.
"""
raise NotImplementedError
def get_source_data(self, progress=None):
"""
Return the data which is to be imported.
"""
return QueryDataProxy(self.query())

View file

@ -1,177 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2015 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
CSV Import Data Providers
"""
from __future__ import unicode_literals
import datetime
from decimal import Decimal
from .core import DataProvider
from rattail.db import model
from rattail.gpc import GPC
from rattail.db.importing import models
from rattail.util import load_object
from rattail.csvutil import UnicodeDictReader
from rattail.db.util import maxlen
from rattail.time import localtime, make_utc
def make_provider(config, session, spec, **kwargs):
"""
Create a provider instance according to the given spec. For now..see the
source code for more details.
"""
provider = None
if '.' not in spec and ':' not in spec:
from rattail.db.importing.providers import csv
if hasattr(csv, spec):
provider = getattr(csv, spec)
elif hasattr(csv, '{0}Provider'.format(spec)):
provider = getattr(csv, '{0}Provider'.format(spec))
else:
provider = load_object(spec)
if provider:
return provider(config, session, **kwargs)
class CsvProvider(DataProvider):
"""
Base class for CSV data providers.
"""
time_format = '%Y-%m-%d %H:%M:%S'
def get_source_data(self, progress=None):
with open(self.data_path, 'rb') as f:
reader = UnicodeDictReader(f)
return list(reader)
def make_utc(self, time):
if time is None:
return None
return make_utc(localtime(self.config, time))
def make_time(self, value):
if not value:
return None
time = datetime.datetime.strptime(value, self.time_format)
return self.make_utc(time)
class ProductProvider(CsvProvider):
"""
CSV product data provider.
"""
importer_class = models.ProductImporter
supported_fields = [
'uuid',
'upc',
'description',
'size',
'department_uuid',
'subdepartment_uuid',
'category_uuid',
'brand_uuid',
'regular_price',
'sale_price',
'sale_starts',
'sale_ends',
]
maxlen_description = maxlen(model.Product.description)
maxlen_size = maxlen(model.Product.size)
def normalize(self, data):
if 'upc' in data:
upc = data['upc']
data['upc'] = GPC(upc) if upc else None
# Warn about truncation until Rattail schema is addressed.
if 'description' in data:
description = data['description'] or ''
if self.maxlen_description and len(description) > self.maxlen_description:
log.warning("product description is more than {} chars and will be truncated: {}".format(
self.maxlen_description, repr(description)))
description = description[:self.maxlen_description]
data['description'] = description
# Warn about truncation until Rattail schema is addressed.
if 'size' in data:
size = data['size'] or ''
if self.maxlen_size and len(size) > self.maxlen_size:
log.warning("product size is more than {} chars and will be truncated: {}".format(
self.maxlen_size, repr(size)))
size = size[:self.maxlen_size]
data['size'] = size
if 'department_uuid' in data:
data['department_uuid'] = data['department_uuid'] or None
if 'subdepartment_uuid' in data:
data['subdepartment_uuid'] = data['subdepartment_uuid'] or None
if 'category_uuid' in data:
data['category_uuid'] = data['category_uuid'] or None
if 'brand_uuid' in data:
data['brand_uuid'] = data['brand_uuid'] or None
if 'regular_price' in data:
price = data['regular_price']
data['regular_price'] = Decimal(price) if price else None
# Determine if sale price is currently active; if it is not then we
# will declare None for all sale fields.
if 'sale_starts' in data:
data['sale_starts'] = self.make_time(data['sale_starts'])
if 'sale_ends' in data:
data['sale_ends'] = self.make_time(data['sale_ends'])
if 'sale_price' in data:
price = data['sale_price']
data['sale_price'] = Decimal(price) if price else None
if data['sale_price']:
sale_starts = data.get('sale_starts')
sale_ends = data.get('sale_ends')
active = False
if sale_starts and sale_ends:
if sale_starts <= self.now <= sale_ends:
active = True
elif sale_starts:
if sale_starts <= self.now:
active = True
elif sale_ends:
if self.now <= sale_ends:
active = True
else:
active = True
if not active:
data['sale_price'] = None
data['sale_starts'] = None
data['sale_ends'] = None
else:
data['sale_starts'] = None
data['sale_ends'] = None
return data

View file

@ -1,31 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2016 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
(New) Data Importing Framework
"""
from __future__ import unicode_literals, absolute_import
from .importers import Importer, QueryImporter, SQLAlchemyImporter, BulkPostgreSQLImporter
from .handlers import ImportHandler, SQLAlchemyImportHandler, BulkPostgreSQLImportHandler
from . import model

View file

@ -1,323 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2016 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Import Handlers
"""
from __future__ import unicode_literals, absolute_import
import sys
import datetime
import logging
import humanize
from rattail.util import OrderedDict
from rattail.mail import send_email
log = logging.getLogger(__name__)
class ImportHandler(object):
"""
Base class for all import handlers.
"""
local_title = "Rattail"
host_title = "Host/Other"
session = None
progress = None
dry_run = False
def __init__(self, config=None, **kwargs):
self.config = config
self.importers = self.get_importers()
for key, value in kwargs.iteritems():
setattr(self, key, value)
def get_importers(self):
"""
Returns a dict of all available importers, where the values are
importer factories. All subclasses will want to override this. Note
that if you return an ``OrderedDict`` instance, you can affect the
ordering of keys in the command line help system, etc.
"""
return {}
def get_importer_keys(self):
"""
Returns a list of keys corresponding to the available importers.
"""
return list(self.importers.iterkeys())
def get_default_keys(self):
"""
Returns a list of keys corresponding to the default importers.
Override this if you wish certain importers to be excluded by default,
e.g. when first testing them out etc.
"""
return self.get_importer_keys()
def get_importer(self, key):
"""
Returns an importer instance corresponding to the given key.
"""
kwargs = self.get_importer_kwargs(key)
kwargs['config'] = self.config
kwargs['session'] = self.session
importer = self.importers[key](**kwargs)
importer.handler = self
return importer
def get_importer_kwargs(self, key):
"""
Return a dict of kwargs to be used when construcing an importer with
the given key.
"""
kwargs = {}
if hasattr(self, 'host_session'):
kwargs['host_session'] = self.host_session
return kwargs
def import_data(self, keys, args):
"""
Import all data for the given importer keys.
"""
self.now = datetime.datetime.utcnow()
self.dry_run = args.dry_run
self.begin_transaction()
self.setup()
changes = OrderedDict()
for key in keys:
importer = self.get_importer(key)
if not importer:
log.warning("skipping unknown importer: {}".format(key))
continue
created, updated, deleted = importer.import_data(args, progress=self.progress)
changed = bool(created or updated or deleted)
logger = log.warning if changed and args.warnings else log.info
logger("{} -> {}: added {}, updated {}, deleted {} {} records".format(
self.host_title, self.local_title, len(created), len(updated), len(deleted), key))
if changed:
changes[key] = created, updated, deleted
if changes:
self.process_changes(changes, args)
if self.dry_run:
self.rollback_transaction()
else:
self.commit_transaction()
self.teardown()
return changes
def begin_transaction(self):
self.begin_host_transaction()
self.begin_local_transaction()
def begin_host_transaction(self):
if hasattr(self, 'make_host_session'):
self.host_session = self.make_host_session()
def begin_local_transaction(self):
pass
def setup(self):
"""
Perform any setup necessary, prior to running the import task(s).
"""
def rollback_transaction(self):
self.rollback_host_transaction()
self.rollback_local_transaction()
def rollback_host_transaction(self):
if hasattr(self, 'host_session'):
self.host_session.rollback()
self.host_session.close()
self.host_session = None
def rollback_local_transaction(self):
pass
def commit_transaction(self):
self.commit_host_transaction()
self.commit_local_transaction()
def commit_host_transaction(self):
if hasattr(self, 'host_session'):
self.host_session.commit()
self.host_session.close()
self.host_session = None
def commit_local_transaction(self):
pass
def teardown(self):
"""
Perform any cleanup necessary, after running the import task(s).
"""
def process_changes(self, changes, args):
"""
This method is called any time changes occur, regardless of whether the
import is running in "warnings" mode. Default implementation however
is to do nothing unless warnings mode is in effect, in which case an
email notification will be sent.
"""
# TODO: This whole thing needs a re-write...but for now, waiting until
# the old importer has really gone away, so we can share its email
# template instead of bothering with something more complicated.
if not args.warnings:
return
data = {
'local_title': self.local_title,
'host_title': self.host_title,
'argv': sys.argv,
'runtime': humanize.naturaldelta(datetime.datetime.utcnow() - self.now),
'changes': changes,
'dry_run': args.dry_run,
'render_record': RecordRenderer(self.config),
'max_display': 15,
}
command = getattr(self, 'command', None)
if command:
data['command'] = '{} {}'.format(command.parent.name, command.name)
else:
data['command'] = None
if command:
key = '{}_{}_updates'.format(command.parent.name, command.name)
key = key.replace('-', '_')
else:
key = 'rattail_import_updates'
send_email(self.config, key, fallback_key='rattail_import_updates', data=data)
class SQLAlchemyImportHandler(ImportHandler):
"""
Handler for imports for which the host data source is represented by a
SQLAlchemy engine and ORM.
"""
host_session = None
def make_host_session(self):
raise NotImplementedError
class BulkPostgreSQLImportHandler(ImportHandler):
"""
Handler for bulk imports which target PostgreSQL on the local side.
"""
def import_data(self, keys, args):
"""
Import all data for the given importer keys.
"""
self.now = datetime.datetime.utcnow()
self.dry_run = args.dry_run
self.begin_transaction()
self.setup()
for key in keys:
importer = self.get_importer(key)
if not importer:
log.warning("skipping unknown importer: {}".format(key))
continue
created = importer.import_data(args, progress=self.progress)
log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
self.host_title, self.local_title, created, key))
if self.dry_run:
self.rollback_transaction()
else:
self.commit_transaction()
self.teardown()
class RecordRenderer(object):
"""
Record renderer for email notifications sent from data import jobs.
"""
def __init__(self, config):
self.config = config
def __call__(self, record):
return self.render(record)
def render(self, record):
"""
Render the given record.
"""
key = record.__class__.__name__.lower()
renderer = getattr(self, 'render_{}'.format(key), None)
if renderer:
return renderer(record)
label = self.get_label(record)
url = self.get_url(record)
if url:
return '<a href="{}">{}</a>'.format(url, label)
return label
def get_label(self, record):
key = record.__class__.__name__.lower()
label = getattr(self, 'label_{}'.format(key), self.label)
return label(record)
def label(self, record):
return unicode(record)
def get_url(self, record):
"""
Fetch / generate a URL for the given data record. You should *not*
override this method, but do :meth:`url()` instead.
"""
key = record.__class__.__name__.lower()
url = getattr(self, 'url_{}'.format(key), self.url)
return url(record)
def url(self, record):
"""
Fetch / generate a URL for the given data record.
"""
if hasattr(record, 'uuid'):
url = self.config.get('tailbone', 'url')
if url:
url = url.rstrip('/')
name = '{}s'.format(record.__class__.__name__.lower())
if name == 'persons': # FIXME, obviously this is a hack
name = 'people'
url = '{}/{}/{{uuid}}'.format(url, name)
return url.format(uuid=record.uuid)

View file

@ -1,640 +0,0 @@
# -*- coding: utf-8 -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2016 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Data Importers
"""
from __future__ import unicode_literals, absolute_import
import datetime
import logging
from sqlalchemy import orm
from sqlalchemy.orm.exc import NoResultFound
from rattail.db import cache
from rattail.time import make_utc
log = logging.getLogger(__name__)
class Importer(object):
"""
Base class for all data importers.
"""
key = 'uuid'
cached_instances = None
allow_create = True
allow_update = True
allow_delete = True
dry_run = False
def __init__(self, config=None, session=None, fields=None, key=None, **kwargs):
self.config = config
self.session = session
self.fields = fields or self.supported_fields
if key:
self.key = key
if isinstance(self.key, basestring):
self.key = (self.key,)
for key, value in kwargs.iteritems():
setattr(self, key, value)
@property
def model_class(self):
"""
This should return a reference to the model class which the importer
"targets" so to speak.
"""
raise NotImplementedError
@property
def model_name(self):
"""
Returns the string 'name' of the model class which the importer targets.
"""
return self.model_class.__name__
@property
def model_mapper(self):
"""
This should return the SQLAlchemy mapper for the model class.
"""
return orm.class_mapper(self.model_class)
@property
def model_table(self):
"""
Returns the underlying table used by the primary local data model class.
"""
tables = self.model_mapper.tables
assert len(tables) == 1
return tables[0]
@property
def simple_fields(self):
"""
The list of field names which may be considered "simple" and therefore
treated as such, i.e. with basic getattr/setattr calls. Note that this
only applies to the local / target side, it has no effect on the
upstream / foreign side.
"""
return list(self.model_mapper.columns.keys())
@property
def supported_fields(self):
"""
The list of field names which are supported in general by the importer.
Note that this only applies to the local / target side, it has no
effect on the upstream / foreign side.
"""
return self.simple_fields
@property
def normalize_progress_message(self):
return "Reading {} data from {}".format(self.model_name, self.handler.host_title)
def setup(self):
"""
Perform any setup necessary, e.g. cache lookups for existing data.
"""
def teardown(self):
"""
Perform any cleanup after import, if necessary.
"""
def _setup(self, args, progress):
self.now = datetime.datetime.utcnow()
self.allow_create = self.allow_create and args.create
self.allow_update = self.allow_update and args.update
self.allow_delete = self.allow_delete and args.delete
self.dry_run = args.dry_run
self.args = args
self.progress = progress
self.setup()
def import_data(self, args, progress=None):
"""
Import some data! This is the core body of logic for that, regardless
of where data is coming from or where it's headed. Note that this
method handles deletions as well as adds/updates.
"""
self._setup(args, progress)
created = updated = deleted = []
data = self.normalize_source_data()
self.cached_instances = self.cache_instance_data(data)
# Normalize source data set in order to prune duplicate keys. This is
# for the sake of sanity since duplicates typically lead to a ping-pong
# effect, where a "clean" (change-less) import is impossible.
unique = {}
for record in data:
key = self.get_key(record)
if key in unique:
log.warning("duplicate records detected from {} for key: {}".format(
self.handler.host_title, key))
unique[key] = record
data = []
for key in sorted(unique):
data.append(unique[key])
if self.allow_create or self.allow_update:
created, updated = self._import_create_update(data, args)
if self.allow_delete:
changes = len(created) + len(updated)
if args.max_total and changes >= args.max_total:
log.warning("max of {} total changes already reached; skipping deletions".format(args.max_total))
else:
deleted = self._import_delete(data, args, host_keys=set(unique), changes=changes)
self.teardown()
return created, updated, deleted
def _import_create_update(self, data, args):
"""
Import the given data; create and/or update records as needed and
according to the args provided.
"""
created = []
updated = []
count = len(data)
if not count:
return created, updated
prog = None
if self.progress:
prog = self.progress("Importing {} data".format(self.model_name), count)
keys_seen = set()
for i, source_data in enumerate(data, 1):
# Get what should be the unique key for the current 'host' data
# record, but warn if we find it to be not unique. Note that we
# will still wind up importing both records however.
key = self.get_key(source_data)
if key in keys_seen:
log.warning("duplicate records from {}:{} for key: {}".format(
self.__class__.__module__, self.__class__.__name__, key))
else:
keys_seen.add(key)
# Fetch local instance, using key from host record.
instance = self.get_instance(key)
# If we have a local instance, but its data differs from host, update it.
if instance and self.allow_update:
instance_data = self.normalize_instance(instance)
diffs = self.data_diffs(instance_data, source_data)
if diffs:
log.debug("fields '{}' differed for local data: {}, host data: {}".format(
','.join(diffs), instance_data, source_data))
instance = self.update_instance(instance, source_data, instance_data)
updated.append((instance, instance_data, source_data))
if args.max_update and len(updated) >= args.max_update:
log.warning("max of {} *updated* records has been reached; stopping now".format(args.max_update))
break
if args.max_total and (len(created) + len(updated)) >= args.max_total:
log.warning("max of {} *total changes* has been reached; stopping now".format(args.max_total))
break
# If we did not yet have a local instance, create it using host data.
elif not instance and self.allow_create:
instance = self.create_instance(key, source_data)
log.debug("created new {} {}: {}".format(self.model_name, key, instance))
created.append((instance, source_data))
if self.cached_instances is not None:
self.cached_instances[key] = {'instance': instance, 'data': self.normalize_instance(instance)}
if args.max_create and len(created) >= args.max_create:
log.warning("max of {} *created* records has been reached; stopping now".format(args.max_create))
break
if args.max_total and (len(created) + len(updated)) >= args.max_total:
log.warning("max of {} *total changes* has been reached; stopping now".format(args.max_total))
break
if prog:
prog.update(i)
if prog:
prog.destroy()
return created, updated
def get_deletion_keys(self):
"""
Return a set of keys from the *local* data set, which are eligible for
deletion. By default this will be all keys from the local (cached)
data set.
"""
return set(self.cached_instances)
def _import_delete(self, data, args, host_keys=None, changes=0):
"""
Import deletions for the given data set.
"""
if host_keys is None:
host_keys = set([self.get_key(rec) for rec in data])
deleted = []
deleting = self.get_deletion_keys() - host_keys
count = len(deleting)
log.debug("found {} instances to delete".format(count))
if count:
prog = None
if self.progress:
prog = self.progress("Deleting {} data".format(self.model_name), count)
for i, key in enumerate(sorted(deleting), 1):
instance = self.cached_instances.pop(key)['instance']
if self.delete_instance(instance):
deleted.append((instance, self.normalize_instance(instance)))
if args.max_delete and len(deleted) >= args.max_delete:
log.warning("max of {} *deleted* records has been reached; stopping now".format(args.max_delete))
break
if args.max_total and (changes + len(deleted)) >= args.max_total:
log.warning("max of {} *total changes* has been reached; stopping now".format(args.max_total))
break
if prog:
prog.update(i)
if prog:
prog.destroy()
return deleted
def delete_instance(self, instance):
"""
Process a deletion for the given instance. The default implementation
really does delete the instance from the local session, so you must
override this if you need something else to happen.
This method must return a boolean indicating whether or not the
deletion was performed. This implies for example that you may simply
do nothing, and return ``False``, to effectively disable deletion
altogether for an importer.
"""
self.session.delete(instance)
self.session.flush()
self.session.expunge(instance)
return True
def get_source_data(self):
"""
Return the "raw" (as-is, not normalized) data which is to be imported.
This may be any sequence-like object, which has a ``len()`` value and
responds to iteration etc. The objects contained within it may be of
any type, no assumptions are made there. (That is the job of the
:meth:`normalize_source_data()` method.)
"""
return []
def normalize_source_data(self):
"""
Return a normalized version of the full set of source data. Note that
this calls :meth:`get_source_data()` to obtain the initial data set,
and then normalizes each record. the normalization process may filter
out some records from the set, in which case the return value will be
smaller than the original data set.
"""
source_data = self.get_source_data()
normalized = []
count = len(source_data)
if count == 0:
return normalized
prog = None
if self.progress:
prog = self.progress(self.normalize_progress_message, count)
for i, data in enumerate(source_data, 1):
data = self.normalize_source_record(data)
if data:
normalized.append(data)
if prog:
prog.update(i)
if prog:
prog.destroy()
return normalized
def get_key(self, data):
"""
Return the key value for the given data record.
"""
return tuple(data[k] for k in self.key)
def int_(self, value):
"""
Coerce ``value`` to an integer, or return ``None`` if that can't be
done cleanly.
"""
try:
return int(value)
except (TypeError, ValueError):
return None
def prioritize_2(self, data, field):
"""
Prioritize the data values for the pair of fields implied by the given
fieldname. I.e., if only one non-empty value is present, make sure
it's in the first slot.
"""
field2 = '{}_2'.format(field)
if field in data and field2 in data:
if data[field2] and not data[field]:
data[field], data[field2] = data[field2], None
def normalize_source_record(self, record):
"""
Normalize a source data record. Generally this is where the importer
may massage the record in any way necessary, so that its values are
more "native" and can be used for direct comparison with, and
assignment to, the target model instance.
Note that if you override this, your method must return the data to be
imported. If your method returns ``None`` then that particular record
would be skipped and not imported.
"""
return record
def cache_model(self, model, **kwargs):
"""
Convenience method which invokes :func:`rattail.db.cache.cache_model()`
with the given model and keyword arguments. It will provide the
``session`` and ``progress`` parameters by default, setting them to the
importer's attributes of the same names.
"""
session = kwargs.pop('session', self.session)
kwargs.setdefault('progress', self.progress)
return cache.cache_model(session, model, **kwargs)
def cache_instance_data(self, data=None):
"""
Cache all existing model instances as normalized data.
"""
return cache.cache_model(self.session, self.model_class,
key=self.get_cache_key,
omit_duplicates=True,
query_options=self.cache_query_options(),
normalizer=self.normalize_cache_instance,
progress=self.progress)
def cache_query_options(self):
"""
Return a list of options to apply to the cache query, if needed.
"""
def get_cache_key(self, instance, normalized):
"""
Get the primary model cache key for a given instance/data object.
"""
return tuple(normalized['data'].get(k) for k in self.key)
def normalize_cache_instance(self, instance):
"""
Normalizer for cache data. This adds the instance to the cache in
addition to its normalized data. This is so that if lots of updates
are required, we don't we have to constantly fetch them.
"""
return {'instance': instance, 'data': self.normalize_instance(instance)}
def get_instance(self, key):
"""
Must return the local object corresponding to the given key, or None.
Default behavior here will be to check the cache if one is in effect,
otherwise return the value from :meth:`get_single_instance()`.
"""
if self.cached_instances is not None:
data = self.cached_instances.get(key)
return data['instance'] if data else None
return self.get_single_instance(key)
def get_single_instance(self, key):
"""
Must return the local object corresponding to the given key, or None.
This method should not consult the cache; that is handled within the
:meth:`get_instance()` method.
"""
query = self.session.query(self.model_class)
for i, k in enumerate(self.key):
query = query.filter(getattr(self.model_class, k) == key[i])
try:
return query.one()
except NoResultFound:
pass
def normalize_instance(self, instance):
"""
Normalize a model instance.
"""
data = {}
for field in self.simple_fields:
if field in self.fields:
data[field] = getattr(instance, field)
return data
def newval(self, data, field, value):
"""
Assign a "new" field value to the given data record. In other words
don't try to be smart about not overwriting it if the existing data
already matches etc. However the main point of this is to skip fields
which are not included in the current task.
"""
if field in self.fields:
data[field] = value
def data_diffs(self, local_data, host_data):
"""
Find all (relevant) fields which differ between the model and host data
for a given record.
"""
diffs = []
for field in self.fields:
if local_data[field] != host_data[field]:
diffs.append(field)
return diffs
def create_instance(self, key, data):
instance = self.new_instance(key)
if instance:
instance = self.update_instance(instance, data)
if instance:
self.session.add(instance)
return instance
def new_instance(self, key):
"""
Return a new model instance to correspond to the given key.
"""
instance = self.model_class()
for i, k in enumerate(self.key):
if hasattr(instance, k):
setattr(instance, k, key[i])
return instance
def update_instance(self, instance, data, instance_data=None):
"""
Update the given model instance with the given data.
"""
for field in self.simple_fields:
if field in self.fields:
if not instance_data or instance_data[field] != data[field]:
setattr(instance, field, data[field])
return instance
class QueryDataProxy(object):
"""
Simple proxy to wrap a SQLAlchemy (or Django) query and make it sort of
behave like a normal sequence, as much as needed to make an importer happy.
"""
def __init__(self, query):
self.query = query
def __len__(self):
return self.query.count()
def __iter__(self):
return iter(self.query)
class QueryImporter(Importer):
"""
Base class for importers whose raw external data source is a SQLAlchemy (or
Django) query.
"""
def query(self):
"""
Must return the primary query which will define the data set.
"""
raise NotImplementedError
def get_source_data(self, progress=None):
return QueryDataProxy(self.query())
class SQLAlchemyImporter(QueryImporter):
"""
Base class for importers whose external data source is a SQLAlchemy query.
"""
host_session = None
@property
def host_model_class(self):
"""
For default behavior, set this to a model class to be used in
generating the host (source) data query.
"""
raise NotImplementedError
def query(self):
"""
Must return the primary query which will define the data set. Default
behavior is to leverage :attr:`host_session` and generate a query for
the class defined by :attr:`host_model_class`.
"""
return self.host_session.query(self.host_model_class)
class BulkPostgreSQLImporter(Importer):
"""
Base class for bulk data importers which target PostgreSQL on the local side.
"""
def import_data(self, args, progress=None):
self._setup(args, progress)
self.open_data_buffers()
data = self.normalize_source_data()
created = self._import_create(data, args)
self.teardown()
return created
def open_data_buffers(self):
self.data_buffer = open(self.data_path, 'wb')
def teardown(self):
self.data_buffer.close()
def _import_create(self, data, args):
count = len(data)
if not count:
return 0
created = count
prog = None
if self.progress:
prog = self.progress("Importing {} data".format(self.model_name), count)
for i, source_data in enumerate(data, 1):
key = self.get_key(source_data)
self.create_instance(key, source_data)
if args.max_create and i >= args.max_create:
log.warning("max of {} *created* records has been reached; stopping now".format(args.max_create))
created = i
break
if prog:
prog.update(i)
if prog:
prog.destroy()
self.commit_create()
return created
def commit_create(self):
log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
self.seek_data_buffers()
cursor = self.session.connection().connection.cursor()
cursor.copy_from(self.data_buffer, self.model_table.name, columns=self.fields)
log.debug("PostgreSQL data copy completed")
def seek_data_buffers(self):
self.data_buffer.close()
self.data_buffer = open(self.data_path, 'rb')
def create_instance(self, key, data):
self.prep_data_for_postgres(data)
self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])))
def prep_data_for_postgres(self, data):
for key, value in data.iteritems():
if value is None:
value = '\\N'
elif value is True:
value = 't'
elif value is False:
value = 'f'
elif isinstance(value, datetime.datetime):
value = make_utc(value)
elif isinstance(value, basestring):
value = value.replace('\\', '\\\\')
value = value.replace('\r\n', '\n')
value = value.replace('\r', '\\r')
value = value.replace('\n', '\\n')
data[key] = unicode(value)

File diff suppressed because it is too large Load diff