# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Cambricon CORPORATION. All rights reserved.

import torch
from transformer_engine.pytorch.module.layernorm_mlp import _act_func
from megatron.addons.function_wrapper import FUNCTION_WRAPPER


#new feature: bias_geglu forward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_geglu.BiasGeGLUFunction.forward')
def BiasGeGLUFunction_forward(ctx, input, bias):
    ctx.save_for_backward(input, bias)
    return _act_func('geglu')[0](input, bias)

#new feature: bias_geglu backward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_geglu.BiasGeGLUFunction.backward')
def BiasGeGLUFunction_backward(ctx, grad_output):
    input, bias = ctx.saved_tensors
    return _act_func('geglu')[1](grad_output, input, bias)

#new feature: geglu forward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_geglu.GeGLUFunction.forward')
def GeGLUFunction_forward(ctx, input):
    ctx.save_for_backward(input)
    return _act_func('geglu')[0](input)

#new feature: geglu backward fused_op
@FUNCTION_WRAPPER.register_wrapper(func_name='megatron.core.fusions.fused_bias_geglu.GeGLUFunction.backward')
def GeGLUFunction_backward(ctx, grad_output):
    input = ctx.saved_tensors
    return _act_func('geglu')[1](grad_output, input[0])

