import os
from PIL import Image
from ollama import Client
import cv2
import numpy as np

# -----------------------
# 參數設定
# -----------------------
IMAGE_PATH = "pdf_cache2/holiday.png"
PREPROCESSED_PATH = "pdf_cache2/holiday_preprocessed.png"
OLLAMA_SERVER = "http://140.116.240.181:45015/"
STAGE1_MODEL = "devstral-small-2:latest"
#STAGE2_MODEL = "gpt-oss:120b"
#STAGE2_MODEL = "llama3.3:70b-instruct-q5_K_M"
#STAGE2_MODEL = "gemma2:27b-instruct-fp16"
#STAGE2_MODEL = "command-r:35b-08-2024-q5_K_M"
STAGE2_MODEL = "mistral-small:24b-instruct-2501-q4_K_M"
STREAM_MODE = False
client = Client(host=OLLAMA_SERVER, headers={'Content-Type': 'application/json'})

# -----------------------
# Stage 0: 前處理
# -----------------------
# -----------------------
# Stage 0: 前處理 (OpenCV 加強版)
# -----------------------
def preprocess_image(image_path, save_path, upscale=3):
    print(f"[Stage 0] 正在處理: {image_path}")

    # 1. 使用 OpenCV 讀取圖片 (比 PIL 更適合做像素處理)
    # 讀進來就是 numpy array
    img = cv2.imread(image_path)
    
    if img is None:
        print(f"錯誤：找不到圖片 {image_path}")
        return None

    # 2. 放大 (Upscaling)
    # 使用 INTER_CUBIC (雙三次插值) 對文字邊緣更平滑
    height, width = img.shape[:2]
    img = cv2.resize(img, (width * upscale, height * upscale), interpolation=cv2.INTER_CUBIC)

    # 3. 轉灰階 (Grayscale)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 4. 【關鍵步驟】自適應二值化 (Adaptive Thresholding)
    # 這行是 PIL 做不到的。它會自動把背景變全白，文字變全黑。
    # 參數解釋:
    # 255: 最大值 (白色)
    # ADAPTIVE_THRESH_GAUSSIAN_C: 根據鄰域加權計算閾值 (抗陰影)
    # THRESH_BINARY: 黑白二值化
    # 15: 鄰域大小 (Block Size)，看多大範圍來決定黑白 (奇數)
    # 10: 常數 C，微調閾值 (越小雜訊越多，越大字越細)
    binary = cv2.adaptiveThreshold(
        gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 15, 10
    )

    # 5. (選用) 形態學操作 - 稍微加粗文字 (Dilation)
    # 如果字體太細 (如細明體)，這招很有效
    # kernel = np.ones((2, 2), np.uint8)
    # binary = cv2.erode(binary, kernel, iterations=1) # 腐蝕黑色背景=膨脹白色文字 (OpenCV邏輯)
    # 注意：若是黑字白底，要讓字變粗其實是用 erode (腐蝕白色)

    # 6. 存檔
    cv2.imwrite(save_path, binary)
    
    print(f"影像增強完成 (放大+二值化)，存於: {save_path}")
    return save_path

