Updated the SQLCompiler subclass to use a simplified version of the Oracle dialect's method for handling JOIN clauses. Also overrode the 'ilike' operation to use '@lower' function calls.

This commit is contained in:
Lance Edgar 2010-04-21 15:55:30 -05:00
parent b8493f6341
commit 39ddf36b59

View file

@ -24,145 +24,51 @@
from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy import types from sqlalchemy import types, and_
from sqlalchemy.sql.compiler import SQLCompiler from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import Join
# CSB7: These are only needed for the visit_select() compiler method override below.
from sqlalchemy.sql import expression as sql
from sqlalchemy import util
class SQLBase7Compiler(SQLCompiler): class SQLBase7Compiler(SQLCompiler):
# I _really_ hate to override this stuff, as it's pretty well full of things # Most of the code below was copied from the Oracle dialect. Thanks to Michael Bayer
# I haven't yet taken the time or brain sterroids to understand, but nonetheless # for pointing that out. Oh, and for writing SQLAlchemy; that was pretty cool.
# it does seem to accomplish the goal of changing the way JOINs happen for good
# ol' Centure SQLBase. I've marked relevant comments below with the CSB7 tag.
def visit_join(self, join, asfrom=False, **kwargs): def visit_join(self, join, **kwargs):
# CSB7: Same as base method, only "JOIN" is replaced with "," kwargs['asfrom'] = True
# and "ON" is replaced with "WHERE". return self.process(join.left, **kwargs) + ", " + self.process(join.right, **kwargs)
return (self.process(join.left, asfrom=True, **kwargs) + \
", " + \
self.process(join.right, asfrom=True, **kwargs) + \
" WHERE " + \
self.process(join.onclause, **kwargs)
)
def visit_select(self, select, asfrom=False, parens=True, def visit_select(self, select, **kwargs):
iswrapper=False, fromhints=None, froms = select._get_display_froms()
compound_index=1, **kwargs): whereclause = self._get_join_whereclause(froms)
if whereclause is not None:
entry = self.stack and self.stack[-1] or {} select = select.where(whereclause)
existingfroms = entry.get('from', None)
froms = select._get_display_froms(existingfroms)
correlate_froms = set(sql._from_objects(*froms))
# TODO: might want to propagate existing froms for select(select(select))
# where innermost select should correlate to outermost
# if existingfroms:
# correlate_froms = correlate_froms.union(existingfroms)
self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper})
if compound_index==1 and not entry or entry.get('iswrapper', False):
column_clause_args = {'result_map':self.result_map}
else:
column_clause_args = {}
# the actual list of columns to print in the SELECT column list.
inner_columns = [
c for c in [
self.process(
self.label_select_column(select, co, asfrom=asfrom),
within_columns_clause=True,
**column_clause_args)
for co in util.unique_list(select.inner_columns)
]
if c is not None
]
text = "SELECT " # we're off to a good start !
if select._hints:
byfrom = dict([
(from_, hinttext % {'name':self.process(from_, ashint=True)})
for (from_, dialect), hinttext in
select._hints.iteritems()
if dialect in ('*', self.dialect.name)
])
hint_text = self.get_select_hint_text(byfrom)
if hint_text:
text += hint_text + " "
if select._prefixes:
text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
text += ', '.join(inner_columns)
# CSB7: If a JOIN is supposed to happen, it will result in the visit_join()
# method overriding the normal behavior and adding a WHERE clause instead of
# the typical ON clause. If so, we'll need to know that so we avoid adding yet
# another WHERE clause for whatever conditions we've got.
join_with_where = False
if froms:
text += " \nFROM "
if select._hints: kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
# TODO: CSB7: Obviously this will need to go away but I'd like to know when return SQLCompiler.visit_select(self, select, **kwargs)
# this code gets called so I can examine the situation more closely...
assert False
text += ', '.join([self.process(f,
asfrom=True, fromhints=byfrom,
**kwargs)
for f in froms])
else:
# CSB7: Here's where the visit_join() call will add the WHERE clause...
join_with_where = True
text += ', '.join([self.process(f,
asfrom=True, **kwargs)
for f in froms])
else:
text += self.default_from()
if select._whereclause is not None: def _get_join_whereclause(self, froms):
t = self.process(select._whereclause, **kwargs) clauses = []
if t:
# CSB7: And here's where we avoid adding a second WHERE clause in the
# event that we already have one...
if join_with_where:
text += " AND (%s)" % t
else:
text += " \nWHERE " + t
if select._group_by_clause.clauses: def visit_join(join):
group_by = self.process(select._group_by_clause, **kwargs) clauses.append(join.onclause)
if group_by: for j in join.left, join.right:
text += " GROUP BY " + group_by if isinstance(j, Join):
visit_join(j)
for f in froms:
if isinstance(f, Join):
visit_join(f)
return and_(*clauses)
if select._having is not None: def visit_ilike_op(self, binary, **kw):
t = self.process(select._having, **kwargs) escape = binary.modifiers.get("escape", None)
if t: return '@lower(%s) LIKE @lower(%s)' % (
text += " \nHAVING " + t self.process(binary.left, **kw),
self.process(binary.right, **kw)) \
+ (escape and ' ESCAPE \'%s\'' % escape or '')
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)
if select._limit is not None or select._offset is not None:
text += self.limit_clause(select)
if select.for_update:
text += self.for_update_clause(select)
self.stack.pop(-1)
if asfrom and parens:
return "(" + text + ")"
else:
return text
class SQLBase7Dialect(DefaultDialect): class SQLBase7Dialect(DefaultDialect):
@ -170,7 +76,6 @@ class SQLBase7Dialect(DefaultDialect):
max_identifier_length = 18 max_identifier_length = 18
# CSB7: Needed to override JOIN clauses.
statement_compiler = SQLBase7Compiler statement_compiler = SQLBase7Compiler
type_map = { type_map = {