diff --git a/src/wuttjamaican/db/model/base.py b/src/wuttjamaican/db/model/base.py index 4a82f5a..81c41ab 100644 --- a/src/wuttjamaican/db/model/base.py +++ b/src/wuttjamaican/db/model/base.py @@ -46,7 +46,25 @@ naming_convention = { metadata = sa.MetaData(naming_convention=naming_convention) -Base = orm.declarative_base(metadata=metadata) + +class ModelBase: + """ """ + + def __iter__(self): + # nb. we override this to allow for `dict(self)` + state = sa.inspect(self) + fields = [attr.key for attr in state.attrs] + return iter([(field, getattr(self, field)) + for field in fields]) + + def __getitem__(self, key): + # nb. we override this to allow for `x = self['field']` + state = sa.inspect(self) + if hasattr(state.attrs, key): + return getattr(self, key) + + +Base = orm.declarative_base(metadata=metadata, cls=ModelBase) def uuid_column(*args, **kwargs): diff --git a/tests/db/model/test_base.py b/tests/db/model/test_base.py index c77dfa2..1d55436 100644 --- a/tests/db/model/test_base.py +++ b/tests/db/model/test_base.py @@ -10,6 +10,15 @@ except ImportError: pass else: + class TestModelBase(TestCase): + + def test_dict_behavior(self): + setting = model.Setting() + self.assertEqual(list(iter(setting)), [('name', None), ('value', None)]) + self.assertIsNone(setting['name']) + setting.name = 'foo' + self.assertEqual(setting['name'], 'foo') + class TestUUIDColumn(TestCase): def test_basic(self):