diff --git a/addons/account/models/account_invoice.py b/addons/account/models/account_invoice.py
index aff495072ce55185821c747eafe22f730d56c067..008c52777adc3a5d7ffd6bc6c882e96e537e9f59 100644
--- a/addons/account/models/account_invoice.py
+++ b/addons/account/models/account_invoice.py
@@ -998,6 +998,21 @@ class AccountInvoice(models.Model):
             result.append((0, 0, values))
         return result
 
+    @api.model
+    def _refund_tax_lines_account_change(self, lines, taxes_to_change):
+        # Let's change the account on tax lines when
+        # @param {list} lines: a list of orm commands
+        # @param {dict} taxes_to_change
+        #   key: tax ID, value: refund account
+
+        if not taxes_to_change:
+            return lines
+
+        for line in lines:
+            if isinstance(line[2], dict) and line[2]['tax_id'] in taxes_to_change:
+                line[2]['account_id'] = taxes_to_change[line[2]['tax_id']]
+        return lines
+
     def _get_refund_common_fields(self):
         return ['partner_id', 'payment_term_id', 'account_id', 'currency_id', 'journal_id']
 
@@ -1039,7 +1054,12 @@ class AccountInvoice(models.Model):
         values['invoice_line_ids'] = self._refund_cleanup_lines(invoice.invoice_line_ids)
 
         tax_lines = invoice.tax_line_ids
-        values['tax_line_ids'] = self._refund_cleanup_lines(tax_lines)
+        taxes_to_change = {
+            line.tax_id.id: line.tax_id.refund_account_id.id
+            for line in tax_lines.filtered(lambda l: l.tax_id.refund_account_id != l.tax_id.account_id)
+        }
+        cleaned_tax_lines = self._refund_cleanup_lines(tax_lines)
+        values['tax_line_ids'] = self._refund_tax_lines_account_change(cleaned_tax_lines, taxes_to_change)
 
         if journal_id:
             journal = self.env['account.journal'].browse(journal_id)
diff --git a/addons/account/tests/test_account_customer_invoice.py b/addons/account/tests/test_account_customer_invoice.py
index 12a164cd3cb2755127d3371d70bd1482a92fb20c..6c6c6a253faed52ccf8586ff5e532d4631f26737 100644
--- a/addons/account/tests/test_account_customer_invoice.py
+++ b/addons/account/tests/test_account_customer_invoice.py
@@ -170,3 +170,59 @@ class TestAccountCustomerInvoice(AccountTestUsers):
         ))
 
         self.assertEquals(invoice.amount_untaxed, sum([x.base for x in invoice.tax_line_ids]))
+
+    def test_customer_invoice_tax_refund(self):
+        company = self.env.user.company_id
+        tax_account = self.env['account.account'].create({
+            'name': 'TAX',
+            'code': 'TAX',
+            'user_type_id': self.env.ref('account.data_account_type_current_assets').id,
+            'company_id': company.id,
+        })
+
+        tax_refund_account = self.env['account.account'].create({
+            'name': 'TAX_REFUND',
+            'code': 'TAX_R',
+            'user_type_id': self.env.ref('account.data_account_type_current_assets').id,
+            'company_id': company.id,
+        })
+
+        journalrec = self.env['account.journal'].search([('type', '=', 'sale')])[0]
+        partner3 = self.env.ref('base.res_partner_3')
+        account_id = self.env['account.account'].search([('user_type_id', '=', self.env.ref('account.data_account_type_revenue').id)], limit=1).id
+
+        tax = self.env['account.tax'].create({
+            'name': 'Tax 15.0',
+            'amount': 15.0,
+            'amount_type': 'percent',
+            'type_tax_use': 'sale',
+            'account_id': tax_account.id,
+            'refund_account_id': tax_refund_account.id
+        })
+
+        invoice_line_data = [
+            (0, 0,
+                {
+                    'product_id': self.env.ref('product.product_product_1').id,
+                    'quantity': 40.0,
+                    'account_id': account_id,
+                    'name': 'product test 1',
+                    'discount': 10.00,
+                    'price_unit': 2.27,
+                    'invoice_line_tax_ids': [(6, 0, [tax.id])],
+                }
+             )]
+
+        invoice = self.env['account.invoice'].create(dict(
+            name="Test Customer Invoice",
+            reference_type="none",
+            journal_id=journalrec.id,
+            partner_id=partner3.id,
+            invoice_line_ids=invoice_line_data
+        ))
+
+        invoice.action_invoice_open()
+
+        refund = invoice.refund()
+        self.assertEqual(invoice.tax_line_ids.mapped('account_id'), tax_account)
+        self.assertEqual(refund.tax_line_ids.mapped('account_id'), tax_refund_account)