diff --git a/setup.py b/setup.py index eb6ec34..c3b5602 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ specific version of this database, that version being 7.5.1. packages = find_packages(), install_requires = [ - 'SQLAlchemy', + 'SQLAlchemy>0.5.2', ], entry_points = { diff --git a/sqlbase7_sa/_version.py b/sqlbase7_sa/_version.py index 729541a..5bb4149 100644 --- a/sqlbase7_sa/_version.py +++ b/sqlbase7_sa/_version.py @@ -23,4 +23,4 @@ ################################################################################ -__version__ = '0.1b1' +__version__ = '0.1b5' diff --git a/sqlbase7_sa/sqlbase7.py b/sqlbase7_sa/sqlbase7.py index 93c8ef3..690395c 100644 --- a/sqlbase7_sa/sqlbase7.py +++ b/sqlbase7_sa/sqlbase7.py @@ -26,6 +26,7 @@ from sqlalchemy.engine.default import DefaultDialect from sqlalchemy import types, and_ from sqlalchemy.sql.expression import Join +from sqlalchemy.sql import visitors, operators, ClauseElement import sqlalchemy @@ -56,7 +57,12 @@ class SQLBase7Compiler(CompilerBase): return self.process(join.left, **kwargs) + ", " + self.process(join.right, **kwargs) def visit_select(self, select, **kwargs): - froms = select._get_display_froms() + if self.stack and 'from' in self.stack[-1]: + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + + froms = select._get_display_froms(existingfroms) whereclause = self._get_join_whereclause(froms) if whereclause is not None: select = select.where(whereclause) @@ -71,7 +77,16 @@ class SQLBase7Compiler(CompilerBase): clauses = [] def visit_join(join): - clauses.append(join.onclause) + if join.isouter: + def visit_binary(binary): + if binary.operator == operators.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) + else: + clauses.append(join.onclause) for j in join.left, join.right: if isinstance(j, Join): visit_join(j) @@ -84,6 +99,16 @@ class SQLBase7Compiler(CompilerBase): return and_(*clauses) return None + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + +class _OuterJoinColumn(ClauseElement): + __visit_name__ = 'outer_join_column' + + def __init__(self, column): + self.column = column + class SQLBase7Dialect(DefaultDialect):