# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.
from contextlib import nullcontext
import os
import math
import numpy as np
import torch
import torch.nn.functional as F
from typing import Optional

from megatron import core
from megatron.training import get_timers, get_args, get_num_microbatches
from megatron.legacy.model.module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.legacy.model.enums import AttnMaskType, LayerType, AttnType
from megatron.legacy.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding, apply_rotary_pos_emb
from megatron.legacy.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_norm
from megatron.core.tensor_parallel import (
    gather_from_sequence_parallel_region_to_moe,
    reduce_scatter_to_sequence_parallel_region_from_moe,
    get_cuda_rng_tracker,
    get_data_parallel_rng_tracker_name
)
from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_and_expert_parallel_group
from megatron.core.jit import jit_fuser

try:
    from einops import rearrange
except ImportError:
    rearrange = None

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
    try:
        from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
    except ImportError:
        flash_attn_unpadded_func 

from megatron.legacy.model.transformer import DropPath, SwitchMLP, ParallelMLP, CoreAttention, ParallelAttention, ParallelTransformerLayer, NoopTransformerLayer, ParallelTransformer
from megatron.legacy.model.transformer import _get_num_layers, _get_layer_type
from megatron.addons.function_wrapper import FUNCTION_WRAPPER

from apex.contrib.fused_bias_dropout.fused_bias_dropout import get_bias_dropout_add


#new feature: bias_dropout_add_fused_train op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.legacy.model.transformer.bias_dropout_add_fused_train')
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: Optional[torch.Tensor],
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
    return get_bias_dropout_add(True, True)((x,bias.contiguous()), residual, prob)

#new feature: geglu
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.legacy.model.transformer.ParallelMLP.__init__')
def ParallelMLP_init(self, config, is_expert=False):
    super(ParallelMLP, self).__init__()
    args = get_args()

    self.add_bias = config.add_bias_linear

    ffn_hidden_size = config.ffn_hidden_size
    if config.gated_linear_unit:
        ffn_hidden_size *= 2

    # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
    self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
        config.hidden_size,
        ffn_hidden_size,
        config=config,
        init_method=config.init_method,
        bias=self.add_bias,
        gather_output=False,
        skip_bias_add=True,
        is_expert=is_expert,
    )

    self.bias_gelu_fusion = False
    self.activation_func = None
    self.swiglu = args.swiglu
    self.geglu = args.geglu

    if args.openai_gelu:
        self.activation_func = openai_gelu
    elif args.onnx_safe:
        self.activation_func = erf_gelu
    elif args.swiglu:
        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]
        self.activation_func = swiglu
    elif args.geglu:
        def geglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.gelu(x[0]) * x[1]
        self.activation_func = geglu
    elif args.squared_relu:
        def squared_relu(x):
            return torch.pow(F.relu(x), 2)
        self.activation_func = squared_relu
    elif args.silu:
        self.activation_func = F.silu
    else:
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu

    # Project back to h.
    self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
        config.ffn_hidden_size,
        config.hidden_size,
        config=config,
        init_method=config.output_layer_init_method,
        bias=self.add_bias,
        skip_bias_add=True,
        input_is_parallel=True,
        is_expert=is_expert,
    )

