#! /bin/bash
set -e


CUR_DIR=$(cd `dirname $0`; pwd)

export MLU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export PYTHONPATH=$CUR_DIR/../../:$PYTHONPATH
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/torch/venv3/pytorch/lib/python3.10/site-packages/torch_mlu/csrc/lib
pip install -r $CUR_DIR/requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

nproc=8

pytest test_basic.py
pytest test_imports.py
pytest test_utils.py
torchrun --nproc_per_node $nproc -m pytest test_parallel_state.py
torchrun --nproc_per_node $nproc -m pytest test_training.py

pushd $CUR_DIR/data
    torchrun --nproc_per_node $nproc -m pytest test_builder.py
    torchrun --nproc_per_node $nproc -m pytest test_mock_gpt_dataset.py
    torchrun --nproc_per_node $nproc -m pytest test_multimodal_dataset.py

    #Need download dataset form github, the test results are unstable due to the network speed
    #torchrun --nproc_per_node $nproc -m pytest test_preprocess_mmdata.py
    #torchrun --nproc_per_node $nproc -m pytest test_preprocess_data.py
popd

pushd $CUR_DIR/fusions
    pytest test_torch_softmax.py
popd

pushd $CUR_DIR/models
    torchrun --nproc_per_node $nproc -m pytest test_base_embedding.py
    torchrun --nproc_per_node $nproc -m pytest test_bert_model.py
    torchrun --nproc_per_node $nproc -m pytest test_clip_vit_model.py
    torchrun --nproc_per_node $nproc -m pytest test_gpt_model.py
    torchrun --nproc_per_node $nproc -m pytest test_llava_model.py
    torchrun --nproc_per_node $nproc -m pytest test_multimodal_projector.py
    torchrun --nproc_per_node $nproc -m pytest test_t5_model.py
popd


pushd $CUR_DIR/pipeline_parallel
    torchrun --nproc_per_node $nproc -m pytest test_schedules.py
popd

pushd $CUR_DIR/tensor_parallel
    torchrun --nproc_per_node $nproc -m pytest test_cross_entropy.py
    torchrun --nproc_per_node $nproc -m pytest test_data.py
    torchrun --nproc_per_node $nproc -m pytest test_initialization.py
    torchrun --nproc_per_node $nproc -m pytest test_random.py
    torchrun --nproc_per_node $nproc -m pytest test_tensor_parallel_utils.py

    #qp pair is not enough.
    #torchrun --nproc_per_node $nproc -m pytest test_mappings.py
popd

pushd $CUR_DIR/transformer
    torchrun --nproc_per_node $nproc -m pytest test_attention.py
    torchrun --nproc_per_node $nproc -m pytest test_mlp.py
    torchrun --nproc_per_node $nproc -m pytest test_module.py

    export NVTE_FUSED_ATTN=0
    export NVTE_FLASH_ATTN=0
    torchrun --nproc_per_node $nproc -m pytest test_retro_attention.py
    unset NVTE_FUSED_ATTN
    unset NVTE_FLASH_ATTN

    torchrun --nproc_per_node $nproc -m pytest test_spec_customization.py
    torchrun --nproc_per_node $nproc -m pytest test_transformer_block.py
    torchrun --nproc_per_node $nproc -m pytest test_transformer_layer.py

    torchrun --nproc_per_node $nproc -m pytest moe/test_grouped_mlp.py
    torchrun --nproc_per_node $nproc -m pytest moe/test_routers.py
    torchrun --nproc_per_node $nproc -m pytest moe/test_sequential_mlp.py
    torchrun --nproc_per_node $nproc -m pytest moe/test_token_dispatcher.py
popd

pushd $CUR_DIR/mlu
    torchrun --nproc_per_node $nproc -m pytest test_swiglu.py
    torchrun --nproc_per_node $nproc -m pytest test_geglu.py
popd
