Add AppHandler.next_counter_value() magic
				
					
				
			this is now used by `BatchHandler.consume_batch_id()` and hopefully is able to auto-magically create a dedicated counter table if the underlying db engine is not postgres. at least that part seems to work for tests using sqlite
This commit is contained in:
		
							parent
							
								
									4de258d09b
								
							
						
					
					
						commit
						e4277d80fb
					
				
					 7 changed files with 179 additions and 3 deletions
				
			
		| 
						 | 
					@ -15,6 +15,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   .. autoattribute:: batch_model_class
 | 
					   .. autoattribute:: batch_model_class
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   .. automethod:: consume_batch_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   .. automethod:: make_batch
 | 
					   .. automethod:: make_batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   .. automethod:: make_basic_batch
 | 
					   .. automethod:: make_basic_batch
 | 
				
			||||||
| 
						 | 
					@ -113,6 +115,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   .. automethod:: remove_row
 | 
					   .. automethod:: remove_row
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   .. automethod:: get_effective_rows
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   .. automethod:: executable
 | 
					   .. automethod:: executable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   .. automethod:: why_not_execute
 | 
					   .. automethod:: why_not_execute
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,6 +30,7 @@ import os
 | 
				
			||||||
# import re
 | 
					# import re
 | 
				
			||||||
import tempfile
 | 
					import tempfile
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import six
 | 
					import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,6 +42,9 @@ from rattail.config import parse_list
 | 
				
			||||||
