diff --git a/rattail/importing/model.py b/rattail/importing/model.py index f0589508..af651c53 100644 --- a/rattail/importing/model.py +++ b/rattail/importing/model.py @@ -573,6 +573,13 @@ class CustomerImporter(ToRattail): return data + def get_group(self, group_id): + if hasattr(self, 'groups'): + return self.groups.get(group_id) + return self.session.query(model.CustomerGroup)\ + .filter(model.CustomerGroup.id == group_id)\ + .first() + def update_object(self, customer, data, local_data=None): customer = super(CustomerImporter, self).update_object(customer, data, local_data) @@ -628,13 +635,14 @@ class CustomerImporter(ToRattail): if 'group_id' in self.fields: group_id = data['group_id'] if group_id: - group = self.groups.get(group_id) + group = self.get_group(group_id) if not group: group = model.CustomerGroup() group.id = group_id group.name = "(auto-created)" self.session.add(group) - self.groups[group.id] = group + if hasattr(self, 'groups'): + self.groups[group.id] = group if group in customer.groups: if group is not customer.groups[0]: customer.groups.remove(group) @@ -647,13 +655,14 @@ class CustomerImporter(ToRattail): if 'group_id_2' in self.fields: group_id = data['group_id_2'] if group_id: - group = self.groups.get(group_id) + group = self.get_group(group_id) if not group: group = model.CustomerGroup() group.id = group_id group.name = "(auto-created)" self.session.add(group) - self.groups[group.id] = group + if hasattr(self, 'groups'): + self.groups[group.id] = group if group in customer.groups: if len(customer.groups) > 1: if group is not customer.groups[1]: