
#    This file is part of InEXEQ.
#
#    InEXEQ is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <https://www.gnu.org/licenses/>.
# 
# Copyright (C) Maciej Bartkowiak, 2014-2018, 2020

__doc__ = """
This module defines a crystal sample and its orientation
in the context of the instrument. It then produces the
reciprocal-space lattice.
Depends on ASE at the moment, although this should be
fairly easy to fix.
Maciej Bartkowiak, 28 Mar 2014

No more ASE dependency. M.B. 21 Aug 2014
"""

import numpy as np
# import scipy.interpolate as scint
from geometry import *

class Sample:
    def __init__(self, constants = np.array(3*[4.0]), angles = np.array(3*[90.0])):
        self.abc = constants
        self.angles = np.radians(angles)
        self.atoms = []
        self.G = []
        self.Gstar = []
        self.B = []
        self.Binv = []
        self.U = []
        self.Uinv = []
        self.repang = None
        self.repvec = None
        self.u, self.v, self.w = None, None, None
        self.rotmat = None
        self.step = None
        self.reciprocal_vectors()
        self.themat = []
    def makeG(self):
        g = np.zeros((3,3))
        g[0,0] = self.abc[0]**2
        g[1,1] = self.abc[1]**2
        g[2,2] = self.abc[2]**2
        g[0,1] = self.abc[0]*self.abc[1]*np.cos(self.angles[2])
        g[0,2] = self.abc[0]*self.abc[2]*np.cos(self.angles[1])
        g[1,2] = self.abc[1]*self.abc[2]*np.cos(self.angles[0])
        g[1,0] = g[0,1]
        g[2,0] = g[0,2]
        g[2,1] = g[1,2]
        self.G = g
    def makeGstar(self):
        g = self.G
        gstar = np.linalg.inv(g)
        self.Gstar = gstar
    def makeRepvec(self):
        vecs = np.zeros(3)
        angs = np.zeros(3)
        vecs[0] = np.sqrt(self.Gstar[0,0])
        vecs[1] = np.sqrt(self.Gstar[1,1])
        vecs[2] = np.sqrt(self.Gstar[2,2])
        angs[0] = np.arccos(self.Gstar[1,2]/vecs[1]/vecs[2])
        angs[1] = np.arccos(self.Gstar[0,2]/vecs[0]/vecs[2])
        angs[2] = np.arccos(self.Gstar[0,1]/vecs[0]/vecs[1])
        self.repvec = vecs
        self.repang = angs
    def reciprocal_vectors(self):
        self.makeG()
        self.makeGstar()
        self.makeRepvec()
        self.makeB()
        # print self.repvec
        # print self.repang
    def makeB(self):
        vecs = self.repvec
        angs = self.repang
        b = np.zeros((3,3))
        b[0,0] = vecs[0]
        b[0,1] = vecs[1] * np.cos(angs[2])
        b[0,2] = vecs[2] * np.cos(angs[1])
        b[1,1] = vecs[1] * np.sin(angs[2])
        b[1,2] = - vecs[2] * np.sin(angs[1]) * np.cos(angs[0])
        b[2,2] = 1 / self.abc[2]
        # b *= 2*np.pi
        binv = np.linalg.inv(b)
        self.B = b
        self.Binv = binv
        # print "B: ", b
        # print "B':", binv
    def makeU(self, u = np.array([1.0, 0.0, 0.0]), v = np.array([0.0, 1.0, 0.0])):
        b = self.B
        bu = normalise(np.dot(b, u))
        bv = normalise(np.dot(b, v))
        if (length(bu) < 1e-6) or (length(bv) < 1e-6):
            return False
        bw = normalise(np.cross(bu, bv))
        if (length(bw) < 1e-6):
            return False
        bv = normalise(np.cross(bw, bu))
        lab = np.zeros((3,3))
        lab[0,1] = -1.0
        lab[1,2] = 1.0
        lab[2,0] = -1.0
        #lab[1,1] = 1.0
        #lab[0,2] = 1.0
        #lab[2,0] = 1.0    
        tau = np.linalg.inv(np.column_stack([bu, bv, bw]))
        u = np.dot(lab, tau)
        # print u
        self.U = u
        # print "U: ", u
        return True
    def orient(self, u = np.array([1.0, 0.0, 0.0]), v = np.array([0.0, 1.0, 0.0]), rot = 0.0,
                     goniometer = [0.0, 0.0, 0.0]):
        """
        Creates a rotation matrix based on the u, v vectors.
        It is a prerequisite for outputting the oriented reciprocal lattice.
        """
        b, binv = self.B, self.Binv
        bu, bv = normalise(np.dot(b, u)), normalise(np.dot(b, v))
        if (length(bu) < 1e-6) or (length(bv) < 1e-6):
            return False
        bw = normalise(np.cross(bu, bv))
        if (length(bw) < 1e-6):
            return False
        bv = normalise(np.cross(bw, bu))
        u, v = normalise(np.dot(binv, bu)), normalise(np.dot(binv, bv))
        phi, chi, omega = goniometer[0], goniometer[1], goniometer[2]
        axis_phi, axis_chi, axis_omega = normalise(u.copy()), normalise(v.copy()), normalise(np.cross(v,u))
        rmat = np.dot(np.dot(self.Binv, arb_rotation(axis_omega, rot-omega)), self.B)
        axis_phi = np.dot(rmat, axis_phi)
        axis_chi = np.dot(rmat, axis_chi)
        axis_omega = np.dot(rmat, axis_omega)
        u = np.dot(rmat, u)
        v = np.dot(rmat, v)
        rmat = np.dot(np.dot(self.Binv, arb_rotation(axis_chi, chi)), self.B)
        axis_phi = np.dot(rmat, axis_phi)
        axis_chi = np.dot(rmat, axis_chi)
        axis_omega = np.dot(rmat, axis_omega)
        u = np.dot(rmat, u)
        v = np.dot(rmat, v)
        rmat = np.dot(np.dot(self.Binv, arb_rotation(axis_phi, phi)), self.B)
        axis_chi = np.dot(rmat, axis_chi)
        axis_omega = np.dot(rmat, axis_omega)
        axis_omega = np.dot(rmat, axis_omega)
        self.u = normalise(np.dot(rmat, u))
        self.v = normalise(np.dot(rmat, v))
        #u = np.dot(rmat, u)
        #v = np.dot(rmat, v)
        #axis = np.cross(v, u)
        #rmat = np.dot(np.dot(self.Binv, arb_rotation(axis, rot)), self.B)
        #self.u = normalise(np.dot(rmat, u))
        #self.v = normalise(np.dot(rmat, v))
        itworked = self.makeU(self.u, self.v)
        self.rotmat = self.U
        self.themat = np.array(np.matrix(self.rotmat)*np.matrix(self.B))
        # self.themat = np.row_stack([themat[1],themat[0],themat[2]])
        return itworked
    def description(self):
        newu, newv = self.u.copy(), self.v.copy()
        for i in range(3):
            if abs(newu[i]) < 1e-6:
                newu[i] = 0.0
            if abs(newv[i]) < 1e-6:
                newv[i] = 0.0
        text = "Unit cell: " + ", ".join([str(x) for x in self.abc]) + "  " + "$^{\circ}$, ".join([str(round(x,2)) for x in np.degrees(self.angles)]) + "$^{\circ}$\n"
        text += " u = " + ",".join([str(round(x,4)) for x in normalise(newu)]) + ", v = " + ",".join([str(round(x,4)) for x in normalise(newv)])
        return text
