# -*- coding: utf-8 -*-
# Part of Odoo. See LICENSE file for full copyright and licensing details.

import threading
from concurrent.futures import ThreadPoolExecutor

import psycopg2.errors

from odoo import api
from odoo.modules.registry import Registry
from odoo.tests.common import get_db_name, tagged, BaseCase
from odoo.tools import mute_logger


@tagged('-standard', '-at_install', 'post_install')
class TestOnboardingConcurrency(BaseCase):

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.registry = Registry(get_db_name())
        cls.addClassCleanup(cls.cleanUpClass)

        with cls.registry.cursor() as cr:
            env = api.Environment(cr, api.SUPERUSER_ID, {})
            cls.onboarding_id = env['onboarding.onboarding'].create([
                {
                    'name': 'Test Onboarding Concurrent',
                    'is_per_company': False,
                    'route_name': 'onboarding_concurrent'
                }
            ]).id

    @classmethod
    def cleanUpClass(cls):
        with cls.registry.cursor() as cr:
            env = api.Environment(cr, api.SUPERUSER_ID, {})
            env['onboarding.onboarding'].browse(cls.onboarding_id).unlink()
            env['onboarding.progress'].search([
                ('onboarding_id', '=', cls.onboarding_id)
            ]).unlink()

    @mute_logger('odoo.sql_db')
    def test_concurrent_create_progress(self):
        barrier = threading.Barrier(2)

        def run():
            with self.registry.cursor() as cr:
                env = api.Environment(cr, api.SUPERUSER_ID, {})
                onboarding = env['onboarding.onboarding'].search([
                    ('id', '=', self.onboarding_id)
                ])
                # There is no progress record
                self.assertFalse(env['onboarding.progress'].search([
                    ('onboarding_id', '=', self.onboarding_id)
                ]))
                barrier.wait(timeout=2)
                try:
                    onboarding._create_progress()
                except psycopg2.errors.UniqueViolation:
                    return True

            return False

        with ThreadPoolExecutor(max_workers=2) as executor:
            future_1 = executor.submit(run)
            future_2 = executor.submit(run)
            raised_1 = future_1.result(timeout=3)
            raised_2 = future_2.result(timeout=3)

        with self.registry.cursor() as cr:
            env = api.Environment(cr, api.SUPERUSER_ID, {})
            self.assertEqual(
                len(env['onboarding.progress'].search([('onboarding_id', '=', self.onboarding_id)])),
                1,
                "Exactly one thread should have been able to create a record."
            )

        self.assertEqual(
            raised_1 + raised_2,
            1,
            "Exactly one thread should have raised a UniqueViolation error even though "
            "there was no progress record at the start of its transaction."
        )
