diff --git a/src/wuttaweb/util.py b/src/wuttaweb/util.py index 0697f03..f634c00 100644 --- a/src/wuttaweb/util.py +++ b/src/wuttaweb/util.py @@ -32,6 +32,7 @@ import uuid as _uuid import warnings import sqlalchemy as sa +from sqlalchemy import orm import colander from webhelpers2.html import HTML, tags @@ -478,24 +479,36 @@ def render_csrf_token(request, name='_csrf'): return HTML.tag('div', tags.hidden(name, value=token, id=None), style='display:none;') -def get_model_fields(config, model_class=None): +def get_model_fields(config, model_class, include_fk=False): """ Convenience function to return a list of field names for the given - model class. + :term:`data model` class. This logic only supports SQLAlchemy mapped classes and will use that to determine the field listing if applicable. Otherwise this returns ``None``. - """ - if not model_class: - return + :param config: App :term:`config object`. + + :param model_class: Data model class. + + :param include_fk: Whether to include foreign key column names in + the result. They are excluded by default, since the + relationship names are also included and generally preferred. + + :returns: List of field names, or ``None`` if it could not be + determined. + """ try: mapper = sa.inspect(model_class) except sa.exc.NoInspectionAvailable: return - fields = [prop.key for prop in mapper.iterate_properties] + if include_fk: + fields = [prop.key for prop in mapper.iterate_properties] + else: + fields = [prop.key for prop in mapper.iterate_properties + if not prop_is_fk(mapper, prop)] # nb. we never want the continuum 'versions' prop app = config.get_app() @@ -505,6 +518,20 @@ def get_model_fields(config, model_class=None): return fields +def prop_is_fk(mapper, prop): + """ """ + if not isinstance(prop, orm.ColumnProperty): + return False + + prop_columns = [col.name for col in prop.columns] + for rel in mapper.relationships: + rel_columns = [col.name for col in rel.local_columns] + if rel_columns == prop_columns: + return True + + return False + + def make_json_safe(value, key=None, warn=True): """ Convert a Python value as needed, to ensure it is compatible with diff --git a/tests/test_util.py b/tests/test_util.py index 6946d65..b8c7ba7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -11,6 +11,8 @@ from fanstatic import Library, Resource from pyramid import testing from wuttjamaican.conf import WuttaConfig +from wuttjamaican.testing import ConfigTestCase + from wuttaweb import util as mod @@ -463,14 +465,10 @@ class TestGetFormData(TestCase): self.assertEqual(data, {'foo2': 'baz'}) -class TestGetModelFields(TestCase): - - def setUp(self): - self.config = WuttaConfig() - self.app = self.config.get_app() +class TestGetModelFields(ConfigTestCase): def test_empty_model_class(self): - fields = mod.get_model_fields(self.config) + fields = mod.get_model_fields(self.config, None) self.assertIsNone(fields) def test_unknown_model_class(self): @@ -482,6 +480,19 @@ class TestGetModelFields(TestCase): fields = mod.get_model_fields(self.config, model.Setting) self.assertEqual(fields, ['name', 'value']) + def test_include_fk(self): + model = self.app.model + + # fk excluded by default + fields = mod.get_model_fields(self.config, model.User) + self.assertNotIn('person_uuid', fields) + self.assertIn('person', fields) + + # fk can be included + fields = mod.get_model_fields(self.config, model.User, include_fk=True) + self.assertIn('person_uuid', fields) + self.assertIn('person', fields) + def test_avoid_versions(self): model = self.app.model