# -----------------------
# Helper: 智慧滑動視窗切圖 (Smart Chunking + Masking)
# -----------------------
def split_image_into_chunks(image_path, chunk_height=800, overlap=400):
    """
    智慧切片 V2.0：
    1. 自動偵測行與行之間的空白處進行切割，避免腰斬文字。
    2. 自動將切片上下邊緣塗白，防止 OCR 讀到殘字。
    """
    # 讀取圖片
    img = cv2.imread(image_path)
    if img is None:
        print(f"錯誤：無法讀取圖片 {image_path}")
        return []

    height, width = img.shape[:2]
    
    # 準備二值化影像用於計算投影 (文字變白，背景變黑)
    # 這樣我們計算每一行的總和時，數值最小的地方就是空白行
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, binary = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY_INV)

    chunks = []
    y = 0
    chunk_id = 0
    
    # 路徑設定
    base_dir = os.path.dirname(image_path)
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    
    print(f"啟動智慧切片 (目標高度:{chunk_height}, 重疊:{overlap})...")

    while y < height:
        # 1. 設定預計切點 (Hard Cut)
        target_cut = min(y + chunk_height, height)
        
        # 如果已經到底了，就直接切
        if target_cut == height:
            real_cut = height
        else:
            # 2. 【核心】尋找最佳切點 (Safe Cut)
            # 在 target_cut 上下 60 pixel 的範圍內，尋找「像素總和最小」的那一行 (即行距)
            search_range = 60
            
            # 確保搜尋範圍不超出邊界，且至少往下走一點
            start_search = max(y + int(chunk_height * 0.5), target_cut - search_range)
            end_search = min(target_cut + search_range, height)
            
            if start_search < end_search:
                # 計算水平投影 (Horizontal Projection)
                # axis=1 代表將每一橫列的像素值加總
                row_sums = np.sum(binary[start_search:end_search], axis=1)
                
                # 找到數值最小的 index (代表該行最乾淨)
                min_idx = np.argmin(row_sums)
                real_cut = start_search + min_idx
            else:
                real_cut = target_cut # 沒得選，只好硬切

        # 3. 執行切割
        chunk = img[y:real_cut, 0:width].copy() # copy 出來以免修改到原圖
        
        # 4. 【核心】邊緣塗白 (Edge Masking)
        # 把切口處最上面和最下面的 20px 塗成純白，防止 OCR 看到上一行或下一行的「半截殘字」
        mask_height = 20
        
        # 如果不是第一張，塗白頂部
        if chunk_id > 0 and chunk.shape[0] > mask_height:
            chunk[0:mask_height, :] = [255, 255, 255] # BGR 白色
            
        # 如果不是最後一張，塗白底部
        if real_cut < height and chunk.shape[0] > mask_height:
            chunk[-mask_height:, :] = [255, 255, 255] # BGR 白色

        # 5. 存檔
        chunk_filename = os.path.join(base_dir, f"{base_name}_chunk_{chunk_id}.png")
        cv2.imwrite(chunk_filename, chunk)
        chunks.append(chunk_filename)
        
        # 6. 判斷是否結束
        if real_cut == height:
            break
            
        # 7. 移動視窗 (Next Step)
        # 下一張的起點 = 當前切點 - 重疊區
        # 確保接縫處的文字，在下一張會完整出現在視窗中央
        y = real_cut - overlap
        chunk_id += 1
        
    print(f"智慧切片完成，共 {len(chunks)} 張")
    return chunks

