diff --git a/src/wuttjamaican/util.py b/src/wuttjamaican/util.py index 78a755a..51bdb03 100644 --- a/src/wuttjamaican/util.py +++ b/src/wuttjamaican/util.py @@ -38,6 +38,46 @@ log = logging.getLogger(__name__) UNSPECIFIED = object() +def get_class_hierarchy(klass, topfirst=True): + """ + Returns a list of all classes in the inheritance chain for the + given class. + + For instance:: + + class A: + pass + + class B(A): + pass + + class C(B): + pass + + get_class_hierarchy(C) + # -> [A, B, C] + + :param klass: The reference class. The list of classes returned + will include this class and all its parents. + + :param topfirst: Whether the returned list should be sorted in a + "top first" way, e.g. A) grandparent, B) parent, C) child. + This is the default but pass ``False`` to get the reverse. + """ + hierarchy = [] + + def traverse(cls): + if cls is not object: + hierarchy.append(cls) + for parent in cls.__bases__: + traverse(parent) + + traverse(klass) + if topfirst: + hierarchy.reverse() + return hierarchy + + def load_entry_points(group, ignore_errors=False): """ Load a set of ``setuptools``-style entry points. diff --git a/tests/test_util.py b/tests/test_util.py index 921e416..0f2baf4 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,20 +6,41 @@ from unittest.mock import patch, MagicMock import pytest -from wuttjamaican import util +from wuttjamaican import util as mod + + +class A: pass +class B(A): pass +class C(B): pass + +class TestGetClassHierarchy(TestCase): + + def test_basic(self): + + classes = mod.get_class_hierarchy(A) + self.assertEqual(classes, [A]) + + classes = mod.get_class_hierarchy(B) + self.assertEqual(classes, [A, B]) + + classes = mod.get_class_hierarchy(C) + self.assertEqual(classes, [A, B, C]) + + classes = mod.get_class_hierarchy(C, topfirst=False) + self.assertEqual(classes, [C, B, A]) class TestLoadEntryPoints(TestCase): def test_empty(self): # empty set returned for unknown group - result = util.load_entry_points('this_should_never_exist!!!!!!') + result = mod.load_entry_points('this_should_never_exist!!!!!!') self.assertEqual(result, {}) def test_basic(self): # load some entry points which should "always" be present, # even in a testing environment. basic sanity check - result = util.load_entry_points('console_scripts', ignore_errors=True) + result = mod.load_entry_points('console_scripts', ignore_errors=True) self.assertTrue(len(result) >= 1) self.assertIn('pip', result) @@ -45,7 +66,7 @@ class TestLoadEntryPoints(TestCase): # load some entry points which should "always" be present, # even in a testing environment. basic sanity check - result = util.load_entry_points('console_scripts', ignore_errors=True) + result = mod.load_entry_points('console_scripts', ignore_errors=True) self.assertTrue(len(result) >= 1) self.assertIn('pytest', result) @@ -82,7 +103,7 @@ class TestLoadEntryPoints(TestCase): # load some entry points which should "always" be present, # even in a testing environment. basic sanity check - result = util.load_entry_points('console_scripts', ignore_errors=True) + result = mod.load_entry_points('console_scripts', ignore_errors=True) self.assertTrue(len(result) >= 1) self.assertIn('pytest', result) @@ -104,7 +125,7 @@ class TestLoadEntryPoints(TestCase): with patch.dict('sys.modules', **{'importlib': importlib}): # empty set returned if errors suppressed - result = util.load_entry_points('wuttatest.thingers', ignore_errors=True) + result = mod.load_entry_points('wuttatest.thingers', ignore_errors=True) self.assertEqual(result, {}) importlib.metadata.entry_points.assert_called_once_with() entry_points.select.assert_called_once_with(group='wuttatest.thingers') @@ -114,7 +135,7 @@ class TestLoadEntryPoints(TestCase): importlib.metadata.entry_points.reset_mock() entry_points.select.reset_mock() entry_point.load.reset_mock() - self.assertRaises(NotImplementedError, util.load_entry_points, 'wuttatest.thingers') + self.assertRaises(NotImplementedError, mod.load_entry_points, 'wuttatest.thingers') importlib.metadata.entry_points.assert_called_once_with() entry_points.select.assert_called_once_with(group='wuttatest.thingers') entry_point.load.assert_called_once_with() @@ -123,96 +144,96 @@ class TestLoadEntryPoints(TestCase): class TestLoadObject(TestCase): def test_missing_spec(self): - self.assertRaises(ValueError, util.load_object, None) + self.assertRaises(ValueError, mod.load_object, None) def test_basic(self): - result = util.load_object('unittest:TestCase') + result = mod.load_object('unittest:TestCase') self.assertIs(result, TestCase) class TestMakeUUID(TestCase): def test_basic(self): - uuid = util.make_uuid() + uuid = mod.make_uuid() self.assertEqual(len(uuid), 32) class TestParseBool(TestCase): def test_null(self): - self.assertIsNone(util.parse_bool(None)) + self.assertIsNone(mod.parse_bool(None)) def test_bool(self): - self.assertTrue(util.parse_bool(True)) - self.assertFalse(util.parse_bool(False)) + self.assertTrue(mod.parse_bool(True)) + self.assertFalse(mod.parse_bool(False)) def test_string_true(self): - self.assertTrue(util.parse_bool('true')) - self.assertTrue(util.parse_bool('yes')) - self.assertTrue(util.parse_bool('y')) - self.assertTrue(util.parse_bool('on')) - self.assertTrue(util.parse_bool('1')) + self.assertTrue(mod.parse_bool('true')) + self.assertTrue(mod.parse_bool('yes')) + self.assertTrue(mod.parse_bool('y')) + self.assertTrue(mod.parse_bool('on')) + self.assertTrue(mod.parse_bool('1')) def test_string_false(self): - self.assertFalse(util.parse_bool('false')) - self.assertFalse(util.parse_bool('no')) - self.assertFalse(util.parse_bool('n')) - self.assertFalse(util.parse_bool('off')) - self.assertFalse(util.parse_bool('0')) + self.assertFalse(mod.parse_bool('false')) + self.assertFalse(mod.parse_bool('no')) + self.assertFalse(mod.parse_bool('n')) + self.assertFalse(mod.parse_bool('off')) + self.assertFalse(mod.parse_bool('0')) # nb. assume false for unrecognized input - self.assertFalse(util.parse_bool('whatever-else')) + self.assertFalse(mod.parse_bool('whatever-else')) class TestParseList(TestCase): def test_null(self): - value = util.parse_list(None) + value = mod.parse_list(None) self.assertIsInstance(value, list) self.assertEqual(len(value), 0) def test_list_instance(self): mylist = [] - value = util.parse_list(mylist) + value = mod.parse_list(mylist) self.assertIs(value, mylist) def test_single_value(self): - value = util.parse_list('foo') + value = mod.parse_list('foo') self.assertEqual(len(value), 1) self.assertEqual(value[0], 'foo') def test_single_value_padded_by_spaces(self): - value = util.parse_list(' foo ') + value = mod.parse_list(' foo ') self.assertEqual(len(value), 1) self.assertEqual(value[0], 'foo') def test_slash_is_not_a_separator(self): - value = util.parse_list('/dev/null') + value = mod.parse_list('/dev/null') self.assertEqual(len(value), 1) self.assertEqual(value[0], '/dev/null') def test_multiple_values_separated_by_whitespace(self): - value = util.parse_list('foo bar baz') + value = mod.parse_list('foo bar baz') self.assertEqual(len(value), 3) self.assertEqual(value[0], 'foo') self.assertEqual(value[1], 'bar') self.assertEqual(value[2], 'baz') def test_multiple_values_separated_by_commas(self): - value = util.parse_list('foo,bar,baz') + value = mod.parse_list('foo,bar,baz') self.assertEqual(len(value), 3) self.assertEqual(value[0], 'foo') self.assertEqual(value[1], 'bar') self.assertEqual(value[2], 'baz') def test_multiple_values_separated_by_whitespace_and_commas(self): - value = util.parse_list(' foo, bar baz') + value = mod.parse_list(' foo, bar baz') self.assertEqual(len(value), 3) self.assertEqual(value[0], 'foo') self.assertEqual(value[1], 'bar') self.assertEqual(value[2], 'baz') def test_multiple_values_separated_by_whitespace_and_commas_with_some_quoting(self): - value = util.parse_list(""" + value = mod.parse_list(""" foo "C:\\some path\\with spaces\\and, a comma", baz @@ -223,7 +244,7 @@ class TestParseList(TestCase): self.assertEqual(value[2], 'baz') def test_multiple_values_separated_by_whitespace_and_commas_with_single_quotes(self): - value = util.parse_list(""" + value = mod.parse_list(""" foo 'C:\\some path\\with spaces\\and, a comma', baz @@ -237,5 +258,5 @@ class TestParseList(TestCase): class TestMakeTitle(TestCase): def test_basic(self): - text = util.make_title('foo_bar') + text = mod.make_title('foo_bar') self.assertEqual(text, "Foo Bar")