fix: add basic data type coercion for CSV -> SQLAlchemy import

this should support common scenarios; may need to be more flexible if
customizations are needed but we'll see
This commit is contained in:
Lance Edgar 2025-10-20 16:21:21 -05:00
parent 8c3948ff33
commit c38cd2c179
2 changed files with 416 additions and 35 deletions

View file

@ -25,13 +25,17 @@ Importing from CSV
""" """
import csv import csv
import datetime
import decimal
import logging import logging
import uuid as _uuid import uuid as _uuid
from collections import OrderedDict from collections import OrderedDict
import sqlalchemy as sa
from sqlalchemy_utils.functions import get_primary_keys from sqlalchemy_utils.functions import get_primary_keys
from wuttjamaican.db.util import make_topo_sortkey, UUID from wuttjamaican.db.util import make_topo_sortkey, UUID
from wuttjamaican.util import parse_bool
from .base import FromFile from .base import FromFile
from .handlers import FromFileHandler from .handlers import FromFileHandler
@ -144,38 +148,48 @@ class FromCsvToSqlalchemyMixin: # pylint: disable=too-few-public-methods
""" """
Mixin class for CSV SQLAlchemy ORM :term:`importers <importer>`. Mixin class for CSV SQLAlchemy ORM :term:`importers <importer>`.
Meant to be used by :class:`FromCsvToSqlalchemyHandlerMixin`. Such importers are generated automatically by
:class:`FromCsvToSqlalchemyHandlerMixin`, so you won't typically
reference this mixin class directly.
This mixin adds some logic to better handle ``uuid`` key fields This mixin adds data type coercion for each field value read from
which are of :class:`~wuttjamaican:wuttjamaican.db.util.UUID` data the CSV file; see :meth:`normalize_source_object()`.
type (i.e. on the target side). Namely, when reading ``uuid``
values as string from CSV, convert them to proper UUID instances, .. attribute:: coercers
so the key matching between source and target will behave as
expected. Dict of coercer functions, keyed by field name. This is an
empty dict by default; however typical usage does not require
you to set it, as it's auto-provided from
:func:`make_coercers()`.
Each coercer function should accept a single value, and return
the coerced value, e.g.::
def coerce_int(val):
return int(val)
""" """
def __init__(self, config, **kwargs): coercers = {}
super().__init__(config, **kwargs)
# nb. keep track of any key fields which use proper UUID type def normalize_source_object(self, obj):
self.uuid_keys = [] """
for field in self.get_keys(): Normalize a source record from CSV input file. See also the
attr = getattr(self.model_class, field) parent docs for
if len(attr.prop.columns) == 1: :meth:`wuttasync.importing.base.Importer.normalize_source_object()`.
if isinstance(attr.prop.columns[0].type, UUID):
self.uuid_keys.append(field)
def normalize_source_object(self, obj): # pylint: disable=empty-docstring This will invoke the appropriate coercer function for each
""" """ field, according to :attr:`coercers`.
data = dict(obj)
# nb. convert to proper UUID values so key matching will work :param obj: Raw data record (dict) from CSV reader.
# properly, where applicable
for key in self.uuid_keys:
uuid = data[key]
if uuid and not isinstance(uuid, _uuid.UUID):
data[key] = _uuid.UUID(uuid)
:returns: Final data dict for the record.
"""
data = {}
for field in self.fields:
value = obj[field]
if field in self.coercers:
value = self.coercers[field](value)
data[field] = value
return data return data
@ -267,6 +281,9 @@ class FromCsvToSqlalchemyHandlerMixin:
* :attr:`FromImporterBase` * :attr:`FromImporterBase`
* :attr:`ToImporterBase` * :attr:`ToImporterBase`
And :attr:`~FromCsvToSqlalchemyMixin.coercers` will be set on
the class, to the result of :func:`make_coercers()`.
:param model_class: A data model class. :param model_class: A data model class.
:param name: The "model name" for the importer/exporter. New :param name: The "model name" for the importer/exporter. New
@ -282,6 +299,7 @@ class FromCsvToSqlalchemyHandlerMixin:
{ {
"model_class": model_class, "model_class": model_class,
"key": list(get_primary_keys(model_class)), "key": list(get_primary_keys(model_class)),
"coercers": make_coercers(model_class),
}, },
) )
@ -299,3 +317,149 @@ class FromCsvToWutta(FromCsvToSqlalchemyHandlerMixin, FromFileHandler, ToWuttaHa
def get_target_model(self): # pylint: disable=empty-docstring def get_target_model(self): # pylint: disable=empty-docstring
""" """ """ """
return self.app.model return self.app.model
##############################
# coercion utilities
##############################
def make_coercers(model_class):
"""
Returns a dict of coercer functions for use by
:meth:`~FromCsvToSqlalchemyMixin.normalize_source_object()`.
This is called automatically by
:meth:`~FromCsvToSqlalchemyHandlerMixin.make_importer_factory()`,
in which case the result is assigned to
:attr:`~FromCsvToSqlalchemyMixin.coercers` on the importer class.
It will iterate over all mapped fields, and call
:func:`make_coercer()` for each.
:param model_class: SQLAlchemy mapped class, e.g.
:class:`wuttjamaican:wuttjamaican.db.model.base.Person`.
:returns: Dict of coercer functions, keyed by field name.
"""
mapper = sa.inspect(model_class)
fields = list(mapper.columns.keys())
coercers = {}
for field in fields:
attr = getattr(model_class, field)
coercers[field] = make_coercer(attr)
return coercers
def make_coercer(attr): # pylint: disable=too-many-return-statements
"""
Returns a coercer function suitable for use by
:meth:`~FromCsvToSqlalchemyMixin.normalize_source_object()`.
This is typically called from :func:`make_coercers()`. The
resulting function will coerce values to the data type defined by
the given attribute, e.g.::
def coerce_int(val):
return int(val)
:param attr: SQLAlchemy mapped attribute, e.g.
:attr:`wuttjamaican:wuttjamaican.db.model.upgrades.Upgrade.exit_code`.
:returns: Coercer function based on mapped attribute data type.
"""
assert len(attr.prop.columns) == 1
column = attr.prop.columns[0]
# UUID
if isinstance(attr.type, UUID):
return coerce_uuid
# Boolean
if isinstance(attr.type, sa.Boolean):
if column.nullable:
return coerce_boolean_nullable
return coerce_boolean
# DateTime
if isinstance(attr.type, sa.DateTime) or (
hasattr(attr.type, "impl") and isinstance(attr.type.impl, sa.DateTime)
):
return coerce_datetime
# Float
# nb. check this before decimal, since Numeric inherits from Float
if isinstance(attr.type, sa.Float):
return coerce_float
# Decimal
if isinstance(attr.type, sa.Numeric):
return coerce_decimal
# Integer
if isinstance(attr.type, sa.Integer):
return coerce_integer
# String
if isinstance(attr.type, sa.String):
if column.nullable:
return coerce_string_nullable
# do not coerce
return coerce_noop
def coerce_boolean(value): # pylint: disable=missing-function-docstring
return parse_bool(value)
def coerce_boolean_nullable(value): # pylint: disable=missing-function-docstring
if value == "":
return None
return coerce_boolean(value)
def coerce_datetime(value): # pylint: disable=missing-function-docstring
if value == "":
return None
try:
return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
except ValueError:
return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
def coerce_decimal(value): # pylint: disable=missing-function-docstring
if value == "":
return None
return decimal.Decimal(value)
def coerce_float(value): # pylint: disable=missing-function-docstring
if value == "":
return None
return float(value)
def coerce_integer(value): # pylint: disable=missing-function-docstring
if value == "":
return None
return int(value)
def coerce_noop(value): # pylint: disable=missing-function-docstring
return value
def coerce_string_nullable(value): # pylint: disable=missing-function-docstring
if value == "":
return None
return value
def coerce_uuid(value): # pylint: disable=missing-function-docstring
if value == "":
return None
return _uuid.UUID(value)

View file

@ -1,9 +1,15 @@
# -*- coding: utf-8; -*- # -*- coding: utf-8; -*-
import csv import csv
import datetime
import decimal
import uuid as _uuid import uuid as _uuid
from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
import sqlalchemy as sa
from sqlalchemy import orm
from wuttjamaican.testing import DataTestCase from wuttjamaican.testing import DataTestCase
from wuttasync.importing import ( from wuttasync.importing import (
@ -115,14 +121,15 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase):
def test_constructor(self): def test_constructor(self):
model = self.app.model model = self.app.model
# no uuid keys # no coercers
imp = self.make_importer(model_class=model.Setting) imp = self.make_importer(model_class=model.Setting)
self.assertEqual(imp.uuid_keys, []) self.assertEqual(imp.coercers, {})
# typical # typical
# nb. as of now Upgrade is the only table using proper UUID imp = self.make_importer(
imp = self.make_importer(model_class=model.Upgrade) model_class=model.Upgrade, coercers=mod.make_coercers(model.Setting)
self.assertEqual(imp.uuid_keys, ["uuid"]) )
self.assertEqual(len(imp.coercers), 2)
def test_normalize_source_object(self): def test_normalize_source_object(self):
model = self.app.model model = self.app.model
@ -133,13 +140,14 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase):
self.assertEqual(result, {"name": "foo", "value": "bar"}) self.assertEqual(result, {"name": "foo", "value": "bar"})
# source has proper UUID # source has proper UUID
# nb. as of now Upgrade is the only table using proper UUID
imp = self.make_importer( imp = self.make_importer(
model_class=model.Upgrade, fields=["uuid", "description"] model_class=model.Upgrade,
fields=["uuid", "description"],
coercers=mod.make_coercers(model.Upgrade),
) )
result = imp.normalize_source_object( result = imp.normalize_source_object(
{ {
"uuid": _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba"), "uuid": "06753693-d892-77f0-8000-ce71bf7ebbba",
"description": "testing", "description": "testing",
} }
) )
@ -152,9 +160,10 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase):
) )
# source has string uuid # source has string uuid
# nb. as of now Upgrade is the only table using proper UUID
imp = self.make_importer( imp = self.make_importer(
model_class=model.Upgrade, fields=["uuid", "description"] model_class=model.Upgrade,
fields=["uuid", "description"],
coercers=mod.make_coercers(model.Upgrade),
) )
result = imp.normalize_source_object( result = imp.normalize_source_object(
{"uuid": "06753693d89277f08000ce71bf7ebbba", "description": "testing"} {"uuid": "06753693d89277f08000ce71bf7ebbba", "description": "testing"}
@ -167,6 +176,33 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase):
}, },
) )
# source has boolean true/false
imp = self.make_importer(
model_class=model.Upgrade,
fields=["uuid", "executing"],
coercers=mod.make_coercers(model.Upgrade),
)
result = imp.normalize_source_object(
{"uuid": "06753693d89277f08000ce71bf7ebbba", "executing": "True"}
)
self.assertEqual(
result,
{
"uuid": _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba"),
"executing": True,
},
)
result = imp.normalize_source_object(
{"uuid": "06753693d89277f08000ce71bf7ebbba", "executing": "false"}
)
self.assertEqual(
result,
{
"uuid": _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba"),
"executing": False,
},
)
class MockMixinHandler(mod.FromCsvToSqlalchemyHandlerMixin, ToSqlalchemyHandler): class MockMixinHandler(mod.FromCsvToSqlalchemyHandlerMixin, ToSqlalchemyHandler):
ToImporterBase = ToSqlalchemy ToImporterBase = ToSqlalchemy
@ -207,6 +243,7 @@ class TestFromCsvToSqlalchemyHandlerMixin(DataTestCase):
factory = handler.make_importer_factory(model.Setting, "Setting") factory = handler.make_importer_factory(model.Setting, "Setting")
self.assertTrue(issubclass(factory, mod.FromCsv)) self.assertTrue(issubclass(factory, mod.FromCsv))
self.assertTrue(issubclass(factory, ToSqlalchemy)) self.assertTrue(issubclass(factory, ToSqlalchemy))
self.assertTrue(isinstance(factory.coercers, dict))
class TestFromCsvToWutta(DataTestCase): class TestFromCsvToWutta(DataTestCase):
@ -217,3 +254,183 @@ class TestFromCsvToWutta(DataTestCase):
def test_get_target_model(self): def test_get_target_model(self):
handler = self.make_handler() handler = self.make_handler()
self.assertIs(handler.get_target_model(), self.app.model) self.assertIs(handler.get_target_model(), self.app.model)
Base = orm.declarative_base()
class Example(Base):
__tablename__ = "example"
id = sa.Column(sa.Integer(), primary_key=True, nullable=False)
optional_id = sa.Column(sa.Integer(), nullable=True)
name = sa.Column(sa.String(length=100), nullable=False)
optional_name = sa.Column(sa.String(length=100), nullable=True)
flag = sa.Column(sa.Boolean(), nullable=False)
optional_flag = sa.Column(sa.Boolean(), nullable=True)
dt = sa.Column(sa.DateTime(), nullable=False)
optional_dt = sa.Column(sa.DateTime(), nullable=True)
dec = sa.Column(sa.Numeric(scale=8, precision=2), nullable=False)
optional_dec = sa.Column(sa.Numeric(scale=8, precision=2), nullable=True)
flt = sa.Column(sa.Float(), nullable=False)
optional_flt = sa.Column(sa.Float(), nullable=True)
class TestMakeCoercers(TestCase):
def test_basic(self):
coercers = mod.make_coercers(Example)
self.assertEqual(len(coercers), 12)
self.assertIs(coercers["id"], mod.coerce_integer)
self.assertIs(coercers["optional_id"], mod.coerce_integer)
self.assertIs(coercers["name"], mod.coerce_noop)
self.assertIs(coercers["optional_name"], mod.coerce_string_nullable)
self.assertIs(coercers["flag"], mod.coerce_boolean)
self.assertIs(coercers["optional_flag"], mod.coerce_boolean_nullable)
self.assertIs(coercers["dt"], mod.coerce_datetime)
self.assertIs(coercers["optional_dt"], mod.coerce_datetime)
self.assertIs(coercers["dec"], mod.coerce_decimal)
self.assertIs(coercers["optional_dec"], mod.coerce_decimal)
self.assertIs(coercers["flt"], mod.coerce_float)
self.assertIs(coercers["optional_flt"], mod.coerce_float)
class TestMakeCoercer(TestCase):
def test_basic(self):
func = mod.make_coercer(Example.id)
self.assertIs(func, mod.coerce_integer)
func = mod.make_coercer(Example.optional_id)
self.assertIs(func, mod.coerce_integer)
func = mod.make_coercer(Example.name)
self.assertIs(func, mod.coerce_noop)
func = mod.make_coercer(Example.optional_name)
self.assertIs(func, mod.coerce_string_nullable)
func = mod.make_coercer(Example.flag)
self.assertIs(func, mod.coerce_boolean)
func = mod.make_coercer(Example.optional_flag)
self.assertIs(func, mod.coerce_boolean_nullable)
func = mod.make_coercer(Example.dt)
self.assertIs(func, mod.coerce_datetime)
func = mod.make_coercer(Example.optional_dt)
self.assertIs(func, mod.coerce_datetime)
func = mod.make_coercer(Example.dec)
self.assertIs(func, mod.coerce_decimal)
func = mod.make_coercer(Example.optional_dec)
self.assertIs(func, mod.coerce_decimal)
func = mod.make_coercer(Example.flt)
self.assertIs(func, mod.coerce_float)
func = mod.make_coercer(Example.optional_flt)
self.assertIs(func, mod.coerce_float)
class TestCoercers(TestCase):
def test_coerce_boolean(self):
self.assertTrue(mod.coerce_boolean("true"))
self.assertTrue(mod.coerce_boolean("1"))
self.assertTrue(mod.coerce_boolean("yes"))
self.assertFalse(mod.coerce_boolean("false"))
self.assertFalse(mod.coerce_boolean("0"))
self.assertFalse(mod.coerce_boolean("no"))
self.assertFalse(mod.coerce_boolean(""))
def test_coerce_boolean_nullable(self):
self.assertTrue(mod.coerce_boolean_nullable("true"))
self.assertTrue(mod.coerce_boolean_nullable("1"))
self.assertTrue(mod.coerce_boolean_nullable("yes"))
self.assertFalse(mod.coerce_boolean_nullable("false"))
self.assertFalse(mod.coerce_boolean_nullable("0"))
self.assertFalse(mod.coerce_boolean_nullable("no"))
self.assertIsNone(mod.coerce_boolean_nullable(""))
def test_coerce_datetime(self):
self.assertIsNone(mod.coerce_datetime(""))
value = mod.coerce_datetime("2025-10-19 20:56:00")
self.assertIsInstance(value, datetime.datetime)
self.assertEqual(value, datetime.datetime(2025, 10, 19, 20, 56))
value = mod.coerce_datetime("2025-10-19 20:56:00.1234")
self.assertIsInstance(value, datetime.datetime)
self.assertEqual(value, datetime.datetime(2025, 10, 19, 20, 56, 0, 123400))
self.assertRaises(ValueError, mod.coerce_datetime, "XXX")
def test_coerce_decimal(self):
self.assertIsNone(mod.coerce_decimal(""))
value = mod.coerce_decimal("42")
self.assertIsInstance(value, decimal.Decimal)
self.assertEqual(value, decimal.Decimal("42.0"))
self.assertEqual(value, 42)
value = mod.coerce_decimal("42.0")
self.assertIsInstance(value, decimal.Decimal)
self.assertEqual(value, decimal.Decimal("42.0"))
self.assertEqual(value, 42)
self.assertRaises(decimal.InvalidOperation, mod.coerce_decimal, "XXX")
def test_coerce_float(self):
self.assertEqual(mod.coerce_float("42"), 42.0)
self.assertEqual(mod.coerce_float("42.0"), 42.0)
self.assertIsNone(mod.coerce_float(""))
self.assertRaises(ValueError, mod.coerce_float, "XXX")
def test_coerce_integer(self):
self.assertEqual(mod.coerce_integer("42"), 42)
self.assertRaises(ValueError, mod.coerce_integer, "42.0")
self.assertIsNone(mod.coerce_integer(""))
self.assertRaises(ValueError, mod.coerce_integer, "XXX")
def test_coerce_noop(self):
self.assertEqual(mod.coerce_noop(""), "")
self.assertEqual(mod.coerce_noop("42"), "42")
self.assertEqual(mod.coerce_noop("XXX"), "XXX")
def test_coerce_string_nullable(self):
self.assertIsNone(mod.coerce_string_nullable(""))
self.assertEqual(mod.coerce_string_nullable("42"), "42")
self.assertEqual(mod.coerce_string_nullable("XXX"), "XXX")
def test_coerce_uuid(self):
self.assertIsNone(mod.coerce_uuid(""))
uuid = mod.coerce_uuid("06753693d89277f08000ce71bf7ebbba")
self.assertIsInstance(uuid, _uuid.UUID)
self.assertEqual(uuid, _uuid.UUID("06753693d89277f08000ce71bf7ebbba"))
self.assertEqual(uuid.hex, "06753693d89277f08000ce71bf7ebbba")
uuid = mod.coerce_uuid("06753693-d892-77f0-8000-ce71bf7ebbba")
self.assertIsInstance(uuid, _uuid.UUID)
self.assertEqual(uuid, _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba"))
self.assertEqual(str(uuid), "06753693-d892-77f0-8000-ce71bf7ebbba")
self.assertEqual(uuid.hex, "06753693d89277f08000ce71bf7ebbba")