#new feature: support alibi
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.legacy.model.transformer.CoreAttention.forward')
def CoreAttention_forward(self, query_layer, key_layer,
            value_layer, attention_mask, alibi=None):
    # ===================================
    # Raw attention scores. [b, np, s, s]
    # ===================================
    # [b, np, sq, sk]
    output_size = (query_layer.size(1),
                   query_layer.size(2),
                   query_layer.size(0),
                   key_layer.size(0))

    # [sq, b, np, hn] -> [sq, b * np, hn]
    query_layer = query_layer.reshape(output_size[2],
                                      output_size[0] * output_size[1], -1)
    # [sk, b, np, hn] -> [sk, b * np, hn]
    key_layer = key_layer.view(output_size[3],
                               output_size[0] * output_size[1], -1)

    if alibi is None:
        # preallocting input tensor: [b * np, sq, sk]
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
            query_layer.dtype, "mpu")

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_input_buffer,
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))
    else:
        assert output_size[1] == alibi.size(1), \
            'number of query heads does not match ALiBi tensor.'
        matmul_result = alibi[..., :output_size[2], :output_size[3]].expand(
            output_size[0], -1, -1, -1)
        matmul_result = matmul_result.reshape(
            output_size[0] * output_size[1], -1, output_size[3])

        if self.apply_query_key_layer_scaling:
            beta = 1.0 / self.layer_number
        else:
            beta = 1.0

        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=beta, alpha=(1.0 / self.norm_factor))

    # change view to [b, np, sq, sk]
    attention_scores = matmul_result.view(*output_size)

    # ===========================
    # Attention probs and dropout
    # ===========================

    # attention scores and attention mask [b, np, sq, sk]
    attention_probs = self.scale_mask_softmax(attention_scores,
                                              attention_mask)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    if not self.sequence_parallel:
        with tensor_parallel.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)
    else:
        attention_probs = self.attention_dropout(attention_probs)

    # =========================
    # Context layer. [sq, b, hp]
    # =========================

    # value_layer -> context layer.
    # [sk, b, np, hn] --> [b, np, sq, hn]

    # context layer shape: [b, np, sq, hn]
    output_size = (value_layer.size(1),
                   value_layer.size(2),
                   query_layer.size(0),
                   value_layer.size(3))

    # change view [sk, b * np, hn]
    value_layer = value_layer.view(value_layer.size(0),
                                   output_size[0] * output_size[1], -1)

    # change view [b * np, sq, sk]
    attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                           output_size[2], -1)

    # matmul: [b * np, sq, hn]
    context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

    # change view [b, np, sq, hn]
    context_layer = context_layer.view(*output_size)

    # [b, np, sq, hn] --> [sq, b, np, hn]
    context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

    # [sq, b, np, hn] --> [sq, b, hp]
    new_context_layer_shape = context_layer.size()[:-2] + \
        (self.hidden_size_per_partition,)

#new feature: alibi
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.legacy.model.transformer.ParallelAttention._checkpointed_attention_forward')
def ParallelAttention_checkpointed_attention_forward(self, query_layer, key_layer,
                                                 value_layer, attention_mask,
                                                 rotary_pos_emb=None,
                                                 alibi=None):
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            alibi = inputs[6]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask,
                                          alibi=alibi)
            return output_

        q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
            else rotary_pos_emb

        hidden_states = tensor_parallel.checkpoint(
            custom_forward,
            False, query_layer, key_layer, value_layer, attention_mask,
            q_pos_emb, k_pos_emb, alibi)

        return hidden_states

#new feature: alibi
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.legacy.model.transformer.ParallelAttention.forward')
def ParallelAttention_forward(self, hidden_states, attention_mask,
                              encoder_output=None, inference_params=None,
                              rotary_pos_emb=None, alibi=None):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        is_first_step = False
        if inference_params:
            if self.layer_number not in inference_params.key_value_memory_dict:
                inf_max_seq_len = inference_params.max_sequence_length
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)
                inference_value_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)

                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
                is_first_step = True
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]

        # =====================
        # Query, Key, and Value
        # =====================
        if self.attention_type == AttnType.self_attn:

            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_query_groups_per_partition,
                (
                    (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
                    * self.hidden_size_per_attention_head
                ),
            )
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
            (query_layer,
            key_layer,
            value_layer) = torch.split(
                mixed_x_layer,
                [
                    (
                        self.num_attention_heads_per_partition // self.num_query_groups_per_partition
                        * self.hidden_size_per_attention_head
                    ),
                    self.hidden_size_per_attention_head,
                    self.hidden_size_per_attention_head
                ],
                dim=3)

            # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
            query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
            value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)

        # ==================================
        # Adjust key and value for inference
        # ==================================

        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
            if isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = rotary_pos_emb
            else:
                rotary_pos_emb = ((rotary_pos_emb,) * 2)

        if inference_params:
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= inference_key_memory.size(1)
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= inference_key_memory.size(0)
            # Copy key and values.
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
                :sequence_end, batch_start:batch_end, ...]
            value_layer = inference_value_memory[
                :sequence_end, batch_start:batch_end, ...]


            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # need to cross check this condition during inference
                # if not set_inference_key_value_memory:
                if not is_first_step:
                    # In inference, we compute one token at a time.
                    # Select the correct positional embedding
                    # (only the last token in the sequence)
                    q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
                else:
                    # In the first forward pass of inference,
                    # we use the entire provided prefix.
                    # q_pos_emb here has the rope embeddings of the entire
                    # prefix + to-be-generated output so
                    # we slice to just the prefix.
                    q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)

        # ==================================
        # core attention computation
        # ==================================

        # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
        if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
            key_layer = key_layer.repeat_interleave(
                self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
                dim = 2
            )
            value_layer = value_layer.repeat_interleave(
                self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
                dim = 2
            )

        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb,self.config)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb,self.config)
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

        if not self.use_flash_attn:
            if self.checkpoint_core_attention:
                context_layer = self._checkpointed_attention_forward(
                    query_layer, key_layer, value_layer, attention_mask, alibi=alibi)
            else:
                context_layer = self.core_attention(
                    query_layer, key_layer, value_layer, attention_mask, alibi=alibi)
        else:
            q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
                       for x in (query_layer, key_layer, value_layer)]
            if not self.sequence_parallel:
                with tensor_parallel.get_cuda_rng_tracker().fork():
                    context_layer = self.core_attention_flash(q, k, v)
            else:
                context_layer = self.core_attention_flash(q, k, v)
            context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()

        # =================
        # Output. [sq, b, h]
        # =================

        output, bias = self.dense(context_layer)

        return output, bias

