# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.
import torch
from typing import Optional, Tuple
from apex.contrib.fused_bias_dropout.fused_bias_dropout import get_bias_dropout_add
from megatron.addons.function_wrapper import FUNCTION_WRAPPER


#new feature: bias_dropout_add_fused_train op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_dropout.bias_dropout_add_fused_train')
def bias_dropout_add_fused_train(
    x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
) -> torch.Tensor:
    return get_bias_dropout_add(True, True)(x_with_bias, residual, prob)
