from PIL import Image, ImageDraw

import cairosvg           # pip install cairosvg

import os

from io import BytesIO


import math


# ---- Uitgebreide dataset (ISO-2 code → aantal vermeldingen) ----

orig_landen_data = {

    # Europa

    'nl': 2886, 'al': 4,    'am': 14,   'az': 22,   'be': 95,

    'ba': 8,    'bg': 19,   'cy': 12,   'dk': 17,   'de': 99,

    'ee': 18,   'fi': 2,    'fr': 37,   'ge': 9,    'gr': 56,

    'hu': 38,   'ie': 10,   'is': 4,    'it': 36,   'xk': 6,

    'hr': 7,    'lv': 3,    'lt': 9,    'lu': 7,    'mt': 8,

    'md': 11,   'mc': 1,    'no': 3,    'ua': 150,  'at': 15,

    'pl': 54,   'pt': 4,    'ro': 18,   'ru': 105,  'rs': 12,

    'sk': 2,    'es': 18,   'cz': 2,    'tr': 218,  'gb': 22,

    'by': 13,   'se': 10,   'ch': 7,

    # Afrika

    'dz': 2,    'ao': 6,    'bj': 1,    'bf': 3,    'bi': 18,

    'cf': 1,    'dj': 1,    'eg': 55,   'er': 13,   'et': 25,

    'ga': 1,    'gm': 4,    'gh': 6,    'gn': 11,   'ci': 3,

    'cv': 1,    'cm': 5,    'ke': 20,   'lr': 3,    'ly': 37,

    'mg': 1,    'mw': 6,    'ma': 65,   'mr': 4,    'mu': 1,

    'mz': 20,   'ne': 14,   'ng': 49,   'ug': 6,    'rw': 16,

    'sn': 4,    'sl': 2,    'sd': 39,   'so': 39,   'tz': 6,

    'tg': 1,    'tn': 10,   'eh': 10,   'zm': 2,    'zw': 5,

    'za': 5,

    # Noord-Amerika & Cariben

    'aw': 51,   'bs': 1,    'bm': 2,    'ca': 20,   'cr': 3,

    'cu': 16,   'dm': 2,    'do': 2,    'gt': 8,    'ht': 9,

    'hn': 4,    'jm': 1,    'mx': 8,    'ni': 5,    'pa': 9,

    'us': 53,

    # Zuid-Amerika

    'ar': 6,    'bo': 2,    'br': 19,   'cl': 5,    'co': 14,

    'ec': 1,    'an': 3,    'py': 1,    'pe': 3,    'sr': 36,

    've': 28,   'gs': 28,

    # Azië & Midden-Oosten

    'af': 91,   'bh': 21,   'bd': 21,   'bt': 1,    'bn': 6,

    'kh': 9,    'cn': 129,  'ph': 7,    'hk': 7,    'in': 81,

    'id': 50,   'iq': 96,   'ir': 133,  'il': 307,  'jp': 24,

    'ye': 38,   'jo': 3,    'kg': 1,    'kw': 8,    'la': 1,

    'lb': 12,   'my': 1,    'mn': 2,    'mm': 27,   'np': 7,

    'kp': 20,   'om': 16,   'tl': 1,    'pk': 46,   'ps': 21,

    'qa': 36,   'ru': 106,  'sa': 33,   'sg': 3,    'lk': 9,

    'sy': 148,  'tj': 1,    'tw': 13,   'th': 14,   'tm': 1,

    'ae': 7,    'vn': 9,    'kr': 9,

    # Oceanië

    'au': 10,   'nz': 2,    'vu': 1

}


# ---- Gebruik volledige aantallen per land ----

landen_data = orig_landen_data.copy()


# ---- Grid-parameters ----

cell_size    = 100       # pixels per cel

margin       = 2        # whitespace rondom vlag en stok

flag_size    = cell_size - 2 * margin

columns      = 120       # vlaggen per rij

stick_length = 80       # stoklengte in px

shear_x      = -0.25    # 3D-skew factor


# ---- Pad naar map met SVG vlaggen ----

flags_dir = r"C:\Users\ReMarkt\Desktop\country-flags-main\svg"


# ---- Bouw lijst van ISO2 codes ----

grid_data = []

for iso2, count in sorted(landen_data.items(), key=lambda x: -x[1]):

    grid_data.extend([iso2] * count)


# ---- Bereken rijen en canvasgrootte ----

rows      = math.ceil(len(grid_data) / columns)

img_width = columns * cell_size

img_height= rows * cell_size + stick_length


# ---- Maak achtergrond (off-white) ----

img = Image.new('RGBA', (img_width, img_height), (255, 255, 224, 255))

draw = ImageDraw.Draw(img)


# ---- Teken vlaggen, schuine stokjes en schaduwen ----

for idx, iso2 in enumerate(grid_data):

    row = idx // columns

    col = idx % columns

    x0  = col * cell_size + margin

    y0  = row * cell_size + margin


    svg_path = os.path.join(flags_dir, f"{iso2}.svg")

    if not os.path.exists(svg_path):

        print(f"⚠️ Vlag ontbreekt: {svg_path}")

        continue


    # SVG → PNG → RGBA

    png_data = cairosvg.svg2png(

        url=svg_path,

        output_width=flag_size,

        output_height=flag_size

    )

    flag_img = Image.open(BytesIO(png_data)).convert('RGBA')

    bbox = flag_img.getbbox()

    if bbox:

        flag_img = flag_img.crop(bbox)

    flag_img = flag_img.resize((flag_size, flag_size), Image.LANCZOS)

    skewed = flag_img.transform(

        (flag_size, flag_size),

        Image.AFFINE,

        (1, shear_x, 0, 0, 1, 0),

        resample=Image.BICUBIC,

        fillcolor=(0, 0, 0, 0)

    )

    img.paste(skewed, (x0, y0), skewed)

    x1, y1 = x0, y0

    dy = flag_size + stick_length

    dx = -shear_x * dy

    x2 = int(x1 + dx)

    y2 = y1 + dy

    draw.line([(x1, y1), (x2, y2)], fill='gray', width=2)

    sh_w, sh_h = 8, 3

    shadow = Image.new('RGBA', (sh_w, sh_h), (0, 0, 0, 80))

    img.paste(shadow, (x2-sh_w//2, y2-sh_h//2), shadow)


# ---- Opslaan ----

output_path = 'grid_flags_stronger_skew.png'

img.convert('RGB').save(output_path)

print(f"✅ Afbeelding opgeslagen als {output_path}")
back