diff --git a/addons/purchase_product_matrix/models/purchase.py b/addons/purchase_product_matrix/models/purchase.py
index 141becfa11bb5878b3a43d758c9e0fca4850c899..3505f3e3761f4472427ef955436484abfe45f41d 100644
--- a/addons/purchase_product_matrix/models/purchase.py
+++ b/addons/purchase_product_matrix/models/purchase.py
@@ -38,6 +38,7 @@ class PurchaseOrder(models.Model):
         if self.grid and self.grid_update:
             grid = json.loads(self.grid)
             product_template = self.env['product.template'].browse(grid['product_template_id'])
+            product_ids = set()
             dirty_cells = grid['changes']
             Attrib = self.env['product.template.attribute.value']
             default_po_line_vals = {}
@@ -56,7 +57,13 @@ class PurchaseOrder(models.Model):
                 old_qty = sum(order_lines.mapped('product_qty'))
                 qty = cell['qty']
                 diff = qty - old_qty
-                if diff and order_lines:
+
+                if not diff:
+                    continue
+
+                product_ids.add(product.id)
+
+                if order_lines:
                     if qty == 0:
                         if self.state in ['draft', 'sent']:
                             # Remove lines if qty was set to 0 in matrix
@@ -87,7 +94,7 @@ class PurchaseOrder(models.Model):
                             # if len(order_lines) > 1:
                             #     # Remove 1+ lines
                             #     self.order_line -= order_lines[1:]
-                elif diff:
+                else:
                     if not default_po_line_vals:
                         OrderLine = self.env['purchase.order.line']
                         default_po_line_vals = OrderLine.default_get(OrderLine._fields.keys())
@@ -100,10 +107,14 @@ class PurchaseOrder(models.Model):
                         product_qty=qty,
                         product_no_variant_attribute_value_ids=no_variant_attribute_values.ids)
                     ))
-            if new_lines:
+            if product_ids:
                 res = False
-                self.update(dict(order_line=new_lines))
-                for line in self.order_line.filtered(lambda line: line.product_template_id == product_template):
+                if new_lines:
+                    # Add new PO lines
+                    self.update(dict(order_line=new_lines))
+
+                # Recompute prices for new/modified lines:
+                for line in self.order_line.filtered(lambda line: line.product_id.id in product_ids):
                     res = line._product_id_change() or res
                     line._onchange_quantity()
                 return res