import torch
import numpy as np
import torch_mlu
import torch.nn.functional as F
from megatron import addons
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from tests.unit_tests.test_utilities import skip_370

def assert_allclose(dtype, l1, l2):
    assert len(l1) == len(l2), "Unequal number of outputs."
    for input_data, other_data in zip(l1, l2):
        if input_data is None and other_data is None:
            continue
        if dtype in [torch.float32]:
            rtol = 1e-3
            atol = 1e-3
        else:
            rtol = 3e-3
            atol = 3e-3
        result = torch.allclose(input_data, other_data, rtol=rtol, atol=atol, equal_nan=True)
        if not result:
            diff = torch.abs(input_data - other_data).flatten()
            m = torch.argmax(diff)
            msg = (f"Outputs not close enough."
                   f"Location of the maximum difference: {m.item()} "
                   f"with {input_data.flatten()[m].item()} vs {other_data.flatten()[m].item()} "
                   f"(diff {diff[m].item()})."
            )
            raise AssertionError(msg)
        assert(result)

def swiglu(y):
    y_1, y_2 = torch.chunk(y, 2, -1) 
    return F.silu(y_1) * y_2 

def bias_swiglu(y, bias):
    y = y + bias
    return swiglu(y)

def swiglu_back(g, y):
    y_1, y_2 = torch.chunk(y, 2, -1)
    return torch.cat(
        (g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1
    )

def bias_swiglu_back(g, y, bias):
    y = y + bias
    return swiglu_back(g, y)

class BiasSwiGLUFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, bias):
        ctx.save_for_backward(input, bias)
        return bias_swiglu(input, bias)

    @staticmethod
    def backward(ctx, grad_output):
        input, bias = ctx.saved_tensors
        tmp = bias_swiglu_back(grad_output, input, bias)
        return tmp, tmp

class SwiGLUFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return swiglu(input)

    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors
        tmp = swiglu_back(grad_output, input[0])
        return tmp

def local_bias_swiglu_impl(input, bias):
    ori_shape = input.shape
    assert len(ori_shape) in [2, 3]
    input = input.view(-1, ori_shape[-1])
    if bias is not None:
        output = BiasSwiGLUFunction.apply(input, bias)
    else:
        output = SwiGLUFunction.apply(input)

    return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)

@skip_370
def test_swiglu():
    np.random.seed(0)
    torch.manual_seed(0)
    a = torch.randn(10,10, requires_grad=True, dtype=torch.bfloat16).mlu()
    a_ref = a.detach().clone().requires_grad_(True)
    bias = torch.randn(10, requires_grad=True, dtype=torch.bfloat16).mlu()
    fused_out = bias_swiglu_impl(a, None)
    expected_out = local_bias_swiglu_impl(a_ref, None)
    assert(torch.equal(fused_out, expected_out))
    
    fused_out.retain_grad()
    expected_out.retain_grad()
    a.retain_grad()
    a_ref.retain_grad()
    fused_out.sum().backward()
    expected_out.sum().backward()
    assert(torch.equal(fused_out.grad, expected_out.grad))

@skip_370
def test_bias_swiglu():
    np.random.seed(0)
    torch.manual_seed(0)
    a = torch.randn(10,10, requires_grad=True, dtype=torch.bfloat16).mlu()
    a_ref = a.detach().clone().requires_grad_(True)
    bias = torch.randn(10, requires_grad=True, dtype=torch.bfloat16).mlu()
    bias_ref = bias.detach().clone().requires_grad_(True)
    fused_out = bias_swiglu_impl(a, bias)
    expected_out = local_bias_swiglu_impl(a_ref, bias_ref)
    assert(torch.equal(fused_out, expected_out))
    
    fused_out.retain_grad()
    expected_out.retain_grad()
    a.retain_grad()
    a_ref.retain_grad()
    fused_out.sum().backward()
    expected_out.sum().backward()
    assert(torch.equal(fused_out.grad, expected_out.grad))
