#!/usr/bin/env python3

import datetime
import sys
import os
import re
import json
import shutil
from optparse import Option, OptionGroup, OptionParser, OptionValueError, SUPPRESS_USAGE

from gppylib.gpparseopts import OptParser, OptChecker
from gppylib.commands.gp import ModifyConfSetting, SegmentStart, SegmentStop
from gppylib.commands.pg import PgBaseBackup
from gppylib.db import dbconn
from gppylib.commands import unix
from gppylib.commands.pg import DbStatus
from gppylib import gplog
from gppylib.commands.base import *
from gppylib.gparray import Segment
from time import sleep
from gppylib.operations.buildMirrorSegments import gDatabaseDirectories, gDatabaseFiles

description = ("""

Configure segment directories for installation into a pre-existing GPDB array.

Used by at least gpexpand, gprecoverseg, and gpaddmirrors

""")

DEFAULT_BATCH_SIZE=8
EXECNAME = os.path.split(__file__)[-1]

class ValidationException(Exception):

    def __init__(self, msg):
        self.__msg = msg
        Exception.__init__(self, msg)

    def getMessage(self):
        return self.__msg

class ConfExpSegCmd(Command):
    def __init__(self, name, cmdstr, datadir, port, dbid, contentid, newseg, tarfile, useLighterDirectoryValidation, isPrimary,
                 syncWithSegmentHostname, syncWithSegmentPort, verbose, validationOnly, forceoverwrite, replicationSlotName, logfileDirectory, progressFile):
        """
        @param useLighterDirectoryValidation if True then we don't require an empty directory; we just require that
                                             database stuff is not there.
        """
        self.datadir = datadir
        self.port = int(port)
        self.dbid = dbid
        self.contentid = contentid
        self.newseg = newseg
        self.tarfile = tarfile
        self.verbose = verbose
        self.useLighterDirectoryValidation = useLighterDirectoryValidation
        self.isPrimary = isPrimary
        self.syncWithSegmentHostname = syncWithSegmentHostname
        self.syncWithSegmentPort = syncWithSegmentPort
        self.replicationSlotName = replicationSlotName
        self.logfileDirectory = logfileDirectory
        self.progressFile = progressFile
        #
        # validationOnly if True then we validate directories and other simple things only; we don't
        #                actually configure the segment
        #
        self.validationOnly = validationOnly

        self.forceoverwrite = forceoverwrite
        Command.__init__(self, name, cmdstr)

    def validatePath(self, path):
        """
        Raises ValidationException when a validation problem is detected
        """

        if not os.path.exists(os.path.dirname(path)):
            self.makeOrUpdatePathAsNeeded(path)

        if not os.path.exists(path):
            return

        if self.useLighterDirectoryValidation:

            # validate that we don't contain database directories or files
            for toCheck in gDatabaseDirectories:
                if toCheck != "log": # it's okay to have log -- to save old logs around to look at
                    if os.path.exists(os.path.join(path, toCheck)):
                        raise ValidationException("Segment directory '%s' contains directory %s but should not!" %
                                    (path, toCheck))
            for toCheck in gDatabaseFiles:
                if os.path.exists(os.path.join(path, toCheck)):
                    raise ValidationException("Segment directory '%s' contains file %s but should not!" %
                                              (path, toCheck))
        else:
            # it better be empty
            if len(os.listdir(path)) != 0:
                raise ValidationException("Segment directory '%s' exists but is not empty!" % path)

    def makeOrUpdatePathAsNeeded(self, path):
        if os.path.exists(path):
            os.chmod(path, 0o700)
        else:
            os.makedirs(path, 0o700)

    def run(self):
        try:
            if self.newseg:
                # make directories, extract template update postgresql.conf
                logger.info("Validate data directories for new segment")
                try:
                    if not self.forceoverwrite:
                        self.validatePath(self.datadir)
                except ValidationException as e:
                    msg = "for segment with port %s: %s" %  (self.port, e.getMessage())
                    self.set_results(CommandResult(1, b'', msg.encode(), True, False))
                    raise
                except Exception as e:
                    self.set_results(CommandResult(1, b'', str(e).encode(), True, False))
                    raise

                if self.validationOnly:
                   self.set_results(CommandResult(0, b'', b'', True, False))
                   return

                if not self.isPrimary:
                    # If the caller does not specify a pg_basebackup
                    # progress file, then create a temporary one and handle
                    # its deletion upon success.
                    shouldDeleteProgressFile = False
                    if not self.progressFile:
                        shouldDeleteProgressFile = True
                        self.progressFile = '%s/pg_basebackup.%s.dbid%s.out' % (gplog.get_logger_dir(),
                                                                                datetime.datetime.today().strftime('%Y%m%d_%H%M%S'),
                                                                                self.dbid)
                    # Create a mirror based on the primary
                    cmd = PgBaseBackup(target_datadir=self.datadir,
                                       source_host=self.syncWithSegmentHostname,
                                       source_port=str(self.syncWithSegmentPort),
                                       create_slot = False,
                                       replication_slot_name=self.replicationSlotName,
                                       forceoverwrite=self.forceoverwrite,
                                       target_gp_dbid=self.dbid,
                                       logfile=self.progressFile)
                    try:
                        logger.info("Running pg_basebackup with progress output temporarily in %s" % self.progressFile)
                        cmd.run(validateAfter=True)
                        self.set_results(CommandResult(0, b'', b'', True, False))
                        if shouldDeleteProgressFile:
                            os.remove(self.progressFile)

                    except Exception as e:
                        #  If the cluster never has mirrors, cmd will fail
                        #  quickly because the internal slot doesn't exist.
                        #  Re-run with `create_slot`.
                        #  GPDB_12_MERGE_FIXME could we check it before? or let
                        #  pg_basebackup create slot if not exists.
                        cmd = PgBaseBackup(target_datadir=self.datadir,
                                           source_host=self.syncWithSegmentHostname,
                                           source_port=str(self.syncWithSegmentPort),
                                           create_slot = True,
                                           replication_slot_name=self.replicationSlotName,
                                           forceoverwrite=True,
                                           target_gp_dbid=self.dbid,
                                           logfile=self.progressFile)
                        try:
                            logger.info("Re-running pg_basebackup, creating the slot this time")
                            cmd.run(validateAfter=True)
                            self.set_results(CommandResult(0, b'', b'', True, False))
                            if shouldDeleteProgressFile:
                                os.remove(self.progressFile)
                        except Exception as e:
                            self.set_results(CommandResult(1, b'', str(e).encode(), True, False))
                            raise

                    logger.info("Successfully ran pg_basebackup: %s" % cmd.cmdStr)
                    return

                logger.info("Create or update data directories for new segment")
                try:
                    self.makeOrUpdatePathAsNeeded(self.datadir)

                except Exception as e:
                    self.set_results(CommandResult(1,b'',str(e).encode(),True,False))
                    raise

                if self.tarfile:
                    logger.info("extract tar file to new segment")
                    extractTarCmd = unix.ExtractTar('extract tar file to new segment', self.tarfile, self.datadir)
                    try:
                        logger.debug('Extracting tar file %s to %s' % (self.tarfile, self.datadir))
                        extractTarCmd.run(validateAfter=True)
                    except:
                        self.set_results(extractTarCmd.get_results())
                        logger.error(extractTarCmd.get_results())
                        raise

                # Restore pg_tblspc
                tblspc_config_json = os.path.join(self.datadir,
                                                  "pg_tblspc",
                                                  "newTableSpaceInfo.json")
                if os.path.exists(tblspc_config_json):
                    # If the tablespace json config file exists,
                    # we have to restore the tablespace files as
                    # the json config file.
                    #             json :: {
                    #                         "names" : [ name::string ],
                    #                         "oids"  : [ oid::string ],
                    #                         "details" : {
                    #                                          dbid::string : [ location::string ]
                    #                                     }
                    #                      }
                    # newTableSpaceInfo[names] and newTableSpaceInfo[oids] are tablespace infos
                    # that are in the same order.
                    # newTableSpaceInfo[dbid] is a list of locations in the same order of oids.
                    with open(tblspc_config_json) as f:
                        tblspc_config = json.load(f)
                        oids = tblspc_config["oids"]
                        names = tblspc_config["names"]
                        locations = tblspc_config["details"][str(self.dbid)]

                        dumps_dir = os.path.join(self.datadir, "pg_tblspc", "dumps")
                        for oid, location in zip(oids, locations):
                            fn = os.listdir(os.path.join(dumps_dir, oid))[0]
                            shutil.move(os.path.join(dumps_dir, oid, fn),
                                        os.path.join(location, str(self.dbid)))
                            os.unlink(os.path.join(self.datadir, "pg_tblspc", oid))
                            os.symlink(os.path.join(location, str(self.dbid)),
                                       os.path.join(self.datadir, "pg_tblspc", oid))

                        shutil.rmtree(dumps_dir)
                        os.remove(tblspc_config_json)

                # Create log for new segment in case it dosen't exist.
                try:

                    self.makeOrUpdatePathAsNeeded(os.path.join(self.datadir, "log"))
                except Exception as e:
                    self.set_results(CommandResult(1,b'',str(e).encode(),True,False))
                    raise

                logger.info("Updating %s/postgresql.conf" % self.datadir)
                modifyConfCmd = ModifyConfSetting('Updating %s/postgresql.conf' % self.datadir,
                                                  self.datadir + '/postgresql.conf',
                                                  'port', self.port, optType='number')
                try:
                    modifyConfCmd.run(validateAfter=True)
                except:
                    self.set_results(modifyConfCmd.get_results())
                    raise

                modifyConfCmd = ModifyConfSetting('Updating %s/postgresql.conf' % self.datadir,
                                                                      self.datadir + '/postgresql.conf',
                                                                      'gp_contentid', self.contentid, optType='number')
                try:
                    modifyConfCmd.run(validateAfter=True)
                except:
                    self.set_results(modifyConfCmd.get_results())
                    raise

                modifyConfCmd = ModifyConfSetting('Updating %s/internal.auto.conf' % self.datadir,
                                                  self.datadir + '/internal.auto.conf',
                                                  'gp_dbid', self.dbid, optType='number')
                try:
                    modifyConfCmd.run(validateAfter=True)
                except:
                    self.set_results(modifyConfCmd.get_results())
                    raise

            # We might need to stop the segment if the last setup failed past this point
            if os.path.exists('%s/postmaster.pid' % self.datadir):
                logger.info('%s/postmaster.pid exists.  Stopping segment' % self.datadir)
                stopCmd = SegmentStop('Stop new segment', self.datadir)
                try:
                    stopCmd.run(validateAfter=True)
                except:
                    pass

        except Exception as e:
            self.set_results(CommandResult(1, b'', str(e).encode(), True, False))
            if self.verbose:
                logger.exception(e)
            return
        else:
            self.set_results(CommandResult(0, b'', b'', True, False))

def getOidDirLists(oidDirs):
    """ break up list in <oid>:<dir> format into list of oid list of dir """
    oidList = []
    dirList = []
    i = 0
    while i < len(oidDirs):
        oidList.append(oidDirs[i])
        i = i + 1
        dirList.append(oidDirs[i])
        i = i + 1
    return oidList, dirList


def parseargs():
    parser = OptParser(option_class=OptChecker,
                description=' '.join(description.split()),
                version='%prog version $Revision: $')
    parser.set_usage('%prog is a utility script used by gpexpand, gprecoverseg, and gpaddmirrors and is not intended to be run separately.')
    parser.remove_option('-h')

    parser.add_option('-v','--verbose', action='store_true', help='debug output.')
    parser.add_option('-c', '--confinfo', type='string')
    parser.add_option('-t', '--tarfile', type='string')
    parser.add_option('-n', '--newsegments', action='store_true')
    parser.add_option('-b', '--batch-size', type='int', default=DEFAULT_BATCH_SIZE, metavar='<batch_size>')
    parser.add_option("-V", "--validation-only", dest="validationOnly", action='store_true', default=False)
    parser.add_option("-W", "--write-gpid-file-only", dest="writeGpidFileOnly", action='store_true', default=False)
    parser.add_option('-f', '--force-overwrite', dest='forceoverwrite', action='store_true', default=False)
    parser.add_option('-l', '--log-dir', dest="logfileDirectory", type="string")

    parser.set_defaults(verbose=False, filters=[], slice=(None, None))

    # Parse the command line arguments
    (options, args) = parser.parse_args()

    options.replicationSlotName = 'internal_wal_replication_slot'

    if not options.confinfo:
        raise Exception('Missing --confinfo argument.')

    if options.batch_size <= 0:
        logger.warn('batch_size was less than zero.  Setting to 1.')
        options.batch_size = 1

    if options.validationOnly and options.writeGpidFileOnly:
        raise Exception('Only one of --validation-only and --write-gpid-file-only can be specified')

    seg_info = []
    conf_lines = options.confinfo.split(',')
    for line in conf_lines:
        conf_vals = line.split(':')
        if len(conf_vals) < 5:
            raise Exception('Invalid configuration value: %s' % conf_vals)
        if conf_vals[0] == '':
            raise Exception('Missing data directory in: %s' % conf_vals)
        try:
            if int(conf_vals[1]) < 1024:
                    conf_vals[1] = int(conf_vals[1])
        except Exception as e:
            raise Exception('Invalid port in: %s' % conf_vals)
        if conf_vals[2] != 'true' and conf_vals[2] != 'false':
            raise Exception('Invalid isPrimary option in: %s' % conf_vals)
        if conf_vals[3] != 'true' and conf_vals[3] != 'false':
            raise Exception('Invalid directory validation option in: %s' % conf_vals)
        try:
            conf_vals[4] = int(conf_vals[4])
        except:
            raise Exception('Invalid dbid option in: %s' % conf_vals)
        try:
            conf_vals[5] = int(conf_vals[5])
        except:
            raise Exception('Invalid contentid option in: %s' % conf_vals)
        seg_info.append(conf_vals)

    seg_info_len = len(seg_info)
    if seg_info_len == 0:
        raise Exception('No segment configuration values found in --confinfo argument')
    elif seg_info_len < options.batch_size:
        # no need to have more threads than segments
        options.batch_size = seg_info_len

    return options, args, seg_info

try:
    (options, args, seg_info) = parseargs()

    logger = gplog.setup_tool_logging(EXECNAME, unix.getLocalHostname(), unix.getUserName(),
                                      logdir=options.logfileDirectory)

    if options.verbose:
        gplog.enable_verbose_logging()

    logger.info("Starting gpconfigurenewsegment with args: %s" % ' '.join(sys.argv[1:]))

    pool = WorkerPool(numWorkers=options.batch_size)

    for seg in seg_info:
        dataDir = seg[0]
        port = seg[1]
        isPrimary = seg[2] == "true"
        directoryValidationLevel = seg[3] == "true"
        dbid = int(seg[4])
        contentid = int(seg[5])

        # These variables should not be used if it's a primary
        # they will be junk data passed through the config.
        primaryHostName = seg[6]
        primarySegmentPort = int(seg[7])
        progressFile = seg[8]

        cmd = ConfExpSegCmd( name = 'Configure segment directory'
                           , cmdstr = ' '.join(sys.argv)
                           , datadir = dataDir
                           , port = port
                           , dbid = dbid
                           , contentid = contentid
                           , newseg = options.newsegments
                           , tarfile = options.tarfile
                           , useLighterDirectoryValidation = directoryValidationLevel
                           , isPrimary = isPrimary
                           , syncWithSegmentHostname = primaryHostName
                           , syncWithSegmentPort = primarySegmentPort
                           , verbose = options.verbose
                           , validationOnly = options.validationOnly
                           , forceoverwrite = options.forceoverwrite
                           , replicationSlotName = options.replicationSlotName
                           , logfileDirectory = options.logfileDirectory
                           , progressFile= progressFile
                           )
        pool.addCommand(cmd)

    pool.join()

    if options.validationOnly:
        errors = []
        for item in pool.getCompletedItems():
            if not item.get_results().wasSuccessful():
                errors.append(str(item.get_results().stderr).replace("\n", " "))

        if errors:
            print("\n".join(errors), file=sys.stderr)
            sys.exit(1)
        else: sys.exit(0)
    else:
        try:
            pool.check_results()
        except Exception as e:
            if options.verbose:
                logger.exception(e)
            logger.error(e)
            print(e, file=sys.stderr)
            sys.exit(1)

    sys.exit(0)

except Exception as msg:
    logger.error(msg)
    print(msg, file=sys.stderr)
    sys.exit(1)

finally:
    if pool:
        pool.haltWork()
