diff --git a/tailbone/views/asgi/__init__.py b/tailbone/views/asgi/__init__.py index 01649f97..68300a44 100644 --- a/tailbone/views/asgi/__init__.py +++ b/tailbone/views/asgi/__init__.py @@ -77,7 +77,7 @@ class WebsocketView(object): model = self.model # load the user's web session - user_session = await self.get_user_session(scope) + user_session = self.get_user_session(scope) if user_session: # determine user uuid @@ -91,7 +91,7 @@ class WebsocketView(object): # load user proper return session.query(model.User).get(user_uuid) - async def get_user_session(self, scope): + def get_user_session(self, scope): settings = self.registry.settings beaker_key = settings['beaker.session.key'] beaker_secret = settings['beaker.session.secret'] diff --git a/tailbone/views/asgi/upgrades.py b/tailbone/views/asgi/upgrades.py index fc066326..f06fc7d3 100644 --- a/tailbone/views/asgi/upgrades.py +++ b/tailbone/views/asgi/upgrades.py @@ -33,35 +33,92 @@ from tailbone.views.asgi import WebsocketView from tailbone.progress import get_basic_session -class UpgradeWS(WebsocketView): +class UpgradeExecutionProgressWS(WebsocketView): - async def execution_progress(self, scope, receive, send): - rattail_config = self.registry['rattail_config'] + # keep track of all "global" state for this socket + global_state = { + 'upgrades': {}, + } + + async def __call__(self, scope, receive, send): + app = self.get_rattail_app() # is user allowed to see this? if not await self.authorize(scope, receive, send, 'upgrades.execute'): return - # this tracks when client disconnects - state = {'disconnected': False} + # keep track of client state + client_state = { + 'uuid': app.make_uuid(), + 'disconnected': False, + 'scope': scope, + 'receive': receive, + 'send': send, + } + + # parse upgrade uuid from query string + query = scope['query_string'].decode('utf_8') + query = parse_qs(query) + uuid = query['uuid'][0] + + # first client to request progress for this upgrade, must + # start a task to manage the collect/transmit logic for + # progress data, on behalf of this and/or any future clients + started_task = None + if uuid not in self.global_state['upgrades']: + + # this upgrade is new to us; establish state and add first client + upgrade_state = self.global_state['upgrades'][uuid] = { + 'clients': {client_state['uuid']: client_state}, + } + + # start task for transmit of progress data to all clients + started_task = asyncio.create_task(self.manage_progress(uuid)) + + else: + + # progress task is already running, just add new client + upgrade_state = self.global_state['upgrades'][uuid] + upgrade_state['clients'][client_state['uuid']] = client_state async def wait_for_disconnect(): message = await receive() if message['type'] == 'websocket.disconnect': - state['disconnected'] = True + client_state['disconnected'] = True - # watch for client disconnect, while we do other things + # wait forever, until client disconnects asyncio.create_task(wait_for_disconnect()) + while not client_state['disconnected']: - query = scope['query_string'].decode('utf_8') - query = parse_qs(query) - uuid = query['uuid'][0] + # can stop if upgrade has completed + if uuid not in self.global_state['upgrades']: + break + + await asyncio.sleep(0.1) + + # remove client from global set, if upgrade still running + if client_state['disconnected']: + upgrade_state = self.global_state['upgrades'].get(uuid) + if upgrade_state: + del upgrade_state['clients'][client_state['uuid']] + + # must continue to wait for other clients, if this client was + # the first to request progress + if started_task: + await started_task + + async def manage_progress(self, uuid): + """ + Task which handles collect / transmit of progress data, for + sake of all attached clients. + """ progress_session_id = 'upgrades.{}.execution_progress'.format(uuid) - progress_session = get_basic_session(rattail_config, + progress_session = get_basic_session(self.rattail_config, id=progress_session_id) - # do the rest forever, until client disconnects - while not state['disconnected']: + upgrade_state = self.global_state['upgrades'][uuid] + clients = upgrade_state['clients'] + while clients: # load latest progress data progress_session.load() @@ -69,26 +126,30 @@ class UpgradeWS(WebsocketView): # when upgrade progress is complete... if progress_session.get('complete'): - # maybe set success flash msg + # maybe set success flash msg (for all clients) msg = progress_session.get('success_msg') if msg: - user_session = await self.get_user_session(scope) - user_session.flash(msg) - user_session.persist() + for client in clients.values(): + user_session = self.get_user_session(client['scope']) + user_session.flash(msg) + user_session.persist() - # tell client progress is complete - await send({'type': 'websocket.send', - 'subtype': 'upgrades.execute_progress', - 'text': json.dumps({'complete': True})}) + # tell clients progress is complete + for client in clients.values(): + await client['send']({ + 'type': 'websocket.send', + 'subtype': 'upgrades.execute_progress', + 'text': json.dumps({'complete': True})}) - # this websocket is done + # this websocket is done, so remove all clients + clients.clear() break # we will send this data down to client - data = dict(progress_session) + data = {} # maybe add more lines from command output - path = rattail_config.upgrade_filepath(uuid, filename='stdout.log') + path = self.rattail_config.upgrade_filepath(uuid, filename='stdout.log') offset = progress_session.get('stdout.offset', 0) if os.path.exists(path): size = os.path.getsize(path) - offset @@ -100,31 +161,33 @@ class UpgradeWS(WebsocketView): progress_session['stdout.offset'] = offset + size progress_session.save() - # send data to client - await send({'type': 'websocket.send', - 'subtype': 'upgrades.execute_progress', - 'text': json.dumps(data)}) + # send data to clients + for client in clients.values(): + await client['send']({ + 'type': 'websocket.send', + 'subtype': 'upgrades.execute_progress', + 'text': json.dumps(data)}) # pause for 1 second await asyncio.sleep(1) + # no more clients, no more reason to track this upgrade + del self.global_state['upgrades'][uuid] + @classmethod def defaults(cls, config): cls._defaults(config) @classmethod def _defaults(cls, config): - - # execution progress - config.add_tailbone_websocket('upgrades.execution_progress', - cls, attr='execution_progress') + config.add_tailbone_websocket('upgrades.execution_progress', cls) def defaults(config, **kwargs): base = globals() - UpgradeWS = kwargs.get('UpgradeWS', base['UpgradeWS']) - UpgradeWS.defaults(config) + UpgradeExecutionProgressWS = kwargs.get('UpgradeExecutionProgressWS', base['UpgradeExecutionProgressWS']) + UpgradeExecutionProgressWS.defaults(config) def includeme(config):