diff --git a/src/wuttasync/importing/csv.py b/src/wuttasync/importing/csv.py index 1d6946d..ab0bf21 100644 --- a/src/wuttasync/importing/csv.py +++ b/src/wuttasync/importing/csv.py @@ -25,13 +25,17 @@ Importing from CSV """ import csv +import datetime +import decimal import logging import uuid as _uuid from collections import OrderedDict +import sqlalchemy as sa from sqlalchemy_utils.functions import get_primary_keys from wuttjamaican.db.util import make_topo_sortkey, UUID +from wuttjamaican.util import parse_bool from .base import FromFile from .handlers import FromFileHandler @@ -144,38 +148,48 @@ class FromCsvToSqlalchemyMixin: # pylint: disable=too-few-public-methods """ Mixin class for CSV → SQLAlchemy ORM :term:`importers `. - Meant to be used by :class:`FromCsvToSqlalchemyHandlerMixin`. + Such importers are generated automatically by + :class:`FromCsvToSqlalchemyHandlerMixin`, so you won't typically + reference this mixin class directly. - This mixin adds some logic to better handle ``uuid`` key fields - which are of :class:`~wuttjamaican:wuttjamaican.db.util.UUID` data - type (i.e. on the target side). Namely, when reading ``uuid`` - values as string from CSV, convert them to proper UUID instances, - so the key matching between source and target will behave as - expected. + This mixin adds data type coercion for each field value read from + the CSV file; see :meth:`normalize_source_object()`. + + .. attribute:: coercers + + Dict of coercer functions, keyed by field name. This is an + empty dict by default; however typical usage does not require + you to set it, as it's auto-provided from + :func:`make_coercers()`. + + Each coercer function should accept a single value, and return + the coerced value, e.g.:: + + def coerce_int(val): + return int(val) """ - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + coercers = {} - # nb. keep track of any key fields which use proper UUID type - self.uuid_keys = [] - for field in self.get_keys(): - attr = getattr(self.model_class, field) - if len(attr.prop.columns) == 1: - if isinstance(attr.prop.columns[0].type, UUID): - self.uuid_keys.append(field) + def normalize_source_object(self, obj): + """ + Normalize a source record from CSV input file. See also the + parent docs for + :meth:`wuttasync.importing.base.Importer.normalize_source_object()`. - def normalize_source_object(self, obj): # pylint: disable=empty-docstring - """ """ - data = dict(obj) + This will invoke the appropriate coercer function for each + field, according to :attr:`coercers`. - # nb. convert to proper UUID values so key matching will work - # properly, where applicable - for key in self.uuid_keys: - uuid = data[key] - if uuid and not isinstance(uuid, _uuid.UUID): - data[key] = _uuid.UUID(uuid) + :param obj: Raw data record (dict) from CSV reader. + :returns: Final data dict for the record. + """ + data = {} + for field in self.fields: + value = obj[field] + if field in self.coercers: + value = self.coercers[field](value) + data[field] = value return data @@ -267,6 +281,9 @@ class FromCsvToSqlalchemyHandlerMixin: * :attr:`FromImporterBase` * :attr:`ToImporterBase` + And :attr:`~FromCsvToSqlalchemyMixin.coercers` will be set on + the class, to the result of :func:`make_coercers()`. + :param model_class: A data model class. :param name: The "model name" for the importer/exporter. New @@ -282,6 +299,7 @@ class FromCsvToSqlalchemyHandlerMixin: { "model_class": model_class, "key": list(get_primary_keys(model_class)), + "coercers": make_coercers(model_class), }, ) @@ -299,3 +317,149 @@ class FromCsvToWutta(FromCsvToSqlalchemyHandlerMixin, FromFileHandler, ToWuttaHa def get_target_model(self): # pylint: disable=empty-docstring """ """ return self.app.model + + +############################## +# coercion utilities +############################## + + +def make_coercers(model_class): + """ + Returns a dict of coercer functions for use by + :meth:`~FromCsvToSqlalchemyMixin.normalize_source_object()`. + + This is called automatically by + :meth:`~FromCsvToSqlalchemyHandlerMixin.make_importer_factory()`, + in which case the result is assigned to + :attr:`~FromCsvToSqlalchemyMixin.coercers` on the importer class. + + It will iterate over all mapped fields, and call + :func:`make_coercer()` for each. + + :param model_class: SQLAlchemy mapped class, e.g. + :class:`wuttjamaican:wuttjamaican.db.model.base.Person`. + + :returns: Dict of coercer functions, keyed by field name. + """ + mapper = sa.inspect(model_class) + fields = list(mapper.columns.keys()) + + coercers = {} + for field in fields: + attr = getattr(model_class, field) + coercers[field] = make_coercer(attr) + + return coercers + + +def make_coercer(attr): # pylint: disable=too-many-return-statements + """ + Returns a coercer function suitable for use by + :meth:`~FromCsvToSqlalchemyMixin.normalize_source_object()`. + + This is typically called from :func:`make_coercers()`. The + resulting function will coerce values to the data type defined by + the given attribute, e.g.:: + + def coerce_int(val): + return int(val) + + :param attr: SQLAlchemy mapped attribute, e.g. + :attr:`wuttjamaican:wuttjamaican.db.model.upgrades.Upgrade.exit_code`. + + :returns: Coercer function based on mapped attribute data type. + """ + assert len(attr.prop.columns) == 1 + column = attr.prop.columns[0] + + # UUID + if isinstance(attr.type, UUID): + return coerce_uuid + + # Boolean + if isinstance(attr.type, sa.Boolean): + if column.nullable: + return coerce_boolean_nullable + return coerce_boolean + + # DateTime + if isinstance(attr.type, sa.DateTime) or ( + hasattr(attr.type, "impl") and isinstance(attr.type.impl, sa.DateTime) + ): + return coerce_datetime + + # Float + # nb. check this before decimal, since Numeric inherits from Float + if isinstance(attr.type, sa.Float): + return coerce_float + + # Decimal + if isinstance(attr.type, sa.Numeric): + return coerce_decimal + + # Integer + if isinstance(attr.type, sa.Integer): + return coerce_integer + + # String + if isinstance(attr.type, sa.String): + if column.nullable: + return coerce_string_nullable + + # do not coerce + return coerce_noop + + +def coerce_boolean(value): # pylint: disable=missing-function-docstring + return parse_bool(value) + + +def coerce_boolean_nullable(value): # pylint: disable=missing-function-docstring + if value == "": + return None + return coerce_boolean(value) + + +def coerce_datetime(value): # pylint: disable=missing-function-docstring + if value == "": + return None + + try: + return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S") + except ValueError: + return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f") + + +def coerce_decimal(value): # pylint: disable=missing-function-docstring + if value == "": + return None + return decimal.Decimal(value) + + +def coerce_float(value): # pylint: disable=missing-function-docstring + if value == "": + return None + return float(value) + + +def coerce_integer(value): # pylint: disable=missing-function-docstring + if value == "": + return None + return int(value) + + +def coerce_noop(value): # pylint: disable=missing-function-docstring + return value + + +def coerce_string_nullable(value): # pylint: disable=missing-function-docstring + if value == "": + return None + return value + + +def coerce_uuid(value): # pylint: disable=missing-function-docstring + if value == "": + return None + return _uuid.UUID(value) diff --git a/tests/importing/test_csv.py b/tests/importing/test_csv.py index 8544d63..b3f0fad 100644 --- a/tests/importing/test_csv.py +++ b/tests/importing/test_csv.py @@ -1,9 +1,15 @@ # -*- coding: utf-8; -*- import csv +import datetime +import decimal import uuid as _uuid +from unittest import TestCase from unittest.mock import patch +import sqlalchemy as sa +from sqlalchemy import orm + from wuttjamaican.testing import DataTestCase from wuttasync.importing import ( @@ -115,14 +121,15 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase): def test_constructor(self): model = self.app.model - # no uuid keys + # no coercers imp = self.make_importer(model_class=model.Setting) - self.assertEqual(imp.uuid_keys, []) + self.assertEqual(imp.coercers, {}) # typical - # nb. as of now Upgrade is the only table using proper UUID - imp = self.make_importer(model_class=model.Upgrade) - self.assertEqual(imp.uuid_keys, ["uuid"]) + imp = self.make_importer( + model_class=model.Upgrade, coercers=mod.make_coercers(model.Setting) + ) + self.assertEqual(len(imp.coercers), 2) def test_normalize_source_object(self): model = self.app.model @@ -133,13 +140,14 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase): self.assertEqual(result, {"name": "foo", "value": "bar"}) # source has proper UUID - # nb. as of now Upgrade is the only table using proper UUID imp = self.make_importer( - model_class=model.Upgrade, fields=["uuid", "description"] + model_class=model.Upgrade, + fields=["uuid", "description"], + coercers=mod.make_coercers(model.Upgrade), ) result = imp.normalize_source_object( { - "uuid": _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba"), + "uuid": "06753693-d892-77f0-8000-ce71bf7ebbba", "description": "testing", } ) @@ -152,9 +160,10 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase): ) # source has string uuid - # nb. as of now Upgrade is the only table using proper UUID imp = self.make_importer( - model_class=model.Upgrade, fields=["uuid", "description"] + model_class=model.Upgrade, + fields=["uuid", "description"], + coercers=mod.make_coercers(model.Upgrade), ) result = imp.normalize_source_object( {"uuid": "06753693d89277f08000ce71bf7ebbba", "description": "testing"} @@ -167,6 +176,33 @@ class TestFromCsvToSqlalchemyMixin(DataTestCase): }, ) + # source has boolean true/false + imp = self.make_importer( + model_class=model.Upgrade, + fields=["uuid", "executing"], + coercers=mod.make_coercers(model.Upgrade), + ) + result = imp.normalize_source_object( + {"uuid": "06753693d89277f08000ce71bf7ebbba", "executing": "True"} + ) + self.assertEqual( + result, + { + "uuid": _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba"), + "executing": True, + }, + ) + result = imp.normalize_source_object( + {"uuid": "06753693d89277f08000ce71bf7ebbba", "executing": "false"} + ) + self.assertEqual( + result, + { + "uuid": _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba"), + "executing": False, + }, + ) + class MockMixinHandler(mod.FromCsvToSqlalchemyHandlerMixin, ToSqlalchemyHandler): ToImporterBase = ToSqlalchemy @@ -207,6 +243,7 @@ class TestFromCsvToSqlalchemyHandlerMixin(DataTestCase): factory = handler.make_importer_factory(model.Setting, "Setting") self.assertTrue(issubclass(factory, mod.FromCsv)) self.assertTrue(issubclass(factory, ToSqlalchemy)) + self.assertTrue(isinstance(factory.coercers, dict)) class TestFromCsvToWutta(DataTestCase): @@ -217,3 +254,183 @@ class TestFromCsvToWutta(DataTestCase): def test_get_target_model(self): handler = self.make_handler() self.assertIs(handler.get_target_model(), self.app.model) + + +Base = orm.declarative_base() + + +class Example(Base): + __tablename__ = "example" + + id = sa.Column(sa.Integer(), primary_key=True, nullable=False) + optional_id = sa.Column(sa.Integer(), nullable=True) + + name = sa.Column(sa.String(length=100), nullable=False) + optional_name = sa.Column(sa.String(length=100), nullable=True) + + flag = sa.Column(sa.Boolean(), nullable=False) + optional_flag = sa.Column(sa.Boolean(), nullable=True) + + dt = sa.Column(sa.DateTime(), nullable=False) + optional_dt = sa.Column(sa.DateTime(), nullable=True) + + dec = sa.Column(sa.Numeric(scale=8, precision=2), nullable=False) + optional_dec = sa.Column(sa.Numeric(scale=8, precision=2), nullable=True) + + flt = sa.Column(sa.Float(), nullable=False) + optional_flt = sa.Column(sa.Float(), nullable=True) + + +class TestMakeCoercers(TestCase): + + def test_basic(self): + coercers = mod.make_coercers(Example) + self.assertEqual(len(coercers), 12) + + self.assertIs(coercers["id"], mod.coerce_integer) + self.assertIs(coercers["optional_id"], mod.coerce_integer) + self.assertIs(coercers["name"], mod.coerce_noop) + self.assertIs(coercers["optional_name"], mod.coerce_string_nullable) + self.assertIs(coercers["flag"], mod.coerce_boolean) + self.assertIs(coercers["optional_flag"], mod.coerce_boolean_nullable) + self.assertIs(coercers["dt"], mod.coerce_datetime) + self.assertIs(coercers["optional_dt"], mod.coerce_datetime) + self.assertIs(coercers["dec"], mod.coerce_decimal) + self.assertIs(coercers["optional_dec"], mod.coerce_decimal) + self.assertIs(coercers["flt"], mod.coerce_float) + self.assertIs(coercers["optional_flt"], mod.coerce_float) + + +class TestMakeCoercer(TestCase): + + def test_basic(self): + func = mod.make_coercer(Example.id) + self.assertIs(func, mod.coerce_integer) + + func = mod.make_coercer(Example.optional_id) + self.assertIs(func, mod.coerce_integer) + + func = mod.make_coercer(Example.name) + self.assertIs(func, mod.coerce_noop) + + func = mod.make_coercer(Example.optional_name) + self.assertIs(func, mod.coerce_string_nullable) + + func = mod.make_coercer(Example.flag) + self.assertIs(func, mod.coerce_boolean) + + func = mod.make_coercer(Example.optional_flag) + self.assertIs(func, mod.coerce_boolean_nullable) + + func = mod.make_coercer(Example.dt) + self.assertIs(func, mod.coerce_datetime) + + func = mod.make_coercer(Example.optional_dt) + self.assertIs(func, mod.coerce_datetime) + + func = mod.make_coercer(Example.dec) + self.assertIs(func, mod.coerce_decimal) + + func = mod.make_coercer(Example.optional_dec) + self.assertIs(func, mod.coerce_decimal) + + func = mod.make_coercer(Example.flt) + self.assertIs(func, mod.coerce_float) + + func = mod.make_coercer(Example.optional_flt) + self.assertIs(func, mod.coerce_float) + + +class TestCoercers(TestCase): + + def test_coerce_boolean(self): + self.assertTrue(mod.coerce_boolean("true")) + self.assertTrue(mod.coerce_boolean("1")) + self.assertTrue(mod.coerce_boolean("yes")) + + self.assertFalse(mod.coerce_boolean("false")) + self.assertFalse(mod.coerce_boolean("0")) + self.assertFalse(mod.coerce_boolean("no")) + + self.assertFalse(mod.coerce_boolean("")) + + def test_coerce_boolean_nullable(self): + self.assertTrue(mod.coerce_boolean_nullable("true")) + self.assertTrue(mod.coerce_boolean_nullable("1")) + self.assertTrue(mod.coerce_boolean_nullable("yes")) + + self.assertFalse(mod.coerce_boolean_nullable("false")) + self.assertFalse(mod.coerce_boolean_nullable("0")) + self.assertFalse(mod.coerce_boolean_nullable("no")) + + self.assertIsNone(mod.coerce_boolean_nullable("")) + + def test_coerce_datetime(self): + self.assertIsNone(mod.coerce_datetime("")) + + value = mod.coerce_datetime("2025-10-19 20:56:00") + self.assertIsInstance(value, datetime.datetime) + self.assertEqual(value, datetime.datetime(2025, 10, 19, 20, 56)) + + value = mod.coerce_datetime("2025-10-19 20:56:00.1234") + self.assertIsInstance(value, datetime.datetime) + self.assertEqual(value, datetime.datetime(2025, 10, 19, 20, 56, 0, 123400)) + + self.assertRaises(ValueError, mod.coerce_datetime, "XXX") + + def test_coerce_decimal(self): + self.assertIsNone(mod.coerce_decimal("")) + + value = mod.coerce_decimal("42") + self.assertIsInstance(value, decimal.Decimal) + self.assertEqual(value, decimal.Decimal("42.0")) + self.assertEqual(value, 42) + + value = mod.coerce_decimal("42.0") + self.assertIsInstance(value, decimal.Decimal) + self.assertEqual(value, decimal.Decimal("42.0")) + self.assertEqual(value, 42) + + self.assertRaises(decimal.InvalidOperation, mod.coerce_decimal, "XXX") + + def test_coerce_float(self): + self.assertEqual(mod.coerce_float("42"), 42.0) + self.assertEqual(mod.coerce_float("42.0"), 42.0) + + self.assertIsNone(mod.coerce_float("")) + + self.assertRaises(ValueError, mod.coerce_float, "XXX") + + def test_coerce_integer(self): + self.assertEqual(mod.coerce_integer("42"), 42) + self.assertRaises(ValueError, mod.coerce_integer, "42.0") + + self.assertIsNone(mod.coerce_integer("")) + + self.assertRaises(ValueError, mod.coerce_integer, "XXX") + + def test_coerce_noop(self): + self.assertEqual(mod.coerce_noop(""), "") + + self.assertEqual(mod.coerce_noop("42"), "42") + self.assertEqual(mod.coerce_noop("XXX"), "XXX") + + def test_coerce_string_nullable(self): + self.assertIsNone(mod.coerce_string_nullable("")) + + self.assertEqual(mod.coerce_string_nullable("42"), "42") + self.assertEqual(mod.coerce_string_nullable("XXX"), "XXX") + + def test_coerce_uuid(self): + self.assertIsNone(mod.coerce_uuid("")) + + uuid = mod.coerce_uuid("06753693d89277f08000ce71bf7ebbba") + self.assertIsInstance(uuid, _uuid.UUID) + self.assertEqual(uuid, _uuid.UUID("06753693d89277f08000ce71bf7ebbba")) + self.assertEqual(uuid.hex, "06753693d89277f08000ce71bf7ebbba") + + uuid = mod.coerce_uuid("06753693-d892-77f0-8000-ce71bf7ebbba") + self.assertIsInstance(uuid, _uuid.UUID) + self.assertEqual(uuid, _uuid.UUID("06753693-d892-77f0-8000-ce71bf7ebbba")) + self.assertEqual(str(uuid), "06753693-d892-77f0-8000-ce71bf7ebbba") + self.assertEqual(uuid.hex, "06753693d89277f08000ce71bf7ebbba")