Compare commits

...

3 commits

Author SHA1 Message Date
Lance Edgar 12daf6a1e3 feat: add basic postgres module for db setup 2024-11-20 12:18:58 -06:00
Lance Edgar 3c75194c26 feat: add ssh module with cache_host_key() function 2024-11-20 11:09:28 -06:00
Lance Edgar c41d364e03 feat: add util.mako_renderer() function 2024-11-20 10:29:31 -06:00
13 changed files with 492 additions and 2 deletions

View file

@ -0,0 +1,6 @@
``wuttamess.postgres``
======================
.. automodule:: wuttamess.postgres
:members:

View file

@ -0,0 +1,6 @@
``wuttamess.ssh``
=================
.. automodule:: wuttamess.ssh
:members:

View file

@ -32,5 +32,7 @@ project.
api/wuttamess api/wuttamess
api/wuttamess.apt api/wuttamess.apt
api/wuttamess.postfix api/wuttamess.postfix
api/wuttamess.postgres
api/wuttamess.ssh
api/wuttamess.sync api/wuttamess.sync
api/wuttamess.util api/wuttamess.util

View file

@ -52,12 +52,15 @@ merely a personal convention. You can define tasks however you need::
""" """
from fabric import task from fabric import task
from wuttamess import apt, sync from wuttamess import apt, sync, util
# nb. this is used below, for file sync # nb. this is used below, for file sync
root = sync.make_root('files') root = sync.make_root('files')
# nb. this is for global mako template context etc.
env = {'machine_is_live': False}
@task @task
def bootstrap_all(c): def bootstrap_all(c):
@ -74,11 +77,13 @@ merely a personal convention. You can define tasks however you need::
""" """
Bootstrap the base system Bootstrap the base system
""" """
renderers = {'mako': util.mako_renderer(c, env)}
apt.dist_upgrade(c) apt.dist_upgrade(c)
# postfix # postfix
apt.install(c, 'postfix') apt.install(c, 'postfix')
if sync.check_isync(c, root, 'etc/postfix'): if sync.check_isync(c, root, 'etc/postfix', renderers=renderers):
c.run('systemctl restart postfix') c.run('systemctl restart postfix')

View file

@ -31,6 +31,8 @@ requires-python = ">= 3.8"
dependencies = [ dependencies = [
"fabric", "fabric",
"fabsync", "fabsync",
"mako",
"typing_extensions",
] ]

154
src/wuttamess/postgres.py Normal file
View file

