import os
from PIL import Image
from ollama import Client
import cv2
import numpy as np
import difflib  # <--- 新增這個庫，用於模糊比對

# -----------------------
# 參數設定
# -----------------------
IMAGE_PATH = "pdf_cache2/holiday.png"
PREPROCESSED_PATH = "pdf_cache2/holiday_preprocessed.png"
OLLAMA_SERVER = "http://140.116.240.181:45015/"
STAGE1_MODEL = "devstral-small-2:latest"
#STAGE2_MODEL = "gpt-oss:120b"
#STAGE2_MODEL = "llama3.3:70b-instruct-q5_K_M"
STAGE2_MODEL = "mistral-small:24b-instruct-2501-q4_K_M"
STREAM_MODE = False
client = Client(host=OLLAMA_SERVER, headers={'Content-Type': 'application/json'})

# -----------------------
# Stage 0: 前處理 (OpenCV 加強版)
# -----------------------
def preprocess_image(image_path, save_path, upscale=1):
    print(f"[Stage 0] 正在處理: {image_path}")

    # 1. 使用 OpenCV 讀取圖片
    img = cv2.imread(image_path)
    
    if img is None:
        print(f"❌ 錯誤：找不到圖片 {image_path}")
        return None

    # 2. 放大 (Upscaling) - 建議放大 2~3 倍
    height, width = img.shape[:2]
    img = cv2.resize(img, (width * upscale, height * upscale), interpolation=cv2.INTER_CUBIC)

    # 3. 轉灰階 (Grayscale)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 4. 自適應二值化 (Adaptive Thresholding)
    binary = cv2.adaptiveThreshold(
        gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 15, 10
    )

    # 5. 存檔
    cv2.imwrite(save_path, binary)
    
    print(f"   ✅ 影像增強完成 (放大+二值化)，存於: {save_path}")
    return save_path

# -----------------------
# Helper: 智慧滑動視窗切圖 (Smart Chunking)
# -----------------------
def split_image_into_chunks(image_path, chunk_height=800, overlap=400):
    """
    智慧切片：自動偵測行距進行切割，並塗白邊緣防止殘字。
    """
    img = cv2.imread(image_path)
    if img is None:
        print(f"❌ 錯誤：無法讀取圖片 {image_path}")
        return []

    height, width = img.shape[:2]
    
    # 用於計算投影的二值化影像
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, binary = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY_INV)

    chunks = []
    y = 0
    chunk_id = 0
    
    base_dir = os.path.dirname(image_path)
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    
    print(f"   ✂️ 啟動智慧切片 (目標高度:{chunk_height}, 重疊:{overlap})...")

    while y < height:
        target_cut = min(y + chunk_height, height)
        
        if target_cut == height:
            real_cut = height
        else:
            # 尋找最佳切點 (Safe Cut)
            search_range = 60
            start_search = max(y + int(chunk_height * 0.5), target_cut - search_range)
            end_search = min(target_cut + search_range, height)
            
            if start_search < end_search:
                row_sums = np.sum(binary[start_search:end_search], axis=1)
                min_idx = np.argmin(row_sums)
                real_cut = start_search + min_idx
            else:
                real_cut = target_cut 

        # 執行切割
        chunk = img[y:real_cut, 0:width].copy()
        
        # 邊緣塗白 (Edge Masking)
        mask_height = 20
        if chunk_id > 0 and chunk.shape[0] > mask_height:
            chunk[0:mask_height, :] = [255, 255, 255] # 上緣塗白
            
        if real_cut < height and chunk.shape[0] > mask_height:
            chunk[-mask_height:, :] = [255, 255, 255] # 下緣塗白

        # 存檔
        chunk_filename = os.path.join(base_dir, f"{base_name}_chunk_{chunk_id}.png")
        cv2.imwrite(chunk_filename, chunk)
        chunks.append(chunk_filename)
        
        if real_cut == height:
            break
            
        # 移動視窗 (Next Step)
        y = real_cut - overlap
        chunk_id += 1
        
    print(f"      -> 智慧切片完成，共 {len(chunks)} 張")
    return chunks

