feat: add native API client for common submission use case
this should work for simple API endpoint which requires an auth token. if more is needed, you'll have to override telemetry handler
This commit is contained in:
parent
e0dd704247
commit
0f42626cb1
7 changed files with 437 additions and 5 deletions
203
tests/test_client.py
Normal file
203
tests/test_client.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
# -*- coding: utf-8; -*-
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
import requests
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from wuttjamaican.testing import ConfigTestCase
|
||||
from wuttatell import client as mod
|
||||
|
||||
|
||||
class TestSimpleAPIClient(ConfigTestCase):
|
||||
|
||||
def make_client(self, **kw):
|
||||
return mod.SimpleAPIClient(self.config, **kw)
|
||||
|
||||
def test_constructor(self):
|
||||
|
||||
# caller specifies params
|
||||
client = self.make_client(base_url='https://example.com/api/',
|
||||
token='XYZPDQ12345',
|
||||
ssl_verify=False,
|
||||
max_retries=5)
|
||||
self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash
|
||||
self.assertEqual(client.token, 'XYZPDQ12345')
|
||||
self.assertFalse(client.ssl_verify)
|
||||
self.assertEqual(client.max_retries, 5)
|
||||
self.assertIsNone(client.session)
|
||||
|
||||
# now with some defaults
|
||||
client = self.make_client(base_url='https://example.com/api/',
|
||||
token='XYZPDQ12345')
|
||||
self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash
|
||||
self.assertEqual(client.token, 'XYZPDQ12345')
|
||||
self.assertTrue(client.ssl_verify)
|
||||
self.assertIsNone(client.max_retries)
|
||||
self.assertIsNone(client.session)
|
||||
|
||||
# now from config
|
||||
self.config.setdefault('wutta.api.base_url', 'https://another.com/api/')
|
||||
self.config.setdefault('wutta.api.token', '9843243q4')
|
||||
self.config.setdefault('wutta.api.ssl_verify', 'false')
|
||||
self.config.setdefault('wutta.api.max_retries', '4')
|
||||
client = self.make_client()
|
||||
self.assertEqual(client.base_url, 'https://another.com/api') # no trailing slash
|
||||
self.assertEqual(client.token, '9843243q4')
|
||||
self.assertFalse(client.ssl_verify)
|
||||
self.assertEqual(client.max_retries, 4)
|
||||
self.assertIsNone(client.session)
|
||||
|
||||
def test_init_session(self):
|
||||
|
||||
# client begins with no session
|
||||
client = self.make_client(base_url='https://example.com/api', token='1234')
|
||||
self.assertIsNone(client.session)
|
||||
|
||||
# session is created here
|
||||
client.init_session()
|
||||
self.assertIsInstance(client.session, requests.Session)
|
||||
self.assertTrue(client.session.verify)
|
||||
self.assertTrue(all([a.max_retries.total == 0 for a in client.session.adapters.values()]))
|
||||
self.assertIn('Authorization', client.session.headers)
|
||||
self.assertEqual(client.session.headers['Authorization'], 'Bearer 1234')
|
||||
|
||||
# session is never re-created
|
||||
orig_session = client.session
|
||||
client.init_session()
|
||||
self.assertIs(client.session, orig_session)
|
||||
|
||||
# new client/session with no ssl_verify
|
||||
client = self.make_client(base_url='https://example.com/api', token='1234', ssl_verify=False)
|
||||
client.init_session()
|
||||
self.assertFalse(client.session.verify)
|
||||
|
||||
# new client/session with max_retries
|
||||
client = self.make_client(base_url='https://example.com/api', token='1234', max_retries=5)
|
||||
client.init_session()
|
||||
self.assertEqual(client.session.adapters['https://example.com/api'].max_retries.total, 5)
|
||||
|
||||
def test_make_request_get(self):
|
||||
|
||||
# start server
|
||||
threading.Thread(target=start_server).start()
|
||||
while not SERVER['running']:
|
||||
time.sleep(0.02)
|
||||
|
||||
# server returns our headers
|
||||
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
|
||||
response = client.make_request('GET', '/telemetry')
|
||||
result = response.json()
|
||||
self.assertIn('headers', result)
|
||||
self.assertIn('Authorization', result['headers'])
|
||||
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
|
||||
self.assertNotIn('payload', result)
|
||||
|
||||
def test_make_request_post(self):
|
||||
|
||||
# start server
|
||||
threading.Thread(target=start_server).start()
|
||||
while not SERVER['running']:
|
||||
time.sleep(0.02)
|
||||
|
||||
# server returns our headers + payload
|
||||
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
|
||||
response = client.make_request('POST', '/telemetry', data={'os': {'name': 'debian'}})
|
||||
result = response.json()
|
||||
self.assertIn('headers', result)
|
||||
self.assertIn('Authorization', result['headers'])
|
||||
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
|
||||
self.assertIn('payload', result)
|
||||
self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}})
|
||||
|
||||
def test_make_request_unsupported(self):
|
||||
|
||||
# start server
|
||||
threading.Thread(target=start_server).start()
|
||||
while not SERVER['running']:
|
||||
time.sleep(0.02)
|
||||
|
||||
# e.g. DELETE is not implemented
|
||||
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
|
||||
self.assertRaises(NotImplementedError, client.make_request, 'DELETE', '/telemetry')
|
||||
|
||||
# nb. issue valid request to stop the server
|
||||
client.make_request('GET', '/telemetry')
|
||||
|
||||
def test_get(self):
|
||||
|
||||
# start server
|
||||
threading.Thread(target=start_server).start()
|
||||
while not SERVER['running']:
|
||||
time.sleep(0.02)
|
||||
|
||||
# server returns our headers
|
||||
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
|
||||
response = client.get('/telemetry')
|
||||
result = response.json()
|
||||
self.assertIn('headers', result)
|
||||
self.assertIn('Authorization', result['headers'])
|
||||
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
|
||||
self.assertNotIn('payload', result)
|
||||
|
||||
def test_post(self):
|
||||
|
||||
# start server
|
||||
threading.Thread(target=start_server).start()
|
||||
while not SERVER['running']:
|
||||
time.sleep(0.02)
|
||||
|
||||
# server returns our headers + payload
|
||||
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False)
|
||||
response = client.post('/telemetry', data={'os': {'name': 'debian'}})
|
||||
result = response.json()
|
||||
self.assertIn('headers', result)
|
||||
self.assertIn('Authorization', result['headers'])
|
||||
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234')
|
||||
self.assertIn('payload', result)
|
||||
self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}})
|
||||
|
||||
|
||||
class FakeRequestHandler(BaseHTTPRequestHandler):
|
||||
""" """
|
||||
|
||||
def do_GET(self):
|
||||
headers = dict([(k, v) for k, v in self.headers.items()])
|
||||
result = {'headers': headers}
|
||||
result = json.dumps(result).encode('utf_8')
|
||||
|
||||
self.send_response(HTTPStatus.OK)
|
||||
self.send_header("Content-Type", 'text/json')
|
||||
self.send_header("Content-Length", str(len(result)))
|
||||
self.end_headers()
|
||||
self.wfile.write(result)
|
||||
|
||||
def do_POST(self):
|
||||
headers = dict([(k, v) for k, v in self.headers.items()])
|
||||
length = int(self.headers.get('Content-Length'))
|
||||
payload = self.rfile.read(length).decode('utf_8')
|
||||
result = {'headers': headers, 'payload': payload}
|
||||
result = json.dumps(result).encode('utf_8')
|
||||
|
||||
self.send_response(HTTPStatus.OK)
|
||||
self.send_header("Content-Type", 'text/json')
|
||||
self.send_header("Content-Length", str(len(result)))
|
||||
self.end_headers()
|
||||
self.wfile.write(result)
|
||||
|
||||
|
||||
SERVER = {'running': False, 'port': 7314}
|
||||
|
||||
def start_server():
|
||||
if SERVER['running']:
|
||||
raise RuntimeError("http server is already running")
|
||||
|
||||
with HTTPServer(('127.0.0.1', SERVER['port']), FakeRequestHandler) as httpd:
|
||||
SERVER['running'] = True
|
||||
httpd.handle_request()
|
||||
|
||||
SERVER['running'] = False
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8; -*-
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from wuttjamaican.testing import ConfigTestCase
|
||||
|
||||
|
@ -160,9 +160,24 @@ class TestTelemetryHandler(ConfigTestCase):
|
|||
self.assertNotIn('errors', data)
|
||||
|
||||
def test_submit_all_data(self):
|
||||
profile = self.handler.get_profile('default')
|
||||
profile.submit_url = '/testing'
|
||||
|
||||
# not (yet?) implemented
|
||||
self.assertRaises(NotImplementedError, self.handler.submit_all_data)
|
||||
with patch.object(mod, 'SimpleAPIClient') as SimpleAPIClient:
|
||||
client = MagicMock()
|
||||
SimpleAPIClient.return_value = client
|
||||
|
||||
# collecting all data
|
||||
with patch.object(self.handler, 'collect_all_data') as collect_all_data:
|
||||
collect_all_data.return_value = []
|
||||
self.handler.submit_all_data(profile)
|
||||
collect_all_data.assert_called_once_with(profile)
|
||||
client.post.assert_called_once_with('/testing', data=[])
|
||||
|
||||
# use data from caller
|
||||
client.post.reset_mock()
|
||||
self.handler.submit_all_data(profile, data=['foo'])
|
||||
client.post.assert_called_once_with('/testing', data=['foo'])
|
||||
|
||||
|
||||
class TestTelemetryProfile(ConfigTestCase):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue