#!/bin/bash
let ngpus=4

if [ -z "$OMPI_COMM_WORLD_LOCAL_SIZE" ]; then
  let OMPI_COMM_WORLD_LOCAL_SIZE=1
  let OMPI_COMM_WORLD_LOCAL_RANK=0
fi

#---------------------------------------------
# start MPS
#---------------------------------------------
if [ $OMPI_COMM_WORLD_LOCAL_RANK = 0 ]; then
  if [ $OMPI_COMM_WORLD_RANK = 0 ]; then
    echo starting mps ...
  fi
  for ((i=0; i< $ngpus; i++))
  do
   rm -rf /dev/shm/${USER}/mps_$i
   rm -rf /dev/shm/${USER}/mps_log_$i
   mkdir -p /dev/shm/${USER}/mps_$i
   mkdir -p /dev/shm/${USER}/mps_log_$i
   export CUDA_VISIBLE_DEVICES=$i
   export CUDA_MPS_PIPE_DIRECTORY=/dev/shm/${USER}/mps_$i
   export CUDA_MPS_LOG_DIRECTORY=/dev/shm/${USER}/mps_log_$i
   /usr/bin/nvidia-cuda-mps-control -d
   sleep 1
  done
fi

#---------------------------------------------
# set CUDA_MPS_PIPE_DIRECTORY per MPI rank
#---------------------------------------------
let product=$ngpus*$OMPI_COMM_WORLD_LOCAL_RANK
let mydevice=$product/$OMPI_COMM_WORLD_LOCAL_SIZE
printf -v myfile "/dev/shm/${USER}/mps_%d" $mydevice

export CUDA_MPS_PIPE_DIRECTORY=$myfile
unset CUDA_VISIBLE_DEVICES

#---------------------------------------------
# run the program
#---------------------------------------------
"$@"

#---------------------------------------------
# stop  MPS
#---------------------------------------------
if [ $OMPI_COMM_WORLD_LOCAL_RANK = 0 ]; then
  if [ $OMPI_COMM_WORLD_RANK = 0 ]; then
    echo stoping mps ...
  fi
  for ((i=0; i< $ngpus; i++))
  do
   export CUDA_MPS_PIPE_DIRECTORY=/dev/shm/${USER}/mps_$i
   echo "quit" | /usr/bin/nvidia-cuda-mps-control
   sleep 1
   rm -rf /dev/shm/${USER}/mps_$i
   rm -rf /dev/shm/${USER}/mps_log_$i
  done
fi
