diff --git a/partner_contact_address_default/models/res_partner.py b/partner_contact_address_default/models/res_partner.py
index a37f74bfe..3160f1507 100644
--- a/partner_contact_address_default/models/res_partner.py
+++ b/partner_contact_address_default/models/res_partner.py
@@ -17,14 +17,34 @@ class ResPartner(models.Model):
)
def get_address_default_type(self):
+ """This will be the extension method for other contact types"""
return ["delivery", "invoice"]
def address_get(self, adr_pref=None):
+ """Force the delivery or invoice addresses. It will try to default
+ to the one set in the commercial partner if any"""
res = super().address_get(adr_pref)
- default_address_type_list = self.get_address_default_type()
+ adr_pref = adr_pref or []
+ default_address_type_list = {
+ x for x in adr_pref if x in self.get_address_default_type()
+ }
for partner in self:
for addr_type in default_address_type_list:
- default_address_id = partner["partner_{}_id".format(addr_type)]
+ default_address_id = (
+ partner["partner_{}_id".format(addr_type)]
+ or partner.commercial_partner_id["partner_{}_id".format(addr_type)]
+ )
if default_address_id:
res[addr_type] = default_address_id.id
return res
+
+ def write(self, vals):
+ """We want to prevent archived contacts as default addresses"""
+ if vals.get("active") is False:
+ self.search([("partner_delivery_id", "in", self.ids)]).write(
+ {"partner_delivery_id": False}
+ )
+ self.search([("partner_invoice_id", "in", self.ids)]).write(
+ {"partner_invoice_id": False}
+ )
+ return super().write(vals)
diff --git a/partner_contact_address_default/tests/test_partner_contact_address_default.py b/partner_contact_address_default/tests/test_partner_contact_address_default.py
index 1e99df16a..464718d5c 100644
--- a/partner_contact_address_default/tests/test_partner_contact_address_default.py
+++ b/partner_contact_address_default/tests/test_partner_contact_address_default.py
@@ -30,12 +30,12 @@ class TestPartnerContactAddressDefault(common.TransactionCase):
def test_contact_address_default(self):
self.partner.partner_delivery_id = self.partner
self.partner.partner_invoice_id = self.partner
- res = self.partner.address_get()
+ res = self.partner.address_get(["delivery", "invoice"])
self.assertEqual(res["delivery"], self.partner.id)
self.assertEqual(res["invoice"], self.partner.id)
self.partner_child_delivery2.partner_delivery_id = self.partner_child_delivery2
self.partner_child_delivery2.partner_invoice_id = self.partner_child_delivery2
- res = self.partner_child_delivery2.address_get()
+ res = self.partner_child_delivery2.address_get(["delivery", "invoice"])
self.assertEqual(res["delivery"], self.partner_child_delivery2.id)
self.assertEqual(res["invoice"], self.partner_child_delivery2.id)
diff --git a/partner_contact_address_default/views/res_partner_views.xml b/partner_contact_address_default/views/res_partner_views.xml
index 5330fdb07..25f864a5f 100644
--- a/partner_contact_address_default/views/res_partner_views.xml
+++ b/partner_contact_address_default/views/res_partner_views.xml
@@ -9,13 +9,15 @@
@@ -23,7 +25,7 @@
@@ -35,18 +37,16 @@
-
-
-
-
-
-
+
+
+
+