import json
import logging
import os
import subprocess
import sys
import concurrent.futures
from pathlib import Path
from typing import List

import requests
from tqdm import tqdm

logging.basicConfig(
    filename='log.txt',
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(message)s'
)
logger = logging.getLogger()

EXPECTED_NUM_LINES = 40_051_216 # Calculated from a different script
SEGMENT_SIZE = 10_000


def download_package_json(download_url: str) -> dict:
    """Download the package.json file at the given URL and extract dependencies and devDependencies."""
    try:
        res = requests.get(download_url)
        res.raise_for_status()
        p = subprocess.Popen(['zstd', '-d'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
        package_json_text = p.communicate(input=res.content)[0]
        p.kill()
        package_json = json.loads(package_json_text)
        dependencies = package_json.get('dependencies', {})
        dev_dependencies = package_json.get('devDependencies', {})
        return {"dependencies": dependencies, "devDependencies": dev_dependencies}
    except Exception as e:
        logger.error(f"Error downloading or parsing package.json at {download_url}: {e}")
        return {"dependencies": {}, "devDependencies": {}}


def process_segment(segment: List[str], chunk_idx: int, line_counter: int) -> int:
    """Process a segment of lines from the CSV file."""
    results = []

    SEGMENT_BATCH_SIZE = 100
    batches = [segment[i:i + SEGMENT_BATCH_SIZE] for i in range(0, len(segment), SEGMENT_BATCH_SIZE)]
    for batch in tqdm(batches):
        with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
            future_to_url = {executor.submit(download_package_json, f"https://softwareheritage.s3.amazonaws.com/content/{item}"): item for item in batch}
            for future in concurrent.futures.as_completed(future_to_url):
                url = future_to_url[future]
                try:
                    data = future.result()
                except Exception as exc:
                    print('%r generated an exception: %s' % (url, exc))
                else:
                    results.append(data)

    # Write results to file
    chunk_dir = Path(f"./results/{chunk_idx}")
    chunk_dir.mkdir(parents=True, exist_ok=True)
    output_file = chunk_dir / f"{line_counter}.json"
    with output_file.open('w') as f:
        json.dump(results, f)

    return line_counter + len(segment)


def process_map(file_path: str, chunk_idx: int) -> int:
    """Process a map file."""
    chunk_offset = 0

    # Read progress file to resume processing if it exists
    progress_file = Path(f"./progress/{chunk_idx}")
    if progress_file.exists():
        chunk_offset = int(progress_file.read_text().strip())

    # Open zstdcat process
    with subprocess.Popen(['zstdcat', file_path], stdout=subprocess.PIPE) as p:

        # Skip header line
        next(p.stdout)

        # Skip lines until progress
        if chunk_offset > 0:
            print("Skipping till", chunk_offset)

            for _ in range(chunk_offset):
                next(p.stdout)

        # Read line by line
        segment = []
        for line in tqdm(p.stdout, total=EXPECTED_NUM_LINES):
            if len(segment) < SEGMENT_SIZE:
                sha1 = line[41:81].decode('utf-8')
                segment.append(sha1)
            else:
                chunk_offset = process_segment(segment, chunk_idx, chunk_offset)
                segment = []

                # Save progress
                progress_file.write_text(str(chunk_offset))

        # Process last segment
        if segment:
            chunk_offset = process_segment(segment, chunk_idx, chunk_offset)

        # Save final progress
        progress_file.write_text(str(chunk_offset))

    return chunk_offset


if __name__ == "__main__":
    file_paths = sorted(Path('./maps').glob('*'))
    print("Number of maps:", len(file_paths))

    Path("./results").mkdir(parents=True, exist_ok=True)
    Path("./progress").mkdir(parents=True, exist_ok=True)

    chunk_idx = int(sys.argv[1])
    map_file = file_paths[chunk_idx]

    num_content_from_map = process_map(str(map_file), chunk_idx)
    print(f"Number of package.json from map {map_file} are {num_content_from_map}")