# -*- coding: utf-8; -*-
# Rattail -- Retail Software Framework
# Copyright © 2010-2021 Lance Edgar
# This file is part of Rattail.
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
# You should have received a copy of the GNU General Public License along with
# Rattail. If not, see <http://www.gnu.org/licenses/>.
Fabric Library for PostgreSQL
import os
import re
from rattail_fabric2 import apt, append, contains, sed, uncomment
def install(c):
Install the PostgreSQL database service
apt.install(c, 'postgresql', 'libpq-dev')
def get_version(c):
Fetch the version of PostgreSQL running on the target system
result = c.run('psql --version', warn=True)
if not result.failed:
match = re.match(r'^psql \(PostgreSQL\) (\d+\.\d+)(?:\.\d+)?', result.stdout.strip())
if match:
return float(match.group(1))
def set_listen_addresses(c, *addresses):
Overwrite the `listen_addresses` config setting in `postgresql.conf`.
version = get_version(c)
if 12 <= version < 13:
version = 12
addresses = ','.join(addresses)
# TODO: this does not seem to work, so cannot do this..? need to
# figure it out, since it's the whole reason i added this function
# addresses = addresses.replace('*', '\\\\*')
if not contains(c, '/etc/postgresql/{}/main/postgresql.conf'.format(version),
"listen_addresses = '{}'".format(addresses)):
uncomment(c, '/etc/postgresql/{}/main/postgresql.conf'.format(version),
r'^#listen_addresses = .*',
sed(c, '/etc/postgresql/{}/main/postgresql.conf'.format(version),
r'listen_addresses = .*',
"listen_addresses = '{}'".format(addresses),
def add_hba_entry(c, entry):
Add an entry to the `pg_hba.conf` file.
version = get_version(c)
if 12 <= version < 13:
version = 12
if not contains(c, '/etc/postgresql/{}/main/pg_hba.conf'.format(version),
entry, use_sudo=True):
append(c, '/etc/postgresql/{}/main/pg_hba.conf'.format(version),
entry, use_sudo=True)
def restart(c):
Restart the PostgreSQL database service
c.sudo('systemctl restart postgresql.service')
def reload_(c):
Reload config for the PostgreSQL database service
c.sudo('systemctl reload postgresql.service')
# TODO: deprecate / remove this
reload = reload_
def sql(c, sql, database='', port=None, **kwargs):
Execute some SQL as the 'postgres' user.
cmd = 'psql {port} --tuples-only --no-align --command="{sql}" {database}'.format(
port='--port={}'.format(port) if port else '',
sql=sql, database=database)
return c.sudo(cmd, user='postgres', **kwargs)
def script(c, path, database='', port=None, user=None, password=None):
Execute a SQL script. By default this will run as 'postgres' user, but can
use PGPASSWORD authentication if necessary.
port = '--port={}'.format(port) if port else ''
if user and password:
kw = dict(pw=password, user=user, port=port, path=path, db=database)
return c.sudo(" PGPASSWORD='{pw}' psql --host=localhost {port} --username='{user}' --file='{path}' {db}".format(**kw),
else: # run as postgres
kw = dict(port=port, path=path, db=database)
return c.sudo("psql {port} --file='{path}' {db}".format(**kw),
def user_exists(c, name, port=None):
Determine if a given PostgreSQL user exists.
user = sql(c, "SELECT rolname FROM pg_roles WHERE rolname = '{0}'".format(name), port=port).stdout.strip()
return bool(user)
def create_user(c, name, password=None, port=None, checkfirst=True, createdb=False):
Create a PostgreSQL user account.
if not checkfirst or not user_exists(c, name, port=port):
cmd = 'createuser {port} {createdb} --no-createrole --no-superuser {name}'.format(
port='--port={}'.format(port) if port else '',
createdb='--{}createdb'.format('' if createdb else 'no-'),
c.sudo(cmd, user='postgres')
if password:
set_user_password(c, name, password, port=port)
def set_user_password(c, name, password, port=None):
Set the password for a PostgreSQL user account
sql(c, "ALTER USER \\\"{}\\\" PASSWORD '{}';".format(name, password), port=port,
hide=True, echo=False)
def db_exists(c, name, port=None):
Determine if a given PostgreSQL database exists.
db = sql(c, "SELECT datname FROM pg_database WHERE datname = '{0}'".format(name), port=port).stdout.strip()
return db == name
def create_db(c, name, owner=None, port=None, checkfirst=True):
Create a PostgreSQL database.
if not checkfirst or not db_exists(c, name, port=port):
cmd = 'createdb {port} {owner} {name}'.format(
port='--port={}'.format(port) if port else '',
owner='--owner={}'.format(owner) if owner else '',
c.sudo(cmd, user='postgres')
def create_schema(c, name, dbname, owner='rattail', port=None):
Create a schema within a PostgreSQL database.
sql_ = "create schema if not exists {} authorization {}".format(name, owner)
sql(c, sql_, database=dbname, port=port)
def drop_db(c, name, checkfirst=True):
Drop a PostgreSQL database.
if not checkfirst or db_exists(c, name):
c.sudo('dropdb {}'.format(name), user='postgres')
def dump_db(c, name, port=None, exclude_tables=None):
Dump a database to file, on the server represented by ``c`` param.
This function returns the name of the DB dump file. The name will not have
a path component as it's assumed to be in the home folder of the connection
c.run('touch {}.sql'.format(name))
c.run('chmod 0666 {}.sql'.format(name))
cmd = 'pg_dump {port} {exclude_tables} --file={name}.sql {name}'.format(
port='--port={}'.format(port) if port else '',
exclude_tables='--exclude-table-data={}'.format(exclude_tables) if exclude_tables else '')
c.sudo(cmd, user='postgres')
c.run('gzip --force {}.sql'.format(name))
return '{}.sql.gz'.format(name)
def download_db(c, name, destination=None, port=None, exclude_tables=None):
Download a database from the server represented by ``c`` param.
if destination is None:
destination = './{}.sql.gz'.format(name)
dumpfile = dump_db(c, name, port=port, exclude_tables=exclude_tables)
c.get(dumpfile, destination)
c.run('rm {}'.format(dumpfile))
def clone_db(c, name, owner, download, user='rattail', force=False, workdir=None):
Clone a database from a (presumably live) server
:param name: Name of the database.
:param owner: Username of the user who is to own the database.
:param force: Whether the target database should be forcibly dropped, if it
exists already.
if db_exists(c, name):
if force:
drop_db(c, name, checkfirst=False)
raise RuntimeError("Database '{}' already exists!".format(name))
create_db(c, name, owner=owner, checkfirst=False)
# upload database dump to target server
if workdir:
curdir = os.getcwd()
download(c, '{}.sql.gz'.format(name), user=user)
c.local('rm {}.sql.gz'.format(name))
if workdir:
# restore database on target server
restore_db(c, name, '{}.sql.gz'.format(name))
def restore_db(c, name, path):
Restore a database from a dump file.
:param name: Name of the database to restore.
:param path: Path to the DB dump file, which should end in ``.sql.gz``
if not path.endswith('.sql.gz'):
raise ValueError("Path to dump file must end in `.sql.gz`")
c.run('gunzip --force {}'.format(path))
sql_path = path[:-3]
c.sudo('psql --echo-errors --file={} {}'.format(sql_path, name),
c.run('rm {}'.format(sql_path))