#!/usr/bin/python
import numpy as np
import argparse
parser = argparse.ArgumentParser()
# Arguments (required)
parser.add_argument('L', type=int, help='length input for N=2(l^3) atom configuration')
parser.add_argument('T',type=float,help='temperature of run [K]')
parser.add_argument('MCS',type=int,help='total Monte Carlo steps per site to perform')
# Options
parser.add_argument('--energies', metavar='EFILE', default='../MoNbTaW_tetEs', help='file containing lookup table of energies for all\
        256 decorations of tetrahedra; default: `%(default)s`')
parser.add_argument('--infile', metavar='INFILE', help='input configuration file of atoms, in proper format; if none, random equiatomic configuration is generated')
parser.add_argument('--outfile', metavar='OUTFILE', default='MCout',\
        help='name of output configuration file; default: `%(default)s`')
parser.add_argument('-w', action='store_true', help="write to a histogram file corresponding to T, e.g. 'Histo-E-p_T3000.dat'")
parser.add_argument('--histdir', metavar='HISTDIR', help='directory for histogram data files; default: `%(default)s`', default='../histogram_data')
parser.add_argument('--corrdir', metavar='CORRDIR', help='directory for tetrahedra correlation data files; default: `%(default)s`',\
        default='../correlations_data')
args = parser.parse_args()

kB = 8.617333262e-5 # Boltzmann constant (eV / K) 
NNdirs = [(-1,1,1),(-1,-1,1),(-1,-1,-1),(-1,1,-1),(1,1,1),(1,-1,1),(1,-1,-1),(1,1,-1)]

L=2*args.L; N=2*(args.L**3)
beta = 1./(kB * args.T)
Steps = N * args.MCS

allCoords=[]
for l in range(L): 
    for r in range(L): 
        for c in range(L):
            if l%2==r%2==c%2: allCoords.append((l,r,c))

J = {}
with open(args.energies, 'r') as TetrahedronEnergies:
    for line in TetrahedronEnergies:
        cE = line.split()
        a,b,c,d,E=int(cE[0]),int(cE[1]),int(cE[2]),int(cE[3]),float(cE[4])
        J[(a,b,c,d)]=E/24.0

def getTetrahedra():
    trigs = [[(2,0,0),(1,1,1),(1,1,-1)],[(2,0,0),(1,1,1),(1,-1,1)],[(0,2,0),(1,1,1),(1,1,-1)],
             [(0,2,0),(1,1,1),(-1,1,1)],[(0,0,2),(1,1,1),(-1,1,1)],[(0,0,2),(1,1,1),(1,-1,1)]]
    allTet=set()
    for (l0,r0,c0) in allCoords:
        SL0=(l0+r0+c0)%4
        for i in range(6):
            sites=[0,0,0,0]
            sites[SL0]=(l0,r0,c0)
            sites[(SL0+2)%4]=((l0+trigs[i][0][0])%L,(r0+trigs[i][0][1])%L,(c0+trigs[i][0][2])%L)
            sites[(SL0+3)%4]=((l0+trigs[i][1][0])%L,(r0+trigs[i][1][1])%L,(c0+trigs[i][1][2])%L)
            sites[(SL0+1)%4]=((l0+trigs[i][2][0])%L,(r0+trigs[i][2][1])%L,(c0+trigs[i][2][2])%L)
            allTet.add(tuple(sites))
    return allTet

allTetrahedra=getTetrahedra()

# Read in config file if given
if args.infile != None:
    s0 = np.loadtxt(args.infile, dtype=int).reshape((L,L,L))
else:
    s0 = np.zeros((L,L,L),dtype=int)
    pool=(N/4)*[1,2,3,4]
    species=np.random.permutation(pool)
    for n in range(N):
        (l,r,c)=allCoords[n]
        s0[l,r,c]=species[n]
sig = np.copy(s0)

# Tetrahedra : dictionary with sets of all tetrahedra containing any
# given site in the lattice (keys) as its values
Tetrahedra = {}
for (l,r,c) in allCoords:
    Tetrahedra[(l,r,c)]=frozenset([tet for tet in allTetrahedra if (l,r,c) in tet])

def Ham(sig):
    H=0.
    for tet in allTetrahedra:
        H+=J[sig[tet[0]],sig[tet[2]],sig[tet[1]],sig[tet[3]]]
    return H

def newDeco(sig,tet,I0,I1):
    sites=[0,0,0,0]
    for (l,r,c) in tet:
        sites[(l+r+c)%4]=(l,r,c)
    if I0 in sites and I1 in sites:
        i0=sites.index(I0); i1=sites.index(I1)
        sites[i0],sites[i1]=sites[i1],sites[i0]
    elif I0 in sites:
        i0=sites.index(I0)
        sites[i0]=I1
    else:
        i1=sites.index(I1)
        sites[i1]=I0
    a,b,c,d=sig[sites[0]],sig[sites[2]],sig[sites[1]],sig[sites[3]] 
    return (a,b,c,d) 

def dE(sig,(l0,r0,c0),(l1,r1,c1)):
    sig0 = sig[l0,r0,c0] ; sig1 = sig[l1,r1,c1]
    if sig0 == sig1: return 0.
    E_initial = 0.; E_final = 0.
    changedTetrahedra = Tetrahedra[(l0,r0,c0)].union(Tetrahedra[(l1,r1,c1)]) # UNION between the two swap sites' tetrahedra
    for tet0 in changedTetrahedra:
        E_initial += J[(sig[tet0[0]],sig[tet0[2]],sig[tet0[1]],sig[tet0[3]])]
        E_final += J[newDeco(sig,tet0,(l0,r0,c0),(l1,r1,c1))]
    dE = (E_final-E_initial)
    return dE

def getNeven(sig):
    Neven = np.zeros(4)
    for (l,r,c) in allCoords: # N loops
        if (l%2==r%2==c%2==0): Neven[sig[l,r,c]-1] += 1.
    return Neven

def getXs(sig):
    tX = np.zeros((4,4))
    for (l,r,c) in allCoords:
        tX[(l+r+c)%4][sig[l,r,c]-1] += 1.
    tX *= 4./N
    return tX

def getZ(sig,SLa,SLb,SLg,SLd):
    tZ = np.zeros((4,4,4,4))
    for tet in allTetrahedra:
        (A,B,C,D) = (sig[tet[SLa]]-1,sig[tet[SLb]]-1,sig[tet[SLg]]-1,sig[tet[SLd]]-1)
        tZ[A,B,C,D] += 1.
    #tZ /= (6.*N)
    return tZ

E0=Ham(sig); E=E0

# Main loop
for step in range(1,Steps+1):
    if (Steps <= 0): break
    d=np.random.choice([0,1]) # Randomly select even or odd
    l0=2*np.random.randint(args.L)+d # Random l
    r0=2*np.random.randint(args.L)+d # Random r
    c0=2*np.random.randint(args.L)+d # Random c
    # Randomly select adjacent swap site:
    (dl,dr,dc)=NNdirs[np.random.randint(8)] 
    # Coordinate of the swap site is (l1,r1,c1):
    (l1,r1,c1)=((l0+dl)%L, (r0+dr)%L, (c0+dc)%L) 
    dE_01=dE(sig,(l0,r0,c0),(l1,r1,c1))
    # Since beta is required to be postive,
    # dE < 0 ==> Guaranteed swap
    # dE > 0 ==> Swap probability given by exp(-beta*dE)
    Boltz=1. if dE_01 <= 0 else np.exp(-beta*dE_01)
    r=np.random.random()
    if (r < Boltz):
        if sig[l0,r0,c0] != sig[l1,r1,c1]:
            sig[l0,r0,c0],sig[l1,r1,c1]=sig[l1,r1,c1],sig[l0,r0,c0]
            E+=dE_01

Neven = getNeven(sig)
p = 4.0*float(abs(Neven[2]-Neven[0]))/N
# Assign sublattices according to concentration of Mo (alpha highest, beta NNNs),
# then others on Ta (gamma highest, delta NNNs)
Xs = getXs(sig)
tmpA=max([(Xs[0,0],0),(Xs[1,0],1),(Xs[2,0],2),(Xs[3,0],3)])
SLa=tmpA[1]; SLb=(SLa+2)%4
tmpG=max([(Xs[(SLa+1)%4,2],(SLa+1)%4),(Xs[(SLa+3)%4,2],(SLa+3)%4)])
SLg=tmpG[1]; SLd=(SLg+2)%4

ZABGD = getZ(sig,SLa,SLb,SLg,SLd)

with open(args.outfile, 'w') as OutFile:
    for slice in sig:
        np.savetxt(OutFile, slice, fmt='%1d')
        OutFile.write('\n')
    OutFile.close()

# Write data to files
if args.w:
    with open(args.histdir+'/Histo-E-p_T%04d.dat' % args.T, 'a') as HistoFile:
        HistoFile.write('%12.12f %1.12f\n' % (E,p))
        HistoFile.close()
    with open(args.corrdir+'/ZABGD_T%04d' % args.T, 'a') as CorrsFile:
        CorrsFile.write('%d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d\n' % tuple(ZABGD.flatten()))
        CorrsFile.close()
