tailbone/tailbone/api/master.py
Lance Edgar 46501b7caa Use sqlalchemy-filters package for REST API collection_get
just sorting and pagination so far though, no actual filters yet
2018-11-19 23:56:42 -06:00

122 lines
4.3 KiB
Python

# -*- coding: utf-8; -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2018 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Tailbone Web API - Master View
"""
from __future__ import unicode_literals, absolute_import
from sqlalchemy_filters import apply_sort, apply_pagination
from tailbone.api import APIView, api
from tailbone.db import Session
class APIMasterView(APIView):
"""
Base class for data model REST API views.
"""
@property
def Session(self):
return Session
@classmethod
def get_model_class(cls):
if hasattr(cls, 'model_class'):
return cls.model_class
raise NotImplementedError("must set `model_class` for {}".format(cls.__name__))
@classmethod
def get_normalized_model_name(cls):
if hasattr(cls, 'normalized_model_name'):
return cls.normalized_model_name
return cls.get_model_class().__name__.lower()
@classmethod
def get_object_key(cls):
if hasattr(cls, 'object_key'):
return cls.object_key
return cls.get_normalized_model_name()
@classmethod
def get_collection_key(cls):
if hasattr(cls, 'collection_key'):
return cls.collection_key
return '{}s'.format(cls.get_object_key())
def _collection_get(self):
cls = self.get_model_class()
query = self.Session.query(cls)
context = {}
# TODO: should vuetable (etc.) be sending us valid sort_spec directly?
sort = self.request.params.get('sort')
if sort:
# TODO: this is fragile, but what to do if bad params?
sortkey, sortdir = sort.split('|')
if sortdir != 'desc':
sortdir = 'asc'
sort_spec = [
{
# 'model': self.model_class.__name__,
'field': sortkey,
'direction': sortdir,
},
]
query = apply_sort(query, sort_spec)
# NOTE: we only paginate results if sorting is in effect, otherwise
# record sequence is "non-determinant" (is that the word?)
page = self.request.params.get('page')
per_page = self.request.params.get('per_page')
if page.isdigit() and per_page.isdigit():
page = int(page)
per_page = int(per_page)
query, pagination = apply_pagination(query, page_number=page, page_size=per_page)
# these pagination values are based on 'vuetable-2'
# https://www.vuetable.com/guide/pagination.html#how-the-pagination-component-works
context['total'] = pagination.total_results
context['per_page'] = pagination.page_size
context['current_page'] = pagination.page_number
context['last_page'] = pagination.num_pages
context['from'] = pagination.page_size * (pagination.page_number - 1) + 1
to = pagination.page_size * (pagination.page_number - 1) + pagination.page_size
if to > pagination.total_results:
context['to'] = pagination.total_results
else:
context['to'] = to
objects = [self.normalize(obj) for obj in query]
context[self.get_collection_key()] = objects
return context
def _get(self):
uuid = self.request.matchdict['uuid']
obj = self.Session.query(self.get_model_class()).get(uuid)
if not obj:
raise self.notfound()
return {self.get_object_key(): self.normalize(obj)}