# Copyright (C) 2021 Sumeet Kulkarni
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http: //www.gnu.org/licenses/>.

"""
Regression tests for the precession-averaged evolution in lalsimulation.tilts_at_infinity
Data generated by <https://git.ligo.org/waveforms/reviews/spin-tilt-angles-at-infinity/-/blob/master/generate_hybrid_evol_data.py>
Based on test_prec_avg_evol.py
"""

import os
import sys
import pytest
import numpy as np

from lalsimulation.tilts_at_infinity import calc_tilts_at_infty_hybrid_evolve

import sys

if sys.version_info[0] < 3:

    import warnings
    warnings.warn("this test module does not support python2")
    sys.exit(77)

# Check if we are in the TEST phase of a conda build (this test is very slow in other parts of the CI with v1)
CONDA_BUILD_TEST = os.getenv("CONDA_BUILD_STATE") == "TEST"

# -- regression data for tilts at infinity ---------------------

# Format: (m1, m2, chi1, chi2, tilt1, tilt2, phi12, fref, tilt1_inf, tilt2_inf, version)
# There are 10 completely random cases and 3 stable aligned-spin cases, generated using generate_hybrid_evol_data.py (see path above)
test_data_inf_v1 = [
    (1.3862687342652575e+32, 1.5853186050191907e+31, 0.8768912154180827, 0.9635416612042661, 2.8861591668037119, 2.7423707262813442, 4.7502537251642867, 8.0000000000000000, 2.8861523354649581, 2.7426263439639471, 'v1'),
    (4.0380177255695994e+31, 2.1111685497317552e+31, 0.9442047756726544, 0.2197148251155545, 2.7060072810080551, 0.8920951236808333, 1.7330264974887994, 14.0000000000000000, 2.7084617950997880, 0.8811507171110192, 'v1'),
    (1.4778236544770486e+32, 2.6197742077777032e+31, 0.4650532384488123, 0.4135203147241133, 2.5477872046486589, 1.3374887745402186, 5.8300235171959054, 15.0000000000000000, 2.5310633997758987, 1.4020814948117246, 'v1'),
    (1.4863232012364162e+32, 8.7793565490275959e+31, 0.5893323731604954, 0.2055055162882128, 0.6903364245776367, 1.7543770650607788, 5.0236304887214525, 11.0000000000000000, 0.6586043676131283, 1.8560196545892511, 'v1'),
    (1.7178573021561749e+32, 1.4944497411052174e+32, 0.2957920607635095, 0.4434151999944888, 2.5141791285526289, 1.7220573548170679, 3.6800935313971785, 19.0000000000000000, 2.7685235682785141, 1.6282719485244728, 'v1'),
    (1.9425515209586304e+32, 1.5834862257761549e+32, 0.2296658494450905, 0.6534173243439886, 1.3270660171515838, 2.8939960781679810, 0.3639705457070962, 13.0000000000000000, 1.4488714781917917, 2.7310319713359963, 'v1'),
    (7.6790541604603600e+31, 1.9132537018679275e+31, 0.1199773081324423, 0.9386480945327058, 0.5035407043069747, 1.7892511812390561, 1.0055563261907063, 19.0000000000000000, 0.4706275128556943, 1.7976334437674295, 'v1'),
    (1.9825472857867221e+32, 1.4367644552485066e+32, 0.4312785756276836, 0.5477442096539639, 1.1515192569867874, 1.3370105009905482, 5.6711539828043662, 19.0000000000000000, 0.7276894817434479, 1.7152906933033634, 'v1'),
    (1.4345358396311136e+31, 1.1150061575178080e+31, 0.6275566268091129, 0.1660169861331117, 0.9178548738199007, 2.8069175353915532, 2.2215576693959620, 13.0000000000000000, 0.9581073572917457, 2.4754114721784592, 'v1'),
    (1.2211378968321533e+32, 6.0296417691783719e+31, 0.7112954381378587, 0.9045511111160647, 0.0351734285267389, 2.7472710892536298, 4.7276274412397177, 8.0000000000000000, 0.1626080752314657, 2.6968435922833787, 'v1'),
    (1.6081881116676779e+32, 8.5515110662266332e+31, 0.4498334498283801, 0.4689712388898268, 3.1415926535897931, 0.0000000000000000, 3.9268299199695771, 17.0000000000000000, 3.1415926535897931, 0.0000000000000000, 'v1'),
    (9.4223304028886096e+31, 2.8099837478253088e+31, 0.5863953401030325, 0.9151619021800407, 3.1415926535897931, 3.1415926535897931, 1.0463785798894087, 5.0000000000000000, 3.1415926535897931, 3.1415926535897931, 'v1')
]

test_data_inf_v2 = [
    (1.3862687342652575e+32, 1.5853186050191907e+31, 0.8768912154180827, 0.9635416612042661, 2.8861591668037119, 2.7423707262813442, 4.7502537251642867, 8.0000000000000000, 2.8861523834312006, 2.7426209139511317, 'v2'),
    (4.0380177255695994e+31, 2.1111685497317552e+31, 0.9442047756726544, 0.2197148251155545, 2.7060072810080551, 0.8920951236808333, 1.7330264974887994, 14.0000000000000000, 2.7084614334650650, 0.8811589982618666, 'v2'),
    (1.4778236544770486e+32, 2.6197742077777032e+31, 0.4650532384488123, 0.4135203147241133, 2.5477872046486589, 1.3374887745402186, 5.8300235171959054, 15.0000000000000000, 2.5310634517965811, 1.4020833745755934, 'v2'),
    (1.4863232012364162e+32, 8.7793565490275959e+31, 0.5893323731604954, 0.2055055162882128, 0.6903364245776367, 1.7543770650607788, 5.0236304887214525, 11.0000000000000000, 0.6586045552673946, 1.8560206324931039, 'v2'),
    (1.7178573021561749e+32, 1.4944497411052174e+32, 0.2957920607635095, 0.4434151999944888, 2.5141791285526289, 1.7220573548170679, 3.6800935313971785, 19.0000000000000000, 2.7687000685659777, 1.6282203844190197, 'v2'),
    (1.9425515209586304e+32, 1.5834862257761549e+32, 0.2296658494450905, 0.6534173243439886, 1.3270660171515838, 2.8939960781679810, 0.3639705457070962, 13.0000000000000000, 1.4488159977052626, 2.7310840337180102, 'v2'),
    (7.6790541604603600e+31, 1.9132537018679275e+31, 0.1199773081324423, 0.9386480945327058, 0.5035407043069747, 1.7892511812390561, 1.0055563261907063, 19.0000000000000000, 0.4706276744704514, 1.7976330588548823, 'v2'),
    (1.9825472857867221e+32, 1.4367644552485066e+32, 0.4312785756276836, 0.5477442096539639, 1.1515192569867874, 1.3370105009905482, 5.6711539828043662, 19.0000000000000000, 0.7277234016736369, 1.7152658660415474, 'v2'),
    (1.4345358396311136e+31, 1.1150061575178080e+31, 0.6275566268091129, 0.1660169861331117, 0.9178548738199007, 2.8069175353915532, 2.2215576693959620, 13.0000000000000000, 0.9581139400168959, 2.4753988230147987, 'v2'),
    (1.2211378968321533e+32, 6.0296417691783719e+31, 0.7112954381378587, 0.9045511111160647, 0.0351734285267389, 2.7472710892536298, 4.7276274412397177, 8.0000000000000000, 0.1626087948291107, 2.6968398428505553, 'v2'),
    (1.6081881116676779e+32, 8.5515110662266332e+31, 0.4498334498283801, 0.4689712388898268, 3.1415926535897931, 0.0000000000000000, 3.9268299199695771, 17.0000000000000000, 3.1415926535897931, 0.0000000000000000, 'v2'),
    (9.4223304028886096e+31, 2.8099837478253088e+31, 0.5863953401030325, 0.9151619021800407, 3.1415926535897931, 3.1415926535897931, 1.0463785798894087, 5.0000000000000000, 3.1415926535897931, 3.1415926535897931, 'v2')
]

# -- test functions ---------------------

@pytest.mark.skipif(not CONDA_BUILD_TEST, reason="This test only runs quickly in conda")
@pytest.mark.parametrize("m1, m2, chi1, chi2, tilt1, tilt2, phi12, fref, tilt1_inf, tilt2_inf, version", test_data_inf_v1)
def test_inf_v1(m1, m2, chi1, chi2, tilt1, tilt2, phi12, fref, tilt1_inf, tilt2_inf, version):
    """
    Regression test for calc_tilts_at_infty_hybrid_evolve for calculating tilts at infinity

    m1, m2: Detector frame masses of the binary, in kg
    chi1, chi2: Dimensionless spin magnitudes of the binary
    tilt1, tilt2: Tilt angles of the binary's spins (w.r.t. the orbital angular momentum) at fref
    phi12: Angle between the in-plane components of the spins at fref
    fref: Reference frequency, in Hz
    tilt1_inf, tilt2_inf: Tilt angles at infinity
    version: Version of calculation to test
    """

    res = calc_tilts_at_infty_hybrid_evolve(m1, m2, chi1, chi2, tilt1, tilt2, phi12, fref, version=version)

    rtol = 1.e-6
    np.testing.assert_allclose([res['tilt1_inf'], res['tilt2_inf']], [tilt1_inf, tilt2_inf], \
        err_msg="Check of tilts at infinity failed.", rtol=rtol)

@pytest.mark.parametrize("m1, m2, chi1, chi2, tilt1, tilt2, phi12, fref, tilt1_inf, tilt2_inf, version", test_data_inf_v2)
def test_inf_v2(m1, m2, chi1, chi2, tilt1, tilt2, phi12, fref, tilt1_inf, tilt2_inf, version):
    """
    Regression test for calc_tilts_at_infty_hybrid_evolve for calculating tilts at infinity

    m1, m2: Detector frame masses of the binary, in kg
    chi1, chi2: Dimensionless spin magnitudes of the binary
    tilt1, tilt2: Tilt angles of the binary's spins (w.r.t. the orbital angular momentum) at fref
    phi12: Angle between the in-plane components of the spins at fref
    fref: Reference frequency, in Hz
    tilt1_inf, tilt2_inf: Tilt angles at infinity
    version: Version of calculation to test
    """

    res = calc_tilts_at_infty_hybrid_evolve(m1, m2, chi1, chi2, tilt1, tilt2, phi12, fref, version=version)

    rtol = 1.e-6
    np.testing.assert_allclose([res['tilt1_inf'], res['tilt2_inf']], [tilt1_inf, tilt2_inf], \
        err_msg="Check of tilts at infinity failed.", rtol=rtol)

# -- run the tests ------------------------------
if __name__ == '__main__':
    args = sys.argv[1:] or ["-v", "-rs", "--junit-xml=junit-hybrid_evol.xml"]
    sys.exit(pytest.main(args=[__file__] + args))

