From 650dc8ff4dc15ae6bcd240c1e3d5c6914684d13f Mon Sep 17 00:00:00 2001 From: Lance Edgar Date: Sun, 29 Jun 2025 19:38:29 -0500 Subject: [PATCH] feat: remove version cap for SQLAlchemy (allow 1.x or 2.x) hoping this does not break things terribly, but it needs to be done regardless so will just have to pick up pieces if so --- pyproject.toml | 2 +- src/wuttjamaican/db/handler.py | 6 +++++- src/wuttjamaican/db/util.py | 7 +++++++ tests/db/test_handler.py | 3 +-- tests/test_install.py | 13 ++++++++++--- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c625694..a00b12b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ [project.optional-dependencies] -db = ["SQLAlchemy<2", "alembic", "alembic-postgresql-enum", "passlib"] +db = ["SQLAlchemy", "alembic", "alembic-postgresql-enum", "passlib"] docs = ["Sphinx", "sphinxcontrib-programoutput", "enum-tools[sphinx]", "furo"] tests = ["pytest-cov", "tox"] diff --git a/src/wuttjamaican/db/handler.py b/src/wuttjamaican/db/handler.py index 7c745d8..849f954 100644 --- a/src/wuttjamaican/db/handler.py +++ b/src/wuttjamaican/db/handler.py @@ -34,6 +34,10 @@ class DatabaseHandler(GenericHandler): Base class and default implementation for the :term:`db handler`. """ + def get_dialect(self, bind): + """ """ + return bind.url.get_dialect().name + def next_counter_value(self, session, key): """ Return the next counter value for the given key. @@ -52,7 +56,7 @@ class DatabaseHandler(GenericHandler): :returns: Next value as integer. """ - dialect = session.bind.url.get_dialect().name + dialect = self.get_dialect(session.bind) # postgres uses "true" native sequence if dialect == 'postgresql': diff --git a/src/wuttjamaican/db/util.py b/src/wuttjamaican/db/util.py index bc06e22..ff4b69b 100644 --- a/src/wuttjamaican/db/util.py +++ b/src/wuttjamaican/db/util.py @@ -25,6 +25,8 @@ Database Utilities """ import uuid as _uuid +from importlib.metadata import version +from packaging.version import Version import sqlalchemy as sa from sqlalchemy import orm @@ -44,6 +46,11 @@ naming_convention = { } +SA2 = True +if Version(version('SQLAlchemy')) < Version('2'): # pragma: no cover + SA2 = False + + class ModelBase: """ """ diff --git a/tests/db/test_handler.py b/tests/db/test_handler.py index e28e813..bee05e5 100644 --- a/tests/db/test_handler.py +++ b/tests/db/test_handler.py @@ -51,8 +51,7 @@ else: # using sqlite backend. # using postgres as backend, should use "sequence" - with patch.object(self.session.bind.url, 'get_dialect') as get_dialect: - get_dialect.return_value.name = 'postgresql' + with patch.object(handler, 'get_dialect', return_value='postgresql'): with patch.object(self.session, 'execute') as execute: execute.return_value.scalar.return_value = 1 value = handler.next_counter_value(self.session, 'testing') diff --git a/tests/test_install.py b/tests/test_install.py index ba410b1..2bc2ddb 100644 --- a/tests/test_install.py +++ b/tests/test_install.py @@ -92,6 +92,8 @@ class TestInstallHandler(ConfigTestCase): except ImportError: pytest.skip("test is not relevant without sqlalchemy") + from wuttjamaican.db.util import SA2 + handler = self.make_handler() def prompt_generic(info, default=None, is_password=False): @@ -112,6 +114,8 @@ class TestInstallHandler(ConfigTestCase): self.assertRaises(RuntimeError, handler.get_dbinfo) sys.exit.assert_called_once_with(1) + seekrit = '***' if SA2 else 'seekrit' + # good dbinfo sys.exit.reset_mock() test_db_connection.return_value = None @@ -119,7 +123,7 @@ class TestInstallHandler(ConfigTestCase): self.assertFalse(sys.exit.called) rprint.assert_called_with("[bold green]good[/bold green]") self.assertEqual(str(dbinfo['dburl']), - 'postgresql+psycopg2://poser:seekrit@localhost:5432/poser') + f'postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser') def test_make_db_url(self): try: @@ -127,13 +131,16 @@ class TestInstallHandler(ConfigTestCase): except ImportError: pytest.skip("test is not relevant without sqlalchemy") + from wuttjamaican.db.util import SA2 + handler = self.make_handler() + seekrit = '***' if SA2 else 'seekrit' url = handler.make_db_url('postgresql', 'localhost', '5432', 'poser', 'poser', 'seekrit') - self.assertEqual(str(url), 'postgresql+psycopg2://poser:seekrit@localhost:5432/poser') + self.assertEqual(str(url), f'postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser') url = handler.make_db_url('mysql', 'localhost', '3306', 'poser', 'poser', 'seekrit') - self.assertEqual(str(url), 'mysql+mysqlconnector://poser:seekrit@localhost:3306/poser') + self.assertEqual(str(url), f'mysql+mysqlconnector://poser:{seekrit}@localhost:3306/poser') def test_test_db_connection(self): try: