diff --git a/partner_contact_address_default/models/res_partner.py b/partner_contact_address_default/models/res_partner.py index a37f74bfe..3160f1507 100644 --- a/partner_contact_address_default/models/res_partner.py +++ b/partner_contact_address_default/models/res_partner.py @@ -17,14 +17,34 @@ class ResPartner(models.Model): ) def get_address_default_type(self): + """This will be the extension method for other contact types""" return ["delivery", "invoice"] def address_get(self, adr_pref=None): + """Force the delivery or invoice addresses. It will try to default + to the one set in the commercial partner if any""" res = super().address_get(adr_pref) - default_address_type_list = self.get_address_default_type() + adr_pref = adr_pref or [] + default_address_type_list = { + x for x in adr_pref if x in self.get_address_default_type() + } for partner in self: for addr_type in default_address_type_list: - default_address_id = partner["partner_{}_id".format(addr_type)] + default_address_id = ( + partner["partner_{}_id".format(addr_type)] + or partner.commercial_partner_id["partner_{}_id".format(addr_type)] + ) if default_address_id: res[addr_type] = default_address_id.id return res + + def write(self, vals): + """We want to prevent archived contacts as default addresses""" + if vals.get("active") is False: + self.search([("partner_delivery_id", "in", self.ids)]).write( + {"partner_delivery_id": False} + ) + self.search([("partner_invoice_id", "in", self.ids)]).write( + {"partner_invoice_id": False} + ) + return super().write(vals) diff --git a/partner_contact_address_default/tests/test_partner_contact_address_default.py b/partner_contact_address_default/tests/test_partner_contact_address_default.py index 1e99df16a..464718d5c 100644 --- a/partner_contact_address_default/tests/test_partner_contact_address_default.py +++ b/partner_contact_address_default/tests/test_partner_contact_address_default.py @@ -30,12 +30,12 @@ class TestPartnerContactAddressDefault(common.TransactionCase): def test_contact_address_default(self): self.partner.partner_delivery_id = self.partner self.partner.partner_invoice_id = self.partner - res = self.partner.address_get() + res = self.partner.address_get(["delivery", "invoice"]) self.assertEqual(res["delivery"], self.partner.id) self.assertEqual(res["invoice"], self.partner.id) self.partner_child_delivery2.partner_delivery_id = self.partner_child_delivery2 self.partner_child_delivery2.partner_invoice_id = self.partner_child_delivery2 - res = self.partner_child_delivery2.address_get() + res = self.partner_child_delivery2.address_get(["delivery", "invoice"]) self.assertEqual(res["delivery"], self.partner_child_delivery2.id) self.assertEqual(res["invoice"], self.partner_child_delivery2.id) diff --git a/partner_contact_address_default/views/res_partner_views.xml b/partner_contact_address_default/views/res_partner_views.xml index 5330fdb07..25f864a5f 100644 --- a/partner_contact_address_default/views/res_partner_views.xml +++ b/partner_contact_address_default/views/res_partner_views.xml @@ -9,13 +9,15 @@ @@ -23,7 +25,7 @@ @@ -35,18 +37,16 @@ - - - - - - + + + +