# -----------------------
# Helper: 模糊接龍 (Fuzzy Stitching) - 關鍵修正版
# -----------------------
def merge_text_overlap(prev_text, curr_text, min_overlap=20):
    """
    使用 difflib 進行模糊比對，容許 OCR 的小誤差。
    不會誤刪內容，只會剔除高度重疊的部分。
    """
    if not prev_text:
        return curr_text
        
    # 只比較接縫處 (前一段的末尾 vs 新一段的開頭)
    # 增加檢查範圍到 500 字，確保能抓到大段落重疊
    check_len = min(len(prev_text), len(curr_text), 500)
    
    suffix = prev_text[-check_len:]
    prefix = curr_text[:check_len]
    
    # 使用 SequenceMatcher 尋找最長且相似的重疊區塊
    s = difflib.SequenceMatcher(None, suffix, prefix)
    match = s.find_longest_match(0, len(suffix), 0, len(prefix))
    
    # 判定標準：
    # 1. 重疊長度要夠長 (避免誤判常用詞)
    # 2. 必須是從 prefix 的開頭開始 (match.b == 0)，代表新文字的頭部確實跟舊文字重疊
    if match.size >= min_overlap and match.b == 0:
        # print(f"   [接龍] 發現重疊 {match.size} 字，已剔除。")
        # 保留 prev_text，並加上 curr_text 扣掉重疊後的部分
        return prev_text + curr_text[match.size:]
    else:
        # 沒發現足夠的重疊，直接換行接上
        return prev_text + "\n" + curr_text

# -----------------------
# Stage 1: OCR 使用 devstral-small-2
# -----------------------
def stage1_ocr(image_path, stream=True):
    prompt = """
    你是一個 OCR 引擎。
    [任務] 請將圖片中的文字「逐字轉錄」出來。
    [嚴格禁止] 1. 禁止輸出 Markdown 表格。 2. 禁止省略。 3. 禁止摘要。
    [輸出格式] 請以「條列式」或「純文字」逐行輸出。
    """
    chunks = split_image_into_chunks(image_path)
    
    # ✅ 修改點：使用變數來進行文字接龍
    full_text_merged = "" 

    for i, chunk in enumerate(chunks):
        print(f"[Stage 1] 正在辨識切片 {i+1}/{len(chunks)} ...", end=" ", flush=True)

        try:
            response_generator = client.chat(
                model=STAGE1_MODEL,
                options={
                    'num_ctx': 8192,
                    'temperature': 0.1,
                    'repeat_penalty': 1.3,
                    'top_k': 40,
                    'num_predict': 512
                },
                messages=[{'role': 'user', 'content': prompt, 'images': [chunk]}],
                stream=stream
            )

            chunk_buffer = ""
            if stream:
                for segment in response_generator:
                    content = segment['message']['content']
                    print(content, end="", flush=True)
                    chunk_buffer += content
                print("\n")
            else:
                chunk_buffer = response_generator['message']['content']

            full_text_merged = merge_text_overlap(full_text_merged, chunk_buffer)

        except Exception as e:
            print(f"\n   ❌ 錯誤：{e}")
        
        finally:
             if os.path.exists(chunk):
                 os.remove(chunk)

    print("[Stage 1] OCR 全部完成")
    return full_text_merged

