import logging
import datetime
import re
import requests
import secrets
import uuid

import stdnum
from stdnum import luhn
from stdnum.exceptions import InvalidChecksum, InvalidFormat
from stdnum.util import clean

from odoo import api, models, fields
from odoo.tools import _, LazyTranslate, hash_sign
from odoo.exceptions import ValidationError, UserError
from odoo.addons.base.models.res_partner import EU_EXTRA_VAT_CODES


_lt = LazyTranslate(__name__)
_logger = logging.getLogger(__name__)


EU_EXTRA_VAT_CODES_INV = {v: k for k, v in EU_EXTRA_VAT_CODES.items()}

_ref_vat = {
    'al': 'ALJ91402501L',
    'ar': '20055361682',
    'at': 'ATU12345675',
    'au': '83 914 571 673',
    'be': 'BE0477472701',
    'bg': 'BG1234567892',
    'br': _lt('either 11 digits for CPF or 14 digits for CNPJ'),
    'cr': '3101012009',
    'ch': _lt('CHE-123.456.788 TVA or CHE-123.456.788 MWST or CHE-123.456.788 IVA'),  # Swiss by Yannick Vaucher @ Camptocamp
    'cl': '76086428-5',
    'co': '213123432-1',
    'cy': 'CY10259033P',
    'cz': 'CZ12345679',
    'de': _lt('DE123456788 or 12/345/67890'),
    'dk': 'DK12345674',
    'do': _lt('1-01-85004-3 or 101850043'),
    'ec': _lt('1792060346001 or 1792060346'),
    'ee': 'EE123456780',
    'es': 'ESA12345674',
    'fi': 'FI12345671',
    'fr': 'FR23334175221',
    'gb': _lt('GB123456782 or XI123456782'),
    'gr': 'EL123456783',
    'hu': _lt('HU12345676 or 12345678-1-11 or 8071592153'),
    'hr': 'HR01234567896',  # Croatia, contributed by Milan Tribuson
    'id': '1234567890123456',
    'ie': 'IE1234567FA',
    'il': _lt('XXXXXXXXX [9 digits] and it should respect the Luhn algorithm checksum'),
    'in': "12AAAAA1234AAZA",
    'is': 'IS062199',
    'it': 'IT12345670017',
    'jp': 'T7000012050002',
    'kr': '123-45-67890 or 1234567890',
    'lt': 'LT123456715',
    'lu': 'LU12345613',
    'lv': 'LV41234567891',
    'ma': '12345678',
    'mc': 'FR53000004605',
    'mt': 'MT12345634',
    'mx': _lt('GODE561231GR8'),
    'nl': 'NL123456782B90',
    'no': 'NO123456785',
    'nz': _lt('49-098-576 or 49098576'),
    'pe': _lt('10XXXXXXXXY or 20XXXXXXXXY or 15XXXXXXXXY or 16XXXXXXXXY or 17XXXXXXXXY'),
    'ph': '123-456-789-123',
    'pl': 'PL1234567883',
    'pt': 'PT123456789',
    'ro': 'RO1234567897 or 8001011234567 or 9000123456789',
    'rs': 'RS101134702',
    'ru': '123456789047',
    'se': 'SE123456789701',
    'si': 'SI12345679',
    'sk': 'SK2022749619',
    'sm': 'SM24165',
    'th': '1234545678781',
    'tr': _lt('11111111111 (NIN) or 2222222222 (VKN)'),
    'ua': _lt('12345678 or UA12345678 (EDRPOU), 1234567890 (RNOPP) or 123456789012 (IPN)'),
    'uy': _lt("Example: '219999830019' (format: 12 digits, all numbers, valid check digit)"),
    'uz': _lt('XXXXXXXXX [9 digits]'),
    've': 'V-12345678-1, V123456781, V-12.345.678-1',
    'xi': 'XI123456782',
    'sa': _lt('310175397400003 [Fifteen digits, first and last digits should be "3"]'),
}


class ResPartner(models.Model):
    _inherit = 'res.partner'

    vies_valid = fields.Boolean(
        string="Intra-Community Valid",
        compute='_compute_vies_valid', store=True, readonly=False,
        tracking=True,
        help='European VAT numbers are automatically checked on the VIES database.',
    )
    # Field representing whether vies_valid is relevant for selecting a fiscal position on this partner
    perform_vies_validation = fields.Boolean(compute='_compute_perform_vies_validation')
    # We put on inverse because a compute with a dependency to itself is not well managed in the ORM (it should be triggered first)
    country_id = fields.Many2one(inverse="_inverse_vat", store=True)
    vat = fields.Char(inverse="_inverse_vat", store=True)

    @api.model
    def _run_vat_checks(self, country, vat, partner_name='', validation='error'):
        """ OVERRIDE """
        if not country or not vat:
            return vat, False
        if len(vat) == 1:
            if vat == '/' or not validation:
                return vat, False
            if validation == 'setnull':
                return '', False
            if validation == 'error':
                raise ValidationError(_("To explicitly indicate no (valid) VAT, use '/' instead. "))
        vat_prefix, vat_number = self._split_vat(vat)

        if vat_prefix == 'EU' and country not in self.env.ref('base.europe').country_ids:
            # Foreign companies that trade with non-enterprises in the EU
            # may have a VATIN starting with "EU" instead of a country code.
            return vat, False

        do_eu_check = False
        prefixed_country = ''
        eu_prefix_country_group = self.env['res.country.group'].search([('code', '=', 'EU_PREFIX')], limit=1)
        country_code = EU_EXTRA_VAT_CODES_INV.get(vat_prefix, vat_prefix)
        if country_code in eu_prefix_country_group.country_ids.mapped('code'):
            if 'EU_PREFIX' in country.country_group_codes and vat_prefix:
                vat = vat_number
                prefixed_country = vat_prefix
            else:
                do_eu_check = True

        code_to_check = prefixed_country or country.code
        vat = self._format_vat_number(code_to_check, vat)

        if prefixed_country == 'GR':
            prefixed_country = 'EL'

        vat_to_return = prefixed_country + vat

        # The context key 'no_vat_validation' allows you to store/set a VAT number without doing validations.
        # This is for API pushes from external platforms where you have no control over VAT numbers.
        if not validation or self.env.context.get('no_vat_validation'):
            return vat_to_return, code_to_check

        # Avoid validating double prefix like BEBE0477472701
        double_prefix = prefixed_country and vat_to_return.startswith(prefixed_country + prefixed_country)
        if not self._check_vat_number(code_to_check, vat) or double_prefix:
            partner_label = _("partner [%s]", partner_name)
            if do_eu_check:
                try:
                    return self._run_vat_checks(self.env['res.country'].search([('code', '=', country_code)], limit=1), vat_prefix + vat_number, partner_name, validation)
                except ValidationError:
                    msg = self._build_vat_error_message(code_to_check, vat, partner_label)
                    raise ValidationError(msg + "\n\n" + _('If you are trying to input a European number, this is the expected format: ') + _ref_vat[country_code.lower()])
            if validation == 'error':
                msg = self._build_vat_error_message(code_to_check, vat, partner_label)
                raise ValidationError(msg)
            else:
                return '', code_to_check
        return vat_to_return, code_to_check

    def _inverse_vat(self):
        self._check_vat()

    @api.onchange('vat', 'country_id')
    def _onchange_vat(self):
        self._check_vat(validation=False)

    @api.depends_context('company')
    @api.depends('vat')
    def _compute_perform_vies_validation(self):
        """ Determine whether to show VIES validity on the current VAT number """
        for partner in self:
            to_check = partner.vat
            company_code = self.env.company.account_fiscal_country_id.code
            partner.perform_vies_validation = (
                to_check
                and not to_check[:2].upper() == company_code
                and self.env.company.vat_check_vies
            )

    @api.depends('vat')
    def _compute_vies_valid(self):
        """ Check the VAT number with VIES, if enabled."""
        if not self.env['res.company'].sudo().search_count([('vat_check_vies', '=', True)]):
            self.vies_valid = False
            return

        for partner in self:
            if not partner.vat:
                partner.vies_valid = False
                continue
            if partner.parent_id and partner.parent_id.vat == partner.vat:
                partner.vies_valid = partner.parent_id.vies_valid
                continue
            status = partner._check_vies_iap()
            partner._update_vies_status(status)

    def _split_vat(self, vat):
        vat_prefix, vat_number = vat[:2].upper(), vat[2:].replace(' ', '')
        if not vat_prefix.isalpha():
            return '', vat
        return vat_prefix, vat_number

    @api.model
    def _get_iap_vies_credentials(self):
        """
        Return a couple (identifier, token) that is going to identify this db to IAP such that only
        this one can request updates on a previously asked VIES check.
        If they exist, we simply return them. If they don't, we create them in another cursor to
        avoid the current transaction to be rolled back after the record has been created on IAP.
        """
        # No existing cron = no way for db to pull updates, thus no need to bother IAP
        if not self.env.ref('base_vat.vies_iap_check_update', raise_if_not_found=False):
            return "dummy_identifier", "dummy_token"  # ignored by IAP, same as neutralized

        IrConfigParam = self.env['ir.config_parameter'].sudo()
        identifier = IrConfigParam.get_param('iap_vies.client_identifier')
        token = IrConfigParam.get_param('iap_vies.client_token')
        if identifier and token:
            return identifier, token

        identifier = str(uuid.uuid4())
        token = secrets.token_urlsafe()
        with self.env.registry.cursor() as new_cursor:
            IrConfigParamNewCursor = self.env(cr=new_cursor)['ir.config_parameter'].sudo()
            IrConfigParamNewCursor.set_param('iap_vies.client_identifier', identifier)
            IrConfigParamNewCursor.set_param('iap_vies.client_token', token)

        return identifier, token

    @api.model
    def _get_iap_vies_endpoint(self):
        prod, test = 'https://vies.api.odoo.com', 'https://vies.test.odoo.com'
        default_endpoint = test if self.env.ref('base.module_base_vat').demo else prod
        endpoint = self.env['ir.config_parameter'].sudo().get_param('iap_vies.endpoint', default_endpoint)
        if endpoint not in (prod, test):
            raise UserError(_('Invalid IAP VIES endpoint'))
        return endpoint

    def _check_vies_iap(self):
        """Called when VAT is manually edited"""
        self.ensure_one()
        endpoint = self._get_iap_vies_endpoint()
        client_identifier, client_token = self._get_iap_vies_credentials()
        try:
            req = requests.post(
                endpoint + '/api/vies/1/check_validity',
                data={
                    "vat": self.vat,
                    "db_uuid": self.env['ir.config_parameter'].sudo().get_param('database.uuid'),
                    "client_identifier": client_identifier,
                    "client_token": client_token,
                    "webhook_url": self.get_base_url() + '/base_vat/1/webhook_update_vies',
                    "webhook_token": hash_sign(self.sudo().env, "vies_check", self.vat, expiration_hours=24),  # See BaseVatWebhookController
                },
                timeout=20,
            )
            req.raise_for_status()
        except requests.exceptions.RequestException:
            _logger.exception("VIES check: call to IAP failed")
            return "fault"
        resp = req.json()
        if not resp.get("status"):
            _logger.error("VIES check: no status returned. Response: %s", resp)
            return "fault"
        return resp["status"]

    @api.model
    def _cron_check_vies_iap(self):
        """Called by cron to check if IAP has any update on a previously requested VAT that was pending"""
        endpoint = self._get_iap_vies_endpoint()
        client_identifier, client_token = self._get_iap_vies_credentials()
        try:
            req = requests.post(
                endpoint + '/api/vies/1/check_update',
                data={
                    "db_uuid": self.env['ir.config_parameter'].sudo().get_param('database.uuid'),
                    "client_identifier": client_identifier,
                    "client_token": client_token,
                },
                timeout=10,
            )
            req.raise_for_status()
        except requests.exceptions.RequestException:
            _logger.exception("Error while contacting IAP VIES")
            return
        resp = req.json()
        _logger.info("IAP VIES check response: %s", resp)
        for company_vat, company_status in resp.items():
            partner = self.search([("vat", "=", company_vat)])
            partner._update_vies_status(company_status)

    def _update_vies_status(self, status):
        self.vies_valid = status == "valid"
        _logger.info("VIES status updated to %s for partner ids: %s", status, self.ids)
        msg = None
        if status == "pending":
            msg = _("The VIES check is pending. The status will be updated soon.")
        elif status == "fault":
            msg = _("The VIES check failed. Please check the Tax ID manually.")
        elif status in ("valid", "unassigned"):
            msg = _("The Intra-Community validity has been updated.")
        if msg:
            self._message_log_batch(bodies={p._origin.id: msg for p in self if p._origin.id})

    @api.model
    def _check_vat_number(self, country_code, vat_number):
        ''' Low-level method directly calling stdnum or our own specific method. '''
        check_func_name = 'check_vat_' + country_code.lower()
        check_func = getattr(self, check_func_name, None) or getattr(stdnum.util.get_cc_module(country_code, 'vat'), 'is_valid', None)
        return check_func(vat_number) if check_func else True

    @api.model
    def _build_vat_error_message(self, country_code, wrong_vat, record_label):
        # OVERRIDE account
        if self.env.context.get('company_id'):
            company = self.env['res.company'].browse(self.env.context['company_id'])
        else:
            company = self.env.company

        vat_label = _("VAT")
        if country_code and company.country_id and country_code == company.country_id.code and company.country_id.vat_label:
            vat_label = company.country_id.vat_label

        expected_format = _ref_vat.get(country_code.lower())
        expected_note = ""
        if expected_format:
            expected_note = ' \n' + _(
                'Note: the expected format is %(expected_format)s',
                 expected_format=expected_format
            )

        # Catch use case where the record label is about the public user (name: False)
        if 'False' not in record_label:
            return '\n' + _(
                'The %(vat_label)s number [%(wrong_vat)s] for %(record_label)s does not seem to be valid. %(expected_note)s',
                vat_label=vat_label,
                wrong_vat=wrong_vat,
                record_label=record_label,
                expected_note=expected_note
            )
        else:
            return '\n' + _(
                'The %(vat_label)s number [%(wrong_vat)s] does not seem to be valid. %(expected_note)s',
                vat_label=vat_label,
                wrong_vat=wrong_vat,
                expected_note=expected_note,
            )

    _check_vat_al_re = re.compile(r'^[JKLM][0-9]{8}[A-Z]$')

    def check_vat_al(self, vat):
        """Check Albania VAT number"""
        number = stdnum.util.get_cc_module('al', 'vat').compact(vat)
        return len(number) == 10 and self._check_vat_al_re.match(number)

    def check_vat_jp(self, vat):
        if vat and vat[0] == 'T':
            vat = vat[1:]
        return stdnum.util.get_cc_module('jp', 'vat').is_valid(vat)

    _check_tin1_ro_natural_persons = re.compile(r'[1-9]\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{6}')
    _check_tin2_ro_natural_persons = re.compile(r'9000\d{9}')

    def check_vat_do(self, vat):
        is_valid_vat = stdnum.util.get_cc_module("do", "vat").is_valid
        is_valid_cedula = stdnum.util.get_cc_module("do", "cedula").is_valid
        return is_valid_vat(vat) or is_valid_cedula(vat)

    def check_vat_ro(self, vat):
        """
            Check Romanian VAT number that can be for example 'RO1234567897 or 'xyyzzaabbxxxx' or '9000xxxxxxxx'.

            - For xyyzzaabbxxxx, 'x' can be any number, 'y' is the two last digit of a year (in the range 00…99),
              'a' is a month, b is a day of the month, the number 8 and 9 are Country or district code
              (For those twos digits, we decided to let some flexibility  to avoid complexifying the regex and also
              for maintainability)
            - 9000xxxxxxxx, start with 9000 and then is filled by number In the range 0...9

            Also stdum also checks the CUI or CIF (Romanian company identifier). So a number like '123456897' will pass.
        """
        tin1 = self._check_tin1_ro_natural_persons.match(vat)
        if tin1:
            return True
        tin2 = self._check_tin2_ro_natural_persons.match(vat)
        if tin2:
            return True
        # Check the vat number
        return stdnum.util.get_cc_module('ro', 'vat').is_valid(vat)

    def check_vat_gr(self, vat):
        """ Allows some custom test VAT number to be valid to allow testing Greece EDI. """
        greece_test_vats = ('047747270', '047747210', '047747220', '117747270', '127747270')
        if vat in greece_test_vats:
            return True
        return stdnum.util.get_cc_module('gr', 'vat').is_valid(vat)

    # Our EDI provider Infile has designated this range of testing VATs for our customers.
    __check_vat_gt_testing_infile = re.compile(r'98[0-9]{10}K')

    def check_vat_gt(self, vat):
        """
        Allow some custom Guatemala NIT numbers to pass the test to be used for testing the Guatemalan EDI.
        """
        guatemalan_test_vats = ('11201220K', '11201350K')
        if vat in guatemalan_test_vats or self.__check_vat_gt_testing_infile.match(vat):
            return True
        return stdnum.util.get_cc_module('gt', 'vat').is_valid(vat)

    _check_tin_hu_individual_re = re.compile(r'^8\d{9}$')
    _check_tin_hu_companies_re = re.compile(r'^\d{8}-?[1-5]-?\d{2}$')
    _check_tin_hu_european_re = re.compile(r'^\d{8}$')

    def check_vat_hu(self, vat):
        """
            Check Hungary VAT number that can be for example 'HU12345676 or 'xxxxxxxx-y-zz' or '8xxxxxxxxy'

            - For xxxxxxxx-y-zz, 'x' can be any number, 'y' is a number between 1 and 5 depending on the person and the 'zz'
              is used for region code.
            - 8xxxxxxxxy, Tin number for individual, it has to start with an 8 and finish with the check digit
            - In case of EU format it will be the first 8 digits of the full VAT
        """
        companies = self._check_tin_hu_companies_re.match(vat)
        if companies:
            return True
        individual = self._check_tin_hu_individual_re.match(vat)
        if individual:
            return True
        european = self._check_tin_hu_european_re.match(vat)
        if european:
            return True
        # Check the vat number
        return stdnum.util.get_cc_module('hu', 'vat').is_valid(vat)

    _check_vat_ch_re = re.compile(r'E([0-9]{9}|-[0-9]{3}\.[0-9]{3}\.[0-9]{3})( )?(MWST|TVA|IVA)$')

    def check_vat_ch(self, vat):
        '''
        Check Switzerland VAT number.
        '''
        # A new VAT number format in Switzerland has been introduced between 2011 and 2013
        # https://www.estv.admin.ch/estv/fr/home/mehrwertsteuer/fachinformationen/steuerpflicht/unternehmens-identifikationsnummer--uid-.html
        # The old format "TVA 123456" is not valid since 2014
        # Accepted format are: (spaces are ignored)
        #     CHE#########MWST
        #     CHE#########TVA
        #     CHE#########IVA
        #     CHE-###.###.### MWST
        #     CHE-###.###.### TVA
        #     CHE-###.###.### IVA
        #
        # /!\ The english abbreviation VAT is not valid /!\

        match = self._check_vat_ch_re.match(vat)
        if match:
            # For new TVA numbers, the last digit is a MOD11 checksum digit build with weighting pattern: 5,4,3,2,7,6,5,4
            num = [s for s in match.group(1) if s.isdigit()]        # get the digits only
            factor = (5, 4, 3, 2, 7, 6, 5, 4)
            csum = sum([int(num[i]) * factor[i] for i in range(8)])
            check = (11 - (csum % 11)) % 11
            return check == int(num[8])
        return False


    def is_valid_ruc_ec(self, vat):
        if len(vat) in (10, 13) and vat.isdecimal():
            return True
        return False

    def check_vat_ec(self, vat):
        vat = clean(vat, ' -.').upper().strip()
        return self.is_valid_ruc_ec(vat)

    def _ie_check_char(self, vat):
        vat = vat.zfill(8)
        extra = 0
        if vat[7] not in ' W':
            if vat[7].isalpha():
                extra = 9 * (ord(vat[7]) - 64)
            else:
                # invalid
                return -1
        checksum = extra + sum((8-i) * int(x) for i, x in enumerate(vat[:7]))
        return 'WABCDEFGHIJKLMNOPQRSTUV'[checksum % 23]

    # TODO: remove in master
    def check_vat_ie(self, vat):
        return stdnum.util.get_cc_module('ie', 'vat').is_valid(vat)

    # Mexican VAT verification, contributed by Vauxoo
    # and Panos Christeas <p_christ@hol.gr>
    _check_vat_mx_re = re.compile(r"(?P<primeras>[A-Za-z\xd1\xf1&]{3,4})"
                                   r"[ \-_]?"
                                   r"(?P<ano>[0-9]{2})(?P<mes>[01][0-9])(?P<dia>[0-3][0-9])"
                                   r"[ \-_]?"
                                   r"(?P<code>[A-Za-z0-9&\xd1\xf1]{3})")

    def check_vat_mx(self, vat):
        ''' Mexican VAT verification

        Verificar RFC México
        '''
        m = self._check_vat_mx_re.fullmatch(vat)
        if not m:
            #No valid format
            return False
        ano = int(m['ano'])
        if ano > 30:
            ano = 1900 + ano
        else:
            ano = 2000 + ano
        try:
            datetime.date(ano, int(m['mes']), int(m['dia']))
        except ValueError:
            return False

        # Valid format and valid date
        return True

    # Norway VAT validation, contributed by Rolv Råen (adEgo) <rora@adego.no>
    # Support for MVA suffix contributed by Bringsvor Consulting AS (bringsvor@bringsvor.com)
    def check_vat_no(self, vat):
        """
        Check Norway VAT number.See http://www.brreg.no/english/coordination/number.html
        """
        if len(vat) == 12 and vat.upper().endswith('MVA'):
            vat = vat[:-3] # Strictly speaking we should enforce the suffix MVA but...

        if len(vat) != 9:
            return False
        try:
            int(vat)
        except ValueError:
            return False

        sum = (3 * int(vat[0])) + (2 * int(vat[1])) + \
            (7 * int(vat[2])) + (6 * int(vat[3])) + \
            (5 * int(vat[4])) + (4 * int(vat[5])) + \
            (3 * int(vat[6])) + (2 * int(vat[7]))

        check = 11 - (sum % 11)
        if check == 11:
            check = 0
        if check == 10:
            # 10 is not a valid check digit for an organization number
            return False
        return check == int(vat[8])

    # Peruvian VAT validation, contributed by Vauxoo
    def check_vat_pe(self, vat):
        if len(vat) != 11 or not vat.isdigit():
            return False
        dig_check = 11 - (sum([int('5432765432'[f]) * int(vat[f]) for f in range(0, 10)]) % 11)
        if dig_check == 10:
            dig_check = 0
        elif dig_check == 11:
            dig_check = 1
        return int(vat[10]) == dig_check

    # Philippines TIN (+ branch code) validation
    _check_vat_ph_re = re.compile(r"\d{3}-\d{3}-\d{3}(-\d{3,5})?$")

    def check_vat_ph(self, vat):
        return len(vat) >= 11 and len(vat) <= 17 and self._check_vat_ph_re.match(vat)

    def check_vat_ru(self, vat):
        '''
        Check Russia VAT number.
        Method copied from vatnumber 1.2 lib https://code.google.com/archive/p/vatnumber/
        '''
        if len(vat) != 10 and len(vat) != 12:
            return False
        try:
            int(vat)
        except ValueError:
            return False

        if len(vat) == 10:
            check_sum = 2 * int(vat[0]) + 4 * int(vat[1]) + 10 * int(vat[2]) + \
                3 * int(vat[3]) + 5 * int(vat[4]) + 9 * int(vat[5]) + \
                4 * int(vat[6]) + 6 * int(vat[7]) + 8 * int(vat[8])
            check = check_sum % 11
            if check % 10 != int(vat[9]):
                return False
        else:
            check_sum1 = 7 * int(vat[0]) + 2 * int(vat[1]) + 4 * int(vat[2]) + \
                10 * int(vat[3]) + 3 * int(vat[4]) + 5 * int(vat[5]) + \
                9 * int(vat[6]) + 4 * int(vat[7]) + 6 * int(vat[8]) + \
                8 * int(vat[9])
            check = check_sum1 % 11

            if check != int(vat[10]):
                return False
            check_sum2 = 3 * int(vat[0]) + 7 * int(vat[1]) + 2 * int(vat[2]) + \
                4 * int(vat[3]) + 10 * int(vat[4]) + 3 * int(vat[5]) + \
                5 * int(vat[6]) + 9 * int(vat[7]) + 4 * int(vat[8]) + \
                6 * int(vat[9]) + 8 * int(vat[10])
            check = check_sum2 % 11
            if check != int(vat[11]):
                return False
        return True

    # VAT validation in Serbia
    def check_vat_rs(self, vat):
        vat = vat.removeprefix('RS')
        return stdnum.util.get_cc_module('rs', 'vat').is_valid(vat)

    # VAT validation in Turkey
    def check_vat_tr(self, vat):
        return stdnum.util.get_cc_module('tr', 'tckimlik').is_valid(vat) or stdnum.util.get_cc_module('tr', 'vkn').is_valid(vat)

    _check_vat_sa_re = re.compile(r"^3[0-9]{13}3$")

    # Saudi Arabia TIN validation
    def check_vat_sa(self, vat):
        """
            Check company VAT TIN according to ZATCA specifications: The VAT number should start and begin with a '3'
            and be 15 digits long
        """
        return self._check_vat_sa_re.match(vat) or False

    def check_vat_ua(self, vat):
        return len(vat[2:] if vat.startswith('UA') else vat) in {8, 10, 12}

    def check_vat_uy(self, vat):
        """ Taken from python-stdnum's master branch, as the release doesn't handle RUT numbers starting with 22.
        origin https://github.com/arthurdejong/python-stdnum/blob/master/stdnum/uy/rut.py
        FIXME Can be removed when python-stdnum does a new release. """

        def compact(number):
            """Convert the number to its minimal representation."""
            number = clean(number, ' -').upper().strip()
            if number.startswith('UY'):
                return number[2:]
            return number

        def calc_check_digit(number):
            """Calculate the check digit."""
            weights = (4, 3, 2, 9, 8, 7, 6, 5, 4, 3, 2)
            total = sum(int(n) * w for w, n in zip(weights, number))
            return str(-total % 11)

        vat = compact(vat)

        return (
            vat.isdigit()  # InvalidFormat
            and len(vat) == 12  # InvalidLength
            and '01' <= vat[:2] <= '22'  # InvalidComponent
            and vat[2:8] != '000000'
            and vat[8:11] == '001'
            and vat[-1] == calc_check_digit(vat)  # Invalid Check Digit
        )

    def check_vat_uz(self, vat):
        return len(vat) == 9 and vat.isdigit()

    def check_vat_ve(self, vat):
        # https://tin-check.com/en/venezuela/
        # https://techdocs.broadcom.com/us/en/symantec-security-software/information-security/data-loss-prevention/15-7/About-content-packs/What-s-included-in-Content-Pack-2021-02/Updated-data-identifiers-in-Content-Pack-2021-02/venezuela-national-identification-number-v115451096-d327e108002-CP2021-02.html
        # Sources last visited on 2022-12-09

        # VAT format: (kind - 1 letter)(identifier number - 8-digit number)(check digit - 1 digit)
        vat_regex = re.compile(r"""
            ([vecjpg])                          # group 1 - kind
            (
                (?P<optional_1>-)?                      # optional '-' (1)
                [0-9]{2}
                (?(optional_1)(?P<optional_2>[.])?)     # optional '.' (2) only if (1)
                [0-9]{3}
                (?(optional_2)[.])                      # mandatory '.' if (2)
                [0-9]{3}
                (?(optional_1)-)                        # mandatory '-' if (1)
            )                                   # group 2 - identifier number
            ([0-9]{1})                          # group X - check digit
        """, re.VERBOSE | re.IGNORECASE)

        matches = re.fullmatch(vat_regex, vat)
        if not matches:
            return False

        kind, identifier_number, *_, check_digit = matches.groups()
        kind = kind.lower()
        identifier_number = identifier_number.replace("-", "").replace(".", "")
        check_digit = int(check_digit)

        if kind == 'v':                   # Venezuela citizenship
            kind_digit = 1
        elif kind == 'e':                 # Foreigner
            kind_digit = 2
        elif kind == 'c' or kind == 'j':  # Township/Communal Council or Legal entity
            kind_digit = 3
        elif kind == 'p':                 # Passport
            kind_digit = 4
        else:                             # Government ('g')
            kind_digit = 5

        # === Checksum validation ===
        multipliers = [3, 2, 7, 6, 5, 4, 3, 2]
        checksum = kind_digit * 4
        checksum += sum(map(lambda n, m: int(n) * m, identifier_number, multipliers))

        checksum_digit = 11 - checksum % 11
        if checksum_digit > 9:
            checksum_digit = 0

        return check_digit == checksum_digit

    def check_vat_in(self, vat):
        #reference from https://www.gstzen.in/a/format-of-a-gst-number-gstin.html
        if vat and len(vat) == 15:
            all_gstin_re = [
                r'[0-9]{2}[a-zA-Z]{5}[0-9]{4}[a-zA-Z]{1}[1-9A-Za-z]{1}[Zz1-9A-Ja-j]{1}[0-9a-zA-Z]{1}', # Normal, Composite, Casual GSTIN
                r'[0-9]{4}[A-Z]{3}[0-9]{5}[UO]{1}[N][A-Z0-9]{1}', #UN/ON Body GSTIN
                r'[0-9]{4}[A-Z]{3}[0-9]{5}[A-Z]{3}',  # Revised NRI GSTIN
                r'[0-9]{4}[a-zA-Z]{3}[0-9]{5}[N][R][0-9a-zA-Z]{1}', #NRI GSTIN
                r'[0-9]{2}[a-zA-Z]{4}[a-zA-Z0-9]{1}[0-9]{4}[a-zA-Z]{1}[1-9A-Za-z]{1}[DK]{1}[0-9a-zA-Z]{1}', #TDS GSTIN
                r'[0-9]{2}[a-zA-Z]{5}[0-9]{4}[a-zA-Z]{1}[1-9A-Za-z]{1}[C]{1}[0-9a-zA-Z]{1}' #TCS GSTIN
            ]
            return any(re.compile(rx).match(vat) for rx in all_gstin_re)
        return False

    def check_vat_br(self, vat):
        is_cpf_valid = stdnum.get_cc_module('br', 'cpf').is_valid
        is_cnpj_valid = stdnum.get_cc_module('br', 'cnpj').is_valid
        return is_cpf_valid(vat) or is_cnpj_valid(vat)

    _check_vat_cr_re = re.compile(r'^(?:[1-9]\d{8}|\d{10}|[1-9]\d{10,11})$')

    def check_vat_cr(self, vat):
        # CÉDULA FÍSICA: 9 digits
        # CÉDULA JURÍDICA: 10 digits
        # CÉDULA DIMEX: 11 or 12 digits
        # CÉDULA NITE: 10 digits

        return self._check_vat_cr_re.match(vat) or False

    __check_vat_vn_re = re.compile(r'^\d{10}(?:-?\d{3})?$|^\d{12}$')
    __check_vat_vn_companies_re = re.compile(r'^\d{10}(?:-?\d{3})?$')

    def check_vat_vn(self, vat):
        """
        VAT format validator for Vietnam.
        Supported formats:
        - 10-digit format (Enterprise tax ID): e.g., 0101243150
        - 13-digit format with branch suffix: e.g., 0101243150-001
        - 12-digit format (Personal ID / Citizen ID - CCCD): e.g., 079123456789
        (used as tax ID for individuals from July 1st, 2025)

        Note:
        - stdnum.vn.mst.validate() currently only supports 10- and 13-digit VAT numbers
        - and does not accept the 12-digit personal tax ID (CCCD) format introduced from 01/07/2025.
        - This helper provides a lightweight format-level validator for use in the meantime.
        - Can be removed once stdnum.vn.mst adds CCCD support.
        """
        vat = vat.strip()
        return bool(self.__check_vat_vn_re.match(vat))

    def format_vat_eu(self, vat):
        # Foreign companies that trade with non-enterprises in the EU
        # may have a VATIN starting with "EU" instead of a country code.
        return vat

    def format_vat_ch(self, vat):
        stdnum_vat_format = stdnum.util.get_cc_module('ch', 'vat').format
        return stdnum_vat_format('CH' + vat)[2:]

    def format_vat_cl(self, vat):
        """ It is better to always have the -"""
        vat = vat.replace('.', '').replace('CL', '').replace(' ', '').replace('-', '').upper()
        if len(vat) > 2:
            return vat[:-1] + '-' + vat[-1]
        return vat

    def format_vat_co(self, vat):
        """ It is better to always have the -"""
        stdnum_vat_format = stdnum.util.get_cc_module('co', 'vat').format
        vat = stdnum_vat_format(vat).replace('.', '').replace('-', '')
        if len(vat) > 2:
            return vat[:-1] + '-' + vat[-1]
        return vat

    def format_vat_vn(self, vat):
        """ It is better to always have the -"""
        stdnum_vat_format = stdnum.util.get_cc_module('vn', 'vat').format
        if self.__check_vat_vn_companies_re.match(vat):
            return stdnum_vat_format(vat)
        else:
            return vat

    def format_vat_hu(self, vat):
        """ We put the - back as we require it for the EDI and the different parts will make it clear to the user"""
        stdnum_vat_fix_func = stdnum.util.get_cc_module('hu', 'vat').compact
        vat = stdnum_vat_fix_func(vat)
        if self._check_tin_hu_companies_re.match(vat):
            vat = vat[:8] + '-' + vat[8] + '-' + vat[9] + vat[10]
        return vat

    def check_vat_id(self, vat):
        """ Temporary Indonesian VAT validation to support the new format
        introduced in January 2024."""
        vat = clean(vat, ' -.').strip()

        if len(vat) not in (15, 16) or not vat.isdecimal():
            return False

        # VAT could be 15 (old numbers) or 16 digits. If there are 15 digits long, the 10th digit is a luhn checksum
        # In some cases, the 15 digits can be transformed in a 16-digit by adding a 0 in front. In such case, we
        # we can verify the luhn checksum like for the 15 digits by removing the 0.
        # However, for newly created VAT 16-digits VAT number, there is no checksum.
        if (len(vat) == 16 and vat[0] != '0'):
            return True

        try:
            luhn.validate(vat[0:9] if len(vat) == 15 else vat[1:10])
        except (InvalidFormat, InvalidChecksum):
            return False

        return True

    def check_vat_th(self, vat):
        check_func = stdnum.util.get_cc_module('th', 'tin').is_valid
        return check_func(vat)

    def check_vat_de(self, vat):
        is_valid_vat = stdnum.util.get_cc_module("de", "vat").is_valid
        is_valid_stnr = stdnum.util.get_cc_module("de", "stnr").is_valid
        return is_valid_vat(vat) or is_valid_stnr(vat)

    def check_vat_il(self, vat):
        check_func = stdnum.util.get_cc_module('il', 'idnr').is_valid
        return check_func(vat)

    def check_vat_ma(self, vat):
        return vat.isdigit() and len(vat) == 8

    def format_vat_sm(self, vat):
        stdnum_vat_format = stdnum.util.get_cc_module('sm', 'vat').compact
        return stdnum_vat_format('SM' + vat)[2:]

    def check_vat_tw(self, vat):
        """
        Since Feb. 2025, due to the imminent exhaustion of the UBN numbers, the validation logic was changed from using
        a division by 10 for the final check to using a division by 5, making numbers that were previously invalid now
        valid.

        The stdnum implementation of the VAT validation is not up to date with this latest update, so we implement our
        own validation to support these new valid UBNs.
        """
        vat = stdnum.util.get_cc_module("tw", "vat").compact(vat)
        if len(vat) != 8:
            return False  # The length is fixed, and we will expect it to be 8 in the following checks.

        logic_multiplier = [1, 2, 1, 2, 1, 2, 4, 1]  # This multiplier is set by the official validation logic.
        # Multiply each of the 8 digits of the VAT number by the corresponding digit of the logic multiplier.
        # For the next steps, we will need to sum the results.
        # For a two-digit product like 20, you would add its digits (2 + 0) to the total sum, so we convert the sums here
        # to strings in order to make it easier later on.
        products = [str(a * int(b)) for a, b in zip(logic_multiplier, vat)]
        if vat[6] != '7':
            # If the 7th number is not 7, we simply sum everything and check that the result is divisible by 5.
            checksum = sum(int(d) for d in ''.join(products))
            return checksum % 5 == 0
        else:
            # If the 7th number is 7, we calculate two sums:
            # z1: Calculate the total sum where the 7th position's contribution is taken as 1.
            # z2: Calculate the total sum where the 7th position's contribution is taken as 0.
            # The VAT number is valid if either Z1 or Z2 (or both) is evenly divisible by 5.
            base_checksum = sum(int(d) for d in "".join(products[0:6] + products[7:]))
            return (base_checksum + 1) % 5 == 0 or base_checksum % 5 == 0

    @api.model
    def _format_vat_number(self, country_code, vat):
        """ Low-level method directly calling stdnum or our own specific method returning the formatted VAT. """
        stdnum_vat_fix_func = getattr(stdnum.util.get_cc_module(country_code, 'vat'), 'compact', None)
        # If any localization module needs to define vat fix method for its country then we give first priority to it.
        format_func_name = 'format_vat_' + country_code.lower()
        format_func = getattr(self, format_func_name, None) or stdnum_vat_fix_func
        if format_func:
            vat = format_func(vat)
        return vat

    @api.model
    def _convert_hu_local_to_eu_vat(self, local_vat):
        if self._check_tin_hu_companies_re.match(local_vat):
            return f'HU{local_vat[:8]}'
        return False

    def _get_vat_required_valid(self, company=None):
        # OVERRIDE
        # If VIES validation does not apply to this partner (e.g. they
        # are in the same country as the partner), then skip.
        vat_required_valid = super()._get_vat_required_valid(company=company)
        if (
            company and company.country_id and self.with_company(company).perform_vies_validation
            and ('EU' in company.country_id.country_group_codes or self.country_id and self.country_id.has_foreign_fiscal_position)
        ):
            vat_required_valid = vat_required_valid and self.vies_valid
        return vat_required_valid

    @api.model_create_multi
    def create(self, vals_list):
        res = super().create(vals_list)
        if self.env.context.get('import_file'):
            res.env.remove_to_compute(self._fields['vies_valid'], res)
        return res

    def write(self, vals):
        res = super().write(vals)
        if self.env.context.get('import_file'):
            self.env.remove_to_compute(self._fields['vies_valid'], self)
        return res

    def _create_contact_parent_company(self):
        new_company = super()._create_contact_parent_company()
        if new_company and self.vies_valid:
            new_company.env.remove_to_compute(self._fields['vies_valid'], new_company)
            new_company.vies_valid = self.vies_valid
        return new_company
