diff --git a/src/wuttasync/cli/import_versions.py b/src/wuttasync/cli/import_versions.py index f1d0481..86da4c4 100644 --- a/src/wuttasync/cli/import_versions.py +++ b/src/wuttasync/cli/import_versions.py @@ -28,6 +28,7 @@ import sys import rich import typer +from typing_extensions import Annotated from wuttjamaican.cli import wutta_typer @@ -36,7 +37,14 @@ from .base import import_command, ImportCommandHandler @wutta_typer.command() @import_command -def import_versions(ctx: typer.Context, **kwargs): # pylint: disable=unused-argument +def import_versions( # pylint: disable=unused-argument + ctx: typer.Context, + comment: Annotated[ + str, + typer.Option("--comment", "-m", help="Comment to set on the transaction."), + ] = "import catch-up versions", + **kwargs, +): """ Import latest data to version tables, for Wutta DB """ diff --git a/src/wuttasync/importing/versions.py b/src/wuttasync/importing/versions.py index 53c25fa..b2fd062 100644 --- a/src/wuttasync/importing/versions.py +++ b/src/wuttasync/importing/versions.py @@ -92,6 +92,13 @@ class FromWuttaToVersions(FromWuttaHandler, ToWuttaHandler): See also :attr:`continuum_uow`. """ + continuum_comment = None + + def consume_kwargs(self, kwargs): + kwargs = super().consume_kwargs(kwargs) + self.continuum_comment = kwargs.pop("comment", None) + return kwargs + def begin_target_transaction(self): # pylint: disable=line-too-long """ @@ -106,6 +113,8 @@ class FromWuttaToVersions(FromWuttaHandler, ToWuttaHandler): :meth:`~sqlalchemy-continuum:sqlalchemy_continuum.unit_of_work.UnitOfWork.create_transaction()` and assigns that to :attr:`continuum_txn`. + It also sets the comment for the transaction, if applicable. + See also docs for parent method: :meth:`~wuttasync.importing.handlers.ToSqlalchemyHandler.begin_target_transaction()` """ @@ -116,8 +125,12 @@ class FromWuttaToVersions(FromWuttaHandler, ToWuttaHandler): self.continuum_uow = continuum.versioning_manager.unit_of_work( self.target_session ) + self.continuum_txn = self.continuum_uow.create_transaction(self.target_session) + if self.continuum_comment: + self.continuum_txn.meta = {"comment": self.continuum_comment} + def get_importer_kwargs(self, key, **kwargs): """ This modifies the new importer kwargs to add: diff --git a/tests/importing/test_versions.py b/tests/importing/test_versions.py index 2067f93..1988706 100644 --- a/tests/importing/test_versions.py +++ b/tests/importing/test_versions.py @@ -14,17 +14,40 @@ class TestFromWuttaToVersions(VersionTestCase): def make_handler(self, **kwargs): return mod.FromWuttaToVersions(self.config, **kwargs) + def test_consume_kwargs(self): + + # no comment by default + handler = self.make_handler() + kw = handler.consume_kwargs({}) + self.assertEqual(kw, {}) + self.assertIsNone(handler.continuum_comment) + + # but can provide one + handler = self.make_handler() + kw = handler.consume_kwargs({"comment": "yeehaw"}) + self.assertEqual(kw, {}) + self.assertEqual(handler.continuum_comment, "yeehaw") + def test_begin_target_transaction(self): model = self.app.model txncls = continuum.transaction_class(model.User) + # basic / defaults handler = self.make_handler() self.assertIsNone(handler.continuum_uow) self.assertIsNone(handler.continuum_txn) - handler.begin_target_transaction() self.assertIsInstance(handler.continuum_uow, continuum.UnitOfWork) self.assertIsInstance(handler.continuum_txn, txncls) + # nb. no comment + self.assertIsNone(handler.continuum_txn.meta.get("comment")) + + # with comment + handler = self.make_handler() + handler.continuum_comment = "yeehaw" + handler.begin_target_transaction() + self.assertIn("comment", handler.continuum_txn.meta) + self.assertEqual(handler.continuum_txn.meta["comment"], "yeehaw") def test_get_importer_kwargs(self): handler = self.make_handler() @@ -57,14 +80,6 @@ class TestFromWuttaToVersions(VersionTestCase): self.assertNotIn("Upgrade", importers) -class UserImporter(mod.FromWuttaToVersionBase): - - @property - def model_class(self): - model = self.app.model - return model.User - - class TestFromWuttaToVersionBase(VersionTestCase): def make_importer(self, model_class=None, **kwargs):