Let just one "task" handle collect/transmit of progress for websocket
first client to connect, will cause task to start; subsequent clients are just added to running set, for broadcast messaging
This commit is contained in:
parent
e93063a344
commit
0a113611e8
|
@ -77,7 +77,7 @@ class WebsocketView(object):
|
||||||
model = self.model
|
model = self.model
|
||||||
|
|
||||||
# load the user's web session
|
# load the user's web session
|
||||||
user_session = await self.get_user_session(scope)
|
user_session = self.get_user_session(scope)
|
||||||
if user_session:
|
if user_session:
|
||||||
|
|
||||||
# determine user uuid
|
# determine user uuid
|
||||||
|
@ -91,7 +91,7 @@ class WebsocketView(object):
|
||||||
# load user proper
|
# load user proper
|
||||||
return session.query(model.User).get(user_uuid)
|
return session.query(model.User).get(user_uuid)
|
||||||
|
|
||||||
async def get_user_session(self, scope):
|
def get_user_session(self, scope):
|
||||||
settings = self.registry.settings
|
settings = self.registry.settings
|
||||||
beaker_key = settings['beaker.session.key']
|
beaker_key = settings['beaker.session.key']
|
||||||
beaker_secret = settings['beaker.session.secret']
|
beaker_secret = settings['beaker.session.secret']
|
||||||
|
|
|
@ -33,35 +33,92 @@ from tailbone.views.asgi import WebsocketView
|
||||||
from tailbone.progress import get_basic_session
|
from tailbone.progress import get_basic_session
|
||||||
|
|
||||||
|
|
||||||
class UpgradeWS(WebsocketView):
|
class UpgradeExecutionProgressWS(WebsocketView):
|
||||||
|
|
||||||
async def execution_progress(self, scope, receive, send):
|
# keep track of all "global" state for this socket
|
||||||
rattail_config = self.registry['rattail_config']
|
global_state = {
|
||||||
|
'upgrades': {},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
app = self.get_rattail_app()
|
||||||
|
|
||||||
# is user allowed to see this?
|
# is user allowed to see this?
|
||||||
if not await self.authorize(scope, receive, send, 'upgrades.execute'):
|
if not await self.authorize(scope, receive, send, 'upgrades.execute'):
|
||||||
return
|
return
|
||||||
|
|
||||||
# this tracks when client disconnects
|
# keep track of client state
|
||||||
state = {'disconnected': False}
|
client_state = {
|
||||||
|
'uuid': app.make_uuid(),
|
||||||
|
'disconnected': False,
|
||||||
|
'scope': scope,
|
||||||
|
'receive': receive,
|
||||||
|
'send': send,
|
||||||
|
}
|
||||||
|
|
||||||
|
# parse upgrade uuid from query string
|
||||||
|
query = scope['query_string'].decode('utf_8')
|
||||||
|
query = parse_qs(query)
|
||||||
|
uuid = query['uuid'][0]
|
||||||
|
|
||||||
|
# first client to request progress for this upgrade, must
|
||||||
|
# start a task to manage the collect/transmit logic for
|
||||||
|
# progress data, on behalf of this and/or any future clients
|
||||||
|
started_task = None
|
||||||
|
if uuid not in self.global_state['upgrades']:
|
||||||
|
|
||||||
|
# this upgrade is new to us; establish state and add first client
|
||||||
|
upgrade_state = self.global_state['upgrades'][uuid] = {
|
||||||
|
'clients': {client_state['uuid']: client_state},
|
||||||
|
}
|
||||||
|
|
||||||
|
# start task for transmit of progress data to all clients
|
||||||
|
started_task = asyncio.create_task(self.manage_progress(uuid))
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
# progress task is already running, just add new client
|
||||||
|
upgrade_state = self.global_state['upgrades'][uuid]
|
||||||
|
upgrade_state['clients'][client_state['uuid']] = client_state
|
||||||
|
|
||||||
async def wait_for_disconnect():
|
async def wait_for_disconnect():
|
||||||
message = await receive()
|
message = await receive()
|
||||||
if message['type'] == 'websocket.disconnect':
|
if message['type'] == 'websocket.disconnect':
|
||||||
state['disconnected'] = True
|
client_state['disconnected'] = True
|
||||||
|
|
||||||
# watch for client disconnect, while we do other things
|
# wait forever, until client disconnects
|
||||||
asyncio.create_task(wait_for_disconnect())
|
asyncio.create_task(wait_for_disconnect())
|
||||||
|
while not client_state['disconnected']:
|
||||||
|
|
||||||
query = scope['query_string'].decode('utf_8')
|
# can stop if upgrade has completed
|
||||||
query = parse_qs(query)
|
if uuid not in self.global_state['upgrades']:
|
||||||
uuid = query['uuid'][0]
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# remove client from global set, if upgrade still running
|
||||||
|
if client_state['disconnected']:
|
||||||
|
upgrade_state = self.global_state['upgrades'].get(uuid)
|
||||||
|
if upgrade_state:
|
||||||
|
del upgrade_state['clients'][client_state['uuid']]
|
||||||
|
|
||||||
|
# must continue to wait for other clients, if this client was
|
||||||
|
# the first to request progress
|
||||||
|
if started_task:
|
||||||
|
await started_task
|
||||||
|
|
||||||
|
async def manage_progress(self, uuid):
|
||||||
|
"""
|
||||||
|
Task which handles collect / transmit of progress data, for
|
||||||
|
sake of all attached clients.
|
||||||
|
"""
|
||||||
progress_session_id = 'upgrades.{}.execution_progress'.format(uuid)
|
progress_session_id = 'upgrades.{}.execution_progress'.format(uuid)
|
||||||
progress_session = get_basic_session(rattail_config,
|
progress_session = get_basic_session(self.rattail_config,
|
||||||
id=progress_session_id)
|
id=progress_session_id)
|
||||||
|
|
||||||
# do the rest forever, until client disconnects
|
upgrade_state = self.global_state['upgrades'][uuid]
|
||||||
while not state['disconnected']:
|
clients = upgrade_state['clients']
|
||||||
|
while clients:
|
||||||
|
|
||||||
# load latest progress data
|
# load latest progress data
|
||||||
progress_session.load()
|
progress_session.load()
|
||||||
|
@ -69,26 +126,30 @@ class UpgradeWS(WebsocketView):
|
||||||
# when upgrade progress is complete...
|
# when upgrade progress is complete...
|
||||||
if progress_session.get('complete'):
|
if progress_session.get('complete'):
|
||||||
|
|
||||||
# maybe set success flash msg
|
# maybe set success flash msg (for all clients)
|
||||||
msg = progress_session.get('success_msg')
|
msg = progress_session.get('success_msg')
|
||||||
if msg:
|
if msg:
|
||||||
user_session = await self.get_user_session(scope)
|
for client in clients.values():
|
||||||
|
user_session = self.get_user_session(client['scope'])
|
||||||
user_session.flash(msg)
|
user_session.flash(msg)
|
||||||
user_session.persist()
|
user_session.persist()
|
||||||
|
|
||||||
# tell client progress is complete
|
# tell clients progress is complete
|
||||||
await send({'type': 'websocket.send',
|
for client in clients.values():
|
||||||
|
await client['send']({
|
||||||
|
'type': 'websocket.send',
|
||||||
'subtype': 'upgrades.execute_progress',
|
'subtype': 'upgrades.execute_progress',
|
||||||
'text': json.dumps({'complete': True})})
|
'text': json.dumps({'complete': True})})
|
||||||
|
|
||||||
# this websocket is done
|
# this websocket is done, so remove all clients
|
||||||
|
clients.clear()
|
||||||
break
|
break
|
||||||
|
|
||||||
# we will send this data down to client
|
# we will send this data down to client
|
||||||
data = dict(progress_session)
|
data = {}
|
||||||
|
|
||||||
# maybe add more lines from command output
|
# maybe add more lines from command output
|
||||||
path = rattail_config.upgrade_filepath(uuid, filename='stdout.log')
|
path = self.rattail_config.upgrade_filepath(uuid, filename='stdout.log')
|
||||||
offset = progress_session.get('stdout.offset', 0)
|
offset = progress_session.get('stdout.offset', 0)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
size = os.path.getsize(path) - offset
|
size = os.path.getsize(path) - offset
|
||||||
|
@ -100,31 +161,33 @@ class UpgradeWS(WebsocketView):
|
||||||
progress_session['stdout.offset'] = offset + size
|
progress_session['stdout.offset'] = offset + size
|
||||||
progress_session.save()
|
progress_session.save()
|
||||||
|
|
||||||
# send data to client
|
# send data to clients
|
||||||
await send({'type': 'websocket.send',
|
for client in clients.values():
|
||||||
|
await client['send']({
|
||||||
|
'type': 'websocket.send',
|
||||||
'subtype': 'upgrades.execute_progress',
|
'subtype': 'upgrades.execute_progress',
|
||||||
'text': json.dumps(data)})
|
'text': json.dumps(data)})
|
||||||
|
|
||||||
# pause for 1 second
|
# pause for 1 second
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# no more clients, no more reason to track this upgrade
|
||||||
|
del self.global_state['upgrades'][uuid]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def defaults(cls, config):
|
def defaults(cls, config):
|
||||||
cls._defaults(config)
|
cls._defaults(config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _defaults(cls, config):
|
def _defaults(cls, config):
|
||||||
|
config.add_tailbone_websocket('upgrades.execution_progress', cls)
|
||||||
# execution progress
|
|
||||||
config.add_tailbone_websocket('upgrades.execution_progress',
|
|
||||||
cls, attr='execution_progress')
|
|
||||||
|
|
||||||
|
|
||||||
def defaults(config, **kwargs):
|
def defaults(config, **kwargs):
|
||||||
base = globals()
|
base = globals()
|
||||||
|
|
||||||
UpgradeWS = kwargs.get('UpgradeWS', base['UpgradeWS'])
|
UpgradeExecutionProgressWS = kwargs.get('UpgradeExecutionProgressWS', base['UpgradeExecutionProgressWS'])
|
||||||
UpgradeWS.defaults(config)
|
UpgradeExecutionProgressWS.defaults(config)
|
||||||
|
|
||||||
|
|
||||||
def includeme(config):
|
def includeme(config):
|
||||||
|
|
Loading…
Reference in a new issue