#!/usr/bin/env python3
# Copyright (C) 2026 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Verifies trace_processor plugin dependency graph rules.

Two rules are enforced by inspecting the JSON GN build description:

  1. Targets outside src/trace_processor/plugins/ must not depend on any
     plugin target. The only allowed dependers of plugin targets from
     non-plugin code are the leaf consumer targets in
     ALLOWED_NON_PLUGIN_DEPENDERS (the trace_processor lib/unittests
     and the root :perfetto_benchmarks executable).

  2. plugin -> plugin GN deps must be matched by a #include of the
     dependee plugin's public header in the depender plugin's sources.
     The public header is plugins/<name>/<name>.h. Including the header
     is what forces the C++ Plugin<Self, ...> template parameter to
     resolve.
"""

from __future__ import print_function

import argparse
import os
import re
import sys
from typing import Dict, List, Optional, Set, Tuple

import gn_utils

PLUGIN_PATH_PREFIX = 'src/trace_processor/plugins/'
ALLOWED_NON_PLUGIN_DEPENDERS = (
    '//src/trace_processor:lib',
    '//src/trace_processor:unittests',
    '//:perfetto_benchmarks',
)

INCLUDE_RE = re.compile(r'^\s*#include\s+"([^"]+)"')


def label_dir_and_name(label: str) -> Tuple[str, str]:
  """Splits a GN label like //foo/bar:baz into ('foo/bar', 'baz')."""
  no_tc = gn_utils.label_without_toolchain(label)
  assert no_tc.startswith('//'), no_tc
  body = no_tc[2:]
  if ':' in body:
    d, n = body.split(':', 1)
    return d, n
  return body, os.path.basename(body)


def plugin_name_for_dir(dir_path: str) -> Optional[str]:
  """For src/trace_processor/plugins/<name>[/sub] returns <name>."""
  if not dir_path.startswith(PLUGIN_PATH_PREFIX):
    return None
  rest = dir_path[len(PLUGIN_PATH_PREFIX):]
  if not rest:
    return None
  return rest.split('/', 1)[0]


def header_included_in_files(header: str, files: Set[str]) -> bool:
  for rel in files:
    abs_path = os.path.join(gn_utils.repo_root(), rel)
    try:
      with open(abs_path, 'r', encoding='utf-8') as f:
        for line in f:
          m = INCLUDE_RE.match(line)
          if m and m.group(1) == header:
            return True
    except (OSError, UnicodeDecodeError):
      continue
  return False


def check(desc: Dict) -> List[str]:
  rule1_violations: Set[Tuple[str, str]] = set()
  plugin_to_plugin: Set[Tuple[str, str]] = set()
  plugin_sources: Dict[str, Set[str]] = {}

  for label, target in desc.items():
    src_dir, _ = label_dir_and_name(label)
    src_label = gn_utils.label_without_toolchain(label)
    src_plugin = plugin_name_for_dir(src_dir)

    if src_plugin is not None:
      bucket = plugin_sources.setdefault(src_plugin, set())
      for s in list(target.get('sources', [])) + list(target.get('public', [])):
        if not s.startswith('//'):
          continue
        if s.endswith('.h') or s.endswith('.cc'):
          bucket.add(gn_utils.label_to_path(s))

    for dep in target.get('deps', []):
      dep_dir, _ = label_dir_and_name(dep)
      dst_plugin = plugin_name_for_dir(dep_dir)
      if dst_plugin is None:
        continue

      if src_plugin is not None:
        if src_plugin == dst_plugin:
          continue
        plugin_to_plugin.add((src_plugin, dst_plugin))
      elif src_label in ALLOWED_NON_PLUGIN_DEPENDERS:
        continue
      else:
        rule1_violations.add((src_label, gn_utils.label_without_toolchain(dep)))

  errors: List[str] = []
  allowed = ' or '.join(ALLOWED_NON_PLUGIN_DEPENDERS)
  for label, dep in sorted(rule1_violations):
    errors.append(
        f'target "{label}" depends on plugin target "{dep}". '
        f'Only plugin targets and {allowed} may depend on plugin targets.')

  for src, dst in sorted(plugin_to_plugin):
    expected = f'{PLUGIN_PATH_PREFIX}{dst}/{dst}.h'
    if not header_included_in_files(expected, plugin_sources.get(src, set())):
      errors.append(
          f'plugin "{src}" GN-depends on plugin "{dst}" but no source '
          f'file in {PLUGIN_PATH_PREFIX}{src}/ includes "{expected}". '
          f'plugin->plugin deps must be matched by a Plugin<{src}, {dst}> '
          f'template parameter, which requires including the dependee '
          f"plugin's public header.")

  return errors


def main() -> int:
  parser = argparse.ArgumentParser(description=__doc__)
  parser.add_argument(
      '--out',
      help='use an existing GN out directory instead of creating a temp one.')
  args = parser.parse_args()

  if args.out:
    desc = gn_utils.load_build_description(args.out)
    errors = check(desc)
  else:
    with gn_utils.BuildDescription('') as bd:
      errors = check(bd.desc)

  for e in errors:
    print(e, file=sys.stderr)
  return 1 if errors else 0


if __name__ == '__main__':
  sys.exit(main())