@staticmethod
def _build_alibi_tensor(max_seq_len, num_attention_heads, batch_size):
    # Copied from bigscience-workshop/Megatron-DeepSpeed
    # Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
    """Returns tensor shaped
    (1, num_attention_heads_per_partition, 1, max_seq_len),
    """

    def get_slopes(n):
        def get_slopes_power_of_2(n):
            start = (2 ** (-2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio ** i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n))
            return (
                get_slopes_power_of_2(closest_power_of_2)
                + get_slopes(
                    2 * closest_power_of_2,
                )[0::2][:n - closest_power_of_2]
            )

    slopes = torch.Tensor(get_slopes(num_attention_heads))
    alibi = (
        slopes.unsqueeze(1).unsqueeze(1)
        * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(
            num_attention_heads, -1, -1)
    )

    # Select the part of the tensor that corresponds to our tensor
    # parallel index.
    tp_world_size = mpu.get_tensor_model_parallel_world_size()
    tp_index = mpu.get_tensor_model_parallel_rank()
    alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index]

    alibi = alibi.unsqueeze(0)
    return alibi

#new feature: support alibi
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.legacy.model.transformer.ParallelTransformerLayer.__init__')
def ParallelTransformerLayer_init(self, config,
                                  layer_number, layer_type=LayerType.encoder,
                                  self_attn_mask_type=AttnMaskType.padding,
                                  drop_path_rate=0.):
    args = get_args()

    super(ParallelTransformerLayer, self).__init__()
    self.layer_number = layer_number
    self.layer_type = layer_type

    self.apply_residual_connection_post_norm \
        = config.apply_residual_connection_post_layernorm

    self.bf16 = config.bf16
    self.fp32_residual_connection = config.fp32_residual_connection

    # Normalize the input data.
    self.input_norm = get_norm(config)

    # Self attention.
    self.self_attention = ParallelAttention(
        config,
        layer_number,
        attention_type=AttnType.self_attn,
        attn_mask_type=self_attn_mask_type)
    self.hidden_dropout = config.hidden_dropout
    self.bias_dropout_fusion = config.bias_dropout_fusion
    self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

    # Normalize the attention output
    self.post_attention_norm = get_norm(config)

    # Cross attention.
    if self.layer_type in (LayerType.decoder,
                           LayerType.retro_decoder,
                           LayerType.retro_decoder_with_retriever,
                           LayerType.retro_encoder):
        self.inter_attention = ParallelAttention(
            config,
            layer_number,
            attention_type=AttnType.cross_attn)
        # Normalize the attention output.
        self.post_inter_attention_norm = get_norm(config)

    # MLP
    if args.num_experts is not None:
        self.mlp = SwitchMLP(config)
    else:
        self.mlp = ParallelMLP(config)

    # ALiBi
    if args.position_embedding_type == 'alibi':
        assert not args.use_flash_attn, \
            'ALiBi does not work with FlashAttention currently'
        self.alibi = self._build_alibi_tensor(
            args.seq_length,
            args.num_attention_heads,
            args.micro_batch_size,
        ).to(torch.cuda.current_device())
        if args.params_dtype is torch.float16:
            self.alibi = self.alibi.to(torch.float16)
        elif args.params_dtype is torch.bfloat16:
            self.alibi = self.alibi.to(torch.bfloat16)
    else:
        self.alibi = None

    # Set bias+dropout+add fusion grad_enable execution handler.
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
    self.bias_dropout_add_exec_handler = \
            nullcontext if use_nvfuser else torch.enable_grad

    if args.retro_add_retriever:
        self.retro_num_neighbors = args.retro_num_neighbors
        self.retro_chunk_length = args.retro_chunk_length
        self.retro_retrieved_length = \
            args.retro_num_retrieved_chunks * args.retro_chunk_length

    # Retriever (bi-directional transformer with cross attention)
    if layer_type == LayerType.retro_decoder_with_retriever:
        self.retriever = ParallelTransformer(
            config=config,
            model_type=ModelType.retro_encoder,
            self_attn_mask_type=AttnMaskType.padding,
            pre_process=True,
            post_process=False,
        )
        self._retriever_key = 'retriever'
    else:
        self.retriever = None

