fix: fix a couple more edge cases around oauth2 token refresh

This commit is contained in:
Lance Edgar 2026-02-06 17:42:47 -06:00
parent b546c9e97d
commit b161109d65
3 changed files with 54 additions and 4 deletions

40
src/wuttafarm/web/util.py Normal file
View file

@ -0,0 +1,40 @@
# -*- coding: utf-8; -*-
################################################################################
#
# WuttaFarm --Web app to integrate with and extend farmOS
# Copyright © 2026 Lance Edgar
#
# This file is part of WuttaFarm.
#
# WuttaFarm is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# WuttaFarm is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# WuttaFarm. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Misc. utilities for web app
"""
def save_farmos_oauth2_token(request, token):
"""
Common logic for saving the given OAuth2 token within the user
session. This function is called from 2 places:
* :meth:`wuttafarm.web.views.auth.farmos_oauth_callback()`
* :meth:`wuttafarm.web.views.farmos.master.FarmOSMasterView.get_farmos_client()`
"""
# nb. we pretend the token expires 1 minute early, to avoid edge
# cases around token refresh
token["expires_at"] -= 60
# save token to user session
request.session["farmos.oauth2.token"] = token

View file

@ -30,6 +30,8 @@ from wuttaweb.views import auth as base
from wuttaweb.auth import login_user from wuttaweb.auth import login_user
from wuttaweb.db import Session from wuttaweb.db import Session
from wuttafarm.web.util import save_farmos_oauth2_token
class AuthView(base.AuthView): class AuthView(base.AuthView):
""" """
@ -91,7 +93,13 @@ class AuthView(base.AuthView):
return self.redirect(self.request.route_url("login")) return self.redirect(self.request.route_url("login"))
# save token in user session # save token in user session
self.request.session["farmos.oauth2.token"] = token save_farmos_oauth2_token(self.request, token)
# nb. must give a *copy* of the token to farmOS client, since
# it will mutate it in-place and we don't want that to happen
# for our original copy in the user session. (otherwise the
# auto-refresh will not work correctly for subsequent calls.)
token = dict(token)
# get (or create) native app user # get (or create) native app user
farmos_client = self.app.get_farmos_client(token=token) farmos_client = self.app.get_farmos_client(token=token)

View file

@ -25,6 +25,8 @@ Base class for farmOS master views
from wuttaweb.views import MasterView from wuttaweb.views import MasterView
from wuttafarm.web.util import save_farmos_oauth2_token
class FarmOSMasterView(MasterView): class FarmOSMasterView(MasterView):
""" """
@ -52,15 +54,15 @@ class FarmOSMasterView(MasterView):
if not token: if not token:
raise self.forbidden() raise self.forbidden()
def token_updater(token):
self.request.session["farmos.oauth2.token"] = token
# nb. must give a *copy* of the token to farmOS client, since # nb. must give a *copy* of the token to farmOS client, since
# it will mutate it in-place and we don't want that to happen # it will mutate it in-place and we don't want that to happen
# for our original copy in the user session. (otherwise the # for our original copy in the user session. (otherwise the
# auto-refresh will not work correctly for subsequent calls.) # auto-refresh will not work correctly for subsequent calls.)
token = dict(token) token = dict(token)
def token_updater(token):
save_farmos_oauth2_token(self.request, token)
return self.app.get_farmos_client(token=token, token_updater=token_updater) return self.app.get_farmos_client(token=token, token_updater=token_updater)
def get_template_context(self, context): def get_template_context(self, context):