fix: add joins param for model_transaction_query()

This commit is contained in:
Lance Edgar 2026-02-27 09:57:53 -06:00
parent d95a230848
commit 10875bf880
2 changed files with 97 additions and 4 deletions

View file

@ -53,7 +53,7 @@ def render_operation_type(operation_type):
return OPERATION_TYPES[operation_type]
def model_transaction_query(instance, session=None, model_class=None):
def model_transaction_query(instance, session=None, model_class=None, joins=None):
"""
Make a query capable of finding all SQLAlchemy-Continuum
``transaction`` records associated with the given model instance.
@ -66,8 +66,35 @@ def model_transaction_query(instance, session=None, model_class=None):
:param model_class: Optional :term:`data model` class to query.
If not specified, will be obtained from the ``instance``.
:param joins: Optional sequence of "join info tuples" - see
further explanation below.
:returns: SQLAlchemy query object. Note that it will *not* have an
``ORDER BY`` clause yet.
The default logic looks for any "version" records for the given
instance, and returns the associated transactions. But sometimes
you need to look for more than one type of version record:
If e.g. a core table provides common fields but a custom table
adds extension fields for a record, you will want to find
transactions involving *either* version table when showing a
record's total history.
This is accomplished here via the ``joins`` param. If specified,
each item in the ``joins`` sequence must be a 3-tuple::
(related_class, related_attr, instance_attr)
The meaning of those is as follows; assuming ``User`` is the main
instance model class:
* ``related_class`` - model class which is "related" to the
instance model class, e.g. ``UserExtension``
* ``related_attr`` - attribute name on the related class which
serves as foreign key to the instance, e.g. ``"user_uuid"``
* ``instance_attr`` - attribute name on the main instance class
which serves as primary key, e.g. ``"uuid"``
"""
if not session:
session = orm.object_session(instance)
@ -77,9 +104,52 @@ def model_transaction_query(instance, session=None, model_class=None):
txncls = continuum.transaction_class(model_class)
vercls = continuum.version_class(model_class)
query = session.query(txncls).join(
# basic query is for the *transaction* table
query = session.query(txncls)
# we'll do inner *or* outer join on main version table below
join_args = (
vercls,
sa.and_(vercls.uuid == instance.uuid, vercls.transaction_id == txncls.id),
sa.and_(
vercls.uuid == instance.uuid,
vercls.transaction_id == txncls.id,
),
)
if joins:
# we must *outer* join on main version table, since we will
# also be joining on other version tables
query = query.outerjoin(*join_args)
# we'll collect "filter conditions" for use below...
conditions = [vercls.uuid != None]
# add join/filter for each requested by caller
for child_class, foreign_attr, primary_attr in joins:
child_vercls = continuum.version_class(child_class)
foreign_attr = getattr(child_vercls, foreign_attr)
query = query.outerjoin(
child_vercls,
sa.and_(
child_vercls.transaction_id == txncls.id,
foreign_attr == getattr(instance, primary_attr),
),
)
# and add the filter condition for use below...
conditions.append(foreign_attr != None)
# at this point we have *outer* joined on *all* version tables
# involved, but that means basically "all transactions" will
# match! so we add explicit filter to make sure at least one
# of them is related for a transaction to match
query = query.filter(sa.or_(*conditions))
else:
# no joins were specified, so we can just do *inner* join on
# the main version table and call it good
query = query.join(*join_args)
return query

View file

@ -24,7 +24,7 @@ class TestRenderOperationType(TestCase):
class TestModelTransactionQuery(VersionTestCase):
def test_basic(self):
def test_inner_join(self):
model = self.app.model
user = model.User(username="fred")
@ -38,3 +38,26 @@ class TestModelTransactionQuery(VersionTestCase):
UserVersion = continuum.version_class(model.User)
version = self.session.query(UserVersion).one()
self.assertIs(version.transaction, txn)
def test_outer_joins(self):
model = self.app.model
person = model.Person(full_name="Fred Flintstone")
self.session.add(person)
user = model.User(username="fred", person=person)
self.session.add(user)
self.session.commit()
query = mod.model_transaction_query(
user, joins=[(model.Person, "uuid", "person_uuid")]
)
self.assertEqual(query.count(), 1)
txn = query.one()
vercls_user = continuum.version_class(model.User)
version = self.session.query(vercls_user).one()
self.assertIs(version.transaction, txn)
vercls_person = continuum.version_class(model.Person)
version = self.session.query(vercls_person).one()
self.assertIs(version.transaction, txn)