diff --git a/sqlbase7_sa/base.py b/sqlbase7_sa/base.py index ec35e9d..d0faa21 100644 --- a/sqlbase7_sa/base.py +++ b/sqlbase7_sa/base.py @@ -24,145 +24,51 @@ from sqlalchemy.engine.default import DefaultDialect -from sqlalchemy import types +from sqlalchemy import types, and_ from sqlalchemy.sql.compiler import SQLCompiler - -# CSB7: These are only needed for the visit_select() compiler method override below. -from sqlalchemy.sql import expression as sql -from sqlalchemy import util +from sqlalchemy.sql.expression import Join class SQLBase7Compiler(SQLCompiler): - # I _really_ hate to override this stuff, as it's pretty well full of things - # I haven't yet taken the time or brain sterroids to understand, but nonetheless - # it does seem to accomplish the goal of changing the way JOINs happen for good - # ol' Centure SQLBase. I've marked relevant comments below with the CSB7 tag. + # Most of the code below was copied from the Oracle dialect. Thanks to Michael Bayer + # for pointing that out. Oh, and for writing SQLAlchemy; that was pretty cool. - def visit_join(self, join, asfrom=False, **kwargs): - # CSB7: Same as base method, only "JOIN" is replaced with "," - # and "ON" is replaced with "WHERE". - return (self.process(join.left, asfrom=True, **kwargs) + \ - ", " + \ - self.process(join.right, asfrom=True, **kwargs) + \ - " WHERE " + \ - self.process(join.onclause, **kwargs) - ) + def visit_join(self, join, **kwargs): + kwargs['asfrom'] = True + return self.process(join.left, **kwargs) + ", " + self.process(join.right, **kwargs) - def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, fromhints=None, - compound_index=1, **kwargs): - - entry = self.stack and self.stack[-1] or {} - - existingfroms = entry.get('from', None) - - froms = select._get_display_froms(existingfroms) - - correlate_froms = set(sql._from_objects(*froms)) - - # TODO: might want to propagate existing froms for select(select(select)) - # where innermost select should correlate to outermost - # if existingfroms: - # correlate_froms = correlate_froms.union(existingfroms) - - self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) - - if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map} - else: - column_clause_args = {} - - # the actual list of columns to print in the SELECT column list. - inner_columns = [ - c for c in [ - self.process( - self.label_select_column(select, co, asfrom=asfrom), - within_columns_clause=True, - **column_clause_args) - for co in util.unique_list(select.inner_columns) - ] - if c is not None - ] - - text = "SELECT " # we're off to a good start ! - - if select._hints: - byfrom = dict([ - (from_, hinttext % {'name':self.process(from_, ashint=True)}) - for (from_, dialect), hinttext in - select._hints.iteritems() - if dialect in ('*', self.dialect.name) - ]) - hint_text = self.get_select_hint_text(byfrom) - if hint_text: - text += hint_text + " " - - if select._prefixes: - text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " - text += self.get_select_precolumns(select) - text += ', '.join(inner_columns) - - # CSB7: If a JOIN is supposed to happen, it will result in the visit_join() - # method overriding the normal behavior and adding a WHERE clause instead of - # the typical ON clause. If so, we'll need to know that so we avoid adding yet - # another WHERE clause for whatever conditions we've got. - join_with_where = False - - if froms: - text += " \nFROM " + def visit_select(self, select, **kwargs): + froms = select._get_display_froms() + whereclause = self._get_join_whereclause(froms) + if whereclause is not None: + select = select.where(whereclause) - if select._hints: - # TODO: CSB7: Obviously this will need to go away but I'd like to know when - # this code gets called so I can examine the situation more closely... - assert False - text += ', '.join([self.process(f, - asfrom=True, fromhints=byfrom, - **kwargs) - for f in froms]) - else: - # CSB7: Here's where the visit_join() call will add the WHERE clause... - join_with_where = True - text += ', '.join([self.process(f, - asfrom=True, **kwargs) - for f in froms]) - else: - text += self.default_from() + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) + return SQLCompiler.visit_select(self, select, **kwargs) - if select._whereclause is not None: - t = self.process(select._whereclause, **kwargs) - if t: - # CSB7: And here's where we avoid adding a second WHERE clause in the - # event that we already have one... - if join_with_where: - text += " AND (%s)" % t - else: - text += " \nWHERE " + t + def _get_join_whereclause(self, froms): + clauses = [] - if select._group_by_clause.clauses: - group_by = self.process(select._group_by_clause, **kwargs) - if group_by: - text += " GROUP BY " + group_by + def visit_join(join): + clauses.append(join.onclause) + for j in join.left, join.right: + if isinstance(j, Join): + visit_join(j) + + for f in froms: + if isinstance(f, Join): + visit_join(f) + + return and_(*clauses) - if select._having is not None: - t = self.process(select._having, **kwargs) - if t: - text += " \nHAVING " + t + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '@lower(%s) LIKE @lower(%s)' % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') - if select._order_by_clause.clauses: - text += self.order_by_clause(select, **kwargs) - if select._limit is not None or select._offset is not None: - text += self.limit_clause(select) - if select.for_update: - text += self.for_update_clause(select) - - self.stack.pop(-1) - - if asfrom and parens: - return "(" + text + ")" - else: - return text - class SQLBase7Dialect(DefaultDialect): @@ -170,7 +76,6 @@ class SQLBase7Dialect(DefaultDialect): max_identifier_length = 18 - # CSB7: Needed to override JOIN clauses. statement_compiler = SQLBase7Compiler type_map = {