diff --git a/docs/api/wuttatell.client.rst b/docs/api/wuttatell.client.rst new file mode 100644 index 0000000..7a3ef40 --- /dev/null +++ b/docs/api/wuttatell.client.rst @@ -0,0 +1,6 @@ + +``wuttatell.client`` +==================== + +.. automodule:: wuttatell.client + :members: diff --git a/docs/conf.py b/docs/conf.py index 54e91d8..d65a6ba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,6 +28,7 @@ templates_path = ['_templates'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] intersphinx_mapping = { + 'requests': ('https://requests.readthedocs.io/en/latest/', None), 'wuttjamaican': ('https://docs.wuttaproject.org/wuttjamaican/', None), } diff --git a/docs/index.rst b/docs/index.rst index 0fa4d2a..ccecc89 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,4 +25,5 @@ project. api/wuttatell.app api/wuttatell.cli api/wuttatell.cli.tell + api/wuttatell.client api/wuttatell.telemetry diff --git a/src/wuttatell/client.py b/src/wuttatell/client.py new file mode 100644 index 0000000..1a77496 --- /dev/null +++ b/src/wuttatell/client.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8; -*- +################################################################################ +# +# WuttaTell -- Telemetry submission for Wutta Framework +# Copyright © 2025 Lance Edgar +# +# This file is part of Wutta Framework. +# +# Wutta Framework is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. +# +# Wutta Framework is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for +# more details. +# +# You should have received a copy of the GNU General Public License along with +# Wutta Framework. If not, see . +# +################################################################################ +""" +Simple API Client +""" + +import json +from urllib.parse import urlparse + +import requests + + +class SimpleAPIClient: + """ + Simple client for "typical" API service. + + This basically assumes telemetry can be submitted to a single API + endpoint, and the request should contain an auth token. + + :param config: App :term:`config object`. + + :param base_url: Base URL of the API. + + :param token: Auth token for the API. + + :param ssl_verify: Whether the SSL cert presented by the server + should be verified. This is effectively true by default, but + may be disabled for testing with self-signed certs etc. + + :param max_retries: Maximum number of retries each connection + should attempt. This value is ultimately given to the + :class:`~requests:requests.adapters.HTTPAdapter` instance. + + Most params may be omitted, if config specifies instead: + + .. code-block:: ini + + [wutta.api] + base_url = https://my.example.com/api + token = XYZPDQ12345 + ssl_verify = false + max_retries = 5 + + Upon instantiation, :attr:`session` will be ``None`` until the + first request is made. (Technically when :meth:`init_session()` + first happens.) + + .. attribute:: session + + :class:`requests:requests.Session` instance being used to make + API requests. + """ + + def __init__(self, config, base_url=None, token=None, ssl_verify=None, max_retries=None): + self.config = config + + self.base_url = base_url or self.config.require(f'{self.config.appname}.api.base_url') + self.base_url = self.base_url.rstrip('/') + self.token = token or self.config.require(f'{self.config.appname}.api.token') + + if max_retries is not None: + self.max_retries = max_retries + else: + self.max_retries = self.config.get_int(f'{self.config.appname}.api.max_retries') + + if ssl_verify is not None: + self.ssl_verify = ssl_verify + else: + self.ssl_verify = self.config.get_bool(f'{self.config.appname}.api.ssl_verify', + default=True) + + self.session = None + + def init_session(self): + """ + Initialize the HTTP session with the API. + + This method is invoked as part of :meth:`make_request()`. + + It first checks :attr:`session` and will skip if already initialized. + + For initialization, it establishes a new + :class:`requests:requests.Session` instance, and modifies it + as needed per config. + """ + if self.session: + return + + self.session = requests.Session() + + # maybe *disable* SSL cert verification + # (should only be used for testing e.g. w/ self-signed certs) + if not self.ssl_verify: + self.session.verify = False + + # maybe set max retries, e.g. for flaky connections + if self.max_retries is not None: + adapter = requests.adapters.HTTPAdapter(max_retries=self.max_retries) + self.session.mount(self.base_url, adapter) + + # TODO: is this a good idea, or hacky security risk..? + # without it, can get error response: + # 400 Client Error: Bad CSRF Origin for url + parts = urlparse(self.base_url) + self.session.headers.update({ + 'Origin': f'{parts.scheme}://{parts.netloc}', + }) + + # authenticate via token only (for now?) + self.session.headers.update({ + 'Authorization': f'Bearer {self.token}', + }) + + def make_request(self, request_method, api_method, params=None, data=None): + """ + Make a request to the API, and return the response. + + This first calls :meth:`init_session()` to establish the + session if needed. + + :param request_method: HTTP request method; for now only + ``'GET'`` and ``'POST'`` are supported. + + :param api_method: API method endpoint to use, + e.g. ``'/my/telemetry'`` + + :param params: Dict of query string params for the request, if + applicable. + + :param data: Payload data for the request, if applicable. + Should be JSON-serializable, e.g. a list or dict. + + :rtype: :class:`requests:requests.Response` instance. + """ + self.init_session() + api_method = api_method.lstrip('/') + url = f'{self.base_url}/{api_method}' + if request_method == 'GET': + response = self.session.get(url, params=params) + elif request_method == 'POST': + response = self.session.post(url, params=params, + data=json.dumps(data)) + else: + raise NotImplementedError(f"unsupported request method: {request_method}") + response.raise_for_status() + return response + + def get(self, api_method, params=None): + """ + Perform a GET request for the given API method, and return the + response. + + This calls :meth:`make_request()` for the heavy lifting. + + :param api_method: API method endpoint to use, + e.g. ``'/my/telemetry'`` + + :param params: Dict of query string params for the request, if + applicable. + + :rtype: :class:`requests:requests.Response` instance. + """ + return self.make_request('GET', api_method, params=params) + + def post(self, api_method, **kwargs): + """ + Perform a POST request for the given API method, and return + the response. + + This calls :meth:`make_request()` for the heavy lifting. + + :param api_method: API method endpoint to use, + e.g. ``'/my/telemetry'`` + + :rtype: :class:`requests:requests.Response` instance. + """ + return self.make_request('POST', api_method, **kwargs) diff --git a/src/wuttatell/telemetry.py b/src/wuttatell/telemetry.py index 2194652..28b6b82 100644 --- a/src/wuttatell/telemetry.py +++ b/src/wuttatell/telemetry.py @@ -31,6 +31,8 @@ import subprocess from wuttjamaican.app import GenericHandler from wuttjamaican.conf import WuttaConfigProfile +from wuttatell.client import SimpleAPIClient + class TelemetryHandler(GenericHandler): """ @@ -220,14 +222,21 @@ class TelemetryHandler(GenericHandler): """ Submit telemetry data to the configured collection service. - Default logic is not implemented; subclass must override. + Default logic will use + :class:`~wuttatell.client.SimpleAPIClient` and submit all + collected data to the configured API endpoint. :param profile: :class:`TelemetryProfile` instance. :param data: Data dict as obtained by :meth:`collect_all_data()`. """ - raise NotImplementedError + profile = self.get_profile(profile) + if data is None: + data = self.collect_all_data(profile) + + client = SimpleAPIClient(self.config) + client.post(profile.submit_url, data=data) class TelemetryProfile(WuttaConfigProfile): diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..95691ee --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8; -*- + +import json +import threading +import time +from http import HTTPStatus +from http.server import HTTPServer, BaseHTTPRequestHandler + +import requests +from urllib3.util.retry import Retry + +from wuttjamaican.testing import ConfigTestCase +from wuttatell import client as mod + + +class TestSimpleAPIClient(ConfigTestCase): + + def make_client(self, **kw): + return mod.SimpleAPIClient(self.config, **kw) + + def test_constructor(self): + + # caller specifies params + client = self.make_client(base_url='https://example.com/api/', + token='XYZPDQ12345', + ssl_verify=False, + max_retries=5) + self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash + self.assertEqual(client.token, 'XYZPDQ12345') + self.assertFalse(client.ssl_verify) + self.assertEqual(client.max_retries, 5) + self.assertIsNone(client.session) + + # now with some defaults + client = self.make_client(base_url='https://example.com/api/', + token='XYZPDQ12345') + self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash + self.assertEqual(client.token, 'XYZPDQ12345') + self.assertTrue(client.ssl_verify) + self.assertIsNone(client.max_retries) + self.assertIsNone(client.session) + + # now from config + self.config.setdefault('wutta.api.base_url', 'https://another.com/api/') + self.config.setdefault('wutta.api.token', '9843243q4') + self.config.setdefault('wutta.api.ssl_verify', 'false') + self.config.setdefault('wutta.api.max_retries', '4') + client = self.make_client() + self.assertEqual(client.base_url, 'https://another.com/api') # no trailing slash + self.assertEqual(client.token, '9843243q4') + self.assertFalse(client.ssl_verify) + self.assertEqual(client.max_retries, 4) + self.assertIsNone(client.session) + + def test_init_session(self): + + # client begins with no session + client = self.make_client(base_url='https://example.com/api', token='1234') + self.assertIsNone(client.session) + + # session is created here + client.init_session() + self.assertIsInstance(client.session, requests.Session) + self.assertTrue(client.session.verify) + self.assertTrue(all([a.max_retries.total == 0 for a in client.session.adapters.values()])) + self.assertIn('Authorization', client.session.headers) + self.assertEqual(client.session.headers['Authorization'], 'Bearer 1234') + + # session is never re-created + orig_session = client.session + client.init_session() + self.assertIs(client.session, orig_session) + + # new client/session with no ssl_verify + client = self.make_client(base_url='https://example.com/api', token='1234', ssl_verify=False) + client.init_session() + self.assertFalse(client.session.verify) + + # new client/session with max_retries + client = self.make_client(base_url='https://example.com/api', token='1234', max_retries=5) + client.init_session() + self.assertEqual(client.session.adapters['https://example.com/api'].max_retries.total, 5) + + def test_make_request_get(self): + + # start server + threading.Thread(target=start_server).start() + while not SERVER['running']: + time.sleep(0.02) + + # server returns our headers + client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) + response = client.make_request('GET', '/telemetry') + result = response.json() + self.assertIn('headers', result) + self.assertIn('Authorization', result['headers']) + self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') + self.assertNotIn('payload', result) + + def test_make_request_post(self): + + # start server + threading.Thread(target=start_server).start() + while not SERVER['running']: + time.sleep(0.02) + + # server returns our headers + payload + client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) + response = client.make_request('POST', '/telemetry', data={'os': {'name': 'debian'}}) + result = response.json() + self.assertIn('headers', result) + self.assertIn('Authorization', result['headers']) + self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') + self.assertIn('payload', result) + self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}}) + + def test_make_request_unsupported(self): + + # start server + threading.Thread(target=start_server).start() + while not SERVER['running']: + time.sleep(0.02) + + # e.g. DELETE is not implemented + client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) + self.assertRaises(NotImplementedError, client.make_request, 'DELETE', '/telemetry') + + # nb. issue valid request to stop the server + client.make_request('GET', '/telemetry') + + def test_get(self): + + # start server + threading.Thread(target=start_server).start() + while not SERVER['running']: + time.sleep(0.02) + + # server returns our headers + client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) + response = client.get('/telemetry') + result = response.json() + self.assertIn('headers', result) + self.assertIn('Authorization', result['headers']) + self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') + self.assertNotIn('payload', result) + + def test_post(self): + + # start server + threading.Thread(target=start_server).start() + while not SERVER['running']: + time.sleep(0.02) + + # server returns our headers + payload + client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) + response = client.post('/telemetry', data={'os': {'name': 'debian'}}) + result = response.json() + self.assertIn('headers', result) + self.assertIn('Authorization', result['headers']) + self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') + self.assertIn('payload', result) + self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}}) + + +class FakeRequestHandler(BaseHTTPRequestHandler): + """ """ + + def do_GET(self): + headers = dict([(k, v) for k, v in self.headers.items()]) + result = {'headers': headers} + result = json.dumps(result).encode('utf_8') + + self.send_response(HTTPStatus.OK) + self.send_header("Content-Type", 'text/json') + self.send_header("Content-Length", str(len(result))) + self.end_headers() + self.wfile.write(result) + + def do_POST(self): + headers = dict([(k, v) for k, v in self.headers.items()]) + length = int(self.headers.get('Content-Length')) + payload = self.rfile.read(length).decode('utf_8') + result = {'headers': headers, 'payload': payload} + result = json.dumps(result).encode('utf_8') + + self.send_response(HTTPStatus.OK) + self.send_header("Content-Type", 'text/json') + self.send_header("Content-Length", str(len(result))) + self.end_headers() + self.wfile.write(result) + + +SERVER = {'running': False, 'port': 7314} + +def start_server(): + if SERVER['running']: + raise RuntimeError("http server is already running") + + with HTTPServer(('127.0.0.1', SERVER['port']), FakeRequestHandler) as httpd: + SERVER['running'] = True + httpd.handle_request() + + SERVER['running'] = False diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index ae90b97..fb40b23 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -1,6 +1,6 @@ # -*- coding: utf-8; -*- -from unittest.mock import patch +from unittest.mock import patch, MagicMock from wuttjamaican.testing import ConfigTestCase @@ -160,9 +160,24 @@ class TestTelemetryHandler(ConfigTestCase): self.assertNotIn('errors', data) def test_submit_all_data(self): + profile = self.handler.get_profile('default') + profile.submit_url = '/testing' - # not (yet?) implemented - self.assertRaises(NotImplementedError, self.handler.submit_all_data) + with patch.object(mod, 'SimpleAPIClient') as SimpleAPIClient: + client = MagicMock() + SimpleAPIClient.return_value = client + + # collecting all data + with patch.object(self.handler, 'collect_all_data') as collect_all_data: + collect_all_data.return_value = [] + self.handler.submit_all_data(profile) + collect_all_data.assert_called_once_with(profile) + client.post.assert_called_once_with('/testing', data=[]) + + # use data from caller + client.post.reset_mock() + self.handler.submit_all_data(profile, data=['foo']) + client.post.assert_called_once_with('/testing', data=['foo']) class TestTelemetryProfile(ConfigTestCase):