wuttatell/tests/test_client.py
Lance Edgar 0f42626cb1 feat: add native API client for common submission use case
this should work for simple API endpoint which requires an auth token.
if more is needed, you'll have to override telemetry handler
2025-08-10 15:13:20 -05:00

203 lines
7.8 KiB
Python

# -*- coding: utf-8; -*-
import json
import threading
import time
from http import HTTPStatus
from http.server import HTTPServer, BaseHTTPRequestHandler
import requests
from urllib3.util.retry import Retry
from wuttjamaican.testing import ConfigTestCase
from wuttatell import client as mod
class TestSimpleAPIClient(ConfigTestCase):
def make_client(self, **kw):
return mod.SimpleAPIClient(self.config, **kw)
def test_constructor(self):
# caller specifies params
client = self.make_client(base_url='https://example.com/api/',
token='XYZPDQ12345',
ssl_verify=False,
max_retries=5)
self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash
self.assertEqual(client.token, 'XYZPDQ12345')
self.assertFalse(client.ssl_verify)
self.assertEqual(client.max_retries, 5)
self.assertIsNone(client.session)
# now with some defaults
client = self.make_client(base_url='https://example.com/api/',
token='XYZPDQ12345')
self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash
self.assertEqual(client.token, 'XYZPDQ12345')
self.assertTrue(client.ssl_verify)
self.assertIsNone(client.max_retries)
self.assertIsNone(client.session)
# now from config
self.config.setdefault('wutta.api.base_url', 'https://another.com/api/')
self.config.setdefault('wutta.api.token', '9843243q4')
self.config.setdefault('wutta.api.ssl_verify', 'false')
self.config.setdefault('wutta.api.max_retries', '4')
client = self.make_client()
self.assertEqual(client.base_url, 'https://another.com/api') # no trailing slash
self.assertEqual(client.token, '9843243q4')
self.assertFalse(client.ssl_verify)
self.assertEqual(client.max_retries, 4)
self.assertIsNone(client.session)
def test_init_session(self):
# client begins with no session
client = self.make_client(base_url='https://example.com/api', token='1234')
self.assertIsNone(client.session)
# session is created here
client.init_session()
self.assertIsInstance(client.session, requests.Session)
self.assertTrue(client.session.verify)
self.assertTrue(all([a.max_retries.total == 0 for a in client.session.adapters.values()]))
self.assertIn('Authorization', client.session.headers)
self.assertEqual(client.session.headers['Authorization'], 'Bearer 1234')
# session is never re-created
orig_session = client.session
client.init_session()
self.assertIs(client.session, orig_session)
# new client/session with no ssl_verify
client = self.make_client(base_url='https://example.com/api', token='1234', ssl_verify=False)
client.init_session()
self.assertFalse(client.session.verify)
# new client/session with max_retries
client = self.make_client(base_url='https://example.com/api', token='1234', max_retries=5)
client.init_session()
self.assertEqual(client.session.adapters['https://example.com/api'].max_retries.total, 5)
def test_make_request_get(self):
# start server
threading.Thread(target=start_server).start()
while not SERVER['running']:
time.sleep(0.02)
# server returns our headers
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
response = client.make_request('GET', '/telemetry')
result = response.json()
self.assertIn('headers', result)
self.assertIn('Authorization', result['headers'])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
self.assertNotIn('payload', result)
def test_make_request_post(self):
# start server
threading.Thread(target=start_server).start()
while not SERVER['running']:
time.sleep(0.02)
# server returns our headers + payload
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
response = client.make_request('POST', '/telemetry', data={'os': {'name': 'debian'}})
result = response.json()
self.assertIn('headers', result)
self.assertIn('Authorization', result['headers'])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
self.assertIn('payload', result)
self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}})
def test_make_request_unsupported(self):
# start server
threading.Thread(target=start_server).start()
while not SERVER['running']:
time.sleep(0.02)
# e.g. DELETE is not implemented
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
self.assertRaises(NotImplementedError, client.make_request, 'DELETE', '/telemetry')
# nb. issue valid request to stop the server
client.make_request('GET', '/telemetry')
def test_get(self):
# start server
threading.Thread(target=start_server).start()
while not SERVER['running']:
time.sleep(0.02)
# server returns our headers
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
response = client.get('/telemetry')
result = response.json()
self.assertIn('headers', result)
self.assertIn('Authorization', result['headers'])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
self.assertNotIn('payload', result)
def test_post(self):
# start server
threading.Thread(target=start_server).start()
while not SERVER['running']:
time.sleep(0.02)
# server returns our headers + payload
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
response = client.post('/telemetry', data={'os': {'name': 'debian'}})
result = response.json()
self.assertIn('headers', result)
self.assertIn('Authorization', result['headers'])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
self.assertIn('payload', result)
self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}})
class FakeRequestHandler(BaseHTTPRequestHandler):
""" """
def do_GET(self):
headers = dict([(k, v) for k, v in self.headers.items()])
result = {'headers': headers}
result = json.dumps(result).encode('utf_8')
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", 'text/json')
self.send_header("Content-Length", str(len(result)))
self.end_headers()
self.wfile.write(result)
def do_POST(self):
headers = dict([(k, v) for k, v in self.headers.items()])
length = int(self.headers.get('Content-Length'))
payload = self.rfile.read(length).decode('utf_8')
result = {'headers': headers, 'payload': payload}
result = json.dumps(result).encode('utf_8')
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", 'text/json')
self.send_header("Content-Length", str(len(result)))
self.end_headers()
self.wfile.write(result)
SERVER = {'running': False, 'port': 7314}
def start_server():
if SERVER['running']:
raise RuntimeError("http server is already running")
with HTTPServer(('127.0.0.1', SERVER['port']), FakeRequestHandler) as httpd:
SERVER['running'] = True
httpd.handle_request()
SERVER['running'] = False