# -----------------------
# Stage 1: OCR 使用 devstral-small-2
# -----------------------
def stage1_ocr(image_path, stream=True):
    # 這裡用 ollama 直接呼叫 devstral-small-2
    prompt = """
    你是一個 OCR 引擎。
    
    [任務]
    請將圖片中的文字「逐字轉錄」出來。

    [嚴格禁止]
    1. **禁止輸出 Markdown 表格**：不要畫 |---| 這種線。
    2. **禁止省略**：不要只輸出標題。
    3. **禁止摘要**。

    [輸出格式]
    請以「條列式」或「純文字」逐行輸出內容。如果原本是表格，請用空白鍵分隔欄位即可。
    """
    chunks = split_image_into_chunks(image_path)
    all_text_parts = []

    for i, chunk in enumerate(chunks):
        print(f"[Stage 1] 正在辨識切片 {i+1}/{len(chunks)} ...", end=" ", flush=True)

        try:
            # 呼叫 Ollama，根據 stream 參數決定模式
            response_generator = client.chat(
                model=STAGE1_MODEL,
                options={
                    'num_ctx': 8192,
                    'temperature': 0.1,    # 稍微降低，讓它專注於精確度
                    'repeat_penalty': 1.3, # 懲罰重複
                    'top_k': 40,
                    'num_predict': 512
                },
                messages=[{'role': 'user', 'content': prompt, 'images': [chunk]}],
                stream=stream  # 這裡動態傳入 True/False
            )

            chunk_buffer = ""
            
            # --- 模式 A: 串流模式 (Debug) ---
            if stream:
                # 需用迴圈讀取 Generator
                for segment in response_generator:
                    content = segment['message']['content']
                    print(content, end="", flush=True) # 即時印出
                    chunk_buffer += content
                print("\n") # 換行
                
            # --- 模式 B: 非串流模式 (Production) ---
            else:
                # 直接取得完整結果
                chunk_buffer = response_generator['message']['content']
                #print(chunk_buffer)
                #print("完成 (非串流模式)") # 簡單回報即可
            
            # =========================================================
            # 針對整塊 chunk_buffer 進行鬼打牆檢查
            # =========================================================
            
            # 1. 檢查是否發生「逗號分隔的無限重複」 (例如: 雇主支付，雇主支付，雇主支付...)
            if chunk_buffer.count("，") > 10 or chunk_buffer.count(",") > 10:
                mid = len(chunk_buffer) // 2
                # 簡單檢查：如果後半段跟前半段高度相似 (完全重複)
                # 這裡用 strip() 避免空白差異導致比對失敗
                if chunk_buffer[:mid].strip() == chunk_buffer[mid:].strip():
                    print("[自動修正] 偵測到嚴重重複，強制砍半內容。")
                    chunk_buffer = chunk_buffer[:mid]

            # 2. (進階建議) 如果不是完美砍半，而是像 "ABC, ABC, ABC" 這種，可以用這個更強的邏輯：
            # 如果字串很長，且最後 100 個字跟前面某段完全一樣，就截斷
            if len(chunk_buffer) > 200:
                tail = chunk_buffer[-50:] # 取最後 50 個字
                # 如果這最後 50 個字，在前面也出現過，且位置很近
                first_pos = chunk_buffer.find(tail)
                if first_pos != -1 and first_pos < len(chunk_buffer) - 50:
                    # 再次確認這不是巧合
                    print("[自動修正] 偵測到迴圈重複，截斷多餘部分。")
                    chunk_buffer = chunk_buffer[:first_pos + len(tail)]

            # 加入即時去重邏輯 (簡單版)
            # 防止單一切片內出現大量重複行
            unique_lines = []
            prev_line = None
            for line in chunk_buffer.split('\n'):
                clean_line = line.strip()
                if clean_line and clean_line != prev_line:
                    unique_lines.append(line)
                    prev_line = clean_line
            
            final_text = "\n".join(unique_lines)
            all_text_parts.append(final_text)

        except Exception as e:
            print(f"\n   ❌ 錯誤：{e}")
        
        finally:
             if os.path.exists(chunk):
                 os.remove(chunk)

    print("[Stage 1] OCR 全部完成")
    return "\n".join(all_text_parts)

