#!/bin/bash
set -e

#set dataset
CUR_DIR=$(cd `dirname $0`; pwd)
MODE=${2-"ONLINE"} # [ONLINE, OFFLINE]
source $CUR_DIR/prepare_once.sh $MODE


GPUS_PER_NODE=8
MASTER_PORT=29600
MASTER_ADDR=localhost
NNODES=1
NODE_RANK=0

PP=1
TP=8
EP=1

MB=2     #4
GB=16    #32
TRAIN_STEPS=${1-50}

DISTRIBUTED_ARGS=(--nproc_per_node $GPUS_PER_NODE
                  --nnodes $NNODES
                  --node_rank $NODE_RANK
                  --master_addr $MASTER_ADDR
                  --master_port $MASTER_PORT
)

MODEL_ARGS=(
    --use-mcore-models
    --disable-bias-linear
    --seq-length 32768
    --max-position-embeddings 32768
    --num-layers 1  #32
    --hidden-size 4096
    --ffn-hidden-size 14336
    --num-attention-heads 32
    --init-method-std 0.01
    --attention-dropout 0.0
    --hidden-dropout 0.0
    --normalization RMSNorm
    --position-embedding-type rope
    --swiglu
    --untie-embeddings-and-output-weights
    --group-query-attention
    --num-query-groups 8
    --no-masked-softmax-fusion
    --no-position-embedding
    --no-rope-fusion
    --make-vocab-size-divisible-by 100
)

MOE_ARGS=(
    --num-experts 8
    --expert-model-parallel-size $EP
    --moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, None. Default is aux_loss.
    --moe-router-topk 2
    --moe-aux-loss-coeff 1e-2
    --moe-grouped-gemm
    )

DATA_ARGS=(
    --tokenizer-type Llama2Tokenizer
    --tokenizer-model $LLAMA2_TOKENIZER_MODEL \
    --data-path $LLAMA2_OPENWEBTEXT_DATASET_PATH \
    --split 99990,8,2
)

TRAINING_ARGS=(
    --micro-batch-size $MB
    --global-batch-size $GB
    --lr 1e-4
    --train-iters $TRAIN_STEPS \
    --lr-decay-iters 320000
    --lr-decay-style cosine
    --min-lr 1.0e-5
    --weight-decay 0.1
    --lr-warmup-iters 500
    --clip-grad 1.0
    --bf16
    --use-flash-attn
)

MODEL_PARALLEL_ARGS=(
    --tensor-model-parallel-size $TP
    --pipeline-model-parallel-size $PP
    --sequence-parallel
    --use-distributed-optimizer
)

LOGGING_ARGS=(
    --log-interval 1
    --save-interval 10000
    --eval-interval 1000
    --eval-iters 10
    --no-load-optim
    --no-load-rng
)


if [ -n "${WANDB_API_KEY}" ]; then
    LOGGING_ARGS+=(
        --wandb-project ${WANDB_PROJECT:-"Mixtral-Finetuning"}
        --wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"} 
    )
fi

CMD="torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
    ${MODEL_ARGS[@]} \
    ${MOE_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${LOGGING_ARGS[@]}"

echo ${CMD}
eval $CMD 2>&1 | tee $CUR_DIR/logs/log-mixtral-mini-$DATESTR.txt
