import os
import subprocess
import numpy as np
import Bio
import random
import string
import argparse

def get_args(): 
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("MAP", type=str, help="EM map")
    parser.add_argument("MODEL", type=str, help="Model(PDB format)")
    parser.add_argument('--Method', type=str, dest='Method', default='MDFF', help='Refinement Method:MDFF,relax)')
    parser.add_argument('--Resolution', type=int, dest='Reso', default=5.0, help='Map Resolution')
    parser.add_argument('--OutPath', type=str, dest='OutPath', default='./Refined_data', help='Out put path')
    parser.add_argument('--Ncpu', type=int, dest='Ncpu', default=8, help='Number of CPU cores')
    parser.add_argument('--Gscale', type=float, dest='Gscale', default=0.5, help='Gscale value for MDFF')
    parser.add_argument('--RosettaPath', type=str, dest='RO_PATH', default='~/bin/', help='Gscale value for MDFF')
    args = parser.parse_args()
    return args

def main():
    #cmd=['ln','-s',map_path,input_map]
    #res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')

    args = get_args()
    MAP=args.MAP
    MODEL=args.MODEL
    OutPath=args.OutPath
    RO_PATH=args.RO_PATH
    
    if not os.path.isfile(MAP):
        print('##Missing MAP file:',MAP)
        return 0
    if not os.path.isfile(MODEL):
        print('##Missing MODEL file:',MODEL)
        return 0
    if os.path.isdir(OutPath):
        print('##Please remove dir:',OutPath)
        return 0
    
    cmd=['mkdir','-p',OutPath]
    res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')
    
    if args.Method=='MDFF':
        print('##Start MDFF')
        #Make PSF
        MakePsf(MODEL,OutPath)
        MDFF_SetUp(MODEL,MAP,OutPath,args.Gscale)
        NAMD_Run(MODEL,OutPath,args.Ncpu)
        return 0
    if args.Method=='relax':
        if not os.path.isdir(RO_PATH):
            print('##Can not find ',RO_PATH)
            print('##Plese specify ROSETTA PATH by --RosettaPath')
            return 0
        RosettaRelax(MODEL,MAP,OutPath,args.Reso,RO_PATH)
        return 0
  


def MakePsf(file,path):
    base_file=file.split('/')[-1]
    cp_file=path+'/'+base_file
    if not os.path.isfile(cp_file):
        cmd=['cp',file,cp_file]
        res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')
    
    outfile=path+'/make_psf.txt'
    lines='cd '+path +'\n'
    lines=lines+'package require autopsf\n'
    lines=lines+'mol new '+base_file+'\n'
    lines=lines+'autopsf -mol 0\n'
    lines=lines+'exit\n'
    with open(outfile,"w") as out:
        out.write(lines)
        
    cmd=['vmd','-dispdev','text','-e',outfile]
    res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')

def MDFF_SetUp(pdb_file,map_file,path,gscale):
    
    base_id=pdb_file.split('/')[-1].split('.pdb')[0]
    
    #Copy Map file
    cp_file=path+'/MAP.mrc'
    if not os.path.isfile(cp_file):
        cmd=['cp',map_file,cp_file]
        res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')
    
    outfile=path+'/run.txt'
    
    lines='cd '+path +'\n'
    lines= lines+'package require mdff\nmdff griddx -i MAP.mrc -o MAP.dx\npackage require ssrestraints\npackage require cispeptide\npackage require chirality\n'
    
    lines=lines+'mol new {:s}_autopsf.psf\n'.format(base_id)
    lines=lines+'mol addfile {:s}_autopsf.pdb\n'.format(base_id)
    lines=lines+'mdff gridpdb -psf {:s}_autopsf.psf -pdb {:s}_autopsf.pdb -o {:s}-grid.pdb\n'.format(base_id,base_id,base_id)

    lines=lines+'cispeptide restrain -o {:s}-cispeptide.txt\n'.format(base_id)
    lines=lines+'chirality restrain -o {:s}-chirality.txt\n'.format(base_id)

    lines=lines+'mdff setup -o {:s} -psf {:s}_autopsf.psf -pdb {:s}_autopsf.pdb -griddx MAP.dx -gridpdb {:s}-grid.pdb -extrab {{{:s}-cispeptide.txt {:s}-chirality.txt}} -gscale {:f} -minsteps 10000 -numsteps 1000000\n'.format(base_id,base_id,base_id,base_id,base_id,base_id,gscale)

    lines=lines+'exit'
    
    with open(outfile,"w") as out:
        out.write(lines)
        
    cmd=['vmd','-dispdev','text','-e',outfile]
    res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')
    
