#!/usr/bin/env python # -*- coding: utf-8 -*- ################################################################################ # # SQLBase7_SA -- SQLAlchemy driver/dialect for Centura SQLBase v7 # Copyright © 2010 Lance Edgar # # This file is part of SQLBase7_SA. # # SQLBase7_SA is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # SQLBase7_SA is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with SQLBase7_SA. If not, see . # ################################################################################ from sqlalchemy.engine.default import DefaultDialect from sqlalchemy import types from sqlalchemy.sql.compiler import SQLCompiler # CSB7: These are only needed for the visit_select() compiler method override below. from sqlalchemy.sql import expression as sql from sqlalchemy import util class SQLBase7Compiler(SQLCompiler): # I _really_ hate to override this stuff, as it's pretty well full of things # I haven't yet taken the time or brain sterroids to understand, but nonetheless # it does seem to accomplish the goal of changing the way JOINs happen for good # ol' Centure SQLBase. I've marked relevant comments below with the CSB7 tag. def visit_join(self, join, asfrom=False, **kwargs): # CSB7: Same as base method, only "JOIN" is replaced with "," # and "ON" is replaced with "WHERE". return (self.process(join.left, asfrom=True, **kwargs) + \ ", " + \ self.process(join.right, asfrom=True, **kwargs) + \ " WHERE " + \ self.process(join.onclause, **kwargs) ) def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} existingfroms = entry.get('from', None) froms = select._get_display_froms(existingfroms) correlate_froms = set(sql._from_objects(*froms)) # TODO: might want to propagate existing froms for select(select(select)) # where innermost select should correlate to outermost # if existingfroms: # correlate_froms = correlate_froms.union(existingfroms) self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) if compound_index==1 and not entry or entry.get('iswrapper', False): column_clause_args = {'result_map':self.result_map} else: column_clause_args = {} # the actual list of columns to print in the SELECT column list. inner_columns = [ c for c in [ self.process( self.label_select_column(select, co, asfrom=asfrom), within_columns_clause=True, **column_clause_args) for co in util.unique_list(select.inner_columns) ] if c is not None ] text = "SELECT " # we're off to a good start ! if select._hints: byfrom = dict([ (from_, hinttext % {'name':self.process(from_, ashint=True)}) for (from_, dialect), hinttext in select._hints.iteritems() if dialect in ('*', self.dialect.name) ]) hint_text = self.get_select_hint_text(byfrom) if hint_text: text += hint_text + " " if select._prefixes: text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " text += self.get_select_precolumns(select) text += ', '.join(inner_columns) # CSB7: If a JOIN is supposed to happen, it will result in the visit_join() # method overriding the normal behavior and adding a WHERE clause instead of # the typical ON clause. If so, we'll need to know that so we avoid adding yet # another WHERE clause for whatever conditions we've got. join_with_where = False if froms: text += " \nFROM " if select._hints: # TODO: CSB7: Obviously this will need to go away but I'd like to know when # this code gets called so I can examine the situation more closely... assert False text += ', '.join([self.process(f, asfrom=True, fromhints=byfrom, **kwargs) for f in froms]) else: # CSB7: Here's where the visit_join() call will add the WHERE clause... join_with_where = True text += ', '.join([self.process(f, asfrom=True, **kwargs) for f in froms]) else: text += self.default_from() if select._whereclause is not None: t = self.process(select._whereclause, **kwargs) if t: # CSB7: And here's where we avoid adding a second WHERE clause in the # event that we already have one... if join_with_where: text += " AND (%s)" % t else: text += " \nWHERE " + t if select._group_by_clause.clauses: group_by = self.process(select._group_by_clause, **kwargs) if group_by: text += " GROUP BY " + group_by if select._having is not None: t = self.process(select._having, **kwargs) if t: text += " \nHAVING " + t if select._order_by_clause.clauses: text += self.order_by_clause(select, **kwargs) if select._limit is not None or select._offset is not None: text += self.limit_clause(select) if select.for_update: text += self.for_update_clause(select) self.stack.pop(-1) if asfrom and parens: return "(" + text + ")" else: return text class SQLBase7Dialect(DefaultDialect): name = 'sqlbase7' max_identifier_length = 18 # CSB7: Needed to override JOIN clauses. statement_compiler = SQLBase7Compiler type_map = { 'CHAR' : types.CHAR, 'DATE' : types.DATE, 'DECIMAL' : types.DECIMAL, 'FLOAT' : types.FLOAT, 'SMALLINT' : types.SMALLINT, 'TIME' : types.TIME, 'TIMESTMP' : types.TIMESTAMP, 'VARCHAR' : types.VARCHAR, } def _check_unicode_returns(self, connection): return False def get_table_names(self, connection, schema=None, **kw): if schema is None: schema = '' else: schema = '%s.' % schema cursor = connection.connection.cursor() table_names = [row.NAME for row in cursor.execute( "SELECT NAME FROM %sSYSTABLES WHERE REMARKS IS NOT NULL" % schema )] cursor.close() return table_names def get_columns(self, connection, table_name, schema=None, **kw): if schema is None: schema = '' else: schema = '%s.' % schema cursor = connection.connection.cursor() columns = [] for row in cursor.execute("SELECT NAME,COLTYPE,NULLS FROM %sSYSCOLUMNS WHERE TBNAME = '%s'" % (schema, table_name)): columns.append({ 'name' : row.NAME, 'type' : self.type_map[row.COLTYPE], 'nullable' : row.NULLS == 'Y', 'default' : None, 'autoincrement' : False, }) cursor.close() return columns def get_primary_keys(self, connection, table_name, schema=None, **kw): if schema is None: schema = '' else: schema = '%s.' % schema cursor = connection.connection.cursor() primary_keys = [row.COLNAME for row in cursor.execute( "SELECT COLNAME FROM %sSYSPKCONSTRAINTS WHERE NAME = '%s' ORDER BY PKCOLSEQNUM" % (schema, table_name) )] cursor.close() return primary_keys def get_foreign_keys(self, connection, table_name, schema=None, **kw): return [] def get_indexes(self, connection, table_name, schema=None, **kw): return []