#! /bin/bash
set -e

CUR_DIR=$(cd `dirname $0`; pwd)
DATESTR=$(date +"%m-%d-%H-%M")

get_prop() {
    grep "${1}" ${2} | cut -d'=' -f2
}


#set dataset
MODE=${2-"ONLINE"} # [ONLINE, OFFLINE]
source $CUR_DIR/prepare_once.sh $MODE

# DDP options
GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=29600
NNODES=1
NODE_RANK=0

# distributed options
TP_SIZE=8
PP_SIZE=1

# model options
LAYERS=10
HIDDEN_SIZE=8192
NUM_ATTN_HEADS=64
MICRO_BATCH_SIZE=8
GLOBAL_BATCH_SIZE=8
SEQ_LEN=4096
MAX_POSITION_EMBEDDINGS=4096

#train options
TRAIN_ITERS=${1-50}

DDP_OPTIONS="--nproc_per_node $GPUS_PER_NODE \
             --nnodes $NNODES \
             --node_rank $NODE_RANK \
             --master_addr $MASTER_ADDR \
             --master_port $MASTER_PORT"

DISTRIBUTED_OPTIONS="--distributed-backend cncl \
                     --sequence-parallel \
                     --tensor-model-parallel-size $TP_SIZE"

RECOMPUTE_OPTIONS="--recompute-granularity selective"

MODEL_OPTIONS="--num-layers $LAYERS \
               --hidden-size $HIDDEN_SIZE \
               --num-attention-heads $NUM_ATTN_HEADS \
               --micro-batch-size $MICRO_BATCH_SIZE \
               --global-batch-size $GLOBAL_BATCH_SIZE \
	       --use-mcore-models \
	       --attention-softmax-in-fp32 \
               --seq-length $SEQ_LEN \
               --max-position-embeddings $MAX_POSITION_EMBEDDINGS"


TRAIN_OPTIONS="--bf16 \
               --train-iters $TRAIN_ITERS \
               --lr-decay-iters 320000 \
               --data-path $LLAMA2_OPENWEBTEXT_DATASET_PATH \
               --split 949,50,1 \
               --lr 0.00015 \
               --lr-decay-style cosine \
               --min-lr 1.0e-5 \
               --weight-decay 1e-2 \
               --clip-grad 1.0 \
               --lr-warmup-fraction .01 \
               --log-interval 1 \
               --save-interval 10000 \
               --eval-interval 1000 \
               --eval-iters 1"

LLAMA_OPTS="--normalization RMSNorm \
            --ffn-hidden-size 28672 \
	    --attention-dropout 0 \
	    --hidden-dropout 0 \
	    --swiglu \
	    --tokenizer-type Llama2Tokenizer \
	    --tokenizer-model $LLAMA2_TOKENIZER_MODEL \
	    --use-rotary-position-embeddings \
	    --untie-embeddings-and-output-weights \
	    --group-query-attention \
	    --num-query-groups 8"

OPT_ARGS="--use-flash-attn \
	  --use-distributed-optimizer"


cmd="torchrun $DDP_OPTIONS \
    pretrain_gpt.py \
    $OPT_ARGS \
    $LLAMA_OPTS \
    $DISTRIBUTED_OPTIONS \
    $MODEL_OPTIONS \
    $TRAIN_OPTIONS \
    $RECOMPUTE_OPTIONS"


echo $cmd
eval $cmd 2>&1 | tee $CUR_DIR/logs/log-llama2-GQA-$DATESTR.txt
