# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.

import torch

from megatron.core.parallel_state import (
    get_tensor_model_parallel_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.addons.function_wrapper import FUNCTION_WRAPPER

@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward')
def _VocabParallelCrossEntropy_forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):

    # Maximum value along vocab dimension across all GPUs.
    logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
    torch.distributed.all_reduce(
        logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
    )
    # Subtract the maximum value.
    vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))

    # Get the partition's vocab indecies
    get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
    partition_vocab_size = vocab_parallel_logits.size()[-1]
    rank = get_tensor_model_parallel_rank()
    world_size = get_tensor_model_parallel_world_size()
    vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)

    # Create a mask of valid vocab ids (1 means it needs to be masked).
    target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
    masked_target = target.clone() - vocab_start_index
    masked_target[target_mask] = 0

    # Get predicted-logits = logits[target].
    # For Simplicity, we convert logits to a 2-D tensor with size
    # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
    logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
    masked_target_1d = masked_target.view(-1)
    arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
    predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
    predicted_logits_1d = predicted_logits_1d.clone().contiguous()
    predicted_logits = predicted_logits_1d.view_as(target)
    predicted_logits[target_mask] = 0.0
    # All reduce is needed to get the chunks from other GPUs.
    torch.distributed.all_reduce(
        predicted_logits,
        op=torch.distributed.ReduceOp.SUM,
        group=get_tensor_model_parallel_group(),
    )

    # Sum of exponential of logits along vocab dimension across all GPUs.
    exp_logits = vocab_parallel_logits
    torch.exp(vocab_parallel_logits, out=exp_logits)
    sum_exp_logits = exp_logits.sum(dim=-1)
    torch.distributed.all_reduce(
        sum_exp_logits,
        op=torch.distributed.ReduceOp.SUM,
        group=get_tensor_model_parallel_group(),
    )

    # Loss = log(sum(exp(logits))) - predicted-logit.
    loss = torch.log(sum_exp_logits) - predicted_logits

    # Normalize and optionally smooth logits
    exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))

    vocab_size = exp_logits.size(-1)
    if label_smoothing > 0:
        """
        We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
        = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
        = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
        = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
        = (K * (1 - alpha) - 1) / (K - 1)) * y_gt  + (alpha / (K - 1)) * \sum_{i} y_i
        = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
        From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
        """
        assert 1.0 > label_smoothing > 0.0
        smoothing = label_smoothing * vocab_size / (vocab_size - 1)

        # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
        log_probs = torch.log(exp_logits)
        mean_log_probs = log_probs.mean(dim=-1)
        loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs

    ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size

    # Store softmax, target-mask and masked-target for backward pass.
    ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

    return loss

