diff --git a/partner_address_version/__init__.py b/partner_address_version/__init__.py index 2800aae86..ebae2035b 100644 --- a/partner_address_version/__init__.py +++ b/partner_address_version/__init__.py @@ -2,3 +2,4 @@ # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). from . import models +from . import wizards diff --git a/partner_address_version/__manifest__.py b/partner_address_version/__manifest__.py index 6af2021d5..a034e290c 100644 --- a/partner_address_version/__manifest__.py +++ b/partner_address_version/__manifest__.py @@ -8,5 +8,7 @@ "category": "CRM", "license": "AGPL-3", "installable": True, - "depends": ["base"], + "depends": [ + "base", + ], } diff --git a/partner_address_version/models/res_partner.py b/partner_address_version/models/res_partner.py index b9bae0298..9d5aa518c 100644 --- a/partner_address_version/models/res_partner.py +++ b/partner_address_version/models/res_partner.py @@ -12,30 +12,11 @@ class ResPartner(models.Model): version_hash = fields.Char(readonly=True, copy=False) date_version = fields.Datetime(string="Date version", readonly=True) - @api.multi - def get_address_version(self): - """ - Get a versioned partner corresponding to the partner fields. - If no versioned partner exists, create a new one. - """ - self.ensure_one() - version_hash = self.get_version_hash() - versioned_partner = self.with_context(active_test=False).search( - [("version_hash", "=", version_hash)] - ) - if versioned_partner: - return versioned_partner - default = { - "active": False, - "version_hash": version_hash, - "parent_id": self.parent_id and self.parent_id.id or self.id, - "date_version": fields.Datetime.now(), - "name": self.name, - } - versioned_partner = self.copy(default=default) - return versioned_partner - def get_version_fields(self): + # deprecated uses _version_fields instead + return self._version_fields() + + def _version_fields(self): return [ "name", "street", @@ -47,24 +28,70 @@ class ResPartner(models.Model): ] def get_version_hash(self): - version_fields = self.get_version_fields() + # deprecated uses _version_hash instead + return self._version_hash() + + def _version_hash(self): + version_fields = self._version_fields() version = OrderedDict() for field in version_fields: if field == "parent_id": - parent_id = self.parent_id and self.parent_id.id or self.id + parent_id = self.parent_id.id if self.parent_id else self.id version[field] = parent_id elif self[field]: version[field] = self[field] version_hash = hashlib.md5(str(version).encode("utf-8")).hexdigest() return version_hash + def _version_impacted_tables(self): + """ + :return: + - list of tables to update in case of address versioning + """ + return [] + + def _version_exclude_keys(self): + """ + :return: + - dict: + key = table name + value = list of columns to ignore in case of address + versioning + """ + return {} + + def _version_need(self): + """ + This method is supposed to be overriden to determine when + an address versioning is needed or not + :return: True if versioning is required else False + """ + return False + + def _version_apply(self): + self.ensure_one() + if self._version_need(): + # the address is used, create a new version and + # update related tables + version_p = self._version_create() + partner_wizard = self.env[ + "base.partner.merge.automatic.wizard" + ].with_context(address_version=True) + partner_wizard._update_foreign_keys(self, version_p) + return False + @api.multi def write(self, vals): - version_fields = self.get_version_fields() - has_written_versioned_fields = any( - (f in version_fields) for f in vals.keys() - ) + version_fields = self._version_fields() + has_written_versioned_fields = any((f in version_fields) for f in vals.keys()) for partner in self: + if ( + not partner.version_hash + and not vals.get("version_hash", False) + and has_written_versioned_fields + ): + partner._version_apply() + if partner.version_hash and has_written_versioned_fields: raise exceptions.UserError( _( @@ -74,3 +101,13 @@ class ResPartner(models.Model): % (version_fields, partner.name) ) return super(ResPartner, self).write(vals) + + def _version_create(self): + version_hash = self._version_hash() + default = { + "active": False, + "version_hash": version_hash, + "parent_id": self.parent_id.id if self.parent_id else self.id, + "date_version": fields.Datetime.now(), + } + return self.copy(default=default) diff --git a/partner_address_version/readme/CONTRIBUTORS.rst b/partner_address_version/readme/CONTRIBUTORS.rst index da15bbc3f..a433be1a9 100644 --- a/partner_address_version/readme/CONTRIBUTORS.rst +++ b/partner_address_version/readme/CONTRIBUTORS.rst @@ -1,2 +1,3 @@ * Benoît Guillot * Kevin Khao +* Cédric Pigeon diff --git a/partner_address_version/tests/test_address_version.py b/partner_address_version/tests/test_address_version.py index 5b974870f..3d90d5e1f 100644 --- a/partner_address_version/tests/test_address_version.py +++ b/partner_address_version/tests/test_address_version.py @@ -29,36 +29,26 @@ class TestAddressVersion(SavepointCase): cls.partner_vals.update({"parent_id": cls.partner.id}) def test_hash(self): - test_hash = hashlib.md5( - str(self.partner_vals).encode("utf-8") - ).hexdigest() - self.assertEqual(test_hash, self.partner.get_version_hash()) + test_hash = hashlib.md5(str(self.partner_vals).encode("utf-8")).hexdigest() + self.assertEqual(test_hash, self.partner._version_hash()) def test_create_version_partner(self): - new_partner = self.partner.get_address_version() + new_partner = self.partner._version_create() self.assertEqual(new_partner.active, False) self.assertNotEqual(new_partner.id, self.partner.id) self.assertEqual(new_partner.parent_id.id, self.partner.id) - def test_get_version_hash(self): - self.partner.version_hash = self.partner.get_version_hash() - self.partner.active = False - version_partner = self.partner.get_address_version() - self.assertEqual(version_partner.id, self.partner.id) - def test_write_versioned_partner(self): - new_partner = self.partner.get_address_version() + new_partner = self.partner._version_create() with self.assertRaises(UserError): new_partner.street = "New street" def test_same_address_different_parent(self): - new_partner = self.partner.get_address_version() - new_partner_2 = self.partner_2.get_address_version() - for field in self.partner.get_version_fields(): + new_partner = self.partner._version_create() + new_partner_2 = self.partner_2._version_create() + for field in self.partner._version_fields(): if field == "parent_id": continue self.assertEqual(new_partner[field], new_partner_2[field]) self.assertNotEqual(new_partner.id, new_partner_2.id) - self.assertNotEqual( - new_partner.version_hash, new_partner_2.version_hash - ) + self.assertNotEqual(new_partner.version_hash, new_partner_2.version_hash) diff --git a/partner_address_version/wizards/__init__.py b/partner_address_version/wizards/__init__.py new file mode 100644 index 000000000..e3fc7010c --- /dev/null +++ b/partner_address_version/wizards/__init__.py @@ -0,0 +1 @@ +from . import base_partner_merge diff --git a/partner_address_version/wizards/base_partner_merge.py b/partner_address_version/wizards/base_partner_merge.py new file mode 100644 index 000000000..4c535c3ce --- /dev/null +++ b/partner_address_version/wizards/base_partner_merge.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 ACSONE SA/NV () +# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl.html). + +from odoo import models + + +class MergePartnerAutomatic(models.TransientModel): + _inherit = "base.partner.merge.automatic.wizard" + + def _get_fk_on(self, table): + foreign_keys = super(MergePartnerAutomatic, self)._get_fk_on(table) + if table == "res_partner" and self.env.context.get("address_version"): + models = self.env["res.partner"]._version_impacted_tables() + limited_fk = [] + for fk in foreign_keys: + if fk[0] in models: + ignore_col_dict = self.env["res.partner"]._version_exclude_keys() + ignore_col = ignore_col_dict.get(fk[0], False) + if ignore_col and fk[1] in ignore_col: + continue + limited_fk.append(fk) + return limited_fk + return foreign_keys