@ -0,0 +1,154 @@
# -*- coding: utf-8; -*-
################################################################################
#
# WuttaMess -- Fabric Automation Helpers
# Copyright © 2024 Lance Edgar
#
# This file is part of Wutta Framework.
#
# Wutta Framework is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# Wutta Framework is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# Wutta Framework. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
PostgreSQL DB utilities
"""
def sql(c, sql, database='', port=None, **kwargs):
"""
Execute some SQL as the ``postgres`` user.
:param c: Fabric connection.
:param sql: SQL string to execute.
:param database: Name of the database on which to execute the SQL.
If not specified, default ``postgres`` is assumed.
:param port: Optional port for PostgreSQL; default is 5432.
"""
port = f' --port={port}' if port else ''
return c.sudo(f'psql{port} --tuples-only --no-align --command="{sql}" {database}',
user='postgres', **kwargs)
def user_exists(c, name, port=None):
"""
Determine if a given PostgreSQL user exists.
:param c: Fabric connection.
:param name: Username to check for.
:param port: Optional port for PostgreSQL; default is 5432.
:returns: ``True`` if user exists, else ``False``.
"""
user = sql(c, f"SELECT rolname FROM pg_roles WHERE rolname = '{name}'", port=port).stdout.strip()
return bool(user)
def create_user(c, name, password=None, port=None, checkfirst=True):
"""
Create a PostgreSQL user account.
:param c: Fabric connection.
:param name: Username to create.
:param password: Optional password for the new user. If set, will
call :func:`set_user_password()`.
:param port: Optional port for PostgreSQL; default is 5432.
:param checkfirst: If true (the default), first check if user
exists and skip creating if already present. If false, then
try to create user with no check.
"""
if not checkfirst or not user_exists(c, name, port=port):
portarg = f' --port={port}' if port else ''
c.sudo(f'createuser{portarg} --no-createrole --no-superuser {name}',
user='postgres')
if password:
set_user_password(c, name, password, port=port)
def set_user_password(c, name, password, port=None):
"""
Set the password for a PostgreSQL user account.
:param c: Fabric connection.
:param name: Username whose password is to be set.
:param password: Password for the new user.
:param port: Optional port for PostgreSQL; default is 5432.
"""
sql(c, f"ALTER USER \\\"{name}\\\" PASSWORD '{password}';", port=port, hide=True, echo=False)
def db_exists(c, name, port=None):
"""
Determine if a given PostgreSQL database exists.
:param c: Fabric connection.
:param name: Name of the database to check for.
:param port: Optional port for PostgreSQL; default is 5432.
:returns: ``True`` if database exists, else ``False``.
"""
db = sql(c, f"SELECT datname FROM pg_database WHERE datname = '{name}'", port=port).stdout.strip()
return db == name
def create_db(c, name, owner=None, port=None, checkfirst=True):
"""
Create a PostgreSQL database.
:param c: Fabric connection.
:param name: Name of the database to create.
:param owner: Optional role name to set as owner for the database.
:param port: Optional port for PostgreSQL; default is 5432.
:param checkfirst: If true (the default), first check if DB exists
and skip creating if already present. If false, then try to
create DB with no check.
"""
if not checkfirst or not db_exists(c, name, port=port):
port = f' --port={port}' if port else ''
owner = f' --owner={owner}' if owner else ''
c.sudo(f'createdb{port}{owner} {name}',
user='postgres')
def drop_db(c, name, checkfirst=True):
"""
Drop a PostgreSQL database.
:param c: Fabric connection.
:param name: Name of the database to drop.
:param checkfirst: If true (the default), first check if DB exists
and skip dropping if not present. If false, then try to drop
DB with no check.
"""
if not checkfirst or db_exists(c, name):
c.sudo(f'dropdb {name}', user='postgres')

75
src/wuttamess/ssh.py Normal file
View file

@ -0,0 +1,75 @@
# -*- coding: utf-8; -*-
################################################################################
#
# WuttaMess -- Fabric Automation Helpers
# Copyright © 2024 Lance Edgar
#
# This file is part of Wutta Framework.
#
# Wutta Framework is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# Wutta Framework is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# Wutta Framework. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
SSH Utilities
"""
def cache_host_key(c, host, port=None, user=None):
"""
Cache the SSH host key for the given host, for the given user.
:param c: Fabric connection.
:param host: Name or IP of the host whose key should be cached.
Note that you can specify a username along with the hostname if
needed, e.g. any of these works:
* ``1.2.3.4``
* ``foo@1.2.3.4``
* ``example.com``
* ``foo@example.com``
:param port: Optional SSH port for the ``host``; default is 22.
:param user: User on the fabric target whose SSH key cache should
be updated to include the given ``host``.
"""
port = f'-p {port} ' if port else ''
# first try to run a basic command over ssh
cmd = f'ssh {port}{host} whoami'
if user and user != 'root':
result = c.sudo(cmd, user=user, warn=True)
else:
result = c.run(cmd, warn=True)
# no need to update cache if command worked okay
if not result.failed:
return
# basic command failed, but in some cases that is simply b/c
# normal commands are not allowed, although the ssh connection
# itself was established okay. so here we check for that.
if "Disallowed command" in result.stderr:
return
# okay then we now think that the ssh connection itself
# was not made, which presumably means we *do* need to
# cache the host key, so try that now
cmd = f'ssh -o StrictHostKeyChecking=no {port}{host} whoami'
if user and user != 'root':
c.sudo(cmd, user=user, warn=True)
else:
c.run(cmd, warn=True)

View file

@ -24,9 +24,46 @@
Misc. Utilities Misc. Utilities
""" """
from pathlib import Path
from typing_extensions import Any, Mapping
from mako.template import Template
def exists(c, path): def exists(c, path):
""" """
Returns ``True`` if given path exists on the host, otherwise ``False``. Returns ``True`` if given path exists on the host, otherwise ``False``.
""" """
return not c.run(f'test -e {path}', warn=True).failed return not c.run(f'test -e {path}', warn=True).failed
def mako_renderer(c, env={}):
"""
This returns a *function* suitable for use as a ``fabsync`` file
renderer. The function assumes the file is a Mako template.
:param c: Fabric connection.
:param env: Environment dictionary to be used as Mako template
context.
Typical usage is something like::
from fabric import task
from wuttamess import sync, util
root = sync.make_root('files')
env = {}
@task
def foo(c):
# define possible renderers for fabsync
renderers = {'mako': util.mako_renderer(c, env)}
sync.check_isync(c, root, 'etc/postfix', renderers=renderers)
"""
def render(path: Path, vars: Mapping[str, Any], **kwargs) -> bytes:
return Template(filename=str(path)).render(**env)
return render

View file

@ -0,0 +1,4 @@
[files."baz"]
renderer = 'mako'
tags = ['baz']

1
tests/files/bar/baz Normal file
View file

@ -0,0 +1 @@
machine_is_live = ${machine_is_live}

124
tests/test_postgres.py Normal file
View file

@ -0,0 +1,124 @@
# -*- coding: utf-8; -*-
from unittest import TestCase
from unittest.mock import MagicMock, patch
from wuttamess import postgres as mod
class TestSql(TestCase):
def test_basic(self):
c = MagicMock()
mod.sql(c, "select @@version")
c.sudo.assert_called_once_with('psql --tuples-only --no-align --command="select @@version" ',
user='postgres')
class TestUserExists(TestCase):
def test_user_exists(self):
c = MagicMock()
with patch.object(mod, 'sql') as sql:
sql.return_value.stdout = 'foo'
self.assertTrue(mod.user_exists(c, 'foo'))
sql.assert_called_once_with(c, "SELECT rolname FROM pg_roles WHERE rolname = 'foo'", port=None)
def test_user_does_not_exist(self):
c = MagicMock()
with patch.object(mod, 'sql') as sql:
sql.return_value.stdout = ''
self.assertFalse(mod.user_exists(c, 'foo'))
sql.assert_called_once_with(c, "SELECT rolname FROM pg_roles WHERE rolname = 'foo'", port=None)
class TestCreateUser(TestCase):
def test_basic(self):
c = MagicMock()
with patch.object(mod, 'set_user_password') as set_user_password:
mod.create_user(c, 'foo', checkfirst=False)
c.sudo.assert_called_once_with('createuser --no-createrole --no-superuser foo',
user='postgres')
set_user_password.assert_not_called()
def test_user_exists(self):
c = MagicMock()
with patch.object(mod, 'user_exists') as user_exists:
user_exists.return_value = True
mod.create_user(c, 'foo')
user_exists.assert_called_once_with(c, 'foo', port=None)
c.sudo.assert_not_called()
def test_with_password(self):
c = MagicMock()
with patch.object(mod, 'set_user_password') as set_user_password:
mod.create_user(c, 'foo', 'foopass', checkfirst=False)
c.sudo.assert_called_once_with('createuser --no-createrole --no-superuser foo',
user='postgres')
set_user_password.assert_called_once_with(c, 'foo', 'foopass', port=None)
class TestSetUserPassword(TestCase):
def test_basic(self):
c = MagicMock()
with patch.object(mod, 'sql') as sql:
mod.set_user_password(c, 'foo', 'foopass')
sql.assert_called_once_with(c, "ALTER USER \\\"foo\\\" PASSWORD 'foopass';",
port=None, hide=True, echo=False)
class TestDbExists(TestCase):
def test_db_exists(self):
c = MagicMock()
with patch.object(mod, 'sql') as sql:
sql.return_value.stdout = 'foo'
self.assertTrue(mod.db_exists(c, 'foo'))
sql.assert_called_once_with(c, "SELECT datname FROM pg_database WHERE datname = 'foo'", port=None)
def test_db_does_not_exist(self):
c = MagicMock()
with patch.object(mod, 'sql') as sql:
sql.return_value.stdout = ''
self.assertFalse(mod.db_exists(c, 'foo'))
sql.assert_called_once_with(c, "SELECT datname FROM pg_database WHERE datname = 'foo'", port=None)
class TestCreateDb(TestCase):
def test_basic(self):
c = MagicMock()
mod.create_db(c, 'foo', checkfirst=False)
c.sudo.assert_called_once_with('createdb foo', user='postgres')
def test_db_exists(self):
c = MagicMock()
with patch.object(mod, 'db_exists') as db_exists:
db_exists.return_value = True
mod.create_db(c, 'foo')
db_exists.assert_called_once_with(c, 'foo', port=None)
c.sudo.assert_not_called()
class TestDropDb(TestCase):
def test_basic(self):
c = MagicMock()
mod.drop_db(c, 'foo', checkfirst=False)
c.sudo.assert_called_once_with('dropdb foo', user='postgres')
def test_db_does_not_exist(self):
c = MagicMock()
with patch.object(mod, 'db_exists') as db_exists:
db_exists.return_value = False
mod.drop_db(c, 'foo')
db_exists.assert_called_once_with(c, 'foo')
c.sudo.assert_not_called()

