Let just one "task" handle collect/transmit of progress for websocket

first client to connect, will cause task to start; subsequent clients
are just added to running set, for broadcast messaging
This commit is contained in:
Lance Edgar 2022-08-20 21:19:20 -05:00
parent e93063a344
commit 0a113611e8
2 changed files with 99 additions and 36 deletions

View file

@ -77,7 +77,7 @@ class WebsocketView(object):
model = self.model
# load the user's web session
user_session = await self.get_user_session(scope)
user_session = self.get_user_session(scope)
if user_session:
# determine user uuid
@ -91,7 +91,7 @@ class WebsocketView(object):
# load user proper
return session.query(model.User).get(user_uuid)
async def get_user_session(self, scope):
def get_user_session(self, scope):
settings = self.registry.settings
beaker_key = settings['beaker.session.key']
beaker_secret = settings['beaker.session.secret']

View file

@ -33,35 +33,92 @@ from tailbone.views.asgi import WebsocketView
from tailbone.progress import get_basic_session
class UpgradeWS(WebsocketView):
class UpgradeExecutionProgressWS(WebsocketView):
async def execution_progress(self, scope, receive, send):
rattail_config = self.registry['rattail_config']
# keep track of all "global" state for this socket
global_state = {
'upgrades': {},
}
async def __call__(self, scope, receive, send):
app = self.get_rattail_app()
# is user allowed to see this?
if not await self.authorize(scope, receive, send, 'upgrades.execute'):
return
# this tracks when client disconnects
state = {'disconnected': False}
# keep track of client state
client_state = {
'uuid': app.make_uuid(),
'disconnected': False,
'scope': scope,
'receive': receive,
'send': send,
}
# parse upgrade uuid from query string
query = scope['query_string'].decode('utf_8')
query = parse_qs(query)
uuid = query['uuid'][0]
# first client to request progress for this upgrade, must
# start a task to manage the collect/transmit logic for
# progress data, on behalf of this and/or any future clients
started_task = None
if uuid not in self.global_state['upgrades']:
# this upgrade is new to us; establish state and add first client
upgrade_state = self.global_state['upgrades'][uuid] = {
'clients': {client_state['uuid']: client_state},
}
# start task for transmit of progress data to all clients
started_task = asyncio.create_task(self.manage_progress(uuid))
else:
# progress task is already running, just add new client
upgrade_state = self.global_state['upgrades'][uuid]
upgrade_state['clients'][client_state['uuid']] = client_state
async def wait_for_disconnect():
message = await receive()
if message['type'] == 'websocket.disconnect':
state['disconnected'] = True
client_state['disconnected'] = True
# watch for client disconnect, while we do other things
# wait forever, until client disconnects
asyncio.create_task(wait_for_disconnect())
while not client_state['disconnected']:
query = scope['query_string'].decode('utf_8')
query = parse_qs(query)
uuid = query['uuid'][0]
# can stop if upgrade has completed
if uuid not in self.global_state['upgrades']:
break
await asyncio.sleep(0.1)
# remove client from global set, if upgrade still running
if client_state['disconnected']:
upgrade_state = self.global_state['upgrades'].get(uuid)
if upgrade_state:
del upgrade_state['clients'][client_state['uuid']]
# must continue to wait for other clients, if this client was
# the first to request progress
if started_task:
await started_task
async def manage_progress(self, uuid):
"""
Task which handles collect / transmit of progress data, for
sake of all attached clients.
"""
progress_session_id = 'upgrades.{}.execution_progress'.format(uuid)
progress_session = get_basic_session(rattail_config,
progress_session = get_basic_session(self.rattail_config,
id=progress_session_id)
# do the rest forever, until client disconnects
while not state['disconnected']:
upgrade_state = self.global_state['upgrades'][uuid]
clients = upgrade_state['clients']
while clients:
# load latest progress data
progress_session.load()
@ -69,26 +126,30 @@ class UpgradeWS(WebsocketView):
# when upgrade progress is complete...
if progress_session.get('complete'):
# maybe set success flash msg
# maybe set success flash msg (for all clients)
msg = progress_session.get('success_msg')
if msg:
user_session = await self.get_user_session(scope)
for client in clients.values():
user_session = self.get_user_session(client['scope'])
user_session.flash(msg)
user_session.persist()
# tell client progress is complete
await send({'type': 'websocket.send',
# tell clients progress is complete
for client in clients.values():
await client['send']({
'type': 'websocket.send',
'subtype': 'upgrades.execute_progress',
'text': json.dumps({'complete': True})})
# this websocket is done
# this websocket is done, so remove all clients
clients.clear()
break
# we will send this data down to client
data = dict(progress_session)
data = {}
# maybe add more lines from command output
path = rattail_config.upgrade_filepath(uuid, filename='stdout.log')
path = self.rattail_config.upgrade_filepath(uuid, filename='stdout.log')
offset = progress_session.get('stdout.offset', 0)
if os.path.exists(path):
size = os.path.getsize(path) - offset
@ -100,31 +161,33 @@ class UpgradeWS(WebsocketView):
progress_session['stdout.offset'] = offset + size
progress_session.save()
# send data to client
await send({'type': 'websocket.send',
# send data to clients
for client in clients.values():
await client['send']({
'type': 'websocket.send',
'subtype': 'upgrades.execute_progress',
'text': json.dumps(data)})
# pause for 1 second
await asyncio.sleep(1)
# no more clients, no more reason to track this upgrade
del self.global_state['upgrades'][uuid]
@classmethod
def defaults(cls, config):
cls._defaults(config)
@classmethod
def _defaults(cls, config):
# execution progress
config.add_tailbone_websocket('upgrades.execution_progress',
cls, attr='execution_progress')
config.add_tailbone_websocket('upgrades.execution_progress', cls)
def defaults(config, **kwargs):
base = globals()
UpgradeWS = kwargs.get('UpgradeWS', base['UpgradeWS'])
UpgradeWS.defaults(config)
UpgradeExecutionProgressWS = kwargs.get('UpgradeExecutionProgressWS', base['UpgradeExecutionProgressWS'])
UpgradeExecutionProgressWS.defaults(config)
def includeme(config):