from rattail.core import get_uuid
 | 
					from rattail.core import get_uuid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AppHandler(object):
 | 
					class AppHandler(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Base class and default implementation for top-level Rattail app handler.
 | 
					    Base class and default implementation for top-level Rattail app handler.
 | 
				
			||||||
| 
						 | 
					@ -170,6 +174,29 @@ class AppHandler(object):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return load_object(spec)
 | 
					        return load_object(spec)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def next_counter_value(self, session, key, **kwargs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Return the next counter value for the given key.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param session: Current session for Rattail DB.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param key: Unique key indicating the counter for which the
 | 
				
			||||||
 | 
					           next value should be fetched.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :returns: Next value as integer.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        dialect = session.bind.url.get_dialect().name
 | 
				
			||||||
 | 
					        if dialect != 'postgresql':
 | 
				
			||||||
 | 
					            log.debug("non-postgresql database detected; will use workaround")
 | 
				
			||||||
 | 
					            from rattail.db.util import CounterMagic
 | 
				
			||||||
 | 
					            magic = CounterMagic(self.config)
 | 
				
			||||||
 | 
					            return magic.next_value(session, key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # normal (uses postgresql sequence)
 | 
				
			||||||
 | 
					        sql = "select nextval('{}_seq')".format(key)
 | 
				
			||||||
 | 
					        value = session.execute(sql).scalar()
 | 
				
			||||||
 | 
					        return value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_active_stores(self, session, **kwargs):
 | 
					    def get_active_stores(self, session, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Returns the list of "active" stores.  A store is considered
 | 
					        Returns the list of "active" stores.  A store is considered
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@
 | 
				
			||||||
################################################################################
 | 
					################################################################################
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
#  Rattail -- Retail Software Framework
 | 
					#  Rattail -- Retail Software Framework
 | 
				
			||||||
#  Copyright © 2010-2021 Lance Edgar
 | 
					#  Copyright © 2010-2022 Lance Edgar
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
#  This file is part of Rattail.
 | 
					#  This file is part of Rattail.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
| 
						 | 
					@ -119,6 +119,22 @@ class BatchHandler(object):
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def consume_batch_id(self, session, as_str=False):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Consumes a new batch ID from the generator, and returns it.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param session: Current session for Rattail DB.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param as_str: Flag indicating whether the return value should be a
 | 
				
			||||||
 | 
					           string, as opposed to the default of integer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :returns: Batch ID as integer, or zero-padded string of 8 chars.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        batch_id = self.app.next_counter_value(session, 'batch_id')
 | 
				
			||||||
 | 
					        if as_str:
 | 
				
			||||||
 | 
					            return '{:08d}'.format(batch_id)
 | 
				
			||||||
 | 
					        return batch_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def make_basic_batch(self, session, user=None, progress=None, **kwargs):
 | 
					    def make_basic_batch(self, session, user=None, progress=None, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Make a new "basic" batch, with no customization beyond what is provided
 | 
					        Make a new "basic" batch, with no customization beyond what is provided
 | 
				
			||||||
| 
						 | 
					@ -712,6 +728,23 @@ class BatchHandler(object):
 | 
				
			||||||
        batch.executed_by = user
 | 
					        batch.executed_by = user
 | 
				
			||||||
        return result
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_effective_rows(self, batch):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Should return the set of rows from the given batch which are
 | 
				
			||||||
 | 
					        considered "effective" - i.e. when the batch is executed,
 | 
				
			||||||
 | 
					        these rows should be processed whereas the remainder should
 | 
				
			||||||
 | 
					        not.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param batch: A
 | 
				
			||||||
 | 
					           :class:`~rattail.db.model.batch.vendorcatalog.VendorCatalogBatch`
 | 
				
			||||||
 | 
					           instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :returns: List of
 | 
				
			||||||
 | 
					           :class:`~rattail.db.model.batch.vendorcatalog.VendorCatalogBatchRow`
 | 
				
			||||||
 | 
					           instances.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return batch.active_rows()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def execute(self, batch, progress=None, **kwargs):
 | 
					    def execute(self, batch, progress=None, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Execute the given batch, according to the given kwargs.  This is really
 | 
					        Execute the given batch, according to the given kwargs.  This is really
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -55,6 +55,16 @@ class VendorCatalogHandler(BatchHandler):
 | 
				
			||||||
    case_cost_diff_threshold = None
 | 
					    case_cost_diff_threshold = None
 | 
				
			||||||
    unit_cost_diff_threshold = None
 | 
					    unit_cost_diff_threshold = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def allow_future(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns boolean indicating whether "future" cost changes
 | 
				
			||||||
 | 
					        should be allowed.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :returns: ``True`` if future cost changes allowed; else ``False``.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.config.getbool('rattail.batch', 'vendor_catalog.allow_future',
 | 
				
			||||||
 | 
					                                   default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def should_populate(self, batch):
 | 
					    def should_populate(self, batch):
 | 
				
			||||||
        # all vendor catalogs must come from data file
 | 
					        # all vendor catalogs must come from data file
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
| 
						 | 
					@ -118,7 +128,8 @@ class VendorCatalogHandler(BatchHandler):
 | 
				
			||||||
                                                   require=True)
 | 
					                                                   require=True)
 | 
				
			||||||
        parser.session = session
 | 
					        parser.session = session
 | 
				
			||||||
        parser.vendor = batch.vendor
 | 
					        parser.vendor = batch.vendor
 | 
				
			||||||
        batch.effective = parser.parse_effective_date(path)
 | 
					        if not batch.effective:
 | 
				
			||||||
 | 
					            batch.effective = parser.parse_effective_date(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def append(row, i):
 | 
					        def append(row, i):
 | 
				
			||||||
            self.add_row(batch, row)
 | 
					            self.add_row(batch, row)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@
 | 
				
			||||||
################################################################################
 | 
					################################################################################
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
#  Rattail -- Retail Software Framework
 | 
					#  Rattail -- Retail Software Framework
 | 
				
			||||||
#  Copyright © 2010-2018 Lance Edgar
 | 
					#  Copyright © 2010-2022 Lance Edgar
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
#  This file is part of Rattail.
 | 
					#  This file is part of Rattail.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
| 
						 | 
					@ -30,6 +30,7 @@ import re
 | 
				
			||||||
import pprint
 | 
					import pprint
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import sqlalchemy as sa
 | 
				
			||||||
from sqlalchemy import orm
 | 
					from sqlalchemy import orm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: Deprecate/remove these imports.
 | 
					# TODO: Deprecate/remove these imports.
 | 
				
			||||||
| 
						 | 
					@ -39,6 +40,27 @@ from rattail.db.config import engine_from_config, get_engines, get_default_engin
 | 
				
			||||||
log = logging.getLogger(__name__)
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CounterMagic(object):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Provides magic counter values, to simulate PostgreSQL sequence.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config):
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.metadata = sa.MetaData()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def next_value(self, session, key):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Increment and return the next counter value for given key.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        engine = session.bind
 | 
				
			||||||
 | 
					        table = sa.Table('counter_{}'.format(key), self.metadata,
 | 
				
			||||||
 | 
					                         sa.Column('value', sa.Integer(), primary_key=True))
 | 
				
			||||||
 | 
					        table.create(engine, checkfirst=True)
 | 
				
			||||||
 | 
					        result = engine.execute(table.insert())
 | 
				
			||||||
 | 
					        return result.lastrowid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class QuerySequence(object):
 | 
					class QuerySequence(object):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it
 | 
					    Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										67
									
								
								tests/batch/test_handlers.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								tests/batch/test_handlers.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,67 @@
 | 
				
			||||||
 | 
					# -*- coding: utf-8; -*-
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from __future__ import unicode_literals, absolute_import
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from unittest import TestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import sqlalchemy as sa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from rattail.batch import handlers as mod
 | 
				
			||||||
 | 
					from rattail.config import make_config
 | 
				
			||||||
 | 
					from rattail.db import Session
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestBatchHandler(TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.config = self.make_config()
 | 
				
			||||||
 | 
					        self.handler = self.make_handler()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def make_config(self):
 | 
				
			||||||
 | 
					        return make_config([], extend=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def make_handler(self):
 | 
				
			||||||
 | 
					        return mod.BatchHandler(self.config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_consume_batch_id(self):
 | 
				
			||||||
 | 
					        engine = sa.create_engine('sqlite://')
 | 
				
			||||||
 | 
					        model = self.config.get_model()
 | 
				
			||||||
 | 
					        model.Base.metadata.create_all(bind=engine)
 | 
				
			||||||
 | 
					        session = Session(bind=engine)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # first id is 1
 | 
				
			||||||
 | 
					        result = self.handler.consume_batch_id(session)
 | 
				
			||||||
 | 
					        self.assertEqual(result, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # second is 2; test string version
 | 
				
			||||||
 | 
					        result = self.handler.consume_batch_id(session, as_str=True)
 | 
				
			||||||
 | 
					        self.assertEqual(result, '00000002')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_get_effective_rows(self):
 | 
				
			||||||
 | 
					        engine = sa.create_engine('sqlite://')
 | 
				
			||||||
 | 
					        model = self.config.get_model()
 | 
				
			||||||
 | 
					        model.Base.metadata.create_all(bind=engine)
 | 
				
			||||||
 | 
					        session = Session(bind=engine)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # make batch w/ 3 rows
 | 
				
			||||||
 | 
					        user = model.User(username='patty')
 | 
				
			||||||
 | 
					        batch = model.NewProductBatch(id=1, created_by=user)
 | 
				
			||||||
 | 
					        batch.data_rows.append(model.NewProductBatchRow())
 | 
				
			||||||
 | 
					        batch.data_rows.append(model.NewProductBatchRow())
 | 
				
			||||||
 | 
					        batch.data_rows.append(model.NewProductBatchRow())
 | 
				
			||||||
 | 
					        self.assertEqual(len(batch.data_rows), 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # all rows should be effective by default
 | 
				
			||||||
 | 
					        result = self.handler.get_effective_rows(batch)
 | 
				
			||||||
 | 
					        self.assertEqual(len(result), 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # unless we mark one as "removed"
 | 
				
			||||||
 | 
					        batch.data_rows[1].removed = True
 | 
				
			||||||
 | 
					        result = self.handler.get_effective_rows(batch)
 | 
				
			||||||
 | 
					        self.assertEqual(len(result), 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # or if we delete one
 | 
				
			||||||
 | 
					        batch.data_rows.pop(-1)
 | 
				
			||||||
 | 
					        result = self.handler.get_effective_rows(batch)
 | 
				
			||||||
 | 
					        self.assertEqual(len(result), 1)
 | 
				
			||||||
| 
						 | 
					@ -28,6 +28,18 @@ class TestProductBatchHandler(TestCase):
 | 
				
			||||||
    def make_handler(self):
 | 
					    def make_handler(self):
 | 
				
			||||||
        return mod.VendorCatalogHandler(self.config)
 | 
					        return mod.VendorCatalogHandler(self.config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_allow_future(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # off by default
 | 
				
			||||||
 | 
					        result = self.handler.allow_future()
 | 
				
			||||||
 | 
					        self.assertFalse(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # but can be enabled via config
 | 
				
			||||||
 | 
					        self.config.setdefault('rattail.batch', 'vendor_catalog.allow_future',
 | 
				
			||||||
 | 
					                               'true')
 | 
				
			||||||
 | 
					        result = self.handler.allow_future()
 | 
				
			||||||
 | 
					        self.assertTrue(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_populate_from_file(self):
 | 
					    def test_populate_from_file(self):
 | 
				
			||||||
        engine = sa.create_engine('sqlite://')
 | 
					        engine = sa.create_engine('sqlite://')
 | 
				
			||||||
        model = self.config.get_model()
 | 
					        model = self.config.get_model()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue