diff --git a/sqlbase7/base.py b/sqlbase7/base.py index a6c9b9d..054a479 100644 --- a/sqlbase7/base.py +++ b/sqlbase7/base.py @@ -23,14 +23,156 @@ ################################################################################ -from sqlalchemy.engine import default +from sqlalchemy.engine.default import DefaultDialect from sqlalchemy import types +from sqlalchemy.sql.compiler import SQLCompiler + +# CSB7: These are only needed for the visit_select() compiler method override below. +from sqlalchemy.sql import expression as sql +from sqlalchemy import util -class SQLBase7Dialect(default.DefaultDialect): +class SQLBase7Compiler(SQLCompiler): + + # I _really_ hate to override this stuff, as it's pretty well full of things + # I haven't yet taken the time or brain sterroids to understand, but nonetheless + # it does seem to accomplish the goal of changing the way JOINs happen for good + # ol' Centure SQLBase. I've marked relevant comments below with the CSB7 tag. + + def visit_join(self, join, asfrom=False, **kwargs): + # CSB7: Same as base method, only "JOIN" is replaced with "," + # and "ON" is replaced with "WHERE". + return (self.process(join.left, asfrom=True, **kwargs) + \ + ", " + \ + self.process(join.right, asfrom=True, **kwargs) + \ + " WHERE " + \ + self.process(join.onclause, **kwargs) + ) + + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, fromhints=None, + compound_index=1, **kwargs): + + entry = self.stack and self.stack[-1] or {} + + existingfroms = entry.get('from', None) + + froms = select._get_display_froms(existingfroms) + + correlate_froms = set(sql._from_objects(*froms)) + + # TODO: might want to propagate existing froms for select(select(select)) + # where innermost select should correlate to outermost + # if existingfroms: + # correlate_froms = correlate_froms.union(existingfroms) + + self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) + + if compound_index==1 and not entry or entry.get('iswrapper', False): + column_clause_args = {'result_map':self.result_map} + else: + column_clause_args = {} + + # the actual list of columns to print in the SELECT column list. + inner_columns = [ + c for c in [ + self.process( + self.label_select_column(select, co, asfrom=asfrom), + within_columns_clause=True, + **column_clause_args) + for co in util.unique_list(select.inner_columns) + ] + if c is not None + ] + + text = "SELECT " # we're off to a good start ! + + if select._hints: + byfrom = dict([ + (from_, hinttext % {'name':self.process(from_, ashint=True)}) + for (from_, dialect), hinttext in + select._hints.iteritems() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + if hint_text: + text += hint_text + " " + + if select._prefixes: + text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " + text += self.get_select_precolumns(select) + text += ', '.join(inner_columns) + + # CSB7: If a JOIN is supposed to happen, it will result in the visit_join() + # method overriding the normal behavior and adding a WHERE clause instead of + # the typical ON clause. If so, we'll need to know that so we avoid adding yet + # another WHERE clause for whatever conditions we've got. + join_with_where = False + + if froms: + text += " \nFROM " + + if select._hints: + # TODO: CSB7: Obviously this will need to go away but I'd like to know when + # this code gets called so I can examine the situation more closely... + assert False + text += ', '.join([self.process(f, + asfrom=True, fromhints=byfrom, + **kwargs) + for f in froms]) + else: + # CSB7: Here's where the visit_join() call will add the WHERE clause... + join_with_where = True + text += ', '.join([self.process(f, + asfrom=True, **kwargs) + for f in froms]) + else: + text += self.default_from() + + if select._whereclause is not None: + t = self.process(select._whereclause, **kwargs) + if t: + # CSB7: And here's where we avoid adding a second WHERE clause in the + # event that we already have one... + if join_with_where: + text += " AND (%s)" % t + else: + text += " \nWHERE " + t + + if select._group_by_clause.clauses: + group_by = self.process(select._group_by_clause, **kwargs) + if group_by: + text += " GROUP BY " + group_by + + if select._having is not None: + t = self.process(select._having, **kwargs) + if t: + text += " \nHAVING " + t + + if select._order_by_clause.clauses: + text += self.order_by_clause(select, **kwargs) + if select._limit is not None or select._offset is not None: + text += self.limit_clause(select) + if select.for_update: + text += self.for_update_clause(select) + + self.stack.pop(-1) + + if asfrom and parens: + return "(" + text + ")" + else: + return text + + +class SQLBase7Dialect(DefaultDialect): name = 'sqlbase7' + max_identifier_length = 18 + + # CSB7: Needed to override JOIN clauses. + statement_compiler = SQLBase7Compiler + type_map = { 'CHAR' : types.CHAR, 'DATE' : types.DATE, @@ -45,19 +187,29 @@ class SQLBase7Dialect(default.DefaultDialect): def _check_unicode_returns(self, connection): return False - def get_table_names(self, schema=None, connection=None): - cursor = schema.connection.cursor() + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = '' + else: + schema = '%s.' % schema + + cursor = connection.connection.cursor() table_names = [row.NAME for row in cursor.execute( - "SELECT NAME FROM SYSADM.SYSTABLES WHERE REMARKS IS NOT NULL" + "SELECT NAME FROM %sSYSTABLES WHERE REMARKS IS NOT NULL" % schema )] cursor.close() return table_names def get_columns(self, connection, table_name, schema=None, **kw): + if schema is None: + schema = '' + else: + schema = '%s.' % schema + cursor = connection.connection.cursor() columns = [] - for row in cursor.execute("SELECT NAME,COLTYPE,NULLS FROM SYSADM.SYSCOLUMNS WHERE TBNAME = '%s'" % table_name): + for row in cursor.execute("SELECT NAME,COLTYPE,NULLS FROM %sSYSCOLUMNS WHERE TBNAME = '%s'" % (schema, table_name)): columns.append({ 'name' : row.NAME, @@ -71,9 +223,14 @@ class SQLBase7Dialect(default.DefaultDialect): return columns def get_primary_keys(self, connection, table_name, schema=None, **kw): + if schema is None: + schema = '' + else: + schema = '%s.' % schema + cursor = connection.connection.cursor() primary_keys = [row.COLNAME for row in cursor.execute( - "SELECT COLNAME FROM SYSADM.SYSPKCONSTRAINTS WHERE NAME = '%s' ORDER BY PKCOLSEQNUM" % table_name + "SELECT COLNAME FROM %sSYSPKCONSTRAINTS WHERE NAME = '%s' ORDER BY PKCOLSEQNUM" % (schema, table_name) )] cursor.close() return primary_keys