# -----------------------
# Stage 2: 結構化 & 修正 (V10 邏輯推論 Prompt)
# -----------------------
def stage2_llm(raw_text):
    print(f"\n[Stage 2] 啟動 {STAGE2_MODEL} 進行邏輯歸納與解構 (V10)...")

    prompt = f"""
    你是一位「高階文件重構專家」。
    任務：將 OCR 識別出的破碎文字，重組為**語意完整**、**邏輯通順**的專業文件。

    【原始 OCR 文字】
    \"\"\"{raw_text}\"\"\"

    【通用重構協議 (Universal Refactoring Protocol)】
    請依序執行以下邏輯。注意：以下提供的範例僅供參考其**思考模式**，請將此邏輯應用於整份文件的所有內容：

    1. **全域去重 (Global Cleanup)**:
       - 掃描全文，若發現重複段落，只保留語意最清晰的一份。
       - 將簡體字轉換為標準繁體中文。

    2. **階層還原與標題上浮 (Hierarchy Restoration & Header Hoisting) - [關鍵修正]**:
       - **現象**：OCR 常先讀取「細則 (1. 2. ...)」，最後才讀取「總標題」。
       - **嚴格禁令**：**絕對禁止**將條列式編號 (如 1., 2. 或 A, B) 直接當作「大標題」。它們必須是內文的縮排項目。
       - **行動 (標題找回)**：
         - 當你看到一組編號條件 (如 1. 未住院... 2. 住院...)，請往該段落的**最底部**或**中間**尋找一個「孤立的名詞」（例如："普通傷病假"、"硬體規格"）。
         - 請務必將該名詞**搬移**到這些編號條件的**正上方**，作為該區塊的唯一大標題。
       - **行動 (解構)**：若「家庭照顧假」的內文黏在「事假」後面，請依語意強制切開，並獨立成一個章節。

    3. **表格語意敘述化 (Semantic Narration)**:
       - **行動**：將表格或條列式的「條件」與「數值」，改寫為**完整的敘述句**。
       - **造句邏輯**：建立「若符合 [條件]，則執行 [結果]」的句型。
       - *[範例]*：原資料 "1年以上, 7日" -> 改寫為 "工作年資滿 1 年以上者，給予休假 7 日。"

    4. **語境邏輯推論 (Contextual Inference) - 核心能力**:
       - 當遇到 OCR 亂碼或語意不通時，請使用以下**策略**進行修復：
       
       - **策略 A：利用定義反推 (Reverse Deduction)**
         - *[思考模式]*：若亂碼後方有「括號說明」或「定義」，請利用該定義推導出前方應有的正確術語。
         - *[範例]*：亂碼 + "(含原位癌)" -> 推導為 "惡性腫瘤" (因為腫瘤包含原位癌)。
         - *[範例]*：亂碼 + "(含CPU與RAM)" -> 推導為 "主機規格"。

       - **策略 B：字形與角色推論 (Shape & Agent Inference)**
         - *[思考模式]*：若主詞是無意義字，請依據文件類型推斷合理的角色。
         - *[範例]*：在勞工規章中看到 "昌工"(Chang-Gong) -> 推論為 "員工"(Yuan-Gong)。
         - *[範例]*：在租賃契約中看到 "房東" 寫成 "房束" -> 推論為 "房東"。

       - **策略 C：合理性檢查 (Plausibility Check)**
         - *[思考模式]*：若句子語意違反常理或目標，請修正為符合邏輯的表達。
         - *[範例]*：寫成 "避免恢復" (負面) -> 修正為 "以免妨礙恢復" (正面)。

    【嚴格輸出規範】
    1. **版面分隔**: 在每個獨立的**「大標題/章節」**之間，務必**空一行**。
    2. **純淨輸出**: 
       - **絕對禁止**使用 Markdown 代碼區塊 (如 ``` )。
       - **絕對禁止**任何開場白。
       - 直接從內容的第一個字開始輸出。

    【開始執行】
    """

    response = client.chat(
        model=STAGE2_MODEL,
        options={
            'num_ctx': 16384, 
            'temperature': 0.1
        },
        messages=[{'role':'user','content':prompt}]
    )
    
    structured_text = response["message"]["content"]
    
    # 雙重保險：清洗 Markdown 與 前後空白
    structured_text = structured_text.replace("```plaintext", "")
    structured_text = structured_text.replace("```", "")
    structured_text = structured_text.strip()

    print("[Stage 2] 文字整理完成")
    print(structured_text)
    return structured_text

# -----------------------
# 主程式流程
# -----------------------
def main():
    # Stage0: 前處理 (建議放大 3 倍以提升 devstral 辨識率)
    preprocessed_img = preprocess_image(IMAGE_PATH, PREPROCESSED_PATH, upscale=3)
    
    # Stage1: OCR (已包含接龍去重)
    raw_text = stage1_ocr(preprocessed_img, stream=STREAM_MODE)
    print("\n[合併後的原始文字]:\n", raw_text)
    print("==================")

    # Stage2: LLM整理 (使用 V10 邏輯推論)
    structured_text = stage2_llm(raw_text)
    
    # 輸出 (選用)
    """
    output_path = "pdf_cache2/holiday_final.txt"
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(structured_text)
    
    print(f"[完成] 最終文字存檔於: {output_path}")
    """

if __name__ == "__main__":
    main()
