# Part of Odoo. See LICENSE file for full copyright and licensing details.
import base64

from datetime import datetime
from freezegun import freeze_time
from lxml import etree
from pytz import timezone
from odoo import Command

from odoo.exceptions import ValidationError, UserError
from odoo.tests import tagged
from odoo.tools import misc
from odoo.addons.l10n_sa_edi.tests.common import TestSaEdiCommon


@tagged('post_install_l10n', '-at_install', 'post_install')
class TestEdiZatca(TestSaEdiCommon):
    # """Test ZATCA EDI compliance for Saudi Arabia."""

    def _test_document_generation(self, test_file_path, expected_xpath, freeze_time_at, additional_xpath='', document_type=False, move=False, move_data=False):
        """
        Common helper to test document generation against expected XML.
        """
        with freeze_time(freeze_time_at):
            # Load expected XML
            expected_xml = misc.file_open(test_file_path, 'rb').read()
            expected_tree = self.get_xml_tree_from_string(expected_xml)
            expected_tree = self.with_applied_xpath(expected_tree, expected_xpath)

            creation_handlers = {
                "invoice": self._create_test_invoice,
                "credit_note": self._create_credit_note,
                "debit_note": self._create_debit_note,
            }

            if additional_xpath:
                expected_tree = self.with_applied_xpath(expected_tree, additional_xpath)

            if move:
                final_move = move
            elif move_data and document_type in creation_handlers:
                final_move = creation_handlers[document_type](**move_data)
            else:
                raise ValidationError("Either move or document_type + move_data need to be given")

            # Generate ZATCA XML
            if final_move.state != 'posted':
                final_move.action_post()

            final_move._l10n_sa_generate_unsigned_data()
            generated_file = self.env['account.edi.format']._l10n_sa_generate_zatca_template(final_move)
            current_tree = self.get_xml_tree_from_string(generated_file)
            current_tree = self.with_applied_xpath(current_tree, self.remove_ubl_extensions_xpath)

            # Assert
            self.assertXmlTreeEqual(current_tree, expected_tree)

    def testCreditNoteSimplified(self):
        """Test simplified credit note generation."""
        move_data = {
            'name': 'INV/2023/00034',
            'invoice_date': '2023-03-10',
            'invoice_date_due': '2023-03-10',
            'partner_id': self.partner_sa_simplified,
            'invoice_line_ids': [{
                'product_id': self.product_burger.id,
                'price_unit': self.product_burger.standard_price,
                'quantity': 3,
                'tax_ids': self.tax_15.ids,
            }]
        }

        self._test_document_generation(
            document_type='credit_note',
            test_file_path='l10n_sa_edi/tests/compliance/simplified/credit.xml',
            expected_xpath=self.credit_note_applied_xpath,
            move_data=move_data,
            freeze_time_at=datetime(2023, 3, 10, 14, 59, 38, tzinfo=timezone('Etc/GMT-3'))
        )

    def testCreditNoteStandard(self):
        """Test standard credit note generation."""
        move_data = {
            'name': 'INV/2022/00014',
            'invoice_date': '2022-09-05',
            'invoice_date_due': '2022-09-22',
            'partner_id': self.partner_sa,
            'invoice_line_ids': [{
                'product_id': self.product_a.id,
                'price_unit': self.product_a.standard_price,
                'tax_ids': self.tax_15.ids,
            }]
        }

        additional_xpath = '''
            <xpath expr="(//*[local-name()='AdditionalDocumentReference']/*[local-name()='UUID'])[1]" position="replace">
                <cbc:UUID xmlns:cbc="urn:oasis:names:specification:ubl:schema:xsd:CommonBasicComponents-2">___ignore___</cbc:UUID>
            </xpath>
        '''

        self._test_document_generation(
            document_type='credit_note',
            test_file_path='l10n_sa_edi/tests/compliance/standard/credit.xml',
            expected_xpath=self.credit_note_applied_xpath,
            move_data=move_data,
            freeze_time_at=datetime(2022, 9, 5, 9, 39, 15, tzinfo=timezone('Etc/GMT-3')),
            additional_xpath=additional_xpath
        )

    def testDebitNoteSimplified(self):
        """Test simplified debit note generation."""
        move_data = {
            'name': 'INV/2023/00034',
            'invoice_date': '2023-03-10',
            'invoice_date_due': '2023-03-10',
            'partner_id': self.partner_sa_simplified,
            'invoice_line_ids': [{
                'product_id': self.product_burger.id,
                'price_unit': self.product_burger.standard_price,
                'quantity': 2,
                'tax_ids': self.tax_15.ids,
            }]
        }

        self._test_document_generation(
            document_type='debit_note',
            test_file_path='l10n_sa_edi/tests/compliance/simplified/debit.xml',
            expected_xpath=self.debit_note_applied_xpath,
            move_data=move_data,
            freeze_time_at=datetime(2023, 3, 10, 15, 1, 46, tzinfo=timezone('Etc/GMT-3'))
        )

    def testDebitNoteStandard(self):
        """Test standard debit note generation."""
        move_data = {
            'name': 'INV/2022/00001',
            'invoice_date': '2022-09-05',
            'invoice_date_due': '2022-09-22',
            'partner_id': self.partner_sa,
            'invoice_line_ids': [{
                'product_id': self.product_b.id,
                'price_unit': self.product_b.standard_price,
                'tax_ids': self.tax_15.ids,
            }]
        }

        additional_xpath = '''
            <xpath expr="(//*[local-name()='AdditionalDocumentReference']/*[local-name()='UUID'])[1]" position="replace">
                <cbc:UUID xmlns:cbc="urn:oasis:names:specification:ubl:schema:xsd:CommonBasicComponents-2">___ignore___</cbc:UUID>
            </xpath>
        '''

        self._test_document_generation(
            document_type='debit_note',
            test_file_path='l10n_sa_edi/tests/compliance/standard/debit.xml',
            expected_xpath=self.debit_note_applied_xpath,
            move_data=move_data,
            freeze_time_at=datetime(2022, 9, 5, 9, 45, 27, tzinfo=timezone('Etc/GMT-3')),
            additional_xpath=additional_xpath
        )

    def testInvoiceSimplified(self):
        """Test simplified invoice generation."""
        move_data = {
            'name': 'INV/2023/00034',
            'invoice_date': '2023-03-10',
            'invoice_date_due': '2023-03-10',
            'partner_id': self.partner_sa_simplified,
            'invoice_line_ids': [{
                'product_id': self.product_burger.id,
                'price_unit': self.product_burger.standard_price,
                'quantity': 3,
                'tax_ids': self.tax_15.ids,
            }]
        }

        self._test_document_generation(
            document_type='invoice',
            test_file_path='l10n_sa_edi/tests/compliance/simplified/invoice.xml',
            expected_xpath=self.invoice_applied_xpath,
            move_data=move_data,
            freeze_time_at=datetime(2023, 3, 10, 14, 56, 55, tzinfo=timezone('Etc/GMT-3'))
        )

    def testInvoiceStandard(self):
        """Test standard invoice generation."""
        move_data = {
            'name': 'INV/2022/00014',
            'invoice_date': '2022-09-05',
            'invoice_date_due': '2022-09-22',
            'partner_id': self.partner_sa,
            'invoice_line_ids': [{
                'product_id': self.product_a.id,
                'price_unit': self.product_a.standard_price,
                'tax_ids': self.tax_15.ids,
            }]
        }

        self._test_document_generation(
            document_type='invoice',
            test_file_path='l10n_sa_edi/tests/compliance/standard/invoice.xml',
            expected_xpath=self.invoice_applied_xpath,
            move_data=move_data,
            freeze_time_at=datetime(2022, 9, 5, 8, 20, 2, tzinfo=timezone('Etc/GMT-3'))
        )

    def testInvoiceWithDownpayment(self):
        """Test invoice generation with downpayment scenarios."""
        if 'sale' not in self.env["ir.module.module"]._installed():
            self.skipTest("Sale module is not installed")
        self.env.user.group_ids += self.env.ref('sales_team.group_sale_salesman')

        freeze = datetime(2022, 9, 5, 8, 20, 2, tzinfo=timezone('Etc/GMT-3'))

        # Helper to test generated files
        saudi_pricelist = self.env['product.pricelist'].create({
            'name': 'SAR',
            'currency_id': self.env.ref('base.SAR').id
        })
        with freeze_time(freeze):
            sale_order = self.env['sale.order'].sudo().create({
                'partner_id': self.partner_sa.id,
                'pricelist_id': saudi_pricelist.id,
                'order_line': [
                    Command.create({
                        'product_id': self.product_a.id,
                        'price_unit': 1000,
                        'product_uom_qty': 1,
                        'tax_ids': [Command.set(self.tax_15.ids)],
                    })
                ]
            }).sudo(False)
            sale_order.action_confirm()

            # Context for wizards
            context = {
                'active_model': 'sale.order',
                'active_ids': [sale_order.id],
                'active_id': sale_order.id,
                'default_journal_id': self.customer_invoice_journal.id,
            }

            # Create downpayment invoice
            downpayment_wizard = self.env['sale.advance.payment.inv'].with_context(context).sudo().create({
                'advance_payment_method': 'fixed',
                'fixed_amount': 115,
            })
            downpayment = downpayment_wizard._create_invoices(sale_order)
            downpayment.invoice_date_due = '2022-09-22'

            # Create final invoice
            final_wizard = self.env['sale.advance.payment.inv'].with_context(context).sudo().create({})
            final = final_wizard._create_invoices(sale_order)
            final.invoice_line_ids.filtered('is_downpayment').name = 'Down Payment'
            final.invoice_date_due = '2022-09-22'

        # Test invoices
        additional_xpath = f'''
            <xpath expr="(//*[local-name()='PaymentMeans']/*[local-name()='InstructionID'])" position="after">
                <cbc:InstructionNote xmlns:cbc="urn:oasis:names:specification:ubl:schema:xsd:CommonBasicComponents-2">{sale_order.name}</cbc:InstructionNote>
            </xpath>
        '''
        for move, test_file in [
            (downpayment, "downpayment_invoice"),
            (final, "final_invoice")
        ]:
            with self.subTest(move=move, test_file=test_file):
                self._test_document_generation(
                    test_file_path=f'l10n_sa_edi/tests/test_files/{test_file}.xml',
                    expected_xpath=self.invoice_applied_xpath,
                    additional_xpath=additional_xpath,
                    freeze_time_at=freeze,
                    move=move,
                )

        # Test credit notes
        for move, test_file in [
            (downpayment, "downpayment_credit_note"),
            (final, "final_credit_note")
        ]:
            with self.subTest(move=move, test_file=test_file):
                # Create refund
                wiz_context = {
                    'active_model': 'account.move',
                    'active_ids': [move.id],
                    'default_journal_id': move.journal_id.id,
                }
                refund_wizard = self.env['account.move.reversal'].with_context(wiz_context).create({
                    'l10n_sa_reason': 'BR-KSA-17-reason-5',
                    'date': '2022-09-05',
                })
                refund_invoice = self.env['account.move'].browse(refund_wizard.reverse_moves()['res_id'])
                refund_invoice.invoice_date_due = '2022-09-22'
                self._test_document_generation(
                    test_file_path=f'l10n_sa_edi/tests/test_files/{test_file}.xml',
                    expected_xpath=self.credit_note_applied_xpath,
                    freeze_time_at=freeze,
                    move=refund_invoice,
                )

    def testInvoiceWithRetention(self):
        """Test standard invoice generation."""

        retention_tax = self.env['account.tax'].create({
            'l10n_sa_is_retention': True,
            'name': 'Retention Tax',
            'amount_type': 'percent',
            'amount': -10.0,
        })

        move_data = {
            'name': 'INV/2022/00014',
            'invoice_date': '2022-09-05',
            'invoice_date_due': '2022-09-22',
            'partner_id': self.partner_sa,
            'invoice_line_ids': [{
                'product_id': self.product_a.id,
                'price_unit': self.product_a.standard_price,
                'tax_ids': self.tax_15.ids + retention_tax.ids,
            }]
        }

        self._test_document_generation(
            document_type='invoice',
            test_file_path='l10n_sa_edi/tests/compliance/standard/invoice.xml',
            expected_xpath=self.invoice_applied_xpath,
            move_data=move_data,
            freeze_time_at=datetime(2022, 9, 5, 8, 20, 2, tzinfo=timezone('Etc/GMT-3'))
        )

    def testCompanyOnSimplifiedInvoiceQR(self):
        move_data = {
            'name': 'INV/2025/00012',
            'invoice_date': '2025-07-05',
            'invoice_date_due': '2025-07-12',
            'company_id': self.sa_branch,
            'partner_id': self.partner_sa_simplified,
            'invoice_line_ids': [{
                'product_id': self.product_a.id,
                'price_unit': self.product_a.standard_price,
                'tax_ids': self.tax_15.ids,
            }],
        }

        # Fetch company name from xml
        invoice = self._create_test_invoice(**move_data)
        invoice.action_post()
        xml_content = self.env['account.edi.format']._l10n_sa_generate_zatca_template(invoice)
        xml_root = etree.fromstring(xml_content)
        xml_company_name = xml_root.xpath(
            "//cac:AccountingSupplierParty/cac:Party/cac:PartyName/cbc:Name",
            namespaces=self.env['account.edi.xml.ubl_21.zatca']._l10n_sa_get_namespaces()
        )[0].text.strip()

        # Fetch company name from QR code
        # Format: Tag (1 Byte) - Length (1 Byte) - Value
        invoice._l10n_sa_generate_unsigned_data()
        decoded_qr = base64.b64decode(invoice.l10n_sa_qr_code_str)
        length = decoded_qr[1]
        qr_company_name = decoded_qr[2:2 + length].decode()

        self.assertEqual(xml_company_name, qr_company_name, "Seller name on the xml does not match the seller name on the QR code")

    def test_company_missing_country_on_standard_invoice(self):
        """Test standard invoice generation when the company does not have a country set."""
        # setup new company to prevent errors in other tests
        vals = self._get_company_vals({"name": "SA Company (Minus Country)"})
        new_company = self._create_company(**vals)

        new_company_customer_invoice_journal = self.env['account.journal'].search([
            ('company_id', '=', new_company.id),
            ('type', '=', 'sale'),
        ], limit=1)
        new_company_customer_invoice_journal._l10n_sa_load_edi_demo_data()

        new_company.country_id = False

        # missing tax should always cause a user error, even if the country is blank
        move_data = {
            'name': 'INV/2022/00014',
            'invoice_date': '2022-09-05',
            'invoice_date_due': '2022-09-22',
            'company_id': new_company,
            'partner_id': self.partner_sa,
            'invoice_line_ids': [{
                'product_id': self.product_a.id,
                'price_unit': self.product_a.standard_price,
                'tax_ids': False,
            }],
        }

        invoice = self._create_test_invoice(**move_data)
        with self.assertRaises(UserError):
            invoice.action_post()

    def test_zatca_xml_price_amount_precision(self):
        """
        Test that PriceAmount has 10 decimal precision to satisfy ZATCA validation BR-KSA-EN16931-11
        """

        self.tax_15.write({
            'price_include_override': 'tax_included',
        })
        move_data = {
            'name': 'INV/2025/00013',
            'invoice_date': '2025-01-15',
            'invoice_date_due': '2025-01-15',
            'partner_id': self.partner_sa,
            'invoice_line_ids': [{
                'product_id': self.product_a.id,
                'price_unit': 200.0,
                'quantity': 7,
                'tax_ids': self.tax_15.ids,
            }]
        }
        invoice = self._create_test_invoice(**move_data)
        invoice.action_post()

        # Generate XML
        xml_content = self.env['account.edi.format']._l10n_sa_generate_zatca_template(invoice)
        xml_root = etree.fromstring(xml_content)

        # Get PriceAmount from XML
        price_amount_nodes = xml_root.xpath(
            "//cac:InvoiceLine/cac:Price/cbc:PriceAmount",
            namespaces=self.env['account.edi.xml.ubl_21.zatca']._l10n_sa_get_namespaces()
        )
        self.assertTrue(price_amount_nodes, "PriceAmount node not found in XML")
        price_amount_str = price_amount_nodes[0].text
        self.assertEqual(price_amount_str, '173.9128571429')

    def test_zatca_xml_line_rounding_amount_consistency(self):
        """Test that LineExtensionAmount + TaxAmount = RoundingAmount for each invoice line."""
        self.tax_15.price_include_override = 'tax_included'
        invoice = self._create_test_invoice(
            name='INV/2022/00001',
            invoice_date='2022-09-05',
            invoice_date_due='2022-09-22',
            partner_id=self.partner_sa,
            invoice_line_ids=[
                {
                    'product_id': self.product_a.id,
                    'price_unit': 18.0,
                    'tax_ids': self.tax_15.ids,
                },
                {
                    'product_id': self.product_b.id,
                    'price_unit': 14.0,
                    'tax_ids': self.tax_15.ids,
                }
            ]
        )
        invoice.action_post()
        xml_content = self.env['account.edi.format']._l10n_sa_generate_zatca_template(invoice)
        xml_root = etree.fromstring(xml_content)
        namespaces = self.env['account.edi.xml.ubl_21.zatca']._l10n_sa_get_namespaces()

        for line in xml_root.xpath('//cac:InvoiceLine', namespaces=namespaces):
            line_ext = float(line.xpath('cbc:LineExtensionAmount/text()', namespaces=namespaces)[0])
            tax_amt = float(line.xpath('cac:TaxTotal/cbc:TaxAmount/text()', namespaces=namespaces)[0])
            rounding_amt = float(line.xpath('cac:TaxTotal/cbc:RoundingAmount/text()', namespaces=namespaces)[0])
            self.assertEqual(line_ext + tax_amt, rounding_amt,
                msg=f"LineExtensionAmount ({line_ext}) + TaxAmount ({tax_amt}) != RoundingAmount ({rounding_amt})")

    def test_csr_generation_compliant_company(self):
        """Test that CSR generation succeeds for a compliant company with valid field lengths."""
        compliant_company = self.env['res.company'].create({
            'name': 'Valid Company Name',
            'vat': '300000000000003',
            'street': 'Short Street Name',
            'city': 'Riyadh',
            'zip': '12345',
            'country_id': self.saudi_arabia.id,
            'state_id': self.riyadh.id,
            'l10n_sa_api_mode': 'sandbox',
            'currency_id': self.env.ref('base.SAR').id,
        })
        compliant_company.partner_id.industry_id = self.env['res.partner.industry'].create({
            'name': 'Technology',
        })
        compliant_company.l10n_sa_private_key_id = self.env['certificate.key'].sudo()._generate_ec_private_key(
            compliant_company, name='Test private key'
        )
        compliant_journal = self.env['account.journal'].create({
            'name': 'Sales',
            'code': 'SAL',
            'type': 'sale',
            'company_id': compliant_company.id,
        })

        try:
            csr_string = self.env['certificate.certificate'].sudo()._l10n_sa_get_csr_str(compliant_journal)
            self.assertTrue(csr_string, "a Valid CSR should not be empty")
        except UserError as e:
            self.fail(f"Compliant company should not raise error: {e}")

    def test_csr_generation_non_compliant_company(self):
        """Test that CSR generation fails for non-compliant company with all invalid fields listed."""
        long_name = "A" * 70
        long_street = "B" * 70
        long_city = "C" * 70
        long_state_name = "D" * 70
        long_industry_name = "E" * 70
        long_journal_name = "F" * 70

        long_state = self.env['res.country.state'].create({
            'name': long_state_name,
            'code': 'LST',
            'country_id': self.saudi_arabia.id,
        })
        long_industry = self.env['res.partner.industry'].create({
            'name': long_industry_name,
        })

        non_compliant_company = self.env['res.company'].create({
            'name': long_name,
            'vat': '333333333333333',
            'street': long_street,
            'city': long_city,
            'zip': '12345',
            'country_id': self.saudi_arabia.id,
            'state_id': long_state.id,
            'l10n_sa_api_mode': 'sandbox',
            'currency_id': self.env.ref('base.SAR').id,
        })
        non_compliant_company.partner_id.industry_id = long_industry
        non_compliant_company.l10n_sa_private_key_id = self.env['certificate.key'].sudo()._generate_ec_private_key(
            non_compliant_company, name='Test private key'
        )
        non_compliant_journal = self.env['account.journal'].create({
            'name': long_journal_name,
            'code': 'NC',
            'type': 'sale',
            'company_id': non_compliant_company.id,
        })

        with self.assertRaises(UserError) as context:
            self.env['certificate.certificate'].sudo()._l10n_sa_get_csr_str(non_compliant_journal)

        error_message = str(context.exception)
        expected_error_fields = [
            "Company Name",
            "Common Name",
            "Street",
            "Locality Name",
            "State/Province Name",
            "Partner Industry Name",
        ]

        for field_name in expected_error_fields:
            self.assertIn(field_name, error_message, f"Error message should contain '{field_name}'")

    def test_otp_validation_without_company_street(self):
        """Test that validating OTP fails when the company street is missing."""
        self.company.street = False

        journal = self.env['account.journal'].search([
            *self.env['account.journal']._check_company_domain(self.company),
            ('type', '=', 'sale'),
        ], limit=1)

        wizard = self.env['l10n_sa_edi.otp.wizard'].create({
            'journal_id': journal.id,
            'l10n_sa_otp': '123456',
        })

        self.assertFalse(journal.l10n_sa_csr_errors)

        wizard.validate()

        self.assertTrue(journal.l10n_sa_csr_errors)
        self.assertEqual(
            str(journal.l10n_sa_csr_errors),
            f'<p>Please set the following on {self.company.name}: Street</p>'
        )

    def test_child_company_api_mode_change_does_not_reset_parent_journal(self):
        """Changing a child company's ZATCA API mode must not reset the parent company's journal."""
        self.customer_invoice_journal._l10n_sa_load_edi_demo_data()
        self.assertTrue(self.customer_invoice_journal.l10n_sa_production_csid_json)

        child_journal = self.env['account.journal'].create({
            'name': 'Child Sales Journal',
            'code': 'CSAL',
            'type': 'sale',
            'company_id': self.sa_branch.id,
        })
        child_journal._l10n_sa_load_edi_demo_data()
        self.assertTrue(child_journal.l10n_sa_production_csid_json)

        self.sa_branch.l10n_sa_api_mode = 'preprod'

        self.assertFalse(child_journal.l10n_sa_production_csid_json,
            "Child journal should be reset after API mode change")
        self.assertTrue(self.customer_invoice_journal.l10n_sa_production_csid_json,
            "Parent journal must not be reset when child company API mode changes")

    def test_invoice_cash_rounding_payable_amount(self):
        """Test that payable_amount is correctly computed when using cash rounding"""
        cash_rounding = self.env['account.cash.rounding'].create({
            'name': 'add_invoice_line',
            'rounding': 1.00,
            'strategy': 'add_invoice_line',
            'profit_account_id': self.company_data['default_account_revenue'].copy().id,
            'loss_account_id': self.company_data['default_account_expense'].copy().id,
            'rounding_method': 'UP',
        })

        move_data = {
            'name': 'INV/2022/00014',
            'invoice_date': '2022-09-05',
            'invoice_date_due': '2022-09-22',
            'partner_id': self.partner_sa,
            'invoice_cash_rounding_id': cash_rounding.id,
            'invoice_line_ids': [{
                'product_id': self.product_a.id,
                'price_unit': 99.55,
                'tax_ids': self.tax_15.ids,
            }],
        }

        invoice = self._create_test_invoice(**move_data)
        invoice.action_post()
        xml_content = self.env['account.edi.format']._l10n_sa_generate_zatca_template(invoice)
        xml_root = etree.fromstring(xml_content)
        payable_amount = xml_root.xpath(
            "//cbc:PayableAmount",
            namespaces=self.env['account.edi.xml.ubl_21.zatca']._l10n_sa_get_namespaces()
        )[0].text.strip()
        self.assertEqual(payable_amount, '115.00')
