3
0
Fork 0

fix: format all code with black

and from now on should not deviate from that...
This commit is contained in:
Lance Edgar 2025-08-30 21:25:44 -05:00
parent 49f9a0228b
commit a6bb538ce9
59 changed files with 2762 additions and 2131 deletions

View file

@ -8,43 +8,46 @@
from importlib.metadata import version as get_version from importlib.metadata import version as get_version
project = 'WuttJamaican' project = "WuttJamaican"
copyright = '2023-2024, Lance Edgar' copyright = "2023-2024, Lance Edgar"
author = 'Lance Edgar' author = "Lance Edgar"
release = get_version('WuttJamaican') release = get_version("WuttJamaican")
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinxcontrib.programoutput', "sphinxcontrib.programoutput",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.todo', "sphinx.ext.todo",
'enum_tools.autoenum', "enum_tools.autoenum",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
intersphinx_mapping = { intersphinx_mapping = {
'alembic': ('https://alembic.sqlalchemy.org/en/latest/', None), "alembic": ("https://alembic.sqlalchemy.org/en/latest/", None),
'humanize': ('https://humanize.readthedocs.io/en/stable/', None), "humanize": ("https://humanize.readthedocs.io/en/stable/", None),
'mako': ('https://docs.makotemplates.org/en/latest/', None), "mako": ("https://docs.makotemplates.org/en/latest/", None),
'packaging': ('https://packaging.python.org/en/latest/', None), "packaging": ("https://packaging.python.org/en/latest/", None),
'python': ('https://docs.python.org/3/', None), "python": ("https://docs.python.org/3/", None),
'python-configuration': ('https://python-configuration.readthedocs.io/en/latest/', None), "python-configuration": (
'rattail': ('https://docs.wuttaproject.org/rattail/', None), "https://python-configuration.readthedocs.io/en/latest/",
'rattail-manual': ('https://docs.wuttaproject.org/rattail-manual/', None), None,
'rich': ('https://rich.readthedocs.io/en/latest/', None), ),
'sqlalchemy': ('http://docs.sqlalchemy.org/en/latest/', None), "rattail": ("https://docs.wuttaproject.org/rattail/", None),
'wutta-continuum': ('https://docs.wuttaproject.org/wutta-continuum/', None), "rattail-manual": ("https://docs.wuttaproject.org/rattail-manual/", None),
"rich": ("https://rich.readthedocs.io/en/latest/", None),
"sqlalchemy": ("http://docs.sqlalchemy.org/en/latest/", None),
"wutta-continuum": ("https://docs.wuttaproject.org/wutta-continuum/", None),
} }
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'furo' html_theme = "furo"
html_static_path = ['_static'] html_static_path = ["_static"]

View file

@ -41,7 +41,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
db = ["SQLAlchemy", "alembic", "alembic-postgresql-enum", "passlib"] db = ["SQLAlchemy", "alembic", "alembic-postgresql-enum", "passlib"]
docs = ["Sphinx", "sphinxcontrib-programoutput", "enum-tools[sphinx]", "furo"] docs = ["Sphinx", "sphinxcontrib-programoutput", "enum-tools[sphinx]", "furo"]
tests = ["pylint", "pytest-cov", "tox"] tests = ["black", "pylint", "pytest", "pytest-cov", "tox"]
[project.scripts] [project.scripts]

View file

@ -6,4 +6,4 @@ Package Version
from importlib.metadata import version from importlib.metadata import version
__version__ = version('WuttJamaican') __version__ = version("WuttJamaican")

View file

@ -33,15 +33,23 @@ import warnings
import humanize import humanize
from wuttjamaican.util import (load_entry_points, load_object, from wuttjamaican.util import (
make_title, make_full_name, make_uuid, make_true_uuid, load_entry_points,
progress_loop, resource_path, simple_error) load_object,
make_title,
make_full_name,
make_uuid,
make_true_uuid,
progress_loop,
resource_path,
simple_error,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class AppHandler: # pylint: disable=too-many-public-methods class AppHandler: # pylint: disable=too-many-public-methods
""" """
Base class and default implementation for top-level :term:`app Base class and default implementation for top-level :term:`app
handler`. handler`.
@ -84,16 +92,17 @@ class AppHandler: # pylint: disable=too-many-public-methods
Dictionary of :class:`AppProvider` instances, as returned by Dictionary of :class:`AppProvider` instances, as returned by
:meth:`get_all_providers()`. :meth:`get_all_providers()`.
""" """
default_app_title = "WuttJamaican" default_app_title = "WuttJamaican"
default_model_spec = 'wuttjamaican.db.model' default_model_spec = "wuttjamaican.db.model"
default_enum_spec = 'wuttjamaican.enum' default_enum_spec = "wuttjamaican.enum"
default_auth_handler_spec = 'wuttjamaican.auth:AuthHandler' default_auth_handler_spec = "wuttjamaican.auth:AuthHandler"
default_db_handler_spec = 'wuttjamaican.db.handler:DatabaseHandler' default_db_handler_spec = "wuttjamaican.db.handler:DatabaseHandler"
default_email_handler_spec = 'wuttjamaican.email:EmailHandler' default_email_handler_spec = "wuttjamaican.email:EmailHandler"
default_install_handler_spec = 'wuttjamaican.install:InstallHandler' default_install_handler_spec = "wuttjamaican.install:InstallHandler"
default_people_handler_spec = 'wuttjamaican.people:PeopleHandler' default_people_handler_spec = "wuttjamaican.people:PeopleHandler"
default_problem_handler_spec = 'wuttjamaican.problems:ProblemHandler' default_problem_handler_spec = "wuttjamaican.problems:ProblemHandler"
default_report_handler_spec = 'wuttjamaican.reports:ReportHandler' default_report_handler_spec = "wuttjamaican.reports:ReportHandler"
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
@ -126,13 +135,13 @@ class AppHandler: # pylint: disable=too-many-public-methods
providers. providers.
""" """
if name == 'model': if name == "model":
return self.get_model() return self.get_model()
if name == 'enum': if name == "enum":
return self.get_enum() return self.get_enum()
if name == 'providers': if name == "providers":
self.providers = self.get_all_providers() self.providers = self.get_all_providers()
return self.providers return self.providers
@ -164,7 +173,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
""" """
# nb. must use 'wutta' and not self.appname prefix here, or # nb. must use 'wutta' and not self.appname prefix here, or
# else we can't find all providers with custom appname # else we can't find all providers with custom appname
providers = load_entry_points('wutta.app.providers') providers = load_entry_points("wutta.app.providers")
for key in list(providers): for key in list(providers):
providers[key] = providers[key](self.config) providers[key] = providers[key](self.config)
return providers return providers
@ -178,8 +187,9 @@ class AppHandler: # pylint: disable=too-many-public-methods
:returns: Title for the app. :returns: Title for the app.
""" """
return self.config.get(f'{self.appname}.app_title', return self.config.get(
default=default or self.default_app_title) f"{self.appname}.app_title", default=default or self.default_app_title
)
def get_node_title(self, default=None): def get_node_title(self, default=None):
""" """
@ -193,7 +203,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
:returns: Title for the local app node. :returns: Title for the local app node.
""" """
title = self.config.get(f'{self.appname}.node_title') title = self.config.get(f"{self.appname}.node_title")
if title: if title:
return title return title
return self.get_title(default=default) return self.get_title(default=default)
@ -217,8 +227,9 @@ class AppHandler: # pylint: disable=too-many-public-methods
[wutta] [wutta]
node_type = warehouse node_type = warehouse
""" """
return self.config.get(f'{self.appname}.node_type', default=default, return self.config.get(
usedb=False) f"{self.appname}.node_type", default=default, usedb=False
)
def get_distribution(self, obj=None): def get_distribution(self, obj=None):
""" """
@ -259,20 +270,20 @@ class AppHandler: # pylint: disable=too-many-public-methods
app_dist = My-Poser-Dist app_dist = My-Poser-Dist
""" """
if obj is None: if obj is None:
dist = self.config.get(f'{self.appname}.app_dist') dist = self.config.get(f"{self.appname}.app_dist")
if dist: if dist:
return dist return dist
# TODO: do we need a config setting for app_package ? # TODO: do we need a config setting for app_package ?
#modpath = self.config.get(f'{self.appname}.app_package') # modpath = self.config.get(f'{self.appname}.app_package')
modpath = None modpath = None
if not modpath: if not modpath:
modpath = type(obj if obj is not None else self).__module__ modpath = type(obj if obj is not None else self).__module__
pkgname = modpath.split('.')[0] pkgname = modpath.split(".")[0]
try: try:
from importlib.metadata import packages_distributions from importlib.metadata import packages_distributions
except ImportError: # python < 3.10 except ImportError: # python < 3.10
from importlib_metadata import packages_distributions from importlib_metadata import packages_distributions
pkgmap = packages_distributions() pkgmap = packages_distributions()
@ -281,7 +292,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
return dist return dist
# fall back to configured dist, if obj lookup failed # fall back to configured dist, if obj lookup failed
return self.config.get(f'{self.appname}.app_dist') return self.config.get(f"{self.appname}.app_dist")
def get_version(self, dist=None, obj=None): def get_version(self, dist=None, obj=None):
""" """
@ -320,10 +331,12 @@ class AppHandler: # pylint: disable=too-many-public-methods
config.setdefault('wutta.model_spec', 'poser.db.model') config.setdefault('wutta.model_spec', 'poser.db.model')
""" """
if 'model' not in self.__dict__: if "model" not in self.__dict__:
spec = self.config.get(f'{self.appname}.model_spec', spec = self.config.get(
usedb=False, f"{self.appname}.model_spec",
default=self.default_model_spec) usedb=False,
default=self.default_model_spec,
)
self.model = importlib.import_module(spec) self.model = importlib.import_module(spec)
return self.model return self.model
@ -344,10 +357,10 @@ class AppHandler: # pylint: disable=too-many-public-methods
config.setdefault('wutta.enum_spec', 'poser.enum') config.setdefault('wutta.enum_spec', 'poser.enum')
""" """
if 'enum' not in self.__dict__: if "enum" not in self.__dict__:
spec = self.config.get(f'{self.appname}.enum_spec', spec = self.config.get(
usedb=False, f"{self.appname}.enum_spec", usedb=False, default=self.default_enum_spec
default=self.default_enum_spec) )
self.enum = importlib.import_module(spec) self.enum = importlib.import_module(spec)
return self.enum return self.enum
@ -389,17 +402,17 @@ class AppHandler: # pylint: disable=too-many-public-methods
app.get_appdir('data') # => /srv/envs/poser/app/data app.get_appdir('data') # => /srv/envs/poser/app/data
""" """
configured_only = kwargs.pop('configured_only', False) configured_only = kwargs.pop("configured_only", False)
create = kwargs.pop('create', False) create = kwargs.pop("create", False)
# maybe specify default path # maybe specify default path
if not configured_only: if not configured_only:
path = os.path.join(sys.prefix, 'app') path = os.path.join(sys.prefix, "app")
kwargs.setdefault('default', path) kwargs.setdefault("default", path)
# get configured path # get configured path
kwargs.setdefault('usedb', False) kwargs.setdefault("usedb", False)
path = self.config.get(f'{self.appname}.appdir', **kwargs) path = self.config.get(f"{self.appname}.appdir", **kwargs)
# add any subpath info # add any subpath info
if path and args: if path and args:
@ -436,7 +449,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
os.makedirs(appdir) os.makedirs(appdir)
if not subfolders: if not subfolders:
subfolders = ['cache', 'data', 'log', 'work'] subfolders = ["cache", "data", "log", "work"]
for name in subfolders: for name in subfolders:
path = os.path.join(appdir, name) path = os.path.join(appdir, name)
@ -444,10 +457,10 @@ class AppHandler: # pylint: disable=too-many-public-methods
os.mkdir(path) os.mkdir(path)
def render_mako_template( def render_mako_template(
self, self,
template, template,
context, context,
output_path=None, output_path=None,
): ):
""" """
Convenience method to render a Mako template. Convenience method to render a Mako template.
@ -464,7 +477,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
""" """
output = template.render(**context) output = template.render(**context)
if output_path: if output_path:
with open(output_path, 'wt', encoding='utf_8') as f: with open(output_path, "wt", encoding="utf_8") as f:
f.write(output) f.write(output)
return output return output
@ -586,12 +599,12 @@ class AppHandler: # pylint: disable=too-many-public-methods
""" """
from .db import short_session from .db import short_session
if 'factory' not in kwargs and 'config' not in kwargs: if "factory" not in kwargs and "config" not in kwargs:
kwargs['factory'] = self.make_session kwargs["factory"] = self.make_session
return short_session(**kwargs) return short_session(**kwargs)
def get_setting(self, session, name, **kwargs): # pylint: disable=unused-argument def get_setting(self, session, name, **kwargs): # pylint: disable=unused-argument
""" """
Get a :term:`config setting` value from the DB. Get a :term:`config setting` value from the DB.
@ -617,13 +630,8 @@ class AppHandler: # pylint: disable=too-many-public-methods
return get_setting(session, name) return get_setting(session, name)
def save_setting( def save_setting(
self, self, session, name, value, force_create=False, **kwargs
session, ): # pylint: disable=unused-argument
name,
value,
force_create=False,
**kwargs
): # pylint: disable=unused-argument
""" """
Save a :term:`config setting` value to the DB. Save a :term:`config setting` value to the DB.
@ -664,7 +672,9 @@ class AppHandler: # pylint: disable=too-many-public-methods
# set value # set value
setting.value = value setting.value = value
def delete_setting(self, session, name, **kwargs): # pylint: disable=unused-argument def delete_setting(
self, session, name, **kwargs
): # pylint: disable=unused-argument
""" """
Delete a :term:`config setting` from the DB. Delete a :term:`config setting` from the DB.
@ -692,7 +702,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
:doc:`wutta-continuum:narr/install`. :doc:`wutta-continuum:narr/install`.
""" """
for provider in self.providers.values(): for provider in self.providers.values():
if hasattr(provider, 'continuum_is_enabled'): if hasattr(provider, "continuum_is_enabled"):
return provider.continuum_is_enabled() return provider.continuum_is_enabled()
return False return False
@ -710,7 +720,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
:returns: Display string for the value. :returns: Display string for the value.
""" """
if value is None: if value is None:
return '' return ""
return "Yes" if value else "No" return "Yes" if value else "No"
@ -727,7 +737,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
:returns: Display string for the value. :returns: Display string for the value.
""" """
if value is None: if value is None:
return '' return ""
if value < 0: if value < 0:
fmt = f"(${{:0,.{scale}f}})" fmt = f"(${{:0,.{scale}f}})"
@ -736,13 +746,13 @@ class AppHandler: # pylint: disable=too-many-public-methods
fmt = f"${{:0,.{scale}f}}" fmt = f"${{:0,.{scale}f}}"
return fmt.format(value) return fmt.format(value)
display_format_date = '%Y-%m-%d' display_format_date = "%Y-%m-%d"
""" """
Format string to use when displaying :class:`python:datetime.date` Format string to use when displaying :class:`python:datetime.date`
objects. See also :meth:`render_date()`. objects. See also :meth:`render_date()`.
""" """
display_format_datetime = '%Y-%m-%d %H:%M%z' display_format_datetime = "%Y-%m-%d %H:%M%z"
""" """
Format string to use when displaying Format string to use when displaying
:class:`python:datetime.datetime` objects. See also :class:`python:datetime.datetime` objects. See also
@ -800,9 +810,9 @@ class AppHandler: # pylint: disable=too-many-public-methods
""" """
if value is None: if value is None:
return "" return ""
fmt = f'{{:0.{decimals}f}} %' fmt = f"{{:0.{decimals}f}} %"
if value < 0: if value < 0:
return f'({fmt.format(-value)})' return f"({fmt.format(-value)})"
return fmt.format(value) return fmt.format(value)
def render_quantity(self, value, empty_zero=False): def render_quantity(self, value, empty_zero=False):
@ -819,13 +829,13 @@ class AppHandler: # pylint: disable=too-many-public-methods
:returns: Display string for the quantity. :returns: Display string for the quantity.
""" """
if value is None: if value is None:
return '' return ""
if int(value) == value: if int(value) == value:
value = int(value) value = int(value)
if empty_zero and value == 0: if empty_zero and value == 0:
return '' return ""
return str(value) return str(value)
return str(value).rstrip('0') return str(value).rstrip("0")
def render_time_ago(self, value): def render_time_ago(self, value):
""" """
@ -852,12 +862,13 @@ class AppHandler: # pylint: disable=too-many-public-methods
:rtype: :class:`~wuttjamaican.auth.AuthHandler` :rtype: :class:`~wuttjamaican.auth.AuthHandler`
""" """
if 'auth' not in self.handlers: if "auth" not in self.handlers:
spec = self.config.get(f'{self.appname}.auth.handler', spec = self.config.get(
default=self.default_auth_handler_spec) f"{self.appname}.auth.handler", default=self.default_auth_handler_spec
)
factory = self.load_object(spec) factory = self.load_object(spec)
self.handlers['auth'] = factory(self.config, **kwargs) self.handlers["auth"] = factory(self.config, **kwargs)
return self.handlers['auth'] return self.handlers["auth"]
def get_batch_handler(self, key, default=None, **kwargs): def get_batch_handler(self, key, default=None, **kwargs):
""" """
@ -872,10 +883,11 @@ class AppHandler: # pylint: disable=too-many-public-methods
for the requested type. If no spec can be determined, a for the requested type. If no spec can be determined, a
``KeyError`` is raised. ``KeyError`` is raised.
""" """
spec = self.config.get(f'{self.appname}.batch.{key}.handler.spec', spec = self.config.get(
default=default) f"{self.appname}.batch.{key}.handler.spec", default=default
)
if not spec: if not spec:
spec = self.config.get(f'{self.appname}.batch.{key}.handler.default_spec') spec = self.config.get(f"{self.appname}.batch.{key}.handler.default_spec")
if not spec: if not spec:
raise KeyError(f"handler spec not found for batch key: {key}") raise KeyError(f"handler spec not found for batch key: {key}")
factory = self.load_object(spec) factory = self.load_object(spec)
@ -924,13 +936,15 @@ class AppHandler: # pylint: disable=too-many-public-methods
handlers.extend(default) handlers.extend(default)
# configured default, if applicable # configured default, if applicable
default = self.config.get(f'{self.config.appname}.batch.{key}.handler.default_spec') default = self.config.get(
f"{self.config.appname}.batch.{key}.handler.default_spec"
)
if default and default not in handlers: if default and default not in handlers:
handlers.append(default) handlers.append(default)
# registered via entry points # registered via entry points
registered = [] registered = []
for handler in load_entry_points(f'{self.appname}.batch.{key}').values(): for handler in load_entry_points(f"{self.appname}.batch.{key}").values():
spec = handler.get_spec() spec = handler.get_spec()
if spec not in handlers: if spec not in handlers:
registered.append(spec) registered.append(spec)
@ -946,12 +960,13 @@ class AppHandler: # pylint: disable=too-many-public-methods
:rtype: :class:`~wuttjamaican.db.handler.DatabaseHandler` :rtype: :class:`~wuttjamaican.db.handler.DatabaseHandler`
""" """
if 'db' not in self.handlers: if "db" not in self.handlers:
spec = self.config.get(f'{self.appname}.db.handler', spec = self.config.get(
default=self.default_db_handler_spec) f"{self.appname}.db.handler", default=self.default_db_handler_spec
)
factory = self.load_object(spec) factory = self.load_object(spec)
self.handlers['db'] = factory(self.config, **kwargs) self.handlers["db"] = factory(self.config, **kwargs)
return self.handlers['db'] return self.handlers["db"]
def get_email_handler(self, **kwargs): def get_email_handler(self, **kwargs):
""" """
@ -961,12 +976,13 @@ class AppHandler: # pylint: disable=too-many-public-methods
:rtype: :class:`~wuttjamaican.email.EmailHandler` :rtype: :class:`~wuttjamaican.email.EmailHandler`
""" """
if 'email' not in self.handlers: if "email" not in self.handlers:
spec = self.config.get(f'{self.appname}.email.handler', spec = self.config.get(
default=self.default_email_handler_spec) f"{self.appname}.email.handler", default=self.default_email_handler_spec
)
factory = self.load_object(spec) factory = self.load_object(spec)
self.handlers['email'] = factory(self.config, **kwargs) self.handlers["email"] = factory(self.config, **kwargs)
return self.handlers['email'] return self.handlers["email"]
def get_install_handler(self, **kwargs): def get_install_handler(self, **kwargs):
""" """
@ -974,12 +990,14 @@ class AppHandler: # pylint: disable=too-many-public-methods
:rtype: :class:`~wuttjamaican.install.handler.InstallHandler` :rtype: :class:`~wuttjamaican.install.handler.InstallHandler`
""" """
if 'install' not in self.handlers: if "install" not in self.handlers:
spec = self.config.get(f'{self.appname}.install.handler', spec = self.config.get(
default=self.default_install_handler_spec) f"{self.appname}.install.handler",
default=self.default_install_handler_spec,
)
factory = self.load_object(spec) factory = self.load_object(spec)
self.handlers['install'] = factory(self.config, **kwargs) self.handlers["install"] = factory(self.config, **kwargs)
return self.handlers['install'] return self.handlers["install"]
def get_people_handler(self, **kwargs): def get_people_handler(self, **kwargs):
""" """
@ -987,12 +1005,14 @@ class AppHandler: # pylint: disable=too-many-public-methods
:rtype: :class:`~wuttjamaican.people.PeopleHandler` :rtype: :class:`~wuttjamaican.people.PeopleHandler`
""" """
if 'people' not in self.handlers: if "people" not in self.handlers:
spec = self.config.get(f'{self.appname}.people.handler', spec = self.config.get(
default=self.default_people_handler_spec) f"{self.appname}.people.handler",
default=self.default_people_handler_spec,
)
factory = self.load_object(spec) factory = self.load_object(spec)
self.handlers['people'] = factory(self.config, **kwargs) self.handlers["people"] = factory(self.config, **kwargs)
return self.handlers['people'] return self.handlers["people"]
def get_problem_handler(self, **kwargs): def get_problem_handler(self, **kwargs):
""" """
@ -1000,13 +1020,15 @@ class AppHandler: # pylint: disable=too-many-public-methods
:rtype: :class:`~wuttjamaican.problems.ProblemHandler` :rtype: :class:`~wuttjamaican.problems.ProblemHandler`
""" """
if 'problems' not in self.handlers: if "problems" not in self.handlers:
spec = self.config.get(f'{self.appname}.problems.handler', spec = self.config.get(
default=self.default_problem_handler_spec) f"{self.appname}.problems.handler",
default=self.default_problem_handler_spec,
)
log.debug("problem_handler spec is: %s", spec) log.debug("problem_handler spec is: %s", spec)
factory = self.load_object(spec) factory = self.load_object(spec)
self.handlers['problems'] = factory(self.config, **kwargs) self.handlers["problems"] = factory(self.config, **kwargs)
return self.handlers['problems'] return self.handlers["problems"]
def get_report_handler(self, **kwargs): def get_report_handler(self, **kwargs):
""" """
@ -1014,12 +1036,14 @@ class AppHandler: # pylint: disable=too-many-public-methods
:rtype: :class:`~wuttjamaican.reports.ReportHandler` :rtype: :class:`~wuttjamaican.reports.ReportHandler`
""" """
if 'reports' not in self.handlers: if "reports" not in self.handlers:
spec = self.config.get(f'{self.appname}.reports.handler_spec', spec = self.config.get(
default=self.default_report_handler_spec) f"{self.appname}.reports.handler_spec",
default=self.default_report_handler_spec,
)
factory = self.load_object(spec) factory = self.load_object(spec)
self.handlers['reports'] = factory(self.config, **kwargs) self.handlers["reports"] = factory(self.config, **kwargs)
return self.handlers['reports'] return self.handlers["reports"]
############################## ##############################
# convenience delegators # convenience delegators
@ -1046,7 +1070,7 @@ class AppHandler: # pylint: disable=too-many-public-methods
self.get_email_handler().send_email(*args, **kwargs) self.get_email_handler().send_email(*args, **kwargs)
class AppProvider: # pylint: disable=too-few-public-methods class AppProvider: # pylint: disable=too-few-public-methods
""" """
Base class for :term:`app providers<app provider>`. Base class for :term:`app providers<app provider>`.
@ -1092,9 +1116,12 @@ class AppProvider: # pylint: disable=too-few-public-methods
def __init__(self, config): def __init__(self, config):
if isinstance(config, AppHandler): if isinstance(config, AppHandler):
warnings.warn("passing app handler to app provider is deprecated; " warnings.warn(
"must pass config object instead", "passing app handler to app provider is deprecated; "
DeprecationWarning, stacklevel=2) "must pass config object instead",
DeprecationWarning,
stacklevel=2,
)
config = config.config config = config.config
self.config = config self.config = config
@ -1140,7 +1167,7 @@ class GenericHandler:
""" """
Returns the class :term:`spec` string for the handler. Returns the class :term:`spec` string for the handler.
""" """
return f'{cls.__module__}:{cls.__name__}' return f"{cls.__module__}:{cls.__name__}"
def get_provider_modules(self, module_type): def get_provider_modules(self, module_type):
""" """
@ -1163,7 +1190,7 @@ class GenericHandler:
if module_type not in self.modules: if module_type not in self.modules:
self.modules[module_type] = [] self.modules[module_type] = []
for provider in self.app.providers.values(): for provider in self.app.providers.values():
name = f'{module_type}_modules' name = f"{module_type}_modules"
if hasattr(provider, name): if hasattr(provider, name):
modules = getattr(provider, name) modules = getattr(provider, name)
if modules: if modules:

View file

@ -35,14 +35,13 @@ from wuttjamaican.app import GenericHandler
# nb. this only works if passlib is installed (part of 'db' extra) # nb. this only works if passlib is installed (part of 'db' extra)
try: try:
from passlib.context import CryptContext from passlib.context import CryptContext
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
pass pass
else: else:
password_context = CryptContext(schemes=['bcrypt']) password_context = CryptContext(schemes=["bcrypt"])
class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
""" """
Base class and default implementation for the :term:`auth Base class and default implementation for the :term:`auth
handler`. handler`.
@ -114,9 +113,11 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
model = self.app.model model = self.app.model
try: try:
token = session.query(model.UserAPIToken)\ token = (
.filter(model.UserAPIToken.token_string == token)\ session.query(model.UserAPIToken)
.one() .filter(model.UserAPIToken.token_string == token)
.one()
)
except orm.exc.NoResultFound: except orm.exc.NoResultFound:
pass pass
else: else:
@ -168,7 +169,7 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
if role: if role:
return role return role
else: # assuming it is a string else: # assuming it is a string
# try to match on Role.uuid # try to match on Role.uuid
try: try:
@ -179,15 +180,12 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
pass pass
# try to match on Role.name # try to match on Role.name
role = session.query(model.Role)\ role = session.query(model.Role).filter_by(name=key).first()
.filter_by(name=key)\
.first()
if role: if role:
return role return role
# try settings; if value then recurse # try settings; if value then recurse
key = self.config.get(f'{self.appname}.role.{key}', key = self.config.get(f"{self.appname}.role.{key}", session=session)
session=session)
if key: if key:
return self.get_role(session, key) return self.get_role(session, key)
return None return None
@ -247,9 +245,9 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
pass pass
# try to match on User.username # try to match on User.username
user = session.query(model.User)\ user = (
.filter(model.User.username == obj)\ session.query(model.User).filter(model.User.username == obj).first()
.first() )
if user: if user:
return user return user
@ -297,8 +295,8 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
""" """
model = self.app.model model = self.app.model
if session and 'username' not in kwargs: if session and "username" not in kwargs:
kwargs['username'] = self.make_unique_username(session, **kwargs) kwargs["username"] = self.make_unique_username(session, **kwargs)
user = model.User(**kwargs) user = model.User(**kwargs)
if session: if session:
@ -320,7 +318,9 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
session = self.app.get_session(user) session = self.app.get_session(user)
session.delete(user) session.delete(user)
def make_preferred_username(self, session, **kwargs): # pylint: disable=unused-argument def make_preferred_username(
self, session, **kwargs
): # pylint: disable=unused-argument
""" """
Generate a "preferred" username, using data from ``kwargs`` as Generate a "preferred" username, using data from ``kwargs`` as
hints. hints.
@ -347,18 +347,18 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: Generated username as string. :returns: Generated username as string.
""" """
person = kwargs.get('person') person = kwargs.get("person")
if person: if person:
first = (person.first_name or '').strip().lower() first = (person.first_name or "").strip().lower()
last = (person.last_name or '').strip().lower() last = (person.last_name or "").strip().lower()
if first and last: if first and last:
return f'{first}.{last}' return f"{first}.{last}"
if first: if first:
return first return first
if last: if last:
return last return last
return 'newuser' return "newuser"
def make_unique_username(self, session, **kwargs): def make_unique_username(self, session, **kwargs):
""" """
@ -398,9 +398,11 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
# check for unique username # check for unique username
counter = 1 counter = 1
while True: while True:
users = session.query(model.User)\ users = (
.filter(model.User.username == username)\ session.query(model.User)
.count() .filter(model.User.username == username)
.count()
)
if not users: if not users:
break break
username = f"{original_username}{counter:02d}" username = f"{original_username}{counter:02d}"
@ -426,22 +428,25 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
""" """
Returns the special "Administrator" role. Returns the special "Administrator" role.
""" """
return self._special_role(session, _uuid.UUID('d937fa8a965611dfa0dd001143047286'), return self._special_role(
"Administrator") session, _uuid.UUID("d937fa8a965611dfa0dd001143047286"), "Administrator"
)
def get_role_anonymous(self, session): def get_role_anonymous(self, session):
""" """
Returns the special "Anonymous" (aka. "Guest") role. Returns the special "Anonymous" (aka. "Guest") role.
""" """
return self._special_role(session, _uuid.UUID('f8a27c98965a11dfaff7001143047286'), return self._special_role(
"Anonymous") session, _uuid.UUID("f8a27c98965a11dfaff7001143047286"), "Anonymous"
)
def get_role_authenticated(self, session): def get_role_authenticated(self, session):
""" """
Returns the special "Authenticated" role. Returns the special "Authenticated" role.
""" """
return self._special_role(session, _uuid.UUID('b765a9cc331a11e6ac2a3ca9f40bc550'), return self._special_role(
"Authenticated") session, _uuid.UUID("b765a9cc331a11e6ac2a3ca9f40bc550"), "Authenticated"
)
def user_is_admin(self, user): def user_is_admin(self, user):
""" """
@ -457,9 +462,9 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
return False return False
def get_permissions(self, session, principal, def get_permissions(
include_anonymous=True, self, session, principal, include_anonymous=True, include_authenticated=True
include_authenticated=True): ):
""" """
Return a set of permission names, which represents all Return a set of permission names, which represents all
permissions effectively granted to the given user or role. permissions effectively granted to the given user or role.
@ -483,10 +488,8 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
""" """
# we will use any `roles` attribute which may be present. in # we will use any `roles` attribute which may be present. in
# practice we would be assuming a User in this case # practice we would be assuming a User in this case
if hasattr(principal, 'roles'): if hasattr(principal, "roles"):
roles = [role roles = [role for role in principal.roles if self._role_is_pertinent(role)]
for role in principal.roles
if self._role_is_pertinent(role)]
# here our User assumption gets a little more explicit # here our User assumption gets a little more explicit
if include_authenticated: if include_authenticated:
@ -507,14 +510,19 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
# build the permissions cache # build the permissions cache
cache = set() cache = set()
for role in roles: for role in roles:
if hasattr(role, 'permissions'): if hasattr(role, "permissions"):
cache.update(role.permissions) cache.update(role.permissions)
return cache return cache
def has_permission(self, session, principal, permission, def has_permission(
include_anonymous=True, self,
include_authenticated=True): session,
principal,
permission,
include_anonymous=True,
include_authenticated=True,
):
""" """
Check if the given user or role has been granted the given Check if the given user or role has been granted the given
permission. permission.
@ -551,9 +559,12 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: Boolean indicating if the permission is granted. :returns: Boolean indicating if the permission is granted.
""" """
perms = self.get_permissions(session, principal, perms = self.get_permissions(
include_anonymous=include_anonymous, session,
include_authenticated=include_authenticated) principal,
include_anonymous=include_anonymous,
include_authenticated=include_authenticated,
)
return permission in perms return permission in perms
def grant_permission(self, role, permission): def grant_permission(self, role, permission):
@ -608,9 +619,7 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
token_string = self.generate_api_token_string() token_string = self.generate_api_token_string()
# persist token in DB # persist token in DB
token = model.UserAPIToken( token = model.UserAPIToken(description=description, token_string=token_string)
description=description,
token_string=token_string)
user.api_tokens.append(token) user.api_tokens.append(token)
session.add(token) session.add(token)
@ -642,7 +651,7 @@ class AuthHandler(GenericHandler): # pylint: disable=too-many-public-methods
# internal methods # internal methods
############################## ##############################
def _role_is_pertinent(self, role): # pylint: disable=unused-argument def _role_is_pertinent(self, role): # pylint: disable=unused-argument
""" """
Check the role to ensure it is "pertinent" for the current app. Check the role to ensure it is "pertinent" for the current app.

View file

@ -31,7 +31,7 @@ import shutil
from wuttjamaican.app import GenericHandler from wuttjamaican.app import GenericHandler
class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
""" """
Base class and *partial* default implementation for :term:`batch Base class and *partial* default implementation for :term:`batch
handlers <batch handler>`. handlers <batch handler>`.
@ -59,8 +59,10 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
Subclass must define this; default is not implemented. Subclass must define this; default is not implemented.
""" """
raise NotImplementedError("You must set the 'model_class' attribute " raise NotImplementedError(
f"for class '{self.__class__.__name__}'") "You must set the 'model_class' attribute "
f"for class '{self.__class__.__name__}'"
)
@property @property
def batch_type(self): def batch_type(self):
@ -99,8 +101,8 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: New batch; instance of :attr:`model_class`. :returns: New batch; instance of :attr:`model_class`.
""" """
# generate new ID unless caller specifies # generate new ID unless caller specifies
if 'id' not in kwargs: if "id" not in kwargs:
kwargs['id'] = self.consume_batch_id(session) kwargs["id"] = self.consume_batch_id(session)
# make batch # make batch
batch = self.model_class(**kwargs) batch = self.model_class(**kwargs)
@ -121,9 +123,9 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: Batch ID as integer, or zero-padded 8-char string. :returns: Batch ID as integer, or zero-padded 8-char string.
""" """
db = self.app.get_db_handler() db = self.app.get_db_handler()
batch_id = db.next_counter_value(session, 'batch_id') batch_id = db.next_counter_value(session, "batch_id")
if as_str: if as_str:
return f'{batch_id:08d}' return f"{batch_id:08d}"
return batch_id return batch_id
def init_batch(self, batch, session=None, progress=None, **kwargs): def init_batch(self, batch, session=None, progress=None, **kwargs):
@ -178,10 +180,10 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: Path to root data dir for handler's batch type. :returns: Path to root data dir for handler's batch type.
""" """
# get root storage path # get root storage path
rootdir = self.config.get(f'{self.config.appname}.batch.storage_path') rootdir = self.config.get(f"{self.config.appname}.batch.storage_path")
if not rootdir: if not rootdir:
appdir = self.app.get_appdir() appdir = self.app.get_appdir()
rootdir = os.path.join(appdir, 'data', 'batch') rootdir = os.path.join(appdir, "data", "batch")
# get path for this batch type # get path for this batch type
path = os.path.join(rootdir, self.batch_type) path = os.path.join(rootdir, self.batch_type)
@ -204,7 +206,7 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
return path return path
def should_populate(self, batch): # pylint: disable=unused-argument def should_populate(self, batch): # pylint: disable=unused-argument
""" """
Must return true or false, indicating whether the given batch Must return true or false, indicating whether the given batch
should be populated from initial data source(s). should be populated from initial data source(s).
@ -348,7 +350,9 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
* :attr:`~wuttjamaican.db.model.batch.BatchMixin.status_text` * :attr:`~wuttjamaican.db.model.batch.BatchMixin.status_text`
""" """
def why_not_execute(self, batch, user=None, **kwargs): # pylint: disable=unused-argument def why_not_execute(
self, batch, user=None, **kwargs
): # pylint: disable=unused-argument
""" """
Returns text indicating the reason (if any) that a given batch Returns text indicating the reason (if any) that a given batch
should *not* be executed. should *not* be executed.
@ -468,16 +472,22 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
if batch.executed: if batch.executed:
raise ValueError(f"batch has already been executed: {batch}") raise ValueError(f"batch has already been executed: {batch}")
reason = self.why_not_execute(batch, user=user, **kwargs) # pylint: disable=assignment-from-none reason = self.why_not_execute( # pylint: disable=assignment-from-none
batch, user=user, **kwargs
)
if reason: if reason:
raise RuntimeError(f"batch execution not allowed: {reason}") raise RuntimeError(f"batch execution not allowed: {reason}")
result = self.execute(batch, user=user, progress=progress, **kwargs) # pylint: disable=assignment-from-none result = self.execute( # pylint: disable=assignment-from-none
batch, user=user, progress=progress, **kwargs
)
batch.executed = datetime.datetime.now() batch.executed = datetime.datetime.now()
batch.executed_by = user batch.executed_by = user
return result return result
def execute(self, batch, user=None, progress=None, **kwargs): # pylint: disable=unused-argument def execute(
self, batch, user=None, progress=None, **kwargs
): # pylint: disable=unused-argument
""" """
Execute the given batch. Execute the given batch.
@ -505,7 +515,9 @@ class BatchHandler(GenericHandler): # pylint: disable=too-many-public-methods
""" """
return None return None
def do_delete(self, batch, user, dry_run=False, progress=None, **kwargs): # pylint: disable=unused-argument def do_delete(
self, batch, user, dry_run=False, progress=None, **kwargs
): # pylint: disable=unused-argument
""" """
Delete the given batch entirely. Delete the given batch entirely.

View file

@ -40,4 +40,5 @@ from . import problems
# discover more commands, installed via other packages # discover more commands, installed via other packages
from .base import typer_eager_imports from .base import typer_eager_imports
typer_eager_imports(wutta_typer) typer_eager_imports(wutta_typer)

View file

@ -59,18 +59,21 @@ def make_cli_config(ctx: typer.Context):
:returns: :class:`~wuttjamaican.conf.WuttaConfig` instance :returns: :class:`~wuttjamaican.conf.WuttaConfig` instance
""" """
logging.basicConfig() logging.basicConfig()
return make_config(files=ctx.params.get('config_paths') or None) return make_config(files=ctx.params.get("config_paths") or None)
def typer_callback( def typer_callback(
ctx: typer.Context, ctx: typer.Context,
config_paths: Annotated[
config_paths: Annotated[ Optional[List[Path]],
Optional[List[Path]], typer.Option(
typer.Option('--config', '-c', "--config",
exists=True, "-c",
help="Config path (may be specified more than once)")] = None, exists=True,
): # pylint: disable=unused-argument help="Config path (may be specified more than once)",
),
] = None,
): # pylint: disable=unused-argument
""" """
Generic callback for use with top-level commands. This adds some Generic callback for use with top-level commands. This adds some
top-level args: top-level args:
@ -85,8 +88,7 @@ def typer_callback(
ctx.wutta_config = make_cli_config(ctx) ctx.wutta_config = make_cli_config(ctx)
def typer_eager_imports( def typer_eager_imports(group: [typer.Typer, str]):
group: [typer.Typer, str]):
""" """
Eagerly import all modules which are registered as having Eagerly import all modules which are registered as having
:term:`subcommands <subcommand>` belonging to the given group :term:`subcommands <subcommand>` belonging to the given group
@ -119,7 +121,7 @@ def typer_eager_imports(
""" """
if isinstance(group, typer.Typer): if isinstance(group, typer.Typer):
group = group.info.name group = group.info.name
load_entry_points(f'{group}.typer_imports') load_entry_points(f"{group}.typer_imports")
def make_typer(**kwargs): def make_typer(**kwargs):
@ -136,11 +138,8 @@ def make_typer(**kwargs):
:returns: ``typer.Typer`` instance :returns: ``typer.Typer`` instance
""" """
kwargs.setdefault('callback', typer_callback) kwargs.setdefault("callback", typer_callback)
return typer.Typer(**kwargs) return typer.Typer(**kwargs)
wutta_typer = make_typer( wutta_typer = make_typer(name="wutta", help="Wutta Software Framework")
name='wutta',
help="Wutta Software Framework"
)

View file

@ -35,13 +35,16 @@ from .base import wutta_typer
@wutta_typer.command() @wutta_typer.command()
def make_appdir( def make_appdir(
ctx: typer.Context, ctx: typer.Context,
appdir_path: Annotated[ appdir_path: Annotated[
Path, Path,
typer.Option('--path', typer.Option(
help="Path to desired app dir; default is (usually) " "--path",
"`app` in the root of virtual environment.")] = None, help="Path to desired app dir; default is (usually) "
): # pylint: disable=unused-argument "`app` in the root of virtual environment.",
),
] = None,
): # pylint: disable=unused-argument
""" """
Make the app dir for virtual environment Make the app dir for virtual environment
@ -49,6 +52,6 @@ def make_appdir(
""" """
config = ctx.parent.wutta_config config = ctx.parent.wutta_config
app = config.get_app() app = config.get_app()
appdir = ctx.params['appdir_path'] or app.get_appdir() appdir = ctx.params["appdir_path"] or app.get_appdir()
app.make_appdir(appdir) app.make_appdir(appdir)
sys.stdout.write(f"established appdir: {appdir}\n") sys.stdout.write(f"established appdir: {appdir}\n")

View file

@ -33,7 +33,7 @@ from .base import wutta_typer
@wutta_typer.command() @wutta_typer.command()
def make_uuid( def make_uuid(
ctx: typer.Context, ctx: typer.Context,
): ):
""" """
Generate a new UUID Generate a new UUID

View file

@ -36,25 +36,34 @@ from .base import wutta_typer
@wutta_typer.command() @wutta_typer.command()
def problems( def problems(
ctx: typer.Context, ctx: typer.Context,
systems: Annotated[
systems: Annotated[ List[str],
List[str], typer.Option(
typer.Option('--system', '-s', "--system",
help="System for which to perform checks; can be specified more " "-s",
"than once. If not specified, all systems are assumed.")] = None, help="System for which to perform checks; can be specified more "
"than once. If not specified, all systems are assumed.",
problems: Annotated[ # pylint: disable=redefined-outer-name ),
List[str], ] = None,
typer.Option('--problem', '-p', problems: Annotated[ # pylint: disable=redefined-outer-name
help="Identify a particular problem check; can be specified " List[str],
"more than once. If not specified, all checks are assumed.")] = None, typer.Option(
"--problem",
list_checks: Annotated[ "-p",
bool, help="Identify a particular problem check; can be specified "
typer.Option('--list', '-l', "more than once. If not specified, all checks are assumed.",
help="List available problem checks; optionally filtered " ),
"per --system and --problem")] = False, ] = None,
list_checks: Annotated[
bool,
typer.Option(
"--list",
"-l",
help="List available problem checks; optionally filtered "
"per --system and --problem",
),
] = False,
): ):
""" """
Find and report on problems with the data or system. Find and report on problems with the data or system.
@ -65,9 +74,11 @@ def problems(
# try to warn user if unknown system is specified; but otherwise ignore # try to warn user if unknown system is specified; but otherwise ignore
supported = handler.get_supported_systems() supported = handler.get_supported_systems()
for key in (systems or []): for key in systems or []:
if key not in supported: if key not in supported:
rich.print(f"\n[bold yellow]No problem reports exist for system: {key}[/bold yellow]") rich.print(
f"\n[bold yellow]No problem reports exist for system: {key}[/bold yellow]"
)
checks = handler.filter_problem_checks(systems=systems, problems=problems) checks = handler.filter_problem_checks(systems=systems, problems=problems)

View file

@ -34,16 +34,20 @@ import tempfile
import config as configuration import config as configuration
from wuttjamaican.util import (load_entry_points, load_object, from wuttjamaican.util import (
parse_bool, parse_list, load_entry_points,
UNSPECIFIED) load_object,
parse_bool,
parse_list,
UNSPECIFIED,
)
from wuttjamaican.exc import ConfigurationError from wuttjamaican.exc import ConfigurationError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class WuttaConfig: # pylint: disable=too-many-instance-attributes class WuttaConfig: # pylint: disable=too-many-instance-attributes
""" """
Configuration class for Wutta Framework Configuration class for Wutta Framework
@ -184,17 +188,18 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
See also :ref:`where-config-settings-come-from`. See also :ref:`where-config-settings-come-from`.
""" """
default_app_handler_spec = 'wuttjamaican.app:AppHandler'
default_engine_maker_spec = 'wuttjamaican.db.conf:make_engine_from_config' default_app_handler_spec = "wuttjamaican.app:AppHandler"
default_engine_maker_spec = "wuttjamaican.db.conf:make_engine_from_config"
def __init__( def __init__(
self, self,
files=None, files=None,
defaults=None, defaults=None,
appname='wutta', appname="wutta",
usedb=None, usedb=None,
preferdb=None, preferdb=None,
configure_logging=None, configure_logging=None,
): ):
self.appname = appname self.appname = appname
configs = [] configs = []
@ -213,36 +218,41 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
# establish logging # establish logging
if configure_logging is None: if configure_logging is None:
configure_logging = self.get_bool(f'{self.appname}.config.configure_logging', configure_logging = self.get_bool(
default=False, usedb=False) f"{self.appname}.config.configure_logging", default=False, usedb=False
)
if configure_logging: if configure_logging:
self._configure_logging() self._configure_logging()
# usedb flag # usedb flag
self.usedb = usedb self.usedb = usedb
if self.usedb is None: if self.usedb is None:
self.usedb = self.get_bool(f'{self.appname}.config.usedb', self.usedb = self.get_bool(
default=False, usedb=False) f"{self.appname}.config.usedb", default=False, usedb=False
)
# preferdb flag # preferdb flag
self.preferdb = preferdb self.preferdb = preferdb
if self.usedb and self.preferdb is None: if self.usedb and self.preferdb is None:
self.preferdb = self.get_bool(f'{self.appname}.config.preferdb', self.preferdb = self.get_bool(
default=False, usedb=False) f"{self.appname}.config.preferdb", default=False, usedb=False
)
# configure main app DB if applicable, or disable usedb flag # configure main app DB if applicable, or disable usedb flag
try: try:
from wuttjamaican.db import Session, get_engines from wuttjamaican.db import Session, get_engines
except ImportError: except ImportError:
if self.usedb: if self.usedb:
log.warning("config created with `usedb = True`, but can't import " log.warning(
"DB module(s), so setting `usedb = False` instead", "config created with `usedb = True`, but can't import "
exc_info=True) "DB module(s), so setting `usedb = False` instead",
exc_info=True,
)
self.usedb = False self.usedb = False
self.preferdb = False self.preferdb = False
else: else:
self.appdb_engines = get_engines(self, f'{self.appname}.db') self.appdb_engines = get_engines(self, f"{self.appname}.db")
self.appdb_engine = self.appdb_engines.get('default') self.appdb_engine = self.appdb_engines.get("default")
Session.configure(bind=self.appdb_engine) Session.configure(bind=self.appdb_engine)
log.debug("config files read: %s", self.files_read) log.debug("config files read: %s", self.files_read)
@ -256,7 +266,7 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
# try to load config with standard parser, and default vars # try to load config with standard parser, and default vars
here = os.path.dirname(path) here = os.path.dirname(path)
config = configparser.ConfigParser(defaults={'here': here, '__file__': path}) config = configparser.ConfigParser(defaults={"here": here, "__file__": path})
if not config.read(path): if not config.read(path):
if require: if require:
raise FileNotFoundError(f"could not read required config file: {path}") raise FileNotFoundError(f"could not read required config file: {path}")
@ -267,14 +277,14 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
for section in config.sections(): for section in config.sections():
temp_config.add_section(section) temp_config.add_section(section)
# nb. must interpolate most values but *not* for logging formatters # nb. must interpolate most values but *not* for logging formatters
raw = section.startswith('formatter_') raw = section.startswith("formatter_")
for option in config.options(section): for option in config.options(section):
temp_config.set(section, option, config.get(section, option, raw=raw)) temp_config.set(section, option, config.get(section, option, raw=raw))
# re-write as temp file with "final" values # re-write as temp file with "final" values
fd, temp_path = tempfile.mkstemp(suffix='.ini') fd, temp_path = tempfile.mkstemp(suffix=".ini")
os.close(fd) os.close(fd)
with open(temp_path, 'wt', encoding='utf_8') as f: with open(temp_path, "wt", encoding="utf_8") as f:
temp_config.write(f) temp_config.write(f)
# and finally, load that into our main config # and finally, load that into our main config
@ -284,13 +294,13 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
os.remove(temp_path) os.remove(temp_path)
# bring in any "required" files # bring in any "required" files
requires = config.get(f'{self.appname}.config.require') requires = config.get(f"{self.appname}.config.require")
if requires: if requires:
for p in self.parse_list(requires): for p in self.parse_list(requires):
self._load_ini_configs(p, configs, require=True) self._load_ini_configs(p, configs, require=True)
# bring in any "included" files # bring in any "included" files
includes = config.get(f'{self.appname}.config.include') includes = config.get(f"{self.appname}.config.include")
if includes: if includes:
for p in self.parse_list(includes): for p in self.parse_list(includes):
self._load_ini_configs(p, configs, require=False) self._load_ini_configs(p, configs, require=False)
@ -304,10 +314,7 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
""" """
return self.files_read return self.files_read
def setdefault( def setdefault(self, key, value):
self,
key,
value):
""" """
Establish a default config value for the given key. Establish a default config value for the given key.
@ -327,16 +334,16 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
return self.get(key, usedb=False) return self.get(key, usedb=False)
def get( def get(
self, self,
key, key,
default=UNSPECIFIED, default=UNSPECIFIED,
require=False, require=False,
ignore_ambiguous=False, ignore_ambiguous=False,
message=None, message=None,
usedb=None, usedb=None,
preferdb=None, preferdb=None,
session=None, session=None,
**kwargs **kwargs,
): ):
""" """
Retrieve a string value from config. Retrieve a string value from config.
@ -489,7 +496,7 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
config.require('foo') config.require('foo')
""" """
kwargs['require'] = True kwargs["require"] = True
return self.get(*args, **kwargs) return self.get(*args, **kwargs)
def get_bool(self, *args, **kwargs): def get_bool(self, *args, **kwargs):
@ -613,9 +620,9 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
parser.set(section, option, value) parser.set(section, option, value)
# write INI file and return path # write INI file and return path
fd, path = tempfile.mkstemp(suffix='.conf') fd, path = tempfile.mkstemp(suffix=".conf")
os.close(fd) os.close(fd)
with open(path, 'wt', encoding='utf_8') as f: with open(path, "wt", encoding="utf_8") as f:
parser.write(f) parser.write(f)
return path return path
@ -626,9 +633,12 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
See also :doc:`/narr/handlers/app`. See also :doc:`/narr/handlers/app`.
""" """
if not hasattr(self, '_app'): if not hasattr(self, "_app"):
spec = self.get(f'{self.appname}.app.handler', usedb=False, spec = self.get(
default=self.default_app_handler_spec) f"{self.appname}.app.handler",
usedb=False,
default=self.default_app_handler_spec,
)
factory = load_object(spec) factory = load_object(spec)
self._app = factory(self) self._app = factory(self)
return self._app return self._app
@ -656,13 +666,14 @@ class WuttaConfig: # pylint: disable=too-many-instance-attributes
[wutta] [wutta]
production = true production = true
""" """
return self.get_bool(f'{self.appname}.production', default=False) return self.get_bool(f"{self.appname}.production", default=False)
class WuttaConfigExtension: class WuttaConfigExtension:
""" """
Base class for all :term:`config extensions <config extension>`. Base class for all :term:`config extensions <config extension>`.
""" """
key = None key = None
def __repr__(self): def __repr__(self):
@ -696,7 +707,7 @@ def generic_default_files(appname):
:returns: List of default file paths. :returns: List of default file paths.
""" """
if sys.platform == 'win32': if sys.platform == "win32":
# use pywin32 to fetch official defaults # use pywin32 to fetch official defaults
try: try:
from win32com.shell import shell, shellcon from win32com.shell import shell, shellcon
@ -705,42 +716,49 @@ def generic_default_files(appname):
return [ return [
# e.g. C:\..?? TODO: what is the user-specific path on win32? # e.g. C:\..?? TODO: what is the user-specific path on win32?
os.path.join(shell.SHGetSpecialFolderPath( os.path.join(
0, shellcon.CSIDL_APPDATA), appname, f'{appname}.conf'), shell.SHGetSpecialFolderPath(0, shellcon.CSIDL_APPDATA),
os.path.join(shell.SHGetSpecialFolderPath( appname,
0, shellcon.CSIDL_APPDATA), f'{appname}.conf'), f"{appname}.conf",
),
os.path.join(
shell.SHGetSpecialFolderPath(0, shellcon.CSIDL_APPDATA),
f"{appname}.conf",
),
# e.g. C:\ProgramData\wutta\wutta.conf # e.g. C:\ProgramData\wutta\wutta.conf
os.path.join(shell.SHGetSpecialFolderPath( os.path.join(
0, shellcon.CSIDL_COMMON_APPDATA), appname, f'{appname}.conf'), shell.SHGetSpecialFolderPath(0, shellcon.CSIDL_COMMON_APPDATA),
os.path.join(shell.SHGetSpecialFolderPath( appname,
0, shellcon.CSIDL_COMMON_APPDATA), f'{appname}.conf'), f"{appname}.conf",
),
os.path.join(
shell.SHGetSpecialFolderPath(0, shellcon.CSIDL_COMMON_APPDATA),
f"{appname}.conf",
),
] ]
# default paths for *nix # default paths for *nix
return [ return [
f'{sys.prefix}/app/{appname}.conf', f"{sys.prefix}/app/{appname}.conf",
os.path.expanduser(f"~/.{appname}/{appname}.conf"),
os.path.expanduser(f'~/.{appname}/{appname}.conf'), os.path.expanduser(f"~/.{appname}.conf"),
os.path.expanduser(f'~/.{appname}.conf'), f"/usr/local/etc/{appname}/{appname}.conf",
f"/usr/local/etc/{appname}.conf",
f'/usr/local/etc/{appname}/{appname}.conf', f"/etc/{appname}/{appname}.conf",
f'/usr/local/etc/{appname}.conf', f"/etc/{appname}.conf",
f'/etc/{appname}/{appname}.conf',
f'/etc/{appname}.conf',
] ]
def get_config_paths( def get_config_paths(
files=None, files=None,
plus_files=None, plus_files=None,
appname='wutta', appname="wutta",
env_files_name=None, env_files_name=None,
env_plus_files_name=None, env_plus_files_name=None,
env=None, env=None,
default_files=None, default_files=None,
winsvc=None): winsvc=None,
):
""" """
This function determines which files should ultimately be provided This function determines which files should ultimately be provided
to the config constructor. It is normally called by to the config constructor. It is normally called by
@ -856,7 +874,7 @@ def get_config_paths(
# first identify any "primary" config files # first identify any "primary" config files
if files is None: if files is None:
if not env_files_name: if not env_files_name:
env_files_name = f'{appname.upper()}_CONFIG_FILES' env_files_name = f"{appname.upper()}_CONFIG_FILES"
files = env.get(env_files_name) files = env.get(env_files_name)
if files is not None: if files is not None:
@ -869,8 +887,7 @@ def get_config_paths(
files = [default_files] files = [default_files]
else: else:
files = list(default_files) files = list(default_files)
files = [path for path in files files = [path for path in files if os.path.exists(path)]
if os.path.exists(path)]
else: else:
files = [] files = []
@ -886,7 +903,7 @@ def get_config_paths(
# then identify any "plus" (config tweak) files # then identify any "plus" (config tweak) files
if plus_files is None: if plus_files is None:
if not env_plus_files_name: if not env_plus_files_name:
env_plus_files_name = f'{appname.upper()}_CONFIG_PLUS_FILES' env_plus_files_name = f"{appname.upper()}_CONFIG_PLUS_FILES"
plus_files = env.get(env_plus_files_name) plus_files = env.get(env_plus_files_name)
if plus_files is not None: if plus_files is not None:
@ -911,9 +928,9 @@ def get_config_paths(
if winsvc: if winsvc:
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(files) config.read(files)
section = f'{appname}.config' section = f"{appname}.config"
if config.has_section(section): if config.has_section(section):
option = f'winsvc.{winsvc}' option = f"winsvc.{winsvc}"
if config.has_option(section, option): if config.has_option(section, option):
# replace file paths with whatever config value says # replace file paths with whatever config value says
files = parse_list(config.get(section, option)) files = parse_list(config.get(section, option))
@ -922,20 +939,21 @@ def get_config_paths(
def make_config( def make_config(
files=None, files=None,
plus_files=None, plus_files=None,
appname='wutta', appname="wutta",
env_files_name=None, env_files_name=None,
env_plus_files_name=None, env_plus_files_name=None,
env=None, env=None,
default_files=None, default_files=None,
winsvc=None, winsvc=None,
usedb=None, usedb=None,
preferdb=None, preferdb=None,
factory=None, factory=None,
extend=True, extend=True,
extension_entry_points=None, extension_entry_points=None,
**kwargs): **kwargs,
):
""" """
Make a new config (usually :class:`WuttaConfig`) object, Make a new config (usually :class:`WuttaConfig`) object,
initialized per the given parameters and (usually) further initialized per the given parameters and (usually) further
@ -992,19 +1010,18 @@ def make_config(
env_plus_files_name=env_plus_files_name, env_plus_files_name=env_plus_files_name,
env=env, env=env,
default_files=default_files, default_files=default_files,
winsvc=winsvc) winsvc=winsvc,
)
# make config object # make config object
if not factory: if not factory:
factory = WuttaConfig factory = WuttaConfig
config = factory(files, appname=appname, config = factory(files, appname=appname, usedb=usedb, preferdb=preferdb, **kwargs)
usedb=usedb, preferdb=preferdb,
**kwargs)
# maybe extend config object # maybe extend config object
if extend: if extend:
if not extension_entry_points: if not extension_entry_points:
extension_entry_points = f'{appname}.config.extensions' extension_entry_points = f"{appname}.config.extensions"
# apply all registered extensions # apply all registered extensions
# TODO: maybe let config disable some extensions? # TODO: maybe let config disable some extensions?
@ -1104,4 +1121,4 @@ class WuttaConfigProfile:
profile = TelemetryProfile("default") profile = TelemetryProfile("default")
url = profile.get_str("submit_url") url = profile.get_str("submit_url")
""" """
return self.config.get(f'{self.section}.{self.key}.{option}', **kwargs) return self.config.get(f"{self.section}.{self.key}.{option}", **kwargs)

View file

@ -11,8 +11,7 @@ from wuttjamaican.conf import make_config
alembic_config = context.config alembic_config = context.config
# this is the wutta-based config # this is the wutta-based config
wutta_config = make_config(alembic_config.config_file_name, wutta_config = make_config(alembic_config.config_file_name, usedb=False)
usedb=False)
# add your model's MetaData object here # add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support
@ -54,9 +53,7 @@ def run_migrations_online() -> None:
connectable = wutta_config.appdb_engine connectable = wutta_config.appdb_engine
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View file

@ -5,6 +5,7 @@ Revises: d686f7abe3e0
Create Date: 2024-07-14 15:14:30.552682 Create Date: 2024-07-14 15:14:30.552682
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
@ -13,8 +14,8 @@ import wuttjamaican.db.util
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '3abcc44f7f91' revision: str = "3abcc44f7f91"
down_revision: Union[str, None] = 'd686f7abe3e0' down_revision: Union[str, None] = "d686f7abe3e0"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -22,25 +23,30 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# person # person
op.create_table('person', op.create_table(
sa.Column('uuid', wuttjamaican.db.util.UUID(), nullable=False), "person",
sa.Column('full_name', sa.String(length=100), nullable=False), sa.Column("uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('first_name', sa.String(length=50), nullable=True), sa.Column("full_name", sa.String(length=100), nullable=False),
sa.Column('middle_name', sa.String(length=50), nullable=True), sa.Column("first_name", sa.String(length=50), nullable=True),
sa.Column('last_name', sa.String(length=50), nullable=True), sa.Column("middle_name", sa.String(length=50), nullable=True),
sa.PrimaryKeyConstraint('uuid', name=op.f('pk_person')) sa.Column("last_name", sa.String(length=50), nullable=True),
) sa.PrimaryKeyConstraint("uuid", name=op.f("pk_person")),
)
# user # user
op.add_column('user', sa.Column('person_uuid', wuttjamaican.db.util.UUID(), nullable=True)) op.add_column(
op.create_foreign_key(op.f('fk_user_person_uuid_person'), 'user', 'person', ['person_uuid'], ['uuid']) "user", sa.Column("person_uuid", wuttjamaican.db.util.UUID(), nullable=True)
)
op.create_foreign_key(
op.f("fk_user_person_uuid_person"), "user", "person", ["person_uuid"], ["uuid"]
)
def downgrade() -> None: def downgrade() -> None:
# user # user
op.drop_constraint(op.f('fk_user_person_uuid_person'), 'user', type_='foreignkey') op.drop_constraint(op.f("fk_user_person_uuid_person"), "user", type_="foreignkey")
op.drop_column('user', 'person_uuid') op.drop_column("user", "person_uuid")
# person # person
op.drop_table('person') op.drop_table("person")

View file

@ -5,6 +5,7 @@ Revises: ebd75b9feaa7
Create Date: 2024-11-24 16:52:36.773657 Create Date: 2024-11-24 16:52:36.773657
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
@ -12,8 +13,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '6bf900765500' revision: str = "6bf900765500"
down_revision: Union[str, None] = 'ebd75b9feaa7' down_revision: Union[str, None] = "ebd75b9feaa7"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -21,10 +22,10 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# user # user
op.add_column('user', sa.Column('prevent_edit', sa.Boolean(), nullable=True)) op.add_column("user", sa.Column("prevent_edit", sa.Boolean(), nullable=True))
def downgrade() -> None: def downgrade() -> None:
# user # user
op.drop_column('user', 'prevent_edit') op.drop_column("user", "prevent_edit")

View file

@ -5,6 +5,7 @@ Revises: fc3a3bcaa069
Create Date: 2024-07-14 13:27:22.703093 Create Date: 2024-07-14 13:27:22.703093
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
@ -13,8 +14,8 @@ import wuttjamaican.db.util
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'd686f7abe3e0' revision: str = "d686f7abe3e0"
down_revision: Union[str, None] = 'fc3a3bcaa069' down_revision: Union[str, None] = "fc3a3bcaa069"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -22,53 +23,63 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# role # role
op.create_table('role', op.create_table(
sa.Column('uuid', wuttjamaican.db.util.UUID(), nullable=False), "role",
sa.Column('name', sa.String(length=100), nullable=False), sa.Column("uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('notes', sa.Text(), nullable=True), sa.Column("name", sa.String(length=100), nullable=False),
sa.PrimaryKeyConstraint('uuid'), sa.Column("notes", sa.Text(), nullable=True),
sa.UniqueConstraint('name', name=op.f('uq_role_name')) sa.PrimaryKeyConstraint("uuid"),
) sa.UniqueConstraint("name", name=op.f("uq_role_name")),
)
# user # user
op.create_table('user', op.create_table(
sa.Column('uuid', wuttjamaican.db.util.UUID(), nullable=False), "user",
sa.Column('username', sa.String(length=25), nullable=False), sa.Column("uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('password', sa.String(length=60), nullable=True), sa.Column("username", sa.String(length=25), nullable=False),
sa.Column('active', sa.Boolean(), nullable=False), sa.Column("password", sa.String(length=60), nullable=True),
sa.PrimaryKeyConstraint('uuid'), sa.Column("active", sa.Boolean(), nullable=False),
sa.UniqueConstraint('username', name=op.f('uq_user_username')) sa.PrimaryKeyConstraint("uuid"),
) sa.UniqueConstraint("username", name=op.f("uq_user_username")),
)
# permission # permission
op.create_table('permission', op.create_table(
sa.Column('role_uuid', wuttjamaican.db.util.UUID(), nullable=False), "permission",
sa.Column('permission', sa.String(length=254), nullable=False), sa.Column("role_uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.ForeignKeyConstraint(['role_uuid'], ['role.uuid'], name=op.f('fk_permission_role_uuid_role')), sa.Column("permission", sa.String(length=254), nullable=False),
sa.PrimaryKeyConstraint('role_uuid', 'permission') sa.ForeignKeyConstraint(
) ["role_uuid"], ["role.uuid"], name=op.f("fk_permission_role_uuid_role")
),
sa.PrimaryKeyConstraint("role_uuid", "permission"),
)
# user_x_role # user_x_role
op.create_table('user_x_role', op.create_table(
sa.Column('uuid', wuttjamaican.db.util.UUID(), nullable=False), "user_x_role",
sa.Column('user_uuid', wuttjamaican.db.util.UUID(), nullable=False), sa.Column("uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('role_uuid', wuttjamaican.db.util.UUID(), nullable=False), sa.Column("user_uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.ForeignKeyConstraint(['role_uuid'], ['role.uuid'], name=op.f('fk_user_x_role_role_uuid_role')), sa.Column("role_uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.ForeignKeyConstraint(['user_uuid'], ['user.uuid'], name=op.f('fk_user_x_role_user_uuid_user')), sa.ForeignKeyConstraint(
sa.PrimaryKeyConstraint('uuid') ["role_uuid"], ["role.uuid"], name=op.f("fk_user_x_role_role_uuid_role")
) ),
sa.ForeignKeyConstraint(
["user_uuid"], ["user.uuid"], name=op.f("fk_user_x_role_user_uuid_user")
),
sa.PrimaryKeyConstraint("uuid"),
)
def downgrade() -> None: def downgrade() -> None:
# user_x_role # user_x_role
op.drop_table('user_x_role') op.drop_table("user_x_role")
# permission # permission
op.drop_table('permission') op.drop_table("permission")
# user # user
op.drop_table('user') op.drop_table("user")
# role # role
op.drop_table('role') op.drop_table("role")

View file

@ -5,6 +5,7 @@ Revises: 3abcc44f7f91
Create Date: 2024-08-24 09:42:21.199679 Create Date: 2024-08-24 09:42:21.199679
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
@ -13,8 +14,8 @@ from sqlalchemy.dialects import postgresql
import wuttjamaican.db.util import wuttjamaican.db.util
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'ebd75b9feaa7' revision: str = "ebd75b9feaa7"
down_revision: Union[str, None] = '3abcc44f7f91' down_revision: Union[str, None] = "3abcc44f7f91"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -22,26 +23,50 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# upgrade # upgrade
sa.Enum('PENDING', 'EXECUTING', 'SUCCESS', 'FAILURE', name='upgradestatus').create(op.get_bind()) sa.Enum("PENDING", "EXECUTING", "SUCCESS", "FAILURE", name="upgradestatus").create(
op.create_table('upgrade', op.get_bind()
sa.Column('uuid', wuttjamaican.db.util.UUID(), nullable=False), )
sa.Column('created', sa.DateTime(timezone=True), nullable=False), op.create_table(
sa.Column('created_by_uuid', wuttjamaican.db.util.UUID(), nullable=False), "upgrade",
sa.Column('description', sa.String(length=255), nullable=False), sa.Column("uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('notes', sa.Text(), nullable=True), sa.Column("created", sa.DateTime(timezone=True), nullable=False),
sa.Column('executing', sa.Boolean(), nullable=False), sa.Column("created_by_uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('status', postgresql.ENUM('PENDING', 'EXECUTING', 'SUCCESS', 'FAILURE', name='upgradestatus', create_type=False), nullable=False), sa.Column("description", sa.String(length=255), nullable=False),
sa.Column('executed', sa.DateTime(timezone=True), nullable=True), sa.Column("notes", sa.Text(), nullable=True),
sa.Column('executed_by_uuid', wuttjamaican.db.util.UUID(), nullable=True), sa.Column("executing", sa.Boolean(), nullable=False),
sa.Column('exit_code', sa.Integer(), nullable=True), sa.Column(
sa.ForeignKeyConstraint(['created_by_uuid'], ['user.uuid'], name=op.f('fk_upgrade_created_by_uuid_user')), "status",
sa.ForeignKeyConstraint(['executed_by_uuid'], ['user.uuid'], name=op.f('fk_upgrade_executed_by_uuid_user')), postgresql.ENUM(
sa.PrimaryKeyConstraint('uuid', name=op.f('pk_upgrade')) "PENDING",
) "EXECUTING",
"SUCCESS",
"FAILURE",
name="upgradestatus",
create_type=False,
),
nullable=False,
),
sa.Column("executed", sa.DateTime(timezone=True), nullable=True),
sa.Column("executed_by_uuid", wuttjamaican.db.util.UUID(), nullable=True),
sa.Column("exit_code", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["created_by_uuid"],
["user.uuid"],
name=op.f("fk_upgrade_created_by_uuid_user"),
),
sa.ForeignKeyConstraint(
["executed_by_uuid"],
["user.uuid"],
name=op.f("fk_upgrade_executed_by_uuid_user"),
),
sa.PrimaryKeyConstraint("uuid", name=op.f("pk_upgrade")),
)
def downgrade() -> None: def downgrade() -> None:
# upgrade # upgrade
op.drop_table('upgrade') op.drop_table("upgrade")
sa.Enum('PENDING', 'EXECUTING', 'SUCCESS', 'FAILURE', name='upgradestatus').drop(op.get_bind()) sa.Enum("PENDING", "EXECUTING", "SUCCESS", "FAILURE", name="upgradestatus").drop(
op.get_bind()
)

View file

@ -5,6 +5,7 @@ Revises: 6bf900765500
Create Date: 2025-08-08 08:58:19.376105 Create Date: 2025-08-08 08:58:19.376105
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
@ -13,8 +14,8 @@ import wuttjamaican.db.util
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'efdcb2c75034' revision: str = "efdcb2c75034"
down_revision: Union[str, None] = '6bf900765500' down_revision: Union[str, None] = "6bf900765500"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@ -22,18 +23,21 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# user_api_token # user_api_token
op.create_table('user_api_token', op.create_table(
sa.Column('uuid', wuttjamaican.db.util.UUID(), nullable=False), "user_api_token",
sa.Column('user_uuid', wuttjamaican.db.util.UUID(), nullable=False), sa.Column("uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False), sa.Column("user_uuid", wuttjamaican.db.util.UUID(), nullable=False),
sa.Column('token_string', sa.String(length=255), nullable=False), sa.Column("description", sa.String(length=255), nullable=False),
sa.Column('created', sa.DateTime(timezone=True), nullable=False), sa.Column("token_string", sa.String(length=255), nullable=False),
sa.ForeignKeyConstraint(['user_uuid'], ['user.uuid'], name=op.f('fk_user_api_token_user_uuid_user')), sa.Column("created", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('uuid', name=op.f('pk_user_api_token')) sa.ForeignKeyConstraint(
) ["user_uuid"], ["user.uuid"], name=op.f("fk_user_api_token_user_uuid_user")
),
sa.PrimaryKeyConstraint("uuid", name=op.f("pk_user_api_token")),
)
def downgrade() -> None: def downgrade() -> None:
# user_api_token # user_api_token
op.drop_table('user_api_token') op.drop_table("user_api_token")

View file

@ -1,10 +1,11 @@
"""init with settings table """init with settings table
Revision ID: fc3a3bcaa069 Revision ID: fc3a3bcaa069
Revises: Revises:
Create Date: 2024-07-10 20:33:41.273952 Create Date: 2024-07-10 20:33:41.273952
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
@ -12,23 +13,24 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'fc3a3bcaa069' revision: str = "fc3a3bcaa069"
down_revision: Union[str, None] = None down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = ('wutta',) branch_labels: Union[str, Sequence[str], None] = ("wutta",)
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# setting # setting
op.create_table('setting', op.create_table(
sa.Column('name', sa.String(length=255), nullable=False), "setting",
sa.Column('value', sa.Text(), nullable=True), sa.Column("name", sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint('name') sa.Column("value", sa.Text(), nullable=True),
) sa.PrimaryKeyConstraint("name"),
)
def downgrade() -> None: def downgrade() -> None:
# setting # setting
op.drop_table('setting') op.drop_table("setting")

View file

@ -62,11 +62,11 @@ def get_engines(config, prefix):
:returns: A dictionary of SQLAlchemy engines, with keys matching :returns: A dictionary of SQLAlchemy engines, with keys matching
those found in config. those found in config.
""" """
keys = config.get(f'{prefix}.keys', usedb=False) keys = config.get(f"{prefix}.keys", usedb=False)
if keys: if keys:
keys = parse_list(keys) keys = parse_list(keys)
else: else:
keys = ['default'] keys = ["default"]
make_engine = config.get_engine_maker() make_engine = config.get_engine_maker()
@ -75,11 +75,11 @@ def get_engines(config, prefix):
for key in keys: for key in keys:
key = key.strip() key = key.strip()
try: try:
engines[key] = make_engine(cfg, prefix=f'{key}.') engines[key] = make_engine(cfg, prefix=f"{key}.")
except KeyError: except KeyError:
if key == 'default': if key == "default":
try: try:
engines[key] = make_engine(cfg, prefix='sqlalchemy.') engines[key] = make_engine(cfg, prefix="sqlalchemy.")
except KeyError: except KeyError:
pass pass
return engines return engines
@ -99,13 +99,10 @@ def get_setting(session, name):
:returns: Setting value as string, or ``None``. :returns: Setting value as string, or ``None``.
""" """
sql = sa.text("select value from setting where name = :name") sql = sa.text("select value from setting where name = :name")
return session.execute(sql, params={'name': name}).scalar() return session.execute(sql, params={"name": name}).scalar()
def make_engine_from_config( def make_engine_from_config(config_dict, prefix="sqlalchemy.", **kwargs):
config_dict,
prefix='sqlalchemy.',
**kwargs):
""" """
Construct a new DB engine from configuration dict. Construct a new DB engine from configuration dict.
@ -141,14 +138,14 @@ def make_engine_from_config(
config_dict = dict(config_dict) config_dict = dict(config_dict)
# convert 'poolclass' arg to actual class # convert 'poolclass' arg to actual class
key = f'{prefix}poolclass' key = f"{prefix}poolclass"
if key in config_dict and 'poolclass' not in kwargs: if key in config_dict and "poolclass" not in kwargs:
kwargs['poolclass'] = load_object(config_dict.pop(key)) kwargs["poolclass"] = load_object(config_dict.pop(key))
# convert 'pool_pre_ping' arg to boolean # convert 'pool_pre_ping' arg to boolean
key = f'{prefix}pool_pre_ping' key = f"{prefix}pool_pre_ping"
if key in config_dict and 'pool_pre_ping' not in kwargs: if key in config_dict and "pool_pre_ping" not in kwargs:
kwargs['pool_pre_ping'] = parse_bool(config_dict.pop(key)) kwargs["pool_pre_ping"] = parse_bool(config_dict.pop(key))
engine = sa.engine_from_config(config_dict, prefix, **kwargs) engine = sa.engine_from_config(config_dict, prefix, **kwargs)

View file

@ -34,7 +34,7 @@ class DatabaseHandler(GenericHandler):
Base class and default implementation for the :term:`db handler`. Base class and default implementation for the :term:`db handler`.
""" """
def get_dialect(self, bind): # pylint: disable=empty-docstring def get_dialect(self, bind): # pylint: disable=empty-docstring
""" """ """ """
return bind.url.get_dialect().name return bind.url.get_dialect().name
@ -59,7 +59,7 @@ class DatabaseHandler(GenericHandler):
dialect = self.get_dialect(session.bind) dialect = self.get_dialect(session.bind)
# postgres uses "true" native sequence # postgres uses "true" native sequence
if dialect == 'postgresql': if dialect == "postgresql":
sql = f"create sequence if not exists {key}_seq" sql = f"create sequence if not exists {key}_seq"
session.execute(sa.text(sql)) session.execute(sa.text(sql))
sql = f"select nextval('{key}_seq')" sql = f"select nextval('{key}_seq')"
@ -69,8 +69,11 @@ class DatabaseHandler(GenericHandler):
# otherwise use "magic" workaround # otherwise use "magic" workaround
engine = session.bind engine = session.bind
metadata = sa.MetaData() metadata = sa.MetaData()
table = sa.Table(f'_counter_{key}', metadata, table = sa.Table(
sa.Column('value', sa.Integer(), primary_key=True)) f"_counter_{key}",
metadata,
sa.Column("value", sa.Integer(), primary_key=True),
)
table.create(engine, checkfirst=True) table.create(engine, checkfirst=True)
with engine.begin() as cxn: with engine.begin() as cxn:
result = cxn.execute(table.insert()) result = cxn.execute(table.insert())

View file

@ -49,7 +49,7 @@ from wuttjamaican.db.util import uuid_column, uuid_fk_column
from wuttjamaican.db.model.base import Base from wuttjamaican.db.model.base import Base
class Role(Base): # pylint: disable=too-few-public-methods class Role(Base): # pylint: disable=too-few-public-methods
""" """
Represents an authentication role within the system; used for Represents an authentication role within the system; used for
permission management. permission management.
@ -67,51 +67,65 @@ class Role(Base): # pylint: disable=too-few-public-methods
See also :attr:`user_refs`. See also :attr:`user_refs`.
""" """
__tablename__ = 'role'
__tablename__ = "role"
__versioned__ = {} __versioned__ = {}
uuid = uuid_column() uuid = uuid_column()
name = sa.Column(sa.String(length=100), nullable=False, unique=True, doc=""" name = sa.Column(
sa.String(length=100),
nullable=False,
unique=True,
doc="""
Name for the role. Each role must have a name, which must be Name for the role. Each role must have a name, which must be
unique. unique.
""") """,
)
notes = sa.Column(sa.Text(), nullable=True, doc=""" notes = sa.Column(
sa.Text(),
nullable=True,
doc="""
Arbitrary notes for the role. Arbitrary notes for the role.
""") """,
)
permission_refs = orm.relationship( permission_refs = orm.relationship(
'Permission', "Permission",
back_populates='role', back_populates="role",
cascade='all, delete-orphan', cascade="all, delete-orphan",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
List of :class:`Permission` references for the role. List of :class:`Permission` references for the role.
See also :attr:`permissions`. See also :attr:`permissions`.
""") """,
)
permissions = association_proxy( permissions = association_proxy(
'permission_refs', 'permission', "permission_refs",
"permission",
creator=lambda p: Permission(permission=p), creator=lambda p: Permission(permission=p),
# TODO # TODO
# getset_factory=getset_factory, # getset_factory=getset_factory,
) )
user_refs = orm.relationship( user_refs = orm.relationship(
'UserRole', "UserRole",
back_populates='role', back_populates="role",
cascade='all, delete-orphan', cascade="all, delete-orphan",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
List of :class:`UserRole` instances belonging to the role. List of :class:`UserRole` instances belonging to the role.
See also :attr:`users`. See also :attr:`users`.
""") """,
)
users = association_proxy( users = association_proxy(
'user_refs', 'user', "user_refs",
"user",
creator=lambda u: UserRole(user=u), creator=lambda u: UserRole(user=u),
# TODO # TODO
# getset_factory=getset_factory, # getset_factory=getset_factory,
@ -121,32 +135,38 @@ class Role(Base): # pylint: disable=too-few-public-methods
return self.name or "" return self.name or ""
class Permission(Base): # pylint: disable=too-few-public-methods class Permission(Base): # pylint: disable=too-few-public-methods
""" """
Represents a permission granted to a role. Represents a permission granted to a role.
""" """
__tablename__ = 'permission'
__tablename__ = "permission"
__versioned__ = {} __versioned__ = {}
role_uuid = uuid_fk_column('role.uuid', primary_key=True, nullable=False) role_uuid = uuid_fk_column("role.uuid", primary_key=True, nullable=False)
role = orm.relationship( role = orm.relationship(
Role, Role,
back_populates='permission_refs', back_populates="permission_refs",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
Reference to the :class:`Role` for which the permission is Reference to the :class:`Role` for which the permission is
granted. granted.
""") """,
)
permission = sa.Column(sa.String(length=254), primary_key=True, doc=""" permission = sa.Column(
sa.String(length=254),
primary_key=True,
doc="""
Key (name) of the permission which is granted. Key (name) of the permission which is granted.
""") """,
)
def __str__(self): def __str__(self):
return self.permission or "" return self.permission or ""
class User(Base): # pylint: disable=too-few-public-methods class User(Base): # pylint: disable=too-few-public-methods
""" """
Represents a user of the system. Represents a user of the system.
@ -159,70 +179,93 @@ class User(Base): # pylint: disable=too-few-public-methods
See also :attr:`role_refs`. See also :attr:`role_refs`.
""" """
__tablename__ = 'user'
__tablename__ = "user"
__versioned__ = {} __versioned__ = {}
uuid = uuid_column() uuid = uuid_column()
username = sa.Column(sa.String(length=25), nullable=False, unique=True, doc=""" username = sa.Column(
sa.String(length=25),
nullable=False,
unique=True,
doc="""
Account username. This is required and must be unique. Account username. This is required and must be unique.
""") """,
)
password = sa.Column(sa.String(length=60), nullable=True, doc=""" password = sa.Column(
sa.String(length=60),
nullable=True,
doc="""
Hashed password for login. (The raw password is not stored.) Hashed password for login. (The raw password is not stored.)
""") """,
)
person_uuid = uuid_fk_column('person.uuid', nullable=True) person_uuid = uuid_fk_column("person.uuid", nullable=True)
person = orm.relationship( person = orm.relationship(
'Person', "Person",
# TODO: seems like this is not needed? # TODO: seems like this is not needed?
# uselist=False, # uselist=False,
back_populates='users', back_populates="users",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
Reference to the :class:`~wuttjamaican.db.model.base.Person` Reference to the :class:`~wuttjamaican.db.model.base.Person`
whose user account this is. whose user account this is.
""") """,
)
active = sa.Column(sa.Boolean(), nullable=False, default=True, doc=""" active = sa.Column(
sa.Boolean(),
nullable=False,
default=True,
doc="""
Flag indicating whether the user account is "active" - it is Flag indicating whether the user account is "active" - it is
``True`` by default. ``True`` by default.
The default auth logic will prevent login for "inactive" user accounts. The default auth logic will prevent login for "inactive" user accounts.
""") """,
)
prevent_edit = sa.Column(sa.Boolean(), nullable=True, doc=""" prevent_edit = sa.Column(
sa.Boolean(),
nullable=True,
doc="""
If set, this user account can only be edited by root. User cannot If set, this user account can only be edited by root. User cannot
change their own password. change their own password.
""") """,
)
role_refs = orm.relationship( role_refs = orm.relationship(
'UserRole', "UserRole",
back_populates='user', back_populates="user",
cascade='all, delete-orphan', cascade="all, delete-orphan",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
List of :class:`UserRole` instances belonging to the user. List of :class:`UserRole` instances belonging to the user.
See also :attr:`roles`. See also :attr:`roles`.
""") """,
)
roles = association_proxy( roles = association_proxy(
'role_refs', 'role', "role_refs",
"role",
creator=lambda r: UserRole(role=r), creator=lambda r: UserRole(role=r),
# TODO # TODO
# getset_factory=getset_factory, # getset_factory=getset_factory,
) )
api_tokens = orm.relationship( api_tokens = orm.relationship(
'UserAPIToken', "UserAPIToken",
back_populates='user', back_populates="user",
order_by='UserAPIToken.created', order_by="UserAPIToken.created",
cascade='all, delete-orphan', cascade="all, delete-orphan",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
List of :class:`UserAPIToken` instances belonging to the user. List of :class:`UserAPIToken` instances belonging to the user.
""") """,
)
def __str__(self): def __str__(self):
if self.person: if self.person:
@ -232,59 +275,72 @@ class User(Base): # pylint: disable=too-few-public-methods
return self.username or "" return self.username or ""
class UserRole(Base): # pylint: disable=too-few-public-methods class UserRole(Base): # pylint: disable=too-few-public-methods
""" """
Represents the association between a user and a role; i.e. the Represents the association between a user and a role; i.e. the
user "belongs" or "is assigned" to the role. user "belongs" or "is assigned" to the role.
""" """
__tablename__ = 'user_x_role'
__tablename__ = "user_x_role"
__versioned__ = {} __versioned__ = {}
uuid = uuid_column() uuid = uuid_column()
user_uuid = uuid_fk_column('user.uuid', nullable=False) user_uuid = uuid_fk_column("user.uuid", nullable=False)
user = orm.relationship( user = orm.relationship(
User, User,
back_populates='role_refs', back_populates="role_refs",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
Reference to the :class:`User` involved. Reference to the :class:`User` involved.
""") """,
)
role_uuid = uuid_fk_column('role.uuid', nullable=False) role_uuid = uuid_fk_column("role.uuid", nullable=False)
role = orm.relationship( role = orm.relationship(
Role, Role,
back_populates='user_refs', back_populates="user_refs",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
Reference to the :class:`Role` involved. Reference to the :class:`Role` involved.
""") """,
)
class UserAPIToken(Base): # pylint: disable=too-few-public-methods class UserAPIToken(Base): # pylint: disable=too-few-public-methods
""" """
User authentication token for use with HTTP API User authentication token for use with HTTP API
""" """
__tablename__ = 'user_api_token'
__tablename__ = "user_api_token"
uuid = uuid_column() uuid = uuid_column()
user_uuid = uuid_fk_column('user.uuid', nullable=False) user_uuid = uuid_fk_column("user.uuid", nullable=False)
user = orm.relationship( user = orm.relationship(
User, User,
back_populates='api_tokens', back_populates="api_tokens",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
Reference to the :class:`User` whose token this is. Reference to the :class:`User` whose token this is.
""") """,
)
description = sa.Column(sa.String(length=255), nullable=False, doc=""" description = sa.Column(
sa.String(length=255),
nullable=False,
doc="""
Description of the token. Description of the token.
""") """,
)
token_string = sa.Column(sa.String(length=255), nullable=False, doc=""" token_string = sa.Column(
sa.String(length=255),
nullable=False,
doc="""
Raw token string, to be used by API clients. Raw token string, to be used by API clients.
""") """,
)
created = sa.Column( created = sa.Column(
sa.DateTime(timezone=True), sa.DateTime(timezone=True),
@ -292,7 +348,8 @@ class UserAPIToken(Base): # pylint: disable=too-few-public-methods
default=datetime.datetime.now, default=datetime.datetime.now,
doc=""" doc="""
Date/time when the token was created. Date/time when the token was created.
""") """,
)
def __str__(self): def __str__(self):
return self.description or "" return self.description or ""

View file

@ -39,7 +39,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
from wuttjamaican.db.util import naming_convention, ModelBase, uuid_column from wuttjamaican.db.util import naming_convention, ModelBase, uuid_column
class WuttaModelBase(ModelBase): # pylint: disable=too-few-public-methods class WuttaModelBase(ModelBase): # pylint: disable=too-few-public-methods
""" """
Base class for data models, from which :class:`Base` inherits. Base class for data models, from which :class:`Base` inherits.
@ -113,8 +113,8 @@ class WuttaModelBase(ModelBase): # pylint: disable=too-few-public-methods
print(user.favorite_color) print(user.favorite_color)
""" """
proxy = association_proxy( proxy = association_proxy(
extension, proxy_name or name, extension, proxy_name or name, creator=lambda value: cls(**{name: value})
creator=lambda value: cls(**{name: value})) )
setattr(main_class, name, proxy) setattr(main_class, name, proxy)
@ -123,19 +123,29 @@ metadata = sa.MetaData(naming_convention=naming_convention)
Base = orm.declarative_base(metadata=metadata, cls=WuttaModelBase) Base = orm.declarative_base(metadata=metadata, cls=WuttaModelBase)
class Setting(Base): # pylint: disable=too-few-public-methods class Setting(Base): # pylint: disable=too-few-public-methods
""" """
Represents a :term:`config setting`. Represents a :term:`config setting`.
""" """
__tablename__ = 'setting'
name = sa.Column(sa.String(length=255), primary_key=True, nullable=False, doc=""" __tablename__ = "setting"
name = sa.Column(
sa.String(length=255),
primary_key=True,
nullable=False,
doc="""
Unique name for the setting. Unique name for the setting.
""") """,
)
value = sa.Column(sa.Text(), nullable=True, doc=""" value = sa.Column(
sa.Text(),
nullable=True,
doc="""
String value for the setting. String value for the setting.
""") """,
)
def __str__(self): def __str__(self):
return self.name or "" return self.name or ""
@ -153,36 +163,54 @@ class Person(Base):
But this table could also be used as a basis for a Customer or But this table could also be used as a basis for a Customer or
Employee relationship etc. Employee relationship etc.
""" """
__tablename__ = 'person'
__tablename__ = "person"
__versioned__ = {} __versioned__ = {}
uuid = uuid_column() uuid = uuid_column()
full_name = sa.Column(sa.String(length=100), nullable=False, doc=""" full_name = sa.Column(
sa.String(length=100),
nullable=False,
doc="""
Full name for the person. Note that this is *required*. Full name for the person. Note that this is *required*.
""") """,
)
first_name = sa.Column(sa.String(length=50), nullable=True, doc=""" first_name = sa.Column(
sa.String(length=50),
nullable=True,
doc="""
The person's first name. The person's first name.
""") """,
)
middle_name = sa.Column(sa.String(length=50), nullable=True, doc=""" middle_name = sa.Column(
sa.String(length=50),
nullable=True,
doc="""
The person's middle name or initial. The person's middle name or initial.
""") """,
)
last_name = sa.Column(sa.String(length=50), nullable=True, doc=""" last_name = sa.Column(
sa.String(length=50),
nullable=True,
doc="""
The person's last name. The person's last name.
""") """,
)
users = orm.relationship( users = orm.relationship(
'User', "User",
back_populates='person', back_populates="person",
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
List of :class:`~wuttjamaican.db.model.auth.User` accounts for List of :class:`~wuttjamaican.db.model.auth.User` accounts for
the person. Typically there is only one user account per the person. Typically there is only one user account per
person, but technically multiple are supported. person, but technically multiple are supported.
""") """,
)
def __str__(self): def __str__(self):
return self.full_name or "" return self.full_name or ""

View file

@ -186,7 +186,7 @@ class BatchMixin:
""" """
@declared_attr @declared_attr
def __table_args__(cls): # pylint: disable=no-self-argument def __table_args__(cls): # pylint: disable=no-self-argument
return cls.__default_table_args__() return cls.__default_table_args__()
@classmethod @classmethod
@ -196,12 +196,12 @@ class BatchMixin:
@classmethod @classmethod
def __batch_table_args__(cls): def __batch_table_args__(cls):
return ( return (
sa.ForeignKeyConstraint(['created_by_uuid'], ['user.uuid']), sa.ForeignKeyConstraint(["created_by_uuid"], ["user.uuid"]),
sa.ForeignKeyConstraint(['executed_by_uuid'], ['user.uuid']), sa.ForeignKeyConstraint(["executed_by_uuid"], ["user.uuid"]),
) )
@declared_attr @declared_attr
def batch_type(cls): # pylint: disable=empty-docstring,no-self-argument def batch_type(cls): # pylint: disable=empty-docstring,no-self-argument
""" """ """ """
return cls.__tablename__ return cls.__tablename__
@ -212,42 +212,44 @@ class BatchMixin:
notes = sa.Column(sa.Text(), nullable=True) notes = sa.Column(sa.Text(), nullable=True)
row_count = sa.Column(sa.Integer(), nullable=True, default=0) row_count = sa.Column(sa.Integer(), nullable=True, default=0)
STATUS_INCOMPLETE = 1 STATUS_INCOMPLETE = 1
STATUS_EXECUTABLE = 2 STATUS_EXECUTABLE = 2
STATUS = { STATUS = {
STATUS_INCOMPLETE : "incomplete", STATUS_INCOMPLETE: "incomplete",
STATUS_EXECUTABLE : "executable", STATUS_EXECUTABLE: "executable",
} }
status_code = sa.Column(sa.Integer(), nullable=True) status_code = sa.Column(sa.Integer(), nullable=True)
status_text = sa.Column(sa.String(length=255), nullable=True) status_text = sa.Column(sa.String(length=255), nullable=True)
created = sa.Column(sa.DateTime(timezone=True), nullable=False, created = sa.Column(
default=datetime.datetime.now) sa.DateTime(timezone=True), nullable=False, default=datetime.datetime.now
)
created_by_uuid = sa.Column(UUID(), nullable=False) created_by_uuid = sa.Column(UUID(), nullable=False)
@declared_attr @declared_attr
def created_by(cls): # pylint: disable=empty-docstring,no-self-argument def created_by(cls): # pylint: disable=empty-docstring,no-self-argument
""" """ """ """
return orm.relationship( return orm.relationship(
User, User,
primaryjoin=lambda: User.uuid == cls.created_by_uuid, primaryjoin=lambda: User.uuid == cls.created_by_uuid,
foreign_keys=lambda: [cls.created_by_uuid], foreign_keys=lambda: [cls.created_by_uuid],
cascade_backrefs=False) cascade_backrefs=False,
)
executed = sa.Column(sa.DateTime(timezone=True), nullable=True) executed = sa.Column(sa.DateTime(timezone=True), nullable=True)
executed_by_uuid = sa.Column(UUID(), nullable=True) executed_by_uuid = sa.Column(UUID(), nullable=True)
@declared_attr @declared_attr
def executed_by(cls): # pylint: disable=empty-docstring,no-self-argument def executed_by(cls): # pylint: disable=empty-docstring,no-self-argument
""" """ """ """
return orm.relationship( return orm.relationship(
User, User,
primaryjoin=lambda: User.uuid == cls.executed_by_uuid, primaryjoin=lambda: User.uuid == cls.executed_by_uuid,
foreign_keys=lambda: [cls.executed_by_uuid], foreign_keys=lambda: [cls.executed_by_uuid],
cascade_backrefs=False) cascade_backrefs=False,
)
def __repr__(self): def __repr__(self):
cls = self.__class__.__name__ cls = self.__class__.__name__
@ -266,11 +268,11 @@ class BatchMixin:
print(batch.id_str) # => '00000042' print(batch.id_str) # => '00000042'
""" """
if self.id: if self.id:
return f'{self.id:08d}' return f"{self.id:08d}"
return None return None
class BatchRowMixin: # pylint: disable=too-few-public-methods class BatchRowMixin: # pylint: disable=too-few-public-methods
""" """
Mixin base class for :term:`data models <data model>` which Mixin base class for :term:`data models <data model>` which
represent a :term:`batch row`. represent a :term:`batch row`.
@ -381,7 +383,7 @@ class BatchRowMixin: # pylint: disable=too-few-public-methods
uuid = uuid_column() uuid = uuid_column()
@declared_attr @declared_attr
def __table_args__(cls): # pylint: disable=no-self-argument def __table_args__(cls): # pylint: disable=no-self-argument
return cls.__default_table_args__() return cls.__default_table_args__()
@classmethod @classmethod
@ -391,14 +393,12 @@ class BatchRowMixin: # pylint: disable=too-few-public-methods
@classmethod @classmethod
def __batchrow_table_args__(cls): def __batchrow_table_args__(cls):
batch_table = cls.__batch_class__.__tablename__ batch_table = cls.__batch_class__.__tablename__
return ( return (sa.ForeignKeyConstraint(["batch_uuid"], [f"{batch_table}.uuid"]),)
sa.ForeignKeyConstraint(['batch_uuid'], [f'{batch_table}.uuid']),
)
batch_uuid = sa.Column(UUID(), nullable=False) batch_uuid = sa.Column(UUID(), nullable=False)
@declared_attr @declared_attr
def batch(cls): # pylint: disable=empty-docstring,no-self-argument def batch(cls): # pylint: disable=empty-docstring,no-self-argument
""" """ """ """
batch_class = cls.__batch_class__ batch_class = cls.__batch_class__
row_class = cls row_class = cls
@ -409,16 +409,16 @@ class BatchRowMixin: # pylint: disable=too-few-public-methods
batch_class.rows = orm.relationship( batch_class.rows = orm.relationship(
row_class, row_class,
order_by=lambda: row_class.sequence, order_by=lambda: row_class.sequence,
collection_class=ordering_list('sequence', count_from=1), collection_class=ordering_list("sequence", count_from=1),
cascade='all, delete-orphan', cascade="all, delete-orphan",
cascade_backrefs=False, cascade_backrefs=False,
back_populates='batch') back_populates="batch",
)
# now, here's the `BatchRow.batch` # now, here's the `BatchRow.batch`
return orm.relationship( return orm.relationship(
batch_class, batch_class, back_populates="rows", cascade_backrefs=False
back_populates='rows', )
cascade_backrefs=False)
sequence = sa.Column(sa.Integer(), nullable=False) sequence = sa.Column(sa.Integer(), nullable=False)
@ -427,6 +427,9 @@ class BatchRowMixin: # pylint: disable=too-few-public-methods
status_code = sa.Column(sa.Integer(), nullable=True) status_code = sa.Column(sa.Integer(), nullable=True)
status_text = sa.Column(sa.String(length=255), nullable=True) status_text = sa.Column(sa.String(length=255), nullable=True)
modified = sa.Column(sa.DateTime(timezone=True), nullable=True, modified = sa.Column(
default=datetime.datetime.now, sa.DateTime(timezone=True),
onupdate=datetime.datetime.now) nullable=True,
default=datetime.datetime.now,
onupdate=datetime.datetime.now,
)

View file

@ -35,63 +35,95 @@ from wuttjamaican.util import make_true_uuid
from wuttjamaican.db.model.base import Base from wuttjamaican.db.model.base import Base
class Upgrade(Base): # pylint: disable=too-few-public-methods class Upgrade(Base): # pylint: disable=too-few-public-methods
""" """
Represents an app upgrade. Represents an app upgrade.
""" """
__tablename__ = 'upgrade'
__tablename__ = "upgrade"
uuid = uuid_column(UUID(), default=make_true_uuid) uuid = uuid_column(UUID(), default=make_true_uuid)
created = sa.Column(sa.DateTime(timezone=True), nullable=False, created = sa.Column(
default=datetime.datetime.now, doc=""" sa.DateTime(timezone=True),
nullable=False,
default=datetime.datetime.now,
doc="""
When the upgrade record was created. When the upgrade record was created.
""") """,
)
created_by_uuid = uuid_fk_column('user.uuid', nullable=False) created_by_uuid = uuid_fk_column("user.uuid", nullable=False)
created_by = orm.relationship( created_by = orm.relationship(
'User', "User",
foreign_keys=[created_by_uuid], foreign_keys=[created_by_uuid],
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
:class:`~wuttjamaican.db.model.auth.User` who created the :class:`~wuttjamaican.db.model.auth.User` who created the
upgrade record. upgrade record.
""") """,
)
description = sa.Column(sa.String(length=255), nullable=False, doc=""" description = sa.Column(
sa.String(length=255),
nullable=False,
doc="""
Basic (identifying) description for the upgrade. Basic (identifying) description for the upgrade.
""") """,
)
notes = sa.Column(sa.Text(), nullable=True, doc=""" notes = sa.Column(
sa.Text(),
nullable=True,
doc="""
Notes for the upgrade. Notes for the upgrade.
""") """,
)
executing = sa.Column(sa.Boolean(), nullable=False, default=False, doc=""" executing = sa.Column(
sa.Boolean(),
nullable=False,
default=False,
doc="""
Whether or not the upgrade is currently being performed. Whether or not the upgrade is currently being performed.
""") """,
)
status = sa.Column(sa.Enum(UpgradeStatus), nullable=False, doc=""" status = sa.Column(
sa.Enum(UpgradeStatus),
nullable=False,
doc="""
Current status for the upgrade. This field uses an enum, Current status for the upgrade. This field uses an enum,
:class:`~wuttjamaican.enum.UpgradeStatus`. :class:`~wuttjamaican.enum.UpgradeStatus`.
""") """,
)
executed = sa.Column(sa.DateTime(timezone=True), nullable=True, doc=""" executed = sa.Column(
sa.DateTime(timezone=True),
nullable=True,
doc="""
When the upgrade was executed. When the upgrade was executed.
""") """,
)
executed_by_uuid = uuid_fk_column('user.uuid', nullable=True) executed_by_uuid = uuid_fk_column("user.uuid", nullable=True)
executed_by = orm.relationship( executed_by = orm.relationship(
'User', "User",
foreign_keys=[executed_by_uuid], foreign_keys=[executed_by_uuid],
cascade_backrefs=False, cascade_backrefs=False,
doc=""" doc="""
:class:`~wuttjamaican.db.model.auth.User` who executed the :class:`~wuttjamaican.db.model.auth.User` who executed the
upgrade. upgrade.
""") """,
)
exit_code = sa.Column(sa.Integer(), nullable=True, doc=""" exit_code = sa.Column(
sa.Integer(),
nullable=True,
doc="""
Exit code for the upgrade execution process, if applicable. Exit code for the upgrade execution process, if applicable.
""") """,
)
def __str__(self): def __str__(self):
return str(self.description or "") return str(self.description or "")

View file

@ -38,7 +38,7 @@ from sqlalchemy import orm
Session = orm.sessionmaker() Session = orm.sessionmaker()
class short_session: # pylint: disable=invalid-name class short_session: # pylint: disable=invalid-name
""" """
Context manager for a short-lived database session. Context manager for a short-lived database session.

View file

@ -38,28 +38,27 @@ from wuttjamaican.util import make_true_uuid
# nb. this convention comes from upstream docs # nb. this convention comes from upstream docs
# https://docs.sqlalchemy.org/en/14/core/constraints.html#constraint-naming-conventions # https://docs.sqlalchemy.org/en/14/core/constraints.html#constraint-naming-conventions
naming_convention = { naming_convention = {
'ix': 'ix_%(column_0_label)s', "ix": "ix_%(column_0_label)s",
'uq': 'uq_%(table_name)s_%(column_0_name)s', "uq": "uq_%(table_name)s_%(column_0_name)s",
'ck': 'ck_%(table_name)s_%(constraint_name)s', "ck": "ck_%(table_name)s_%(constraint_name)s",
'fk': 'fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s', "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
'pk': 'pk_%(table_name)s', "pk": "pk_%(table_name)s",
} }
SA2 = True SA2 = True
if Version(version('SQLAlchemy')) < Version('2'): # pragma: no cover if Version(version("SQLAlchemy")) < Version("2"): # pragma: no cover
SA2 = False SA2 = False
class ModelBase: # pylint: disable=empty-docstring class ModelBase: # pylint: disable=empty-docstring
""" """ """ """
def __iter__(self): def __iter__(self):
# nb. we override this to allow for `dict(self)` # nb. we override this to allow for `dict(self)`
state = sa.inspect(self) state = sa.inspect(self)
fields = [attr.key for attr in state.attrs] fields = [attr.key for attr in state.attrs]
return iter([(field, getattr(self, field)) return iter([(field, getattr(self, field)) for field in fields])
for field in fields])
def __getitem__(self, key): def __getitem__(self, key):
# nb. we override this to allow for `x = self['field']` # nb. we override this to allow for `x = self['field']`
@ -69,7 +68,9 @@ class ModelBase: # pylint: disable=empty-docstring
raise KeyError(f"model instance has no attr with key: {key}") raise KeyError(f"model instance has no attr with key: {key}")
class UUID(sa.types.TypeDecorator): # pylint: disable=abstract-method,too-many-ancestors class UUID(
sa.types.TypeDecorator
): # pylint: disable=abstract-method,too-many-ancestors
""" """
Platform-independent UUID type. Platform-independent UUID type.
@ -80,17 +81,18 @@ class UUID(sa.types.TypeDecorator): # pylint: disable=abstract-method,too-many-a
documentation documentation
<https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type>`_. <https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type>`_.
""" """
impl = sa.CHAR impl = sa.CHAR
cache_ok = True cache_ok = True
""" """ # nb. suppress sphinx autodoc for cache_ok """ """ # nb. suppress sphinx autodoc for cache_ok
def load_dialect_impl(self, dialect): # pylint: disable=empty-docstring def load_dialect_impl(self, dialect): # pylint: disable=empty-docstring
""" """ """ """
if dialect.name == "postgresql": if dialect.name == "postgresql":
return dialect.type_descriptor(PGUUID()) return dialect.type_descriptor(PGUUID())
return dialect.type_descriptor(sa.CHAR(32)) return dialect.type_descriptor(sa.CHAR(32))
def process_bind_param(self, value, dialect): # pylint: disable=empty-docstring def process_bind_param(self, value, dialect): # pylint: disable=empty-docstring
""" """ """ """
if value is None: if value is None:
return value return value
@ -104,7 +106,9 @@ class UUID(sa.types.TypeDecorator): # pylint: disable=abstract-method,too-many-a
# hexstring # hexstring
return f"{value.int:032x}" return f"{value.int:032x}"
def process_result_value(self, value, dialect): # pylint: disable=unused-argument,empty-docstring def process_result_value(
self, value, dialect
): # pylint: disable=unused-argument,empty-docstring
""" """ """ """
if value is None: if value is None:
return value return value
@ -119,9 +123,9 @@ def uuid_column(*args, **kwargs):
""" """
if not args: if not args:
args = (UUID(),) args = (UUID(),)
kwargs.setdefault('primary_key', True) kwargs.setdefault("primary_key", True)
kwargs.setdefault('nullable', False) kwargs.setdefault("nullable", False)
kwargs.setdefault('default', make_true_uuid) kwargs.setdefault("default", make_true_uuid)
return sa.Column(*args, **kwargs) return sa.Column(*args, **kwargs)
@ -149,8 +153,7 @@ def make_topo_sortkey(model):
containing model classes. containing model classes.
""" """
metadata = model.Base.metadata metadata = model.Base.metadata
tables = {table.name: i tables = {table.name: i for i, table in enumerate(metadata.sorted_tables, 1)}
for i, table in enumerate(metadata.sorted_tables, 1)}
def sortkey(name): def sortkey(name):
cls = getattr(model, name) cls = getattr(model, name)

View file

@ -40,7 +40,7 @@ from wuttjamaican.util import resource_path
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class EmailSetting: # pylint: disable=too-few-public-methods class EmailSetting: # pylint: disable=too-few-public-methods
""" """
Base class for all :term:`email settings <email setting>`. Base class for all :term:`email settings <email setting>`.
@ -93,6 +93,7 @@ class EmailSetting: # pylint: disable=too-few-public-methods
rendered with the email context. But in most cases that rendered with the email context. But in most cases that
feature can be ignored, and this will be a simple string. feature can be ignored, and this will be a simple string.
""" """
default_subject = None default_subject = None
def __init__(self, config): def __init__(self, config):
@ -109,7 +110,7 @@ class EmailSetting: # pylint: disable=too-few-public-methods
return {} return {}
class Message: # pylint: disable=too-many-instance-attributes class Message: # pylint: disable=too-many-instance-attributes
""" """
Represents an email message to be sent. Represents an email message to be sent.
@ -172,17 +173,17 @@ class Message: # pylint: disable=too-many-instance-attributes
""" """
def __init__( def __init__(
self, self,
key=None, key=None,
sender=None, sender=None,
subject=None, subject=None,
to=None, to=None,
cc=None, cc=None,
bcc=None, bcc=None,
replyto=None, replyto=None,
txt_body=None, txt_body=None,
html_body=None, html_body=None,
attachments=None, attachments=None,
): ):
self.key = key self.key = key
self.sender = sender self.sender = sender
@ -195,7 +196,7 @@ class Message: # pylint: disable=too-many-instance-attributes
self.html_body = html_body self.html_body = html_body
self.attachments = attachments or [] self.attachments = attachments or []
def get_recips(self, value): # pylint: disable=empty-docstring def get_recips(self, value): # pylint: disable=empty-docstring
""" """ """ """
if value: if value:
if isinstance(value, str): if isinstance(value, str):
@ -216,15 +217,15 @@ class Message: # pylint: disable=too-many-instance-attributes
msg = None msg = None
if self.txt_body and self.html_body: if self.txt_body and self.html_body:
txt = MIMEText(self.txt_body, _charset='utf_8') txt = MIMEText(self.txt_body, _charset="utf_8")
html = MIMEText(self.html_body, _subtype='html', _charset='utf_8') html = MIMEText(self.html_body, _subtype="html", _charset="utf_8")
msg = MIMEMultipart(_subtype='alternative', _subparts=[txt, html]) msg = MIMEMultipart(_subtype="alternative", _subparts=[txt, html])
elif self.txt_body: elif self.txt_body:
msg = MIMEText(self.txt_body, _charset='utf_8') msg = MIMEText(self.txt_body, _charset="utf_8")
elif self.html_body: elif self.html_body:
msg = MIMEText(self.html_body, 'html', _charset='utf_8') msg = MIMEText(self.html_body, "html", _charset="utf_8")
if not msg: if not msg:
raise ValueError("message has no body parts") raise ValueError("message has no body parts")
@ -232,27 +233,29 @@ class Message: # pylint: disable=too-many-instance-attributes
if self.attachments: if self.attachments:
for attachment in self.attachments: for attachment in self.attachments:
if isinstance(attachment, str): if isinstance(attachment, str):
raise ValueError("must specify valid MIME attachments; this class cannot " raise ValueError(
"auto-create them from file path etc.") "must specify valid MIME attachments; this class cannot "
msg = MIMEMultipart(_subtype='mixed', _subparts=[msg] + self.attachments) "auto-create them from file path etc."
)
msg = MIMEMultipart(_subtype="mixed", _subparts=[msg] + self.attachments)
msg['Subject'] = self.subject msg["Subject"] = self.subject
msg['From'] = self.sender msg["From"] = self.sender
for addr in self.to: for addr in self.to:
msg['To'] = addr msg["To"] = addr
for addr in self.cc: for addr in self.cc:
msg['Cc'] = addr msg["Cc"] = addr
for addr in self.bcc: for addr in self.bcc:
msg['Bcc'] = addr msg["Bcc"] = addr
if self.replyto: if self.replyto:
msg.add_header('Reply-To', self.replyto) msg.add_header("Reply-To", self.replyto)
return msg.as_string() return msg.as_string()
class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
""" """
Base class and default implementation for the :term:`email Base class and default implementation for the :term:`email
handler`. handler`.
@ -272,13 +275,13 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# prefer configured list of template lookup paths, if set # prefer configured list of template lookup paths, if set
templates = self.config.get_list(f'{self.config.appname}.email.templates') templates = self.config.get_list(f"{self.config.appname}.email.templates")
if not templates: if not templates:
# otherwise use all available paths, from app providers # otherwise use all available paths, from app providers
available = [] available = []
for provider in self.app.providers.values(): for provider in self.app.providers.values():
if hasattr(provider, 'email_templates'): if hasattr(provider, "email_templates"):
templates = provider.email_templates templates = provider.email_templates
if isinstance(templates, str): if isinstance(templates, str):
templates = [templates] templates = [templates]
@ -292,10 +295,12 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
# will use these lookups from now on # will use these lookups from now on
self.txt_templates = TemplateLookup(directories=templates) self.txt_templates = TemplateLookup(directories=templates)
self.html_templates = TemplateLookup(directories=templates, self.html_templates = TemplateLookup(
# nb. escape HTML special chars directories=templates,
# TODO: sounds great but i forget why? # nb. escape HTML special chars
default_filters=['h']) # TODO: sounds great but i forget why?
default_filters=["h"],
)
def get_email_modules(self): def get_email_modules(self):
""" """
@ -309,7 +314,7 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
:meth:`~wuttjamaican.app.GenericHandler.get_provider_modules()` :meth:`~wuttjamaican.app.GenericHandler.get_provider_modules()`
under the hood, for ``email`` module type. under the hood, for ``email`` module type.
""" """
return self.get_provider_modules('email') return self.get_provider_modules("email")
def get_email_settings(self): def get_email_settings(self):
""" """
@ -319,14 +324,16 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
This calls :meth:`get_email_modules()` and for each module, it This calls :meth:`get_email_modules()` and for each module, it
discovers all the email settings it contains. discovers all the email settings it contains.
""" """
if not hasattr(self, '_email_settings'): if not hasattr(self, "_email_settings"):
self._email_settings = {} self._email_settings = {}
for module in self.get_email_modules(): for module in self.get_email_modules():
for name in dir(module): for name in dir(module):
obj = getattr(module, name) obj = getattr(module, name)
if (isinstance(obj, type) if (
isinstance(obj, type)
and obj is not EmailSetting and obj is not EmailSetting
and issubclass(obj, EmailSetting)): and issubclass(obj, EmailSetting)
):
self._email_settings[obj.__name__] = obj self._email_settings[obj.__name__] = obj
return self._email_settings return self._email_settings
@ -400,21 +407,23 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
* :meth:`get_auto_html_body()` * :meth:`get_auto_html_body()`
""" """
context = context or {} context = context or {}
kwargs['key'] = key kwargs["key"] = key
if 'sender' not in kwargs: if "sender" not in kwargs:
kwargs['sender'] = self.get_auto_sender(key) kwargs["sender"] = self.get_auto_sender(key)
if 'subject' not in kwargs: if "subject" not in kwargs:
kwargs['subject'] = self.get_auto_subject(key, context, default=default_subject) kwargs["subject"] = self.get_auto_subject(
if 'to' not in kwargs: key, context, default=default_subject
kwargs['to'] = self.get_auto_to(key) )
if 'cc' not in kwargs: if "to" not in kwargs:
kwargs['cc'] = self.get_auto_cc(key) kwargs["to"] = self.get_auto_to(key)
if 'bcc' not in kwargs: if "cc" not in kwargs:
kwargs['bcc'] = self.get_auto_bcc(key) kwargs["cc"] = self.get_auto_cc(key)
if 'txt_body' not in kwargs: if "bcc" not in kwargs:
kwargs['txt_body'] = self.get_auto_txt_body(key, context) kwargs["bcc"] = self.get_auto_bcc(key)
if 'html_body' not in kwargs: if "txt_body" not in kwargs:
kwargs['html_body'] = self.get_auto_html_body(key, context) kwargs["txt_body"] = self.get_auto_txt_body(key, context)
if "html_body" not in kwargs:
kwargs["html_body"] = self.get_auto_html_body(key, context)
return self.make_message(**kwargs) return self.make_message(**kwargs)
def get_auto_sender(self, key): def get_auto_sender(self, key):
@ -424,13 +433,14 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
message, as determined by config. message, as determined by config.
""" """
# prefer configured sender specific to key # prefer configured sender specific to key
sender = self.config.get(f'{self.config.appname}.email.{key}.sender') sender = self.config.get(f"{self.config.appname}.email.{key}.sender")
if sender: if sender:
return sender return sender
# fall back to global default # fall back to global default
return self.config.get(f'{self.config.appname}.email.default.sender', return self.config.get(
default='root@localhost') f"{self.config.appname}.email.default.sender", default="root@localhost"
)
def get_auto_replyto(self, key): def get_auto_replyto(self, key):
""" """
@ -438,14 +448,16 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
address for a message, as determined by config. address for a message, as determined by config.
""" """
# prefer configured replyto specific to key # prefer configured replyto specific to key
replyto = self.config.get(f'{self.config.appname}.email.{key}.replyto') replyto = self.config.get(f"{self.config.appname}.email.{key}.replyto")
if replyto: if replyto:
return replyto return replyto
# fall back to global default, if present # fall back to global default, if present
return self.config.get(f'{self.config.appname}.email.default.replyto') return self.config.get(f"{self.config.appname}.email.default.replyto")
def get_auto_subject(self, key, context=None, rendered=True, setting=None, default=None): def get_auto_subject(
self, key, context=None, rendered=True, setting=None, default=None
):
""" """
Returns automatic :attr:`~wuttjamaican.email.Message.subject` Returns automatic :attr:`~wuttjamaican.email.Message.subject`
line for a message, as determined by config. line for a message, as determined by config.
@ -501,7 +513,7 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: Final subject template, as raw text. :returns: Final subject template, as raw text.
""" """
# prefer configured subject specific to key # prefer configured subject specific to key
template = self.config.get(f'{self.config.appname}.email.{key}.subject') template = self.config.get(f"{self.config.appname}.email.{key}.subject")
if template: if template:
return template return template
@ -516,44 +528,47 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
return setting.default_subject return setting.default_subject
# fall back to global default # fall back to global default
return self.config.get(f'{self.config.appname}.email.default.subject', return self.config.get(
default=self.universal_subject) f"{self.config.appname}.email.default.subject",
default=self.universal_subject,
)
def get_auto_to(self, key): def get_auto_to(self, key):
""" """
Returns automatic :attr:`~wuttjamaican.email.Message.to` Returns automatic :attr:`~wuttjamaican.email.Message.to`
recipient address(es) for a message, as determined by config. recipient address(es) for a message, as determined by config.
""" """
return self.get_auto_recips(key, 'to') return self.get_auto_recips(key, "to")
def get_auto_cc(self, key): def get_auto_cc(self, key):
""" """
Returns automatic :attr:`~wuttjamaican.email.Message.cc` Returns automatic :attr:`~wuttjamaican.email.Message.cc`
recipient address(es) for a message, as determined by config. recipient address(es) for a message, as determined by config.
""" """
return self.get_auto_recips(key, 'cc') return self.get_auto_recips(key, "cc")
def get_auto_bcc(self, key): def get_auto_bcc(self, key):
""" """
Returns automatic :attr:`~wuttjamaican.email.Message.bcc` Returns automatic :attr:`~wuttjamaican.email.Message.bcc`
recipient address(es) for a message, as determined by config. recipient address(es) for a message, as determined by config.
""" """
return self.get_auto_recips(key, 'bcc') return self.get_auto_recips(key, "bcc")
def get_auto_recips(self, key, typ): # pylint: disable=empty-docstring def get_auto_recips(self, key, typ): # pylint: disable=empty-docstring
""" """ """ """
typ = typ.lower() typ = typ.lower()
if typ not in ('to', 'cc', 'bcc'): if typ not in ("to", "cc", "bcc"):
raise ValueError("requested type not supported") raise ValueError("requested type not supported")
# prefer configured recips specific to key # prefer configured recips specific to key
recips = self.config.get_list(f'{self.config.appname}.email.{key}.{typ}') recips = self.config.get_list(f"{self.config.appname}.email.{key}.{typ}")
if recips: if recips:
return recips return recips
# fall back to global default # fall back to global default
return self.config.get_list(f'{self.config.appname}.email.default.{typ}', return self.config.get_list(
default=[]) f"{self.config.appname}.email.default.{typ}", default=[]
)
def get_auto_txt_body(self, key, context=None): def get_auto_txt_body(self, key, context=None):
""" """
@ -561,7 +576,7 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
content for a message, as determined by config. This renders content for a message, as determined by config. This renders
a template with the given context. a template with the given context.
""" """
template = self.get_auto_body_template(key, 'txt') template = self.get_auto_body_template(key, "txt")
if template: if template:
context = context or {} context = context or {}
return template.render(**context) return template.render(**context)
@ -574,24 +589,24 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
message, as determined by config. This renders a template message, as determined by config. This renders a template
with the given context. with the given context.
""" """
template = self.get_auto_body_template(key, 'html') template = self.get_auto_body_template(key, "html")
if template: if template:
context = context or {} context = context or {}
return template.render(**context) return template.render(**context)
return None return None
def get_auto_body_template(self, key, mode): # pylint: disable=empty-docstring def get_auto_body_template(self, key, mode): # pylint: disable=empty-docstring
""" """ """ """
mode = mode.lower() mode = mode.lower()
if mode == 'txt': if mode == "txt":
templates = self.txt_templates templates = self.txt_templates
elif mode == 'html': elif mode == "html":
templates = self.html_templates templates = self.html_templates
else: else:
raise ValueError("requested mode not supported") raise ValueError("requested mode not supported")
try: try:
return templates.get_template(f'{key}.{mode}.mako') return templates.get_template(f"{key}.{mode}.mako")
except TopLevelLookupException: except TopLevelLookupException:
pass pass
return None return None
@ -604,7 +619,7 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: Notes as string if found; otherwise ``None``. :returns: Notes as string if found; otherwise ``None``.
""" """
return self.config.get(f'{self.config.appname}.email.{key}.notes') return self.config.get(f"{self.config.appname}.email.{key}.notes")
def is_enabled(self, key): def is_enabled(self, key):
""" """
@ -646,8 +661,8 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
:returns: True if this email type is enabled, otherwise false. :returns: True if this email type is enabled, otherwise false.
""" """
for k in set([key, 'default']): for k in set([key, "default"]):
enabled = self.config.get_bool(f'{self.config.appname}.email.{k}.enabled') enabled = self.config.get_bool(f"{self.config.appname}.email.{k}.enabled")
if enabled is not None: if enabled is not None:
return enabled return enabled
return True return True
@ -702,9 +717,11 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
message = message.as_string() message = message.as_string()
# get smtp info # get smtp info
server = self.config.get(f'{self.config.appname}.mail.smtp.server', default='localhost') server = self.config.get(
username = self.config.get(f'{self.config.appname}.mail.smtp.username') f"{self.config.appname}.mail.smtp.server", default="localhost"
password = self.config.get(f'{self.config.appname}.mail.smtp.password') )
username = self.config.get(f"{self.config.appname}.mail.smtp.username")
password = self.config.get(f"{self.config.appname}.mail.smtp.password")
# make sure sending is enabled # make sure sending is enabled
log.debug("sending email from %s; to %s", sender, recips) log.debug("sending email from %s; to %s", sender, recips)
@ -735,10 +752,13 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
Note that it is OFF by default. Note that it is OFF by default.
""" """
return self.config.get_bool(f'{self.config.appname}.mail.send_emails', return self.config.get_bool(
default=False) f"{self.config.appname}.mail.send_emails", default=False
)
def send_email(self, key=None, context=None, message=None, sender=None, recips=None, **kwargs): def send_email(
self, key=None, context=None, message=None, sender=None, recips=None, **kwargs
):
""" """
Send an email message. Send an email message.
@ -807,11 +827,13 @@ class EmailHandler(GenericHandler): # pylint: disable=too-many-public-methods
# auto-create message from key + context # auto-create message from key + context
if sender: if sender:
kwargs['sender'] = sender kwargs["sender"] = sender
message = self.make_auto_message(key, context or {}, **kwargs) message = self.make_auto_message(key, context or {}, **kwargs)
if not (message.txt_body or message.html_body): if not (message.txt_body or message.html_body):
raise RuntimeError(f"message (type: {key}) has no body - " raise RuntimeError(
"perhaps template file not found?") f"message (type: {key}) has no body - "
"perhaps template file not found?"
)
if not (message.txt_body or message.html_body): if not (message.txt_body or message.html_body):
if key: if key:

View file

@ -32,7 +32,8 @@ class UpgradeStatus(Enum):
Enum values for Enum values for
:attr:`wuttjamaican.db.model.upgrades.Upgrade.status`. :attr:`wuttjamaican.db.model.upgrades.Upgrade.status`.
""" """
PENDING = 'pending'
EXECUTING = 'executing' PENDING = "pending"
SUCCESS = 'success' EXECUTING = "executing"
FAILURE = 'failure' SUCCESS = "success"
FAILURE = "failure"

View file

@ -80,7 +80,8 @@ class InstallHandler(GenericHandler):
Egg name for the app. If not specified one will be guessed. Egg name for the app. If not specified one will be guessed.
""" """
pkg_name = 'poser'
pkg_name = "poser"
app_title = None app_title = None
pypi_name = None pypi_name = None
egg_name = None egg_name = None
@ -97,7 +98,7 @@ class InstallHandler(GenericHandler):
if not self.pypi_name: if not self.pypi_name:
self.pypi_name = self.app_title self.pypi_name = self.app_title
if not self.egg_name: if not self.egg_name:
self.egg_name = self.pypi_name.replace('-', '_') self.egg_name = self.pypi_name.replace("-", "_")
def run(self): def run(self):
""" """
@ -118,11 +119,13 @@ class InstallHandler(GenericHandler):
self.require_prompt_toolkit() self.require_prompt_toolkit()
paths = [ paths = [
self.app.resource_path('wuttjamaican:templates/install'), self.app.resource_path("wuttjamaican:templates/install"),
] ]
try: try:
paths.insert(0, self.app.resource_path(f'{self.pkg_name}:templates/install')) paths.insert(
0, self.app.resource_path(f"{self.pkg_name}:templates/install")
)
except (TypeError, ModuleNotFoundError): except (TypeError, ModuleNotFoundError):
pass pass
@ -143,8 +146,10 @@ class InstallHandler(GenericHandler):
""" """
self.rprint(f"\n\t[blue]Welcome to {self.app.get_title()}![/blue]") self.rprint(f"\n\t[blue]Welcome to {self.app.get_title()}![/blue]")
self.rprint("\n\tThis tool will install and configure the app.") self.rprint("\n\tThis tool will install and configure the app.")
self.rprint("\n\t[italic]NB. You should already have created " self.rprint(
"the database in PostgreSQL or MySQL.[/italic]") "\n\t[italic]NB. You should already have created "
"the database in PostgreSQL or MySQL.[/italic]"
)
# shall we continue? # shall we continue?
if not self.prompt_bool("continue?", True): if not self.prompt_bool("continue?", True):
@ -159,7 +164,7 @@ class InstallHandler(GenericHandler):
This is normally called by :meth:`run()`. This is normally called by :meth:`run()`.
""" """
# appdir must not yet exist # appdir must not yet exist
appdir = os.path.join(sys.prefix, 'app') appdir = os.path.join(sys.prefix, "app")
if os.path.exists(appdir): if os.path.exists(appdir):
self.rprint(f"\n\t[bold red]appdir already exists:[/bold red] {appdir}\n") self.rprint(f"\n\t[bold red]appdir already exists:[/bold red] {appdir}\n")
sys.exit(2) sys.exit(2)
@ -185,7 +190,7 @@ class InstallHandler(GenericHandler):
self.make_appdir(context) self.make_appdir(context)
# install db schema if user likes # install db schema if user likes
self.schema_installed = self.install_db_schema(dbinfo['dburl']) self.schema_installed = self.install_db_schema(dbinfo["dburl"])
def get_dbinfo(self): def get_dbinfo(self):
""" """
@ -199,27 +204,29 @@ class InstallHandler(GenericHandler):
dbinfo = {} dbinfo = {}
# get db info # get db info
dbinfo['dbtype'] = self.prompt_generic('db type', 'postgresql') dbinfo["dbtype"] = self.prompt_generic("db type", "postgresql")
dbinfo['dbhost'] = self.prompt_generic('db host', 'localhost') dbinfo["dbhost"] = self.prompt_generic("db host", "localhost")
default_port = '3306' if dbinfo['dbtype'] == 'mysql' else '5432' default_port = "3306" if dbinfo["dbtype"] == "mysql" else "5432"
dbinfo['dbport'] = self.prompt_generic('db port', default_port) dbinfo["dbport"] = self.prompt_generic("db port", default_port)
dbinfo['dbname'] = self.prompt_generic('db name', self.pkg_name) dbinfo["dbname"] = self.prompt_generic("db name", self.pkg_name)
dbinfo['dbuser'] = self.prompt_generic('db user', self.pkg_name) dbinfo["dbuser"] = self.prompt_generic("db user", self.pkg_name)
# get db password # get db password
dbinfo['dbpass'] = None dbinfo["dbpass"] = None
while not dbinfo['dbpass']: while not dbinfo["dbpass"]:
dbinfo['dbpass'] = self.prompt_generic('db pass', is_password=True) dbinfo["dbpass"] = self.prompt_generic("db pass", is_password=True)
# test db connection # test db connection
self.rprint("\n\ttesting db connection... ", end='') self.rprint("\n\ttesting db connection... ", end="")
dbinfo['dburl'] = self.make_db_url(dbinfo['dbtype'], dbinfo["dburl"] = self.make_db_url(
dbinfo['dbhost'], dbinfo["dbtype"],
dbinfo['dbport'], dbinfo["dbhost"],
dbinfo['dbname'], dbinfo["dbport"],
dbinfo['dbuser'], dbinfo["dbname"],
dbinfo['dbpass']) dbinfo["dbuser"],
error = self.test_db_connection(dbinfo['dburl']) dbinfo["dbpass"],
)
error = self.test_db_connection(dbinfo["dburl"])
if error: if error:
self.rprint("[bold red]cannot connect![/bold red] ..error was:") self.rprint("[bold red]cannot connect![/bold red] ..error was:")
self.rprint(f"\n{error}") self.rprint(f"\n{error}")
@ -229,23 +236,27 @@ class InstallHandler(GenericHandler):
return dbinfo return dbinfo
def make_db_url(self, dbtype, dbhost, dbport, dbname, dbuser, dbpass): # pylint: disable=empty-docstring def make_db_url(
self, dbtype, dbhost, dbport, dbname, dbuser, dbpass
): # pylint: disable=empty-docstring
""" """ """ """
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
if dbtype == 'mysql': if dbtype == "mysql":
drivername = 'mysql+mysqlconnector' drivername = "mysql+mysqlconnector"
else: else:
drivername = 'postgresql+psycopg2' drivername = "postgresql+psycopg2"
return URL.create(drivername=drivername, return URL.create(
username=dbuser, drivername=drivername,
password=dbpass, username=dbuser,
host=dbhost, password=dbpass,
port=dbport, host=dbhost,
database=dbname) port=dbport,
database=dbname,
)
def test_db_connection(self, url): # pylint: disable=empty-docstring def test_db_connection(self, url): # pylint: disable=empty-docstring
""" """ """ """
import sqlalchemy as sa import sqlalchemy as sa
@ -254,8 +265,8 @@ class InstallHandler(GenericHandler):
# check for random table; does not matter if it exists, we # check for random table; does not matter if it exists, we
# just need to test interaction and this is a neutral way # just need to test interaction and this is a neutral way
try: try:
sa.inspect(engine).has_table('whatever') sa.inspect(engine).has_table("whatever")
except Exception as error: # pylint: disable=broad-exception-caught except Exception as error: # pylint: disable=broad-exception-caught
return str(error) return str(error)
return None return None
@ -287,16 +298,16 @@ class InstallHandler(GenericHandler):
* ``db_url`` - value from ``dbinfo['dburl']`` * ``db_url`` - value from ``dbinfo['dburl']``
""" """
envname = os.path.basename(sys.prefix) envname = os.path.basename(sys.prefix)
appdir = os.path.join(sys.prefix, 'app') appdir = os.path.join(sys.prefix, "app")
context = { context = {
'envdir': sys.prefix, "envdir": sys.prefix,
'envname': envname, "envname": envname,
'pkg_name': self.pkg_name, "pkg_name": self.pkg_name,
'app_title': self.app_title, "app_title": self.app_title,
'pypi_name': self.pypi_name, "pypi_name": self.pypi_name,
'appdir': appdir, "appdir": appdir,
'db_url': dbinfo['dburl'], "db_url": dbinfo["dburl"],
'egg_name': self.egg_name, "egg_name": self.egg_name,
} }
context.update(kwargs) context.update(kwargs)
return context return context
@ -335,38 +346,36 @@ class InstallHandler(GenericHandler):
# but then we also generate some files... # but then we also generate some files...
# wutta.conf # wutta.conf
self.make_config_file('wutta.conf.mako', self.make_config_file(
os.path.join(appdir, 'wutta.conf'), "wutta.conf.mako", os.path.join(appdir, "wutta.conf"), **context
**context) )
# web.conf # web.conf
web_context = dict(context) web_context = dict(context)
web_context.setdefault('beaker_key', context.get('pkg_name', 'poser')) web_context.setdefault("beaker_key", context.get("pkg_name", "poser"))
web_context.setdefault('beaker_secret', 'TODO_YOU_SHOULD_CHANGE_THIS') web_context.setdefault("beaker_secret", "TODO_YOU_SHOULD_CHANGE_THIS")
web_context.setdefault('pyramid_host', '0.0.0.0') web_context.setdefault("pyramid_host", "0.0.0.0")
web_context.setdefault('pyramid_port', '9080') web_context.setdefault("pyramid_port", "9080")
self.make_config_file('web.conf.mako', self.make_config_file(
os.path.join(appdir, 'web.conf'), "web.conf.mako", os.path.join(appdir, "web.conf"), **web_context
**web_context) )
# upgrade.sh # upgrade.sh
template = self.templates.get_template('upgrade.sh.mako') template = self.templates.get_template("upgrade.sh.mako")
output_path = os.path.join(appdir, 'upgrade.sh') output_path = os.path.join(appdir, "upgrade.sh")
self.render_mako_template(template, context, self.render_mako_template(template, context, output_path=output_path)
output_path=output_path) os.chmod(
os.chmod(output_path, stat.S_IRWXU output_path,
| stat.S_IRGRP stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH,
| stat.S_IXGRP )
| stat.S_IROTH
| stat.S_IXOTH)
self.rprint(f"\n\tappdir created at: [bold green]{appdir}[/bold green]") self.rprint(f"\n\tappdir created at: [bold green]{appdir}[/bold green]")
def render_mako_template( def render_mako_template(
self, self,
template, template,
context, context,
output_path=None, output_path=None,
): ):
""" """
Convenience wrapper around Convenience wrapper around
@ -384,8 +393,7 @@ class InstallHandler(GenericHandler):
if isinstance(template, str): if isinstance(template, str):
template = self.templates.get_template(template) template = self.templates.get_template(template)
return self.app.render_mako_template(template, context, return self.app.render_mako_template(template, context, output_path=output_path)
output_path=output_path)
def make_config_file(self, template, output_path, **kwargs): def make_config_file(self, template, output_path, **kwargs):
""" """
@ -413,14 +421,13 @@ class InstallHandler(GenericHandler):
Once it does that it calls :meth:`render_mako_template()`. Once it does that it calls :meth:`render_mako_template()`.
""" """
context = { context = {
'app_title': self.app.get_title(), "app_title": self.app.get_title(),
'appdir': self.app.get_appdir(), "appdir": self.app.get_appdir(),
'db_url': 'postresql://user:pass@localhost/poser', "db_url": "postresql://user:pass@localhost/poser",
'os': os, "os": os,
} }
context.update(kwargs) context.update(kwargs)
self.render_mako_template(template, context, self.render_mako_template(template, context, output_path=output_path)
output_path=output_path)
return output_path return output_path
def install_db_schema(self, db_url, appdir=None): def install_db_schema(self, db_url, appdir=None):
@ -444,13 +451,19 @@ class InstallHandler(GenericHandler):
# install db schema # install db schema
appdir = appdir or self.app.get_appdir() appdir = appdir or self.app.get_appdir()
cmd = [os.path.join(sys.prefix, 'bin', 'alembic'), cmd = [
'-c', os.path.join(appdir, 'wutta.conf'), os.path.join(sys.prefix, "bin", "alembic"),
'upgrade', 'heads'] "-c",
os.path.join(appdir, "wutta.conf"),
"upgrade",
"heads",
]
subprocess.check_call(cmd) subprocess.check_call(cmd)
self.rprint("\n\tdb schema installed to: " self.rprint(
f"[bold green]{obfuscate_url_pw(db_url)}[/bold green]") "\n\tdb schema installed to: "
f"[bold green]{obfuscate_url_pw(db_url)}[/bold green]"
)
return True return True
def show_goodbye(self): def show_goodbye(self):
@ -472,19 +485,22 @@ class InstallHandler(GenericHandler):
# console utility functions # console utility functions
############################## ##############################
def require_prompt_toolkit(self, answer=None): # pylint: disable=empty-docstring def require_prompt_toolkit(self, answer=None): # pylint: disable=empty-docstring
""" """ """ """
try: try:
import prompt_toolkit # pylint: disable=unused-import import prompt_toolkit # pylint: disable=unused-import
except ImportError: except ImportError:
value = answer or input("\nprompt_toolkit is not installed. shall i install it? [Yn] ") value = answer or input(
"\nprompt_toolkit is not installed. shall i install it? [Yn] "
)
value = value.strip() value = value.strip()
if value and not self.config.parse_bool(value): if value and not self.config.parse_bool(value):
sys.stderr.write("prompt_toolkit is required; aborting\n") sys.stderr.write("prompt_toolkit is required; aborting\n")
sys.exit(1) sys.exit(1)
subprocess.check_call([sys.executable, '-m', 'pip', subprocess.check_call(
'install', 'prompt_toolkit']) [sys.executable, "-m", "pip", "install", "prompt_toolkit"]
)
# nb. this should now succeed # nb. this should now succeed
import prompt_toolkit import prompt_toolkit
@ -495,23 +511,25 @@ class InstallHandler(GenericHandler):
""" """
rich.print(*args, **kwargs) rich.print(*args, **kwargs)
def get_prompt_style(self): # pylint: disable=empty-docstring def get_prompt_style(self): # pylint: disable=empty-docstring
""" """ """ """
from prompt_toolkit.styles import Style from prompt_toolkit.styles import Style
# message formatting styles # message formatting styles
return Style.from_dict({ return Style.from_dict(
'': '', {
'bold': 'bold', "": "",
}) "bold": "bold",
}
)
def prompt_generic( def prompt_generic(
self, self,
info, info,
default=None, default=None,
is_password=False, is_password=False,
is_bool=False, is_bool=False,
required=False, required=False,
): ):
""" """
Prompt the user to get their input. Prompt the user to get their input.
@ -540,39 +558,42 @@ class InstallHandler(GenericHandler):
# build prompt message # build prompt message
message = [ message = [
('', '\n'), ("", "\n"),
('class:bold', info), ("class:bold", info),
] ]
if default is not None: if default is not None:
if is_bool: if is_bool:
message.append(('', f' [{"Y" if default else "N"}]: ')) message.append(("", f' [{"Y" if default else "N"}]: '))
else: else:
message.append(('', f' [{default}]: ')) message.append(("", f" [{default}]: "))
else: else:
message.append(('', ': ')) message.append(("", ": "))
# prompt user for input # prompt user for input
style = self.get_prompt_style() style = self.get_prompt_style()
try: try:
text = prompt(message, style=style, is_password=is_password) text = prompt(message, style=style, is_password=is_password)
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
self.rprint("\n\t[bold yellow]operation canceled by user[/bold yellow]\n", self.rprint(
file=sys.stderr) "\n\t[bold yellow]operation canceled by user[/bold yellow]\n",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
if is_bool: if is_bool:
if text == '': if text == "":
return default return default
if text.upper() == 'Y': if text.upper() == "Y":
return True return True
if text.upper() == 'N': if text.upper() == "N":
return False return False
self.rprint("\n\t[bold yellow]ambiguous, please try again[/bold yellow]\n") self.rprint("\n\t[bold yellow]ambiguous, please try again[/bold yellow]\n")
return self.prompt_generic(info, default, is_bool=True) return self.prompt_generic(info, default, is_bool=True)
if required and not text and not default: if required and not text and not default:
return self.prompt_generic(info, default, is_password=is_password, return self.prompt_generic(
required=True) info, default, is_password=is_password, required=True
)
return text or default return text or default

View file

@ -54,12 +54,14 @@ class PeopleHandler(GenericHandler):
""" """
model = self.app.model model = self.app.model
if 'full_name' not in kwargs: if "full_name" not in kwargs:
full_name = self.app.make_full_name(kwargs.get('first_name'), full_name = self.app.make_full_name(
kwargs.get('middle_name'), kwargs.get("first_name"),
kwargs.get('last_name')) kwargs.get("middle_name"),
kwargs.get("last_name"),
)
if full_name: if full_name:
kwargs['full_name'] = full_name kwargs["full_name"] = full_name
return model.Person(**kwargs) return model.Person(**kwargs)

View file

@ -100,7 +100,7 @@ class ProblemCheck:
""" """
return [] return []
def get_email_context(self, problems, **kwargs): # pylint: disable=unused-argument def get_email_context(self, problems, **kwargs): # pylint: disable=unused-argument
""" """
This can be used to add extra context for a specific check's This can be used to add extra context for a specific check's
report email template. report email template.
@ -149,15 +149,18 @@ class ProblemHandler(GenericHandler):
:returns: List of :class:`ProblemCheck` classes. :returns: List of :class:`ProblemCheck` classes.
""" """
checks = [] checks = []
modules = self.config.get_list(f'{self.config.appname}.problems.modules', modules = self.config.get_list(
default=['wuttjamaican.problems']) f"{self.config.appname}.problems.modules", default=["wuttjamaican.problems"]
)
for module_path in modules: for module_path in modules:
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
for name in dir(module): for name in dir(module):
obj = getattr(module, name) obj = getattr(module, name)
if (isinstance(obj, type) and if (
issubclass(obj, ProblemCheck) and isinstance(obj, type)
obj is not ProblemCheck): and issubclass(obj, ProblemCheck)
and obj is not ProblemCheck
):
checks.append(obj) checks.append(obj)
return checks return checks
@ -224,8 +227,8 @@ class ProblemHandler(GenericHandler):
:returns: ``True`` if enabled; ``False`` if not. :returns: ``True`` if enabled; ``False`` if not.
""" """
key = f'{check.system_key}.{check.problem_key}' key = f"{check.system_key}.{check.problem_key}"
enabled = self.config.get_bool(f'{self.config.appname}.problems.{key}.enabled') enabled = self.config.get_bool(f"{self.config.appname}.problems.{key}.enabled")
if enabled is not None: if enabled is not None:
return enabled return enabled
return True return True
@ -243,8 +246,10 @@ class ProblemHandler(GenericHandler):
:returns: ``True`` if check should run; ``False`` if not. :returns: ``True`` if check should run; ``False`` if not.
""" """
key = f'{check.system_key}.{check.problem_key}' key = f"{check.system_key}.{check.problem_key}"
enabled = self.config.get_bool(f'{self.config.appname}.problems.{key}.day{weekday}') enabled = self.config.get_bool(
f"{self.config.appname}.problems.{key}.day{weekday}"
)
if enabled is not None: if enabled is not None:
return enabled return enabled
return True return True
@ -302,7 +307,7 @@ class ProblemHandler(GenericHandler):
:param force: If true, run the check regardless of whether it :param force: If true, run the check regardless of whether it
is configured to run. is configured to run.
""" """
key = f'{check.system_key}.{check.problem_key}' key = f"{check.system_key}.{check.problem_key}"
log.info("running problem check: %s", key) log.info("running problem check: %s", key)
if not self.is_enabled(check): if not self.is_enabled(check):
@ -312,8 +317,11 @@ class ProblemHandler(GenericHandler):
weekday = datetime.date.today().weekday() weekday = datetime.date.today().weekday()
if not self.should_run_for_weekday(check, weekday): if not self.should_run_for_weekday(check, weekday):
log.debug("problem check is not scheduled for %s: %s", log.debug(
calendar.day_name[weekday], key) "problem check is not scheduled for %s: %s",
calendar.day_name[weekday],
key,
)
if not force: if not force:
return None return None
@ -355,7 +363,7 @@ class ProblemHandler(GenericHandler):
:returns: Config key for problem report email message. :returns: Config key for problem report email message.
""" """
return f'{check.system_key}_problems_{check.problem_key}' return f"{check.system_key}_problems_{check.problem_key}"
def send_problem_report(self, check, problems): def send_problem_report(self, check, problems):
""" """
@ -377,18 +385,20 @@ class ProblemHandler(GenericHandler):
""" """
context = self.get_global_email_context() context = self.get_global_email_context()
context = self.get_check_email_context(check, problems, **context) context = self.get_check_email_context(check, problems, **context)
context.update({ context.update(
'config': self.config, {
'app': self.app, "config": self.config,
'check': check, "app": self.app,
'problems': problems, "check": check,
}) "problems": problems,
}
)
email_key = self.get_email_key(check) email_key = self.get_email_key(check)
attachments = check.make_email_attachments(context) attachments = check.make_email_attachments(context)
self.app.send_email(email_key, context, self.app.send_email(
default_subject=check.title, email_key, context, default_subject=check.title, attachments=attachments
attachments=attachments) )
def get_global_email_context(self, **kwargs): def get_global_email_context(self, **kwargs):
""" """
@ -413,6 +423,6 @@ class ProblemHandler(GenericHandler):
:returns: Context dict for email template. :returns: Context dict for email template.
""" """
kwargs['system_title'] = self.get_system_title(check.system_key) kwargs["system_title"] = self.get_system_title(check.system_key)
kwargs = check.get_email_context(problems, **kwargs) kwargs = check.get_email_context(problems, **kwargs)
return kwargs return kwargs

View file

@ -98,16 +98,20 @@ class ConsoleProgress(ProgressBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args) super().__init__(*args)
self.stderr = kwargs.get('stderr', sys.stderr) self.stderr = kwargs.get("stderr", sys.stderr)
self.stderr.write(f"\n{self.message}...\n") self.stderr.write(f"\n{self.message}...\n")
self.bar = Bar(message='', max=self.maximum, width=70, # pylint: disable=disallowed-name self.bar = Bar( # pylint: disable=disallowed-name
suffix='%(index)d/%(max)d %(percent)d%% ETA %(eta)ds') message="",
max=self.maximum,
width=70,
suffix="%(index)d/%(max)d %(percent)d%% ETA %(eta)ds",
)
def update(self, value): # pylint: disable=empty-docstring def update(self, value): # pylint: disable=empty-docstring
""" """ """ """
self.bar.next() self.bar.next()
def finish(self): # pylint: disable=empty-docstring def finish(self): # pylint: disable=empty-docstring
""" """ """ """
self.bar.finish() self.bar.finish()

View file

@ -39,6 +39,7 @@ class Report:
This is the common display title for the report. This is the common display title for the report.
""" """
report_title = "Untitled Report" report_title = "Untitled Report"
def __init__(self, config): def __init__(self, config):
@ -146,7 +147,7 @@ class ReportHandler(GenericHandler):
:meth:`~wuttjamaican.app.GenericHandler.get_provider_modules()` :meth:`~wuttjamaican.app.GenericHandler.get_provider_modules()`
under the hood, for ``report`` module type. under the hood, for ``report`` module type.
""" """
return self.get_provider_modules('report') return self.get_provider_modules("report")
def get_reports(self): def get_reports(self):
""" """
@ -156,14 +157,16 @@ class ReportHandler(GenericHandler):
This calls :meth:`get_report_modules()` and for each module, This calls :meth:`get_report_modules()` and for each module,
it discovers all the reports it contains. it discovers all the reports it contains.
""" """
if not hasattr(self, '_reports'): if not hasattr(self, "_reports"):
self._reports = {} self._reports = {}
for module in self.get_report_modules(): for module in self.get_report_modules():
for name in dir(module): for name in dir(module):
obj = getattr(module, name) obj = getattr(module, name)
if (isinstance(obj, type) if (
isinstance(obj, type)
and obj is not Report and obj is not Report
and issubclass(obj, Report)): and issubclass(obj, Report)
):
self._reports[obj.report_key] = obj self._reports[obj.report_key] = obj
return self._reports return self._reports
@ -214,6 +217,6 @@ class ReportHandler(GenericHandler):
""" """
data = report.make_data(params or {}, progress=progress, **kwargs) data = report.make_data(params or {}, progress=progress, **kwargs)
if not isinstance(data, dict): if not isinstance(data, dict):
data = {'data': data} data = {"data": data}
data.setdefault('output_title', report.report_title) data.setdefault("output_title", report.report_title)
return data return data

View file

@ -53,7 +53,7 @@ class FileTestCase(TestCase):
class. class.
""" """
def setUp(self): # pylint: disable=empty-docstring def setUp(self): # pylint: disable=empty-docstring
""" """ """ """
self.setup_files() self.setup_files()
@ -63,14 +63,17 @@ class FileTestCase(TestCase):
""" """
self.tempdir = tempfile.mkdtemp() self.tempdir = tempfile.mkdtemp()
def setup_file_config(self): # pragma: no cover; pylint: disable=empty-docstring def setup_file_config(self): # pragma: no cover; pylint: disable=empty-docstring
""" """ """ """
warnings.warn("FileTestCase.setup_file_config() is deprecated; " warnings.warn(
"please use setup_files() instead", "FileTestCase.setup_file_config() is deprecated; "
DeprecationWarning, stacklevel=2) "please use setup_files() instead",
DeprecationWarning,
stacklevel=2,
)
self.setup_files() self.setup_files()
def tearDown(self): # pylint: disable=empty-docstring def tearDown(self): # pylint: disable=empty-docstring
""" """ """ """
self.teardown_files() self.teardown_files()
@ -80,11 +83,14 @@ class FileTestCase(TestCase):
""" """
shutil.rmtree(self.tempdir) shutil.rmtree(self.tempdir)
def teardown_file_config(self): # pragma: no cover; pylint: disable=empty-docstring def teardown_file_config(self): # pragma: no cover; pylint: disable=empty-docstring
""" """ """ """
warnings.warn("FileTestCase.teardown_file_config() is deprecated; " warnings.warn(
"please use teardown_files() instead", "FileTestCase.teardown_file_config() is deprecated; "
DeprecationWarning, stacklevel=2) "please use teardown_files() instead",
DeprecationWarning,
stacklevel=2,
)
self.teardown_files() self.teardown_files()
def write_file(self, filename, content): def write_file(self, filename, content):
@ -95,15 +101,18 @@ class FileTestCase(TestCase):
myconf = self.write_file('my.conf', '<file contents>') myconf = self.write_file('my.conf', '<file contents>')
""" """
path = os.path.join(self.tempdir, filename) path = os.path.join(self.tempdir, filename)
with open(path, 'wt', encoding='utf_8') as f: with open(path, "wt", encoding="utf_8") as f:
f.write(content) f.write(content)
return path return path
def mkdir(self, dirname): # pylint: disable=unused-argument,empty-docstring def mkdir(self, dirname): # pylint: disable=unused-argument,empty-docstring
""" """ """ """
warnings.warn("FileTestCase.mkdir() is deprecated; " warnings.warn(
"please use FileTestCase.mkdtemp() instead", "FileTestCase.mkdir() is deprecated; "
DeprecationWarning, stacklevel=2) "please use FileTestCase.mkdtemp() instead",
DeprecationWarning,
stacklevel=2,
)
return self.mkdtemp() return self.mkdtemp()
def mkdtemp(self): def mkdtemp(self):
@ -143,7 +152,7 @@ class ConfigTestCase(FileTestCase):
methods for this class. methods for this class.
""" """
def setUp(self): # pylint: disable=empty-docstring def setUp(self): # pylint: disable=empty-docstring
""" """ """ """
self.setup_config() self.setup_config()
@ -155,7 +164,7 @@ class ConfigTestCase(FileTestCase):
self.config = self.make_config() self.config = self.make_config()
self.app = self.config.get_app() self.app = self.config.get_app()
def tearDown(self): # pylint: disable=empty-docstring def tearDown(self): # pylint: disable=empty-docstring
""" """ """ """
self.teardown_config() self.teardown_config()
@ -165,7 +174,7 @@ class ConfigTestCase(FileTestCase):
""" """
self.teardown_files() self.teardown_files()
def make_config(self, **kwargs): # pylint: disable=empty-docstring def make_config(self, **kwargs): # pylint: disable=empty-docstring
""" """ """ """
return WuttaConfig(**kwargs) return WuttaConfig(**kwargs)
@ -203,7 +212,7 @@ class DataTestCase(FileTestCase):
teardown methods, as this class handles that automatically. teardown methods, as this class handles that automatically.
""" """
def setUp(self): # pylint: disable=empty-docstring def setUp(self): # pylint: disable=empty-docstring
""" """ """ """
self.setup_db() self.setup_db()
@ -212,9 +221,11 @@ class DataTestCase(FileTestCase):
Perform config/app/db setup operations for the test. Perform config/app/db setup operations for the test.
""" """
self.setup_files() self.setup_files()
self.config = self.make_config(defaults={ self.config = self.make_config(
'wutta.db.default.url': 'sqlite://', defaults={
}) "wutta.db.default.url": "sqlite://",
}
)
self.app = self.config.get_app() self.app = self.config.get_app()
# init db # init db
@ -222,7 +233,7 @@ class DataTestCase(FileTestCase):
model.Base.metadata.create_all(bind=self.config.appdb_engine) model.Base.metadata.create_all(bind=self.config.appdb_engine)
self.session = self.app.make_session() self.session = self.app.make_session()
def tearDown(self): # pylint: disable=empty-docstring def tearDown(self): # pylint: disable=empty-docstring
""" """ """ """
self.teardown_db() self.teardown_db()
@ -232,6 +243,6 @@ class DataTestCase(FileTestCase):
""" """
self.teardown_files() self.teardown_files()
def make_config(self, **kwargs): # pylint: disable=empty-docstring def make_config(self, **kwargs): # pylint: disable=empty-docstring
""" """ """ """
return WuttaConfig(**kwargs) return WuttaConfig(**kwargs)

View file

@ -106,7 +106,7 @@ def load_entry_points(group, ignore_errors=False):
import importlib_metadata import importlib_metadata
eps = importlib_metadata.entry_points() eps = importlib_metadata.entry_points()
if not hasattr(eps, 'select'): if not hasattr(eps, "select"):
# python < 3.10 # python < 3.10
eps = eps.get(group, []) eps = eps.get(group, [])
else: else:
@ -115,11 +115,10 @@ def load_entry_points(group, ignore_errors=False):
for entry_point in eps: for entry_point in eps:
try: try:
ep = entry_point.load() ep = entry_point.load()
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
if not ignore_errors: if not ignore_errors:
raise raise
log.warning("failed to load entry point: %s", entry_point, log.warning("failed to load entry point: %s", entry_point, exc_info=True)
exc_info=True)
else: else:
entry_points[entry_point.name] = ep entry_points[entry_point.name] = ep
@ -151,7 +150,7 @@ def load_object(spec):
if not spec: if not spec:
raise ValueError("no object spec provided") raise ValueError("no object spec provided")
module_path, name = spec.split(':') module_path, name = spec.split(":")
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, name) return getattr(module, name)
@ -165,10 +164,10 @@ def make_title(text):
make_title('foo_bar') # => 'Foo Bar' make_title('foo_bar') # => 'Foo Bar'
""" """
text = text.replace('_', ' ') text = text.replace("_", " ")
text = text.replace('-', ' ') text = text.replace("-", " ")
words = text.split() words = text.split()
return ' '.join([x.capitalize() for x in words]) return " ".join([x.capitalize() for x in words])
def make_full_name(*parts): def make_full_name(*parts):
@ -185,10 +184,9 @@ def make_full_name(*parts):
make_full_name('First', '', 'Last', 'Suffix') make_full_name('First', '', 'Last', 'Suffix')
# => "First Last Suffix" # => "First Last Suffix"
""" """
parts = [(part or '').strip() parts = [(part or "").strip() for part in parts]
for part in parts]
parts = [part for part in parts if part] parts = [part for part in parts if part]
return ' '.join(parts) return " ".join(parts)
def make_true_uuid(): def make_true_uuid():
@ -238,7 +236,7 @@ def parse_bool(value):
return None return None
if isinstance(value, bool): if isinstance(value, bool):
return value return value
if str(value).lower() in ('true', 'yes', 'y', 'on', '1'): if str(value).lower() in ("true", "yes", "y", "on", "1"):
return True return True
return False return False
@ -253,7 +251,7 @@ def parse_list(value):
if isinstance(value, list): if isinstance(value, list):
return value return value
parser = shlex.shlex(value) parser = shlex.shlex(value)
parser.whitespace += ',' parser.whitespace += ","
parser.whitespace_split = True parser.whitespace_split = True
values = list(parser) values = list(parser)
for i, val in enumerate(values): for i, val in enumerate(values):
@ -344,15 +342,15 @@ def resource_path(path):
:returns: Absolute file path to the resource. :returns: Absolute file path to the resource.
""" """
if not os.path.isabs(path) and ':' in path: if not os.path.isabs(path) and ":" in path:
try: try:
# nb. these were added in python 3.9 # nb. these were added in python 3.9
from importlib.resources import files, as_file from importlib.resources import files, as_file
except ImportError: # python < 3.9 except ImportError: # python < 3.9
from importlib_resources import files, as_file from importlib_resources import files, as_file
package, filename = path.split(':') package, filename = path.split(":")
ref = files(package) / filename ref = files(package) / filename
with as_file(ref) as p: with as_file(ref) as p:
return str(p) return str(p)

View file

@ -15,14 +15,14 @@ def release(c, skip_tests=False):
Release a new version of WuttJamaican Release a new version of WuttJamaican
""" """
if not skip_tests: if not skip_tests:
c.run('pytest') c.run("pytest")
# rebuild local tar.gz file for distribution # rebuild local tar.gz file for distribution
if os.path.exists('dist'): if os.path.exists("dist"):
shutil.rmtree('dist') shutil.rmtree("dist")
if os.path.exists('WuttJamaican.egg-info'): if os.path.exists("WuttJamaican.egg-info"):
shutil.rmtree('WuttJamaican.egg-info') shutil.rmtree("WuttJamaican.egg-info")
c.run('python -m build --sdist') c.run("python -m build --sdist")
# upload to PyPI # upload to PyPI
c.run('twine upload dist/*') c.run("twine upload dist/*")

View file

@ -11,13 +11,13 @@ from wuttjamaican.cli import base as mod
here = os.path.dirname(__file__) here = os.path.dirname(__file__)
example_conf = os.path.join(here, 'example.conf') example_conf = os.path.join(here, "example.conf")
class TestMakeCliConfig(TestCase): class TestMakeCliConfig(TestCase):
def test_basic(self): def test_basic(self):
ctx = MagicMock(params={'config_paths': [example_conf]}) ctx = MagicMock(params={"config_paths": [example_conf]})
config = mod.make_cli_config(ctx) config = mod.make_cli_config(ctx)
self.assertIsInstance(config, WuttaConfig) self.assertIsInstance(config, WuttaConfig)
self.assertEqual(config.files_read, [example_conf]) self.assertEqual(config.files_read, [example_conf])
@ -26,7 +26,7 @@ class TestMakeCliConfig(TestCase):
class TestTyperCallback(TestCase): class TestTyperCallback(TestCase):
def test_basic(self): def test_basic(self):
ctx = MagicMock(params={'config_paths': [example_conf]}) ctx = MagicMock(params={"config_paths": [example_conf]})
mod.typer_callback(ctx) mod.typer_callback(ctx)
self.assertIsInstance(ctx.wutta_config, WuttaConfig) self.assertIsInstance(ctx.wutta_config, WuttaConfig)
self.assertEqual(ctx.wutta_config.files_read, [example_conf]) self.assertEqual(ctx.wutta_config.files_read, [example_conf])
@ -35,10 +35,10 @@ class TestTyperCallback(TestCase):
class TestTyperEagerImports(TestCase): class TestTyperEagerImports(TestCase):
def test_basic(self): def test_basic(self):
typr = mod.make_typer(name='foobreezy') typr = mod.make_typer(name="foobreezy")
with patch.object(mod, 'load_entry_points') as load_entry_points: with patch.object(mod, "load_entry_points") as load_entry_points:
mod.typer_eager_imports(typr) mod.typer_eager_imports(typr)
load_entry_points.assert_called_once_with('foobreezy.typer_imports') load_entry_points.assert_called_once_with("foobreezy.typer_imports")
class TestMakeTyper(TestCase): class TestMakeTyper(TestCase):

View file

@ -9,17 +9,16 @@ from wuttjamaican.app import AppHandler
here = os.path.dirname(__file__) here = os.path.dirname(__file__)
example_conf = os.path.join(here, 'example.conf') example_conf = os.path.join(here, "example.conf")
class TestMakeAppdir(ConfigTestCase): class TestMakeAppdir(ConfigTestCase):
def test_basic(self): def test_basic(self):
appdir = os.path.join(self.tempdir, 'app') appdir = os.path.join(self.tempdir, "app")
ctx = MagicMock(params={'config_paths': [example_conf], ctx = MagicMock(params={"config_paths": [example_conf], "appdir_path": appdir})
'appdir_path': appdir})
ctx.parent.wutta_config = self.config ctx.parent.wutta_config = self.config
with patch.object(AppHandler, 'make_appdir') as make_appdir: with patch.object(AppHandler, "make_appdir") as make_appdir:
mod.make_appdir(ctx) mod.make_appdir(ctx)
make_appdir.assert_called_once_with(appdir) make_appdir.assert_called_once_with(appdir)

View file

@ -8,13 +8,13 @@ from wuttjamaican.cli import make_uuid as mod
here = os.path.dirname(__file__) here = os.path.dirname(__file__)
example_conf = os.path.join(here, 'example.conf') example_conf = os.path.join(here, "example.conf")
class TestMakeUuid(TestCase): class TestMakeUuid(TestCase):
def test_basic(self): def test_basic(self):
ctx = MagicMock(params={'config_paths': [example_conf]}) ctx = MagicMock(params={"config_paths": [example_conf]})
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
mod.make_uuid(ctx) mod.make_uuid(ctx)
sys.stdout.write.assert_called_once() sys.stdout.write.assert_called_once()

View file

@ -8,8 +8,8 @@ from wuttjamaican.problems import ProblemHandler, ProblemCheck
class FakeCheck(ProblemCheck): class FakeCheck(ProblemCheck):
system_key = 'wuttatest' system_key = "wuttatest"
problem_key = 'fake_check' problem_key = "fake_check"
title = "Fake problem check" title = "Fake problem check"
@ -20,18 +20,28 @@ class TestProblems(ConfigTestCase):
ctx.parent.wutta_config = self.config ctx.parent.wutta_config = self.config
# nb. avoid printing to console # nb. avoid printing to console
with patch.object(mod.rich, 'print') as rich_print: with patch.object(mod.rich, "print") as rich_print:
# nb. use fake check # nb. use fake check
with patch.object(ProblemHandler, 'get_all_problem_checks', return_value=[FakeCheck]): with patch.object(
ProblemHandler, "get_all_problem_checks", return_value=[FakeCheck]
):
with patch.object(ProblemHandler, 'run_problem_checks') as run_problem_checks: with patch.object(
ProblemHandler, "run_problem_checks"
) as run_problem_checks:
# list problem checks # list problem checks
orig_organize = ProblemHandler.organize_problem_checks orig_organize = ProblemHandler.organize_problem_checks
def mock_organize(checks): def mock_organize(checks):
return orig_organize(None, checks) return orig_organize(None, checks)
with patch.object(ProblemHandler, 'organize_problem_checks', side_effect=mock_organize) as organize_problem_checks:
with patch.object(
ProblemHandler,
"organize_problem_checks",
side_effect=mock_organize,
) as organize_problem_checks:
mod.problems(ctx, list_checks=True) mod.problems(ctx, list_checks=True)
organize_problem_checks.assert_called_once_with([FakeCheck]) organize_problem_checks.assert_called_once_with([FakeCheck])
run_problem_checks.assert_not_called() run_problem_checks.assert_not_called()
@ -41,10 +51,13 @@ class TestProblems(ConfigTestCase):
# nb. just --list for convenience # nb. just --list for convenience
# note that since we also specify invalid --system, no checks will # note that since we also specify invalid --system, no checks will
# match and hence nothing significant will be printed to stdout # match and hence nothing significant will be printed to stdout
mod.problems(ctx, list_checks=True, systems=['craziness']) mod.problems(ctx, list_checks=True, systems=["craziness"])
rich_print.assert_called_once() rich_print.assert_called_once()
self.assertEqual(len(rich_print.call_args.args), 1) self.assertEqual(len(rich_print.call_args.args), 1)
self.assertIn("No problem reports exist for system", rich_print.call_args.args[0]) self.assertIn(
"No problem reports exist for system",
rich_print.call_args.args[0],
)
self.assertEqual(len(rich_print.call_args.kwargs), 0) self.assertEqual(len(rich_print.call_args.kwargs), 0)
run_problem_checks.assert_not_called() run_problem_checks.assert_not_called()

View file

@ -18,22 +18,20 @@ else:
role.name = "Managers" role.name = "Managers"
self.assertEqual(str(role), "Managers") self.assertEqual(str(role), "Managers")
class TestPermission(TestCase): class TestPermission(TestCase):
def test_basic(self): def test_basic(self):
perm = model.Permission() perm = model.Permission()
self.assertEqual(str(perm), "") self.assertEqual(str(perm), "")
perm.permission = 'users.create' perm.permission = "users.create"
self.assertEqual(str(perm), "users.create") self.assertEqual(str(perm), "users.create")
class TestUser(TestCase): class TestUser(TestCase):
def test_str(self): def test_str(self):
user = model.User() user = model.User()
self.assertEqual(str(user), "") self.assertEqual(str(user), "")
user.username = 'barney' user.username = "barney"
self.assertEqual(str(user), "barney") self.assertEqual(str(user), "barney")
def test_str_with_person(self): def test_str_with_person(self):
@ -44,7 +42,6 @@ else:
user.person = person user.person = person
self.assertEqual(str(user), "Barney Rubble") self.assertEqual(str(user), "Barney Rubble")
class TestUserAPIToken(TestCase): class TestUserAPIToken(TestCase):
def test_str(self): def test_str(self):

View file

@ -11,35 +11,32 @@ except ImportError:
pass pass
else: else:
class MockUser(mod.Base): class MockUser(mod.Base):
__tablename__ = 'mock_user' __tablename__ = "mock_user"
uuid = mod.uuid_column(sa.ForeignKey('user.uuid'), default=False) uuid = mod.uuid_column(sa.ForeignKey("user.uuid"), default=False)
user = orm.relationship( user = orm.relationship(
User, User,
backref=orm.backref('_mock', uselist=False, cascade='all, delete-orphan')) backref=orm.backref("_mock", uselist=False, cascade="all, delete-orphan"),
)
favorite_color = sa.Column(sa.String(length=100), nullable=False) favorite_color = sa.Column(sa.String(length=100), nullable=False)
class TestWuttaModelBase(TestCase): class TestWuttaModelBase(TestCase):
def test_make_proxy(self): def test_make_proxy(self):
self.assertFalse(hasattr(User, 'favorite_color')) self.assertFalse(hasattr(User, "favorite_color"))
MockUser.make_proxy(User, '_mock', 'favorite_color') MockUser.make_proxy(User, "_mock", "favorite_color")
self.assertTrue(hasattr(User, 'favorite_color')) self.assertTrue(hasattr(User, "favorite_color"))
user = User(favorite_color='green') user = User(favorite_color="green")
self.assertEqual(user.favorite_color, 'green') self.assertEqual(user.favorite_color, "green")
class TestSetting(TestCase): class TestSetting(TestCase):
def test_basic(self): def test_basic(self):
setting = mod.Setting() setting = mod.Setting()
self.assertEqual(str(setting), "") self.assertEqual(str(setting), "")
setting.name = 'foo' setting.name = "foo"
self.assertEqual(str(setting), "foo") self.assertEqual(str(setting), "foo")
class TestPerson(TestCase): class TestPerson(TestCase):
def test_basic(self): def test_basic(self):

View file

@ -17,40 +17,46 @@ else:
def test_basic(self): def test_basic(self):
class MyBatch(mod.BatchMixin, model.Base): class MyBatch(mod.BatchMixin, model.Base):
__tablename__ = 'testing_mybatch' __tablename__ = "testing_mybatch"
model.Base.metadata.create_all(bind=self.session.bind) model.Base.metadata.create_all(bind=self.session.bind)
metadata = sa.MetaData() metadata = sa.MetaData()
metadata.reflect(self.session.bind) metadata.reflect(self.session.bind)
self.assertIn('testing_mybatch', metadata.tables) self.assertIn("testing_mybatch", metadata.tables)
batch = MyBatch(id=42, uuid=_uuid.UUID('0675cdac-ffc9-7690-8000-6023de1c8cfd')) batch = MyBatch(
self.assertEqual(repr(batch), "MyBatch(uuid=UUID('0675cdac-ffc9-7690-8000-6023de1c8cfd'))") id=42, uuid=_uuid.UUID("0675cdac-ffc9-7690-8000-6023de1c8cfd")
)
self.assertEqual(
repr(batch),
"MyBatch(uuid=UUID('0675cdac-ffc9-7690-8000-6023de1c8cfd'))",
)
self.assertEqual(str(batch), "00000042") self.assertEqual(str(batch), "00000042")
self.assertEqual(batch.id_str, "00000042") self.assertEqual(batch.id_str, "00000042")
batch2 = MyBatch() batch2 = MyBatch()
self.assertIsNone(batch2.id_str) self.assertIsNone(batch2.id_str)
class TestBatchRowMixin(DataTestCase): class TestBatchRowMixin(DataTestCase):
def test_basic(self): def test_basic(self):
class MyBatch2(mod.BatchMixin, model.Base): class MyBatch2(mod.BatchMixin, model.Base):
__tablename__ = 'testing_mybatch2' __tablename__ = "testing_mybatch2"
class MyBatchRow2(mod.BatchRowMixin, model.Base): class MyBatchRow2(mod.BatchRowMixin, model.Base):
__tablename__ = 'testing_mybatch_row2' __tablename__ = "testing_mybatch_row2"
__batch_class__ = MyBatch2 __batch_class__ = MyBatch2
model.Base.metadata.create_all(bind=self.session.bind) model.Base.metadata.create_all(bind=self.session.bind)
metadata = sa.MetaData() metadata = sa.MetaData()
metadata.reflect(self.session.bind) metadata.reflect(self.session.bind)
self.assertIn('testing_mybatch2', metadata.tables) self.assertIn("testing_mybatch2", metadata.tables)
self.assertIn('testing_mybatch_row2', metadata.tables) self.assertIn("testing_mybatch_row2", metadata.tables)
# nb. this gives coverage but doesn't really test much # nb. this gives coverage but doesn't really test much
batch = MyBatch2(id=42, uuid=_uuid.UUID('0675cdac-ffc9-7690-8000-6023de1c8cfd')) batch = MyBatch2(
id=42, uuid=_uuid.UUID("0675cdac-ffc9-7690-8000-6023de1c8cfd")
)
row = MyBatchRow2() row = MyBatchRow2()
batch.rows.append(row) batch.rows.append(row)

View file

@ -27,109 +27,130 @@ else:
def write_file(self, filename, content): def write_file(self, filename, content):
path = os.path.join(self.tempdir, filename) path = os.path.join(self.tempdir, filename)
with open(path, 'wt') as f: with open(path, "wt") as f:
f.write(content) f.write(content)
return path return path
def test_no_default(self): def test_no_default(self):
myfile = self.write_file('my.conf', '') myfile = self.write_file("my.conf", "")
config = WuttaConfig([myfile]) config = WuttaConfig([myfile])
self.assertEqual(conf.get_engines(config, 'wuttadb'), {}) self.assertEqual(conf.get_engines(config, "wuttadb"), {})
def test_default(self): def test_default(self):
myfile = self.write_file('my.conf', """\ myfile = self.write_file(
"my.conf",
"""\
[wuttadb] [wuttadb]
default.url = sqlite:// default.url = sqlite://
""") """,
)
config = WuttaConfig([myfile]) config = WuttaConfig([myfile])
result = conf.get_engines(config, 'wuttadb') result = conf.get_engines(config, "wuttadb")
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertIn('default', result) self.assertIn("default", result)
engine = result['default'] engine = result["default"]
self.assertEqual(engine.dialect.name, 'sqlite') self.assertEqual(engine.dialect.name, "sqlite")
def test_default_fallback(self): def test_default_fallback(self):
myfile = self.write_file('my.conf', """\ myfile = self.write_file(
"my.conf",
"""\
[wuttadb] [wuttadb]
sqlalchemy.url = sqlite:// sqlalchemy.url = sqlite://
""") """,
)
config = WuttaConfig([myfile]) config = WuttaConfig([myfile])
result = conf.get_engines(config, 'wuttadb') result = conf.get_engines(config, "wuttadb")
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertIn('default', result) self.assertIn("default", result)
engine = result['default'] engine = result["default"]
self.assertEqual(engine.dialect.name, 'sqlite') self.assertEqual(engine.dialect.name, "sqlite")
def test_other(self): def test_other(self):
myfile = self.write_file('my.conf', """\ myfile = self.write_file(
"my.conf",
"""\
[otherdb] [otherdb]
keys = first, second keys = first, second
first.url = sqlite:// first.url = sqlite://
second.url = sqlite:// second.url = sqlite://
""") """,
)
config = WuttaConfig([myfile]) config = WuttaConfig([myfile])
result = conf.get_engines(config, 'otherdb') result = conf.get_engines(config, "otherdb")
self.assertEqual(len(result), 2) self.assertEqual(len(result), 2)
self.assertIn('first', result) self.assertIn("first", result)
self.assertIn('second', result) self.assertIn("second", result)
class TestGetSetting(TestCase): class TestGetSetting(TestCase):
def setUp(self): def setUp(self):
Session = orm.sessionmaker() Session = orm.sessionmaker()
engine = sa.create_engine('sqlite://') engine = sa.create_engine("sqlite://")
self.session = Session(bind=engine) self.session = Session(bind=engine)
self.session.execute(sa.text(""" self.session.execute(
sa.text(
"""
create table setting ( create table setting (
name varchar(255) primary key, name varchar(255) primary key,
value text value text
); );
""")) """
)
)
def tearDown(self): def tearDown(self):
self.session.close() self.session.close()
def test_basic_value(self): def test_basic_value(self):
self.session.execute(sa.text("insert into setting values ('foo', 'bar');")) self.session.execute(sa.text("insert into setting values ('foo', 'bar');"))
value = conf.get_setting(self.session, 'foo') value = conf.get_setting(self.session, "foo")
self.assertEqual(value, 'bar') self.assertEqual(value, "bar")
def test_missing_value(self): def test_missing_value(self):
value = conf.get_setting(self.session, 'foo') value = conf.get_setting(self.session, "foo")
self.assertIsNone(value) self.assertIsNone(value)
class TestMakeEngineFromConfig(TestCase): class TestMakeEngineFromConfig(TestCase):
def test_basic(self): def test_basic(self):
engine = conf.make_engine_from_config({ engine = conf.make_engine_from_config(
'sqlalchemy.url': 'sqlite://', {
}) "sqlalchemy.url": "sqlite://",
}
)
self.assertIsInstance(engine, Engine) self.assertIsInstance(engine, Engine)
def test_poolclass(self): def test_poolclass(self):
engine = conf.make_engine_from_config({ engine = conf.make_engine_from_config(
'sqlalchemy.url': 'sqlite://', {
}) "sqlalchemy.url": "sqlite://",
}
)
self.assertNotIsInstance(engine.pool, NullPool) self.assertNotIsInstance(engine.pool, NullPool)
engine = conf.make_engine_from_config({ engine = conf.make_engine_from_config(
'sqlalchemy.url': 'sqlite://', {
'sqlalchemy.poolclass': 'sqlalchemy.pool:NullPool', "sqlalchemy.url": "sqlite://",
}) "sqlalchemy.poolclass": "sqlalchemy.pool:NullPool",
}
)
self.assertIsInstance(engine.pool, NullPool) self.assertIsInstance(engine.pool, NullPool)
def test_pool_pre_ping(self): def test_pool_pre_ping(self):
engine = conf.make_engine_from_config({ engine = conf.make_engine_from_config(
'sqlalchemy.url': 'sqlite://', {
}) "sqlalchemy.url": "sqlite://",
}
)
self.assertFalse(engine.pool._pre_ping) self.assertFalse(engine.pool._pre_ping)
engine = conf.make_engine_from_config({ engine = conf.make_engine_from_config(
'sqlalchemy.url': 'sqlite://', {
'sqlalchemy.pool_pre_ping': 'true', "sqlalchemy.url": "sqlite://",
}) "sqlalchemy.pool_pre_ping": "true",
}
)
self.assertTrue(engine.pool._pre_ping) self.assertTrue(engine.pool._pre_ping)

View file

@ -22,20 +22,20 @@ else:
# counter table should not exist yet # counter table should not exist yet
metadata = sa.MetaData() metadata = sa.MetaData()
metadata.reflect(self.session.bind) metadata.reflect(self.session.bind)
self.assertNotIn('_counter_testing', metadata.tables) self.assertNotIn("_counter_testing", metadata.tables)
# using sqlite as backend, should make table for counter # using sqlite as backend, should make table for counter
value = handler.next_counter_value(self.session, 'testing') value = handler.next_counter_value(self.session, "testing")
self.assertEqual(value, 1) self.assertEqual(value, 1)
# counter table should exist now # counter table should exist now
metadata.reflect(self.session.bind) metadata.reflect(self.session.bind)
self.assertIn('_counter_testing', metadata.tables) self.assertIn("_counter_testing", metadata.tables)
# counter increments okay # counter increments okay
value = handler.next_counter_value(self.session, 'testing') value = handler.next_counter_value(self.session, "testing")
self.assertEqual(value, 2) self.assertEqual(value, 2)
value = handler.next_counter_value(self.session, 'testing') value = handler.next_counter_value(self.session, "testing")
self.assertEqual(value, 3) self.assertEqual(value, 3)
def test_next_counter_value_postgres(self): def test_next_counter_value_postgres(self):
@ -44,20 +44,20 @@ else:
# counter table should not exist # counter table should not exist
metadata = sa.MetaData() metadata = sa.MetaData()
metadata.reflect(self.session.bind) metadata.reflect(self.session.bind)
self.assertNotIn('_counter_testing', metadata.tables) self.assertNotIn("_counter_testing", metadata.tables)
# nb. we have to pretty much mock this out, can't really # nb. we have to pretty much mock this out, can't really
# test true sequence behavior for postgres since tests are # test true sequence behavior for postgres since tests are
# using sqlite backend. # using sqlite backend.
# using postgres as backend, should use "sequence" # using postgres as backend, should use "sequence"
with patch.object(handler, 'get_dialect', return_value='postgresql'): with patch.object(handler, "get_dialect", return_value="postgresql"):
with patch.object(self.session, 'execute') as execute: with patch.object(self.session, "execute") as execute:
execute.return_value.scalar.return_value = 1 execute.return_value.scalar.return_value = 1
value = handler.next_counter_value(self.session, 'testing') value = handler.next_counter_value(self.session, "testing")
self.assertEqual(value, 1) self.assertEqual(value, 1)
execute.return_value.scalar.assert_called_once_with() execute.return_value.scalar.assert_called_once_with()
# counter table should still not exist # counter table should still not exist
metadata.reflect(self.session.bind) metadata.reflect(self.session.bind)
self.assertNotIn('_counter_testing', metadata.tables) self.assertNotIn("_counter_testing", metadata.tables)

View file

@ -16,18 +16,16 @@ except ImportError:
pass pass
else: else:
class TestModelBase(TestCase): class TestModelBase(TestCase):
def test_dict_behavior(self): def test_dict_behavior(self):
setting = Setting() setting = Setting()
self.assertEqual(list(iter(setting)), [('name', None), ('value', None)]) self.assertEqual(list(iter(setting)), [("name", None), ("value", None)])
self.assertIsNone(setting.name) self.assertIsNone(setting.name)
self.assertIsNone(setting['name']) self.assertIsNone(setting["name"])
setting.name = 'foo' setting.name = "foo"
self.assertEqual(setting['name'], 'foo') self.assertEqual(setting["name"], "foo")
self.assertRaises(KeyError, lambda: setting['notfound']) self.assertRaises(KeyError, lambda: setting["notfound"])
class TestUUID(TestCase): class TestUUID(TestCase):
@ -39,14 +37,14 @@ else:
# coverage at least.. # coverage at least..
# postgres # postgres
dialect.name = 'postgresql' dialect.name = "postgresql"
dialect.type_descriptor.return_value = 42 dialect.type_descriptor.return_value = 42
result = typ.load_dialect_impl(dialect) result = typ.load_dialect_impl(dialect)
self.assertTrue(dialect.type_descriptor.called) self.assertTrue(dialect.type_descriptor.called)
self.assertEqual(result, 42) self.assertEqual(result, 42)
# other # other
dialect.name = 'mysql' dialect.name = "mysql"
dialect.type_descriptor.return_value = 43 dialect.type_descriptor.return_value = 43
dialect.type_descriptor.reset_mock() dialect.type_descriptor.reset_mock()
result = typ.load_dialect_impl(dialect) result = typ.load_dialect_impl(dialect)
@ -56,7 +54,7 @@ else:
def test_process_bind_param_postgres(self): def test_process_bind_param_postgres(self):
typ = mod.UUID() typ = mod.UUID()
dialect = MagicMock() dialect = MagicMock()
dialect.name = 'postgresql' dialect.name = "postgresql"
# null # null
result = typ.process_bind_param(None, dialect) result = typ.process_bind_param(None, dialect)
@ -75,7 +73,7 @@ else:
def test_process_bind_param_other(self): def test_process_bind_param_other(self):
typ = mod.UUID() typ = mod.UUID()
dialect = MagicMock() dialect = MagicMock()
dialect.name = 'mysql' dialect.name = "mysql"
# null # null
result = typ.process_bind_param(None, dialect) result = typ.process_bind_param(None, dialect)
@ -110,7 +108,6 @@ else:
result = typ.process_result_value(uuid_true, dialect) result = typ.process_result_value(uuid_true, dialect)
self.assertIs(result, uuid_true) self.assertIs(result, uuid_true)
class TestUUIDColumn(TestCase): class TestUUIDColumn(TestCase):
def test_basic(self): def test_basic(self):
@ -118,24 +115,22 @@ else:
self.assertIsInstance(column, sa.Column) self.assertIsInstance(column, sa.Column)
self.assertIsInstance(column.type, mod.UUID) self.assertIsInstance(column.type, mod.UUID)
class TestUUIDFKColumn(TestCase): class TestUUIDFKColumn(TestCase):
def test_basic(self): def test_basic(self):
column = mod.uuid_fk_column('foo.bar') column = mod.uuid_fk_column("foo.bar")
self.assertIsInstance(column, sa.Column) self.assertIsInstance(column, sa.Column)
self.assertIsInstance(column.type, mod.UUID) self.assertIsInstance(column.type, mod.UUID)
class TestMakeTopoSortkey(DataTestCase): class TestMakeTopoSortkey(DataTestCase):
def test_basic(self): def test_basic(self):
model = self.app.model model = self.app.model
sortkey = mod.make_topo_sortkey(model) sortkey = mod.make_topo_sortkey(model)
original = ['User', 'Person', 'UserRole', 'Role'] original = ["User", "Person", "UserRole", "Role"]
# models are sorted so dependants come later # models are sorted so dependants come later
result = sorted(original, key=sortkey) result = sorted(original, key=sortkey)
self.assertTrue(result.index('Role') < result.index('UserRole')) self.assertTrue(result.index("Role") < result.index("UserRole"))
self.assertTrue(result.index('User') < result.index('UserRole')) self.assertTrue(result.index("User") < result.index("UserRole"))
self.assertTrue(result.index('Person') < result.index('User')) self.assertTrue(result.index("Person") < result.index("User"))

View file

@ -26,6 +26,7 @@ from wuttjamaican.batch import BatchHandler
class MockBatchHandler(BatchHandler): class MockBatchHandler(BatchHandler):
pass pass
class AnotherBatchHandler(BatchHandler): class AnotherBatchHandler(BatchHandler):
pass pass
@ -34,14 +35,14 @@ class TestAppHandler(FileTestCase):
def setUp(self): def setUp(self):
self.setup_files() self.setup_files()
self.config = WuttaConfig(appname='wuttatest') self.config = WuttaConfig(appname="wuttatest")
self.app = mod.AppHandler(self.config) self.app = mod.AppHandler(self.config)
self.config.app = self.app self.config.app = self.app
def test_init(self): def test_init(self):
self.assertIs(self.app.config, self.config) self.assertIs(self.app.config, self.config)
self.assertEqual(self.app.handlers, {}) self.assertEqual(self.app.handlers, {})
self.assertEqual(self.app.appname, 'wuttatest') self.assertEqual(self.app.appname, "wuttatest")
def test_get_enum(self): def test_get_enum(self):
self.assertIs(self.app.get_enum(), wuttjamaican.enum) self.assertIs(self.app.get_enum(), wuttjamaican.enum)
@ -50,48 +51,50 @@ class TestAppHandler(FileTestCase):
# just confirm the method works on a basic level; the # just confirm the method works on a basic level; the
# underlying function is tested elsewhere # underlying function is tested elsewhere
obj = self.app.load_object('wuttjamaican.util:UNSPECIFIED') obj = self.app.load_object("wuttjamaican.util:UNSPECIFIED")
self.assertIs(obj, UNSPECIFIED) self.assertIs(obj, UNSPECIFIED)
def test_get_appdir(self): def test_get_appdir(self):
mockdir = self.mkdir('mockdir') mockdir = self.mkdir("mockdir")
# default appdir # default appdir
with patch.object(sys, 'prefix', new=mockdir): with patch.object(sys, "prefix", new=mockdir):
# default is returned by default # default is returned by default
appdir = self.app.get_appdir() appdir = self.app.get_appdir()
self.assertEqual(appdir, os.path.join(mockdir, 'app')) self.assertEqual(appdir, os.path.join(mockdir, "app"))
# but not if caller wants config only # but not if caller wants config only
appdir = self.app.get_appdir(configured_only=True) appdir = self.app.get_appdir(configured_only=True)
self.assertIsNone(appdir) self.assertIsNone(appdir)
# also, cannot create if appdir path not known # also, cannot create if appdir path not known
self.assertRaises(ValueError, self.app.get_appdir, configured_only=True, create=True) self.assertRaises(
ValueError, self.app.get_appdir, configured_only=True, create=True
)
# configured appdir # configured appdir
self.config.setdefault('wuttatest.appdir', mockdir) self.config.setdefault("wuttatest.appdir", mockdir)
appdir = self.app.get_appdir() appdir = self.app.get_appdir()
self.assertEqual(appdir, mockdir) self.assertEqual(appdir, mockdir)
# appdir w/ subpath # appdir w/ subpath
appdir = self.app.get_appdir('foo', 'bar') appdir = self.app.get_appdir("foo", "bar")
self.assertEqual(appdir, os.path.join(mockdir, 'foo', 'bar')) self.assertEqual(appdir, os.path.join(mockdir, "foo", "bar"))
# subpath is created # subpath is created
self.assertEqual(len(os.listdir(mockdir)), 0) self.assertEqual(len(os.listdir(mockdir)), 0)
appdir = self.app.get_appdir('foo', 'bar', create=True) appdir = self.app.get_appdir("foo", "bar", create=True)
self.assertEqual(appdir, os.path.join(mockdir, 'foo', 'bar')) self.assertEqual(appdir, os.path.join(mockdir, "foo", "bar"))
self.assertEqual(os.listdir(mockdir), ['foo']) self.assertEqual(os.listdir(mockdir), ["foo"])
self.assertEqual(os.listdir(os.path.join(mockdir, 'foo')), ['bar']) self.assertEqual(os.listdir(os.path.join(mockdir, "foo")), ["bar"])
def test_make_appdir(self): def test_make_appdir(self):
# appdir is created, and 3 subfolders added by default # appdir is created, and 3 subfolders added by default
tempdir = tempfile.mkdtemp() tempdir = tempfile.mkdtemp()
appdir = os.path.join(tempdir, 'app') appdir = os.path.join(tempdir, "app")
self.assertFalse(os.path.exists(appdir)) self.assertFalse(os.path.exists(appdir))
self.app.make_appdir(appdir) self.app.make_appdir(appdir)
self.assertTrue(os.path.exists(appdir)) self.assertTrue(os.path.exists(appdir))
@ -107,23 +110,30 @@ class TestAppHandler(FileTestCase):
shutil.rmtree(tempdir) shutil.rmtree(tempdir)
def test_render_mako_template(self): def test_render_mako_template(self):
output_conf = self.write_file('output.conf', '') output_conf = self.write_file("output.conf", "")
template = Template("""\ template = Template(
"""\
[wutta] [wutta]
app_title = WuttaTest app_title = WuttaTest
""") """
)
output = self.app.render_mako_template(template, {}, output_path=output_conf) output = self.app.render_mako_template(template, {}, output_path=output_conf)
self.assertEqual(output, """\ self.assertEqual(
output,
"""\
[wutta] [wutta]
app_title = WuttaTest app_title = WuttaTest
""") """,
)
with open(output_conf, 'rt') as f: with open(output_conf, "rt") as f:
self.assertEqual(f.read(), output) self.assertEqual(f.read(), output)
def test_resource_path(self): def test_resource_path(self):
result = self.app.resource_path('wuttjamaican:templates') result = self.app.resource_path("wuttjamaican:templates")
self.assertEqual(result, os.path.join(os.path.dirname(mod.__file__), 'templates')) self.assertEqual(
result, os.path.join(os.path.dirname(mod.__file__), "templates")
)
def test_make_session(self): def test_make_session(self):
try: try:
@ -138,11 +148,12 @@ app_title = WuttaTest
short_session = MagicMock() short_session = MagicMock()
mockdb = MagicMock(short_session=short_session) mockdb = MagicMock(short_session=short_session)
with patch.dict('sys.modules', **{'wuttjamaican.db': mockdb}): with patch.dict("sys.modules", **{"wuttjamaican.db": mockdb}):
with self.app.short_session(foo='bar') as s: with self.app.short_session(foo="bar") as s:
short_session.assert_called_once_with( short_session.assert_called_once_with(
foo='bar', factory=self.app.make_session) foo="bar", factory=self.app.make_session
)
def test_get_setting(self): def test_get_setting(self):
try: try:
@ -152,22 +163,26 @@ app_title = WuttaTest
pytest.skip("test is not relevant without sqlalchemy") pytest.skip("test is not relevant without sqlalchemy")
Session = orm.sessionmaker() Session = orm.sessionmaker()
engine = sa.create_engine('sqlite://') engine = sa.create_engine("sqlite://")
session = Session(bind=engine) session = Session(bind=engine)
session.execute(sa.text(""" session.execute(
sa.text(
"""
create table setting ( create table setting (
name varchar(255) primary key, name varchar(255) primary key,
value text value text
); );
""")) """
)
)
session.commit() session.commit()
value = self.app.get_setting(session, 'foo') value = self.app.get_setting(session, "foo")
self.assertIsNone(value) self.assertIsNone(value)
session.execute(sa.text("insert into setting values ('foo', 'bar');")) session.execute(sa.text("insert into setting values ('foo', 'bar');"))
value = self.app.get_setting(session, 'foo') value = self.app.get_setting(session, "foo")
self.assertEqual(value, 'bar') self.assertEqual(value, "bar")
def test_save_setting(self): def test_save_setting(self):
try: try:
@ -177,25 +192,29 @@ app_title = WuttaTest
pytest.skip("test is not relevant without sqlalchemy") pytest.skip("test is not relevant without sqlalchemy")
Session = orm.sessionmaker() Session = orm.sessionmaker()
engine = sa.create_engine('sqlite://') engine = sa.create_engine("sqlite://")
session = Session(bind=engine) session = Session(bind=engine)
session.execute(sa.text(""" session.execute(
sa.text(
"""
create table setting ( create table setting (
name varchar(255) primary key, name varchar(255) primary key,
value text value text
); );
""")) """
)
)
session.commit() session.commit()
# value null by default # value null by default
value = self.app.get_setting(session, 'foo') value = self.app.get_setting(session, "foo")
self.assertIsNone(value) self.assertIsNone(value)
# unless we save a value # unless we save a value
self.app.save_setting(session, 'foo', '1') self.app.save_setting(session, "foo", "1")
session.commit() session.commit()
value = self.app.get_setting(session, 'foo') value = self.app.get_setting(session, "foo")
self.assertEqual(value, '1') self.assertEqual(value, "1")
def test_delete_setting(self): def test_delete_setting(self):
try: try:
@ -205,43 +224,48 @@ app_title = WuttaTest
pytest.skip("test is not relevant without sqlalchemy") pytest.skip("test is not relevant without sqlalchemy")
Session = orm.sessionmaker() Session = orm.sessionmaker()
engine = sa.create_engine('sqlite://') engine = sa.create_engine("sqlite://")
session = Session(bind=engine) session = Session(bind=engine)
session.execute(sa.text(""" session.execute(
sa.text(
"""
create table setting ( create table setting (
name varchar(255) primary key, name varchar(255) primary key,
value text value text
); );
""")) """
)
)
session.commit() session.commit()
# value null by default # value null by default
value = self.app.get_setting(session, 'foo') value = self.app.get_setting(session, "foo")
self.assertIsNone(value) self.assertIsNone(value)
# unless we save a value # unless we save a value
self.app.save_setting(session, 'foo', '1') self.app.save_setting(session, "foo", "1")
session.commit() session.commit()
value = self.app.get_setting(session, 'foo') value = self.app.get_setting(session, "foo")
self.assertEqual(value, '1') self.assertEqual(value, "1")
# but then if we delete it, should be null again # but then if we delete it, should be null again
self.app.delete_setting(session, 'foo') self.app.delete_setting(session, "foo")
session.commit() session.commit()
value = self.app.get_setting(session, 'foo') value = self.app.get_setting(session, "foo")
self.assertIsNone(value) self.assertIsNone(value)
def test_continuum_is_enabled(self): def test_continuum_is_enabled(self):
# false by default # false by default
with patch.object(self.app, 'providers', new={}): with patch.object(self.app, "providers", new={}):
self.assertFalse(self.app.continuum_is_enabled()) self.assertFalse(self.app.continuum_is_enabled())
# but "any" provider technically could enable it... # but "any" provider technically could enable it...
class MockProvider: class MockProvider:
def continuum_is_enabled(self): def continuum_is_enabled(self):
return True return True
with patch.object(self.app, 'providers', new={'mock': MockProvider()}):
with patch.object(self.app, "providers", new={"mock": MockProvider()}):
self.assertTrue(self.app.continuum_is_enabled()) self.assertTrue(self.app.continuum_is_enabled())
def test_model(self): def test_model(self):
@ -250,7 +274,7 @@ app_title = WuttaTest
except ImportError: except ImportError:
pytest.skip("test not relevant without sqlalchemy") pytest.skip("test not relevant without sqlalchemy")
self.assertNotIn('model', self.app.__dict__) self.assertNotIn("model", self.app.__dict__)
self.assertIs(self.app.model, model) self.assertIs(self.app.model, model)
def test_get_model(self): def test_get_model(self):
@ -262,20 +286,20 @@ app_title = WuttaTest
self.assertIs(self.app.get_model(), model) self.assertIs(self.app.get_model(), model)
def test_get_title(self): def test_get_title(self):
self.assertEqual(self.app.get_title(), 'WuttJamaican') self.assertEqual(self.app.get_title(), "WuttJamaican")
def test_get_node_title(self): def test_get_node_title(self):
# default # default
self.assertEqual(self.app.get_node_title(), 'WuttJamaican') self.assertEqual(self.app.get_node_title(), "WuttJamaican")
# will fallback to app title # will fallback to app title
self.config.setdefault('wuttatest.app_title', "WuttaTest") self.config.setdefault("wuttatest.app_title", "WuttaTest")
self.assertEqual(self.app.get_node_title(), 'WuttaTest') self.assertEqual(self.app.get_node_title(), "WuttaTest")
# will read from config # will read from config
self.config.setdefault('wuttatest.node_title', "WuttaNode") self.config.setdefault("wuttatest.node_title", "WuttaNode")
self.assertEqual(self.app.get_node_title(), 'WuttaNode') self.assertEqual(self.app.get_node_title(), "WuttaNode")
def test_get_node_type(self): def test_get_node_type(self):
@ -283,8 +307,8 @@ app_title = WuttaTest
self.assertIsNone(self.app.get_node_type()) self.assertIsNone(self.app.get_node_type())
# will read from config # will read from config
self.config.setdefault('wuttatest.node_type', 'warehouse') self.config.setdefault("wuttatest.node_type", "warehouse")
self.assertEqual(self.app.get_node_type(), 'warehouse') self.assertEqual(self.app.get_node_type(), "warehouse")
def test_get_distribution(self): def test_get_distribution(self):
@ -296,16 +320,16 @@ app_title = WuttaTest
# works with "non-native" objects # works with "non-native" objects
query = Query({}) query = Query({})
dist = self.app.get_distribution(query) dist = self.app.get_distribution(query)
self.assertEqual(dist, 'SQLAlchemy') self.assertEqual(dist, "SQLAlchemy")
# can override dist via config # can override dist via config
self.config.setdefault('wuttatest.app_dist', 'importlib_metadata') self.config.setdefault("wuttatest.app_dist", "importlib_metadata")
dist = self.app.get_distribution() dist = self.app.get_distribution()
self.assertEqual(dist, 'importlib_metadata') self.assertEqual(dist, "importlib_metadata")
# but the provided object takes precedence # but the provided object takes precedence
dist = self.app.get_distribution(query) dist = self.app.get_distribution(query)
self.assertEqual(dist, 'SQLAlchemy') self.assertEqual(dist, "SQLAlchemy")
def test_get_distribution_pre_python_3_10(self): def test_get_distribution_pre_python_3_10(self):
@ -318,30 +342,32 @@ app_title = WuttaTest
importlib_metadata = MagicMock() importlib_metadata = MagicMock()
importlib_metadata.packages_distributions = MagicMock( importlib_metadata.packages_distributions = MagicMock(
return_value={ return_value={
'wuttjamaican': ['WuttJamaican'], "wuttjamaican": ["WuttJamaican"],
'config': ['python-configuration'], "config": ["python-configuration"],
}) }
)
orig_import = __import__ orig_import = __import__
def mock_import(name, *args, **kwargs): def mock_import(name, *args, **kwargs):
if name == 'importlib.metadata': if name == "importlib.metadata":
raise ImportError raise ImportError
if name == 'importlib_metadata': if name == "importlib_metadata":
return importlib_metadata return importlib_metadata
return orig_import(name, *args, **kwargs) return orig_import(name, *args, **kwargs)
with patch('builtins.__import__', side_effect=mock_import): with patch("builtins.__import__", side_effect=mock_import):
# default should always be WuttJamaican (right..?) # default should always be WuttJamaican (right..?)
dist = self.app.get_distribution() dist = self.app.get_distribution()
self.assertEqual(dist, 'WuttJamaican') self.assertEqual(dist, "WuttJamaican")
# also works with "non-native" objects # also works with "non-native" objects
from config import Configuration from config import Configuration
config = Configuration({}) config = Configuration({})
dist = self.app.get_distribution(config) dist = self.app.get_distribution(config)
self.assertEqual(dist, 'python-configuration') self.assertEqual(dist, "python-configuration")
# hacky sort of test, just in case we can't deduce the # hacky sort of test, just in case we can't deduce the
# package dist based on the obj - easy enough since we # package dist based on the obj - easy enough since we
@ -350,17 +376,17 @@ app_title = WuttaTest
self.assertIsNone(dist) self.assertIsNone(dist)
# can override dist via config # can override dist via config
self.config.setdefault('wuttatest.app_dist', 'importlib_metadata') self.config.setdefault("wuttatest.app_dist", "importlib_metadata")
dist = self.app.get_distribution() dist = self.app.get_distribution()
self.assertEqual(dist, 'importlib_metadata') self.assertEqual(dist, "importlib_metadata")
# but the provided object takes precedence # but the provided object takes precedence
dist = self.app.get_distribution(config) dist = self.app.get_distribution(config)
self.assertEqual(dist, 'python-configuration') self.assertEqual(dist, "python-configuration")
# hacky test again, this time config override should win # hacky test again, this time config override should win
dist = self.app.get_distribution(42) dist = self.app.get_distribution(42)
self.assertEqual(dist, 'importlib_metadata') self.assertEqual(dist, "importlib_metadata")
def test_get_version(self): def test_get_version(self):
from importlib.metadata import version from importlib.metadata import version
@ -373,31 +399,31 @@ app_title = WuttaTest
# works with "non-native" objects # works with "non-native" objects
query = Query({}) query = Query({})
ver = self.app.get_version(obj=query) ver = self.app.get_version(obj=query)
self.assertEqual(ver, version('SQLAlchemy')) self.assertEqual(ver, version("SQLAlchemy"))
# random object will not yield a dist nor version # random object will not yield a dist nor version
ver = self.app.get_version(obj=42) ver = self.app.get_version(obj=42)
self.assertIsNone(ver) self.assertIsNone(ver)
# can override dist via config # can override dist via config
self.config.setdefault('wuttatest.app_dist', 'python-configuration') self.config.setdefault("wuttatest.app_dist", "python-configuration")
ver = self.app.get_version() ver = self.app.get_version()
self.assertEqual(ver, version('python-configuration')) self.assertEqual(ver, version("python-configuration"))
# but the provided object takes precedence # but the provided object takes precedence
ver = self.app.get_version(obj=query) ver = self.app.get_version(obj=query)
self.assertEqual(ver, version('SQLAlchemy')) self.assertEqual(ver, version("SQLAlchemy"))
# can also specify the dist # can also specify the dist
ver = self.app.get_version(dist='passlib') ver = self.app.get_version(dist="passlib")
self.assertEqual(ver, version('passlib')) self.assertEqual(ver, version("passlib"))
def test_make_title(self): def test_make_title(self):
text = self.app.make_title('foo_bar') text = self.app.make_title("foo_bar")
self.assertEqual(text, "Foo Bar") self.assertEqual(text, "Foo Bar")
def test_make_full_name(self): def test_make_full_name(self):
name = self.app.make_full_name('Fred', '', 'Flintstone', '') name = self.app.make_full_name("Fred", "", "Flintstone", "")
self.assertEqual(name, "Fred Flintstone") self.assertEqual(name, "Fred Flintstone")
def test_make_uuid(self): def test_make_uuid(self):
@ -414,12 +440,10 @@ app_title = WuttaTest
pass pass
# with progress # with progress
self.app.progress_loop(act, [1, 2, 3], ProgressBase, self.app.progress_loop(act, [1, 2, 3], ProgressBase, message="whatever")
message="whatever")
# without progress # without progress
self.app.progress_loop(act, [1, 2, 3], None, self.app.progress_loop(act, [1, 2, 3], None, message="whatever")
message="whatever")
def test_get_session(self): def test_get_session(self):
try: try:
@ -433,7 +457,7 @@ app_title = WuttaTest
self.assertIsNone(self.app.get_session(user)) self.assertIsNone(self.app.get_session(user))
Session = orm.sessionmaker() Session = orm.sessionmaker()
engine = sa.create_engine('sqlite://') engine = sa.create_engine("sqlite://")
mysession = Session(bind=engine) mysession = Session(bind=engine)
mysession.add(user) mysession.add(user)
session = self.app.get_session(user) session = self.app.get_session(user)
@ -453,39 +477,39 @@ app_title = WuttaTest
def test_render_currency(self): def test_render_currency(self):
# null # null
self.assertEqual(self.app.render_currency(None), '') self.assertEqual(self.app.render_currency(None), "")
# basic decimal example # basic decimal example
value = decimal.Decimal('42.00') value = decimal.Decimal("42.00")
self.assertEqual(self.app.render_currency(value), '$42.00') self.assertEqual(self.app.render_currency(value), "$42.00")
# basic float example # basic float example
value = 42.00 value = 42.00
self.assertEqual(self.app.render_currency(value), '$42.00') self.assertEqual(self.app.render_currency(value), "$42.00")
# decimal places will be rounded # decimal places will be rounded
value = decimal.Decimal('42.12345') value = decimal.Decimal("42.12345")
self.assertEqual(self.app.render_currency(value), '$42.12') self.assertEqual(self.app.render_currency(value), "$42.12")
# but we can declare the scale # but we can declare the scale
value = decimal.Decimal('42.12345') value = decimal.Decimal("42.12345")
self.assertEqual(self.app.render_currency(value, scale=4), '$42.1234') self.assertEqual(self.app.render_currency(value, scale=4), "$42.1234")
# negative numbers get parens # negative numbers get parens
value = decimal.Decimal('-42.42') value = decimal.Decimal("-42.42")
self.assertEqual(self.app.render_currency(value), '($42.42)') self.assertEqual(self.app.render_currency(value), "($42.42)")
def test_render_date(self): def test_render_date(self):
self.assertEqual(self.app.render_date(None), '') self.assertEqual(self.app.render_date(None), "")
dt = datetime.date(2024, 12, 11) dt = datetime.date(2024, 12, 11)
self.assertEqual(self.app.render_date(dt), '2024-12-11') self.assertEqual(self.app.render_date(dt), "2024-12-11")
def test_render_datetime(self): def test_render_datetime(self):
self.assertEqual(self.app.render_datetime(None), '') self.assertEqual(self.app.render_datetime(None), "")
dt = datetime.datetime(2024, 12, 11, 8, 30, tzinfo=datetime.timezone.utc) dt = datetime.datetime(2024, 12, 11, 8, 30, tzinfo=datetime.timezone.utc)
self.assertEqual(self.app.render_datetime(dt), '2024-12-11 08:30+0000') self.assertEqual(self.app.render_datetime(dt), "2024-12-11 08:30+0000")
def test_render_error(self): def test_render_error(self):
@ -509,15 +533,15 @@ app_title = WuttaTest
self.assertEqual(self.app.render_percent(None), "") self.assertEqual(self.app.render_percent(None), "")
# typical # typical
self.assertEqual(self.app.render_percent(12.3419), '12.34 %') self.assertEqual(self.app.render_percent(12.3419), "12.34 %")
# more decimal places # more decimal places
self.assertEqual(self.app.render_percent(12.3419, decimals=3), '12.342 %') self.assertEqual(self.app.render_percent(12.3419, decimals=3), "12.342 %")
self.assertEqual(self.app.render_percent(12.3419, decimals=4), '12.3419 %') self.assertEqual(self.app.render_percent(12.3419, decimals=4), "12.3419 %")
# negative # negative
self.assertEqual(self.app.render_percent(-12.3419), '(12.34 %)') self.assertEqual(self.app.render_percent(-12.3419), "(12.34 %)")
self.assertEqual(self.app.render_percent(-12.3419, decimals=3), '(12.342 %)') self.assertEqual(self.app.render_percent(-12.3419, decimals=3), "(12.342 %)")
def test_render_quantity(self): def test_render_quantity(self):
@ -525,11 +549,11 @@ app_title = WuttaTest
self.assertEqual(self.app.render_quantity(None), "") self.assertEqual(self.app.render_quantity(None), "")
# integer decimals become integers # integer decimals become integers
value = decimal.Decimal('1.000') value = decimal.Decimal("1.000")
self.assertEqual(self.app.render_quantity(value), "1") self.assertEqual(self.app.render_quantity(value), "1")
# but decimal places are preserved # but decimal places are preserved
value = decimal.Decimal('1.234') value = decimal.Decimal("1.234")
self.assertEqual(self.app.render_quantity(value), "1.234") self.assertEqual(self.app.render_quantity(value), "1.234")
# zero can be empty string # zero can be empty string
@ -537,20 +561,20 @@ app_title = WuttaTest
self.assertEqual(self.app.render_quantity(0, empty_zero=True), "") self.assertEqual(self.app.render_quantity(0, empty_zero=True), "")
def test_render_time_ago(self): def test_render_time_ago(self):
with patch.object(mod, 'humanize') as humanize: with patch.object(mod, "humanize") as humanize:
humanize.naturaltime.return_value = 'now' humanize.naturaltime.return_value = "now"
now = datetime.datetime.now() now = datetime.datetime.now()
result = self.app.render_time_ago(now) result = self.app.render_time_ago(now)
self.assertEqual(result, 'now') self.assertEqual(result, "now")
humanize.naturaltime.assert_called_once_with(now) humanize.naturaltime.assert_called_once_with(now)
def test_get_person(self): def test_get_person(self):
people = self.app.get_people_handler() people = self.app.get_people_handler()
with patch.object(people, 'get_person') as get_person: with patch.object(people, "get_person") as get_person:
get_person.return_value = 'foo' get_person.return_value = "foo"
person = self.app.get_person('bar') person = self.app.get_person("bar")
get_person.assert_called_once_with('bar') get_person.assert_called_once_with("bar")
self.assertEqual(person, 'foo') self.assertEqual(person, "foo")
def test_get_auth_handler(self): def test_get_auth_handler(self):
from wuttjamaican.auth import AuthHandler from wuttjamaican.auth import AuthHandler
@ -561,55 +585,80 @@ app_title = WuttaTest
def test_get_batch_handler(self): def test_get_batch_handler(self):
# error if handler not found # error if handler not found
self.assertRaises(KeyError, self.app.get_batch_handler, 'CannotFindMe!') self.assertRaises(KeyError, self.app.get_batch_handler, "CannotFindMe!")
# caller can specify default # caller can specify default
handler = self.app.get_batch_handler('foo', default='wuttjamaican.batch:BatchHandler') handler = self.app.get_batch_handler(
"foo", default="wuttjamaican.batch:BatchHandler"
)
self.assertIsInstance(handler, BatchHandler) self.assertIsInstance(handler, BatchHandler)
# default can be configured # default can be configured
self.config.setdefault('wuttatest.batch.foo.handler.default_spec', self.config.setdefault(
'wuttjamaican.batch:BatchHandler') "wuttatest.batch.foo.handler.default_spec",
handler = self.app.get_batch_handler('foo') "wuttjamaican.batch:BatchHandler",
)
handler = self.app.get_batch_handler("foo")
self.assertIsInstance(handler, BatchHandler) self.assertIsInstance(handler, BatchHandler)
# preference can be configured # preference can be configured
self.config.setdefault('wuttatest.batch.foo.handler.spec', self.config.setdefault(
'tests.test_app:MockBatchHandler') "wuttatest.batch.foo.handler.spec", "tests.test_app:MockBatchHandler"
handler = self.app.get_batch_handler('foo') )
handler = self.app.get_batch_handler("foo")
self.assertIsInstance(handler, MockBatchHandler) self.assertIsInstance(handler, MockBatchHandler)
def test_get_batch_handler_specs(self): def test_get_batch_handler_specs(self):
# empty by default # empty by default
specs = self.app.get_batch_handler_specs('foo') specs = self.app.get_batch_handler_specs("foo")
self.assertEqual(specs, []) self.assertEqual(specs, [])
# caller can specify default as string # caller can specify default as string
specs = self.app.get_batch_handler_specs('foo', default='wuttjamaican.batch:BatchHandler') specs = self.app.get_batch_handler_specs(
self.assertEqual(specs, ['wuttjamaican.batch:BatchHandler']) "foo", default="wuttjamaican.batch:BatchHandler"
)
self.assertEqual(specs, ["wuttjamaican.batch:BatchHandler"])
# caller can specify default as list # caller can specify default as list
specs = self.app.get_batch_handler_specs('foo', default=['wuttjamaican.batch:BatchHandler', specs = self.app.get_batch_handler_specs(
'tests.test_app:MockBatchHandler']) "foo",
self.assertEqual(specs, ['wuttjamaican.batch:BatchHandler', default=[
'tests.test_app:MockBatchHandler']) "wuttjamaican.batch:BatchHandler",
"tests.test_app:MockBatchHandler",
],
)
self.assertEqual(
specs,
["wuttjamaican.batch:BatchHandler", "tests.test_app:MockBatchHandler"],
)
# default can be configured # default can be configured
self.config.setdefault('wuttatest.batch.foo.handler.default_spec', self.config.setdefault(
'wuttjamaican.batch:BatchHandler') "wuttatest.batch.foo.handler.default_spec",
specs = self.app.get_batch_handler_specs('foo') "wuttjamaican.batch:BatchHandler",
self.assertEqual(specs, ['wuttjamaican.batch:BatchHandler']) )
specs = self.app.get_batch_handler_specs("foo")
self.assertEqual(specs, ["wuttjamaican.batch:BatchHandler"])
# the rest come from entry points # the rest come from entry points
with patch.object(mod, 'load_entry_points', return_value={ with patch.object(
'mock': MockBatchHandler, mod,
'another': AnotherBatchHandler, "load_entry_points",
}): return_value={
specs = self.app.get_batch_handler_specs('foo') "mock": MockBatchHandler,
self.assertEqual(specs, ['wuttjamaican.batch:BatchHandler', "another": AnotherBatchHandler,
'tests.test_app:AnotherBatchHandler', },
'tests.test_app:MockBatchHandler']) ):
specs = self.app.get_batch_handler_specs("foo")
self.assertEqual(
specs,
[
"wuttjamaican.batch:BatchHandler",
"tests.test_app:AnotherBatchHandler",
"tests.test_app:MockBatchHandler",
],
)
def test_get_db_handler(self): def test_get_db_handler(self):
try: try:
@ -653,15 +702,15 @@ app_title = WuttaTest
def test_send_email(self): def test_send_email(self):
from wuttjamaican.email import EmailHandler from wuttjamaican.email import EmailHandler
with patch.object(EmailHandler, 'send_email') as send_email: with patch.object(EmailHandler, "send_email") as send_email:
self.app.send_email('foo') self.app.send_email("foo")
send_email.assert_called_once_with('foo') send_email.assert_called_once_with("foo")
class TestAppProvider(TestCase): class TestAppProvider(TestCase):
def setUp(self): def setUp(self):
self.config = WuttaConfig(appname='wuttatest') self.config = WuttaConfig(appname="wuttatest")
self.app = mod.AppHandler(self.config) self.app = mod.AppHandler(self.config)
self.config._app = self.app self.config._app = self.app
@ -671,11 +720,11 @@ class TestAppProvider(TestCase):
provider = mod.AppProvider(self.config) provider = mod.AppProvider(self.config)
self.assertIs(provider.config, self.config) self.assertIs(provider.config, self.config)
self.assertIs(provider.app, self.app) self.assertIs(provider.app, self.app)
self.assertEqual(provider.appname, 'wuttatest') self.assertEqual(provider.appname, "wuttatest")
# but can pass app handler instead # but can pass app handler instead
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
provider = mod.AppProvider(self.app) provider = mod.AppProvider(self.app)
self.assertIs(provider.config, self.config) self.assertIs(provider.config, self.config)
self.assertIs(provider.app, self.app) self.assertIs(provider.app, self.app)
@ -686,17 +735,17 @@ class TestAppProvider(TestCase):
pass pass
# nb. we specify *classes* here # nb. we specify *classes* here
fake_providers = {'fake': FakeProvider} fake_providers = {"fake": FakeProvider}
with patch('wuttjamaican.app.load_entry_points') as load_entry_points: with patch("wuttjamaican.app.load_entry_points") as load_entry_points:
load_entry_points.return_value = fake_providers load_entry_points.return_value = fake_providers
# sanity check, we get *instances* back from this # sanity check, we get *instances* back from this
providers = self.app.get_all_providers() providers = self.app.get_all_providers()
load_entry_points.assert_called_once_with('wutta.app.providers') load_entry_points.assert_called_once_with("wutta.app.providers")
self.assertEqual(len(providers), 1) self.assertEqual(len(providers), 1)
self.assertIn('fake', providers) self.assertIn("fake", providers)
self.assertIsInstance(providers['fake'], FakeProvider) self.assertIsInstance(providers["fake"], FakeProvider)
def test_hasattr(self): def test_hasattr(self):
@ -704,15 +753,15 @@ class TestAppProvider(TestCase):
def fake_foo(self): def fake_foo(self):
pass pass
self.app.providers = {'fake': FakeProvider(self.config)} self.app.providers = {"fake": FakeProvider(self.config)}
self.assertTrue(hasattr(self.app, 'fake_foo')) self.assertTrue(hasattr(self.app, "fake_foo"))
self.assertFalse(hasattr(self.app, 'fake_method_does_not_exist')) self.assertFalse(hasattr(self.app, "fake_method_does_not_exist"))
def test_getattr(self): def test_getattr(self):
# enum # enum
self.assertNotIn('enum', self.app.__dict__) self.assertNotIn("enum", self.app.__dict__)
self.assertIs(self.app.enum, wuttjamaican.enum) self.assertIs(self.app.enum, wuttjamaican.enum)
# now we test that providers are loaded... # now we test that providers are loaded...
@ -722,12 +771,12 @@ class TestAppProvider(TestCase):
return 42 return 42
# nb. using instances here # nb. using instances here
fake_providers = {'fake': FakeProvider(self.config)} fake_providers = {"fake": FakeProvider(self.config)}
with patch.object(self.app, 'get_all_providers') as get_all_providers: with patch.object(self.app, "get_all_providers") as get_all_providers:
get_all_providers.return_value = fake_providers get_all_providers.return_value = fake_providers
self.assertNotIn('providers', self.app.__dict__) self.assertNotIn("providers", self.app.__dict__)
self.assertIs(self.app.providers, fake_providers) self.assertIs(self.app.providers, fake_providers)
get_all_providers.assert_called_once_with() get_all_providers.assert_called_once_with()
@ -738,27 +787,27 @@ class TestAppProvider(TestCase):
pytest.skip("test not relevant without sqlalchemy") pytest.skip("test not relevant without sqlalchemy")
# model # model
self.assertNotIn('model', self.app.__dict__) self.assertNotIn("model", self.app.__dict__)
self.assertIs(self.app.model, wuttjamaican.db.model) self.assertIs(self.app.model, wuttjamaican.db.model)
def test_getattr_providers(self): def test_getattr_providers(self):
# collection of providers is loaded on demand # collection of providers is loaded on demand
self.assertNotIn('providers', self.app.__dict__) self.assertNotIn("providers", self.app.__dict__)
self.assertIsNotNone(self.app.providers) self.assertIsNotNone(self.app.providers)
# custom attr does not exist yet # custom attr does not exist yet
self.assertRaises(AttributeError, getattr, self.app, 'foo_value') self.assertRaises(AttributeError, getattr, self.app, "foo_value")
# but provider can supply the attr # but provider can supply the attr
self.app.providers['mytest'] = MagicMock(foo_value='bar') self.app.providers["mytest"] = MagicMock(foo_value="bar")
self.assertEqual(self.app.foo_value, 'bar') self.assertEqual(self.app.foo_value, "bar")
class TestGenericHandler(ConfigTestCase): class TestGenericHandler(ConfigTestCase):
def make_config(self, **kw): def make_config(self, **kw):
kw.setdefault('appname', 'wuttatest') kw.setdefault("appname", "wuttatest")
return super().make_config(**kw) return super().make_config(**kw)
def make_handler(self, **kwargs): def make_handler(self, **kwargs):
@ -768,34 +817,36 @@ class TestGenericHandler(ConfigTestCase):
handler = mod.GenericHandler(self.config) handler = mod.GenericHandler(self.config)
self.assertIs(handler.config, self.config) self.assertIs(handler.config, self.config)
self.assertIs(handler.app, self.app) self.assertIs(handler.app, self.app)
self.assertEqual(handler.appname, 'wuttatest') self.assertEqual(handler.appname, "wuttatest")
def test_get_spec(self): def test_get_spec(self):
self.assertEqual(mod.GenericHandler.get_spec(), 'wuttjamaican.app:GenericHandler') self.assertEqual(
mod.GenericHandler.get_spec(), "wuttjamaican.app:GenericHandler"
)
def test_get_provider_modules(self): def test_get_provider_modules(self):
# no providers, no email modules # no providers, no email modules
with patch.object(self.app, 'providers', new={}): with patch.object(self.app, "providers", new={}):
handler = self.make_handler() handler = self.make_handler()
self.assertEqual(handler.get_provider_modules('email'), []) self.assertEqual(handler.get_provider_modules("email"), [])
# provider may specify modules as list # provider may specify modules as list
providers = { providers = {
'wuttatest': MagicMock(email_modules=['wuttjamaican.app']), "wuttatest": MagicMock(email_modules=["wuttjamaican.app"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
modules = handler.get_provider_modules('email') modules = handler.get_provider_modules("email")
self.assertEqual(len(modules), 1) self.assertEqual(len(modules), 1)
self.assertIs(modules[0], mod) self.assertIs(modules[0], mod)
# provider may specify modules as string # provider may specify modules as string
providers = { providers = {
'wuttatest': MagicMock(email_modules='wuttjamaican.app'), "wuttatest": MagicMock(email_modules="wuttjamaican.app"),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
modules = handler.get_provider_modules('email') modules = handler.get_provider_modules("email")
self.assertEqual(len(modules), 1) self.assertEqual(len(modules), 1)
self.assertIs(modules[0], mod) self.assertIs(modules[0], mod)

View file

@ -11,7 +11,6 @@ except ImportError:
pass pass
else: else:
class TestAuthHandler(TestCase): class TestAuthHandler(TestCase):
def setUp(self): def setUp(self):
@ -19,7 +18,7 @@ else:
self.app = self.config.get_app() self.app = self.config.get_app()
self.handler = self.make_handler() self.handler = self.make_handler()
self.engine = sa.create_engine('sqlite://') self.engine = sa.create_engine("sqlite://")
self.app.model.Base.metadata.create_all(bind=self.engine) self.app.model.Base.metadata.create_all(bind=self.engine)
self.session = self.make_session() self.session = self.make_session()
@ -35,37 +34,37 @@ else:
def test_authenticate_user(self): def test_authenticate_user(self):
model = self.app.model model = self.app.model
barney = model.User(username='barney') barney = model.User(username="barney")
self.handler.set_user_password(barney, 'goodpass') self.handler.set_user_password(barney, "goodpass")
self.session.add(barney) self.session.add(barney)
self.session.commit() self.session.commit()
# login ok # login ok
user = self.handler.authenticate_user(self.session, 'barney', 'goodpass') user = self.handler.authenticate_user(self.session, "barney", "goodpass")
self.assertIs(user, barney) self.assertIs(user, barney)
# can also pass user instead of username # can also pass user instead of username
user = self.handler.authenticate_user(self.session, barney, 'goodpass') user = self.handler.authenticate_user(self.session, barney, "goodpass")
self.assertIs(user, barney) self.assertIs(user, barney)
# bad password # bad password
user = self.handler.authenticate_user(self.session, 'barney', 'BADPASS') user = self.handler.authenticate_user(self.session, "barney", "BADPASS")
self.assertIsNone(user) self.assertIsNone(user)
# bad username # bad username
user = self.handler.authenticate_user(self.session, 'NOBODY', 'goodpass') user = self.handler.authenticate_user(self.session, "NOBODY", "goodpass")
self.assertIsNone(user) self.assertIsNone(user)
# inactive user # inactive user
user = self.handler.authenticate_user(self.session, 'barney', 'goodpass') user = self.handler.authenticate_user(self.session, "barney", "goodpass")
self.assertIs(user, barney) self.assertIs(user, barney)
barney.active = False barney.active = False
user = self.handler.authenticate_user(self.session, 'barney', 'goodpass') user = self.handler.authenticate_user(self.session, "barney", "goodpass")
self.assertIsNone(user) self.assertIsNone(user)
def test_authenticate_user_token(self): def test_authenticate_user_token(self):
model = self.app.model model = self.app.model
barney = model.User(username='barney') barney = model.User(username="barney")
self.session.add(barney) self.session.add(barney)
token = self.handler.add_api_token(barney, "test token") token = self.handler.add_api_token(barney, "test token")
self.session.commit() self.session.commit()
@ -73,32 +72,38 @@ else:
user = self.handler.authenticate_user_token(self.session, None) user = self.handler.authenticate_user_token(self.session, None)
self.assertIsNone(user) self.assertIsNone(user)
user = self.handler.authenticate_user_token(self.session, token.token_string) user = self.handler.authenticate_user_token(
self.session, token.token_string
)
self.assertIs(user, barney) self.assertIs(user, barney)
barney.active = False barney.active = False
self.session.flush() self.session.flush()
user = self.handler.authenticate_user_token(self.session, token.token_string) user = self.handler.authenticate_user_token(
self.session, token.token_string
)
self.assertIsNone(user) self.assertIsNone(user)
barney.active = True barney.active = True
self.session.flush() self.session.flush()
user = self.handler.authenticate_user_token(self.session, token.token_string) user = self.handler.authenticate_user_token(
self.session, token.token_string
)
self.assertIs(user, barney) self.assertIs(user, barney)
user = self.handler.authenticate_user_token(self.session, 'bad-token') user = self.handler.authenticate_user_token(self.session, "bad-token")
self.assertIsNone(user) self.assertIsNone(user)
def test_check_user_password(self): def test_check_user_password(self):
model = self.app.model model = self.app.model
barney = model.User(username='barney') barney = model.User(username="barney")
self.handler.set_user_password(barney, 'goodpass') self.handler.set_user_password(barney, "goodpass")
self.session.add(barney) self.session.add(barney)
self.session.commit() self.session.commit()
# basics # basics
self.assertTrue(self.handler.check_user_password(barney, 'goodpass')) self.assertTrue(self.handler.check_user_password(barney, "goodpass"))
self.assertFalse(self.handler.check_user_password(barney, 'BADPASS')) self.assertFalse(self.handler.check_user_password(barney, "BADPASS"))
def test_get_role(self): def test_get_role(self):
model = self.app.model model = self.app.model
@ -120,17 +125,17 @@ else:
# key may be represented within a setting # key may be represented within a setting
self.config.usedb = True self.config.usedb = True
role = self.handler.get_role(self.session, 'mykey') role = self.handler.get_role(self.session, "mykey")
self.assertIsNone(role) self.assertIsNone(role)
setting = model.Setting(name='wutta.role.mykey', value=myrole.uuid.hex) setting = model.Setting(name="wutta.role.mykey", value=myrole.uuid.hex)
self.session.add(setting) self.session.add(setting)
self.session.commit() self.session.commit()
role = self.handler.get_role(self.session, 'mykey') role = self.handler.get_role(self.session, "mykey")
self.assertIs(role, myrole) self.assertIs(role, myrole)
def test_get_user(self): def test_get_user(self):
model = self.app.model model = self.app.model
myuser = model.User(username='myuser') myuser = model.User(username="myuser")
self.session.add(myuser) self.session.add(myuser)
self.session.commit() self.session.commit()
@ -155,7 +160,7 @@ else:
self.assertIs(user, myuser) self.assertIs(user, myuser)
# find user from person # find user from person
myperson = model.Person(full_name='My Name') myperson = model.Person(full_name="My Name")
self.session.add(myperson) self.session.add(myperson)
user.person = myperson user.person = myperson
self.session.commit() self.session.commit()
@ -173,11 +178,11 @@ else:
self.assertIsNone(person.full_name) self.assertIsNone(person.full_name)
self.assertNotIn(person, self.session) self.assertNotIn(person, self.session)
person = handler.make_person(first_name='Barney', last_name='Rubble') person = handler.make_person(first_name="Barney", last_name="Rubble")
self.assertIsInstance(person, model.Person) self.assertIsInstance(person, model.Person)
self.assertEqual(person.first_name, 'Barney') self.assertEqual(person.first_name, "Barney")
self.assertEqual(person.last_name, 'Rubble') self.assertEqual(person.last_name, "Rubble")
self.assertEqual(person.full_name, 'Barney Rubble') self.assertEqual(person.full_name, "Barney Rubble")
self.assertNotIn(person, self.session) self.assertNotIn(person, self.session)
def test_make_user(self): def test_make_user(self):
@ -197,13 +202,13 @@ else:
# default username # default username
# nb. this behavior requires a session # nb. this behavior requires a session
user = self.handler.make_user(session=self.session) user = self.handler.make_user(session=self.session)
self.assertEqual(user.username, 'newuser') self.assertEqual(user.username, "newuser")
def test_delete_user(self): def test_delete_user(self):
model = self.app.model model = self.app.model
# basics # basics
myuser = model.User(username='myuser') myuser = model.User(username="myuser")
self.session.add(myuser) self.session.add(myuser)
self.session.commit() self.session.commit()
user = self.session.query(model.User).one() user = self.session.query(model.User).one()
@ -217,67 +222,67 @@ else:
# default # default
name = self.handler.make_preferred_username(self.session) name = self.handler.make_preferred_username(self.session)
self.assertEqual(name, 'newuser') self.assertEqual(name, "newuser")
# person/first+last # person/first+last
person = model.Person(first_name='Barney', last_name='Rubble') person = model.Person(first_name="Barney", last_name="Rubble")
name = self.handler.make_preferred_username(self.session, person=person) name = self.handler.make_preferred_username(self.session, person=person)
self.assertEqual(name, 'barney.rubble') self.assertEqual(name, "barney.rubble")
# person/first # person/first
person = model.Person(first_name='Barney') person = model.Person(first_name="Barney")
name = self.handler.make_preferred_username(self.session, person=person) name = self.handler.make_preferred_username(self.session, person=person)
self.assertEqual(name, 'barney') self.assertEqual(name, "barney")
# person/last # person/last
person = model.Person(last_name='Rubble') person = model.Person(last_name="Rubble")
name = self.handler.make_preferred_username(self.session, person=person) name = self.handler.make_preferred_username(self.session, person=person)
self.assertEqual(name, 'rubble') self.assertEqual(name, "rubble")
def test_make_unique_username(self): def test_make_unique_username(self):
model = self.app.model model = self.app.model
# default # default
name = self.handler.make_unique_username(self.session) name = self.handler.make_unique_username(self.session)
self.assertEqual(name, 'newuser') self.assertEqual(name, "newuser")
user = model.User(username=name) user = model.User(username=name)
self.session.add(user) self.session.add(user)
self.session.commit() self.session.commit()
# counter invoked if name exists # counter invoked if name exists
name = self.handler.make_unique_username(self.session) name = self.handler.make_unique_username(self.session)
self.assertEqual(name, 'newuser01') self.assertEqual(name, "newuser01")
user = model.User(username=name) user = model.User(username=name)
self.session.add(user) self.session.add(user)
self.session.commit() self.session.commit()
# starts by getting preferred name # starts by getting preferred name
person = model.Person(first_name='Barney', last_name='Rubble') person = model.Person(first_name="Barney", last_name="Rubble")
name = self.handler.make_unique_username(self.session, person=person) name = self.handler.make_unique_username(self.session, person=person)
self.assertEqual(name, 'barney.rubble') self.assertEqual(name, "barney.rubble")
user = model.User(username=name) user = model.User(username=name)
self.session.add(user) self.session.add(user)
self.session.commit() self.session.commit()
# counter invoked if name exists # counter invoked if name exists
name = self.handler.make_unique_username(self.session, person=person) name = self.handler.make_unique_username(self.session, person=person)
self.assertEqual(name, 'barney.rubble01') self.assertEqual(name, "barney.rubble01")
def test_set_user_password(self): def test_set_user_password(self):
model = self.app.model model = self.app.model
myuser = model.User(username='myuser') myuser = model.User(username="myuser")
self.session.add(myuser) self.session.add(myuser)
# basics # basics
self.assertIsNone(myuser.password) self.assertIsNone(myuser.password)
self.handler.set_user_password(myuser, 'goodpass') self.handler.set_user_password(myuser, "goodpass")
self.session.commit() self.session.commit()
self.assertIsNotNone(myuser.password) self.assertIsNotNone(myuser.password)
# nb. password is hashed # nb. password is hashed
self.assertNotEqual(myuser.password, 'goodpass') self.assertNotEqual(myuser.password, "goodpass")
# confirm login works with new password # confirm login works with new password
user = self.handler.authenticate_user(self.session, 'myuser', 'goodpass') user = self.handler.authenticate_user(self.session, "myuser", "goodpass")
self.assertIs(user, myuser) self.assertIs(user, myuser)
def test_get_role_administrator(self): def test_get_role_administrator(self):
@ -337,15 +342,15 @@ else:
self.assertEqual(len(perms), 0) self.assertEqual(len(perms), 0)
# role perms # role perms
myrole = model.Role(name='My Role') myrole = model.Role(name="My Role")
self.session.add(myrole) self.session.add(myrole)
self.handler.grant_permission(myrole, 'foo') self.handler.grant_permission(myrole, "foo")
self.session.commit() self.session.commit()
perms = self.handler.get_permissions(self.session, myrole) perms = self.handler.get_permissions(self.session, myrole)
self.assertEqual(perms, {'foo'}) self.assertEqual(perms, {"foo"})
# user perms # user perms
myuser = model.User(username='myuser') myuser = model.User(username="myuser")
self.session.add(myuser) self.session.add(myuser)
self.session.commit() self.session.commit()
perms = self.handler.get_permissions(self.session, myuser) perms = self.handler.get_permissions(self.session, myuser)
@ -353,7 +358,7 @@ else:
myuser.roles.append(myrole) myuser.roles.append(myrole)
self.session.commit() self.session.commit()
perms = self.handler.get_permissions(self.session, myuser) perms = self.handler.get_permissions(self.session, myuser)
self.assertEqual(perms, {'foo'}) self.assertEqual(perms, {"foo"})
# invalid principal # invalid principal
perms = self.handler.get_permissions(self.session, RuntimeError) perms = self.handler.get_permissions(self.session, RuntimeError)
@ -368,39 +373,41 @@ else:
# false default for role # false default for role
role = model.Role() role = model.Role()
self.assertFalse(self.handler.has_permission(self.session, role, 'foo')) self.assertFalse(self.handler.has_permission(self.session, role, "foo"))
# empty default for user # empty default for user
user = model.User() user = model.User()
self.assertFalse(self.handler.has_permission(self.session, user, 'foo')) self.assertFalse(self.handler.has_permission(self.session, user, "foo"))
# role perms # role perms
myrole = model.Role(name='My Role') myrole = model.Role(name="My Role")
self.session.add(myrole) self.session.add(myrole)
self.session.commit() self.session.commit()
self.assertFalse(self.handler.has_permission(self.session, myrole, 'foo')) self.assertFalse(self.handler.has_permission(self.session, myrole, "foo"))
self.handler.grant_permission(myrole, 'foo') self.handler.grant_permission(myrole, "foo")
self.session.commit() self.session.commit()
self.assertTrue(self.handler.has_permission(self.session, myrole, 'foo')) self.assertTrue(self.handler.has_permission(self.session, myrole, "foo"))
# user perms # user perms
myuser = model.User(username='myuser') myuser = model.User(username="myuser")
self.session.add(myuser) self.session.add(myuser)
self.session.commit() self.session.commit()
self.assertFalse(self.handler.has_permission(self.session, myuser, 'foo')) self.assertFalse(self.handler.has_permission(self.session, myuser, "foo"))
myuser.roles.append(myrole) myuser.roles.append(myrole)
self.session.commit() self.session.commit()
self.assertTrue(self.handler.has_permission(self.session, myuser, 'foo')) self.assertTrue(self.handler.has_permission(self.session, myuser, "foo"))
# invalid principal # invalid principal
self.assertFalse(self.handler.has_permission(self.session, RuntimeError, 'foo')) self.assertFalse(
self.handler.has_permission(self.session, RuntimeError, "foo")
)
# missing principal # missing principal
self.assertFalse(self.handler.has_permission(self.session, None, 'foo')) self.assertFalse(self.handler.has_permission(self.session, None, "foo"))
def test_grant_permission(self): def test_grant_permission(self):
model = self.app.model model = self.app.model
myrole = model.Role(name='My Role') myrole = model.Role(name="My Role")
self.session.add(myrole) self.session.add(myrole)
self.session.commit() self.session.commit()
@ -408,38 +415,38 @@ else:
self.assertEqual(self.session.query(model.Permission).count(), 0) self.assertEqual(self.session.query(model.Permission).count(), 0)
# grant one perm, and confirm # grant one perm, and confirm
self.handler.grant_permission(myrole, 'foo') self.handler.grant_permission(myrole, "foo")
self.session.commit() self.session.commit()
self.assertEqual(self.session.query(model.Permission).count(), 1) self.assertEqual(self.session.query(model.Permission).count(), 1)
perm = self.session.query(model.Permission).one() perm = self.session.query(model.Permission).one()
self.assertIs(perm.role, myrole) self.assertIs(perm.role, myrole)
self.assertEqual(perm.permission, 'foo') self.assertEqual(perm.permission, "foo")
# grant same perm again, confirm just one exists # grant same perm again, confirm just one exists
self.handler.grant_permission(myrole, 'foo') self.handler.grant_permission(myrole, "foo")
self.session.commit() self.session.commit()
self.assertEqual(self.session.query(model.Permission).count(), 1) self.assertEqual(self.session.query(model.Permission).count(), 1)
perm = self.session.query(model.Permission).one() perm = self.session.query(model.Permission).one()
self.assertIs(perm.role, myrole) self.assertIs(perm.role, myrole)
self.assertEqual(perm.permission, 'foo') self.assertEqual(perm.permission, "foo")
def test_revoke_permission(self): def test_revoke_permission(self):
model = self.app.model model = self.app.model
myrole = model.Role(name='My Role') myrole = model.Role(name="My Role")
self.session.add(myrole) self.session.add(myrole)
self.handler.grant_permission(myrole, 'foo') self.handler.grant_permission(myrole, "foo")
self.session.commit() self.session.commit()
# just the one perm # just the one perm
self.assertEqual(self.session.query(model.Permission).count(), 1) self.assertEqual(self.session.query(model.Permission).count(), 1)
# revoke it, then confirm # revoke it, then confirm
self.handler.revoke_permission(myrole, 'foo') self.handler.revoke_permission(myrole, "foo")
self.session.commit() self.session.commit()
self.assertEqual(self.session.query(model.Permission).count(), 0) self.assertEqual(self.session.query(model.Permission).count(), 0)
# revoke again, confirm # revoke again, confirm
self.handler.revoke_permission(myrole, 'foo') self.handler.revoke_permission(myrole, "foo")
self.session.commit() self.session.commit()
self.assertEqual(self.session.query(model.Permission).count(), 0) self.assertEqual(self.session.query(model.Permission).count(), 0)
@ -450,7 +457,7 @@ else:
def test_add_api_token(self): def test_add_api_token(self):
model = self.app.model model = self.app.model
barney = model.User(username='barney') barney = model.User(username="barney")
self.session.add(barney) self.session.add(barney)
token = self.handler.add_api_token(barney, "test token") token = self.handler.add_api_token(barney, "test token")
@ -461,7 +468,7 @@ else:
def test_delete_api_token(self): def test_delete_api_token(self):
model = self.app.model model = self.app.model
barney = model.User(username='barney') barney = model.User(username="barney")
self.session.add(barney) self.session.add(barney)
token = self.handler.add_api_token(barney, "test token") token = self.handler.add_api_token(barney, "test token")
self.session.commit() self.session.commit()

View file

@ -14,10 +14,10 @@ except ImportError:
else: else:
class MockBatch(model.BatchMixin, model.Base): class MockBatch(model.BatchMixin, model.Base):
__tablename__ = 'testing_batch_mock' __tablename__ = "testing_batch_mock"
class MockBatchRow(model.BatchRowMixin, model.Base): class MockBatchRow(model.BatchRowMixin, model.Base):
__tablename__ = 'testing_batch_mock_row' __tablename__ = "testing_batch_mock_row"
__batch_class__ = MockBatch __batch_class__ = MockBatch
class MockBatchHandler(mod.BatchHandler): class MockBatchHandler(mod.BatchHandler):
@ -30,12 +30,12 @@ else:
def test_model_class(self): def test_model_class(self):
handler = mod.BatchHandler(self.config) handler = mod.BatchHandler(self.config)
self.assertRaises(NotImplementedError, getattr, handler, 'model_class') self.assertRaises(NotImplementedError, getattr, handler, "model_class")
def test_batch_type(self): def test_batch_type(self):
with patch.object(mod.BatchHandler, 'model_class', new=MockBatch): with patch.object(mod.BatchHandler, "model_class", new=MockBatch):
handler = mod.BatchHandler(self.config) handler = mod.BatchHandler(self.config)
self.assertEqual(handler.batch_type, 'testing_batch_mock') self.assertEqual(handler.batch_type, "testing_batch_mock")
def test_make_batch(self): def test_make_batch(self):
handler = self.make_handler() handler = self.make_handler()
@ -50,25 +50,30 @@ else:
self.assertEqual(second, first + 1) self.assertEqual(second, first + 1)
third = handler.consume_batch_id(self.session, as_str=True) third = handler.consume_batch_id(self.session, as_str=True)
self.assertEqual(third, f'{first + 2:08d}') self.assertEqual(third, f"{first + 2:08d}")
def test_get_data_path(self): def test_get_data_path(self):
model = self.app.model model = self.app.model
user = model.User(username='barney') user = model.User(username="barney")
self.session.add(user) self.session.add(user)
with patch.object(mod.BatchHandler, 'model_class', new=MockBatch): with patch.object(mod.BatchHandler, "model_class", new=MockBatch):
handler = self.make_handler() handler = self.make_handler()
# root storage (default) # root storage (default)
with patch.object(self.app, 'get_appdir', return_value=self.tempdir): with patch.object(self.app, "get_appdir", return_value=self.tempdir):
path = handler.get_data_path() path = handler.get_data_path()
self.assertEqual(path, os.path.join(self.tempdir, 'data', 'batch', 'testing_batch_mock')) self.assertEqual(
path,
os.path.join(
self.tempdir, "data", "batch", "testing_batch_mock"
),
)
# root storage (configured) # root storage (configured)
self.config.setdefault('wutta.batch.storage_path', self.tempdir) self.config.setdefault("wutta.batch.storage_path", self.tempdir)
path = handler.get_data_path() path = handler.get_data_path()
self.assertEqual(path, os.path.join(self.tempdir, 'testing_batch_mock')) self.assertEqual(path, os.path.join(self.tempdir, "testing_batch_mock"))
batch = handler.make_batch(self.session, created_by=user) batch = handler.make_batch(self.session, created_by=user)
self.session.add(batch) self.session.add(batch)
@ -78,11 +83,18 @@ else:
path = handler.get_data_path(batch) path = handler.get_data_path(batch)
uuid = batch.uuid.hex uuid = batch.uuid.hex
final = os.path.join(uuid[-2:], uuid[:-2]) final = os.path.join(uuid[-2:], uuid[:-2])
self.assertEqual(path, os.path.join(self.tempdir, 'testing_batch_mock', final)) self.assertEqual(
path, os.path.join(self.tempdir, "testing_batch_mock", final)
)
# with filename # with filename
path = handler.get_data_path(batch, 'input.csv') path = handler.get_data_path(batch, "input.csv")
self.assertEqual(path, os.path.join(self.tempdir, 'testing_batch_mock', final, 'input.csv')) self.assertEqual(
path,
os.path.join(
self.tempdir, "testing_batch_mock", final, "input.csv"
),
)
# makedirs # makedirs
path = handler.get_data_path(batch) path = handler.get_data_path(batch)
@ -118,7 +130,7 @@ else:
def test_remove_row(self): def test_remove_row(self):
model = self.app.model model = self.app.model
handler = self.make_handler() handler = self.make_handler()
user = model.User(username='barney') user = model.User(username="barney")
self.session.add(user) self.session.add(user)
batch = handler.make_batch(self.session, created_by=user) batch = handler.make_batch(self.session, created_by=user)
self.session.add(batch) self.session.add(batch)
@ -134,7 +146,7 @@ else:
model = self.app.model model = self.app.model
handler = self.make_handler() handler = self.make_handler()
user = model.User(username='barney') user = model.User(username="barney")
self.session.add(user) self.session.add(user)
batch = handler.make_batch(self.session, created_by=user) batch = handler.make_batch(self.session, created_by=user)
self.session.add(batch) self.session.add(batch)
@ -152,7 +164,7 @@ else:
def test_do_execute(self): def test_do_execute(self):
model = self.app.model model = self.app.model
user = model.User(username='barney') user = model.User(username="barney")
self.session.add(user) self.session.add(user)
handler = self.make_handler() handler = self.make_handler()
@ -161,7 +173,7 @@ else:
self.session.flush() self.session.flush()
# error if execution not allowed # error if execution not allowed
with patch.object(handler, 'why_not_execute', return_value="bad batch"): with patch.object(handler, "why_not_execute", return_value="bad batch"):
self.assertRaises(RuntimeError, handler.do_execute, batch, user) self.assertRaises(RuntimeError, handler.do_execute, batch, user)
# nb. coverage only; tests nothing # nb. coverage only; tests nothing
@ -178,7 +190,7 @@ else:
model = self.app.model model = self.app.model
handler = self.make_handler() handler = self.make_handler()
user = model.User(username='barney') user = model.User(username="barney")
self.session.add(user) self.session.add(user)
# simple delete # simple delete
@ -201,13 +213,13 @@ else:
self.assertEqual(self.session.query(MockBatch).count(), 0) self.assertEqual(self.session.query(MockBatch).count(), 0)
# delete w/ files # delete w/ files
self.config.setdefault('wutta.batch.storage_path', self.tempdir) self.config.setdefault("wutta.batch.storage_path", self.tempdir)
batch = handler.make_batch(self.session, created_by=user) batch = handler.make_batch(self.session, created_by=user)
self.session.add(batch) self.session.add(batch)
self.session.flush() self.session.flush()
path = handler.get_data_path(batch, 'data.txt', makedirs=True) path = handler.get_data_path(batch, "data.txt", makedirs=True)
with open(path, 'wt') as f: with open(path, "wt") as f:
f.write('foo=bar') f.write("foo=bar")
self.assertEqual(self.session.query(MockBatch).count(), 1) self.assertEqual(self.session.query(MockBatch).count(), 1)
path = handler.get_data_path(batch) path = handler.get_data_path(batch)
self.assertTrue(os.path.exists(path)) self.assertTrue(os.path.exists(path))
@ -216,13 +228,13 @@ else:
self.assertFalse(os.path.exists(path)) self.assertFalse(os.path.exists(path))
# delete w/ files (dry-run) # delete w/ files (dry-run)
self.config.setdefault('wutta.batch.storage_path', self.tempdir) self.config.setdefault("wutta.batch.storage_path", self.tempdir)
batch = handler.make_batch(self.session, created_by=user) batch = handler.make_batch(self.session, created_by=user)
self.session.add(batch) self.session.add(batch)
self.session.flush() self.session.flush()
path = handler.get_data_path(batch, 'data.txt', makedirs=True) path = handler.get_data_path(batch, "data.txt", makedirs=True)
with open(path, 'wt') as f: with open(path, "wt") as f:
f.write('foo=bar') f.write("foo=bar")
self.assertEqual(self.session.query(MockBatch).count(), 1) self.assertEqual(self.session.query(MockBatch).count(), 1)
path = handler.get_data_path(batch) path = handler.get_data_path(batch)
self.assertTrue(os.path.exists(path)) self.assertTrue(os.path.exists(path))

File diff suppressed because it is too large Load diff

View file

@ -18,7 +18,7 @@ class TestEmailSetting(ConfigTestCase):
setting = mod.EmailSetting(self.config) setting = mod.EmailSetting(self.config)
self.assertIs(setting.config, self.config) self.assertIs(setting.config, self.config)
self.assertIs(setting.app, self.app) self.assertIs(setting.app, self.app)
self.assertEqual(setting.key, 'EmailSetting') self.assertEqual(setting.key, "EmailSetting")
def test_sample_data(self): def test_sample_data(self):
setting = mod.EmailSetting(self.config) setting = mod.EmailSetting(self.config)
@ -34,23 +34,23 @@ class TestMessage(FileTestCase):
msg = self.make_message() msg = self.make_message()
# set as list # set as list
recips = msg.get_recips(['sally@example.com']) recips = msg.get_recips(["sally@example.com"])
self.assertEqual(recips, ['sally@example.com']) self.assertEqual(recips, ["sally@example.com"])
# set as tuple # set as tuple
recips = msg.get_recips(('barney@example.com',)) recips = msg.get_recips(("barney@example.com",))
self.assertEqual(recips, ['barney@example.com']) self.assertEqual(recips, ["barney@example.com"])
# set as string # set as string
recips = msg.get_recips('wilma@example.com') recips = msg.get_recips("wilma@example.com")
self.assertEqual(recips, ['wilma@example.com']) self.assertEqual(recips, ["wilma@example.com"])
# set as null # set as null
recips = msg.get_recips(None) recips = msg.get_recips(None)
self.assertEqual(recips, []) self.assertEqual(recips, [])
# otherwise error # otherwise error
self.assertRaises(ValueError, msg.get_recips, {'foo': 'foo@example.com'}) self.assertRaises(ValueError, msg.get_recips, {"foo": "foo@example.com"})
def test_as_string(self): def test_as_string(self):
@ -59,38 +59,44 @@ class TestMessage(FileTestCase):
self.assertRaises(ValueError, msg.as_string) self.assertRaises(ValueError, msg.as_string)
# txt body # txt body
msg = self.make_message(sender='bob@example.com', msg = self.make_message(sender="bob@example.com", txt_body="hello world")
txt_body="hello world")
complete = msg.as_string() complete = msg.as_string()
self.assertIn('From: bob@example.com', complete) self.assertIn("From: bob@example.com", complete)
# html body # html body
msg = self.make_message(sender='bob@example.com', msg = self.make_message(
html_body="<p>hello world</p>") sender="bob@example.com", html_body="<p>hello world</p>"
)
complete = msg.as_string() complete = msg.as_string()
self.assertIn('From: bob@example.com', complete) self.assertIn("From: bob@example.com", complete)
# txt + html body # txt + html body
msg = self.make_message(sender='bob@example.com', msg = self.make_message(
txt_body="hello world", sender="bob@example.com",
html_body="<p>hello world</p>") txt_body="hello world",
html_body="<p>hello world</p>",
)
complete = msg.as_string() complete = msg.as_string()
self.assertIn('From: bob@example.com', complete) self.assertIn("From: bob@example.com", complete)
# html + attachment # html + attachment
csv_part = MIMEText("foo,bar\n1,2", 'csv', 'utf_8') csv_part = MIMEText("foo,bar\n1,2", "csv", "utf_8")
msg = self.make_message(sender='bob@example.com', msg = self.make_message(
html_body="<p>hello world</p>", sender="bob@example.com",
attachments=[csv_part]) html_body="<p>hello world</p>",
attachments=[csv_part],
)
complete = msg.as_string() complete = msg.as_string()
self.assertIn('Content-Type: multipart/mixed; boundary=', complete) self.assertIn("Content-Type: multipart/mixed; boundary=", complete)
self.assertIn('Content-Type: text/csv; charset="utf_8"', complete) self.assertIn('Content-Type: text/csv; charset="utf_8"', complete)
# error if improper attachment # error if improper attachment
csv_path = self.write_file('data.csv', "foo,bar\n1,2") csv_path = self.write_file("data.csv", "foo,bar\n1,2")
msg = self.make_message(sender='bob@example.com', msg = self.make_message(
html_body="<p>hello world</p>", sender="bob@example.com",
attachments=[csv_path]) html_body="<p>hello world</p>",
attachments=[csv_path],
)
self.assertRaises(ValueError, msg.as_string) self.assertRaises(ValueError, msg.as_string)
try: try:
msg.as_string() msg.as_string()
@ -98,27 +104,30 @@ class TestMessage(FileTestCase):
self.assertIn("must specify valid MIME attachments", str(err)) self.assertIn("must specify valid MIME attachments", str(err))
# everything # everything
msg = self.make_message(sender='bob@example.com', msg = self.make_message(
subject='meeting follow-up', sender="bob@example.com",
to='sally@example.com', subject="meeting follow-up",
cc='marketing@example.com', to="sally@example.com",
bcc='bob@example.com', cc="marketing@example.com",
replyto='sales@example.com', bcc="bob@example.com",
txt_body="hello world", replyto="sales@example.com",
html_body="<p>hello world</p>") txt_body="hello world",
html_body="<p>hello world</p>",
)
complete = msg.as_string() complete = msg.as_string()
self.assertIn('From: bob@example.com', complete) self.assertIn("From: bob@example.com", complete)
self.assertIn('Subject: meeting follow-up', complete) self.assertIn("Subject: meeting follow-up", complete)
self.assertIn('To: sally@example.com', complete) self.assertIn("To: sally@example.com", complete)
self.assertIn('Cc: marketing@example.com', complete) self.assertIn("Cc: marketing@example.com", complete)
self.assertIn('Bcc: bob@example.com', complete) self.assertIn("Bcc: bob@example.com", complete)
self.assertIn('Reply-To: sales@example.com', complete) self.assertIn("Reply-To: sales@example.com", complete)
class mock_foo(mod.EmailSetting): class mock_foo(mod.EmailSetting):
default_subject = "MOCK FOO!" default_subject = "MOCK FOO!"
def sample_data(self): def sample_data(self):
return {'foo': 'mock'} return {"foo": "mock"}
class TestEmailHandler(ConfigTestCase): class TestEmailHandler(ConfigTestCase):
@ -129,43 +138,43 @@ class TestEmailHandler(ConfigTestCase):
def test_constructor_lookups(self): def test_constructor_lookups(self):
# empty lookup paths by default, if no providers # empty lookup paths by default, if no providers
with patch.object(self.app, 'providers', new={}): with patch.object(self.app, "providers", new={}):
handler = self.make_handler() handler = self.make_handler()
self.assertEqual(handler.txt_templates.directories, []) self.assertEqual(handler.txt_templates.directories, [])
self.assertEqual(handler.html_templates.directories, []) self.assertEqual(handler.html_templates.directories, [])
# provider may specify paths as list # provider may specify paths as list
providers = { providers = {
'wuttatest': MagicMock(email_templates=['wuttjamaican:email-templates']), "wuttatest": MagicMock(email_templates=["wuttjamaican:email-templates"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
path = resource_path('wuttjamaican:email-templates') path = resource_path("wuttjamaican:email-templates")
self.assertEqual(handler.txt_templates.directories, [path]) self.assertEqual(handler.txt_templates.directories, [path])
self.assertEqual(handler.html_templates.directories, [path]) self.assertEqual(handler.html_templates.directories, [path])
# provider may specify paths as string # provider may specify paths as string
providers = { providers = {
'wuttatest': MagicMock(email_templates='wuttjamaican:email-templates'), "wuttatest": MagicMock(email_templates="wuttjamaican:email-templates"),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
path = resource_path('wuttjamaican:email-templates') path = resource_path("wuttjamaican:email-templates")
self.assertEqual(handler.txt_templates.directories, [path]) self.assertEqual(handler.txt_templates.directories, [path])
self.assertEqual(handler.html_templates.directories, [path]) self.assertEqual(handler.html_templates.directories, [path])
def test_get_email_modules(self): def test_get_email_modules(self):
# no providers, no email modules # no providers, no email modules
with patch.object(self.app, 'providers', new={}): with patch.object(self.app, "providers", new={}):
handler = self.make_handler() handler = self.make_handler()
self.assertEqual(handler.get_email_modules(), []) self.assertEqual(handler.get_email_modules(), [])
# provider may specify modules as list # provider may specify modules as list
providers = { providers = {
'wuttatest': MagicMock(email_modules=['wuttjamaican.email']), "wuttatest": MagicMock(email_modules=["wuttjamaican.email"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
modules = handler.get_email_modules() modules = handler.get_email_modules()
self.assertEqual(len(modules), 1) self.assertEqual(len(modules), 1)
@ -173,9 +182,9 @@ class TestEmailHandler(ConfigTestCase):
# provider may specify modules as string # provider may specify modules as string
providers = { providers = {
'wuttatest': MagicMock(email_modules='wuttjamaican.email'), "wuttatest": MagicMock(email_modules="wuttjamaican.email"),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
modules = handler.get_email_modules() modules = handler.get_email_modules()
self.assertEqual(len(modules), 1) self.assertEqual(len(modules), 1)
@ -184,36 +193,36 @@ class TestEmailHandler(ConfigTestCase):
def test_get_email_settings(self): def test_get_email_settings(self):
# no providers, no email settings # no providers, no email settings
with patch.object(self.app, 'providers', new={}): with patch.object(self.app, "providers", new={}):
handler = self.make_handler() handler = self.make_handler()
self.assertEqual(handler.get_email_settings(), {}) self.assertEqual(handler.get_email_settings(), {})
# provider may define email settings (via modules) # provider may define email settings (via modules)
providers = { providers = {
'wuttatest': MagicMock(email_modules=['tests.test_email']), "wuttatest": MagicMock(email_modules=["tests.test_email"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
settings = handler.get_email_settings() settings = handler.get_email_settings()
self.assertEqual(len(settings), 1) self.assertEqual(len(settings), 1)
self.assertIn('mock_foo', settings) self.assertIn("mock_foo", settings)
def test_get_email_setting(self): def test_get_email_setting(self):
providers = { providers = {
'wuttatest': MagicMock(email_modules=['tests.test_email']), "wuttatest": MagicMock(email_modules=["tests.test_email"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
# as instance # as instance
setting = handler.get_email_setting('mock_foo') setting = handler.get_email_setting("mock_foo")
self.assertIsInstance(setting, mod.EmailSetting) self.assertIsInstance(setting, mod.EmailSetting)
self.assertIsInstance(setting, mock_foo) self.assertIsInstance(setting, mock_foo)
# as class # as class
setting = handler.get_email_setting('mock_foo', instance=False) setting = handler.get_email_setting("mock_foo", instance=False)
self.assertTrue(issubclass(setting, mod.EmailSetting)) self.assertTrue(issubclass(setting, mod.EmailSetting))
self.assertIs(setting, mock_foo) self.assertIs(setting, mock_foo)
@ -229,10 +238,10 @@ class TestEmailHandler(ConfigTestCase):
# self.assertRaises(ConfigurationError, handler.make_auto_message, 'foo') # self.assertRaises(ConfigurationError, handler.make_auto_message, 'foo')
# message is empty by default # message is empty by default
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
self.assertIsInstance(msg, mod.Message) self.assertIsInstance(msg, mod.Message)
self.assertEqual(msg.key, 'foo') self.assertEqual(msg.key, "foo")
self.assertEqual(msg.sender, 'root@localhost') self.assertEqual(msg.sender, "root@localhost")
self.assertEqual(msg.subject, "Automated message") self.assertEqual(msg.subject, "Automated message")
self.assertEqual(msg.to, []) self.assertEqual(msg.to, [])
self.assertEqual(msg.cc, []) self.assertEqual(msg.cc, [])
@ -242,14 +251,14 @@ class TestEmailHandler(ConfigTestCase):
self.assertIsNone(msg.html_body) self.assertIsNone(msg.html_body)
# override defaults # override defaults
self.config.setdefault('wutta.email.default.sender', 'bob@example.com') self.config.setdefault("wutta.email.default.sender", "bob@example.com")
self.config.setdefault('wutta.email.default.subject', 'Attention required') self.config.setdefault("wutta.email.default.subject", "Attention required")
# message is empty by default # message is empty by default
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
self.assertIsInstance(msg, mod.Message) self.assertIsInstance(msg, mod.Message)
self.assertEqual(msg.key, 'foo') self.assertEqual(msg.key, "foo")
self.assertEqual(msg.sender, 'bob@example.com') self.assertEqual(msg.sender, "bob@example.com")
self.assertEqual(msg.subject, "Attention required") self.assertEqual(msg.subject, "Attention required")
self.assertEqual(msg.to, []) self.assertEqual(msg.to, [])
self.assertEqual(msg.cc, []) self.assertEqual(msg.cc, [])
@ -260,15 +269,15 @@ class TestEmailHandler(ConfigTestCase):
# but if there is a proper email profile configured for key, # but if there is a proper email profile configured for key,
# then we should get back a more complete message # then we should get back a more complete message
self.config.setdefault('wutta.email.test_foo.subject', "hello foo") self.config.setdefault("wutta.email.test_foo.subject", "hello foo")
self.config.setdefault('wutta.email.test_foo.to', 'sally@example.com') self.config.setdefault("wutta.email.test_foo.to", "sally@example.com")
self.config.setdefault('wutta.email.templates', 'tests:email-templates') self.config.setdefault("wutta.email.templates", "tests:email-templates")
handler = self.make_handler() handler = self.make_handler()
msg = handler.make_auto_message('test_foo') msg = handler.make_auto_message("test_foo")
self.assertEqual(msg.key, 'test_foo') self.assertEqual(msg.key, "test_foo")
self.assertEqual(msg.sender, 'bob@example.com') self.assertEqual(msg.sender, "bob@example.com")
self.assertEqual(msg.subject, "hello foo") self.assertEqual(msg.subject, "hello foo")
self.assertEqual(msg.to, ['sally@example.com']) self.assertEqual(msg.to, ["sally@example.com"])
self.assertEqual(msg.cc, []) self.assertEqual(msg.cc, [])
self.assertEqual(msg.bcc, []) self.assertEqual(msg.bcc, [])
self.assertIsNone(msg.replyto) self.assertIsNone(msg.replyto)
@ -279,160 +288,162 @@ class TestEmailHandler(ConfigTestCase):
# kwarg at all; others get skipped if kwarg is empty # kwarg at all; others get skipped if kwarg is empty
# sender # sender
with patch.object(handler, 'get_auto_sender') as get_auto_sender: with patch.object(handler, "get_auto_sender") as get_auto_sender:
msg = handler.make_auto_message('foo', sender=None) msg = handler.make_auto_message("foo", sender=None)
get_auto_sender.assert_not_called() get_auto_sender.assert_not_called()
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
get_auto_sender.assert_called_once_with('foo') get_auto_sender.assert_called_once_with("foo")
# subject # subject
with patch.object(handler, 'get_auto_subject') as get_auto_subject: with patch.object(handler, "get_auto_subject") as get_auto_subject:
msg = handler.make_auto_message('foo', subject=None) msg = handler.make_auto_message("foo", subject=None)
get_auto_subject.assert_not_called() get_auto_subject.assert_not_called()
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
get_auto_subject.assert_called_once_with('foo', {}, default=None) get_auto_subject.assert_called_once_with("foo", {}, default=None)
# to # to
with patch.object(handler, 'get_auto_to') as get_auto_to: with patch.object(handler, "get_auto_to") as get_auto_to:
msg = handler.make_auto_message('foo', to=None) msg = handler.make_auto_message("foo", to=None)
get_auto_to.assert_not_called() get_auto_to.assert_not_called()
get_auto_to.return_value = None get_auto_to.return_value = None
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
get_auto_to.assert_called_once_with('foo') get_auto_to.assert_called_once_with("foo")
# cc # cc
with patch.object(handler, 'get_auto_cc') as get_auto_cc: with patch.object(handler, "get_auto_cc") as get_auto_cc:
msg = handler.make_auto_message('foo', cc=None) msg = handler.make_auto_message("foo", cc=None)
get_auto_cc.assert_not_called() get_auto_cc.assert_not_called()
get_auto_cc.return_value = None get_auto_cc.return_value = None
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
get_auto_cc.assert_called_once_with('foo') get_auto_cc.assert_called_once_with("foo")
# bcc # bcc
with patch.object(handler, 'get_auto_bcc') as get_auto_bcc: with patch.object(handler, "get_auto_bcc") as get_auto_bcc:
msg = handler.make_auto_message('foo', bcc=None) msg = handler.make_auto_message("foo", bcc=None)
get_auto_bcc.assert_not_called() get_auto_bcc.assert_not_called()
get_auto_bcc.return_value = None get_auto_bcc.return_value = None
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
get_auto_bcc.assert_called_once_with('foo') get_auto_bcc.assert_called_once_with("foo")
# txt_body # txt_body
with patch.object(handler, 'get_auto_txt_body') as get_auto_txt_body: with patch.object(handler, "get_auto_txt_body") as get_auto_txt_body:
msg = handler.make_auto_message('foo', txt_body=None) msg = handler.make_auto_message("foo", txt_body=None)
get_auto_txt_body.assert_not_called() get_auto_txt_body.assert_not_called()
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
get_auto_txt_body.assert_called_once_with('foo', {}) get_auto_txt_body.assert_called_once_with("foo", {})
# html_body # html_body
with patch.object(handler, 'get_auto_html_body') as get_auto_html_body: with patch.object(handler, "get_auto_html_body") as get_auto_html_body:
msg = handler.make_auto_message('foo', html_body=None) msg = handler.make_auto_message("foo", html_body=None)
get_auto_html_body.assert_not_called() get_auto_html_body.assert_not_called()
msg = handler.make_auto_message('foo') msg = handler.make_auto_message("foo")
get_auto_html_body.assert_called_once_with('foo', {}) get_auto_html_body.assert_called_once_with("foo", {})
def test_get_auto_sender(self): def test_get_auto_sender(self):
handler = self.make_handler() handler = self.make_handler()
# basic global default # basic global default
self.assertEqual(handler.get_auto_sender('foo'), 'root@localhost') self.assertEqual(handler.get_auto_sender("foo"), "root@localhost")
# can set global default # can set global default
self.config.setdefault('wutta.email.default.sender', 'bob@example.com') self.config.setdefault("wutta.email.default.sender", "bob@example.com")
self.assertEqual(handler.get_auto_sender('foo'), 'bob@example.com') self.assertEqual(handler.get_auto_sender("foo"), "bob@example.com")
# can set for key # can set for key
self.config.setdefault('wutta.email.foo.sender', 'sally@example.com') self.config.setdefault("wutta.email.foo.sender", "sally@example.com")
self.assertEqual(handler.get_auto_sender('foo'), 'sally@example.com') self.assertEqual(handler.get_auto_sender("foo"), "sally@example.com")
def test_get_auto_replyto(self): def test_get_auto_replyto(self):
handler = self.make_handler() handler = self.make_handler()
# null by default # null by default
self.assertIsNone(handler.get_auto_replyto('foo')) self.assertIsNone(handler.get_auto_replyto("foo"))
# can set global default # can set global default
self.config.setdefault('wutta.email.default.replyto', 'george@example.com') self.config.setdefault("wutta.email.default.replyto", "george@example.com")
self.assertEqual(handler.get_auto_replyto('foo'), 'george@example.com') self.assertEqual(handler.get_auto_replyto("foo"), "george@example.com")
# can set for key # can set for key
self.config.setdefault('wutta.email.foo.replyto', 'kathy@example.com') self.config.setdefault("wutta.email.foo.replyto", "kathy@example.com")
self.assertEqual(handler.get_auto_replyto('foo'), 'kathy@example.com') self.assertEqual(handler.get_auto_replyto("foo"), "kathy@example.com")
def test_get_auto_subject_template(self): def test_get_auto_subject_template(self):
handler = self.make_handler() handler = self.make_handler()
# global default # global default
template = handler.get_auto_subject_template('foo') template = handler.get_auto_subject_template("foo")
self.assertEqual(template, "Automated message") self.assertEqual(template, "Automated message")
# can configure alternate global default # can configure alternate global default
self.config.setdefault('wutta.email.default.subject', "Wutta Message") self.config.setdefault("wutta.email.default.subject", "Wutta Message")
template = handler.get_auto_subject_template('foo') template = handler.get_auto_subject_template("foo")
self.assertEqual(template, "Wutta Message") self.assertEqual(template, "Wutta Message")
# can configure just for key # can configure just for key
self.config.setdefault('wutta.email.foo.subject', "Foo Message") self.config.setdefault("wutta.email.foo.subject", "Foo Message")
template = handler.get_auto_subject_template('foo') template = handler.get_auto_subject_template("foo")
self.assertEqual(template, "Foo Message") self.assertEqual(template, "Foo Message")
# EmailSetting can provide default subject # EmailSetting can provide default subject
providers = { providers = {
'wuttatest': MagicMock(email_modules=['tests.test_email']), "wuttatest": MagicMock(email_modules=["tests.test_email"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
template = handler.get_auto_subject_template('mock_foo') template = handler.get_auto_subject_template("mock_foo")
self.assertEqual(template, "MOCK FOO!") self.assertEqual(template, "MOCK FOO!")
# caller can provide default subject # caller can provide default subject
template = handler.get_auto_subject_template('mock_foo', default="whatever is clever") template = handler.get_auto_subject_template(
"mock_foo", default="whatever is clever"
)
self.assertEqual(template, "whatever is clever") self.assertEqual(template, "whatever is clever")
def test_get_auto_subject(self): def test_get_auto_subject(self):
handler = self.make_handler() handler = self.make_handler()
# global default # global default
subject = handler.get_auto_subject('foo') subject = handler.get_auto_subject("foo")
self.assertEqual(subject, "Automated message") self.assertEqual(subject, "Automated message")
# can configure alternate global default # can configure alternate global default
self.config.setdefault('wutta.email.default.subject', "Wutta Message") self.config.setdefault("wutta.email.default.subject", "Wutta Message")
subject = handler.get_auto_subject('foo') subject = handler.get_auto_subject("foo")
self.assertEqual(subject, "Wutta Message") self.assertEqual(subject, "Wutta Message")
# caller can provide default subject # caller can provide default subject
subject = handler.get_auto_subject('foo', default="whatever is clever") subject = handler.get_auto_subject("foo", default="whatever is clever")
self.assertEqual(subject, "whatever is clever") self.assertEqual(subject, "whatever is clever")
# can configure just for key # can configure just for key
self.config.setdefault('wutta.email.foo.subject', "Foo Message") self.config.setdefault("wutta.email.foo.subject", "Foo Message")
subject = handler.get_auto_subject('foo') subject = handler.get_auto_subject("foo")
self.assertEqual(subject, "Foo Message") self.assertEqual(subject, "Foo Message")
# proper template is rendered # proper template is rendered
self.config.setdefault('wutta.email.bar.subject', "${foo} Message") self.config.setdefault("wutta.email.bar.subject", "${foo} Message")
subject = handler.get_auto_subject('bar', {'foo': "FOO"}) subject = handler.get_auto_subject("bar", {"foo": "FOO"})
self.assertEqual(subject, "FOO Message") self.assertEqual(subject, "FOO Message")
# unless we ask it not to # unless we ask it not to
subject = handler.get_auto_subject('bar', {'foo': "FOO"}, rendered=False) subject = handler.get_auto_subject("bar", {"foo": "FOO"}, rendered=False)
self.assertEqual(subject, "${foo} Message") self.assertEqual(subject, "${foo} Message")
def test_get_auto_recips(self): def test_get_auto_recips(self):
handler = self.make_handler() handler = self.make_handler()
# error if bad type requested # error if bad type requested
self.assertRaises(ValueError, handler.get_auto_recips, 'foo', 'doesnotexist') self.assertRaises(ValueError, handler.get_auto_recips, "foo", "doesnotexist")
# can configure global default # can configure global default
self.config.setdefault('wutta.email.default.to', 'admin@example.com') self.config.setdefault("wutta.email.default.to", "admin@example.com")
recips = handler.get_auto_recips('foo', 'to') recips = handler.get_auto_recips("foo", "to")
self.assertEqual(recips, ['admin@example.com']) self.assertEqual(recips, ["admin@example.com"])
# can configure just for key # can configure just for key
self.config.setdefault('wutta.email.foo.to', 'bob@example.com') self.config.setdefault("wutta.email.foo.to", "bob@example.com")
recips = handler.get_auto_recips('foo', 'to') recips = handler.get_auto_recips("foo", "to")
self.assertEqual(recips, ['bob@example.com']) self.assertEqual(recips, ["bob@example.com"])
def test_get_auto_body_template(self): def test_get_auto_body_template(self):
from mako.template import Template from mako.template import Template
@ -440,88 +451,88 @@ class TestEmailHandler(ConfigTestCase):
handler = self.make_handler() handler = self.make_handler()
# error if bad request # error if bad request
self.assertRaises(ValueError, handler.get_auto_body_template, 'foo', 'BADTYPE') self.assertRaises(ValueError, handler.get_auto_body_template, "foo", "BADTYPE")
# empty by default # empty by default
template = handler.get_auto_body_template('foo', 'txt') template = handler.get_auto_body_template("foo", "txt")
self.assertIsNone(template) self.assertIsNone(template)
# but returns a template if it exists # but returns a template if it exists
providers = { providers = {
'wuttatest': MagicMock(email_templates=['tests:email-templates']), "wuttatest": MagicMock(email_templates=["tests:email-templates"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
template = handler.get_auto_body_template('test_foo', 'txt') template = handler.get_auto_body_template("test_foo", "txt")
self.assertIsInstance(template, Template) self.assertIsInstance(template, Template)
self.assertEqual(template.uri, 'test_foo.txt.mako') self.assertEqual(template.uri, "test_foo.txt.mako")
def test_get_auto_txt_body(self): def test_get_auto_txt_body(self):
handler = self.make_handler() handler = self.make_handler()
# empty by default # empty by default
body = handler.get_auto_txt_body('some-random-email') body = handler.get_auto_txt_body("some-random-email")
self.assertIsNone(body) self.assertIsNone(body)
# but returns body if template exists # but returns body if template exists
providers = { providers = {
'wuttatest': MagicMock(email_templates=['tests:email-templates']), "wuttatest": MagicMock(email_templates=["tests:email-templates"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
body = handler.get_auto_txt_body('test_foo') body = handler.get_auto_txt_body("test_foo")
self.assertEqual(body, 'hello from foo txt template\n') self.assertEqual(body, "hello from foo txt template\n")
def test_get_auto_html_body(self): def test_get_auto_html_body(self):
handler = self.make_handler() handler = self.make_handler()
# empty by default # empty by default
body = handler.get_auto_html_body('some-random-email') body = handler.get_auto_html_body("some-random-email")
self.assertIsNone(body) self.assertIsNone(body)
# but returns body if template exists # but returns body if template exists
providers = { providers = {
'wuttatest': MagicMock(email_templates=['tests:email-templates']), "wuttatest": MagicMock(email_templates=["tests:email-templates"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
body = handler.get_auto_html_body('test_foo') body = handler.get_auto_html_body("test_foo")
self.assertEqual(body, '<p>hello from foo html template</p>\n') self.assertEqual(body, "<p>hello from foo html template</p>\n")
def test_get_notes(self): def test_get_notes(self):
handler = self.make_handler() handler = self.make_handler()
# null by default # null by default
self.assertIsNone(handler.get_notes('foo')) self.assertIsNone(handler.get_notes("foo"))
# configured notes # configured notes
self.config.setdefault('wutta.email.foo.notes', 'hello world') self.config.setdefault("wutta.email.foo.notes", "hello world")
self.assertEqual(handler.get_notes('foo'), 'hello world') self.assertEqual(handler.get_notes("foo"), "hello world")
def test_is_enabled(self): def test_is_enabled(self):
handler = self.make_handler() handler = self.make_handler()
# enabled by default # enabled by default
self.assertTrue(handler.is_enabled('default')) self.assertTrue(handler.is_enabled("default"))
self.assertTrue(handler.is_enabled('foo')) self.assertTrue(handler.is_enabled("foo"))
# specific type disabled # specific type disabled
self.config.setdefault('wutta.email.foo.enabled', 'false') self.config.setdefault("wutta.email.foo.enabled", "false")
self.assertFalse(handler.is_enabled('foo')) self.assertFalse(handler.is_enabled("foo"))
# default is disabled # default is disabled
self.assertTrue(handler.is_enabled('bar')) self.assertTrue(handler.is_enabled("bar"))
self.config.setdefault('wutta.email.default.enabled', 'false') self.config.setdefault("wutta.email.default.enabled", "false")
self.assertFalse(handler.is_enabled('bar')) self.assertFalse(handler.is_enabled("bar"))
def test_deliver_message(self): def test_deliver_message(self):
handler = self.make_handler() handler = self.make_handler()
msg = handler.make_message(sender='bob@example.com', to='sally@example.com') msg = handler.make_message(sender="bob@example.com", to="sally@example.com")
with patch.object(msg, 'as_string', return_value='msg-str'): with patch.object(msg, "as_string", return_value="msg-str"):
# no smtp session since sending email is disabled by default # no smtp session since sending email is disabled by default
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg) handler.deliver_message(msg)
@ -530,85 +541,99 @@ class TestEmailHandler(ConfigTestCase):
session.sendmail.assert_not_called() session.sendmail.assert_not_called()
# now let's enable sending # now let's enable sending
self.config.setdefault('wutta.mail.send_emails', 'true') self.config.setdefault("wutta.mail.send_emails", "true")
# smtp login not attempted by default # smtp login not attempted by default
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg) handler.deliver_message(msg)
smtplib.SMTP.assert_called_once_with('localhost') smtplib.SMTP.assert_called_once_with("localhost")
session.login.assert_not_called() session.login.assert_not_called()
session.sendmail.assert_called_once_with('bob@example.com', {'sally@example.com'}, 'msg-str') session.sendmail.assert_called_once_with(
"bob@example.com", {"sally@example.com"}, "msg-str"
)
# but login attempted if config has credentials # but login attempted if config has credentials
self.config.setdefault('wutta.mail.smtp.username', 'bob') self.config.setdefault("wutta.mail.smtp.username", "bob")
self.config.setdefault('wutta.mail.smtp.password', 'seekrit') self.config.setdefault("wutta.mail.smtp.password", "seekrit")
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg) handler.deliver_message(msg)
smtplib.SMTP.assert_called_once_with('localhost') smtplib.SMTP.assert_called_once_with("localhost")
session.login.assert_called_once_with('bob', 'seekrit') session.login.assert_called_once_with("bob", "seekrit")
session.sendmail.assert_called_once_with('bob@example.com', {'sally@example.com'}, 'msg-str') session.sendmail.assert_called_once_with(
"bob@example.com", {"sally@example.com"}, "msg-str"
)
# error if no sender # error if no sender
msg = handler.make_message(to='sally@example.com') msg = handler.make_message(to="sally@example.com")
self.assertRaises(ValueError, handler.deliver_message, msg) self.assertRaises(ValueError, handler.deliver_message, msg)
# error if no recips # error if no recips
msg = handler.make_message(sender='bob@example.com') msg = handler.make_message(sender="bob@example.com")
self.assertRaises(ValueError, handler.deliver_message, msg) self.assertRaises(ValueError, handler.deliver_message, msg)
# can set recips as list # can set recips as list
msg = handler.make_message(sender='bob@example.com') msg = handler.make_message(sender="bob@example.com")
with patch.object(msg, 'as_string', return_value='msg-str'): with patch.object(msg, "as_string", return_value="msg-str"):
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg, recips=['sally@example.com']) handler.deliver_message(msg, recips=["sally@example.com"])
smtplib.SMTP.assert_called_once_with('localhost') smtplib.SMTP.assert_called_once_with("localhost")
session.sendmail.assert_called_once_with('bob@example.com', {'sally@example.com'}, 'msg-str') session.sendmail.assert_called_once_with(
"bob@example.com", {"sally@example.com"}, "msg-str"
)
# can set recips as string # can set recips as string
msg = handler.make_message(sender='bob@example.com') msg = handler.make_message(sender="bob@example.com")
with patch.object(msg, 'as_string', return_value='msg-str'): with patch.object(msg, "as_string", return_value="msg-str"):
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg, recips='sally@example.com') handler.deliver_message(msg, recips="sally@example.com")
smtplib.SMTP.assert_called_once_with('localhost') smtplib.SMTP.assert_called_once_with("localhost")
session.sendmail.assert_called_once_with('bob@example.com', {'sally@example.com'}, 'msg-str') session.sendmail.assert_called_once_with(
"bob@example.com", {"sally@example.com"}, "msg-str"
)
# can set recips via to # can set recips via to
msg = handler.make_message(sender='bob@example.com', to='sally@example.com') msg = handler.make_message(sender="bob@example.com", to="sally@example.com")
with patch.object(msg, 'as_string', return_value='msg-str'): with patch.object(msg, "as_string", return_value="msg-str"):
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg) handler.deliver_message(msg)
smtplib.SMTP.assert_called_once_with('localhost') smtplib.SMTP.assert_called_once_with("localhost")
session.sendmail.assert_called_once_with('bob@example.com', {'sally@example.com'}, 'msg-str') session.sendmail.assert_called_once_with(
"bob@example.com", {"sally@example.com"}, "msg-str"
)
# can set recips via cc # can set recips via cc
msg = handler.make_message(sender='bob@example.com', cc='sally@example.com') msg = handler.make_message(sender="bob@example.com", cc="sally@example.com")
with patch.object(msg, 'as_string', return_value='msg-str'): with patch.object(msg, "as_string", return_value="msg-str"):
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg) handler.deliver_message(msg)
smtplib.SMTP.assert_called_once_with('localhost') smtplib.SMTP.assert_called_once_with("localhost")
session.sendmail.assert_called_once_with('bob@example.com', {'sally@example.com'}, 'msg-str') session.sendmail.assert_called_once_with(
"bob@example.com", {"sally@example.com"}, "msg-str"
)
# can set recips via bcc # can set recips via bcc
msg = handler.make_message(sender='bob@example.com', bcc='sally@example.com') msg = handler.make_message(sender="bob@example.com", bcc="sally@example.com")
with patch.object(msg, 'as_string', return_value='msg-str'): with patch.object(msg, "as_string", return_value="msg-str"):
with patch.object(mod, 'smtplib') as smtplib: with patch.object(mod, "smtplib") as smtplib:
session = MagicMock() session = MagicMock()
smtplib.SMTP.return_value = session smtplib.SMTP.return_value = session
handler.deliver_message(msg) handler.deliver_message(msg)
smtplib.SMTP.assert_called_once_with('localhost') smtplib.SMTP.assert_called_once_with("localhost")
session.sendmail.assert_called_once_with('bob@example.com', {'sally@example.com'}, 'msg-str') session.sendmail.assert_called_once_with(
"bob@example.com", {"sally@example.com"}, "msg-str"
)
def test_sending_is_enabled(self): def test_sending_is_enabled(self):
handler = self.make_handler() handler = self.make_handler()
@ -617,12 +642,12 @@ class TestEmailHandler(ConfigTestCase):
self.assertFalse(handler.sending_is_enabled()) self.assertFalse(handler.sending_is_enabled())
# but can be turned on # but can be turned on
self.config.setdefault('wutta.mail.send_emails', 'true') self.config.setdefault("wutta.mail.send_emails", "true")
self.assertTrue(handler.sending_is_enabled()) self.assertTrue(handler.sending_is_enabled())
def test_send_email(self): def test_send_email(self):
handler = self.make_handler() handler = self.make_handler()
with patch.object(handler, 'deliver_message') as deliver_message: with patch.object(handler, "deliver_message") as deliver_message:
# specify message w/ no body # specify message w/ no body
msg = handler.make_message() msg = handler.make_message()
@ -631,7 +656,7 @@ class TestEmailHandler(ConfigTestCase):
# again, but also specify key # again, but also specify key
msg = handler.make_message() msg = handler.make_message()
self.assertRaises(ValueError, handler.send_email, 'foo', message=msg) self.assertRaises(ValueError, handler.send_email, "foo", message=msg)
self.assertFalse(deliver_message.called) self.assertFalse(deliver_message.called)
# specify complete message # specify complete message
@ -643,7 +668,7 @@ class TestEmailHandler(ConfigTestCase):
# again, but also specify key # again, but also specify key
deliver_message.reset_mock() deliver_message.reset_mock()
msg = handler.make_message(txt_body="hello world") msg = handler.make_message(txt_body="hello world")
handler.send_email('foo', message=msg) handler.send_email("foo", message=msg)
deliver_message.assert_called_once_with(msg, recips=None) deliver_message.assert_called_once_with(msg, recips=None)
# no key, no message # no key, no message
@ -652,25 +677,27 @@ class TestEmailHandler(ConfigTestCase):
# auto-create message w/ no template # auto-create message w/ no template
deliver_message.reset_mock() deliver_message.reset_mock()
self.assertRaises(RuntimeError, handler.send_email, 'foo', sender='foo@example.com') self.assertRaises(
RuntimeError, handler.send_email, "foo", sender="foo@example.com"
)
self.assertFalse(deliver_message.called) self.assertFalse(deliver_message.called)
# auto create w/ body # auto create w/ body
deliver_message.reset_mock() deliver_message.reset_mock()
handler.send_email('foo', sender='foo@example.com', txt_body="hello world") handler.send_email("foo", sender="foo@example.com", txt_body="hello world")
self.assertTrue(deliver_message.called) self.assertTrue(deliver_message.called)
# type is disabled # type is disabled
deliver_message.reset_mock() deliver_message.reset_mock()
self.config.setdefault('wutta.email.foo.enabled', False) self.config.setdefault("wutta.email.foo.enabled", False)
handler.send_email('foo', sender='foo@example.com', txt_body="hello world") handler.send_email("foo", sender="foo@example.com", txt_body="hello world")
self.assertFalse(deliver_message.called) self.assertFalse(deliver_message.called)
# default is disabled # default is disabled
deliver_message.reset_mock() deliver_message.reset_mock()
handler.send_email('bar', sender='bar@example.com', txt_body="hello world") handler.send_email("bar", sender="bar@example.com", txt_body="hello world")
self.assertTrue(deliver_message.called) self.assertTrue(deliver_message.called)
deliver_message.reset_mock() deliver_message.reset_mock()
self.config.setdefault('wutta.email.default.enabled', False) self.config.setdefault("wutta.email.default.enabled", False)
handler.send_email('bar', sender='bar@example.com', txt_body="hello world") handler.send_email("bar", sender="bar@example.com", txt_body="hello world")
self.assertFalse(deliver_message.called) self.assertFalse(deliver_message.called)

View file

@ -18,16 +18,16 @@ class TestInstallHandler(ConfigTestCase):
def test_constructor(self): def test_constructor(self):
handler = self.make_handler() handler = self.make_handler()
self.assertEqual(handler.pkg_name, 'poser') self.assertEqual(handler.pkg_name, "poser")
self.assertEqual(handler.app_title, 'poser') self.assertEqual(handler.app_title, "poser")
self.assertEqual(handler.pypi_name, 'poser') self.assertEqual(handler.pypi_name, "poser")
self.assertEqual(handler.egg_name, 'poser') self.assertEqual(handler.egg_name, "poser")
def test_run(self): def test_run(self):
handler = self.make_handler() handler = self.make_handler()
with patch.object(handler, 'show_welcome') as show_welcome: with patch.object(handler, "show_welcome") as show_welcome:
with patch.object(handler, 'sanity_check') as sanity_check: with patch.object(handler, "sanity_check") as sanity_check:
with patch.object(handler, 'do_install_steps') as do_install_steps: with patch.object(handler, "do_install_steps") as do_install_steps:
handler.run() handler.run()
show_welcome.assert_called_once_with() show_welcome.assert_called_once_with()
sanity_check.assert_called_once_with() sanity_check.assert_called_once_with()
@ -35,9 +35,9 @@ class TestInstallHandler(ConfigTestCase):
def test_show_welcome(self): def test_show_welcome(self):
handler = self.make_handler() handler = self.make_handler()
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
with patch.object(handler, 'rprint') as rprint: with patch.object(handler, "rprint") as rprint:
with patch.object(handler, 'prompt_bool') as prompt_bool: with patch.object(handler, "prompt_bool") as prompt_bool:
# user continues # user continues
prompt_bool.return_value = True prompt_bool.return_value = True
@ -51,9 +51,9 @@ class TestInstallHandler(ConfigTestCase):
def test_sanity_check(self): def test_sanity_check(self):
handler = self.make_handler() handler = self.make_handler()
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
with patch.object(mod, 'os') as os: with patch.object(mod, "os") as os:
with patch.object(handler, 'rprint') as rprint: with patch.object(handler, "rprint") as rprint:
# pretend appdir does not exist # pretend appdir does not exist
os.path.exists.return_value = False os.path.exists.return_value = False
@ -67,24 +67,26 @@ class TestInstallHandler(ConfigTestCase):
def test_do_install_steps(self): def test_do_install_steps(self):
handler = self.make_handler() handler = self.make_handler()
handler.templates = TemplateLookup(directories=[ handler.templates = TemplateLookup(
self.app.resource_path('wuttjamaican:templates/install'), directories=[
]) self.app.resource_path("wuttjamaican:templates/install"),
]
)
dbinfo = { dbinfo = {
'dburl': f'sqlite:///{self.tempdir}/poser.sqlite', "dburl": f"sqlite:///{self.tempdir}/poser.sqlite",
} }
with patch.object(handler, 'get_dbinfo', return_value=dbinfo): with patch.object(handler, "get_dbinfo", return_value=dbinfo):
with patch.object(handler, 'make_appdir') as make_appdir: with patch.object(handler, "make_appdir") as make_appdir:
with patch.object(handler, 'install_db_schema') as install_db_schema: with patch.object(handler, "install_db_schema") as install_db_schema:
# nb. just for sanity/coverage # nb. just for sanity/coverage
install_db_schema.return_value = True install_db_schema.return_value = True
self.assertFalse(hasattr(handler, 'schema_installed')) self.assertFalse(hasattr(handler, "schema_installed"))
handler.do_install_steps() handler.do_install_steps()
self.assertTrue(make_appdir.called) self.assertTrue(make_appdir.called)
self.assertTrue(handler.schema_installed) self.assertTrue(handler.schema_installed)
install_db_schema.assert_called_once_with(dbinfo['dburl']) install_db_schema.assert_called_once_with(dbinfo["dburl"])
def test_get_dbinfo(self): def test_get_dbinfo(self):
try: try:
@ -97,16 +99,16 @@ class TestInstallHandler(ConfigTestCase):
handler = self.make_handler() handler = self.make_handler()
def prompt_generic(info, default=None, is_password=False): def prompt_generic(info, default=None, is_password=False):
if info in ('db name', 'db user'): if info in ("db name", "db user"):
return 'poser' return "poser"
if is_password: if is_password:
return 'seekrit' return "seekrit"
return default return default
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
with patch.object(handler, 'prompt_generic', side_effect=prompt_generic): with patch.object(handler, "prompt_generic", side_effect=prompt_generic):
with patch.object(handler, 'test_db_connection') as test_db_connection: with patch.object(handler, "test_db_connection") as test_db_connection:
with patch.object(handler, 'rprint') as rprint: with patch.object(handler, "rprint") as rprint:
# bad dbinfo # bad dbinfo
test_db_connection.return_value = "bad dbinfo" test_db_connection.return_value = "bad dbinfo"
@ -114,7 +116,7 @@ class TestInstallHandler(ConfigTestCase):
self.assertRaises(RuntimeError, handler.get_dbinfo) self.assertRaises(RuntimeError, handler.get_dbinfo)
sys.exit.assert_called_once_with(1) sys.exit.assert_called_once_with(1)
seekrit = '***' if SA2 else 'seekrit' seekrit = "***" if SA2 else "seekrit"
# good dbinfo # good dbinfo
sys.exit.reset_mock() sys.exit.reset_mock()
@ -122,8 +124,10 @@ class TestInstallHandler(ConfigTestCase):
dbinfo = handler.get_dbinfo() dbinfo = handler.get_dbinfo()
self.assertFalse(sys.exit.called) self.assertFalse(sys.exit.called)
rprint.assert_called_with("[bold green]good[/bold green]") rprint.assert_called_with("[bold green]good[/bold green]")
self.assertEqual(str(dbinfo['dburl']), self.assertEqual(
f'postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser') str(dbinfo["dburl"]),
f"postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser",
)
def test_make_db_url(self): def test_make_db_url(self):
try: try:
@ -134,13 +138,21 @@ class TestInstallHandler(ConfigTestCase):
from wuttjamaican.db.util import SA2 from wuttjamaican.db.util import SA2
handler = self.make_handler() handler = self.make_handler()
seekrit = '***' if SA2 else 'seekrit' seekrit = "***" if SA2 else "seekrit"
url = handler.make_db_url('postgresql', 'localhost', '5432', 'poser', 'poser', 'seekrit') url = handler.make_db_url(
self.assertEqual(str(url), f'postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser') "postgresql", "localhost", "5432", "poser", "poser", "seekrit"
)
self.assertEqual(
str(url), f"postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser"
)
url = handler.make_db_url('mysql', 'localhost', '3306', 'poser', 'poser', 'seekrit') url = handler.make_db_url(
self.assertEqual(str(url), f'mysql+mysqlconnector://poser:{seekrit}@localhost:3306/poser') "mysql", "localhost", "3306", "poser", "poser", "seekrit"
)
self.assertEqual(
str(url), f"mysql+mysqlconnector://poser:{seekrit}@localhost:3306/poser"
)
def test_test_db_connection(self): def test_test_db_connection(self):
try: try:
@ -151,11 +163,11 @@ class TestInstallHandler(ConfigTestCase):
handler = self.make_handler() handler = self.make_handler()
# db does not exist # db does not exist
result = handler.test_db_connection('sqlite:///bad/url/should/not/exist') result = handler.test_db_connection("sqlite:///bad/url/should/not/exist")
self.assertIn('unable to open database file', result) self.assertIn("unable to open database file", result)
# db is setup # db is setup
url = f'sqlite:///{self.tempdir}/db.sqlite' url = f"sqlite:///{self.tempdir}/db.sqlite"
engine = sa.create_engine(url) engine = sa.create_engine(url)
with engine.begin() as cxn: with engine.begin() as cxn:
cxn.execute(sa.text("create table whatever (id int primary key);")) cxn.execute(sa.text("create table whatever (id int primary key);"))
@ -163,27 +175,29 @@ class TestInstallHandler(ConfigTestCase):
def test_make_template_context(self): def test_make_template_context(self):
handler = self.make_handler() handler = self.make_handler()
dbinfo = {'dburl': 'sqlite:///poser.sqlite'} dbinfo = {"dburl": "sqlite:///poser.sqlite"}
context = handler.make_template_context(dbinfo) context = handler.make_template_context(dbinfo)
self.assertEqual(context['envdir'], sys.prefix) self.assertEqual(context["envdir"], sys.prefix)
self.assertEqual(context['pkg_name'], 'poser') self.assertEqual(context["pkg_name"], "poser")
self.assertEqual(context['app_title'], 'poser') self.assertEqual(context["app_title"], "poser")
self.assertEqual(context['pypi_name'], 'poser') self.assertEqual(context["pypi_name"], "poser")
self.assertEqual(context['egg_name'], 'poser') self.assertEqual(context["egg_name"], "poser")
self.assertEqual(context['appdir'], os.path.join(sys.prefix, 'app')) self.assertEqual(context["appdir"], os.path.join(sys.prefix, "app"))
self.assertEqual(context['db_url'], 'sqlite:///poser.sqlite') self.assertEqual(context["db_url"], "sqlite:///poser.sqlite")
def test_make_appdir(self): def test_make_appdir(self):
handler = self.make_handler() handler = self.make_handler()
handler.templates = TemplateLookup(directories=[ handler.templates = TemplateLookup(
self.app.resource_path('wuttjamaican:templates/install'), directories=[
]) self.app.resource_path("wuttjamaican:templates/install"),
dbinfo = {'dburl': 'sqlite:///poser.sqlite'} ]
)
dbinfo = {"dburl": "sqlite:///poser.sqlite"}
context = handler.make_template_context(dbinfo) context = handler.make_template_context(dbinfo)
handler.make_appdir(context, appdir=self.tempdir) handler.make_appdir(context, appdir=self.tempdir)
wutta_conf = os.path.join(self.tempdir, 'wutta.conf') wutta_conf = os.path.join(self.tempdir, "wutta.conf")
with open(wutta_conf, 'rt') as f: with open(wutta_conf, "rt") as f:
self.assertIn('default.url = sqlite:///poser.sqlite', f.read()) self.assertIn("default.url = sqlite:///poser.sqlite", f.read())
def test_install_db_schema(self): def test_install_db_schema(self):
try: try:
@ -192,89 +206,105 @@ class TestInstallHandler(ConfigTestCase):
pytest.skip("test is not relevant without sqlalchemy") pytest.skip("test is not relevant without sqlalchemy")
handler = self.make_handler() handler = self.make_handler()
db_url = f'sqlite:///{self.tempdir}/poser.sqlite' db_url = f"sqlite:///{self.tempdir}/poser.sqlite"
wutta_conf = self.write_file('wutta.conf', f""" wutta_conf = self.write_file(
"wutta.conf",
f"""
[wutta.db] [wutta.db]
default.url = {db_url} default.url = {db_url}
""") """,
)
# convert to proper URL object # convert to proper URL object
db_url = sa.create_engine(db_url).url db_url = sa.create_engine(db_url).url
with patch.object(mod, 'subprocess') as subprocess: with patch.object(mod, "subprocess") as subprocess:
# user declines offer to install schema # user declines offer to install schema
with patch.object(handler, 'prompt_bool', return_value=False): with patch.object(handler, "prompt_bool", return_value=False):
self.assertFalse(handler.install_db_schema(db_url, appdir=self.tempdir)) self.assertFalse(handler.install_db_schema(db_url, appdir=self.tempdir))
# user agrees to install schema # user agrees to install schema
with patch.object(handler, 'prompt_bool', return_value=True): with patch.object(handler, "prompt_bool", return_value=True):
self.assertTrue(handler.install_db_schema(db_url, appdir=self.tempdir)) self.assertTrue(handler.install_db_schema(db_url, appdir=self.tempdir))
subprocess.check_call.assert_called_once_with([ subprocess.check_call.assert_called_once_with(
os.path.join(sys.prefix, 'bin', 'alembic'), [
'-c', wutta_conf, 'upgrade', 'heads']) os.path.join(sys.prefix, "bin", "alembic"),
"-c",
wutta_conf,
"upgrade",
"heads",
]
)
def test_show_goodbye(self): def test_show_goodbye(self):
handler = self.make_handler() handler = self.make_handler()
with patch.object(handler, 'rprint') as rprint: with patch.object(handler, "rprint") as rprint:
handler.schema_installed = True handler.schema_installed = True
handler.show_goodbye() handler.show_goodbye()
rprint.assert_any_call("\n\t[bold green]initial setup is complete![/bold green]") rprint.assert_any_call(
"\n\t[bold green]initial setup is complete![/bold green]"
)
rprint.assert_any_call("\t[blue]bin/wutta -c app/web.conf webapp -r[/blue]") rprint.assert_any_call("\t[blue]bin/wutta -c app/web.conf webapp -r[/blue]")
def test_require_prompt_toolkit_installed(self): def test_require_prompt_toolkit_installed(self):
# nb. this assumes we *do* have prompt_toolkit installed # nb. this assumes we *do* have prompt_toolkit installed
handler = self.make_handler() handler = self.make_handler()
with patch.object(mod, 'subprocess') as subprocess: with patch.object(mod, "subprocess") as subprocess:
handler.require_prompt_toolkit(answer='Y') handler.require_prompt_toolkit(answer="Y")
self.assertFalse(subprocess.check_call.called) self.assertFalse(subprocess.check_call.called)
def test_require_prompt_toolkit_missing(self): def test_require_prompt_toolkit_missing(self):
handler = self.make_handler() handler = self.make_handler()
orig_import = __import__ orig_import = __import__
stuff = {'attempts': 0} stuff = {"attempts": 0}
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'prompt_toolkit': if name == "prompt_toolkit":
# nb. pretend this is not installed # nb. pretend this is not installed
raise ImportError raise ImportError
return orig_import(name, globals, locals, fromlist, level) return orig_import(name, globals, locals, fromlist, level)
# prompt_toolkit not installed, and user declines offer to install # prompt_toolkit not installed, and user declines offer to install
with patch('builtins.__import__', side_effect=mock_import): with patch("builtins.__import__", side_effect=mock_import):
with patch.object(mod, 'subprocess') as subprocess: with patch.object(mod, "subprocess") as subprocess:
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
sys.exit.side_effect = RuntimeError sys.exit.side_effect = RuntimeError
self.assertRaises(RuntimeError, handler.require_prompt_toolkit, answer='N') self.assertRaises(
RuntimeError, handler.require_prompt_toolkit, answer="N"
)
self.assertFalse(subprocess.check_call.called) self.assertFalse(subprocess.check_call.called)
sys.stderr.write.assert_called_once_with("prompt_toolkit is required; aborting\n") sys.stderr.write.assert_called_once_with(
"prompt_toolkit is required; aborting\n"
)
sys.exit.assert_called_once_with(1) sys.exit.assert_called_once_with(1)
def test_require_prompt_toolkit_missing_then_installed(self): def test_require_prompt_toolkit_missing_then_installed(self):
handler = self.make_handler() handler = self.make_handler()
orig_import = __import__ orig_import = __import__
stuff = {'attempts': 0} stuff = {"attempts": 0}
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'prompt_toolkit': if name == "prompt_toolkit":
stuff['attempts'] += 1 stuff["attempts"] += 1
if stuff['attempts'] == 1: if stuff["attempts"] == 1:
# nb. pretend this is not installed # nb. pretend this is not installed
raise ImportError raise ImportError
return orig_import('prompt_toolkit') return orig_import("prompt_toolkit")
return orig_import(name, globals, locals, fromlist, level) return orig_import(name, globals, locals, fromlist, level)
# prompt_toolkit not installed, and user declines offer to install # prompt_toolkit not installed, and user declines offer to install
with patch('builtins.__import__', side_effect=mock_import): with patch("builtins.__import__", side_effect=mock_import):
with patch.object(mod, 'subprocess') as subprocess: with patch.object(mod, "subprocess") as subprocess:
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
sys.executable = 'python' sys.executable = "python"
handler.require_prompt_toolkit(answer='Y') handler.require_prompt_toolkit(answer="Y")
subprocess.check_call.assert_called_once_with(['python', '-m', 'pip', subprocess.check_call.assert_called_once_with(
'install', 'prompt_toolkit']) ["python", "-m", "pip", "install", "prompt_toolkit"]
)
self.assertFalse(sys.exit.called) self.assertFalse(sys.exit.called)
self.assertEqual(stuff['attempts'], 2) self.assertEqual(stuff["attempts"], 2)
def test_prompt_generic(self): def test_prompt_generic(self):
handler = self.make_handler() handler = self.make_handler()
@ -283,86 +313,94 @@ default.url = {db_url}
mock_prompt = MagicMock() mock_prompt = MagicMock()
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'prompt_toolkit': if name == "prompt_toolkit":
if fromlist == ('prompt',): if fromlist == ("prompt",):
return MagicMock(prompt=mock_prompt) return MagicMock(prompt=mock_prompt)
return orig_import(name, globals, locals, fromlist, level) return orig_import(name, globals, locals, fromlist, level)
with patch('builtins.__import__', side_effect=mock_import): with patch("builtins.__import__", side_effect=mock_import):
with patch.object(handler, 'get_prompt_style', return_value=style): with patch.object(handler, "get_prompt_style", return_value=style):
with patch.object(handler, 'rprint') as rprint: with patch.object(handler, "rprint") as rprint:
# no input or default value # no input or default value
mock_prompt.return_value = '' mock_prompt.return_value = ""
result = handler.prompt_generic('foo') result = handler.prompt_generic("foo")
self.assertIsNone(result) self.assertIsNone(result)
mock_prompt.assert_called_once_with([('', '\n'), mock_prompt.assert_called_once_with(
('class:bold', 'foo'), [("", "\n"), ("class:bold", "foo"), ("", ": ")],
('', ': ')], style=style,
style=style, is_password=False) is_password=False,
)
# fallback to default value # fallback to default value
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = '' mock_prompt.return_value = ""
result = handler.prompt_generic('foo', default='baz') result = handler.prompt_generic("foo", default="baz")
self.assertEqual(result, 'baz') self.assertEqual(result, "baz")
mock_prompt.assert_called_once_with([('', '\n'), mock_prompt.assert_called_once_with(
('class:bold', 'foo'), [("", "\n"), ("class:bold", "foo"), ("", " [baz]: ")],
('', ' [baz]: ')], style=style,
style=style, is_password=False) is_password=False,
)
# text input value # text input value
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = 'bar' mock_prompt.return_value = "bar"
result = handler.prompt_generic('foo') result = handler.prompt_generic("foo")
self.assertEqual(result, 'bar') self.assertEqual(result, "bar")
mock_prompt.assert_called_once_with([('', '\n'), mock_prompt.assert_called_once_with(
('class:bold', 'foo'), [("", "\n"), ("class:bold", "foo"), ("", ": ")],
('', ': ')], style=style,
style=style, is_password=False) is_password=False,
)
# bool value (no default; true input) # bool value (no default; true input)
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = 'Y' mock_prompt.return_value = "Y"
result = handler.prompt_generic('foo', is_bool=True) result = handler.prompt_generic("foo", is_bool=True)
self.assertTrue(result) self.assertTrue(result)
mock_prompt.assert_called_once_with([('', '\n'), mock_prompt.assert_called_once_with(
('class:bold', 'foo'), [("", "\n"), ("class:bold", "foo"), ("", ": ")],
('', ': ')], style=style,
style=style, is_password=False) is_password=False,
)
# bool value (no default; false input) # bool value (no default; false input)
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = 'N' mock_prompt.return_value = "N"
result = handler.prompt_generic('foo', is_bool=True) result = handler.prompt_generic("foo", is_bool=True)
self.assertFalse(result) self.assertFalse(result)
mock_prompt.assert_called_once_with([('', '\n'), mock_prompt.assert_called_once_with(
('class:bold', 'foo'), [("", "\n"), ("class:bold", "foo"), ("", ": ")],
('', ': ')], style=style,
style=style, is_password=False) is_password=False,
)
# bool value (default; no input) # bool value (default; no input)
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = '' mock_prompt.return_value = ""
result = handler.prompt_generic('foo', is_bool=True, default=True) result = handler.prompt_generic("foo", is_bool=True, default=True)
self.assertTrue(result) self.assertTrue(result)
mock_prompt.assert_called_once_with([('', '\n'), mock_prompt.assert_called_once_with(
('class:bold', 'foo'), [("", "\n"), ("class:bold", "foo"), ("", " [Y]: ")],
('', ' [Y]: ')], style=style,
style=style, is_password=False) is_password=False,
)
# bool value (bad input) # bool value (bad input)
mock_prompt.reset_mock() mock_prompt.reset_mock()
counter = {'attempts': 0} counter = {"attempts": 0}
def omg(*args, **kwargs): def omg(*args, **kwargs):
counter['attempts'] += 1 counter["attempts"] += 1
if counter['attempts'] == 1: if counter["attempts"] == 1:
# nb. bad input first time we ask # nb. bad input first time we ask
return 'doesnotmakesense' return "doesnotmakesense"
# nb. but good input after that # nb. but good input after that
return 'N' return "N"
mock_prompt.side_effect = omg mock_prompt.side_effect = omg
result = handler.prompt_generic('foo', is_bool=True) result = handler.prompt_generic("foo", is_bool=True)
self.assertFalse(result) self.assertFalse(result)
# nb. user was prompted twice # nb. user was prompted twice
self.assertEqual(mock_prompt.call_count, 2) self.assertEqual(mock_prompt.call_count, 2)
@ -370,32 +408,34 @@ default.url = {db_url}
# Ctrl+C # Ctrl+C
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.side_effect = KeyboardInterrupt mock_prompt.side_effect = KeyboardInterrupt
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
sys.exit.side_effect = RuntimeError sys.exit.side_effect = RuntimeError
self.assertRaises(RuntimeError, handler.prompt_generic, 'foo') self.assertRaises(RuntimeError, handler.prompt_generic, "foo")
sys.exit.assert_called_once_with(1) sys.exit.assert_called_once_with(1)
# Ctrl+D # Ctrl+D
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.side_effect = EOFError mock_prompt.side_effect = EOFError
with patch.object(mod, 'sys') as sys: with patch.object(mod, "sys") as sys:
sys.exit.side_effect = RuntimeError sys.exit.side_effect = RuntimeError
self.assertRaises(RuntimeError, handler.prompt_generic, 'foo') self.assertRaises(RuntimeError, handler.prompt_generic, "foo")
sys.exit.assert_called_once_with(1) sys.exit.assert_called_once_with(1)
# missing required value # missing required value
mock_prompt.reset_mock() mock_prompt.reset_mock()
counter = {'attempts': 0} counter = {"attempts": 0}
def omg(*args, **kwargs): def omg(*args, **kwargs):
counter['attempts'] += 1 counter["attempts"] += 1
if counter['attempts'] == 1: if counter["attempts"] == 1:
# nb. no input first time we ask # nb. no input first time we ask
return '' return ""
# nb. but good input after that # nb. but good input after that
return 'bar' return "bar"
mock_prompt.side_effect = omg mock_prompt.side_effect = omg
result = handler.prompt_generic('foo', required=True) result = handler.prompt_generic("foo", required=True)
self.assertEqual(result, 'bar') self.assertEqual(result, "bar")
# nb. user was prompted twice # nb. user was prompted twice
self.assertEqual(mock_prompt.call_count, 2) self.assertEqual(mock_prompt.call_count, 2)
@ -405,47 +445,49 @@ default.url = {db_url}
mock_prompt = MagicMock() mock_prompt = MagicMock()
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'prompt_toolkit': if name == "prompt_toolkit":
if fromlist == ('prompt',): if fromlist == ("prompt",):
return MagicMock(prompt=mock_prompt) return MagicMock(prompt=mock_prompt)
return orig_import(name, globals, locals, fromlist, level) return orig_import(name, globals, locals, fromlist, level)
with patch('builtins.__import__', side_effect=mock_import): with patch("builtins.__import__", side_effect=mock_import):
with patch.object(handler, 'rprint') as rprint: with patch.object(handler, "rprint") as rprint:
# no default; true input # no default; true input
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = 'Y' mock_prompt.return_value = "Y"
result = handler.prompt_bool('foo') result = handler.prompt_bool("foo")
self.assertTrue(result) self.assertTrue(result)
mock_prompt.assert_called_once() mock_prompt.assert_called_once()
# no default; false input # no default; false input
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = 'N' mock_prompt.return_value = "N"
result = handler.prompt_bool('foo') result = handler.prompt_bool("foo")
self.assertFalse(result) self.assertFalse(result)
mock_prompt.assert_called_once() mock_prompt.assert_called_once()
# default; no input # default; no input
mock_prompt.reset_mock() mock_prompt.reset_mock()
mock_prompt.return_value = '' mock_prompt.return_value = ""
result = handler.prompt_bool('foo', default=True) result = handler.prompt_bool("foo", default=True)
self.assertTrue(result) self.assertTrue(result)
mock_prompt.assert_called_once() mock_prompt.assert_called_once()
# bad input # bad input
mock_prompt.reset_mock() mock_prompt.reset_mock()
counter = {'attempts': 0} counter = {"attempts": 0}
def omg(*args, **kwargs): def omg(*args, **kwargs):
counter['attempts'] += 1 counter["attempts"] += 1
if counter['attempts'] == 1: if counter["attempts"] == 1:
# nb. bad input first time we ask # nb. bad input first time we ask
return 'doesnotmakesense' return "doesnotmakesense"
# nb. but good input after that # nb. but good input after that
return 'N' return "N"
mock_prompt.side_effect = omg mock_prompt.side_effect = omg
result = handler.prompt_bool('foo') result = handler.prompt_bool("foo")
self.assertFalse(result) self.assertFalse(result)
# nb. user was prompted twice # nb. user was prompted twice
self.assertEqual(mock_prompt.call_count, 2) self.assertEqual(mock_prompt.call_count, 2)

View file

@ -9,7 +9,6 @@ except ImportError:
pass pass
else: else:
class TestPeopleHandler(DataTestCase): class TestPeopleHandler(DataTestCase):
def make_handler(self): def make_handler(self):
@ -17,7 +16,7 @@ else:
def test_get_person(self): def test_get_person(self):
model = self.app.model model = self.app.model
myperson = model.Person(full_name='Barny Rubble') myperson = model.Person(full_name="Barny Rubble")
self.session.add(myperson) self.session.add(myperson)
self.session.commit() self.session.commit()
handler = self.make_handler() handler = self.make_handler()
@ -31,7 +30,7 @@ else:
self.assertIs(person, myperson) self.assertIs(person, myperson)
# find person from user # find person from user
myuser = model.User(username='barney', person=myperson) myuser = model.User(username="barney", person=myperson)
self.session.add(myuser) self.session.add(myuser)
self.session.commit() self.session.commit()
person = handler.get_person(myuser) person = handler.get_person(myuser)
@ -48,9 +47,9 @@ else:
self.assertIsNone(person.full_name) self.assertIsNone(person.full_name)
self.assertNotIn(person, self.session) self.assertNotIn(person, self.session)
person = handler.make_person(first_name='Barney', last_name='Rubble') person = handler.make_person(first_name="Barney", last_name="Rubble")
self.assertIsInstance(person, model.Person) self.assertIsInstance(person, model.Person)
self.assertEqual(person.first_name, 'Barney') self.assertEqual(person.first_name, "Barney")
self.assertEqual(person.last_name, 'Rubble') self.assertEqual(person.last_name, "Rubble")
self.assertEqual(person.full_name, 'Barney Rubble') self.assertEqual(person.full_name, "Barney Rubble")
self.assertNotIn(person, self.session) self.assertNotIn(person, self.session)

View file

@ -14,15 +14,15 @@ class TestProblemCheck(ConfigTestCase):
def test_system_key(self): def test_system_key(self):
check = self.make_check() check = self.make_check()
self.assertRaises(AttributeError, getattr, check, 'system_key') self.assertRaises(AttributeError, getattr, check, "system_key")
def test_problem_key(self): def test_problem_key(self):
check = self.make_check() check = self.make_check()
self.assertRaises(AttributeError, getattr, check, 'problem_key') self.assertRaises(AttributeError, getattr, check, "problem_key")
def test_title(self): def test_title(self):
check = self.make_check() check = self.make_check()
self.assertRaises(AttributeError, getattr, check, 'title') self.assertRaises(AttributeError, getattr, check, "title")
def test_find_problems(self): def test_find_problems(self):
check = self.make_check() check = self.make_check()
@ -44,8 +44,8 @@ class TestProblemCheck(ConfigTestCase):
class FakeProblemCheck(mod.ProblemCheck): class FakeProblemCheck(mod.ProblemCheck):
system_key = 'wuttatest' system_key = "wuttatest"
problem_key = 'fake_check' problem_key = "fake_check"
title = "Fake problem check" title = "Fake problem check"
# def find_problems(self): # def find_problems(self):
@ -69,7 +69,7 @@ class TestProblemHandler(ConfigTestCase):
self.assertEqual(len(checks), 0) self.assertEqual(len(checks), 0)
# but let's configure our fake check # but let's configure our fake check
self.config.setdefault('wutta.problems.modules', 'tests.test_problems') self.config.setdefault("wutta.problems.modules", "tests.test_problems")
checks = self.handler.get_all_problem_checks() checks = self.handler.get_all_problem_checks()
self.assertIsInstance(checks, list) self.assertIsInstance(checks, list)
self.assertEqual(len(checks), 1) self.assertEqual(len(checks), 1)
@ -82,27 +82,31 @@ class TestProblemHandler(ConfigTestCase):
self.assertEqual(len(checks), 0) self.assertEqual(len(checks), 0)
# but let's configure our fake check # but let's configure our fake check
self.config.setdefault('wutta.problems.modules', 'tests.test_problems') self.config.setdefault("wutta.problems.modules", "tests.test_problems")
checks = self.handler.filter_problem_checks() checks = self.handler.filter_problem_checks()
self.assertIsInstance(checks, list) self.assertIsInstance(checks, list)
self.assertEqual(len(checks), 1) self.assertEqual(len(checks), 1)
# filter by system_key # filter by system_key
checks = self.handler.filter_problem_checks(systems=['wuttatest']) checks = self.handler.filter_problem_checks(systems=["wuttatest"])
self.assertEqual(len(checks), 1) self.assertEqual(len(checks), 1)
checks = self.handler.filter_problem_checks(systems=['something_else']) checks = self.handler.filter_problem_checks(systems=["something_else"])
self.assertEqual(len(checks), 0) self.assertEqual(len(checks), 0)
# filter by problem_key # filter by problem_key
checks = self.handler.filter_problem_checks(problems=['fake_check']) checks = self.handler.filter_problem_checks(problems=["fake_check"])
self.assertEqual(len(checks), 1) self.assertEqual(len(checks), 1)
checks = self.handler.filter_problem_checks(problems=['something_else']) checks = self.handler.filter_problem_checks(problems=["something_else"])
self.assertEqual(len(checks), 0) self.assertEqual(len(checks), 0)
# filter by both # filter by both
checks = self.handler.filter_problem_checks(systems=['wuttatest'], problems=['fake_check']) checks = self.handler.filter_problem_checks(
systems=["wuttatest"], problems=["fake_check"]
)
self.assertEqual(len(checks), 1) self.assertEqual(len(checks), 1)
checks = self.handler.filter_problem_checks(systems=['wuttatest'], problems=['bad_check']) checks = self.handler.filter_problem_checks(
systems=["wuttatest"], problems=["bad_check"]
)
self.assertEqual(len(checks), 0) self.assertEqual(len(checks), 0)
def test_get_supported_systems(self): def test_get_supported_systems(self):
@ -113,14 +117,14 @@ class TestProblemHandler(ConfigTestCase):
self.assertEqual(len(systems), 0) self.assertEqual(len(systems), 0)
# but let's configure our fake check # but let's configure our fake check
self.config.setdefault('wutta.problems.modules', 'tests.test_problems') self.config.setdefault("wutta.problems.modules", "tests.test_problems")
systems = self.handler.get_supported_systems() systems = self.handler.get_supported_systems()
self.assertIsInstance(systems, list) self.assertIsInstance(systems, list)
self.assertEqual(systems, ['wuttatest']) self.assertEqual(systems, ["wuttatest"])
def test_get_system_title(self): def test_get_system_title(self):
title = self.handler.get_system_title('wutta') title = self.handler.get_system_title("wutta")
self.assertEqual(title, 'wutta') self.assertEqual(title, "wutta")
def test_is_enabled(self): def test_is_enabled(self):
check = FakeProblemCheck(self.config) check = FakeProblemCheck(self.config)
@ -129,7 +133,7 @@ class TestProblemHandler(ConfigTestCase):
self.assertTrue(self.handler.is_enabled(check)) self.assertTrue(self.handler.is_enabled(check))
# config can disable # config can disable
self.config.setdefault('wutta.problems.wuttatest.fake_check.enabled', 'false') self.config.setdefault("wutta.problems.wuttatest.fake_check.enabled", "false")
self.assertFalse(self.handler.is_enabled(check)) self.assertFalse(self.handler.is_enabled(check))
def test_should_run_for_weekday(self): def test_should_run_for_weekday(self):
@ -140,8 +144,8 @@ class TestProblemHandler(ConfigTestCase):
self.assertTrue(self.handler.should_run_for_weekday(check, weekday)) self.assertTrue(self.handler.should_run_for_weekday(check, weekday))
# config can disable, e.g. for weekends # config can disable, e.g. for weekends
self.config.setdefault('wutta.problems.wuttatest.fake_check.day5', 'false') self.config.setdefault("wutta.problems.wuttatest.fake_check.day5", "false")
self.config.setdefault('wutta.problems.wuttatest.fake_check.day6', 'false') self.config.setdefault("wutta.problems.wuttatest.fake_check.day6", "false")
for weekday in range(5): for weekday in range(5):
self.assertTrue(self.handler.should_run_for_weekday(check, weekday)) self.assertTrue(self.handler.should_run_for_weekday(check, weekday))
for weekday in (5, 6): for weekday in (5, 6):
@ -152,10 +156,10 @@ class TestProblemHandler(ConfigTestCase):
organized = self.handler.organize_problem_checks(checks) organized = self.handler.organize_problem_checks(checks)
self.assertIsInstance(organized, dict) self.assertIsInstance(organized, dict)
self.assertEqual(list(organized), ['wuttatest']) self.assertEqual(list(organized), ["wuttatest"])
self.assertIsInstance(organized['wuttatest'], dict) self.assertIsInstance(organized["wuttatest"], dict)
self.assertEqual(list(organized['wuttatest']), ['fake_check']) self.assertEqual(list(organized["wuttatest"]), ["fake_check"])
self.assertIs(organized['wuttatest']['fake_check'], FakeProblemCheck) self.assertIs(organized["wuttatest"]["fake_check"], FakeProblemCheck)
def test_find_problems(self): def test_find_problems(self):
check = FakeProblemCheck(self.config) check = FakeProblemCheck(self.config)
@ -165,7 +169,7 @@ class TestProblemHandler(ConfigTestCase):
def test_get_email_key(self): def test_get_email_key(self):
check = FakeProblemCheck(self.config) check = FakeProblemCheck(self.config)
key = self.handler.get_email_key(check) key = self.handler.get_email_key(check)
self.assertEqual(key, 'wuttatest_problems_fake_check') self.assertEqual(key, "wuttatest_problems_fake_check")
def test_get_global_email_context(self): def test_get_global_email_context(self):
context = self.handler.get_global_email_context() context = self.handler.get_global_email_context()
@ -175,44 +179,53 @@ class TestProblemHandler(ConfigTestCase):
check = FakeProblemCheck(self.config) check = FakeProblemCheck(self.config)
problems = [] problems = []
context = self.handler.get_check_email_context(check, problems) context = self.handler.get_check_email_context(check, problems)
self.assertEqual(context, {'system_title': 'wuttatest'}) self.assertEqual(context, {"system_title": "wuttatest"})
def test_send_problem_report(self): def test_send_problem_report(self):
check = FakeProblemCheck(self.config) check = FakeProblemCheck(self.config)
problems = [] problems = []
with patch.object(self.app, 'send_email') as send_email: with patch.object(self.app, "send_email") as send_email:
self.handler.send_problem_report(check, problems) self.handler.send_problem_report(check, problems)
send_email.assert_called_once_with('wuttatest_problems_fake_check', { send_email.assert_called_once_with(
'system_title': 'wuttatest', "wuttatest_problems_fake_check",
'config': self.config, {
'app': self.app, "system_title": "wuttatest",
'check': check, "config": self.config,
'problems': problems, "app": self.app,
}, default_subject="Fake problem check", attachments=None) "check": check,
"problems": problems,
},
default_subject="Fake problem check",
attachments=None,
)
def test_run_problem_check(self): def test_run_problem_check(self):
with patch.object(FakeProblemCheck, 'find_problems') as find_problems: with patch.object(FakeProblemCheck, "find_problems") as find_problems:
with patch.object(self.handler, 'send_problem_report') as send_problem_report: with patch.object(
self.handler, "send_problem_report"
) as send_problem_report:
# check runs by default # check runs by default
find_problems.return_value = [{'foo': 'bar'}] find_problems.return_value = [{"foo": "bar"}]
problems = self.handler.run_problem_check(FakeProblemCheck) problems = self.handler.run_problem_check(FakeProblemCheck)
self.assertEqual(problems, [{'foo': 'bar'}]) self.assertEqual(problems, [{"foo": "bar"}])
find_problems.assert_called_once_with() find_problems.assert_called_once_with()
send_problem_report.assert_called_once() send_problem_report.assert_called_once()
# does not run if generally disabled # does not run if generally disabled
find_problems.reset_mock() find_problems.reset_mock()
send_problem_report.reset_mock() send_problem_report.reset_mock()
with patch.object(self.handler, 'is_enabled', return_value=False): with patch.object(self.handler, "is_enabled", return_value=False):
problems = self.handler.run_problem_check(FakeProblemCheck) problems = self.handler.run_problem_check(FakeProblemCheck)
self.assertIsNone(problems) self.assertIsNone(problems)
find_problems.assert_not_called() find_problems.assert_not_called()
send_problem_report.assert_not_called() send_problem_report.assert_not_called()
# unless caller gives force flag # unless caller gives force flag
problems = self.handler.run_problem_check(FakeProblemCheck, force=True) problems = self.handler.run_problem_check(
self.assertEqual(problems, [{'foo': 'bar'}]) FakeProblemCheck, force=True
)
self.assertEqual(problems, [{"foo": "bar"}])
find_problems.assert_called_once_with() find_problems.assert_called_once_with()
send_problem_report.assert_called_once() send_problem_report.assert_called_once()
@ -220,7 +233,9 @@ class TestProblemHandler(ConfigTestCase):
find_problems.reset_mock() find_problems.reset_mock()
send_problem_report.reset_mock() send_problem_report.reset_mock()
weekday = datetime.date.today().weekday() weekday = datetime.date.today().weekday()
self.config.setdefault(f'wutta.problems.wuttatest.fake_check.day{weekday}', 'false') self.config.setdefault(
f"wutta.problems.wuttatest.fake_check.day{weekday}", "false"
)
problems = self.handler.run_problem_check(FakeProblemCheck) problems = self.handler.run_problem_check(FakeProblemCheck)
self.assertIsNone(problems) self.assertIsNone(problems)
find_problems.assert_not_called() find_problems.assert_not_called()
@ -228,16 +243,18 @@ class TestProblemHandler(ConfigTestCase):
# unless caller gives force flag # unless caller gives force flag
problems = self.handler.run_problem_check(FakeProblemCheck, force=True) problems = self.handler.run_problem_check(FakeProblemCheck, force=True)
self.assertEqual(problems, [{'foo': 'bar'}]) self.assertEqual(problems, [{"foo": "bar"}])
find_problems.assert_called_once_with() find_problems.assert_called_once_with()
send_problem_report.assert_called_once() send_problem_report.assert_called_once()
def test_run_problem_checks(self): def test_run_problem_checks(self):
with patch.object(FakeProblemCheck, 'find_problems') as find_problems: with patch.object(FakeProblemCheck, "find_problems") as find_problems:
with patch.object(self.handler, 'send_problem_report') as send_problem_report: with patch.object(
self.handler, "send_problem_report"
) as send_problem_report:
# check runs by default # check runs by default
find_problems.return_value = [{'foo': 'bar'}] find_problems.return_value = [{"foo": "bar"}]
self.handler.run_problem_checks([FakeProblemCheck]) self.handler.run_problem_checks([FakeProblemCheck])
find_problems.assert_called_once_with() find_problems.assert_called_once_with()
send_problem_report.assert_called_once() send_problem_report.assert_called_once()
@ -245,7 +262,7 @@ class TestProblemHandler(ConfigTestCase):
# does not run if generally disabled # does not run if generally disabled
find_problems.reset_mock() find_problems.reset_mock()
send_problem_report.reset_mock() send_problem_report.reset_mock()
with patch.object(self.handler, 'is_enabled', return_value=False): with patch.object(self.handler, "is_enabled", return_value=False):
self.handler.run_problem_checks([FakeProblemCheck]) self.handler.run_problem_checks([FakeProblemCheck])
find_problems.assert_not_called() find_problems.assert_not_called()
send_problem_report.assert_not_called() send_problem_report.assert_not_called()
@ -259,7 +276,9 @@ class TestProblemHandler(ConfigTestCase):
find_problems.reset_mock() find_problems.reset_mock()
send_problem_report.reset_mock() send_problem_report.reset_mock()
weekday = datetime.date.today().weekday() weekday = datetime.date.today().weekday()
self.config.setdefault(f'wutta.problems.wuttatest.fake_check.day{weekday}', 'false') self.config.setdefault(
f"wutta.problems.wuttatest.fake_check.day{weekday}", "false"
)
self.handler.run_problem_checks([FakeProblemCheck]) self.handler.run_problem_checks([FakeProblemCheck])
find_problems.assert_not_called() find_problems.assert_not_called()
send_problem_report.assert_not_called() send_problem_report.assert_not_called()

View file

@ -10,7 +10,7 @@ class TestProgressBase(TestCase):
def test_basic(self): def test_basic(self):
# sanity / coverage check # sanity / coverage check
prog = mod.ProgressBase('testing', 2) prog = mod.ProgressBase("testing", 2)
prog.update(1) prog.update(1)
prog.update(2) prog.update(2)
prog.finish() prog.finish()
@ -21,7 +21,7 @@ class TestConsoleProgress(TestCase):
def test_basic(self): def test_basic(self):
# sanity / coverage check # sanity / coverage check
prog = mod.ConsoleProgress('testing', 2) prog = mod.ConsoleProgress("testing", 2)
prog.update(1) prog.update(1)
prog.update(2) prog.update(2)
prog.finish() prog.finish()

View file

@ -7,12 +7,12 @@ from wuttjamaican.testing import ConfigTestCase
class MockFooReport(mod.Report): class MockFooReport(mod.Report):
report_key = 'mock_foo' report_key = "mock_foo"
report_title = "MOCK Report" report_title = "MOCK Report"
def make_data(self, params, **kwargs): def make_data(self, params, **kwargs):
return [ return [
{'foo': 'bar'}, {"foo": "bar"},
] ]
@ -35,15 +35,15 @@ class TestReportHandler(ConfigTestCase):
def test_get_report_modules(self): def test_get_report_modules(self):
# no providers, no report modules # no providers, no report modules
with patch.object(self.app, 'providers', new={}): with patch.object(self.app, "providers", new={}):
handler = self.make_handler() handler = self.make_handler()
self.assertEqual(handler.get_report_modules(), []) self.assertEqual(handler.get_report_modules(), [])
# provider may specify modules as list # provider may specify modules as list
providers = { providers = {
'wuttatest': MagicMock(report_modules=['wuttjamaican.reports']), "wuttatest": MagicMock(report_modules=["wuttjamaican.reports"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
modules = handler.get_report_modules() modules = handler.get_report_modules()
self.assertEqual(len(modules), 1) self.assertEqual(len(modules), 1)
@ -51,9 +51,9 @@ class TestReportHandler(ConfigTestCase):
# provider may specify modules as string # provider may specify modules as string
providers = { providers = {
'wuttatest': MagicMock(report_modules='wuttjamaican.reports'), "wuttatest": MagicMock(report_modules="wuttjamaican.reports"),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
modules = handler.get_report_modules() modules = handler.get_report_modules()
self.assertEqual(len(modules), 1) self.assertEqual(len(modules), 1)
@ -62,54 +62,54 @@ class TestReportHandler(ConfigTestCase):
def test_get_reports(self): def test_get_reports(self):
# no providers, no reports # no providers, no reports
with patch.object(self.app, 'providers', new={}): with patch.object(self.app, "providers", new={}):
handler = self.make_handler() handler = self.make_handler()
self.assertEqual(handler.get_reports(), {}) self.assertEqual(handler.get_reports(), {})
# provider may define reports (via modules) # provider may define reports (via modules)
providers = { providers = {
'wuttatest': MagicMock(report_modules=['tests.test_reports']), "wuttatest": MagicMock(report_modules=["tests.test_reports"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
reports = handler.get_reports() reports = handler.get_reports()
self.assertEqual(len(reports), 1) self.assertEqual(len(reports), 1)
self.assertIn('mock_foo', reports) self.assertIn("mock_foo", reports)
def test_get_report(self): def test_get_report(self):
providers = { providers = {
'wuttatest': MagicMock(report_modules=['tests.test_reports']), "wuttatest": MagicMock(report_modules=["tests.test_reports"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
# as instance # as instance
report = handler.get_report('mock_foo') report = handler.get_report("mock_foo")
self.assertIsInstance(report, mod.Report) self.assertIsInstance(report, mod.Report)
self.assertIsInstance(report, MockFooReport) self.assertIsInstance(report, MockFooReport)
# as class # as class
report = handler.get_report('mock_foo', instance=False) report = handler.get_report("mock_foo", instance=False)
self.assertTrue(issubclass(report, mod.Report)) self.assertTrue(issubclass(report, mod.Report))
self.assertIs(report, MockFooReport) self.assertIs(report, MockFooReport)
# not found # not found
report = handler.get_report('unknown') report = handler.get_report("unknown")
self.assertIsNone(report) self.assertIsNone(report)
def test_make_report_data(self): def test_make_report_data(self):
providers = { providers = {
'wuttatest': MagicMock(report_modules=['tests.test_reports']), "wuttatest": MagicMock(report_modules=["tests.test_reports"]),
} }
with patch.object(self.app, 'providers', new=providers): with patch.object(self.app, "providers", new=providers):
handler = self.make_handler() handler = self.make_handler()
report = handler.get_report('mock_foo') report = handler.get_report("mock_foo")
data = handler.make_report_data(report) data = handler.make_report_data(report)
self.assertEqual(len(data), 2) self.assertEqual(len(data), 2)
self.assertIn('output_title', data) self.assertIn("output_title", data)
self.assertEqual(data['output_title'], "MOCK Report") self.assertEqual(data["output_title"], "MOCK Report")
self.assertIn('data', data) self.assertIn("data", data)
self.assertEqual(data['data'], [{'foo': 'bar'}]) self.assertEqual(data["data"], [{"foo": "bar"}])

View file

@ -10,9 +10,17 @@ from wuttjamaican import util as mod
from wuttjamaican.progress import ProgressBase from wuttjamaican.progress import ProgressBase
class A: pass class A:
class B(A): pass pass
class C(B): pass
class B(A):
pass
class C(B):
pass
class TestGetClassHierarchy(TestCase): class TestGetClassHierarchy(TestCase):
@ -35,15 +43,15 @@ class TestLoadEntryPoints(TestCase):
def test_empty(self): def test_empty(self):
# empty set returned for unknown group # empty set returned for unknown group
result = mod.load_entry_points('this_should_never_exist!!!!!!') result = mod.load_entry_points("this_should_never_exist!!!!!!")
self.assertEqual(result, {}) self.assertEqual(result, {})
def test_basic(self): def test_basic(self):
# load some entry points which should "always" be present, # load some entry points which should "always" be present,
# even in a testing environment. basic sanity check # even in a testing environment. basic sanity check
result = mod.load_entry_points('console_scripts', ignore_errors=True) result = mod.load_entry_points("console_scripts", ignore_errors=True)
self.assertTrue(len(result) >= 1) self.assertTrue(len(result) >= 1)
self.assertIn('pip', result) self.assertIn("pip", result)
def test_basic_pre_python_3_10(self): def test_basic_pre_python_3_10(self):
@ -54,6 +62,7 @@ class TestLoadEntryPoints(TestCase):
pytest.skip("this test is not relevant before python 3.10") pytest.skip("this test is not relevant before python 3.10")
import importlib.metadata import importlib.metadata
real_entry_points = importlib.metadata.entry_points() real_entry_points = importlib.metadata.entry_points()
class FakeEntryPoints(dict): class FakeEntryPoints(dict):
@ -63,13 +72,13 @@ class TestLoadEntryPoints(TestCase):
importlib = MagicMock() importlib = MagicMock()
importlib.metadata.entry_points.return_value = FakeEntryPoints() importlib.metadata.entry_points.return_value = FakeEntryPoints()
with patch.dict('sys.modules', **{'importlib': importlib}): with patch.dict("sys.modules", **{"importlib": importlib}):
# load some entry points which should "always" be present, # load some entry points which should "always" be present,
# even in a testing environment. basic sanity check # even in a testing environment. basic sanity check
result = mod.load_entry_points('console_scripts', ignore_errors=True) result = mod.load_entry_points("console_scripts", ignore_errors=True)
self.assertTrue(len(result) >= 1) self.assertTrue(len(result) >= 1)
self.assertIn('pytest', result) self.assertIn("pytest", result)
def test_basic_pre_python_3_8(self): def test_basic_pre_python_3_8(self):
@ -80,11 +89,12 @@ class TestLoadEntryPoints(TestCase):
pytest.skip("this test is not relevant before python 3.8") pytest.skip("this test is not relevant before python 3.8")
from importlib.metadata import entry_points from importlib.metadata import entry_points
real_entry_points = entry_points() real_entry_points = entry_points()
class FakeEntryPoints(dict): class FakeEntryPoints(dict):
def get(self, group, default): def get(self, group, default):
if hasattr(real_entry_points, 'select'): if hasattr(real_entry_points, "select"):
return real_entry_points.select(group=group) return real_entry_points.select(group=group)
return real_entry_points.get(group, []) return real_entry_points.get(group, [])
@ -94,19 +104,19 @@ class TestLoadEntryPoints(TestCase):
orig_import = __import__ orig_import = __import__
def mock_import(name, *args, **kwargs): def mock_import(name, *args, **kwargs):
if name == 'importlib.metadata': if name == "importlib.metadata":
raise ImportError raise ImportError
if name == 'importlib_metadata': if name == "importlib_metadata":
return importlib_metadata return importlib_metadata
return orig_import(name, *args, **kwargs) return orig_import(name, *args, **kwargs)
with patch('builtins.__import__', side_effect=mock_import): with patch("builtins.__import__", side_effect=mock_import):
# load some entry points which should "always" be present, # load some entry points which should "always" be present,
# even in a testing environment. basic sanity check # even in a testing environment. basic sanity check
result = mod.load_entry_points('console_scripts', ignore_errors=True) result = mod.load_entry_points("console_scripts", ignore_errors=True)
self.assertTrue(len(result) >= 1) self.assertTrue(len(result) >= 1)
self.assertIn('pytest', result) self.assertIn("pytest", result)
def test_error(self): def test_error(self):
@ -123,22 +133,24 @@ class TestLoadEntryPoints(TestCase):
importlib = MagicMock() importlib = MagicMock()
importlib.metadata.entry_points.return_value = entry_points importlib.metadata.entry_points.return_value = entry_points
with patch.dict('sys.modules', **{'importlib': importlib}): with patch.dict("sys.modules", **{"importlib": importlib}):
# empty set returned if errors suppressed # empty set returned if errors suppressed
result = mod.load_entry_points('wuttatest.thingers', ignore_errors=True) result = mod.load_entry_points("wuttatest.thingers", ignore_errors=True)
self.assertEqual(result, {}) self.assertEqual(result, {})
importlib.metadata.entry_points.assert_called_once_with() importlib.metadata.entry_points.assert_called_once_with()
entry_points.select.assert_called_once_with(group='wuttatest.thingers') entry_points.select.assert_called_once_with(group="wuttatest.thingers")
entry_point.load.assert_called_once_with() entry_point.load.assert_called_once_with()
# error is raised, if not suppressed # error is raised, if not suppressed
importlib.metadata.entry_points.reset_mock() importlib.metadata.entry_points.reset_mock()
entry_points.select.reset_mock() entry_points.select.reset_mock()
entry_point.load.reset_mock() entry_point.load.reset_mock()
self.assertRaises(NotImplementedError, mod.load_entry_points, 'wuttatest.thingers') self.assertRaises(
NotImplementedError, mod.load_entry_points, "wuttatest.thingers"
)
importlib.metadata.entry_points.assert_called_once_with() importlib.metadata.entry_points.assert_called_once_with()
entry_points.select.assert_called_once_with(group='wuttatest.thingers') entry_points.select.assert_called_once_with(group="wuttatest.thingers")
entry_point.load.assert_called_once_with() entry_point.load.assert_called_once_with()
@ -148,7 +160,7 @@ class TestLoadObject(TestCase):
self.assertRaises(ValueError, mod.load_object, None) self.assertRaises(ValueError, mod.load_object, None)
def test_basic(self): def test_basic(self):
result = mod.load_object('unittest:TestCase') result = mod.load_object("unittest:TestCase")
self.assertIs(result, TestCase) self.assertIs(result, TestCase)
@ -169,20 +181,20 @@ class TestParseBool(TestCase):
self.assertFalse(mod.parse_bool(False)) self.assertFalse(mod.parse_bool(False))
def test_string_true(self): def test_string_true(self):
self.assertTrue(mod.parse_bool('true')) self.assertTrue(mod.parse_bool("true"))
self.assertTrue(mod.parse_bool('yes')) self.assertTrue(mod.parse_bool("yes"))
self.assertTrue(mod.parse_bool('y')) self.assertTrue(mod.parse_bool("y"))
self.assertTrue(mod.parse_bool('on')) self.assertTrue(mod.parse_bool("on"))
self.assertTrue(mod.parse_bool('1')) self.assertTrue(mod.parse_bool("1"))
def test_string_false(self): def test_string_false(self):
self.assertFalse(mod.parse_bool('false')) self.assertFalse(mod.parse_bool("false"))
self.assertFalse(mod.parse_bool('no')) self.assertFalse(mod.parse_bool("no"))
self.assertFalse(mod.parse_bool('n')) self.assertFalse(mod.parse_bool("n"))
self.assertFalse(mod.parse_bool('off')) self.assertFalse(mod.parse_bool("off"))
self.assertFalse(mod.parse_bool('0')) self.assertFalse(mod.parse_bool("0"))
# nb. assume false for unrecognized input # nb. assume false for unrecognized input
self.assertFalse(mod.parse_bool('whatever-else')) self.assertFalse(mod.parse_bool("whatever-else"))
class TestParseList(TestCase): class TestParseList(TestCase):
@ -198,76 +210,82 @@ class TestParseList(TestCase):
self.assertIs(value, mylist) self.assertIs(value, mylist)
def test_single_value(self): def test_single_value(self):
value = mod.parse_list('foo') value = mod.parse_list("foo")
self.assertEqual(len(value), 1) self.assertEqual(len(value), 1)
self.assertEqual(value[0], 'foo') self.assertEqual(value[0], "foo")
def test_single_value_padded_by_spaces(self): def test_single_value_padded_by_spaces(self):
value = mod.parse_list(' foo ') value = mod.parse_list(" foo ")
self.assertEqual(len(value), 1) self.assertEqual(len(value), 1)
self.assertEqual(value[0], 'foo') self.assertEqual(value[0], "foo")
def test_slash_is_not_a_separator(self): def test_slash_is_not_a_separator(self):
value = mod.parse_list('/dev/null') value = mod.parse_list("/dev/null")
self.assertEqual(len(value), 1) self.assertEqual(len(value), 1)
self.assertEqual(value[0], '/dev/null') self.assertEqual(value[0], "/dev/null")
def test_multiple_values_separated_by_whitespace(self): def test_multiple_values_separated_by_whitespace(self):
value = mod.parse_list('foo bar baz') value = mod.parse_list("foo bar baz")
self.assertEqual(len(value), 3) self.assertEqual(len(value), 3)
self.assertEqual(value[0], 'foo') self.assertEqual(value[0], "foo")
self.assertEqual(value[1], 'bar') self.assertEqual(value[1], "bar")
self.assertEqual(value[2], 'baz') self.assertEqual(value[2], "baz")
def test_multiple_values_separated_by_commas(self): def test_multiple_values_separated_by_commas(self):
value = mod.parse_list('foo,bar,baz') value = mod.parse_list("foo,bar,baz")
self.assertEqual(len(value), 3) self.assertEqual(len(value), 3)
self.assertEqual(value[0], 'foo') self.assertEqual(value[0], "foo")
self.assertEqual(value[1], 'bar') self.assertEqual(value[1], "bar")
self.assertEqual(value[2], 'baz') self.assertEqual(value[2], "baz")
def test_multiple_values_separated_by_whitespace_and_commas(self): def test_multiple_values_separated_by_whitespace_and_commas(self):
value = mod.parse_list(' foo, bar baz') value = mod.parse_list(" foo, bar baz")
self.assertEqual(len(value), 3) self.assertEqual(len(value), 3)
self.assertEqual(value[0], 'foo') self.assertEqual(value[0], "foo")
self.assertEqual(value[1], 'bar') self.assertEqual(value[1], "bar")
self.assertEqual(value[2], 'baz') self.assertEqual(value[2], "baz")
def test_multiple_values_separated_by_whitespace_and_commas_with_some_quoting(self): def test_multiple_values_separated_by_whitespace_and_commas_with_some_quoting(self):
value = mod.parse_list(""" value = mod.parse_list(
"""
foo foo
"C:\\some path\\with spaces\\and, a comma", "C:\\some path\\with spaces\\and, a comma",
baz baz
""") """
)
self.assertEqual(len(value), 3) self.assertEqual(len(value), 3)
self.assertEqual(value[0], 'foo') self.assertEqual(value[0], "foo")
self.assertEqual(value[1], 'C:\\some path\\with spaces\\and, a comma') self.assertEqual(value[1], "C:\\some path\\with spaces\\and, a comma")
self.assertEqual(value[2], 'baz') self.assertEqual(value[2], "baz")
def test_multiple_values_separated_by_whitespace_and_commas_with_single_quotes(self): def test_multiple_values_separated_by_whitespace_and_commas_with_single_quotes(
value = mod.parse_list(""" self,
):
value = mod.parse_list(
"""
foo foo
'C:\\some path\\with spaces\\and, a comma', 'C:\\some path\\with spaces\\and, a comma',
baz baz
""") """
)
self.assertEqual(len(value), 3) self.assertEqual(len(value), 3)
self.assertEqual(value[0], 'foo') self.assertEqual(value[0], "foo")
self.assertEqual(value[1], 'C:\\some path\\with spaces\\and, a comma') self.assertEqual(value[1], "C:\\some path\\with spaces\\and, a comma")
self.assertEqual(value[2], 'baz') self.assertEqual(value[2], "baz")
class TestMakeTitle(TestCase): class TestMakeTitle(TestCase):
def test_basic(self): def test_basic(self):
text = mod.make_title('foo_bar') text = mod.make_title("foo_bar")
self.assertEqual(text, "Foo Bar") self.assertEqual(text, "Foo Bar")
class TestMakeFullName(TestCase): class TestMakeFullName(TestCase):
def test_basic(self): def test_basic(self):
name = mod.make_full_name('Fred', '', 'Flintstone', '') name = mod.make_full_name("Fred", "", "Flintstone", "")
self.assertEqual(name, 'Fred Flintstone') self.assertEqual(name, "Fred Flintstone")
class TestProgressLoop(TestCase): class TestProgressLoop(TestCase):
@ -278,12 +296,10 @@ class TestProgressLoop(TestCase):
pass pass
# with progress # with progress
mod.progress_loop(act, [1, 2, 3], ProgressBase, mod.progress_loop(act, [1, 2, 3], ProgressBase, message="whatever")
message="whatever")
# without progress # without progress
mod.progress_loop(act, [1, 2, 3], None, mod.progress_loop(act, [1, 2, 3], None, message="whatever")
message="whatever")
class TestResourcePath(TestCase): class TestResourcePath(TestCase):
@ -291,11 +307,13 @@ class TestResourcePath(TestCase):
def test_basic(self): def test_basic(self):
# package spec is resolved to path # package spec is resolved to path
path = mod.resource_path('wuttjamaican:util.py') path = mod.resource_path("wuttjamaican:util.py")
self.assertTrue(path.endswith('wuttjamaican/util.py')) self.assertTrue(path.endswith("wuttjamaican/util.py"))
# absolute path returned as-is # absolute path returned as-is
self.assertEqual(mod.resource_path('/tmp/doesnotexist.txt'), '/tmp/doesnotexist.txt') self.assertEqual(
mod.resource_path("/tmp/doesnotexist.txt"), "/tmp/doesnotexist.txt"
)
def test_basic_pre_python_3_9(self): def test_basic_pre_python_3_9(self):
@ -310,20 +328,22 @@ class TestResourcePath(TestCase):
orig_import = __import__ orig_import = __import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0): def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'importlib.resources': if name == "importlib.resources":
raise ImportError raise ImportError
if name == 'importlib_resources': if name == "importlib_resources":
return MagicMock(files=files, as_file=as_file) return MagicMock(files=files, as_file=as_file)
return orig_import(name, globals, locals, fromlist, level) return orig_import(name, globals, locals, fromlist, level)
with patch('builtins.__import__', side_effect=mock_import): with patch("builtins.__import__", side_effect=mock_import):
# package spec is resolved to path # package spec is resolved to path
path = mod.resource_path('wuttjamaican:util.py') path = mod.resource_path("wuttjamaican:util.py")
self.assertTrue(path.endswith('wuttjamaican/util.py')) self.assertTrue(path.endswith("wuttjamaican/util.py"))
# absolute path returned as-is # absolute path returned as-is
self.assertEqual(mod.resource_path('/tmp/doesnotexist.txt'), '/tmp/doesnotexist.txt') self.assertEqual(
mod.resource_path("/tmp/doesnotexist.txt"), "/tmp/doesnotexist.txt"
)
class TestSimpleError(TestCase): class TestSimpleError(TestCase):

View file

@ -10,6 +10,12 @@ commands = pytest {posargs}
[testenv:nox] [testenv:nox]
extras = tests extras = tests
[testenv:black]
basepython = python3.11
extras = tests
deps =
commands = black --check .
[testenv:pylint] [testenv:pylint]
basepython = python3.11 basepython = python3.11
extras = db,tests extras = db,tests