def NAMD_Run(pdb_file,path,ncpu):
    base_id=pdb_file.split('/')[-1].split('.pdb')[0]
    namd_file=path+'/'+base_id+'-step1.namd'
    if not os.path.isfile(namd_file):
        print('Error in ',namd_file)
        return 0
    
    
    print('#Please execute the following command!!')
    print('cd ',path)
    print('./charmrun ++local +p{:d} namd2 {:s}-step1.namd'.format(ncpu,base_id))
    
    
def RosettaRelax(pdb_file,map_file,path,reso,ROSETTA_PATH):
    base_id=pdb_file.split('/')[-1].split('.pdb')[0]
    base_file=pdb_file.split('/')[-1]
    #Copy PDB
    cp_file=path+'/'+base_file
    if not os.path.isfile(cp_file):
        cmd=['cp',pdb_file,cp_file]
        res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')
    #Copy Map file
    cp_file=path+'/MAP.mrc'
    if not os.path.isfile(cp_file):
        cmd=['cp',map_file,cp_file]
        res=subprocess.run(cmd,stdout=subprocess.PIPE,encoding='utf-8')
        
    XML='''
    <ROSETTASCRIPTS>
   <SCOREFXNS>
      <ScoreFunction name="dens" weights="beta_cart">
         <Reweight scoretype="elec_dens_fast" weight="35.0"/>
         <Set scale_sc_dens_byres="R:0.76,K:0.76,E:0.76,D:0.76,M:0.76,C:0.81,Q:0.81,H:0.81,N:0.81,T:0.81,S:0.81,Y:0.88,W:0.88,A:0.88,F:0.88,P:0.88,I:0.88,L:0.88,V:0.88"/>
      </ScoreFunction>
   </SCOREFXNS>

   <MOVERS>
       <SetupForDensityScoring name="setupdens"/>
       <LoadDensityMap name="loaddens" mapfile="MAP.mrc"/>
       <FastRelax name="relaxcart" scorefxn="dens" repeats="2" cartesian="1"/>
   </MOVERS>

   <PROTOCOLS>
      <Add mover="setupdens"/>
      <Add mover="loaddens"/>
      <Add mover="relaxcart"/>
   </PROTOCOLS>
   <OUTPUT scorefxn="dens"/>

</ROSETTASCRIPTS>
    '''
    
    outfile=path+'/relax.xml'
    with open(outfile,"w") as out:
        out.write(XML)
    
    BASH='''#!/bin/bash

ROSETTA3={:s}

$ROSETTA3/source/bin/rosetta_scripts.static.linuxgccrelease \\
 -database $ROSETTA3/database/ \\
 -in::file::s {:s}  \\
 -parser::protocol relax.xml \\
 -ignore_unrecognized_res \\
 -edensity::mapreso {:f} \\
 -edensity::cryoem_scatterers \\
 -crystal_refine \\
 -beta \
 -out::suffix _relax \\
 -default_max_cycles 200
'''.format(ROSETTA_PATH,base_file,reso)
    
    outfile=path+'/relax.sh'
    with open(outfile,"w") as out:
        out.write(BASH)
    
    print('#Please execute the following command!!')
    print('cd ',path)
    print('bash ./relax.sh')

    
if __name__ == '__main__':
    main()