From 10875bf8807ada94928657650112d2e68a48a7a1 Mon Sep 17 00:00:00 2001 From: Lance Edgar Date: Fri, 27 Feb 2026 09:57:53 -0600 Subject: [PATCH] fix: add `joins` param for `model_transaction_query()` --- src/wutta_continuum/util.py | 76 +++++++++++++++++++++++++++++++++++-- tests/test_util.py | 25 +++++++++++- 2 files changed, 97 insertions(+), 4 deletions(-) diff --git a/src/wutta_continuum/util.py b/src/wutta_continuum/util.py index 4ca64ec..d213e2c 100644 --- a/src/wutta_continuum/util.py +++ b/src/wutta_continuum/util.py @@ -53,7 +53,7 @@ def render_operation_type(operation_type): return OPERATION_TYPES[operation_type] -def model_transaction_query(instance, session=None, model_class=None): +def model_transaction_query(instance, session=None, model_class=None, joins=None): """ Make a query capable of finding all SQLAlchemy-Continuum ``transaction`` records associated with the given model instance. @@ -66,8 +66,35 @@ def model_transaction_query(instance, session=None, model_class=None): :param model_class: Optional :term:`data model` class to query. If not specified, will be obtained from the ``instance``. + :param joins: Optional sequence of "join info tuples" - see + further explanation below. + :returns: SQLAlchemy query object. Note that it will *not* have an ``ORDER BY`` clause yet. + + The default logic looks for any "version" records for the given + instance, and returns the associated transactions. But sometimes + you need to look for more than one type of version record: + + If e.g. a core table provides common fields but a custom table + adds extension fields for a record, you will want to find + transactions involving *either* version table when showing a + record's total history. + + This is accomplished here via the ``joins`` param. If specified, + each item in the ``joins`` sequence must be a 3-tuple:: + + (related_class, related_attr, instance_attr) + + The meaning of those is as follows; assuming ``User`` is the main + instance model class: + + * ``related_class`` - model class which is "related" to the + instance model class, e.g. ``UserExtension`` + * ``related_attr`` - attribute name on the related class which + serves as foreign key to the instance, e.g. ``"user_uuid"`` + * ``instance_attr`` - attribute name on the main instance class + which serves as primary key, e.g. ``"uuid"`` """ if not session: session = orm.object_session(instance) @@ -77,9 +104,52 @@ def model_transaction_query(instance, session=None, model_class=None): txncls = continuum.transaction_class(model_class) vercls = continuum.version_class(model_class) - query = session.query(txncls).join( + # basic query is for the *transaction* table + query = session.query(txncls) + + # we'll do inner *or* outer join on main version table below + join_args = ( vercls, - sa.and_(vercls.uuid == instance.uuid, vercls.transaction_id == txncls.id), + sa.and_( + vercls.uuid == instance.uuid, + vercls.transaction_id == txncls.id, + ), ) + if joins: + + # we must *outer* join on main version table, since we will + # also be joining on other version tables + query = query.outerjoin(*join_args) + + # we'll collect "filter conditions" for use below... + conditions = [vercls.uuid != None] + + # add join/filter for each requested by caller + for child_class, foreign_attr, primary_attr in joins: + child_vercls = continuum.version_class(child_class) + foreign_attr = getattr(child_vercls, foreign_attr) + query = query.outerjoin( + child_vercls, + sa.and_( + child_vercls.transaction_id == txncls.id, + foreign_attr == getattr(instance, primary_attr), + ), + ) + + # and add the filter condition for use below... + conditions.append(foreign_attr != None) + + # at this point we have *outer* joined on *all* version tables + # involved, but that means basically "all transactions" will + # match! so we add explicit filter to make sure at least one + # of them is related for a transaction to match + query = query.filter(sa.or_(*conditions)) + + else: + + # no joins were specified, so we can just do *inner* join on + # the main version table and call it good + query = query.join(*join_args) + return query diff --git a/tests/test_util.py b/tests/test_util.py index 944c861..158e02b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -24,7 +24,7 @@ class TestRenderOperationType(TestCase): class TestModelTransactionQuery(VersionTestCase): - def test_basic(self): + def test_inner_join(self): model = self.app.model user = model.User(username="fred") @@ -38,3 +38,26 @@ class TestModelTransactionQuery(VersionTestCase): UserVersion = continuum.version_class(model.User) version = self.session.query(UserVersion).one() self.assertIs(version.transaction, txn) + + def test_outer_joins(self): + model = self.app.model + + person = model.Person(full_name="Fred Flintstone") + self.session.add(person) + user = model.User(username="fred", person=person) + self.session.add(user) + self.session.commit() + + query = mod.model_transaction_query( + user, joins=[(model.Person, "uuid", "person_uuid")] + ) + self.assertEqual(query.count(), 1) + txn = query.one() + + vercls_user = continuum.version_class(model.User) + version = self.session.query(vercls_user).one() + self.assertIs(version.transaction, txn) + + vercls_person = continuum.version_class(model.Person) + version = self.session.query(vercls_person).one() + self.assertIs(version.transaction, txn)