62
tests/test_ssh.py Normal file
View file

@ -0,0 +1,62 @@
# -*- coding: utf-8; -*-
from unittest import TestCase
from unittest.mock import MagicMock, call
from wuttamess import ssh as mod
class TestCacheHostKey(TestCase):
def test_root_already_cached(self):
c = MagicMock()
# assume the first command runs okay
c.run.return_value.failed = False
mod.cache_host_key(c, 'example.com')
c.run.assert_called_once_with('ssh example.com whoami', warn=True)
def test_root_commands_not_allowed(self):
c = MagicMock()
# assume the first command fails b/c "disallowed"
c.run.return_value.failed = True
c.run.return_value.stderr = "Disallowed command"
mod.cache_host_key(c, 'example.com')
c.run.assert_called_once_with('ssh example.com whoami', warn=True)
def test_root_cache_key(self):
c = MagicMock()
# first command fails; second command caches host key
c.run.return_value.failed = True
mod.cache_host_key(c, 'example.com')
c.run.assert_has_calls([call('ssh example.com whoami', warn=True)])
c.run.assert_called_with('ssh -o StrictHostKeyChecking=no example.com whoami', warn=True)
def test_user_already_cached(self):
c = MagicMock()
# assume the first command runs okay
c.sudo.return_value.failed = False
mod.cache_host_key(c, 'example.com', user='foo')
c.sudo.assert_called_once_with('ssh example.com whoami', user='foo', warn=True)
def test_user_commands_not_allowed(self):
c = MagicMock()
# assume the first command fails b/c "disallowed"
c.sudo.return_value.failed = True
c.sudo.return_value.stderr = "Disallowed command"
mod.cache_host_key(c, 'example.com', user='foo')
c.sudo.assert_called_once_with('ssh example.com whoami', user='foo', warn=True)
def test_user_cache_key(self):
c = MagicMock()
# first command fails; second command caches host key
c.sudo.return_value.failed = True
mod.cache_host_key(c, 'example.com', user='foo')
c.sudo.assert_has_calls([call('ssh example.com whoami', user='foo', warn=True)])
c.sudo.assert_called_with('ssh -o StrictHostKeyChecking=no example.com whoami',
user='foo', warn=True)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8; -*- # -*- coding: utf-8; -*-
import os
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -12,3 +13,14 @@ class TestExists(TestCase):
c = MagicMock() c = MagicMock()
mod.exists(c, '/foo') mod.exists(c, '/foo')
c.run.assert_called_once_with('test -e /foo', warn=True) c.run.assert_called_once_with('test -e /foo', warn=True)
class TestMakoRenderer(TestCase):
def test_basic(self):
c = MagicMock()
renderer = mod.mako_renderer(c, env={'machine_is_live': True})
here = os.path.dirname(__file__)
path = os.path.join(here, 'files', 'bar', 'baz')
rendered = renderer(path, vars={})
self.assertEqual(rendered, 'machine_is_live = True')