fix: format all code with black

and from now on should not deviate from that...
This commit is contained in:
Lance Edgar 2025-08-31 12:47:46 -05:00
parent 18f7fa6c51
commit 8be1b66c9e
10 changed files with 334 additions and 258 deletions

View file

@ -8,33 +8,33 @@
from importlib.metadata import version as get_version from importlib.metadata import version as get_version
project = 'WuttaTell' project = "WuttaTell"
copyright = '2025, Lance Edgar' copyright = "2025, Lance Edgar"
author = 'Lance Edgar' author = "Lance Edgar"
release = get_version('WuttaTell') release = get_version("WuttaTell")
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.todo', "sphinx.ext.todo",
'sphinxcontrib.programoutput', "sphinxcontrib.programoutput",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
intersphinx_mapping = { intersphinx_mapping = {
'requests': ('https://requests.readthedocs.io/en/latest/', None), "requests": ("https://requests.readthedocs.io/en/latest/", None),
'wuttjamaican': ('https://docs.wuttaproject.org/wuttjamaican/', None), "wuttjamaican": ("https://docs.wuttaproject.org/wuttjamaican/", None),
} }
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'furo' html_theme = "furo"
html_static_path = ['_static'] html_static_path = ["_static"]

View file

@ -3,4 +3,4 @@
from importlib.metadata import version from importlib.metadata import version
__version__ = version('WuttaTell') __version__ = version("WuttaTell")

View file

@ -41,9 +41,11 @@ class WuttaTellAppProvider(AppProvider):
:rtype: :class:`~wuttatell.telemetry.TelemetryHandler` :rtype: :class:`~wuttatell.telemetry.TelemetryHandler`
""" """
if not hasattr(self, 'telemetry_handler'): if not hasattr(self, "telemetry_handler"):
spec = self.config.get(f'{self.appname}.telemetry.handler', spec = self.config.get(
default='wuttatell.telemetry:TelemetryHandler') f"{self.appname}.telemetry.handler",
default="wuttatell.telemetry:TelemetryHandler",
)
factory = self.app.load_object(spec) factory = self.app.load_object(spec)
self.telemetry_handler = factory(self.config, **kwargs) self.telemetry_handler = factory(self.config, **kwargs)
return self.telemetry_handler return self.telemetry_handler

View file

