# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.

import torch
from transformer_engine.pytorch.module.layernorm_mlp import _act_func
from megatron.addons.function_wrapper import FUNCTION_WRAPPER


#new feature: bias_swiglu forward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_swiglu.BiasSwiGLUFunction.forward')
def BiasSwiGLUFunction_forward(ctx, input, bias):
    ctx.save_for_backward(input, bias)
    return _act_func('swiglu')[0](input, bias)

#new feature: bias_swiglu backward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_swiglu.BiasSwiGLUFunction.backward')
def BiasSwiGLUFunction_backward(ctx, grad_output):
    input, bias = ctx.saved_tensors
    return _act_func('swiglu')[1](grad_output, input, bias)

#new feature: swiglu forward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_swiglu.SwiGLUFunction.forward')
def SwiGLUFunction_forward(ctx, input):
    ctx.save_for_backward(input)
    return _act_func('swiglu')[0](input)

#new feature: swiglu backward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_swiglu.SwiGLUFunction.backward')
def SwiGLUFunction_backward(ctx, grad_output):
    input = ctx.saved_tensors
    return _act_func('swiglu')[1](grad_output, input[0])