# -----------------------
# Stage 2: 結構化 & 修正使用 GPT-OSS
# -----------------------
def stage2_llm(raw_text):
    prompt = f"""
    你是一位「高階文件重構專家」。
    任務：將 OCR (光學字元辨識) 產生的破碎文字，重組為結構嚴謹、邏輯通順的純文字文件。

    【原始 OCR 文字】
    \"\"\"{raw_text}\"\"\"

    【通用重構協議 (Universal Refactoring Protocol)】
    請依序執行以下邏輯，**不依賴特定文件類型**，而是依據「文件結構原理」：

    1. **結構歸位 (Structural Realignment)**:
       - **現象**：OCR 常因排版問題，先讀取到「條列內容」(List Items)，最後才讀取到該段落的「大標題」(Header)。
       - **行動**：請分析全文。若發現有「孤兒內容」（如 1. ... 2. ...）漂浮在某個「標題」的**上方**，且語意屬於該標題。
       - **執行**：務必將該「標題」**搬移**到這些內容的**最上方**，建立正確的階層關係。

    2. **視覺與語境除錯 (Visual & Contextual Denoising)**:
       - **原理**：OCR 錯誤通常源於「字形相似」但「語意不通」。
       - **行動**：遇到不通順的詞彙時，請執行「視覺聯想」與「邏輯校正」。
         - *形近字修復*：若看到「辯理」(Bian) 但語境是行政流程，修正為「辦理」(Ban)；若看到「昌工/雷工」，修正為「員工」。
         - *語意修復*：若看到「配偶女性產檢」，依常識修正為「配偶妊娠產檢」。
         - *邏輯修復*：若看到「避免...恢復」(負面語意)，依常理修正為「以免妨礙...恢復」(正面語意)。

    3. **表格平面化 (Table Linearization)**:
       - 將表格結構改寫為通順的敘述句。
       - 確保「條件」(Condition) 與「結果」(Result) 緊密相連，不可斷裂。
       - 刪除重複出現的頁首、頁尾或文件名稱。

    【嚴格輸出規範 (Strict Output Rules)】
    1. **版面分隔**: 在每個獨立的**「標題」、「章節」、「條款」或「語意區塊」**之間，務必**空一行**。
    2. **數值完整**: 嚴格保留所有數值條件（如天數、金額、百分比），不得摘要。
    3. **純淨輸出**: 
       - **絕對禁止**使用 Markdown 代碼區塊 (如 ```plaintext 或 ```)。
       - **絕對禁止**輸出任何開場白 (如 "好的"、"修正如下") 或結語。
       - 直接從內容的第一個字開始輸出。

    【開始執行】
    """
    """
    prompt = f
    你是一位「高階文件重構專家」。
    你的任務是將 OCR 產生的混亂文字，重組為結構嚴謹、邏輯通順的規範文件。

    【原始 OCR 文字】
    \"\"\"{raw_text}\"\"\"

    【通用重構協議 (Universal Refactoring Protocol)】
    請依序執行以下邏輯，**不依賴特定關鍵字**，而是依據文件結構原理：

    1. **標題上浮邏輯 (Header Hoisting) - 解決結構錯置**:
       - **現象**：OCR 常因為排版問題，先讀取到「細則內容」，最後才讀到該段落的「大標題」。
       - **行動**：請檢視整份文件。若發現有「孤兒段落」（如：未住院...、住院...）漂浮在某個「大標題」（如：普通傷病假）的**上方**，且內容屬於該標題。
       - **執行**：請務必將該「大標題」**搬移**到這些細則的**最上方**，成為統攝性的章節。

    2. **主體合理性檢查 (Agent Consistency)**:
       - **原理**：勞動法規的主詞通常是「員工」、「受僱者」。
       - **行動**：若看到「昌工」、「雷工」等無意義詞彙，請依據常識修正為「**員工**」。
       - **行動**：若看到「婦女」，為求用語一致，建議統一修復為「**女性員工**」。

    3. **醫療與語意去噪 (Semantic Denoising)**:
       - **原理**：OCR 遇到複雜筆畫會產生亂碼。
       - **行動 (醫療)**：若在「疾病/治療」的語境下看到亂碼（如：候症癥症、疑意、肺結），且後方括號有（含原位癌），請利用醫學常識推斷，修正為標準術語「**惡性腫瘤**」或「**癌症**」。
       - **行動 (邏輯)**：若句子語意違反常理（如：避免產後恢復），請修正為正面語意（如：**以免妨礙**產後恢復）。

    4. **表格平面化與歸位**:
       - 將表格內容改寫為通順語句。
       - 確保「工資給付規定」緊跟在該假別的最後。
       - 刪除重複的標題或頁首/頁尾（如重複出現的「國立成功大學...」）。

    【輸出規範】
    1. 每個假別（標題）之間務必**空一行**。
    2. 嚴格保留數值條件，不得摘要。
    3. 直接輸出整理後的內容。

    【開始執行】
    """

    response = client.chat(
        model=STAGE2_MODEL,
        options={'num_ctx':131072,'temperature':0.0},
        messages=[{'role':'user','content':prompt}]
    )
    structured_text = response["message"]["content"]
    print("[Stage 2] 文字整理完成")
    print(structured_text)
    return structured_text

# -----------------------
# 主程式流程
# -----------------------
def main():
    
    # Stage0: 前處理
    preprocessed_img = preprocess_image(IMAGE_PATH, PREPROCESSED_PATH)
    
    # Stage1: OCR
    raw_text = stage1_ocr(preprocessed_img, stream=STREAM_MODE)
    print(raw_text)
    print("==================")

    # Stage2: LLM整理
    structured_text = stage2_llm(raw_text)
    
    # 輸出
    """
    output_path = "pdf_cache2/holiday_final.txt"
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(structured_text)
    
    print(f"[完成] 最終文字存檔於: {output_path}")
    """
if __name__ == "__main__":
    main()
