3
0
Fork 0

fix: prompt for continuum support in app installer

unless installer declares static preference, then do not prompt.

also, do not prompt if continuum packages are missing.
This commit is contained in:
Lance Edgar 2026-01-04 22:47:32 -06:00
parent d018d4e764
commit dd56fbcc2d
3 changed files with 255 additions and 95 deletions

View file

@ -87,6 +87,9 @@ class InstallHandler(GenericHandler): # pylint: disable=too-many-public-methods
egg_name = None
schema_installed = False
# nb. we prompt the user for this, unless attr already has value
wants_continuum = None
template_paths = ["wuttjamaican:templates/install"]
def __init__(self, config, **kwargs):
@ -185,58 +188,84 @@ class InstallHandler(GenericHandler): # pylint: disable=too-many-public-methods
This method is called by :meth:`run()` and does the following:
* call :meth:`get_dbinfo()` to get DB info from user, and test connection
* call :meth:`prompt_user_for_context()` to collect DB info etc.
* call :meth:`make_template_context()` to use when generating output
* call :meth:`make_appdir()` to create app dir with config files
* call :meth:`install_db_schema()` to (optionally) create tables in DB
"""
# prompt user for db info
dbinfo = self.get_dbinfo()
# get context for generated app files
context = self.make_template_context(dbinfo)
# prompt user / get context
context = self.prompt_user_for_context()
context = self.make_template_context(**context)
# make the appdir
self.make_appdir(context)
# install db schema if user likes
self.schema_installed = self.install_db_schema(dbinfo["dburl"])
self.schema_installed = self.install_db_schema(context["db_url"])
def get_dbinfo(self):
def prompt_user_for_context(self):
"""
Collect DB connection info from the user, and test the
connection. If connection fails, exit the install.
This is responsible for initial user prompts.
This method is normally called by :meth:`do_install_steps()`.
This happens early in the install, so this method can verify
the info, e.g. test the DB connection, but should not write
any files as the app dir may not exist yet.
:returns: Dict of DB info collected from user.
Default logic calls :meth:`get_db_url()` for the DB
connection, then may ask about Wutta-Continuum data
versioning. (The latter is skipped if the package is
missing.)
Subclass should override this method if they need different
prompting logic. The return value should always include at
least these 2 items:
* ``db_url`` - URL for the DB connection
* ``wants_continuum`` - whether data versioning should be enabled
:returns: Dict of template context
"""
dbinfo = {}
# db info
db_url = self.get_db_url()
# get db info
dbinfo["dbtype"] = self.prompt_generic("db type", "postgresql")
dbinfo["dbhost"] = self.prompt_generic("db host", "localhost")
default_port = "3306" if dbinfo["dbtype"] == "mysql" else "5432"
dbinfo["dbport"] = self.prompt_generic("db port", default_port)
dbinfo["dbname"] = self.prompt_generic("db name", self.pkg_name)
dbinfo["dbuser"] = self.prompt_generic("db user", self.pkg_name)
# continuum
if self.wants_continuum is None:
try:
import wutta_continuum
except ImportError:
self.wants_continuum = False
else:
self.wants_continuum = self.prompt_bool(
"use continuum for data versioning?", default=False
)
# get db password
dbinfo["dbpass"] = None
while not dbinfo["dbpass"]:
dbinfo["dbpass"] = self.prompt_generic("db pass", is_password=True)
return {"db_url": db_url, "wants_continuum": self.wants_continuum}
def get_db_url(self):
"""
This must return the DB engine URL.
Default logic will prompt the user for hostname, port, DB name
and credentials. It then assembles the URL from those parts.
This method will also test the DB connection. If it fails,
the install is aborted.
This method is normally called by
:meth:`prompt_user_for_context()`.
:returns: SQLAlchemy engine URL (as object or string)
"""
# get db info/url
dbinfo = self.get_dbinfo()
if "db_url" in dbinfo:
db_url = dbinfo["db_url"]
else:
db_url = self.make_db_url(dbinfo)
# test db connection
self.rprint("\n\ttesting db connection... ", end="")
dbinfo["dburl"] = self.make_db_url(
dbinfo["dbtype"],
dbinfo["dbhost"],
dbinfo["dbport"],
dbinfo["dbname"],
dbinfo["dbuser"],
dbinfo["dbpass"],
)
error = self.test_db_connection(dbinfo["dburl"])
error = self.test_db_connection(db_url)
if error:
self.rprint("[bold red]cannot connect![/bold red] ..error was:")
self.rprint(f"\n{error}")
@ -244,26 +273,42 @@ class InstallHandler(GenericHandler): # pylint: disable=too-many-public-methods
sys.exit(1)
self.rprint("[bold green]good[/bold green]")
return db_url
def get_dbinfo(self): # pylint: disable=missing-function-docstring
dbinfo = {}
# main info
dbinfo["dbtype"] = self.prompt_generic("db type", "postgresql")
dbinfo["dbhost"] = self.prompt_generic("db host", "localhost")
default_port = "3306" if dbinfo["dbtype"] == "mysql" else "5432"
dbinfo["dbport"] = self.prompt_generic("db port", default_port)
dbinfo["dbname"] = self.prompt_generic("db name", self.pkg_name)
dbinfo["dbuser"] = self.prompt_generic("db user", self.pkg_name)
# password
dbinfo["dbpass"] = None
while not dbinfo["dbpass"]:
dbinfo["dbpass"] = self.prompt_generic("db pass", is_password=True)
return dbinfo
def make_db_url(
self, dbtype, dbhost, dbport, dbname, dbuser, dbpass
): # pylint: disable=empty-docstring,too-many-arguments,too-many-positional-arguments
def make_db_url(self, dbinfo): # pylint: disable=empty-docstring
""" """
from sqlalchemy.engine import URL # pylint: disable=import-outside-toplevel
if dbtype == "mysql":
if dbinfo["dbtype"] == "mysql":
drivername = "mysql+mysqlconnector"
else:
drivername = "postgresql+psycopg2"
return URL.create(
drivername=drivername,
username=dbuser,
password=dbpass,
host=dbhost,
port=dbport,
database=dbname,
username=dbinfo["dbuser"],
password=dbinfo["dbpass"],
host=dbinfo["dbhost"],
port=dbinfo["dbport"],
database=dbinfo["dbname"],
)
def test_db_connection(self, url): # pylint: disable=empty-docstring
@ -280,7 +325,7 @@ class InstallHandler(GenericHandler): # pylint: disable=too-many-public-methods
return str(error)
return None
def make_template_context(self, dbinfo, **kwargs):
def make_template_context(self, **kwargs):
"""
This must return a dict to be used as global template context
when generating output (e.g. config) files.
@ -289,14 +334,19 @@ class InstallHandler(GenericHandler): # pylint: disable=too-many-public-methods
The ``context`` returned is then passed to
:meth:`render_mako_template()`.
:param dbinfo: Dict of DB connection info as obtained from
:meth:`get_dbinfo()`.
Note these first 2 params are not explicitly listed in the
method signature; they are required nonetheless.
:param db_url: This must be a string URL for the DB engine.
:param wants_continuum: Whether data versioning should be
enabled within the config.
:param \\**kwargs: Extra template context.
:returns: Dict for global template context.
The context dict will include:
The final context dict should include at least:
* ``envdir`` - value from :data:`python:sys.prefix`
* ``envname`` - "last" dirname from ``sys.prefix``
@ -305,13 +355,16 @@ class InstallHandler(GenericHandler): # pylint: disable=too-many-public-methods
* ``pypi_name`` - value from :attr:`pypi_name`
* ``egg_name`` - value from :attr:`egg_name`
* ``appdir`` - ``app`` folder under ``sys.prefix``
* ``db_url`` - value from ``dbinfo['dburl']``
* ``db_url`` - value from ``kwargs``
* ``wants_continuum`` - value from ``kwargs``
"""
envname = os.path.basename(sys.prefix)
appdir = os.path.join(sys.prefix, "app")
dburl = dbinfo["dburl"]
if not isinstance(dburl, str):
dburl = dburl.render_as_string(hide_password=False)
db_url = kwargs.pop("db_url")
if not isinstance(db_url, str):
db_url = db_url.render_as_string(hide_password=False)
context = {
"envdir": sys.prefix,
"envname": envname,
@ -319,8 +372,8 @@ class InstallHandler(GenericHandler): # pylint: disable=too-many-public-methods
"app_title": self.app_title,
"pypi_name": self.pypi_name,
"appdir": appdir,
"db_url": dburl,
"egg_name": self.egg_name,
"db_url": db_url,
}
context.update(kwargs)
return context

View file

@ -56,8 +56,11 @@ preferdb = true
<%def name="section_wutta_db()">
[wutta.db]
default.url = ${db_url}
## TODO
## versioning.enabled = true
% if wants_continuum:
[wutta_continuum]
enable_versioning = true
% endif
</%def>
<%def name="section_wutta_mail()">
@ -94,7 +97,11 @@ templates = wuttjamaican:templates/mail
<%def name="section_alembic()">
[alembic]
script_location = wuttjamaican.db:alembic
% if wants_continuum:
version_locations = ${pkg_name}.db:alembic/versions wutta_continuum.db:alembic/versions wuttjamaican.db:alembic/versions
% else:
version_locations = ${pkg_name}.db:alembic/versions wuttjamaican.db:alembic/versions
% endif
</%def>
<%def name="sectiongroup_logging()">

View file

@ -64,25 +64,119 @@ class TestInstallHandler(ConfigTestCase):
def test_do_install_steps(self):
handler = self.make_handler()
handler.templates = TemplateLookup(
directories=[
self.app.resource_path("wuttjamaican:templates/install"),
]
)
dbinfo = {
"dburl": f"sqlite:///{self.tempdir}/poser.sqlite",
}
db_url = f"sqlite:///{self.tempdir}/poser.sqlite"
with patch.object(handler, "get_dbinfo", return_value=dbinfo):
with patch.object(handler, "prompt_user_for_context") as prompt_user:
prompt_user.return_value = {"db_url": db_url, "wants_continuum": False}
with patch.object(handler, "make_appdir") as make_appdir:
with patch.object(handler, "install_db_schema") as install_db_schema:
with patch.object(handler, "install_db_schema") as install_schema:
# nb. just for sanity/coverage
install_db_schema.return_value = True
self.assertFalse(handler.schema_installed)
install_schema.return_value = True
handler.do_install_steps()
self.assertTrue(make_appdir.called)
prompt_user.assert_called_once()
make_appdir.assert_called_once()
install_schema.assert_called_once_with(db_url)
self.assertTrue(handler.schema_installed)
install_db_schema.assert_called_once_with(dbinfo["dburl"])
def test_prompt_user_for_context(self):
db_url = f"sqlite:///{self.tempdir}/poser.sqlite"
with patch.object(mod.InstallHandler, "get_db_url", return_value=db_url):
# should prompt for continuum by default
handler = self.make_handler()
with patch.object(handler, "prompt_bool") as prompt_bool:
prompt_bool.return_value = True
context = handler.prompt_user_for_context()
prompt_bool.assert_called_once_with(
"use continuum for data versioning?", default=False
)
self.assertEqual(context, {"db_url": db_url, "wants_continuum": True})
# should not prompt if continuum flag already true
handler = self.make_handler()
with patch.object(handler, "wants_continuum", new=True):
with patch.object(handler, "prompt_bool") as prompt_bool:
context = handler.prompt_user_for_context()
prompt_bool.assert_not_called()
self.assertEqual(
context, {"db_url": db_url, "wants_continuum": True}
)
# should not prompt if continuum flag already false
handler = self.make_handler()
with patch.object(handler, "wants_continuum", new=False):
with patch.object(handler, "prompt_bool") as prompt_bool:
context = handler.prompt_user_for_context()
prompt_bool.assert_not_called()
self.assertEqual(
context, {"db_url": db_url, "wants_continuum": False}
)
# should not prompt if continuum pkg missing...
handler = self.make_handler()
with patch("builtins.__import__", side_effect=ImportError):
with patch.object(handler, "prompt_bool") as prompt_bool:
context = handler.prompt_user_for_context()
prompt_bool.assert_not_called()
self.assertEqual(
context, {"db_url": db_url, "wants_continuum": False}
)
def test_get_db_url(self):
try:
import sqlalchemy
from wuttjamaican.db.util import SA2
except ImportError:
pytest.skip("test is not relevant without sqlalchemy")
handler = self.make_handler()
# url from dbinfo is returned, if present
dbinfo = {"db_url": "sqlite:///"}
with patch.object(handler, "get_dbinfo", return_value=dbinfo):
db_url = handler.get_db_url()
self.assertEqual(db_url, "sqlite:///")
# or url will be assembled from dbinfo parts
dbinfo = {
"dbtype": "postgresql",
"dbhost": "localhost",
"dbport": 5432,
"dbname": "poser",
"dbuser": "poser",
"dbpass": "seekrit",
}
with patch.object(handler, "get_dbinfo", return_value=dbinfo):
with patch.object(handler, "test_db_connection", return_value=None):
db_url = handler.get_db_url()
seekrit = "***" if SA2 else "seekrit"
self.assertEqual(
str(db_url),
f"postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser",
)
# now we test the "test db connection" feature
dbinfo = {"db_url": "sqlite:///"}
with patch.object(handler, "get_dbinfo", return_value=dbinfo):
with patch.object(handler, "test_db_connection") as test_db_connection:
with patch.object(handler, "rprint") as rprint:
with patch.object(mod, "sys") as sys:
# pretend user gave bad dbinfo; should exit
test_db_connection.return_value = "bad dbinfo"
sys.exit.side_effect = RuntimeError
self.assertRaises(RuntimeError, handler.get_db_url)
sys.exit.assert_called_once_with(1)
# pretend user gave good dbinfo
sys.exit.reset_mock()
test_db_connection.return_value = None
db_url = handler.get_db_url()
self.assertFalse(sys.exit.called)
rprint.assert_called_with("[bold green]good[/bold green]")
self.assertEqual(str(db_url), "sqlite:///")
def test_get_dbinfo(self):
try:
@ -101,27 +195,19 @@ class TestInstallHandler(ConfigTestCase):
return "seekrit"
return default
with patch.object(mod, "sys") as sys:
with patch.object(handler, "prompt_generic", side_effect=prompt_generic):
with patch.object(handler, "test_db_connection") as test_db_connection:
with patch.object(handler, "rprint") as rprint:
# bad dbinfo
test_db_connection.return_value = "bad dbinfo"
sys.exit.side_effect = RuntimeError
self.assertRaises(RuntimeError, handler.get_dbinfo)
sys.exit.assert_called_once_with(1)
seekrit = "***" if SA2 else "seekrit"
# good dbinfo
sys.exit.reset_mock()
test_db_connection.return_value = None
dbinfo = handler.get_dbinfo()
self.assertFalse(sys.exit.called)
rprint.assert_called_with("[bold green]good[/bold green]")
self.assertEqual(
str(dbinfo["dburl"]),
f"postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser",
dbinfo,
{
"dbtype": "postgresql",
"dbhost": "localhost",
"dbport": "5432",
"dbname": "poser",
"dbuser": "poser",
"dbpass": "seekrit",
},
)
def test_make_db_url(self):
@ -136,14 +222,28 @@ class TestInstallHandler(ConfigTestCase):
seekrit = "***" if SA2 else "seekrit"
url = handler.make_db_url(
"postgresql", "localhost", "5432", "poser", "poser", "seekrit"
dict(
dbtype="postgresql",
dbhost="localhost",
dbport="5432",
dbname="poser",
dbuser="poser",
dbpass="seekrit",
)
)
self.assertEqual(
str(url), f"postgresql+psycopg2://poser:{seekrit}@localhost:5432/poser"
)
url = handler.make_db_url(
"mysql", "localhost", "3306", "poser", "poser", "seekrit"
dict(
dbtype="mysql",
dbhost="localhost",
dbport="3306",
dbname="poser",
dbuser="poser",
dbpass="seekrit",
)
)
self.assertEqual(
str(url), f"mysql+mysqlconnector://poser:{seekrit}@localhost:3306/poser"
@ -172,8 +272,8 @@ class TestInstallHandler(ConfigTestCase):
handler = self.make_handler()
# can handle dburl as string
dbinfo = {"dburl": "sqlite:///poser.sqlite"}
context = handler.make_template_context(dbinfo)
db_url = "sqlite:///poser.sqlite"
context = handler.make_template_context(db_url=db_url)
self.assertEqual(context["envdir"], sys.prefix)
self.assertEqual(context["pkg_name"], "poser")
self.assertEqual(context["app_title"], "poser")
@ -188,8 +288,8 @@ class TestInstallHandler(ConfigTestCase):
pytest.skip("remainder of test is not relevant without sqlalchemy")
# but also can handle dburl as object
dbinfo = {"dburl": sa.create_engine("sqlite:///poser.sqlite").url}
context = handler.make_template_context(dbinfo)
db_url = sa.create_engine("sqlite:///poser.sqlite").url
context = handler.make_template_context(db_url=db_url)
self.assertEqual(context["envdir"], sys.prefix)
self.assertEqual(context["pkg_name"], "poser")
self.assertEqual(context["app_title"], "poser")
@ -205,8 +305,8 @@ class TestInstallHandler(ConfigTestCase):
self.app.resource_path("wuttjamaican:templates/install"),
]
)
dbinfo = {"dburl": "sqlite:///poser.sqlite"}
context = handler.make_template_context(dbinfo)
db_url = "sqlite:///poser.sqlite"
context = handler.make_template_context(db_url=db_url)
handler.make_appdir(context, appdir=self.tempdir)
wutta_conf = os.path.join(self.tempdir, "wutta.conf")
with open(wutta_conf, "rt") as f: