#!/usr/bin/env python
# Copyright (C) 2019 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This tool uses a collection of BUILD.gn files and build targets to generate
# an "amalgamated" C++ header and source file pair which compiles to an
# equivalent program. The tool also outputs the necessary compiler and linker
# flags needed to compile the resulting source code.

from __future__ import print_function
import argparse
import os
import re
import shutil
import subprocess
import sys
import tempfile

import gn_utils

# Default targets to include in the result.
default_targets = [
    '//:libperfetto_client_experimental',
    "//include/perfetto/protozero:protozero",
    "//protos/perfetto/common:zero",
    "//protos/perfetto/config:zero",
    "//protos/perfetto/config:zero",
    "//protos/perfetto/trace:zero",
    "//protos/perfetto/trace/gpu:zero",
]

# Arguments for the GN output directory (unless overridden from the command
# line).
gn_args = 'is_debug=false'

# Compiler flags which aren't filtered out.
cflag_whitelist = r'^-(W.*|fno-exceptions|fPIC|std.*|fvisibility.*)$'

# Linker flags which aren't filtered out.
ldflag_whitelist = r'^-()$'

# Libraries which are filtered out.
lib_blacklist = r'^(c|gcc_eh)$'

# Macros which aren't filtered out.
define_whitelist = r'^(PERFETTO.*|GOOGLE_PROTOBUF.*)$'

# Includes which will be removed from the generated source.
includes_to_remove = r'^(gtest).*$'

default_cflags = [
    # Since we're expanding header files into the generated source file, some
    # constant may remain unused.
    '-Wno-unused-const-variable'
]

# Build flags to satisfy a protobuf (lite or full) dependency.
protobuf_cflags = [
    # Note that these point to the local copy of protobuf in buildtools. In
    # reality the user of the amalgamated result will have to provide a path to
    # an installed copy of the exact same version of protobuf which was used to
    # generate the amalgamated build.
    '-isystembuildtools/protobuf/src',
    '-Lbuildtools/protobuf/src/.libs',
    # We also need to disable some warnings for protobuf.
    '-Wno-missing-prototypes',
    '-Wno-missing-variable-declarations',
    '-Wno-sign-conversion',
    '-Wno-unknown-pragmas',
    '-Wno-unused-macros',
]

# A mapping of dependencies to system libraries. Libraries in this map will not
# be built statically but instead added as dependencies of the amalgamated
# project.
system_library_map = {
    '//buildtools:protobuf_full': {
        'libs': ['protobuf'],
        'cflags': protobuf_cflags,
    },
    '//buildtools:protobuf_lite': {
        'libs': ['protobuf-lite'],
        'cflags': protobuf_cflags,
    },
    '//buildtools:protoc_lib': {'libs': ['protoc']},
}

# ----------------------------------------------------------------------------
# End of configuration.
# ----------------------------------------------------------------------------

tool_name = os.path.basename(__file__)
preamble = """// Copyright (C) 2019 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is automatically generated by %s. Do not edit.
""" % tool_name


def apply_blacklist(blacklist, items):
    return [item for item in items if not re.match(blacklist, item)]


def apply_whitelist(whitelist, items):
    return [item for item in items if re.match(whitelist, item)]


class Error(Exception):
    pass


class DependencyNode(object):
    """A target in a GN build description along with its dependencies."""

    def __init__(self, target_name):
        self.target_name = target_name
        self.dependencies = set()

    def add_dependency(self, target_node):
        if target_node in self.dependencies:
            return
        self.dependencies.add(target_node)

    def iterate_depth_first(self):
        for node in sorted(self.dependencies, key=lambda n: n.target_name):
            for node in node.iterate_depth_first():
                yield node
        if self.target_name:
            yield self


class DependencyTree(object):
    """A tree of GN build target dependencies."""

    def __init__(self):
        self.target_to_node_map = {}
        self.root = self._get_or_create_node(None)

    def _get_or_create_node(self, target_name):
        if target_name in self.target_to_node_map:
            return self.target_to_node_map[target_name]
        node = DependencyNode(target_name)
        self.target_to_node_map[target_name] = node
        return node

    def add_dependency(self, from_target, to_target):
        from_node = self._get_or_create_node(from_target)
        to_node = self._get_or_create_node(to_target)
        assert from_node is not to_node
        from_node.add_dependency(to_node)

    def iterate_depth_first(self):
        for node in self.root.iterate_depth_first():
            yield node


class AmalgamatedProject(object):
    """In-memory representation of an amalgamated source/header pair."""

    def __init__(self, desc, source_deps):
        """Constructor.

        Args:
            desc: JSON build description.
            source_deps: A map of (source file, [dependency header]) which is
                to detect which header files are included by each source file.
        """
        self.desc = desc
        self.source_deps = source_deps
        self.header = []
        self.source = []
        # Note that we don't support multi-arg flags.
        self.cflags = set(default_cflags)
        self.ldflags = set()
        self.defines = set()
        self.libs = set()
        self._dependency_tree = DependencyTree()
        self._processed_sources = set()
        self._processed_headers = set()
        self._processed_source_headers = set()  # Header files included from .cc
        self._include_re = re.compile(r'#include "(.*)"')

    def add_target(self, target_name):
        """Include |target_name| in the amalgamated result."""
        self._dependency_tree.add_dependency(None, target_name)
        self._add_target_dependencies(target_name)
        self._add_target_flags(target_name)
        self._add_target_headers(target_name)

    def _iterate_dep_edges(self, target_name):
        target = self.desc[target_name]
        for dep in target.get('deps', []):
            # Ignore system libraries since they will be added as build-time
            # dependencies.
            if dep in system_library_map:
                continue
            # Don't descend into build action dependencies.
            if self.desc[dep]['type'] == 'action':
                continue
            for sub_target, sub_dep in self._iterate_dep_edges(dep):
                yield sub_target, sub_dep
            yield target_name, dep

    def _iterate_target_and_deps(self, target_name):
        yield target_name
        for _, dep in self._iterate_dep_edges(target_name):
            yield dep

    def _add_target_dependencies(self, target_name):
        for target, dep in self._iterate_dep_edges(target_name):
            self._dependency_tree.add_dependency(target, dep)

        def process_dep(dep):
            if dep in system_library_map:
                self.libs.update(system_library_map[dep].get('libs', []))
                self.cflags.update(system_library_map[dep].get('cflags', []))
                self.defines.update(system_library_map[dep].get('defines', []))
                return True

        def walk_all_deps(target_name):
            target = self.desc[target_name]
            for dep in target.get('deps', []):
                if process_dep(dep):
                    return
                walk_all_deps(dep)
        walk_all_deps(target_name)

    def _filter_cflags(self, cflags):
        # Since we want to deduplicate flags, combine two-part switches (e.g.,
        # "-foo bar") into one value ("-foobar") so we can store the result as
        # a set.
        result = []
        for flag in cflags:
            if flag.startswith('-'):
                result.append(flag)
            else:
                result[-1] += flag
        return apply_whitelist(cflag_whitelist, result)

    def _add_target_flags(self, target_name):
        for target_name in self._iterate_target_and_deps(target_name):
            target = self.desc[target_name]
            self.cflags.update(self._filter_cflags(target.get('cflags', [])))
            self.cflags.update(self._filter_cflags(target.get('cflags_cc', [])))
            self.ldflags.update(
                apply_whitelist(ldflag_whitelist, target.get('ldflags', [])))
            self.libs.update(
                apply_blacklist(lib_blacklist, target.get('libs', [])))
            self.defines.update(
                apply_whitelist(define_whitelist, target.get('defines', [])))

    def _add_target_headers(self, target_name):
        target = self.desc[target_name]
        if not 'sources' in target:
            return
        headers = [gn_utils.label_to_path(s)
                   for s in target['sources'] if s.endswith('.h')]
        for header in headers:
            self._add_header(target_name, header)

    def _get_include_dirs(self, target_name):
        include_dirs = set()
        for target_name in self._iterate_target_and_deps(target_name):
            target = self.desc[target_name]
            if 'include_dirs' in target:
                include_dirs.update(
                    [gn_utils.label_to_path(d) for d in target['include_dirs']])
        return include_dirs

    def _add_source_included_header(
            self, include_dirs, allowed_files, header_name):
        if header_name in self._processed_source_headers:
            return
        self._processed_source_headers.add(header_name)
        for include_dir in include_dirs:
            full_path = os.path.join(include_dir, header_name)
            if os.path.exists(full_path):
                if not full_path in allowed_files:
                    return
                with open(full_path) as f:
                    self.source.append(
                        '// %s begin header: %s' % (tool_name, full_path))
                    self.source.extend(
                        self._process_source_includes(
                            include_dirs, allowed_files, f))
                return
        msg = 'Looked in %s' % ', '.join('"%s"' % d for d in include_dirs)
        raise Error('Header file %s not found. %s' % (header_name, msg))

    def _add_source(self, target_name, source_name):
        if source_name in self._processed_sources:
            return
        self._processed_sources.add(source_name)
        include_dirs = self._get_include_dirs(target_name)
        deps = self.source_deps[source_name]
        if not os.path.exists(source_name):
            raise Error('Source file %s not found' % source_name)
        with open(source_name) as f:
            self.source.append(
                '// %s begin source: %s' % (tool_name, source_name))
            try:
                self.source.extend(self._patch_source(source_name,
                    self._process_source_includes(include_dirs, deps, f)))
            except Error as e:
                raise Error(
                    'Failed adding source %s: %s' % (source_name, e.message))

    def _add_header_included_header(self, include_dirs, header_name):
        if header_name in self._processed_headers:
            return
        self._processed_headers.add(header_name)
        for include_dir in include_dirs:
            full_path = os.path.join(include_dir, header_name)
            if os.path.exists(full_path):
                with open(full_path) as f:
                    self.header.append(
                        '// %s begin header: %s' % (tool_name, full_path))
                    self.header.extend(
                        self._process_header_includes(include_dirs, f))
                return
        msg = 'Looked in %s' % ', '.join('"%s"' % d for d in include_dirs)
        raise Error('Header file %s not found. %s' % (header_name, msg))

    def _add_header(self, target_name, header_name):
        if header_name in self._processed_headers:
            return
        self._processed_headers.add(header_name)
        include_dirs = self._get_include_dirs(target_name)
        if not os.path.exists(header_name):
            raise Error('Header file %s not found' % source_name)
        with open(header_name) as f:
            self.header.append(
                '// %s begin header: %s' % (tool_name, header_name))
            try:
                self.header.extend(
                    self._process_header_includes(include_dirs, f))
            except Error as e:
                raise Error(
                    'Failed adding header %s: %s' % (header_name, e.message))

    def _patch_source(self, source_name, lines):
        result = []
        namespace = re.sub(r'[^a-z]', '_',
            os.path.splitext(os.path.basename(source_name))[0])
        for line in lines:
            # Protobuf generates an identical anonymous function into each
            # message description. Rename all but the first occurrence to avoid
            # duplicate symbol definitions.
            line = line.replace('MergeFromFail', '%s_MergeFromFail' % namespace)
            result.append(line)
        return result

    def _process_source_includes(self, include_dirs, allowed_files, file):
        result = []
        for line in file:
            line = line.rstrip('\n')
            m = self._include_re.match(line)
            if not m:
                result.append(line)
                continue
            elif re.match(includes_to_remove, m.group(1)):
                result.append('// %s removed: %s' % (tool_name, line))
            else:
                result.append('// %s expanded: %s' % (tool_name, line))
                self._add_source_included_header(
                    include_dirs, allowed_files, m.group(1))
        return result

    def _process_header_includes(self, include_dirs, file):
        result = []
        for line in file:
            line = line.rstrip('\n')
            m = self._include_re.match(line)
            if not m:
                result.append(line)
                continue
            elif re.match(includes_to_remove, m.group(1)):
                result.append('// %s removed: %s' % (tool_name, line))
            else:
                result.append('// %s expanded: %s' % (tool_name, line))
                self._add_header_included_header(include_dirs, m.group(1))
        return result

    def generate(self):
        """Prepares the output for this amalgamated project.

        Call save() to persist the result.
        """

        source_files = []
        for node in self._dependency_tree.iterate_depth_first():
            target = self.desc[node.target_name]
            if not 'sources' in target:
                continue
            sources = [(node.target_name, gn_utils.label_to_path(s))
                        for s in target['sources'] if s.endswith('.cc')]
            source_files.extend(sources)
        for target_name, source_name in source_files:
            self._add_source(target_name, source_name)

    def _get_nice_path(self, prefix, format):
        basename = os.path.basename(prefix)
        return os.path.join(
            os.path.relpath(os.path.dirname(prefix)), format % basename)

    def save(self, output_prefix):
        """Save the generated header and source file pair.

        Returns a message describing the output with build instructions.
        """
        header_file = self._get_nice_path(output_prefix, '%s.h')
        source_file = self._get_nice_path(output_prefix, '%s.cc')
        with open(header_file, 'w') as f:
            f.write('\n'.join([preamble] + self.header + ['\n']))
        with open(source_file, 'w') as f:
            include_stmt = '#include "%s"' % os.path.basename(header_file)
            f.write('\n'.join([preamble, include_stmt] + self.source + ['\n']))
        build_cmd = self.get_build_command(output_prefix)

        return """Amalgamated project written to %s and %s.

Build settings:
 - cflags:    %s
 - ldflags:   %s
 - libs:      %s
 - defines:   %s

Example build command:

%s
""" % (header_file, source_file, ' '.join(self.cflags), ' '.join(self.ldflags),
       ' '.join(self.libs), ' '.join(self.defines), ' '.join(build_cmd))

    def get_build_command(self, output_prefix):
        """Returns an example command line for building the output source."""
        source = self._get_nice_path(output_prefix, '%s.cc')
        library = self._get_nice_path(output_prefix, 'lib%s.so')
        build_cmd = ['clang++', source, '-o', library, '-shared'] + \
            sorted(self.cflags) + sorted(self.ldflags)
        for lib in sorted(self.libs):
            build_cmd.append('-l%s' % lib)
        for define in sorted(self.defines):
            build_cmd.append('-D%s' % define)
        return build_cmd



def create_amalgamated_project_for_targets(desc, targets, source_deps):
    """Generate an amalgamated project for a list of GN targets."""
    project = AmalgamatedProject(desc, source_deps)
    for target in targets:
        project.add_target(target)
    project.generate()
    return project


def main():
    parser = argparse.ArgumentParser(
        description='Generate an amalgamated header/source pair from a GN '
                    'build description.')
    parser.add_argument(
        '--output',
        help='Base name of files to create. A .cc/.h extension will be added',
        default=os.path.join(gn_utils.repo_root(), 'perfetto'))
    parser.add_argument(
        '--gn_args', help='GN arguments used to prepare the output directory',
        default=gn_args)
    parser.add_argument(
        '--keep', help='Don\'t delete the GN output directory at exit',
        action='store_true')
    parser.add_argument(
        '--build', help='Also compile the generated files',
        action='store_true')
    parser.add_argument(
        '--check', help='Don\'t keep the generated files',
        action='store_true')
    parser.add_argument('--quiet', help='Only report errors',
        action='store_true')
    parser.add_argument(
        'targets',
        nargs=argparse.REMAINDER,
        help='Targets to include in the output (e.g., "//:libperfetto")')
    args = parser.parse_args()
    targets = args.targets or default_targets

    output = args.output
    if args.check:
        output = os.path.join(tempfile.mkdtemp(), 'perfetto_amalgamated')

    try:
        if not args.quiet:
            print('Building project...')
        out = gn_utils.prepare_out_directory(
            args.gn_args, 'tmp.gen_amalgamated')
        desc = gn_utils.load_build_description(out)
        # We need to build everything first so that the necessary header
        # dependencies get generated.
        gn_utils.build_targets(out, targets)
        source_deps = gn_utils.compute_source_dependencies(out)
        project = create_amalgamated_project_for_targets(
            desc, targets, source_deps)
        result = project.save(output)
        if not args.quiet:
            print(result)
        if args.build:
            if not args.quiet:
                sys.stdout.write('Building amalgamated project...')
                sys.stdout.flush()
            subprocess.check_call(project.get_build_command(output))
            if not args.quiet:
                print('done')
    finally:
        if not args.keep:
            shutil.rmtree(out)
        if args.check:
            shutil.rmtree(os.path.dirname(output))

if __name__ == '__main__':
    sys.exit(main())
