import torch
import numpy as np
import torch_mlu
import torch.nn.functional as F
from megatron import addons
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl

def assert_allclose(dtype, l1, l2):
    assert len(l1) == len(l2), "Unequal number of outputs."
    for input_data, other_data in zip(l1, l2):
        if input_data is None and other_data is None:
            continue
        if dtype in [torch.float32]:
            rtol = 1e-3
            atol = 1e-3
        else:
            rtol = 3e-3
            atol = 3e-3
        result = torch.allclose(input_data, other_data, rtol=rtol, atol=atol, equal_nan=True)
        if not result:
            diff = torch.abs(input_data - other_data).flatten()
            m = torch.argmax(diff)
            msg = (f"Outputs not close enough."
                   f"Location of the maximum difference: {m.item()} "
                   f"with {input_data.flatten()[m].item()} vs {other_data.flatten()[m].item()} "
                   f"(diff {diff[m].item()})."
            )
            raise AssertionError(msg)
        assert(result)

def geglu(y):
    y_1, y_2 = torch.chunk(y, 2, -1)
    return (y_1 * 0.5 * (1.0 + torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1)))) * y_2

def bias_geglu(bias, y):
    y = y + bias
    return geglu(y)

# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
def geglu_back(g, y):
    y_1, y_2 = torch.chunk(y, 2, -1)
    tanh_out = torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
    ff = 0.5 * y_1 * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * y_1 * y_1)) + 0.5 * (
        1 + tanh_out
    )
    return torch.cat(((g * y_2) * ff, g * (y_1 * 0.5 * (1.0 + tanh_out))), -1)

def bias_geglu_back(g, y, bias):
    y = y + bias
    return geglu_back(g, y)


class BiasGeGLUFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, bias):
        ctx.save_for_backward(input, bias)
        return bias_geglu(input, bias)

    @staticmethod
    def backward(ctx, grad_output):
        input, bias = ctx.saved_tensors
        tmp = bias_geglu_back(grad_output, input, bias)
        return tmp, tmp


class GeGLUFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return geglu(input)

    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors
        tmp = geglu_back(grad_output, input[0])
        return tmp

def local_bias_geglu_impl(input, bias):
    ori_shape = input.shape
    assert len(ori_shape) in [2, 3]
    input = input.view(-1, ori_shape[-1])
    if bias is not None:
        output = BiasGeGLUFunction.apply(input, bias)
    else:
        output = GeGLUFunction.apply(input)

    return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)

#based on jira CNNLEXTRAS-3406, the pricision of this op on mlu is higner than on gpu for bf16 dtype,
#thus, bf16 will not be tested.
def test_geglu():
    np.random.seed(0)
    torch.manual_seed(0)
    a = torch.randn(10,10, requires_grad=True, dtype=torch.float32).mlu()
    a_ref = a.detach().clone().requires_grad_(True)
    bias = torch.randn(10, requires_grad=True, dtype=torch.float32).mlu()
    fused_out = bias_geglu_impl(a, None)
    expected_out = local_bias_geglu_impl(a_ref, None)
    assert_allclose(torch.float32, fused_out, expected_out)
    
    fused_out.retain_grad()
    expected_out.retain_grad()
    a.retain_grad()
    a_ref.retain_grad()
    fused_out.sum().backward()
    expected_out.sum().backward()
    assert_allclose(torch.float32, fused_out.grad, expected_out.grad)

#based on jira CNNLEXTRAS-3406, the pricision of this op on mlu is higner than on gpu for bf16 dtype,
#thus, bf16 will not be tested.
def test_bias_geglu():
    np.random.seed(0)
    torch.manual_seed(0)
    a = torch.randn(10,10, requires_grad=True, dtype=torch.float32).mlu()
    a_ref = a.detach().clone().requires_grad_(True)
    bias = torch.randn(10, requires_grad=True, dtype=torch.float32).mlu()
    bias_ref = bias.detach().clone().requires_grad_(True)
    fused_out = bias_geglu_impl(a, bias)
    expected_out = local_bias_geglu_impl(a_ref, bias_ref)
    assert_allclose(torch.float32, fused_out, expected_out)
    
    fused_out.retain_grad()
    expected_out.retain_grad()
    a.retain_grad()
    a_ref.retain_grad()
    fused_out.sum().backward()
    expected_out.sum().backward()
    assert_allclose(torch.float32, fused_out.grad, expected_out.grad)
