diff --git a/pyproject.toml b/pyproject.toml index c625694..a00b12b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ [project.optional-dependencies] -db = ["SQLAlchemy<2", "alembic", "alembic-postgresql-enum", "passlib"] +db = ["SQLAlchemy", "alembic", "alembic-postgresql-enum", "passlib"] docs = ["Sphinx", "sphinxcontrib-programoutput", "enum-tools[sphinx]", "furo"] tests = ["pytest-cov", "tox"] diff --git a/src/wuttjamaican/db/handler.py b/src/wuttjamaican/db/handler.py index 7c745d8..849f954 100644 --- a/src/wuttjamaican/db/handler.py +++ b/src/wuttjamaican/db/handler.py @@ -34,6 +34,10 @@ class DatabaseHandler(GenericHandler): Base class and default implementation for the :term:`db handler`. """ + def get_dialect(self, bind): + """ """ + return bind.url.get_dialect().name + def next_counter_value(self, session, key): """ Return the next counter value for the given key. @@ -52,7 +56,7 @@ class DatabaseHandler(GenericHandler): :returns: Next value as integer. """ - dialect = session.bind.url.get_dialect().name + dialect = self.get_dialect(session.bind) # postgres uses "true" native sequence if dialect == 'postgresql': diff --git a/src/wuttjamaican/db/util.py b/src/wuttjamaican/db/util.py index bc06e22..ff4b69b 100644 --- a/src/wuttjamaican/db/util.py +++ b/src/wuttjamaican/db/util.py @@ -25,6 +25,8 @@ Database Utilities """ import uuid as _uuid +from importlib.metadata import version +from packaging.version import Version import sqlalchemy as sa from sqlalchemy import orm @@ -44,6 +46,11 @@ naming_convention = { } +SA2 = True +if Version(version('SQLAlchemy')) < Version('2'): # pragma: no cover + SA2 = False + + class ModelBase: """ """ diff --git a/tests/db/test_handler.py b/tests/db/test_handler.py index e28e813..bee05e5 100644 --- a/tests/db/test_handler.py +++ b/tests/db/test_handler.py @@ -51,8 +51,7 @@ else: # using sqlite backend. # using postgres as backend, should use "sequence" - with patch.object(self.session.bind.url, 'get_dialect') as get_dialect: - get_dialect.return_value.name = 'postgresql' + with patch.object(handler, 'get_dialect', return_value='postgresql'): with patch.object(self.session, 'execute') as execute: execute.return_value.scalar.return_value = 1 value = handler.next_counter_value(self.session, 'testing') diff --git a/tests/test_install.py b/tests/test_install.py index ba410b1..2bc2ddb 100644 --- a/tests/test_install.py +++ b/tests/test_install.py @@ -92,6 +92,8 @@ class TestInstallHandler(ConfigTestCase): except ImportError: pytest.skip("test is not relevant without sqlalchemy") + from wuttjamaican.db.util import SA2 + handler = self.make_handler() def prompt_generic(info, default=None, is_password=False): @@ -112,6 +114,8 @@ class TestInstallHandler(ConfigTestCase): self.assertRaises(RuntimeError, handler.get_dbinfo) sys.exit.assert_called_once_with(1) + seekrit = '***' if SA2 else 'seekrit' + # good dbinfo sys.exit.reset_mock() test_db_connection.return_value = None @@ -119,7 +123,7 @@ class TestInstallHandler(ConfigTestCase): self.assertFalse(sys.exit.called) rprint.assert_called_with("[bold green]good[/bold green]") self.assertEqual(str(dbinfo['dburl']), - 'postgresql+psycopg2://poser:seekrit@localhost:5432/poser') + f'postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser') def test_make_db_url(self): try: @@ -127,13 +131,16 @@ class TestInstallHandler(ConfigTestCase): except ImportError: pytest.skip("test is not relevant without sqlalchemy") + from wuttjamaican.db.util import SA2 + handler = self.make_handler() + seekrit = '***' if SA2 else 'seekrit' url = handler.make_db_url('postgresql', 'localhost', '5432', 'poser', 'poser', 'seekrit') - self.assertEqual(str(url), 'postgresql+psycopg2://poser:seekrit@localhost:5432/poser') + self.assertEqual(str(url), f'postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser') url = handler.make_db_url('mysql', 'localhost', '3306', 'poser', 'poser', 'seekrit') - self.assertEqual(str(url), 'mysql+mysqlconnector://poser:seekrit@localhost:3306/poser') + self.assertEqual(str(url), f'mysql+mysqlconnector://poser:{seekrit}@localhost:3306/poser') def test_test_db_connection(self): try: