# Part of Odoo. See LICENSE file for full copyright and licensing details.

import operator as py_operator
from collections.abc import Iterable
from re import findall as regex_findall, split as regex_split
from collections import defaultdict

from odoo import _, api, fields, models
from odoo.exceptions import UserError, ValidationError
from odoo.fields import Domain

PY_OPERATORS = {
    '<': py_operator.lt,
    '>': py_operator.gt,
    '<=': py_operator.le,
    '>=': py_operator.ge,
    '=': py_operator.eq,
    '!=': py_operator.ne,
    'in': lambda elem, container: elem in container,
    'not in': lambda elem, container: elem not in container,
}


class StockLot(models.Model):
    _name = 'stock.lot'
    _inherit = ['mail.thread', 'mail.activity.mixin']
    _description = 'Lot/Serial'
    _check_company_auto = True
    _order = 'name, id'

    @api.model
    def default_get(self, fields):
        context = dict(self.env.context)
        # We always want the company_id to be computed, regardless of where it's been created.
        context.pop('default_company_id', False)
        return super(StockLot, self.with_context(context)).default_get(fields)

    def _read_group_location_id(self, locations, domain):
        partner_locations = locations.search([('usage', 'in', ('customer', 'supplier'))])
        return partner_locations + locations.warehouse_id.search([]).lot_stock_id

    name = fields.Char('Lot/Serial Number', required=True, compute='_compute_name', store=True, readonly=False, help="Unique Lot/Serial Number", index='trigram', precompute=True)
    ref = fields.Char('Internal Reference', help="Internal reference number in case it differs from the manufacturer's lot/serial number")
    product_id = fields.Many2one(
        'product.product', 'Product', index=True,
        domain=("[('tracking', '!=', 'none'), ('is_storable', '=', True)] +"
            " ([('product_tmpl_id', '=', context['default_product_tmpl_id'])] if context.get('default_product_tmpl_id') else [])"),
        required=True, check_company=True, tracking=True)
    product_uom_id = fields.Many2one(
        'uom.uom', 'Unit',
        related='product_id.uom_id')
    quant_ids = fields.One2many('stock.quant', 'lot_id', 'Quants', readonly=True)
    product_qty = fields.Float('On Hand Quantity', compute='_product_qty', search='_search_product_qty')
    note = fields.Html(string='Description')
    display_complete = fields.Boolean(compute='_compute_display_complete')
    company_id = fields.Many2one('res.company', 'Company', index=True, store=True, readonly=False, compute='_compute_company_id')
    delivery_ids = fields.Many2many('stock.picking', compute='_compute_delivery_ids', string='Transfers')
    delivery_count = fields.Integer('Delivery order count', compute='_compute_delivery_ids')
    partner_ids = fields.Many2many('res.partner', compute='_compute_partner_ids', search='_search_partner_ids')
    lot_properties = fields.Properties('Properties', definition='product_id.lot_properties_definition', copy=True)
    location_id = fields.Many2one(
        'stock.location', 'Location', compute='_compute_single_location', store=True, readonly=False,
        inverse='_set_single_location', domain="[('usage', '!=', 'view')]", group_expand='_read_group_location_id')

    @api.depends('product_id')
    def _compute_name(self):
        for lot in self:
            if not lot.name:
                lot.name = lot.product_id.lot_sequence_id.next_by_id() if lot.product_id.lot_sequence_id else False

    @api.model
    def generate_lot_names(self, first_lot, count):
        """Generate `lot_names` from a string."""
        # We look if the first lot contains at least one digit.
        caught_initial_number = regex_findall(r"\d+", first_lot)
        if not caught_initial_number:
            return self.generate_lot_names(first_lot + "0", count)
        # We base the series on the last number found in the base lot.
        initial_number = caught_initial_number[-1]
        padding = len(initial_number)
        # We split the lot name to get the prefix and suffix.
        splitted = regex_split(initial_number, first_lot)
        # initial_number could appear several times, e.g. BAV023B00001S00001
        prefix = initial_number.join(splitted[:-1])
        suffix = splitted[-1]
        initial_number = int(initial_number)

        return [{
            'lot_name': '%s%s%s' % (prefix, str(initial_number + i).zfill(padding), suffix),
        } for i in range(0, count)]

    @api.model
    def _get_next_serial(self, company, product):
        """Return the next serial number to be attributed to the product."""
        if product.tracking != "none":
            last_serial = self.env['stock.lot'].search(
                ['|', ('company_id', '=', company.id), ('company_id', '=', False), ('product_id', '=', product.id)],
                limit=1, order='id DESC')
            if last_serial:
                return self.env['stock.lot'].generate_lot_names(last_serial.name, 2)[1]['lot_name']
        return False

    @api.constrains('name', 'product_id', 'company_id')
    def _check_unique_lot(self):
        domain = [('product_id', 'in', self.product_id.ids),
                  ('name', 'in', self.mapped('name'))]
        groupby = ['company_id', 'product_id', 'name']
        if any(not lot.company_id for lot in self):
            # We need to check across other companies to not have duplicates between 'no-company' and a company.
            self = self.sudo()
        records = self.with_context(skip_preprocess_gs1=True)._read_group(domain, groupby, ['__count'], order='company_id DESC')
        error_message_lines = set()
        cross_lots = {}
        for company, product, name, count in records:
            if not company:
                cross_lots[(product, name)] = count
            # For company-specific lots, we check that there is no duplicate with 'no-company' lots, but NOT between specific-company ones.
            if (company and (cross_lots.get((product, name), 0) + count) > 1) or count > 1:
                error_message_lines.add(_(" - Product: %(product)s, Lot/Serial Number: %(lot)s", product=product.display_name, lot=name))
        if error_message_lines:
            raise ValidationError(
                _(
                    "The combination of lot/serial number and product must be unique within a company including when no company is defined.\nThe following combinations contain duplicates:\n%(error_lines)s",
                    error_lines="\n".join(error_message_lines),
                ),
            )

    def _check_create(self):
        active_picking_id = self.env.context.get('active_picking_id', False)
        if active_picking_id:
            picking_id = self.env['stock.picking'].browse(active_picking_id)
            if picking_id and not picking_id.picking_type_id.use_create_lots:
                raise UserError(_('You are not allowed to create a lot or serial number with this operation type. To change this, go on the operation type and tick the box "Create New Lots/Serial Numbers".'))

    @api.depends('product_id.company_id')
    def _compute_company_id(self):
        for lot in self:
            if self.env.company in lot.product_id.company_id.all_child_ids and lot.product_id.company_id not in self.env.companies:
                lot.company_id = self.env.company
            else:
                lot.company_id = lot.product_id.company_id

    @api.depends('name')
    def _compute_display_complete(self):
        """ Defines if we want to display all fields in the stock.production.lot form view.
        It will if the record exists (`id` set) or if we precised it into the context.
        This compute depends on field `name` because as it has always a default value, it'll be
        always triggered.
        """
        for prod_lot in self:
            prod_lot.display_complete = prod_lot.id or self.env.context.get('display_complete')

    def _compute_delivery_ids(self):
        delivery_ids_by_lot = self._find_delivery_ids_by_lot_iterative()
        for lot in self:
            lot.delivery_ids = delivery_ids_by_lot.get(lot.id, [])
            lot.delivery_count = len(lot.delivery_ids)

    def _compute_partner_ids(self):
        delivery_ids_by_lot = self._find_delivery_ids_by_lot_iterative()
        for lot in self:
            if delivery_ids_by_lot.get(lot.id, []):
                lot.partner_ids = self.env['stock.picking'].browse(delivery_ids_by_lot[lot.id]).sorted(key='date_done', reverse=True).partner_id
            else:
                lot.partner_ids = False

    @api.depends('quant_ids', 'quant_ids.quantity')
    def _compute_single_location(self):
        for lot in self:
            quants = lot.quant_ids.filtered(lambda q: q.quantity > 0)
            lot.location_id = quants.location_id if len(quants.location_id) == 1 else False

    def _set_single_location(self):
        quants = self.quant_ids.filtered(lambda q: q.quantity > 0)
        if len(quants.location_id) == 1:
            unpack = len(quants.package_id.quant_ids) > 1
            quants.move_quants(location_dest_id=self.location_id, message=_("Lot/Serial Number Relocated"), unpack=unpack)
        elif len(quants.location_id) > 1:
            raise UserError(_('You can only move a lot/serial to a new location if it exists in a single location.'))

    @api.model_create_multi
    def create(self, vals_list):
        lot_product_ids =  {val.get('product_id') for val in vals_list} | {self.env.context.get('default_product_id')}
        self.with_context(lot_product_ids=lot_product_ids)._check_create()
        return super(StockLot, self.with_context(mail_create_nosubscribe=True)).create(vals_list)

    def write(self, vals):
        if 'company_id' in vals:
            for lot in self:
                if lot.location_id.company_id and vals['company_id'] and lot.location_id.company_id.id != vals['company_id']:
                    raise UserError(_("You cannot change the company of a lot/serial number currently in a location belonging to another company."))
        if 'product_id' in vals and any(vals['product_id'] != lot.product_id.id for lot in self):
            move_lines = self.env['stock.move.line'].search([('lot_id', 'in', self.ids), ('product_id', '!=', vals['product_id'])])
            if move_lines:
                raise UserError(_(
                    'You are not allowed to change the product linked to a serial or lot number '
                    'if some stock moves have already been created with that number. '
                    'This would lead to inconsistencies in your stock.'
                ))
        return super().write(vals)

    def copy_data(self, default=None):
        default = dict(default or {})
        vals_list = super().copy_data(default=default)
        if 'name' not in default:
            for lot, vals in zip(self, vals_list):
                vals['name'] = _("(copy of) %s", lot.name)
        return vals_list

    @api.depends('quant_ids', 'quant_ids.quantity')
    @api.depends_context('owner_id', 'package_id', 'to_date', 'location', 'warehouse_id', 'allowed_company_ids')
    def _product_qty(self):
        domain_quant_loc, domain_move_in_loc, domain_move_out_loc = self.env['product.product'].with_context(skip_in_progress=True)._get_domain_locations()
        owner_id = self.env.context.get('owner_id')
        package_id = self.env.context.get('package_id')
        to_date = fields.Datetime.to_datetime(self.env.context.get('to_date'))
        dates_in_the_past = to_date and to_date < fields.Datetime.now()
        domain_quant = Domain([('lot_id', 'in', self.ids)]) & domain_quant_loc
        if owner_id is not None:
            domain_quant &= Domain([('owner_id', '=', owner_id)])
            domain_move_in_loc &= Domain([('owner_id', '=', owner_id)])
            domain_move_out_loc &= Domain([('owner_id', '=', owner_id)])
        if package_id is not None:
            domain_quant &= Domain([('package_id', '=', package_id)])
        quant_qty_by_lot = dict(self.env['stock.quant']._read_group(domain_quant, ['lot_id'], ['quantity:sum']))
        if not dates_in_the_past:
            for lot in self:
                lot.product_qty = quant_qty_by_lot.get(lot, 0.0)
        else:
            # If the date is in the past, we need to adjust the quantity on hand with the moves that happened after that date.
            domain_lot_done = Domain([('lot_id', 'in', self.ids), ('state', '=', 'done'), ('move_id.date', '>', to_date)])
            move_in_qty_by_lot = dict(self.env['stock.move.line']._read_group(domain_move_in_loc & domain_lot_done, ['lot_id'], ['quantity_product_uom:sum']))
            move_out_qty_by_lot = dict(self.env['stock.move.line']._read_group(domain_move_out_loc & domain_lot_done, ['lot_id'], ['quantity_product_uom:sum']))
            for lot in self:
                lot.product_qty = quant_qty_by_lot.get(lot, 0.0) - move_in_qty_by_lot.get(lot, 0.0) + move_out_qty_by_lot.get(lot, 0.0)

    def _search_product_qty(self, operator, value):
        op = PY_OPERATORS.get(operator)
        if not op:
            return NotImplemented
        if isinstance(value, Iterable) and not isinstance(value, str):
            value = {float(v) for v in value}
        else:
            value = float(value)
        domain = [
            ('lot_id', '!=', False),
            '|', ('location_id.usage', '=', 'internal'),
            '&', ('location_id.usage', '=', 'transit'), ('location_id.company_id', 'in', self.env.companies.ids)
        ]
        lots_w_qty = self.env['stock.quant']._read_group(domain=domain, groupby=['lot_id'], aggregates=['quantity:sum'], having=[('quantity:sum', '!=', 0)])
        ids = []
        lot_ids_w_qty = []
        for lot, quantity_sum in lots_w_qty:
            lot_id = lot.id
            lot_ids_w_qty.append(lot_id)
            if op(quantity_sum, value):
                ids.append(lot_id)

        # check if we need include zero values in result
        include_zero = op(0.0, value)
        if include_zero:
            return ['|', ('id', 'in', ids), ('id', 'not in', lot_ids_w_qty)]
        return [('id', 'in', ids)]

    def _search_partner_ids(self, operator, value):
        """ returns partner_ids that are directly delivered the product of the lot/SN, i.e. not
        lots/SNs that are consumed within a MO. This means this search is NOT symmetric with the
        partner_ids field within the form view since it uses different logic that isn't efficient
        enough for this search due to it being usable within the list view.
        """
        if operator in Domain.NEGATIVE_OPERATORS or not isinstance(value, (Iterable)):
            return NotImplemented
        is_no_partner = operator == 'in' and list(value) == [False]
        domain = Domain([
            ('lot_id', '!=', False),
            ('state', '=', 'done'),
        ])
        if is_no_partner:
            # reverse the search, get all lots sent to partner so we can return all lots NOT sent
            domain &= Domain('picking_partner_id', 'not in', value)
        else:
            domain &= Domain.OR([
                Domain('picking_partner_id', operator, value),
                Domain('move_partner_id', operator, value),
            ])
        domain &= Domain(self._get_outgoing_domain())
        move_lines = self.env['stock.move.line'].search(domain)

        if is_no_partner:
            return [('id', 'not in', move_lines.lot_id.ids)]
        return [('id', 'in', move_lines.lot_id.ids)]

    def action_lot_open_quants(self):
        self = self.with_context(search_default_lot_id=self.id, create=False)
        if self.env.user.has_group('stock.group_stock_manager'):
            self = self.with_context(inventory_mode=True)
        return self.env['stock.quant'].action_view_quants()

    def action_lot_open_transfers(self):
        self.ensure_one()

        action = {
            'res_model': 'stock.picking',
            'type': 'ir.actions.act_window'
        }
        if len(self.delivery_ids) == 1:
            action.update({
                'view_mode': 'form',
                'res_id': self.delivery_ids[0].id
            })
        else:
            action.update({
                'name': _("Delivery orders of %s", self.display_name),
                'domain': [('id', 'in', self.delivery_ids.ids)],
                'view_mode': 'list,form'
            })
        return action

    @api.model
    def _get_outgoing_domain(self):
        return [
            '|',
            '|', ('picking_code', '=', 'outgoing'), ('move_id.picking_code', '=', 'outgoing'),
            ('produce_line_ids', '!=', False),
        ]

    def _find_delivery_ids_by_lot(self, lot_path=None, delivery_by_lot=None):
        if lot_path is None:
            lot_path = set()
        domain = Domain([
            ('lot_id', 'in', self.ids),
            ('state', '=', 'done'),
        ]) & Domain(self._get_outgoing_domain())
        move_lines = self.env['stock.move.line'].search(domain)
        moves_by_lot = {
            lot_id: {'producing_lines': set(), 'barren_lines': set()}
            for lot_id in move_lines.lot_id.ids
        }
        for line in move_lines:
            if line.produce_line_ids:
                moves_by_lot[line.lot_id.id]['producing_lines'].add(line.id)
            else:
                moves_by_lot[line.lot_id.id]['barren_lines'].add(line.id)
        if delivery_by_lot is None:
            delivery_by_lot = dict()
        for lot in self:
            delivery_ids = set()

            if moves_by_lot.get(lot.id):
                producing_move_lines = self.env['stock.move.line'].browse(moves_by_lot[lot.id]['producing_lines'])
                barren_move_lines = self.env['stock.move.line'].browse(moves_by_lot[lot.id]['barren_lines'])

                if producing_move_lines:
                    lot_path.add(lot.id)
                    next_lots = producing_move_lines.produce_line_ids.lot_id.filtered(lambda l: l.id not in lot_path)
                    next_lots_ids = set(next_lots.ids)
                    # If some producing lots are in lot_path, it means that they have been previously processed.
                    # Their results are therefore already in delivery_by_lot and we add them to delivery_ids directly.
                    delivery_ids.update(*(delivery_by_lot.get(lot_id, []) for lot_id in (producing_move_lines.produce_line_ids.lot_id - next_lots).ids))

                    for lot_id, delivery_ids_set in next_lots._find_delivery_ids_by_lot(lot_path=lot_path, delivery_by_lot=delivery_by_lot).items():
                        if lot_id in next_lots_ids:
                            delivery_ids.update(delivery_ids_set)
                delivery_ids.update(barren_move_lines.picking_id.ids)

            delivery_by_lot[lot.id] = list(delivery_ids)
        return delivery_by_lot

    def _find_delivery_ids_by_lot_iterative(self):
        """ Retrieve all delivery IDs (outgoing picking) linked to the lots
            in self and all the lots found when parcouring the produce lines.
            :return: A dictionary where keys are the IDs of the original 'stock.lot'
                      records (self) and values are lists of associated 'stock.picking' IDs.
            :rtype: dict
        """

        all_lot_ids = set(self.ids)
        barren_lines = defaultdict(set)
        parent_map = defaultdict(set)

        # Prefetch the lines linked to lots and split them between producing lines
        # and barren lines (lines that have `produce_line_ids` and lines that don't
        # have them respectively) and build the map of the parents of each lot (so we
        # can browse the tree from the leaves to the root and propagate the pickings)
        queue = list(self.ids)
        while queue:
            domain = Domain([
                ('lot_id', 'in', queue),
                ('state', '=', 'done'),
            ]) & Domain(self._get_outgoing_domain())

            queue = []
            move_lines = self.env['stock.move.line'].search(domain)
            for line in move_lines:
                lot_id = line.lot_id.id

                produce_line_lot_ids = line.produce_line_ids.lot_id.ids
                if produce_line_lot_ids:
                    for child_lot_id in produce_line_lot_ids:
                        parent_map[child_lot_id].add(lot_id)
                else:
                    barren_lines[lot_id].add(line.id)

                next_lots = set(produce_line_lot_ids) - all_lot_ids
                all_lot_ids.update(next_lots)
                queue.extend(next_lots)

        # Initialize delivery_by_lot with barren lines (i.e. the leaves of the lot tree)
        lots_to_propagate = set()
        delivery_by_lot = {lot_id: set() for lot_id in all_lot_ids}
        for lot_id in barren_lines:
            barren_line_ids = barren_lines[lot_id]
            if barren_line_ids:
                barren_move_lines = self.env['stock.move.line'].browse(barren_line_ids)
                delivery_by_lot[lot_id].update(barren_move_lines.picking_id.ids)
                lots_to_propagate.add(lot_id)

        # Propagate the deliveries from the children to their parent lots.
        # This loop processes lots whose delivery sets have just been updated,
        # ensuring the new results are merged upward through the parent graph until
        # all deliveries are propagated
        while lots_to_propagate:
            lot_id = lots_to_propagate.pop()

            for parent_id in parent_map.get(lot_id, []):
                new_deliveries = delivery_by_lot[lot_id] - delivery_by_lot[parent_id]
                if new_deliveries:
                    delivery_by_lot[parent_id].update(new_deliveries)
                    lots_to_propagate.add(parent_id)

        return {lot_id: list(delivery_by_lot[lot_id]) for lot_id in delivery_by_lot}
