diff --git a/docs/api/wuttamess.ssh.rst b/docs/api/wuttamess.ssh.rst new file mode 100644 index 0000000..1810230 --- /dev/null +++ b/docs/api/wuttamess.ssh.rst @@ -0,0 +1,6 @@ + +``wuttamess.ssh`` +================= + +.. automodule:: wuttamess.ssh + :members: diff --git a/docs/index.rst b/docs/index.rst index 78fc2d0..a719218 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,5 +32,6 @@ project. api/wuttamess api/wuttamess.apt api/wuttamess.postfix + api/wuttamess.ssh api/wuttamess.sync api/wuttamess.util diff --git a/src/wuttamess/ssh.py b/src/wuttamess/ssh.py new file mode 100644 index 0000000..87a7540 --- /dev/null +++ b/src/wuttamess/ssh.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8; -*- +################################################################################ +# +# WuttaMess -- Fabric Automation Helpers +# Copyright © 2024 Lance Edgar +# +# This file is part of Wutta Framework. +# +# Wutta Framework is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. +# +# Wutta Framework is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for +# more details. +# +# You should have received a copy of the GNU General Public License along with +# Wutta Framework. If not, see . +# +################################################################################ +""" +SSH Utilities +""" + + +def cache_host_key(c, host, port=None, user=None): + """ + Cache the SSH host key for the given host, for the given user. + + :param c: Fabric connection. + + :param host: Name or IP of the host whose key should be cached. + + Note that you can specify a username along with the hostname if + needed, e.g. any of these works: + + * ``1.2.3.4`` + * ``foo@1.2.3.4`` + * ``example.com`` + * ``foo@example.com`` + + :param port: Optional SSH port for the ``host``; default is 22. + + :param user: User on the fabric target whose SSH key cache should + be updated to include the given ``host``. + """ + port = f'-p {port} ' if port else '' + + # first try to run a basic command over ssh + cmd = f'ssh {port}{host} whoami' + if user and user != 'root': + result = c.sudo(cmd, user=user, warn=True) + else: + result = c.run(cmd, warn=True) + + # no need to update cache if command worked okay + if not result.failed: + return + + # basic command failed, but in some cases that is simply b/c + # normal commands are not allowed, although the ssh connection + # itself was established okay. so here we check for that. + if "Disallowed command" in result.stderr: + return + + # okay then we now think that the ssh connection itself + # was not made, which presumably means we *do* need to + # cache the host key, so try that now + cmd = f'ssh -o StrictHostKeyChecking=no {port}{host} whoami' + if user and user != 'root': + c.sudo(cmd, user=user, warn=True) + else: + c.run(cmd, warn=True) diff --git a/tests/test_ssh.py b/tests/test_ssh.py new file mode 100644 index 0000000..93b1cdc --- /dev/null +++ b/tests/test_ssh.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8; -*- + +from unittest import TestCase +from unittest.mock import MagicMock, call + +from wuttamess import ssh as mod + + +class TestCacheHostKey(TestCase): + + def test_root_already_cached(self): + c = MagicMock() + + # assume the first command runs okay + c.run.return_value.failed = False + mod.cache_host_key(c, 'example.com') + c.run.assert_called_once_with('ssh example.com whoami', warn=True) + + def test_root_commands_not_allowed(self): + c = MagicMock() + + # assume the first command fails b/c "disallowed" + c.run.return_value.failed = True + c.run.return_value.stderr = "Disallowed command" + mod.cache_host_key(c, 'example.com') + c.run.assert_called_once_with('ssh example.com whoami', warn=True) + + def test_root_cache_key(self): + c = MagicMock() + + # first command fails; second command caches host key + c.run.return_value.failed = True + mod.cache_host_key(c, 'example.com') + c.run.assert_has_calls([call('ssh example.com whoami', warn=True)]) + c.run.assert_called_with('ssh -o StrictHostKeyChecking=no example.com whoami', warn=True) + + def test_user_already_cached(self): + c = MagicMock() + + # assume the first command runs okay + c.sudo.return_value.failed = False + mod.cache_host_key(c, 'example.com', user='foo') + c.sudo.assert_called_once_with('ssh example.com whoami', user='foo', warn=True) + + def test_user_commands_not_allowed(self): + c = MagicMock() + + # assume the first command fails b/c "disallowed" + c.sudo.return_value.failed = True + c.sudo.return_value.stderr = "Disallowed command" + mod.cache_host_key(c, 'example.com', user='foo') + c.sudo.assert_called_once_with('ssh example.com whoami', user='foo', warn=True) + + def test_user_cache_key(self): + c = MagicMock() + + # first command fails; second command caches host key + c.sudo.return_value.failed = True + mod.cache_host_key(c, 'example.com', user='foo') + c.sudo.assert_has_calls([call('ssh example.com whoami', user='foo', warn=True)]) + c.sudo.assert_called_with('ssh -o StrictHostKeyChecking=no example.com whoami', + user='foo', warn=True)