@ -37,18 +37,24 @@ log = logging.getLogger(__name__)
@wutta_typer.command() @wutta_typer.command()
def tell( def tell(
ctx: typer.Context, ctx: typer.Context,
profile: Annotated[ profile: Annotated[
str, str,
typer.Option('--profile', '-p', typer.Option(
help="Profile (type) of telemetry data to collect. " "--profile",
"This also determines where/how data is submitted. " "-p",
"If not specified, default profile is assumed.")] = None, help="Profile (type) of telemetry data to collect. "
dry_run: Annotated[ "This also determines where/how data is submitted. "
bool, "If not specified, default profile is assumed.",
typer.Option('--dry-run', ),
help="Go through all the motions but do not submit " ] = None,
"the data to server.")] = False, dry_run: Annotated[
bool,
typer.Option(
"--dry-run",
help="Go through all the motions but do not submit " "the data to server.",
),
] = False,
): ):
""" """
Collect and submit telemetry data Collect and submit telemetry data
@ -58,7 +64,7 @@ def tell(
telemetry = app.get_telemetry_handler() telemetry = app.get_telemetry_handler()
data = telemetry.collect_all_data(profile=profile) data = telemetry.collect_all_data(profile=profile)
log.info("data collected for: %s", ', '.join(sorted(data))) log.info("data collected for: %s", ", ".join(sorted(data)))
log.debug("%s", data) log.debug("%s", data)
if dry_run: if dry_run:

View file

@ -71,23 +71,30 @@ class SimpleAPIClient:
API requests. API requests.
""" """
def __init__(self, config, base_url=None, token=None, ssl_verify=None, max_retries=None): def __init__(
self, config, base_url=None, token=None, ssl_verify=None, max_retries=None
):
self.config = config self.config = config
self.base_url = base_url or self.config.require(f'{self.config.appname}.api.base_url') self.base_url = base_url or self.config.require(
self.base_url = self.base_url.rstrip('/') f"{self.config.appname}.api.base_url"
self.token = token or self.config.require(f'{self.config.appname}.api.token') )
self.base_url = self.base_url.rstrip("/")
self.token = token or self.config.require(f"{self.config.appname}.api.token")
if max_retries is not None: if max_retries is not None:
self.max_retries = max_retries self.max_retries = max_retries
else: else:
self.max_retries = self.config.get_int(f'{self.config.appname}.api.max_retries') self.max_retries = self.config.get_int(
f"{self.config.appname}.api.max_retries"
)
if ssl_verify is not None: if ssl_verify is not None:
self.ssl_verify = ssl_verify self.ssl_verify = ssl_verify
else: else:
self.ssl_verify = self.config.get_bool(f'{self.config.appname}.api.ssl_verify', self.ssl_verify = self.config.get_bool(
default=True) f"{self.config.appname}.api.ssl_verify", default=True
)
self.session = None self.session = None
@ -122,14 +129,18 @@ class SimpleAPIClient:
# without it, can get error response: # without it, can get error response:
# 400 Client Error: Bad CSRF Origin for url # 400 Client Error: Bad CSRF Origin for url
parts = urlparse(self.base_url) parts = urlparse(self.base_url)
self.session.headers.update({ self.session.headers.update(
'Origin': f'{parts.scheme}://{parts.netloc}', {
}) "Origin": f"{parts.scheme}://{parts.netloc}",
}
)
# authenticate via token only (for now?) # authenticate via token only (for now?)
self.session.headers.update({ self.session.headers.update(
'Authorization': f'Bearer {self.token}', {
}) "Authorization": f"Bearer {self.token}",
}
)
def make_request(self, request_method, api_method, params=None, data=None): def make_request(self, request_method, api_method, params=None, data=None):
""" """
@ -153,13 +164,12 @@ class SimpleAPIClient:
:rtype: :class:`requests:requests.Response` instance. :rtype: :class:`requests:requests.Response` instance.
""" """
self.init_session() self.init_session()
api_method = api_method.lstrip('/') api_method = api_method.lstrip("/")
url = f'{self.base_url}/{api_method}' url = f"{self.base_url}/{api_method}"
if request_method == 'GET': if request_method == "GET":
response = self.session.get(url, params=params) response = self.session.get(url, params=params)
elif request_method == 'POST': elif request_method == "POST":
response = self.session.post(url, params=params, response = self.session.post(url, params=params, data=json.dumps(data))
data=json.dumps(data))
else: else:
raise NotImplementedError(f"unsupported request method: {request_method}") raise NotImplementedError(f"unsupported request method: {request_method}")
response.raise_for_status() response.raise_for_status()
@ -180,7 +190,7 @@ class SimpleAPIClient:
:rtype: :class:`requests:requests.Response` instance. :rtype: :class:`requests:requests.Response` instance.
""" """
return self.make_request('GET', api_method, params=params) return self.make_request("GET", api_method, params=params)
def post(self, api_method, **kwargs): def post(self, api_method, **kwargs):
""" """
@ -194,4 +204,4 @@ class SimpleAPIClient:
:rtype: :class:`requests:requests.Response` instance. :rtype: :class:`requests:requests.Response` instance.
""" """
return self.make_request('POST', api_method, **kwargs) return self.make_request("POST", api_method, **kwargs)

View file

@ -49,7 +49,7 @@ class TelemetryHandler(GenericHandler):
if isinstance(profile, TelemetryProfile): if isinstance(profile, TelemetryProfile):
return profile return profile
return TelemetryProfile(self.config, profile or 'default') return TelemetryProfile(self.config, profile or "default")
def collect_all_data(self, profile=None): def collect_all_data(self, profile=None):
""" """
@ -76,7 +76,7 @@ class TelemetryHandler(GenericHandler):
profile = self.get_profile(profile) profile = self.get_profile(profile)
for key in profile.collect_keys: for key in profile.collect_keys:
collector = getattr(self, f'collect_data_{key}') collector = getattr(self, f"collect_data_{key}")
data[key] = collector(profile=profile) data[key] = collector(profile=profile)
self.normalize_errors(data) self.normalize_errors(data)
@ -87,11 +87,11 @@ class TelemetryHandler(GenericHandler):
all_errors = [] all_errors = []
for key, value in data.items(): for key, value in data.items():
if value: if value:
errors = value.pop('errors', None) errors = value.pop("errors", None)
if errors: if errors:
all_errors.extend(errors) all_errors.extend(errors)
if all_errors: if all_errors:
data['errors'] = all_errors data["errors"] = all_errors
def collect_data_os(self, profile, **kwargs): def collect_data_os(self, profile, **kwargs):
""" """
@ -119,40 +119,40 @@ class TelemetryHandler(GenericHandler):
errors = [] errors = []
# release # release
release_path = kwargs.get('release_path', '/etc/os-release') release_path = kwargs.get("release_path", "/etc/os-release")
try: try:
with open(release_path, 'rt') as f: with open(release_path, "rt") as f:
output = f.read() output = f.read()
except: except:
errors.append(f"Failed to read {release_path}") errors.append(f"Failed to read {release_path}")
else: else:
release = {} release = {}
pattern = re.compile(r'^([^=]+)=(.*)$') pattern = re.compile(r"^([^=]+)=(.*)$")
for line in output.strip().split('\n'): for line in output.strip().split("\n"):
if match := pattern.match(line): if match := pattern.match(line):
key, val = match.groups() key, val = match.groups()
if val.startswith('"') and val.endswith('"'): if val.startswith('"') and val.endswith('"'):
val = val.strip('"') val = val.strip('"')
release[key] = val release[key] = val
try: try:
data['release_id'] = release['ID'] data["release_id"] = release["ID"]
data['release_version'] = release['VERSION_ID'] data["release_version"] = release["VERSION_ID"]
data['release_full'] = release['PRETTY_NAME'] data["release_full"] = release["PRETTY_NAME"]
except KeyError: except KeyError:
errors.append(f"Failed to parse {release_path}") errors.append(f"Failed to parse {release_path}")
# timezone # timezone
timezone_path = kwargs.get('timezone_path', '/etc/timezone') timezone_path = kwargs.get("timezone_path", "/etc/timezone")
try: try:
with open(timezone_path, 'rt') as f: with open(timezone_path, "rt") as f:
output = f.read() output = f.read()
except: except:
errors.append(f"Failed to read {timezone_path}") errors.append(f"Failed to read {timezone_path}")
else: else:
data['timezone'] = output.strip() data["timezone"] = output.strip()
if errors: if errors:
data['errors'] = errors data["errors"] = errors
return data return data
def collect_data_python(self, profile): def collect_data_python(self, profile):
@ -191,31 +191,32 @@ class TelemetryHandler(GenericHandler):
errors = [] errors = []
# envroot determines python executable # envroot determines python executable
envroot = profile.get_str('collect.python.envroot') envroot = profile.get_str("collect.python.envroot")
if envroot: if envroot:
data['envroot'] = envroot data["envroot"] = envroot
python = os.path.join(envroot, 'bin/python') python = os.path.join(envroot, "bin/python")
else: else:
python = profile.get_str('collect.python.executable', python = profile.get_str(
default='/usr/bin/python3') "collect.python.executable", default="/usr/bin/python3"
)
# python version # python version
data['executable'] = python data["executable"] = python
try: try:
output = subprocess.check_output([python, '--version']) output = subprocess.check_output([python, "--version"])
except (subprocess.CalledProcessError, FileNotFoundError) as err: except (subprocess.CalledProcessError, FileNotFoundError) as err:
errors.append("Failed to execute `python --version`") errors.append("Failed to execute `python --version`")
errors.append(str(err)) errors.append(str(err))
else: else:
output = output.decode('utf_8').strip() output = output.decode("utf_8").strip()
data['release_full'] = output data["release_full"] = output
if match := re.match(r'^Python (\d+\.\d+\.\d+)', output): if match := re.match(r"^Python (\d+\.\d+\.\d+)", output):
data['release_version'] = match.group(1) data["release_version"] = match.group(1)
else: else:
errors.append("Failed to parse Python version") errors.append("Failed to parse Python version")
if errors: if errors:
data['errors'] = errors data["errors"] = errors
return data return data
def submit_all_data(self, profile=None, data=None): def submit_all_data(self, profile=None, data=None):
@ -269,6 +270,6 @@ class TelemetryProfile(WuttaConfigProfile):
def load(self): def load(self):
""" """ """ """
keys = self.get_str('collect.keys', default='os,python') keys = self.get_str("collect.keys", default="os,python")
self.collect_keys = self.config.parse_list(keys) self.collect_keys = self.config.parse_list(keys)
self.submit_url = self.get_str('submit.url') self.submit_url = self.get_str("submit.url")

View file

@ -15,10 +15,10 @@ def release(c, skip_tests=False):
Release a new version of WuttaTell Release a new version of WuttaTell
""" """
if not skip_tests: if not skip_tests:
c.run('pytest') c.run("pytest")
if os.path.exists('dist'): if os.path.exists("dist"):
shutil.rmtree('dist') shutil.rmtree("dist")
c.run('python -m build --sdist') c.run("python -m build --sdist")
c.run('twine upload dist/*') c.run("twine upload dist/*")

View file

@ -14,10 +14,10 @@ class TestTell(ConfigTestCase):
ctx = Mock() ctx = Mock()
ctx.parent = Mock() ctx.parent = Mock()
ctx.parent.wutta_config = self.config ctx.parent.wutta_config = self.config
with patch.object(TelemetryHandler, 'submit_all_data') as submit_all_data: with patch.object(TelemetryHandler, "submit_all_data") as submit_all_data:
# dry run # dry run
with patch.object(TelemetryHandler, 'collect_all_data') as collect_all_data: with patch.object(TelemetryHandler, "collect_all_data") as collect_all_data:
mod.tell(ctx, dry_run=True) mod.tell(ctx, dry_run=True)
collect_all_data.assert_called_once_with(profile=None) collect_all_data.assert_called_once_with(profile=None)
submit_all_data.assert_not_called() submit_all_data.assert_not_called()

View file

@ -21,33 +21,42 @@ class TestSimpleAPIClient(ConfigTestCase):
def test_constructor(self): def test_constructor(self):
# caller specifies params # caller specifies params
client = self.make_client(base_url='https://example.com/api/', client = self.make_client(
token='XYZPDQ12345', base_url="https://example.com/api/",
ssl_verify=False, token="XYZPDQ12345",
max_retries=5) ssl_verify=False,
self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash max_retries=5,
self.assertEqual(client.token, 'XYZPDQ12345') )
self.assertEqual(
client.base_url, "https://example.com/api"
) # no trailing slash
self.assertEqual(client.token, "XYZPDQ12345")
self.assertFalse(client.ssl_verify) self.assertFalse(client.ssl_verify)
self.assertEqual(client.max_retries, 5) self.assertEqual(client.max_retries, 5)
self.assertIsNone(client.session) self.assertIsNone(client.session)
# now with some defaults # now with some defaults
client = self.make_client(base_url='https://example.com/api/', client = self.make_client(
token='XYZPDQ12345') base_url="https://example.com/api/", token="XYZPDQ12345"
self.assertEqual(client.base_url, 'https://example.com/api') # no trailing slash )
self.assertEqual(client.token, 'XYZPDQ12345') self.assertEqual(
client.base_url, "https://example.com/api"
) # no trailing slash
self.assertEqual(client.token, "XYZPDQ12345")
self.assertTrue(client.ssl_verify) self.assertTrue(client.ssl_verify)
self.assertIsNone(client.max_retries) self.assertIsNone(client.max_retries)
self.assertIsNone(client.session) self.assertIsNone(client.session)
# now from config # now from config
self.config.setdefault('wutta.api.base_url', 'https://another.com/api/') self.config.setdefault("wutta.api.base_url", "https://another.com/api/")
self.config.setdefault('wutta.api.token', '9843243q4') self.config.setdefault("wutta.api.token", "9843243q4")
self.config.setdefault('wutta.api.ssl_verify', 'false') self.config.setdefault("wutta.api.ssl_verify", "false")
self.config.setdefault('wutta.api.max_retries', '4') self.config.setdefault("wutta.api.max_retries", "4")
client = self.make_client() client = self.make_client()
self.assertEqual(client.base_url, 'https://another.com/api') # no trailing slash self.assertEqual(
self.assertEqual(client.token, '9843243q4') client.base_url, "https://another.com/api"
) # no trailing slash
self.assertEqual(client.token, "9843243q4")
self.assertFalse(client.ssl_verify) self.assertFalse(client.ssl_verify)
self.assertEqual(client.max_retries, 4) self.assertEqual(client.max_retries, 4)
self.assertIsNone(client.session) self.assertIsNone(client.session)
@ -55,16 +64,18 @@ class TestSimpleAPIClient(ConfigTestCase):
def test_init_session(self): def test_init_session(self):
# client begins with no session # client begins with no session
client = self.make_client(base_url='https://example.com/api', token='1234') client = self.make_client(base_url="https://example.com/api", token="1234")
self.assertIsNone(client.session) self.assertIsNone(client.session)
# session is created here # session is created here
client.init_session() client.init_session()
self.assertIsInstance(client.session, requests.Session) self.assertIsInstance(client.session, requests.Session)
self.assertTrue(client.session.verify) self.assertTrue(client.session.verify)
self.assertTrue(all([a.max_retries.total == 0 for a in client.session.adapters.values()])) self.assertTrue(
self.assertIn('Authorization', client.session.headers) all([a.max_retries.total == 0 for a in client.session.adapters.values()])
self.assertEqual(client.session.headers['Authorization'], 'Bearer 1234') )
self.assertIn("Authorization", client.session.headers)
self.assertEqual(client.session.headers["Authorization"], "Bearer 1234")
# session is never re-created # session is never re-created
orig_session = client.session orig_session = client.session
@ -72,94 +83,124 @@ class TestSimpleAPIClient(ConfigTestCase):
self.assertIs(client.session, orig_session) self.assertIs(client.session, orig_session)
# new client/session with no ssl_verify # new client/session with no ssl_verify
client = self.make_client(base_url='https://example.com/api', token='1234', ssl_verify=False) client = self.make_client(
base_url="https://example.com/api", token="1234", ssl_verify=False
)
client.init_session() client.init_session()
self.assertFalse(client.session.verify) self.assertFalse(client.session.verify)
# new client/session with max_retries # new client/session with max_retries
client = self.make_client(base_url='https://example.com/api', token='1234', max_retries=5) client = self.make_client(
base_url="https://example.com/api", token="1234", max_retries=5
)
client.init_session() client.init_session()
self.assertEqual(client.session.adapters['https://example.com/api'].max_retries.total, 5) self.assertEqual(
client.session.adapters["https://example.com/api"].max_retries.total, 5
)
def test_make_request_get(self): def test_make_request_get(self):
# start server # start server
threading.Thread(target=start_server).start() threading.Thread(target=start_server).start()
while not SERVER['running']: while not SERVER["running"]:
time.sleep(0.02) time.sleep(0.02)
# server returns our headers # server returns our headers
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) client = self.make_client(
response = client.make_request('GET', '/telemetry') base_url=f'http://127.0.0.1:{SERVER["port"]}',
token="1234",
ssl_verify=False,
)
response = client.make_request("GET", "/telemetry")
result = response.json() result = response.json()
self.assertIn('headers', result) self.assertIn("headers", result)
self.assertIn('Authorization', result['headers']) self.assertIn("Authorization", result["headers"])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') self.assertEqual(result["headers"]["Authorization"], "Bearer 1234")
self.assertNotIn('payload', result) self.assertNotIn("payload", result)
def test_make_request_post(self): def test_make_request_post(self):
# start server # start server
threading.Thread(target=start_server).start() threading.Thread(target=start_server).start()
while not SERVER['running']: while not SERVER["running"]:
time.sleep(0.02) time.sleep(0.02)
# server returns our headers + payload # server returns our headers + payload
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) client = self.make_client(
response = client.make_request('POST', '/telemetry', data={'os': {'name': 'debian'}}) base_url=f'http://127.0.0.1:{SERVER["port"]}',
token="1234",
ssl_verify=False,
)
response = client.make_request(
"POST", "/telemetry", data={"os": {"name": "debian"}}
)
result = response.json() result = response.json()
self.assertIn('headers', result) self.assertIn("headers", result)
self.assertIn('Authorization', result['headers']) self.assertIn("Authorization", result["headers"])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') self.assertEqual(result["headers"]["Authorization"], "Bearer 1234")
self.assertIn('payload', result) self.assertIn("payload", result)
self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}}) self.assertEqual(json.loads(result["payload"]), {"os": {"name": "debian"}})
def test_make_request_unsupported(self): def test_make_request_unsupported(self):
# start server # start server
threading.Thread(target=start_server).start() threading.Thread(target=start_server).start()
while not SERVER['running']: while not SERVER["running"]:
time.sleep(0.02) time.sleep(0.02)
# e.g. DELETE is not implemented # e.g. DELETE is not implemented
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) client = self.make_client(
self.assertRaises(NotImplementedError, client.make_request, 'DELETE', '/telemetry') base_url=f'http://127.0.0.1:{SERVER["port"]}',
token="1234",
ssl_verify=False,
)
self.assertRaises(
NotImplementedError, client.make_request, "DELETE", "/telemetry"
)
# nb. issue valid request to stop the server # nb. issue valid request to stop the server
client.make_request('GET', '/telemetry') client.make_request("GET", "/telemetry")
def test_get(self): def test_get(self):
# start server # start server
threading.Thread(target=start_server).start() threading.Thread(target=start_server).start()
while not SERVER['running']: while not SERVER["running"]:
time.sleep(0.02) time.sleep(0.02)
# server returns our headers # server returns our headers
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) client = self.make_client(
response = client.get('/telemetry') base_url=f'http://127.0.0.1:{SERVER["port"]}',
token="1234",
ssl_verify=False,
)
response = client.get("/telemetry")
result = response.json() result = response.json()
self.assertIn('headers', result) self.assertIn("headers", result)
self.assertIn('Authorization', result['headers']) self.assertIn("Authorization", result["headers"])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') self.assertEqual(result["headers"]["Authorization"], "Bearer 1234")
self.assertNotIn('payload', result) self.assertNotIn("payload", result)
def test_post(self): def test_post(self):
# start server # start server
threading.Thread(target=start_server).start() threading.Thread(target=start_server).start()
while not SERVER['running']: while not SERVER["running"]:
time.sleep(0.02) time.sleep(0.02)
# server returns our headers + payload # server returns our headers + payload
client = self.make_client(base_url=f'http://127.0.0.1:{SERVER["port"]}', token='1234', ssl_verify=False) client = self.make_client(
response = client.post('/telemetry', data={'os': {'name': 'debian'}}) base_url=f'http://127.0.0.1:{SERVER["port"]}',
token="1234",
ssl_verify=False,
)
response = client.post("/telemetry", data={"os": {"name": "debian"}})
result = response.json() result = response.json()
self.assertIn('headers', result) self.assertIn("headers", result)
self.assertIn('Authorization', result['headers']) self.assertIn("Authorization", result["headers"])
self.assertEqual(result['headers']['Authorization'], 'Bearer 1234') self.assertEqual(result["headers"]["Authorization"], "Bearer 1234")
self.assertIn('payload', result) self.assertIn("payload", result)
self.assertEqual(json.loads(result['payload']), {'os': {'name': 'debian'}}) self.assertEqual(json.loads(result["payload"]), {"os": {"name": "debian"}})
class FakeRequestHandler(BaseHTTPRequestHandler): class FakeRequestHandler(BaseHTTPRequestHandler):
@ -167,37 +208,38 @@ class FakeRequestHandler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
headers = dict([(k, v) for k, v in self.headers.items()]) headers = dict([(k, v) for k, v in self.headers.items()])
result = {'headers': headers} result = {"headers": headers}
result = json.dumps(result).encode('utf_8') result = json.dumps(result).encode("utf_8")
self.send_response(HTTPStatus.OK) self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", 'text/json') self.send_header("Content-Type", "text/json")
self.send_header("Content-Length", str(len(result))) self.send_header("Content-Length", str(len(result)))
self.end_headers() self.end_headers()
self.wfile.write(result) self.wfile.write(result)
def do_POST(self): def do_POST(self):
headers = dict([(k, v) for k, v in self.headers.items()]) headers = dict([(k, v) for k, v in self.headers.items()])
length = int(self.headers.get('Content-Length')) length = int(self.headers.get("Content-Length"))
payload = self.rfile.read(length).decode('utf_8') payload = self.rfile.read(length).decode("utf_8")
result = {'headers': headers, 'payload': payload} result = {"headers": headers, "payload": payload}
result = json.dumps(result).encode('utf_8') result = json.dumps(result).encode("utf_8")
self.send_response(HTTPStatus.OK) self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", 'text/json') self.send_header("Content-Type", "text/json")
self.send_header("Content-Length", str(len(result))) self.send_header("Content-Length", str(len(result)))
self.end_headers() self.end_headers()
self.wfile.write(result) self.wfile.write(result)
SERVER = {'running': False, 'port': 7314} SERVER = {"running": False, "port": 7314}
def start_server(): def start_server():
if SERVER['running']: if SERVER["running"]:
raise RuntimeError("http server is already running") raise RuntimeError("http server is already running")
with HTTPServer(('127.0.0.1', SERVER['port']), FakeRequestHandler) as httpd: with HTTPServer(("127.0.0.1", SERVER["port"]), FakeRequestHandler) as httpd:
SERVER['running'] = True SERVER["running"] = True
httpd.handle_request() httpd.handle_request()
SERVER['running'] = False SERVER["running"] = False

View file

@ -19,193 +19,208 @@ class TestTelemetryHandler(ConfigTestCase):
def test_get_profile(self): def test_get_profile(self):
# default # default
default = self.handler.get_profile('default') default = self.handler.get_profile("default")
self.assertIsInstance(default, mod.TelemetryProfile) self.assertIsInstance(default, mod.TelemetryProfile)
self.assertEqual(default.key, 'default') self.assertEqual(default.key, "default")
# same profile is returned # same profile is returned
profile = self.handler.get_profile(default) profile = self.handler.get_profile(default)
self.assertIs(profile, default) self.assertIs(profile, default)
def test_collect_data_os(self): def test_collect_data_os(self):
profile = self.handler.get_profile('default') profile = self.handler.get_profile("default")
# typical / working scenario # typical / working scenario
data = self.handler.collect_data_os(profile) data = self.handler.collect_data_os(profile)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertIn('release_id', data) self.assertIn("release_id", data)
self.assertIn('release_version', data) self.assertIn("release_version", data)
self.assertIn('release_full', data) self.assertIn("release_full", data)
self.assertIn('timezone', data) self.assertIn("timezone", data)
self.assertNotIn('errors', data) self.assertNotIn("errors", data)
# unreadable release path # unreadable release path
data = self.handler.collect_data_os(profile, release_path='/a/path/which/does/not/exist') data = self.handler.collect_data_os(
profile, release_path="/a/path/which/does/not/exist"
)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertNotIn('release_id', data) self.assertNotIn("release_id", data)
self.assertNotIn('release_version', data) self.assertNotIn("release_version", data)
self.assertNotIn('release_full', data) self.assertNotIn("release_full", data)
self.assertIn('timezone', data) self.assertIn("timezone", data)
self.assertIn('errors', data) self.assertIn("errors", data)
self.assertEqual(data['errors'], [ self.assertEqual(
"Failed to read /a/path/which/does/not/exist" data["errors"], ["Failed to read /a/path/which/does/not/exist"]
]) )
# unparsable release path # unparsable release path
path = self.write_file('release', "bad-content") path = self.write_file("release", "bad-content")
data = self.handler.collect_data_os(profile, release_path=path) data = self.handler.collect_data_os(profile, release_path=path)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertNotIn('release_id', data) self.assertNotIn("release_id", data)
self.assertNotIn('release_version', data) self.assertNotIn("release_version", data)
self.assertNotIn('release_full', data) self.assertNotIn("release_full", data)
self.assertIn('timezone', data) self.assertIn("timezone", data)
self.assertIn('errors', data) self.assertIn("errors", data)
self.assertEqual(data['errors'], [ self.assertEqual(data["errors"], [f"Failed to parse {path}"])
f"Failed to parse {path}"
])
# unreadable timezone path # unreadable timezone path
data = self.handler.collect_data_os(profile, timezone_path='/a/path/which/does/not/exist') data = self.handler.collect_data_os(
profile, timezone_path="/a/path/which/does/not/exist"
)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertIn('release_id', data) self.assertIn("release_id", data)
self.assertIn('release_version', data) self.assertIn("release_version", data)
self.assertIn('release_full', data) self.assertIn("release_full", data)
self.assertNotIn('timezone', data) self.assertNotIn("timezone", data)
self.assertIn('errors', data) self.assertIn("errors", data)
self.assertEqual(data['errors'], [ self.assertEqual(
"Failed to read /a/path/which/does/not/exist" data["errors"], ["Failed to read /a/path/which/does/not/exist"]
]) )
def test_collect_data_python(self): def test_collect_data_python(self):
profile = self.handler.get_profile('default') profile = self.handler.get_profile("default")
# typical / working (system-wide) scenario # typical / working (system-wide) scenario
data = self.handler.collect_data_python(profile) data = self.handler.collect_data_python(profile)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertNotIn('envroot', data) self.assertNotIn("envroot", data)
self.assertIn('executable', data) self.assertIn("executable", data)
self.assertIn('release_full', data) self.assertIn("release_full", data)
self.assertIn('release_version', data) self.assertIn("release_version", data)
self.assertNotIn('errors', data) self.assertNotIn("errors", data)
# missing executable # missing executable
with patch.dict(self.config.defaults, {'wutta.telemetry.default.collect.python.executable': '/bad/path'}): with patch.dict(
self.config.defaults,
{"wutta.telemetry.default.collect.python.executable": "/bad/path"},
):
data = self.handler.collect_data_python(profile) data = self.handler.collect_data_python(profile)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertNotIn('envroot', data) self.assertNotIn("envroot", data)
self.assertIn('executable', data) self.assertIn("executable", data)
self.assertNotIn('release_full', data) self.assertNotIn("release_full", data)
self.assertNotIn('release_version', data) self.assertNotIn("release_version", data)
self.assertIn('errors', data) self.assertIn("errors", data)
self.assertEqual(data['errors'][0], "Failed to execute `python --version`") self.assertEqual(data["errors"][0], "Failed to execute `python --version`")
# unparsable executable output # unparsable executable output
with patch.object(mod, 'subprocess') as subprocess: with patch.object(mod, "subprocess") as subprocess:
subprocess.check_output.return_value = 'bad output'.encode('utf_8') subprocess.check_output.return_value = "bad output".encode("utf_8")
data = self.handler.collect_data_python(profile) data = self.handler.collect_data_python(profile)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertNotIn('envroot', data) self.assertNotIn("envroot", data)
self.assertIn('executable', data) self.assertIn("executable", data)
self.assertIn('release_full', data) self.assertIn("release_full", data)
self.assertNotIn('release_version', data) self.assertNotIn("release_version", data)
self.assertIn('errors', data) self.assertIn("errors", data)
self.assertEqual(data['errors'], [ self.assertEqual(
"Failed to parse Python version", data["errors"],
]) [
"Failed to parse Python version",
],
)
# typical / working (virtual environment) scenario # typical / working (virtual environment) scenario
self.config.setdefault('wutta.telemetry.default.collect.python.envroot', '/srv/envs/poser') self.config.setdefault(
"wutta.telemetry.default.collect.python.envroot", "/srv/envs/poser"
)
data = self.handler.collect_data_python(profile) data = self.handler.collect_data_python(profile)
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertIn('executable', data) self.assertIn("executable", data)
self.assertEqual(data['executable'], '/srv/envs/poser/bin/python') self.assertEqual(data["executable"], "/srv/envs/poser/bin/python")
self.assertNotIn('release_full', data) self.assertNotIn("release_full", data)
self.assertNotIn('release_version', data) self.assertNotIn("release_version", data)
self.assertIn('errors', data) self.assertIn("errors", data)
self.assertEqual(data['errors'][0], "Failed to execute `python --version`") self.assertEqual(data["errors"][0], "Failed to execute `python --version`")
def test_normalize_errors(self): def test_normalize_errors(self):
data = { data = {
'os': { "os": {
'timezone': 'America/Chicago', "timezone": "America/Chicago",
'errors': [ "errors": [
"Failed to read /etc/os-release", "Failed to read /etc/os-release",
], ],
}, },
'python': { "python": {
'executable': '/usr/bin/python3', "executable": "/usr/bin/python3",
'errors': [ "errors": [
"Failed to run `python --version`", "Failed to run `python --version`",
], ],
}, },
} }
self.handler.normalize_errors(data) self.handler.normalize_errors(data)
self.assertIn('os', data) self.assertIn("os", data)
self.assertIn('python', data) self.assertIn("python", data)
self.assertIn('errors', data) self.assertIn("errors", data)
self.assertEqual(data['errors'], [ self.assertEqual(
"Failed to read /etc/os-release", data["errors"],
"Failed to run `python --version`", [
]) "Failed to read /etc/os-release",
"Failed to run `python --version`",
],
)
def test_collect_all_data(self): def test_collect_all_data(self):
# typical / working scenario # typical / working scenario
data = self.handler.collect_all_data() data = self.handler.collect_all_data()
self.assertIsInstance(data, dict) self.assertIsInstance(data, dict)
self.assertIn('os', data) self.assertIn("os", data)
self.assertIn('python', data) self.assertIn("python", data)
self.assertNotIn('errors', data) self.assertNotIn("errors", data)
def test_submit_all_data(self): def test_submit_all_data(self):
profile = self.handler.get_profile('default') profile = self.handler.get_profile("default")
profile.submit_url = '/testing' profile.submit_url = "/testing"
with patch.object(mod, 'SimpleAPIClient') as SimpleAPIClient: with patch.object(mod, "SimpleAPIClient") as SimpleAPIClient:
client = MagicMock() client = MagicMock()
SimpleAPIClient.return_value = client SimpleAPIClient.return_value = client
# collecting all data # collecting all data
with patch.object(self.handler, 'collect_all_data') as collect_all_data: with patch.object(self.handler, "collect_all_data") as collect_all_data:
collect_all_data.return_value = [] collect_all_data.return_value = []
self.handler.submit_all_data(profile) self.handler.submit_all_data(profile)
collect_all_data.assert_called_once_with(profile) collect_all_data.assert_called_once_with(profile)
client.post.assert_called_once_with('/testing', data=[]) client.post.assert_called_once_with("/testing", data=[])
# use data from caller # use data from caller
client.post.reset_mock() client.post.reset_mock()
self.handler.submit_all_data(profile, data=['foo']) self.handler.submit_all_data(profile, data=["foo"])
client.post.assert_called_once_with('/testing', data=['foo']) client.post.assert_called_once_with("/testing", data=["foo"])
class TestTelemetryProfile(ConfigTestCase): class TestTelemetryProfile(ConfigTestCase):
def make_profile(self, key='default'): def make_profile(self, key="default"):
return mod.TelemetryProfile(self.config, key) return mod.TelemetryProfile(self.config, key)
def test_section(self): def test_section(self):
# default # default
profile = self.make_profile() profile = self.make_profile()
self.assertEqual(profile.section, 'wutta.telemetry') self.assertEqual(profile.section, "wutta.telemetry")
# custom appname # custom appname
with patch.object(self.config, 'appname', new='wuttatest'): with patch.object(self.config, "appname", new="wuttatest"):
profile = self.make_profile() profile = self.make_profile()
self.assertEqual(profile.section, 'wuttatest.telemetry') self.assertEqual(profile.section, "wuttatest.telemetry")
def test_load(self): def test_load(self):
# defaults # defaults
profile = self.make_profile() profile = self.make_profile()
self.assertEqual(profile.collect_keys, ['os', 'python']) self.assertEqual(profile.collect_keys, ["os", "python"])
self.assertIsNone(profile.submit_url) self.assertIsNone(profile.submit_url)
# configured # configured
self.config.setdefault('wutta.telemetry.default.collect.keys', 'os,network,python') self.config.setdefault(
self.config.setdefault('wutta.telemetry.default.submit.url', '/nodes/telemetry') "wutta.telemetry.default.collect.keys", "os,network,python"
)
self.config.setdefault("wutta.telemetry.default.submit.url", "/nodes/telemetry")
profile = self.make_profile() profile = self.make_profile()
self.assertEqual(profile.collect_keys, ['os', 'network', 'python']) self.assertEqual(profile.collect_keys, ["os", "network", "python"])
self.assertEqual(profile.submit_url, '/nodes/telemetry') self.assertEqual(profile.submit_url, "/nodes/telemetry")