Let just one "task" handle collect/transmit of progress for websocket
first client to connect, will cause task to start; subsequent clients are just added to running set, for broadcast messaging
This commit is contained in:
parent
e93063a344
commit
0a113611e8
|
@ -77,7 +77,7 @@ class WebsocketView(object):
|
|||
model = self.model
|
||||
|
||||
# load the user's web session
|
||||
user_session = await self.get_user_session(scope)
|
||||
user_session = self.get_user_session(scope)
|
||||
if user_session:
|
||||
|
||||
# determine user uuid
|
||||
|
@ -91,7 +91,7 @@ class WebsocketView(object):
|
|||
# load user proper
|
||||
return session.query(model.User).get(user_uuid)
|
||||
|
||||
async def get_user_session(self, scope):
|
||||
def get_user_session(self, scope):
|
||||
settings = self.registry.settings
|
||||
beaker_key = settings['beaker.session.key']
|
||||
beaker_secret = settings['beaker.session.secret']
|
||||
|
|
|
@ -33,35 +33,92 @@ from tailbone.views.asgi import WebsocketView
|
|||
from tailbone.progress import get_basic_session
|
||||
|
||||
|
||||
class UpgradeWS(WebsocketView):
|
||||
class UpgradeExecutionProgressWS(WebsocketView):
|
||||
|
||||
async def execution_progress(self, scope, receive, send):
|
||||
rattail_config = self.registry['rattail_config']
|
||||
# keep track of all "global" state for this socket
|
||||
global_state = {
|
||||
'upgrades': {},
|
||||
}
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
app = self.get_rattail_app()
|
||||
|
||||
# is user allowed to see this?
|
||||
if not await self.authorize(scope, receive, send, 'upgrades.execute'):
|
||||
return
|
||||
|
||||
# this tracks when client disconnects
|
||||
state = {'disconnected': False}
|
||||
# keep track of client state
|
||||
client_state = {
|
||||
'uuid': app.make_uuid(),
|
||||
'disconnected': False,
|
||||
'scope': scope,
|
||||
'receive': receive,
|
||||
'send': send,
|
||||
}
|
||||
|
||||
# parse upgrade uuid from query string
|
||||
query = scope['query_string'].decode('utf_8')
|
||||
query = parse_qs(query)
|
||||
uuid = query['uuid'][0]
|
||||
|
||||
# first client to request progress for this upgrade, must
|
||||
# start a task to manage the collect/transmit logic for
|
||||
# progress data, on behalf of this and/or any future clients
|
||||
started_task = None
|
||||
if uuid not in self.global_state['upgrades']:
|
||||
|
||||
# this upgrade is new to us; establish state and add first client
|
||||
upgrade_state = self.global_state['upgrades'][uuid] = {
|
||||
'clients': {client_state['uuid']: client_state},
|
||||
}
|
||||
|
||||
# start task for transmit of progress data to all clients
|
||||
started_task = asyncio.create_task(self.manage_progress(uuid))
|
||||
|
||||
else:
|
||||
|
||||
# progress task is already running, just add new client
|
||||
upgrade_state = self.global_state['upgrades'][uuid]
|
||||
upgrade_state['clients'][client_state['uuid']] = client_state
|
||||
|
||||
async def wait_for_disconnect():
|
||||
message = await receive()
|
||||
if message['type'] == 'websocket.disconnect':
|
||||
state['disconnected'] = True
|
||||
client_state['disconnected'] = True
|
||||
|
||||
# watch for client disconnect, while we do other things
|
||||
# wait forever, until client disconnects
|
||||
asyncio.create_task(wait_for_disconnect())
|
||||
while not client_state['disconnected']:
|
||||
|
||||
query = scope['query_string'].decode('utf_8')
|
||||
query = parse_qs(query)
|
||||
uuid = query['uuid'][0]
|
||||
# can stop if upgrade has completed
|
||||
if uuid not in self.global_state['upgrades']:
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# remove client from global set, if upgrade still running
|
||||
if client_state['disconnected']:
|
||||
upgrade_state = self.global_state['upgrades'].get(uuid)
|
||||
if upgrade_state:
|
||||
del upgrade_state['clients'][client_state['uuid']]
|
||||
|
||||
# must continue to wait for other clients, if this client was
|
||||
# the first to request progress
|
||||
if started_task:
|
||||
await started_task
|
||||
|
||||
async def manage_progress(self, uuid):
|
||||
"""
|
||||
Task which handles collect / transmit of progress data, for
|
||||
sake of all attached clients.
|
||||
"""
|
||||
progress_session_id = 'upgrades.{}.execution_progress'.format(uuid)
|
||||
progress_session = get_basic_session(rattail_config,
|
||||
progress_session = get_basic_session(self.rattail_config,
|
||||
id=progress_session_id)
|
||||
|
||||
# do the rest forever, until client disconnects
|
||||
while not state['disconnected']:
|
||||
upgrade_state = self.global_state['upgrades'][uuid]
|
||||
clients = upgrade_state['clients']
|
||||
while clients:
|
||||
|
||||
# load latest progress data
|
||||
progress_session.load()
|
||||
|
@ -69,26 +126,30 @@ class UpgradeWS(WebsocketView):
|
|||
# when upgrade progress is complete...
|
||||
if progress_session.get('complete'):
|
||||
|
||||
# maybe set success flash msg
|
||||
# maybe set success flash msg (for all clients)
|
||||
msg = progress_session.get('success_msg')
|
||||
if msg:
|
||||
user_session = await self.get_user_session(scope)
|
||||
user_session.flash(msg)
|
||||
user_session.persist()
|
||||
for client in clients.values():
|
||||
user_session = self.get_user_session(client['scope'])
|
||||
user_session.flash(msg)
|
||||
user_session.persist()
|
||||
|
||||
# tell client progress is complete
|
||||
await send({'type': 'websocket.send',
|
||||
'subtype': 'upgrades.execute_progress',
|
||||
'text': json.dumps({'complete': True})})
|
||||
# tell clients progress is complete
|
||||
for client in clients.values():
|
||||
await client['send']({
|
||||
'type': 'websocket.send',
|
||||
'subtype': 'upgrades.execute_progress',
|
||||
'text': json.dumps({'complete': True})})
|
||||
|
||||
# this websocket is done
|
||||
# this websocket is done, so remove all clients
|
||||
clients.clear()
|
||||
break
|
||||
|
||||
# we will send this data down to client
|
||||
data = dict(progress_session)
|
||||
data = {}
|
||||
|
||||
# maybe add more lines from command output
|
||||
path = rattail_config.upgrade_filepath(uuid, filename='stdout.log')
|
||||
path = self.rattail_config.upgrade_filepath(uuid, filename='stdout.log')
|
||||
offset = progress_session.get('stdout.offset', 0)
|
||||
if os.path.exists(path):
|
||||
size = os.path.getsize(path) - offset
|
||||
|
@ -100,31 +161,33 @@ class UpgradeWS(WebsocketView):
|
|||
progress_session['stdout.offset'] = offset + size
|
||||
progress_session.save()
|
||||
|
||||
# send data to client
|
||||
await send({'type': 'websocket.send',
|
||||
'subtype': 'upgrades.execute_progress',
|
||||
'text': json.dumps(data)})
|
||||
# send data to clients
|
||||
for client in clients.values():
|
||||
await client['send']({
|
||||
'type': 'websocket.send',
|
||||
'subtype': 'upgrades.execute_progress',
|
||||
'text': json.dumps(data)})
|
||||
|
||||
# pause for 1 second
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# no more clients, no more reason to track this upgrade
|
||||
del self.global_state['upgrades'][uuid]
|
||||
|
||||
@classmethod
|
||||
def defaults(cls, config):
|
||||
cls._defaults(config)
|
||||
|
||||
@classmethod
|
||||
def _defaults(cls, config):
|
||||
|
||||
# execution progress
|
||||
config.add_tailbone_websocket('upgrades.execution_progress',
|
||||
cls, attr='execution_progress')
|
||||
config.add_tailbone_websocket('upgrades.execution_progress', cls)
|
||||
|
||||
|
||||
def defaults(config, **kwargs):
|
||||
base = globals()
|
||||
|
||||
UpgradeWS = kwargs.get('UpgradeWS', base['UpgradeWS'])
|
||||
UpgradeWS.defaults(config)
|
||||
UpgradeExecutionProgressWS = kwargs.get('UpgradeExecutionProgressWS', base['UpgradeExecutionProgressWS'])
|
||||
UpgradeExecutionProgressWS.defaults(config)
|
||||
|
||||
|
||||
def includeme(config):
|
||||
|
|
Loading…
Reference in a new issue