#new feature: support geglu
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.legacy.model.transformer.ParallelTransformer.__init__')
def ParallelTransformer_init(self, config,
                 model_type, layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding,
                 post_norm=True,
                 pre_process=True,
                 post_process=True,
                 drop_path_rate=0.0):
    super(ParallelTransformer, self).__init__()
    args = get_args()

    self.layer_type = layer_type
    self.model_type = model_type
    self.bf16 = config.bf16
    self.fp32_residual_connection = config.fp32_residual_connection
    self.post_norm = post_norm
    self.pre_process = pre_process
    self.post_process = post_process
    self.input_tensor = None
    self.drop_path_rate = drop_path_rate
    self.transformer_impl = args.transformer_impl
    self.retro_add_retriever = args.retro_add_retriever

    # Store activation checkpoiting flag.
    self.recompute_granularity = config.recompute_granularity
    self.recompute_method = config.recompute_method
    self.recompute_num_layers = config.recompute_num_layers
    self.distribute_saved_activations = \
        config.distribute_saved_activations and not config.sequence_parallel

    self.sequence_parallel = config.sequence_parallel

    # Transformer Engine Init.
    self.transformer_engine_v_0_10 = False
    self.transformer_engine_v_0_11 = False
    self.transformer_engine_v_0_8 = False
    if self.transformer_impl == 'transformer_engine':
        global transformer_engine
        import transformer_engine
        from importlib.metadata import version
        from pkg_resources import packaging

        te_version = packaging.version.Version(version("transformer-engine"))
        if te_version >= packaging.version.Version("0.8.0"):
            self.transformer_engine_v_0_8 = True
        if te_version >= packaging.version.Version("0.10.0"):
            self.transformer_engine_v_0_10 = True
        if te_version >= packaging.version.Version("0.11.0"):
            self.transformer_engine_v_0_11 = True

        del version, packaging

        assert not args.squared_relu, "TransformerEngine does not support squared relu activation."

    self.use_fp8 = args.fp8 is not None
    self.fp8_recipe = None
    self.fp8_group = None
    if self.use_fp8:
        assert args.transformer_impl == 'transformer_engine', \
            'transformer-engine required for fp8 training and inference'
        self.fp8_group = mpu.get_amax_reduction_group()
        if args.fp8 == "e4m3":
            fp8_format = transformer_engine.common.recipe.Format.E4M3
        elif args.fp8 == "hybrid":
            fp8_format = transformer_engine.common.recipe.Format.HYBRID
        else:
            raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
        self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=args.fp8_margin,
            interval=args.fp8_interval,
            fp8_format=fp8_format,
            amax_history_len=args.fp8_amax_history_len,
            amax_compute_algo=args.fp8_amax_compute_algo,
            override_linear_precision=(False, False, not args.fp8_wgrad),
        )

    self.num_microbatches_in_previous_step = -1
    self.microbatch_count = 0
    self.checkpoint_core_attention = config.recompute_granularity == 'selective'

    # Number of layers.
    self.num_layers = _get_num_layers(args, model_type,
                                      layer_type==LayerType.decoder)

    self.drop_path_rates = [
        rate.item() for rate in
        torch.linspace(0, self.drop_path_rate, config.num_layers)]

    self.retro_layer_numbers = None
    if model_type == ModelType.retro_decoder:
        retro_layer_start = 6 if config.num_layers <= 15 else 9
        self.retro_layer_numbers = \
            np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
    if model_type == ModelType.retro_encoder:
        self.retro_layer_numbers = [1]

    # Transformer layers.
    if args.retro_add_retriever:
        assert self.recompute_granularity != 'full', \
            "Full recompute not supported for Retro."
        assert args.transformer_impl == 'local', \
            "Transformer engine does not support Retro layers."
    def build_layer(layer_number):
        if args.transformer_impl == 'local':
            current_layer_type = _get_layer_type(
                model_type, layer_type, self.retro_layer_numbers,
                layer_number)
            return ParallelTransformerLayer(
                config,
                layer_number,
                layer_type=current_layer_type,
                self_attn_mask_type=self_attn_mask_type,
                drop_path_rate=self.drop_path_rates[layer_number - 1])
        else:
            # This argument is only available from TE v0.10 onwards.
            extra_transformer_engine_kwargs = {}
            if self.transformer_engine_v_0_8:
                extra_transformer_engine_kwargs["bias"] = args.add_bias_linear
            if self.transformer_engine_v_0_10:
                #extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
                activation_mode = ""
                if args.swiglu:
                    activation_mode = "swiglu"
                elif args.geglu:
                    activation_mode = "geglu"
                else:
                    activation_mode = "gelu"
                extra_transformer_engine_kwargs["activation"] = activation_mode
            if self.transformer_engine_v_0_11:
                extra_transformer_engine_kwargs["normalization"] = args.normalization
            assert config.attention_softmax_in_fp32, "TransformerEngine only supports softmax compute in FP32."
            assert (
                (bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and args.fp16) == config.apply_query_key_layer_scaling
            ), "Unsupported config for apply_query_key_layer_scaling in TransformerEngine."
            if args.num_query_groups > 1:
                extra_transformer_engine_kwargs["num_gqa_groups"] = args.num_query_groups
            return transformer_engine.pytorch.TransformerLayer(
                config.hidden_size,
                config.ffn_hidden_size,
                config.num_attention_heads,
                layernorm_epsilon=config.layernorm_epsilon,
                hidden_dropout=config.hidden_dropout,
                attention_dropout=config.attention_dropout,
                init_method=config.init_method,
                output_layer_init_method=config.output_layer_init_method,
                layer_number=layer_number,
                kv_channels=config.kv_channels,
                self_attn_mask_type=self_attn_mask_type.name,
                tp_group=mpu.get_tensor_model_parallel_group(),
                get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
                fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
                seq_length=args.seq_length,
                micro_batch_size=args.micro_batch_size,
                sequence_parallel=config.sequence_parallel,
                params_dtype=config.params_dtype,
                apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
                output_layernorm=False,
                layer_type="encoder",
                drop_path_rate=self.drop_path_rates[layer_number - 1],
                set_parallel_mode=True,
                fuse_qkv_params=True,
                ub_tp_comm_overlap=args.tp_comm_overlap,
                **extra_transformer_engine_kwargs)

    if config.virtual_pipeline_model_parallel_size is not None:
        assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
            'num_layers_per_stage must be divisible by ' \
            'virtual_pipeline_model_parallel_size'
        assert args.model_type != ModelType.encoder_and_decoder
        # Number of layers in each model chunk is the number of layers in the stage,
        # divided by the number of model chunks in a stage.
        self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
        # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
        # layers to stages like (each list is a model chunk):
        # Stage 0: [0]  [2]  [4]  [6]
        # Stage 1: [1]  [3]  [5]  [7]
        # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
        # layers to stages like (each list is a model chunk):
        # Stage 0: [0, 1]  [4, 5]
        # Stage 1: [2, 3]  [6, 7]
        offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
            config.num_layers // config.virtual_pipeline_model_parallel_size) + \
            (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
    else:
        # Each stage gets a contiguous set of layers.
        if args.model_type == ModelType.encoder_and_decoder and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            pipeline_rank = mpu.get_pipeline_model_parallel_rank()
            if layer_type == LayerType.encoder:
                offset = pipeline_rank * self.num_layers
            else:
                num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
        else:
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers

    if self.num_layers == 0:
        # When a standalone embedding stage is used (e.g.,
        # args.standalone_embedding_stage == True), virtual pipeline ranks
        # on pipeline rank 0 will have zero transformer layers assigned to
        # them. This results in the model's input and output tensors to be
        # the same, which will cause failure for certain output tensor
        # optimizations (e.g., pipeline output deallocation). To remedy
        # this, we assign a 'no-op' layer on these ranks, which will
        # disconnect the input tensor from the output tensor.
        self.num_layers = 1
        self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
    else:
        self.layers = torch.nn.ModuleList(
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])

        # Update dropout rate for Retro encoder.
        if model_type == ModelType.retro_encoder:
            for layer in self.layers:
                if layer.self_attention.use_flash_attn:
                    layer.self_attention.core_attention_flash.dropout_p = \
                        torch.nn.Dropout(args.retro_encoder_attention_dropout)
                else:
                    layer.self_attention.core_attention.attention_dropout.p =\
                        args.retro_encoder_attention_dropout
                layer.hidden_dropout = args.retro_encoder_hidden_dropout

    if self.post_process and self.post_norm:
        # Final layer norm before output.
        self.final_norm = get_norm(config)
