# This example script is contributed by external user https://github.com/LydiaXiaohongLi
#!/bin/bash
set -e
 
DATESTR=$(date +"%m-%d-%H-%M")
CUR_DIR=$(cd `dirname $0`; pwd)

#set dataset
MODE=${2-"ONLINE"} # [ONLINE, OFFLINE]
source $CUR_DIR/prepare_once.sh $MODE

######################################
# Change the below configurations here

TP=1
PP=1
 
GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6899
NNODES=1
NODE_RANK=${1-"0"}
 
HIDDEN_SIZE=4096
NUM_LAYERS=30
NUM_HEADS=32
SEQ_LENGTH=2048

MICRO_BATCH_SIZE=2
GLOBAL_BATCH_SIZE=16
#train options
TRAIN_ITERS=${1-50}
LR=2e-5
MIN_LR=3e-6
LR_WARMUP_STEPS=2
WEIGHT_DECAY=0.0001
GRAD_CLIP=1
 
BLOOM_OPTIONS="--position-embedding-type alibi"
                                
MODEL_OPTIONS="--num-layers $NUM_LAYERS \
               --hidden-size $HIDDEN_SIZE \
               --num-attention-heads $NUM_HEADS \
               --seq-length $SEQ_LENGTH \
               --max-position-embeddings $SEQ_LENGTH \
	       --use-mcore-models \
	       --attention-softmax-in-fp32 \
               --micro-batch-size $MICRO_BATCH_SIZE \
               --global-batch-size $GLOBAL_BATCH_SIZE"


TRAINING_OPTIONS="--lr $LR \
                  --lr-decay-style cosine \
                  --min-lr $MIN_LR \
                  --adam-beta1 0.9 \
                  --adam-beta2 0.95 \
                  --adam-eps 1e-5 \
                  --weight-decay $WEIGHT_DECAY \
                  --clip-grad $GRAD_CLIP \
                  --lr-warmup-iters $LR_WARMUP_STEPS"
                
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"
 
TB_ARGS="--transformer-impl transformer_engine \
         --tp-comm-overlap"

cmd="torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
        $MODEL_OPTIONS \
        $TRAINING_OPTIONS \
        $BLOOM_OPTIONS \
        --tensor-model-parallel-size $TP \
        --pipeline-model-parallel-size $PP \
        --sequence-parallel \
        --train-iters $TRAIN_ITERS \
        --data-path $LLAMA2_OPENWEBTEXT_DATASET_PATH \
        --tokenizer-type GPT2BPETokenizer \
        --vocab-file $GPT2_VOCAB_PATH \
        --merge-file $GPT2_MERGES_PATH \
        --split 949,50,1 \
        --distributed-backend cncl \
        --log-interval 1 \
        --save-interval 10000 \
        --eval-interval 1000 \
        --recompute-granularity selective \
        --eval-iters 10 \
        --bf16"


echo $cmd
eval $cmd 2>&1 | tee $CUR_DIR/logs/log-bloom-7B-$DATESTR.txt
