# -*- coding: utf-8 -*-
# ==============================================================================
# Buddhist NMT System with Enhanced EDA and Compatibility Fixes
# VERSION 4.2 - Fixes EDA errors (NameError, TF-IDF, Collocations),
# Training Blocker (Tensorboard), Font Rendering Issues.
# ==============================================================================
# --- 1.1. Standard Libraries ---
import json
import os
import logging
import time
import random
import re # For regex splitting and word finding
import datetime
import base64
# <<< MODIFIED v4.2 >>> Added io for BytesIO
from io import BytesIO
from collections import Counter
from pathlib import Path # Use pathlib for paths
import itertools
import math
import sys
import warnings # To suppress specific warnings if needed
# --- 1.2. Core ML/NLP Libraries ---
import torch
from torch.utils.data import Dataset
# <<< MODIFIED >>> Import specific exception for version check later if needed
# Also import version
from transformers import MarianMTModel, MarianTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import __version__ as transformers_version
# <<< MODIFIED >>> Use parse_version for reliable comparison
try:
from pkg_resources import parse_version
except ImportError:
# Fallback for environments without pkg_resources (less common now)
from packaging.version import parse as parse_version
import pandas as pd
import numpy as np
# --- 1.3. Enhanced EDA & Visualization Libraries ---
try:
import matplotlib
matplotlib.use('Agg') # Set backend BEFORE importing pyplot
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import seaborn as sns
from wordcloud import WordCloud, STOPWORDS as WORDCLOUD_STOPWORDS
import nltk
from nltk.util import ngrams
from nltk.probability import FreqDist
# <<< MODIFIED v4.2 >>> BigramAssocMeasures/Finder are not used, remove imports
# from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
from sklearn.feature_extraction.text import TfidfVectorizer
EDA_LIBS_AVAILABLE = True
# Download NLTK data quietly
try: nltk.data.find('tokenizers/punkt')
except LookupError: nltk.download('punkt', quiet=True)
try: nltk.data.find('corpora/stopwords')
except LookupError: nltk.download('stopwords', quiet=True)
except ImportError as e:
print(f"Warning: EDA/Viz libraries missing: {e}. Features limited.")
EDA_LIBS_AVAILABLE = False
# Define placeholders
WordCloud, WORDCLOUD_STOPWORDS, nltk, TfidfVectorizer, sns, plt, fm = None, set(), None, None, None, None, None
# --- 1.4. Jupyter/IPython Specific ---
try:
# Use notebook version for better display
from tqdm.notebook import tqdm as tqdm_notebook
# Check if running in a notebook environment that supports rich display
from IPython import get_ipython
if get_ipython() is not None and 'IPKernelApp' in get_ipython().config:
NOTEBOOK_ENV = True
from IPython.display import display, HTML, Image
tqdm = tqdm_notebook # Use notebook tqdm
print("Notebook environment detected. Enabling rich outputs.")
else:
NOTEBOOK_ENV = False
from tqdm import tqdm # Fall back to regular tqdm
display, HTML, Image = print, lambda x: print(x), None # Fallback display
print("Non-notebook environment detected. Rich outputs disabled.")
# Configure pandas display options
pd.set_option('display.max_rows', 100); pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000); pd.set_option('display.max_colwidth', 150)
except ImportError:
from tqdm import tqdm # Fall back to regular tqdm
NOTEBOOK_ENV = False
display, HTML, Image = print, lambda x: print(x), None
print("Warning: IPython/ipywidgets not found. Rich outputs disabled.")
# --- 1.5. Other Libraries ---
import difflib
try:
# Check if jieba is installed and initialize
import jieba
try: jieba.setLogLevel(logging.INFO) # Suppress messages unless DEBUG
except AttributeError: pass
# <<< MODIFIED v4.2 >>> Quieter initialization test
try: _ = list(jieba.cut("测试初始化", HMM=False)) # Test initialization, optionally faster HMM=False
except Exception as jieba_init_err: print(f"Warning: Jieba installed but failed to initialize: {jieba_init_err}"); raise
JIEBA_AVAILABLE = True
print("Jieba initialized successfully.")
except ImportError:
# <<< MODIFIED >>> Clearer warning message
print("\n *** Warning: 'jieba' package not found. ***")
print(" Chinese word segmentation during EDA will use character-based splitting.")
print(" For better Chinese analysis, please install it: pip install jieba\n")
JIEBA_AVAILABLE = False
jieba = None
# <<< MODIFIED v4.2 >>> Removed redundant jieba init failure check here
# --- 1.6. Configuration & Font Handling ---
def find_and_set_chinese_font():
"""Try to find SimHei.ttf locally, common paths, and register it with Matplotlib."""
# <<< MODIFIED v4.2 >>> Added common Linux/Mac paths explicitly
potential_paths = [
'SimHei.ttf', 'simhei.ttf', # Local
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # Linux (WenQuanYi)
'/usr/share/fonts/truetype/simsun/simsun.ttc', # Linux (SimSun)
'/System/Library/Fonts/STHeiti Medium.ttc', # macOS (Heiti)
'/System/Library/Fonts/Supplemental/Songti.ttc', # macOS (Songti)
'/Library/Fonts/Arial Unicode MS.ttf', # macOS/Windows (if installed)
'C:/Windows/Fonts/msyh.ttc', # Windows (YaHei)
'C:/Windows/Fonts/simhei.ttf', # Windows (SimHei)
'C:/Windows/Fonts/simsun.ttc' # Windows (SimSun)
]
found_path = None
# <<< MODIFIED v4.2 >>> Simplified search logic
for font_path_str in potential_paths:
font_path = Path(font_path_str)
if font_path.is_file():
found_path = str(font_path.resolve())
print(f"Found potential Chinese font: {found_path}")
break
if not found_path:
# Try finding system fonts more generally if specific paths fail
if fm:
try:
# Look for fonts containing 'Hei', 'Song', 'Ming' common in Chinese names
system_fonts = fm.findSystemFonts(fontpaths=None, fontext='ttf')
for sys_font in system_fonts:
fname = Path(sys_font).name.lower()
if any(name in fname for name in ['hei', 'song', 'ming', 'kai', 'wqy', 'msyh']):
found_path = sys_font
print(f"Found potential system font: {found_path}")
break
except Exception as e:
print(f"Warning: Error searching system fonts: {e}")
if found_path and plt and fm:
try:
# Check fontManager cache before adding font
# <<< MODIFIED v4.2 >>> More robust check including ttflist
known_font_files = {Path(f).name for f in fm.findSystemFonts(fontpaths=None, fontext='ttf')} | \
{f.name for f in fm.fontManager.ttflist}
found_path_obj = Path(found_path)
if found_path_obj.name not in known_font_files:
try:
fm.fontManager.addfont(found_path)
print(f"Added font to fontManager: {found_path}")
# Rebuilding cache can be slow and sometimes problematic, often not needed interactively
# try: fm.fontManager.findfont(fm.FontProperties(fname=found_path), rebuild_if_missing=True)
# except: pass # Ignore errors during explicit rebuild
except Exception as add_err:
print(f"Warning: Could not add font {found_path} to fontManager: {add_err}")
prop = fm.FontProperties(fname=found_path)
font_name = prop.get_name() # Get the actual font name
# <<< MODIFIED v4.2 >>> Add font to rcParams more reliably
plt.rcParams['font.family'] = 'sans-serif' # Ensure sans-serif is the base family
# Prepend the found font to the sans-serif list
current_sans_serif = plt.rcParams['font.sans-serif']
if font_name not in current_sans_serif:
plt.rcParams['font.sans-serif'].insert(0, font_name)
plt.rcParams['axes.unicode_minus'] = False # Ensure minus sign displays correctly
print(f"Attempting to use font '{font_name}' from {found_path}.")
print(f"Current sans-serif list: {plt.rcParams['font.sans-serif']}")
return found_path # Return the path even if registration details are complex
except Exception as e:
print(f"Warning: Failed to fully register font '{found_path}': {e}")
return found_path # Return path even if registration has issues
elif found_path:
print("Found font, but Matplotlib/font_manager unavailable.")
return found_path
else:
print("Warning: No common Chinese font found or registered. Chinese plots may render incorrectly.")
return None
CHINESE_FONT_PATH = find_and_set_chinese_font()
# <<< MODIFIED v4.2 >>> Check backend after potential font changes
if plt:
try:
plt.switch_backend('Agg'); print(f"Matplotlib backend set to: {plt.get_backend()}")
except Exception as e:
print(f"Warning: Could not switch Matplotlib backend to Agg: {e}")
if plt.get_backend(): print(f"Using existing Matplotlib backend: {plt.get_backend()}")
# --- Logging Configuration ---
def setup_logging(output_dir, level=logging.INFO):
"""Sets up logging to console and file."""
log_dir = Path(output_dir) / "logs"; log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"buddhist_nmt_run_{datetime.datetime.now():%Y%m%d_%H%M%S}.log"
# <<< MODIFIED v4.2 >>> Clear root handlers more safely before basicConfig
root_logger = logging.getLogger()
if root_logger.hasHandlers():
for handler in root_logger.handlers[:]:
try: handler.close(); root_logger.removeHandler(handler)
except Exception as e: print(f"Warning: Error removing logging handler: {e}")
logging.basicConfig(level=level, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(log_file, mode='a', encoding='utf-8')])
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("nltk").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
# <<< MODIFIED v4.2 >>> Quieter Jieba logging setup
if level > logging.DEBUG and JIEBA_AVAILABLE and hasattr(jieba, 'setLogLevel'):
try: jieba.setLogLevel(logging.INFO) # Use INFO instead of WARNING to suppress non-errors
except Exception: pass # Ignore if fails
return logging.getLogger()
# ==============================================================================
# 2. Dataset Definition (Unchanged from v4.1)
# ==============================================================================
class BuddhistDataset(Dataset):
def __init__(self, samples, tokenizer, max_length=128):
if not isinstance(samples, list): raise TypeError(f"Expected list, got {type(samples)}")
if not all(isinstance(s, dict) and 'source' in s and 'target' in s for s in samples): raise ValueError("Items must be dicts with 'source'/'target'.")
self.samples = samples; self.tokenizer = tokenizer; self.max_length = max_length
self.logger = logging.getLogger("BuddhistDataset"); self.logger.info(f"Dataset init with {len(samples)} samples.")
def __len__(self): return len(self.samples)
def __getitem__(self, idx):
if idx >= len(self.samples): raise IndexError(f"Index {idx} out of bounds.")
item = self.samples[idx]; source_text = str(item.get('source', '')); target_text = str(item.get('target', ''))
try:
source_encoding = self.tokenizer(source_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
# Ensure target tokenization uses the correct context manager if needed (Marian often doesn't require explicit switch)
# Older Transformers might benefit from this:
# with self.tokenizer.as_target_tokenizer():
# target_encoding = self.tokenizer(target_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
# Newer/Marian usually handles this implicitly based on labels argument in Trainer
target_encoding = self.tokenizer(text_target=target_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
input_ids = source_encoding['input_ids'].squeeze(0); attention_mask = source_encoding['attention_mask'].squeeze(0); labels = target_encoding['input_ids'].squeeze(0)
labels[labels == self.tokenizer.pad_token_id] = -100 # Mask padding tokens for loss calculation
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}
except Exception as e:
self.logger.error(f"Tokenize error item {idx}: {e}. Src:'{source_text[:30]}...', Tgt:'{target_text[:30]}...'", exc_info=False)
pad_id = self.tokenizer.pad_token_id if self.tokenizer else 0
# Return dummy tensors on error to avoid crashing dataloader
return {'input_ids': torch.full((self.max_length,), pad_id, dtype=torch.long),
'attention_mask': torch.zeros((self.max_length,), dtype=torch.long),
'labels': torch.full((self.max_length,), -100, dtype=torch.long)}
# ==============================================================================
# 3. Full Corpus Analysis (Streaming) (With Fixes)
# ==============================================================================
def _get_ngrams(tokens, n):
if len(tokens) < n: return []
# <<< MODIFIED v4.2 >>> Use try-except inside list comprehension for robustness
return list(ngrams(tokens, n))
def _clean_ngram_for_display(ngram_tuple):
# Handles both tuples of strings and single strings (for unigrams stored as tuples/lists)
try:
if isinstance(ngram_tuple, tuple):
return " ".join(map(str, ngram_tuple))
else:
return str(ngram_tuple) # Assume it's already a string or single element
except Exception:
return str(ngram_tuple) # Fallback
# <<< MODIFIED v4.2 >>> Added helper for plotting full corpus ngrams
def _plot_full_ngram_frequencies_helper(freq_results, save_dir, top_n=20, chinese_font_path=None):
""" Plots top N N-gram frequencies from the FULL corpus analysis. """
logger = logging.getLogger("FullCorpusPlots")
if not EDA_LIBS_AVAILABLE or not plt or not sns or not fm: logger.warning("Plot libs unavailable."); return
if not freq_results: logger.warning("No freq results for plotting."); return
logger.info(f"Generating FULL corpus frequency plots (1-2 grams, Top {top_n})...")
save_dir = Path(save_dir); save_dir.mkdir(parents=True, exist_ok=True)
plt.style.use('seaborn-v0_8-whitegrid')
chinese_font_prop = None
if chinese_font_path and Path(chinese_font_path).is_file():
try:
chinese_font_prop = fm.FontProperties(fname=chinese_font_path)
logger.info(f"Using font prop: {chinese_font_path} for full corpus plots")
except Exception as e:
logger.warning(f"FontProp fail for full corpus plots: {e}")
elif chinese_font_path:
logger.warning(f"Specified Chinese font not found: {chinese_font_path}")
else:
logger.warning("No valid Chinese font for full corpus plots.")
lang_map = {"zh_src": ("Chinese Source", "YlGnBu"), "en_tgt": ("English Target", "viridis")}
for lang_key, (lang_name, palette_name) in lang_map.items():
freqs = freq_results.get(f"{lang_key.split('_')[1]}_freqs") # Get 'source_freqs' or 'target_freqs'
if not freqs:
logger.info(f"No frequency data found for {lang_name} in results.")
continue
is_chinese = 'zh' in lang_key
current_font_prop = chinese_font_prop if is_chinese else None
for n in range(1, 3): # Only plot 1-grams and 2-grams for full corpus to keep it manageable
fig = None
counter = freqs.get(n)
if not counter:
logger.info(f"No data for {lang_name} {n}-grams plot."); continue
try:
fig, ax = plt.subplots(figsize=(12, max(8, top_n * 0.4))) # Adjust height
common_items = counter.most_common(top_n)
if not common_items:
logger.info(f"No common {lang_name} {n}-grams.");
plt.close(fig); continue
labels = [_clean_ngram_for_display(item[0]) for item in common_items]
counts = [item[1] for item in common_items]
df_plot = pd.DataFrame({f'{n}-gram': labels, 'Frequency': counts})
palette = sns.color_palette(palette_name, n_colors=len(df_plot))
sns.barplot(x='Frequency', y=f'{n}-gram', data=df_plot, palette=palette, ax=ax, hue=f'{n}-gram', dodge=False, legend=False)
title = f'Top {top_n} {lang_name} {n}-grams (Full Corpus Sampled)'
# <<< MODIFIED v4.2 >>> Apply font prop to title
ax.set_title(title, fontproperties=current_font_prop if is_chinese else None)
ax.set_xlabel('Frequency')
ax.set_ylabel(f'{n}-gram')
# Apply font prop to y-tick labels (carefully)
try:
# <<< MODIFIED v4.2 >>> Set yticks explicitly before labels
ax.set_yticks(ticks=range(len(labels)))
ax.set_yticklabels(labels, fontproperties=current_font_prop if is_chinese else None)
except Exception as e:
logger.warning(f"Could not set font for y-tick labels on {lang_name} {n}-gram plot: {e}")
# Fallback: Use default font for labels if setting font fails
ax.set_yticklabels(labels)
# Add count labels to bars
for i, bar in enumerate(ax.patches):
try:
bar_width = bar.get_width()
# Adjust label position based on bar width to avoid overlap
x_pos = bar_width + (max(counts) * 0.01) if bar_width > 0 else 0.1
y_pos = bar.get_y() + bar.get_height() / 2
ax.text(x_pos, y_pos, f' {counts[i]:,}', va='center', ha='left', fontsize=9)
except IndexError: pass # Ignore if counts list doesn't match patches index
except Exception as label_err: logger.warning(f"Error adding label to bar: {label_err}")
plt.tight_layout()
plot_path = save_dir / f'full_corpus_top_{top_n}_{lang_key}_{n}gram_freq.png'
try:
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
logger.info(f"Saved full corpus plot: {plot_path}")
# Display logic remains the same as in sampled EDA plots
if NOTEBOOK_ENV and Image:
logger.debug(f"Attempting display: {plot_path}")
try: display(Image(filename=str(plot_path))); logger.debug(f"Displayed {plot_path} via filename.")
except Exception as e1:
logger.warning(f"Display(filename) fail: {e1}. Trying Image(data).")
try:
with open(plot_path,"rb") as f: img_bytes = f.read()
display(Image(data=img_bytes)); logger.debug(f"Displayed {plot_path} via data.")
except Exception as e2: logger.error(f"Display(data) fail: {e2}"); print(f"[Info] Plot saved to {plot_path}, display failed.")
elif NOTEBOOK_ENV: logger.warning(f"Cannot display {plot_path}: Image obj unavailable.")
except Exception as e:
logger.error(f"Save/Display plot failed {plot_path}: {e}")
except Exception as plot_err:
logger.error(f"Plotting {lang_name} {n}-g failed: {plot_err}", exc_info=True)
finally:
if fig is not None and plt.fignum_exists(fig.number):
plt.close(fig)
def analyze_full_corpus_streaming(corpus_path, output_dir, chinese_font_path=None, jieba_available=False, sample_rate=1.0, max_ngrams_to_save=10000, word_cloud_max_words=150):
logger = logging.getLogger("FullCorpusAnalysis")
if not (WordCloud and plt and pd and EDA_LIBS_AVAILABLE):
logger.error("Required EDA libraries (WordCloud, Matplotlib, Pandas) are missing. Cannot perform full corpus analysis.")
return None
corpus_path = Path(corpus_path); output_dir = Path(output_dir); output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"--- Starting Full Corpus Streaming Analysis: {corpus_path} ---")
logger.info(f"Saving outputs to: {output_dir}")
if sample_rate < 1.0: logger.info(f"Using sample rate: {sample_rate*100:.1f}%")
if not corpus_path.is_file(): logger.error(f"Corpus file not found: {corpus_path}"); return None
# Initialize frequency counters and stats
src_freqs = {n: Counter() for n in range(1, 5)} # 1-gram to 4-gram for source (Chinese)
tgt_freqs = {n: Counter() for n in range(1, 5)} # 1-gram to 4-gram for target (English)
stats = {"proc": 0, "skip_json": 0, "skip_other": 0, "src_chars": 0, "tgt_chars": 0, "src_tokens": 0, "tgt_tokens": 0, "valid_pairs": 0}
# Setup stopwords (similar to BuddhistTextAnalyzer)
try:
temp_analyzer = BuddhistTextAnalyzer("dummy") # Initialize dummy to get stopwords
en_stop = temp_analyzer.english_stopwords - temp_analyzer.buddhist_terms_en
zh_stop = temp_analyzer.chinese_stopwords
logger.info(f"Using {len(en_stop)} EN stopwords and {len(zh_stop)} ZH stopwords/chars.")
except Exception as e:
logger.warning(f"Failed to initialize stopwords via BuddhistTextAnalyzer: {e}. Using basic NLTK/hardcoded lists.")
try: en_stop = set(nltk.corpus.stopwords.words('english')) if nltk else set()
except LookupError: logger.warning("NLTK stopwords not found. Using empty set for EN."); en_stop = set()
zh_stop = set(['的','了','和','是','在','我','有','就','不','也','人','都','說','此','彼',',','。',';','、',':','"','"','?','!','(',')','【','】','《','》','「','」','『','』',' ','\n','\t','之','其','或','亦','方','於','即','皆','因','仍','故','尚','者','曰','云']) # Basic list
# <<< MODIFIED v4.2 FIX >>> Correct definition for word cloud stopwords
# Use the 'en_stop' defined above for consistency
en_stop_wc = set(WORDCLOUD_STOPWORDS) | en_stop if WORDCLOUD_STOPWORDS else en_stop
total_lines = None
try:
# Estimate total lines for progress bar
logger.info("Estimating total lines in corpus...")
with corpus_path.open('r', encoding='utf-8') as f_count:
total_lines = sum(1 for _ in f_count)
logger.info(f"Corpus contains approximately {total_lines:,} lines.")
except Exception as e:
logger.warning(f"Could not estimate total lines: {e}")
logger.info("Processing corpus stream...")
start_time = time.time()
try:
with corpus_path.open('r', encoding='utf-8') as f:
# Setup tqdm progress bar
pbar_desc = f"Streaming Analysis ({sample_rate*100:.0f}%)"
pbar = tqdm(f, desc=pbar_desc, unit=" lines", total=total_lines, disable=total_lines is None, mininterval=2.0)
for i, line in enumerate(pbar):
stats["proc"] += 1
if random.random() > sample_rate: # Apply sampling
continue
try:
item = json.loads(line.strip())
src = item.get('source_sentence', '')
tgt = item.get('translation', '')
src_tok, tgt_tok = [], []
# Process Source (Chinese)
if isinstance(src, str) and src.strip():
stats["src_chars"] += len(src)
if jieba_available and jieba:
# Use jieba for word segmentation
toks = [w for w in jieba.cut(src) if w.strip() and w not in zh_stop]
else:
# Use character-based segmentation if jieba is not available
toks = [c for c in src if c.strip() and c not in zh_stop]
src_tok = toks
stats["src_tokens"] += len(toks)
# Update n-gram counts for source
for n in range(1, 5):
if len(toks) >= n:
src_freqs[n].update(_get_ngrams(toks, n))
# Process Target (English)
if isinstance(tgt, str) and tgt.strip():
stats["tgt_chars"] += len(tgt)
# Use regex to find words, convert to lowercase
words = re.findall(r'\b[a-zA-Z]{2,}\b', tgt.lower()) # Find words of 2+ letters
toks = [w for w in words if w not in en_stop] # Filter stopwords
tgt_tok = toks
stats["tgt_tokens"] += len(toks)
# Update n-gram counts for target (store ngrams as space-separated strings)
for n in range(1, 5):
if len(toks) >= n:
# Store as strings for easier handling later
ngrams_list = [" ".join(g) for g in _get_ngrams(toks, n)]
tgt_freqs[n].update(ngrams_list)
# Increment valid pair count if both source and target produced tokens
if src_tok and tgt_tok:
stats["valid_pairs"] += 1
# Update progress bar postfix less frequently
if (i + 1) % 50000 == 0:
pbar.set_postfix({"Valid Pairs": f"{stats['valid_pairs']:,}", "Skipped": f"{stats['skip_json'] + stats['skip_other']:,}"}, refresh=True)
except json.JSONDecodeError:
stats["skip_json"] += 1
except Exception as line_err:
# logger.warning(f"Skipping line {i+1} due to error: {line_err}", exc_info=False) # Reduce log noise
stats["skip_other"] += 1
pbar.close() # Ensure progress bar finishes cleanly
processing_time = time.time() - start_time
logger.info(f"Corpus streaming finished in {processing_time:.2f} seconds.")
logger.info(f"Lines processed: {stats['proc']:,}. Skipped JSON: {stats['skip_json']:,}, Other: {stats['skip_other']:,}.")
logger.info(f"Valid source-target pairs found (sampled): {stats['valid_pairs']:,}")
logger.info(f"Total Tokens (sampled): Source(ZH) {stats['src_tokens']:,}, Target(EN) {stats['tgt_tokens']:,}")
# Consolidate results
results = {"stats": stats, "source_freqs": src_freqs, "target_freqs": tgt_freqs}
# --- Save N-gram Frequencies to CSV ---
logger.info("Saving top N-gram frequencies to CSV...")
for lang_prefix, freqs_dict in [("zh_src", src_freqs), ("en_tgt", tgt_freqs)]:
for n, counter in freqs_dict.items():
if counter:
top_ngrams = counter.most_common(max_ngrams_to_save)
if not top_ngrams: continue # Skip if empty
df_ngrams = pd.DataFrame(top_ngrams, columns=['Ngram', 'Frequency'])
# Clean Ngrams for display before saving (handles tuples)
df_ngrams['Ngram'] = df_ngrams['Ngram'].apply(_clean_ngram_for_display)
filename = output_dir / f"full_{lang_prefix}_{n}g_top{max_ngrams_to_save}.csv"
try:
df_ngrams.to_csv(filename, index=False, encoding='utf-8-sig') # Use utf-8-sig for Excel compatibility
logger.info(f"Saved {lang_prefix} {n}-grams: {filename}")
except Exception as e:
logger.error(f"Failed to save {filename}: {e}")
# Display top 10 in notebook if environment detected
if NOTEBOOK_ENV and n <= 2: # Show only top 1-grams and 2-grams
print(f"\n--- Top 10 Full Corpus {lang_prefix} {n}-grams ---")
display(df_ngrams.head(10))
# --- Generate Word Clouds ---
logger.info("Generating word clouds for full corpus analysis...")
# Define colormaps for different n-grams and languages
zh_colormaps = {1: 'tab20c', 2: 'Set2', 3: 'Accent', 4: 'Paired'}
en_colormaps = {1: 'viridis', 2: 'plasma', 3: 'magma', 4: 'cividis'}
# Generate Source (Chinese) Word Clouds
for n, counter in src_freqs.items():
if counter:
# Prepare frequency data for word cloud (needs string keys)
wc_freq_data = {_clean_ngram_for_display(k): v for k, v in counter.items()}
cloud_title = f"Full Corpus (Sampled {sample_rate*100:.0f}%) - Chinese Source {n}-grams"
cloud_filename = output_dir / f"full_zh_src_{n}g_cloud.png"
_plot_word_cloud_helper(wc_freq_data, cloud_title, cloud_filename,
zh_colormaps.get(n, 'tab20c'), 'white',
word_cloud_max_words, chinese_font_path, is_chinese=True)
# Generate Target (English) Word Clouds
for n, counter in tgt_freqs.items():
if counter:
# Frequency data is already string-keyed here
cloud_title = f"Full Corpus (Sampled {sample_rate*100:.0f}%) - English Target {n}-grams"
cloud_filename = output_dir / f"full_en_tgt_{n}g_cloud.png"
# Apply stopwords only to unigrams (1-grams)
current_stopwords = en_stop_wc if n == 1 else set()
_plot_word_cloud_helper(counter, cloud_title, cloud_filename,
en_colormaps.get(n, 'viridis'), 'black',
word_cloud_max_words, None, # No specific font needed for English
is_chinese=False,
stopwords=current_stopwords,
collocations=(n == 1)) # Enable collocations for unigrams to potentially group common adjacent words
# --- Generate N-gram Frequency Plots (Optional) ---
# <<< MODIFIED v4.2 >>> Added call to plot helper
_plot_full_ngram_frequencies_helper(results, output_dir, top_n=20, chinese_font_path=chinese_font_path)
logger.info("--- Full Corpus Streaming Analysis Finished ---")
return results
except Exception as e:
logger.error(f"Error during full corpus streaming analysis: {e}", exc_info=True)
return None
def _plot_word_cloud_helper(frequency_data, title, filepath, colormap, background_color, max_words, font_path, is_chinese, stopwords=None, collocations=False):
"""Helper function to generate and save a single word cloud."""
logger = logging.getLogger("WordCloudHelper")
fig = None # Initialize fig to None
# Basic checks
if not WordCloud or not plt:
logger.warning(f"WordCloud/Matplotlib unavailable, skipping cloud: {title}")
return
if not frequency_data:
logger.warning(f"No frequency data provided for word cloud: {title}")
return
filepath = Path(filepath)
font_path_use = font_path if is_chinese and font_path else None
# Check font path validity specifically for Chinese clouds
if is_chinese:
if font_path_use and Path(font_path_use).is_file():
logger.debug(f"Using font {font_path_use} for Chinese word cloud: '{title}'")
elif font_path_use:
logger.warning(f"Specified Chinese font not found: {font_path_use}. WordCloud may fail or render incorrectly for '{title}'.")
font_path_use = None # Fallback to default if path invalid
else:
logger.warning(f"No Chinese font path provided for WordCloud: '{title}'. Using default font.")
try:
# Ensure keys are strings for WordCloud input
freq_dict = {str(k): v for k, v in frequency_data.items()}
if not freq_dict:
logger.warning(f"Frequency dictionary is empty after processing for: {title}")
return
# Create WordCloud object
# <<< MODIFIED v4.2 >>> Added prefer_horizontal for better readability
wc = WordCloud(width=1200, height=600,
background_color=background_color,
font_path=font_path_use, # Pass the validated font path
max_words=max_words,
colormap=colormap,
stopwords=stopwords,
collocations=collocations,
prefer_horizontal=0.95 # Prefer horizontal layout
).generate_from_frequencies(freq_dict)
# Plotting
fig, ax = plt.subplots(figsize=(12, 6))
ax.imshow(wc, interpolation='bilinear')
ax.axis("off")
# <<< MODIFIED v4.2 >>> Apply font properties to title if Chinese
font_prop = fm.FontProperties(fname=font_path_use) if is_chinese and font_path_use and fm else None
ax.set_title(title, fontsize=16, fontproperties=font_prop)
plt.tight_layout(pad=0.5)
plt.savefig(filepath, dpi=150, bbox_inches='tight')
logger.info(f"Word cloud saved: '{filepath}'")
# Display in notebook
if NOTEBOOK_ENV and Image:
try:
# <<< MODIFIED v4.2 >>> Display using filename first, fallback to data
logger.debug(f"Attempting display word cloud: {filepath}")
try:
display(Image(filename=str(filepath)))
logger.debug(f"Displayed {filepath} via filename.")
except Exception as e1:
logger.warning(f"Display WordCloud(filename) fail: {e1}. Trying Image(data).")
try:
with open(filepath,"rb") as f: img_bytes = f.read()
display(Image(data=img_bytes))
logger.debug(f"Displayed {filepath} via data.")
except Exception as e2:
logger.error(f"Display WordCloud(data) fail: {e2}")
except Exception as display_err:
logger.warning(f"Could not display word cloud image {filepath}: {display_err}")
plt.close(fig) # Close the figure to release memory
except Exception as e:
logger.error(f"Failed to generate word cloud for '{title}': {e}", exc_info=False) # Set exc_info=False for less noise unless debugging
# Provide specific hint for font errors
if is_chinese and "font" in str(e).lower():
logger.error(f"Hint: This might be a Chinese font issue. Check font path: '{font_path_use}'. Error details: {e}")
# Ensure figure is closed if an error occurred during plotting/saving
if fig is not None and plt and plt.fignum_exists(fig.number):
plt.close(fig)
# ==============================================================================
# 4. Exploratory Data Analysis (EDA) - Sample Based (With Fixes)
# ==============================================================================
class BuddhistTextAnalyzer:
""" Class for EDA on a SAMPLE of the Buddhist text corpus. """
def __init__(self, corpus_path, max_samples=20000, output_dir_eda=None): # Added output_dir_eda
self.corpus_path = Path(corpus_path); self.max_samples = max_samples; self.logger = logging.getLogger("BuddhistTextAnalyzer")
self.raw_data = []; self.df = pd.DataFrame()
# <<< MODIFIED v4.2 >>> Store output dir within the analyzer
self.output_dir = Path(output_dir_eda) if output_dir_eda else Path("./eda_sampled_outputs")
self.output_dir.mkdir(parents=True, exist_ok=True)
self._setup_stopwords(); self._define_categories_and_terms()
self.word_freq_data = None; self.tfidf_vectorizer_en = None; self.tfidf_matrix_en = None; self.feature_names_en = None; self.tfidf_vectorizer_zh = None; self.tfidf_matrix_zh = None; self.feature_names_zh = None
if self.corpus_path.name != "dummy": self.load_sampled_data()
else: self.logger.info("Analyzer initialized with dummy path.")
def _setup_stopwords(self): # Unchanged from v4.1
self.english_stopwords = set();
if nltk:
try: self.english_stopwords = set(nltk.corpus.stopwords.words('english')); self.english_stopwords.update(['us','say','said','also','like','etc','one','two','three','oh','well']) # Simplified additions
except Exception as e: self.logger.warning(f"NLTK stopword fail: {e}. Using basic."); self.english_stopwords=set(['the','of','and','to','a','in','is','it','that','was'])
self.chinese_stopwords = set(['的','了','和','是','在','我','有','就','不','也','人','都','說','此','彼',',','。',';','、',':','"','"','?','!','(',')','【','】','《','》','「','」','『','』',' ','\n','\t','之','其','或','亦','方','於','即','皆','因','仍','故','尚','者','曰','云']); self.chinese_stopwords.update([str(i) for i in range(10)]); self.chinese_stopwords.update(list('abcdefghijklmnopqrstuvwxyz'))
def _define_categories_and_terms(self): # Unchanged from v4.1
self.buddhist_terms_en = set(['buddha','dharma','sangha','sutra','nirvana','bodhisattva','karma','meditation','enlightenment','mindfulness','emptiness','sunyata','samsara','mantra','zen','arahant','arhat','bhikkhu','bhikshu','stupa','vinaya','abhidharma','prajna','paramita','tathagata','bodhicitta','vajrayana','theravada','mahayana'])
self.categories_en = {"Buddha": ['buddha', 'tathagata'], "Bodhisattva": ['bodhisattva', 'mahasattva'], "Concepts": ['emptiness', 'sunyata', 'prajna', 'paramita', 'nirvana', 'samsara', 'karma', 'wisdom'], "Practice": ['dharma', 'sutra', 'vinaya', 'meditation', 'mindfulness', 'sila', 'dana', 'mantra'], "Persons": ['sangha', 'bhikkhu', 'arhat', 'monk', 'practitioner']}
self.category_keywords_flat_en = {t for terms in self.categories_en.values() for t in terms}; self.term_to_category_en = {t: name for name, terms in self.categories_en.items() for t in terms}
self.categories_zh = {"佛名": ['佛', '如來', '世尊'], "菩薩": ['菩薩', '大士', '摩訶薩'], "概念": ['空', '無我', '無常', '苦', '涅槃', '緣起', '般若', '智慧', '法界', '心'], "法/修": ['法', '經', '律', '論', '戒', '定', '慧', '禪', '三昧', '道'], "人物": ['僧', '比丘', '阿羅漢', '聲聞', '眾生']}
self.category_keywords_flat_zh = {t for terms in self.categories_zh.values() for t in terms}; self.term_to_category_zh = {t: name for name, terms in self.categories_zh.items() for t in terms}
maps = ['Paired', 'Set1', 'Set3', 'tab10', 'Accent', 'Dark2', 'Set2', 'tab20b', 'tab20c']; n_en = len(self.categories_en)
self.category_colormaps_en = {name: maps[i % len(maps)] for i, name in enumerate(self.categories_en.keys())}; self.category_colormaps_zh = {name: maps[(n_en + i) % len(maps)] for i, name in enumerate(self.categories_zh.keys())}
self.category_colormaps_en["Overall"] = 'viridis'; self.category_colormaps_zh["Overall"] = 'tab20c'
def load_sampled_data(self): # Unchanged from v4.1
self.logger.info(f"Loading up to {self.max_samples:,} samples from {self.corpus_path}...");
if not self.corpus_path.is_file(): raise FileNotFoundError(f"Corpus not found: {self.corpus_path}")
loaded, skip_j, skip_o, data = 0, 0, 0, []
try:
with self.corpus_path.open('r', encoding='utf-8') as f:
# Estimate total lines for better tqdm display
total_lines_est = None
try:
f.seek(0); total_lines_est = sum(1 for _ in f); f.seek(0)
except: pass # Ignore if estimation fails
pbar = tqdm(f, desc="Loading sample", unit=" lines", total=self.max_samples if self.max_samples else total_lines_est, disable=self.max_samples is None and total_lines_est is None, leave=False)
for line in pbar:
if self.max_samples and loaded >= self.max_samples: break
try:
item=json.loads(line.strip()); src=item.get('source_sentence'); tgt=item.get('translation')
if isinstance(src,str) and isinstance(tgt,str) and src.strip() and tgt.strip():
src, tgt = src.strip(), tgt.strip(); src_len, tgt_len = len(src), len(tgt)
# Apply basic length filter here too
if not (1 <= src_len <= 1024 and 1 <= tgt_len <= 1024): skip_o += 1; continue
# Calculate word counts (approx for ZH if no Jieba)
tgt_w = len(re.findall(r'\b\w+\b', tgt))
src_w = len(jieba.lcut(src)) if JIEBA_AVAILABLE and jieba else src_len # Use jieba if available, else char count
data.append({'source':src,'target':tgt,'slen':src_len,'tlen':tgt_len,'swords':src_w,'twords':tgt_w,'lratio':tgt_len/src_len if src_len>0 else 0}); loaded+=1
if loaded % 5000 == 0: pbar.set_postfix({"Loaded": f"{loaded:,}"}, refresh=True)
else: skip_o+=1
except json.JSONDecodeError: skip_j+=1
except Exception: skip_o+=1
pbar.close()
self.logger.info(f"Sample load done. Skipped JSON:{skip_j:,}, Other/Length:{skip_o:,}.")
if not data: self.logger.warning("No valid sample data loaded."); self.df=pd.DataFrame()
else: self.df=pd.DataFrame(data); self.logger.info(f"Loaded {len(self.df):,} samples into DataFrame."); self._add_tokenized_columns()
except FileNotFoundError: self.logger.error(f"Corpus file not found: {self.corpus_path}"); raise
except Exception as e: self.logger.error(f"Sample load error: {e}", exc_info=True); self.df=pd.DataFrame()
def _add_tokenized_columns(self): # Mostly unchanged, minor logging adjustment
"""Add tokenized columns for English and Chinese text analysis."""
if self.df.empty or not EDA_LIBS_AVAILABLE:
self.logger.info("Skipping tokenization: DataFrame is empty or EDA libs unavailable.")
return
if 'target' not in self.df.columns or 'source' not in self.df.columns:
self.logger.warning("Skipping tokenization: Missing 'source' or 'target' columns.")
return
self.logger.info("Tokenizing sampled data for EDA...")
# English tokenization
self.df['target'] = self.df['target'].fillna('').astype(str)
en_stop = self.english_stopwords - self.buddhist_terms_en
try:
self.logger.info("Tokenizing English target text...")
tqdm.pandas(desc="Tokenizing EN")
self.df['en_1g'] = self.df['target'].progress_apply(
lambda t: [w for w in nltk.word_tokenize(t.lower())
if w.isalpha() and w not in en_stop and len(w) > 1]
if pd.notna(t) else []
)
self.logger.info("Generating English n-grams (2-4)...")
for n in range(2, 5):
tqdm.pandas(desc=f"Generating EN {n}-grams")
self.df[f'en_{n}g'] = self.df['en_1g'].progress_apply(
lambda tokens: [" ".join(g) for g in _get_ngrams(tokens, n)] if tokens else []
)
except Exception as e:
self.logger.error(f"Error during English tokenization/n-gram generation: {e}", exc_info=True)
# Add empty list columns as placeholders if error occurs
for n in range(1, 5): self.df[f'en_{n}g'] = [[] for _ in range(len(self.df))]
# Chinese tokenization
self.df['source'] = self.df['source'].fillna('').astype(str)
try:
if JIEBA_AVAILABLE and jieba:
self.logger.info("Tokenizing Chinese source text (using Jieba)...")
tqdm.pandas(desc="Tokenizing ZH (Jieba)")
self.df['zh_1g'] = self.df['source'].progress_apply(
lambda t: [w for w in jieba.cut(t)
if w.strip() and w not in self.chinese_stopwords]
if pd.notna(t) else []
)
else:
self.logger.warning("Jieba not available. Using character-based tokenization for Chinese.")
tqdm.pandas(desc="Tokenizing ZH (Chars)")
self.df['zh_1g'] = self.df['source'].progress_apply(
lambda t: [c for c in t
if c.strip() and c not in self.chinese_stopwords]
if pd.notna(t) else []
)
self.logger.info("Generating Chinese n-grams (2-4)...")
for n in range(2, 5):
tqdm.pandas(desc=f"Generating ZH {n}-grams")
# Store ZH n-grams as tuples of characters/words
self.df[f'zh_{n}g'] = self.df['zh_1g'].progress_apply(
lambda tokens: list(_get_ngrams(tokens, n)) if tokens else []
)
except Exception as e:
self.logger.error(f"Error during Chinese tokenization/n-gram generation: {e}", exc_info=True)
for n in range(1, 5): self.df[f'zh_{n}g'] = [[] for _ in range(len(self.df))]
self.logger.info("Tokenization and n-gram generation complete.")
def _analyze_basic_stats(self): # Unchanged from v4.1
"""Analyze basic statistics of the dataset."""
if self.df is None or self.df.empty: self.logger.warning("No sample data for basic stats."); return None
self.logger.info(f"Calculating basic statistics for {len(self.df):,} samples.")
numeric_cols = ['slen', 'tlen', 'swords', 'twords', 'lratio']
valid_cols = [c for c in numeric_cols if c in self.df.columns and pd.api.types.is_numeric_dtype(self.df[c])]
if not valid_cols: self.logger.warning("No valid numeric columns for basic stats."); return None
stats = self.df[valid_cols].describe()
avg_ratio_str = "N/A"
if 'lratio' in valid_cols:
valid_ratios = self.df.loc[self.df['lratio'].notna() & np.isfinite(self.df['lratio']), 'lratio']
avg_ratio_str = f"{valid_ratios.mean():.2f}" if not valid_ratios.empty else "N/A"
summary_df = pd.DataFrame({
"Metric": ["Total Samples", "Avg Length Ratio (Tgt/Src)"],
"Value": [f"{len(self.df):,}", avg_ratio_str]
}).set_index("Metric")
print("\nOverall Summary:")
try: display(summary_df)
except Exception: print(summary_df) # Fallback print
print("\nDetailed Numeric Stats:")
try: display(stats.round(2))
except Exception: print(stats.round(2))
return stats # Return the detailed stats DataFrame
def _analyze_length_distributions_matplotlib(self, save_dir): # Unchanged from v4.1
"""Generate length distribution plots using matplotlib."""
if not EDA_LIBS_AVAILABLE or not plt or not sns or self.df is None or self.df.empty:
self.logger.warning("Plotting libraries unavailable or no data. Skipping length plots.")
return
required_cols = ['slen', 'tlen', 'swords', 'twords', 'lratio']
available_cols = [c for c in required_cols if c in self.df.columns]
if not available_cols: self.logger.warning("Missing required columns for length plots."); return
self.logger.info("Generating length distribution plots...")
save_dir = Path(save_dir); plt.style.use('seaborn-v0_8-whitegrid')
plot_map = {'slen': ('Source Length (chars)', 'Blues'), 'tlen': ('Target Length (chars)', 'Oranges'),
'swords': ('Source Words/Chars', 'Greens'), 'twords': ('Target Words', 'Purples')}
plot_cols = [c for c in plot_map if c in available_cols]
if not plot_cols: return
n_rows = len(plot_cols); fig, axes = plt.subplots(n_rows, 2, figsize=(15, 5 * n_rows), squeeze=False)
fig.suptitle('Sampled Text Length Analysis (Filtered to 99th Percentile)', fontsize=16, y=1.02)
for i, col in enumerate(plot_cols):
title, cmap = plot_map[col]
data_col = self.df[col].dropna()
if data_col.empty: continue
q99 = data_col.quantile(0.99)
plot_data = data_col[data_col <= q99] if not data_col[data_col <= q99].empty else data_col # Ensure plot_data is not empty
# Histogram
sns.histplot(plot_data, kde=True, ax=axes[i, 0], color=sns.color_palette(cmap)[2], bins=50)
axes[i, 0].set_title(f'{title} Distribution')
axes[i, 0].set_xlabel(''); axes[i, 0].set_ylabel('Frequency')
# Box Plot
sns.boxplot(x=plot_data, ax=axes[i, 1], color=sns.color_palette(cmap)[1])
axes[i, 1].set_title(f'{title} Box Plot')
axes[i, 1].set_xlabel('Length/Count')
plt.tight_layout(rect=[0, 0.03, 1, 0.98]) # Adjust layout
plot_path = save_dir / 'sampled_length_distributions.png'
try:
plt.savefig(plot_path, dpi=150, bbox_inches='tight'); self.logger.info(f"Length distribution plot saved: {plot_path}")
if NOTEBOOK_ENV and Image: display(Image(filename=str(plot_path)))
except Exception as e: self.logger.error(f"Failed to save/display length distribution plot: {e}")
finally: plt.close(fig)
# Ratio plot (separate)
if 'lratio' in available_cols:
fig_r, axes_r = plt.subplots(1, 2, figsize=(14, 5))
fig_r.suptitle('Target/Source Length Ratio (Filtered 2nd-98th Percentile)', fontsize=14, y=1.02)
ratios = self.df.loc[self.df['lratio'].notna() & np.isfinite(self.df['lratio']), 'lratio']
if not ratios.empty:
q02, q98 = ratios.quantile(0.02), ratios.quantile(0.98)
plot_ratios = ratios[(ratios >= q02) & (ratios <= q98)]
if not plot_ratios.empty:
sns.histplot(plot_ratios, kde=True, color='teal', bins=40, ax=axes_r[0]); axes_r[0].set_title('Ratio Distribution'); axes_r[0].set_xlabel('Ratio'); axes_r[0].set_ylabel('Frequency')
sns.boxplot(x=plot_ratios, color='lightseagreen', ax=axes_r[1]); axes_r[1].set_title('Ratio Box Plot'); axes_r[1].set_xlabel('Ratio')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plot_path_r = save_dir / 'sampled_length_ratio_distribution.png'
try:
plt.savefig(plot_path_r, dpi=150, bbox_inches='tight'); self.logger.info(f"Length ratio plot saved: {plot_path_r}")
if NOTEBOOK_ENV and Image: display(Image(filename=str(plot_path_r)))
except Exception as e: self.logger.error(f"Failed to save/display length ratio plot: {e}")
finally: plt.close(fig_r)
else: plt.close(fig_r) # Close if no valid ratios
# Correlation plot (separate)
if 'slen' in available_cols and 'tlen' in available_cols:
print("\n--- Length Correlation Plot ---")
fig_s, ax_s = plt.subplots(figsize=(8, 6))
# Sample data for correlation plot to avoid overplotting
df_sample = self.df.sample(min(5000, len(self.df)), random_state=42)
q99_x, q99_y = df_sample['slen'].quantile(0.99), df_sample['tlen'].quantile(0.99)
df_plot_s = df_sample[(df_sample['slen'] <= q99_x) & (df_sample['tlen'] <= q99_y)]
if not df_plot_s.empty:
sns.scatterplot(data=df_plot_s, x='slen', y='tlen', alpha=0.4, s=15, ax=ax_s, color='darkgreen')
ax_s.set_title('Source vs Target Length Correlation (Sampled, ≤99th Pctl)')
ax_s.set_xlabel('Source Length (chars)'); ax_s.set_ylabel('Target Length (chars)')
ax_s.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plot_path_s = save_dir / 'sampled_length_correlation.png'
try:
plt.savefig(plot_path_s, dpi=150, bbox_inches='tight'); self.logger.info(f"Correlation plot saved: {plot_path_s}")
if NOTEBOOK_ENV and Image: display(Image(filename=str(plot_path_s)))
except Exception as e: self.logger.error(f"Failed to save/display correlation plot: {e}")
finally: plt.close(fig_s)
else: plt.close(fig_s)
def _analyze_samples(self, n=3, attempts=10): # Unchanged from v4.1
"""Analyze sample texts from different length categories."""
if self.df is None or self.df.empty or 'slen' not in self.df.columns:
self.logger.warning("No sample data or 'slen' column. Skipping sample text analysis.")
return None
self.logger.info("Analyzing sample texts by source length category...")
bins = [0, 50, 100, 150, 200, float('inf')]
labels = ['1-50', '51-100', '101-150', '151-200', '201+']
try:
self.df['len_cat'] = pd.cut(self.df['slen'], bins=bins, labels=labels, right=True, include_lowest=True)
except ValueError as e:
self.logger.error(f"Error creating length categories: {e}. Skipping sample analysis."); return None
cjk_pattern = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af]+') # CJK chars
samples_list = []; skipped_non_english_target = 0; used_indices = set()
for cat in labels:
df_filtered = self.df[self.df['len_cat'] == cat]
collected_count = 0; attempt_count = 0
available_indices = df_filtered.index.difference(used_indices)
while collected_count < n and attempt_count < attempts and len(available_indices) > 0:
attempt_count += 1
try:
# Choose randomly from available indices for this category
random_idx = random.choice(available_indices.tolist())
used_indices.add(random_idx) # Mark as used globally
available_indices = available_indices.drop(random_idx) # Remove from local available pool
row = df_filtered.loc[random_idx]
target_text = str(row.get('target', ''))
# Check if target text is likely English (no CJK chars) and not an error message
if target_text and not target_text.startswith("Error:") and not cjk_pattern.search(target_text):
samples_list.append({
'Category': cat, 'Source Length': row.get('slen'), 'Target Length': row.get('tlen'),
'Source Text': row.get('source'), 'Target Text': target_text
})
collected_count += 1
else:
skipped_non_english_target += 1
except KeyError: continue # Index might be invalid if df changed unexpectedly
except Exception as e: self.logger.warning(f"Error getting sample for category {cat}: {e}"); break # Stop trying for this cat on error
if 'len_cat' in self.df.columns: self.df.drop(columns=['len_cat'], inplace=True, errors='ignore')
if skipped_non_english_target > 0: self.logger.info(f"Skipped {skipped_non_english_target} samples with non-English target text.")
if not samples_list: self.logger.warning("No suitable English target samples found across categories."); return None
return pd.DataFrame(samples_list)
def _analyze_frequency(self): # Unchanged from v4.1
"""Analyze word frequencies for n-grams (1-4) in both languages."""
if self.df is None or self.df.empty or not EDA_LIBS_AVAILABLE or not nltk:
self.logger.warning("Frequency analysis skipped: No data or missing libraries (nltk).")
return None
required_cols_zh = [f'zh_{n}g' for n in range(1, 5)]; required_cols_en = [f'en_{n}g' for n in range(1, 5)]
if not all(c in self.df.columns for c in required_cols_zh + required_cols_en):
self.logger.warning("Frequency analysis skipped: Missing pre-tokenized n-gram columns.")
self.logger.debug(f"Available columns: {list(self.df.columns)}")
return None
self.logger.info("Analyzing n-gram frequencies (1-4) from sampled data...")
results = {'chinese': {n: Counter() for n in range(1, 5)}, 'english': {n: Counter() for n in range(1, 5)}}
# Aggregate Chinese n-grams
self.logger.info("Aggregating Chinese n-gram frequencies...")
for n in range(1, 5):
col_name = f'zh_{n}g'
try:
# Flatten the list of lists/tuples of ngrams, handling potential None/NaNs
ngrams_list = list(itertools.chain.from_iterable(self.df[col_name].dropna()))
if ngrams_list: results['chinese'][n] = Counter(ngrams_list)
except Exception as e: self.logger.error(f"Error aggregating Chinese {n}-grams: {e}")
# Aggregate English n-grams
self.logger.info("Aggregating English n-gram frequencies...")
for n in range(1, 5):
col_name = f'en_{n}g'
try:
ngrams_list = list(itertools.chain.from_iterable(self.df[col_name].dropna()))
if ngrams_list: results['english'][n] = Counter(ngrams_list)
except Exception as e: self.logger.error(f"Error aggregating English {n}-grams: {e}")
self.word_freq_data = results
return results
def _plot_top_ngram_frequencies(self, freq_results, save_dir, top_n=20): # Mostly unchanged, fontprop fix
""" Plots the top N frequencies for N-grams (1-4) from the SAMPLE. """
if not EDA_LIBS_AVAILABLE or not plt or not sns or not fm: self.logger.warning("Plotting libraries unavailable. Skipping n-gram frequency plots."); return
if not freq_results: self.logger.warning("No frequency results provided for plotting."); return
self.logger.info(f"Generating sampled n-gram frequency plots (1-4 grams, Top {top_n})...")
save_dir = Path(save_dir); plt.style.use('seaborn-v0_8-whitegrid')
chinese_font_prop = None # Initialize font prop
if CHINESE_FONT_PATH and Path(CHINESE_FONT_PATH).is_file():
try:
chinese_font_prop = fm.FontProperties(fname=CHINESE_FONT_PATH)
self.logger.info(f"Using font prop for sampled plots: {CHINESE_FONT_PATH}")
except Exception as e: self.logger.warning(f"FontProp creation failed: {e}")
elif CHINESE_FONT_PATH: self.logger.warning(f"Chinese font path invalid: {CHINESE_FONT_PATH}")
else: self.logger.warning("No Chinese font path for sampled plots.")
for lang, freqs in freq_results.items():
is_chinese = (lang == 'chinese')
lang_name = "Chinese" if is_chinese else "English"
target_or_source = "Source" if is_chinese else "Target"
palette_name = "YlGnBu" if is_chinese else "viridis"
current_font_prop = chinese_font_prop if is_chinese and chinese_font_prop else None # Use only if valid
for n in range(1, 5): # Loop 1-gram to 4-gram
fig = None; counter = freqs.get(n)
if not counter: self.logger.info(f"No data for sampled {lang} {n}-grams plot."); continue
try:
fig, ax = plt.subplots(figsize=(12, max(8, top_n * 0.45))) # Dynamic height
common_items = counter.most_common(top_n)
if not common_items: self.logger.info(f"No common {lang} {n}-grams found."); plt.close(fig); continue
labels = [_clean_ngram_for_display(item[0]) for item in common_items]
counts = [item[1] for item in common_items]
df_plot = pd.DataFrame({f'{n}-gram': labels, 'Frequency': counts})
palette = sns.color_palette(palette_name, n_colors=len(df_plot))
sns.barplot(x='Frequency', y=f'{n}-gram', data=df_plot, palette=palette, ax=ax, hue=f'{n}-gram', dodge=False, legend=False)
title = f'Top {top_n} Sampled {lang_name} {target_or_source} {n}-grams'
# <<< MODIFIED v4.2 >>> Apply fontprop correctly
ax.set_title(title, fontproperties=current_font_prop); ax.set_xlabel('Frequency'); ax.set_ylabel(f'{n}-gram')
# <<< MODIFIED v4.2 >>> Apply fontprop to ytick labels carefully
try:
# Set ticks explicitly first based on the number of labels
ax.set_yticks(ticks=range(len(labels)))
ax.set_yticklabels(labels, fontproperties=current_font_prop)
except Exception as e:
self.logger.warning(f"Could not set font for y-tick labels on {lang} {n}-gram plot: {e}")
ax.set_yticklabels(labels) # Fallback to default font
# Add count labels to bars
for i, bar in enumerate(ax.patches):
try:
bar_width = bar.get_width()
x_pos = bar_width + (max(counts) * 0.01) if bar_width > 0 else 0.1 # Adjust based on max value
y_pos = bar.get_y() + bar.get_height() / 2
ax.text(x_pos, y_pos, f' {counts[i]:,}', va='center', ha='left', fontsize=9)
except IndexError: pass
except Exception as label_err: self.logger.warning(f"Error adding bar label: {label_err}")
plt.tight_layout()
plot_path = save_dir / f'sampled_top_{top_n}_{lang}_{n}gram_freq.png'
# Save/display logic (unchanged from v4.1)
try:
plt.savefig(plot_path, dpi=150, bbox_inches='tight'); self.logger.info(f"Saved plot: {plot_path}")
if NOTEBOOK_ENV and Image:
logger.debug(f"Attempting display: {plot_path}")
try: display(Image(filename=str(plot_path))); logger.debug(f"Displayed {plot_path} via filename.")
except Exception as e1:
logger.warning(f"Display(filename) fail: {e1}. Trying Image(data).")
try:
with open(plot_path,"rb") as f: img_bytes = f.read()
display(Image(data=img_bytes)); logger.debug(f"Displayed {plot_path} via data.")
except Exception as e2: logger.error(f"Display(data) fail: {e2}"); print(f"[Info] Plot saved to {plot_path}, display failed.")
elif NOTEBOOK_ENV: logger.warning(f"Cannot display {plot_path}: Image obj unavailable.")
except Exception as e: self.logger.error(f"Save/Display plot failed {plot_path}: {e}")
except Exception as plot_err: self.logger.error(f"Plotting sampled {lang} {n}-g failed: {plot_err}", exc_info=True)
finally:
if fig is not None and plt.fignum_exists(fig.number): plt.close(fig)
_plot_top_frequencies = _plot_top_ngram_frequencies # Keep alias
# <<< MODIFIED v4.2 >>> Refined TF-IDF Analysis
def _analyze_tf_idf(self):
"""Analyze TF-IDF scores using tokenized data, handling potential errors."""
if self.df is None or self.df.empty: self.logger.warning("TF-IDF skipped: No data."); return None
if not EDA_LIBS_AVAILABLE or not TfidfVectorizer: self.logger.warning("TF-IDF skipped: Scikit-learn unavailable."); return None
self.logger.info("Analyzing TF-IDF scores for sampled data...")
results = {}
# --- English TF-IDF ---
if 'en_1g' in self.df.columns:
try:
# Join tokens back into strings for vectorizer, handle empty lists
en_docs = [" ".join(tokens) for tokens in self.df['en_1g'].dropna() if tokens]
if en_docs:
en_stop_list = list(self.english_stopwords - self.buddhist_terms_en)
self.tfidf_vectorizer_en = TfidfVectorizer(max_features=1000, stop_words=en_stop_list)
self.tfidf_matrix_en = self.tfidf_vectorizer_en.fit_transform(en_docs)
self.feature_names_en = self.tfidf_vectorizer_en.get_feature_names_out()
en_mean_scores = np.asarray(self.tfidf_matrix_en.mean(axis=0)).ravel() # More direct way to get means
en_results_df = pd.DataFrame({
'word': self.feature_names_en,
'mean_tfidf': en_mean_scores
}).sort_values(by='mean_tfidf', ascending=False)
results['en'] = en_results_df
self.logger.info(f"Calculated TF-IDF for {len(en_docs)} English documents.")
else:
self.logger.warning("No valid English documents for TF-IDF after processing.")
except ValueError as ve:
# Catch specific 'empty vocabulary' error
if "empty vocabulary" in str(ve).lower():
self.logger.error(f"English TF-IDF failed: Empty vocabulary. Check tokenization and stopword lists. {ve}", exc_info=False)
else:
self.logger.error(f"English TF-IDF failed: {ve}", exc_info=True)
except Exception as e:
self.logger.error(f"An unexpected error occurred during English TF-IDF: {e}", exc_info=True)
else:
self.logger.warning("Skipping English TF-IDF: 'en_1g' column not found.")
# --- Chinese TF-IDF ---
if 'zh_1g' in self.df.columns:
try:
# Join tokens (chars or words) back into strings
zh_docs = [" ".join(map(str, tokens)) for tokens in self.df['zh_1g'].dropna() if tokens] # Ensure tokens are strings
if zh_docs:
# For Chinese, TF-IDF stop words are tricky without word segmentation.
# We can pass the character-based stop list, but its effectiveness varies.
# Using stop_words might not be ideal for character-based TF-IDF.
# Consider running without stop_words or using a minimal punctuation list.
zh_stop_list = list(self.chinese_stopwords) if not JIEBA_AVAILABLE else None # Only apply if char-based?
# Let's try without aggressive stop words for char-based TF-IDF
self.tfidf_vectorizer_zh = TfidfVectorizer(max_features=1000, token_pattern=r"(?u)\b\w+\b" if JIEBA_AVAILABLE else r"(?u)\S") # Use word pattern if Jieba used, else single char
self.tfidf_matrix_zh = self.tfidf_vectorizer_zh.fit_transform(zh_docs)
self.feature_names_zh = self.tfidf_vectorizer_zh.get_feature_names_out()
zh_mean_scores = np.asarray(self.tfidf_matrix_zh.mean(axis=0)).ravel()
zh_results_df = pd.DataFrame({
'token': self.feature_names_zh, # 'token' might be char or word
'mean_tfidf': zh_mean_scores
}).sort_values(by='mean_tfidf', ascending=False)
results['zh'] = zh_results_df
self.logger.info(f"Calculated TF-IDF for {len(zh_docs)} Chinese documents.")
else:
self.logger.warning("No valid Chinese documents for TF-IDF after processing.")
except ValueError as ve:
if "empty vocabulary" in str(ve).lower():
self.logger.error(f"Chinese TF-IDF failed: Empty vocabulary. Check tokenization/stopword lists. {ve}", exc_info=False)
else:
self.logger.error(f"Chinese TF-IDF failed: {ve}", exc_info=True)
except Exception as e:
self.logger.error(f"An unexpected error occurred during Chinese TF-IDF: {e}", exc_info=True)
else:
self.logger.warning("Skipping Chinese TF-IDF: 'zh_1g' column not found.")
return results if results else None
# <<< MODIFIED v4.2 >>> Fixed Collocation Analysis DataFrame creation
def _analyze_collocations(self):
"""Analyze word collocations (bigrams) in the dataset."""
if self.df is None or self.df.empty: self.logger.warning("Collocations skipped: No data."); return None
if not EDA_LIBS_AVAILABLE or not Counter: self.logger.warning("Collocations skipped: Libraries unavailable."); return None
self.logger.info("Analyzing word collocations (bigrams) for sampled data...")
results = {}
# --- English Collocations ---
if 'en_1g' in self.df.columns:
try:
# Get all tokens from the 'en_1g' column
all_en_tokens = list(itertools.chain.from_iterable(self.df['en_1g'].dropna()))
if len(all_en_tokens) >= 2:
en_bigrams = list(ngrams(all_en_tokens, 2))
en_counter = Counter(en_bigrams)
if en_counter:
en_common = en_counter.most_common(100)
# Correct DataFrame creation from list of ((word1, word2), freq)
en_results_df = pd.DataFrame(
[(item[0][0], item[0][1], item[1]) for item in en_common],
columns=['word1', 'word2', 'frequency']
)
results['en'] = en_results_df
self.logger.info(f"Found {len(en_counter)} unique English bigrams.")
else: self.logger.warning("No English bigrams found.")
else: self.logger.warning("Not enough English tokens to find bigrams.")
except Exception as e:
self.logger.error(f"Error during English collocation analysis: {e}", exc_info=True)
else:
self.logger.warning("Skipping English collocations: 'en_1g' column not found.")
# --- Chinese Collocations ---
if 'zh_1g' in self.df.columns:
try:
# Get all tokens (chars or words)
all_zh_tokens = list(itertools.chain.from_iterable(self.df['zh_1g'].dropna()))
if len(all_zh_tokens) >= 2:
zh_bigrams = list(ngrams(all_zh_tokens, 2))
zh_counter = Counter(zh_bigrams)
if zh_counter:
zh_common = zh_counter.most_common(100)
# Correct DataFrame creation
zh_results_df = pd.DataFrame(
[(item[0][0], item[0][1], item[1]) for item in zh_common],
columns=['token1', 'token2', 'frequency']
)
results['zh'] = zh_results_df
self.logger.info(f"Found {len(zh_counter)} unique Chinese bigrams (token-based).")
else: self.logger.warning("No Chinese bigrams found.")
else: self.logger.warning("Not enough Chinese tokens to find bigrams.")
except Exception as e:
self.logger.error(f"Error during Chinese collocation analysis: {e}", exc_info=True)
else:
self.logger.warning("Skipping Chinese collocations: 'zh_1g' column not found.")
return results if results else None
def _generate_overall_ngram_word_clouds(self, save_dir): # Unchanged from v4.1
"""Generates overall N-gram word clouds for the SAMPLE."""
if self.df is None or self.df.empty: self.logger.warning("Word clouds skipped: No data."); return
if not WordCloud or not plt: self.logger.warning("Word clouds skipped: Libraries unavailable."); return
if not self.word_freq_data:
self.logger.warning("Word frequency data missing. Attempting to generate...")
self._analyze_frequency() # Try to generate it if missing
if not self.word_freq_data: self.logger.error("Word frequency data unavailable. Cannot generate clouds."); return
self.logger.info("Generating overall N-gram word clouds for sampled data..."); save_dir = Path(save_dir)
# Chinese Clouds
zh_freqs = self.word_freq_data.get('chinese', {})
for n, ctr in zh_freqs.items():
if ctr:
wc_freq_data = {_clean_ngram_for_display(k): v for k, v in ctr.items()}
title = f"Sampled Chinese Source {n}-grams"
filename = save_dir / f"sampled_zh_src_{n}g_cloud.png"
cmap = self.category_colormaps_zh.get("Overall", 'tab20c')
self._plot_single_word_cloud(wc_freq_data, title, filename, cmap, 'white', is_chinese=True)
# English Clouds
en_freqs = self.word_freq_data.get('english', {})
wc_stop = set(WORDCLOUD_STOPWORDS or []) | (self.english_stopwords - self.buddhist_terms_en)
for n, ctr in en_freqs.items():
if ctr:
title = f"Sampled English Target {n}-grams"
filename = save_dir / f"sampled_en_tgt_{n}g_cloud.png"
cmap = self.category_colormaps_en.get("Overall", 'viridis')
# Apply stopwords only to unigrams, enable collocations only for unigrams
self._plot_single_word_cloud(ctr, title, filename, cmap, 'black', is_chinese=False,
stopwords=(wc_stop if n == 1 else None),
collocations=(n == 1))
def _generate_categorized_unigram_word_clouds(self, save_dir): # Unchanged from v4.1
"""Generate categorized UNIGRAM word clouds for English and Chinese from the SAMPLE."""
if self.df is None or self.df.empty: self.logger.warning("Cat. clouds skipped: No data."); return
if not WordCloud or not plt: self.logger.warning("Cat. clouds skipped: Libraries unavailable."); return
if not self.word_freq_data:
self.logger.warning("Word frequency data missing for categorized clouds."); self._analyze_frequency();
if not self.word_freq_data: self.logger.error("Word frequency data unavailable. Cannot generate cat. clouds."); return
self.logger.info("Generating categorized unigram word clouds (EN & ZH)..."); save_dir = Path(save_dir)
# --- English Categories ---
en_unigram_freq = self.word_freq_data.get('english', {}).get(1, Counter())
if en_unigram_freq:
self.logger.info("Generating English categorized clouds..."); categorized_counts_en = {name: Counter() for name in self.categories_en}
for token, freq in en_unigram_freq.items():
token_str = str(token) # Ensure token is string
if token_str in self.term_to_category_en:
categorized_counts_en[self.term_to_category_en[token_str]][token_str] += freq
for name, count_data in categorized_counts_en.items():
print(f"\n--- English Category: {name} ---");
if count_data:
safe_name = re.sub(r'[\\/*?:"<>|]', '', name).strip().lower().replace(' ', '_') # Sanitize filename
filename = save_dir / f"sampled_cat_en_{safe_name}_1g_cloud.png"
cmap = self.category_colormaps_en.get(name, 'coolwarm')
self._plot_single_word_cloud(count_data, f"Sampled EN Category: {name}", filename, cmap, 'white', is_chinese=False, max_words=80)
else: print(f"(No terms found for English category '{name}')")
else: self.logger.warning("No English unigram data available for categorized clouds.")
# --- Chinese Categories ---
zh_unigram_freq = self.word_freq_data.get('chinese', {}).get(1, Counter())
if zh_unigram_freq:
self.logger.info("Generating Chinese categorized clouds..."); categorized_counts_zh = {name: Counter() for name in self.categories_zh}
for token_tuple, freq in zh_unigram_freq.items():
# Handle tuple vs single item
token = token_tuple[0] if isinstance(token_tuple, tuple) and len(token_tuple) == 1 else str(token_tuple)
if token in self.term_to_category_zh:
categorized_counts_zh[self.term_to_category_zh[token]][token] += freq
for name, count_data in categorized_counts_zh.items():
print(f"\n--- Chinese Category: {name} ---");
if count_data:
safe_name = re.sub(r'[^\w\s-]', '', name).strip().lower().replace(' ', '_').replace('/','_') # Sanitize more
filename = save_dir / f"sampled_cat_zh_{safe_name}_1g_cloud.png"
cmap = self.category_colormaps_zh.get(name, 'autumn')
self._plot_single_word_cloud(count_data, f"Sampled ZH Category: {name}", filename, cmap, 'white', is_chinese=True, max_words=80)
else: print(f"(No terms found for Chinese category '{name}')")
else: self.logger.warning("No Chinese unigram data available for categorized clouds.")
def _plot_single_word_cloud(self, frequency_counter, title, filepath, colormap='viridis', background_color='black', max_words=100, is_chinese=False, stopwords=None, collocations=False): # Mostly Unchanged, font prop fix
"""Helper function to plot a single word cloud with better font handling."""
if not WordCloud or not plt: self.logger.warning("WordCloud/Plotting unavailable."); return
if not frequency_counter: self.logger.warning(f"No frequency data for cloud: {title}"); return
filepath = Path(filepath)
font_path = CHINESE_FONT_PATH if is_chinese and CHINESE_FONT_PATH else None
font_prop = None # For title
fig = None
if is_chinese:
if font_path and Path(font_path).is_file():
self.logger.debug(f"Using font {font_path} for '{title}'")
try: font_prop = fm.FontProperties(fname=font_path)
except Exception as e: self.logger.warning(f"Failed to create FontProperties for '{font_path}': {e}")
elif font_path:
self.logger.warning(f"Chinese font path invalid, using default: {font_path}"); font_path = None
else:
self.logger.warning(f"No Chinese font path, using default for '{title}'.")
try:
freq_dict = {str(k): v for k, v in frequency_counter.items()}
if not freq_dict: self.logger.warning(f"Empty frequency dict for cloud: {title}"); return
wc = WordCloud(width=800, height=400, background_color=background_color,
font_path=font_path, # Pass path to WordCloud
max_words=max_words, colormap=colormap,
stopwords=stopwords, collocations=collocations, prefer_horizontal=0.9)
wc.generate_from_frequencies(freq_dict)
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(wc, interpolation='bilinear')
ax.axis("off")
# <<< MODIFIED v4.2 >>> Apply font property to title
ax.set_title(title, fontsize=14, fontproperties=font_prop if is_chinese else None)
plt.tight_layout(pad=0.5)
plt.savefig(filepath, dpi=150, bbox_inches='tight')
self.logger.info(f"Cloud saved: '{filepath}'")
if NOTEBOOK_ENV and Image:
try: display(Image(filename=str(filepath)))
except Exception as e: self.logger.warning(f"Display fail {filepath}: {e}")
plt.close(fig)
except Exception as e:
self.logger.error(f"Word cloud generation error for '{title}': {e}", exc_info=False)
if is_chinese and "font" in str(e).lower(): self.logger.error(f"Hint: Chinese font failed. Check path: '{font_path}'.")
if plt and fig is not None and plt.fignum_exists(fig.number): plt.close(fig)
# <<< MODIFIED v4.2 >>> Pass self.output_dir to helpers
def run_full_sampled_eda(self, output_dir=None):
"""Run full EDA on the loaded SAMPLE with enhanced visualizations and n-gram analysis."""
if self.df is None or self.df.empty: self.logger.warning("No sample data loaded. Skipping EDA."); return None
# Use the analyzer's output_dir if none is provided
eda_dir = Path(output_dir) if output_dir else self.output_dir
eda_dir.mkdir(parents=True, exist_ok=True)
self.logger.info(f"--- Starting Sampled EDA ({len(self.df):,} samples) ---")
self.logger.info(f"Saving sampled EDA outputs to: {eda_dir}")
start_time = time.time()
results = {} # To store analysis results
# 1. Basic Stats
print("\n" + "="*30 + " 1. Basic Statistics " + "="*30)
self.logger.info("[EDA 1/8] Calculating Basic Stats...")
stats_df = self._analyze_basic_stats()
if stats_df is not None:
try: stats_df.round(2).to_csv(eda_dir / "sampled_stats_summary.csv"); self.logger.info("Saved basic stats summary.")
except Exception as e: self.logger.warning(f"Failed to save basic stats summary: {e}")
print("-" * 80)
# 2. Length Distributions
print("\n" + "="*30 + " 2. Length Distributions " + "="*30)
self.logger.info("[EDA 2/8] Analyzing Length Distributions...")
self._analyze_length_distributions_matplotlib(eda_dir)
print("-" * 80)
# 3. Sample Texts Analysis
print("\n" + "="*30 + " 3. Sample Texts Analysis " + "="*30)
self.logger.info("[EDA 3/8] Analyzing Sample Texts by Length...")
samples_df = self._analyze_samples()
if samples_df is not None and not samples_df.empty:
try: display(samples_df); samples_df.to_csv(eda_dir / "sampled_texts_by_length.csv", index=False, encoding='utf-8-sig'); self.logger.info("Saved sample texts.")
except Exception as e: self.logger.warning(f"Failed to save/display sample texts: {e}")
else: print("No suitable English target samples found.")
print("-" * 80)
# 4. N-gram Frequency Analysis & Plots
print("\n" + "="*30 + " 4. N-gram Frequencies " + "="*30)
self.logger.info("[EDA 4/8] Analyzing N-gram Frequencies...")
freq_res = self._analyze_frequency()
if freq_res:
results['frequency_analysis'] = freq_res
for lang, freqs in freq_res.items():
print(f"\n--- Top Sampled {lang.capitalize()} N-grams ---")
for n, ctr in sorted(freqs.items()):
if ctr:
top_n_display = 15 # Show top 15 in output
common = ctr.most_common(top_n_display)
items = [(_clean_ngram_for_display(i[0]), i[1]) for i in common]
df_f = pd.DataFrame(items, columns=[f'{n}-gram', 'Frequency'])
print(f"\nTop {top_n_display} {n}-grams:"); display(df_f)
# Save full frequency list (up to 5000)
try:
full_list = [(_clean_ngram_for_display(i[0]), i[1]) for i in ctr.most_common(5000)]
df_full = pd.DataFrame(full_list, columns=[f'{n}-gram', 'Frequency'])
df_full.to_csv(eda_dir / f"sampled_{lang}_{n}g_freq_top5000.csv", index=False, encoding='utf-8-sig')
self.logger.info(f"Saved {len(df_full):,} sampled {lang} {n}-gram frequencies.")
except Exception as e: self.logger.warning(f"Failed to save full {lang} {n}-gram frequencies: {e}")
else: print(f"\nNo data for {lang} {n}-grams.")
# Plot frequencies
self._plot_top_ngram_frequencies(freq_res, eda_dir, top_n=20)
else: self.logger.warning("Frequency analysis did not produce results.")
print("-" * 80)
# 5. TF-IDF Analysis
print("\n" + "="*30 + " 5. TF-IDF Analysis " + "="*30)
self.logger.info("[EDA 5/8] Analyzing TF-IDF...")
tfidf_results = self._analyze_tf_idf()
if tfidf_results:
results['tfidf_analysis'] = tfidf_results
if 'en' in tfidf_results:
print("\n--- Top English TF-IDF Scores ---"); display(tfidf_results['en'].head(15))
try: tfidf_results['en'].to_csv(eda_dir / "sampled_tfidf_en.csv", index=False); self.logger.info("Saved English TF-IDF.")
except Exception as e: self.logger.warning(f"Failed to save English TF-IDF: {e}")
if 'zh' in tfidf_results:
print("\n--- Top Chinese TF-IDF Scores ---"); display(tfidf_results['zh'].head(15))
try: tfidf_results['zh'].to_csv(eda_dir / "sampled_tfidf_zh.csv", index=False, encoding='utf-8-sig'); self.logger.info("Saved Chinese TF-IDF.")
except Exception as e: self.logger.warning(f"Failed to save Chinese TF-IDF: {e}")
else: self.logger.warning("TF-IDF analysis did not produce results.")
print("-" * 80)
# 6. Collocation Analysis
print("\n" + "="*30 + " 6. Collocation Analysis " + "="*30)
self.logger.info("[EDA 6/8] Analyzing Collocations...")
coll_results = self._analyze_collocations()
if coll_results:
results['collocation_analysis'] = coll_results
if 'en' in coll_results:
print("\n--- Top English Collocations (Bigrams) ---"); display(coll_results['en'].head(15))
try: coll_results['en'].to_csv(eda_dir / "sampled_collocations_en.csv", index=False); self.logger.info("Saved English collocations.")
except Exception as e: self.logger.warning(f"Failed to save English collocations: {e}")
if 'zh' in coll_results:
print("\n--- Top Chinese Collocations (Bigrams) ---"); display(coll_results['zh'].head(15))
try: coll_results['zh'].to_csv(eda_dir / "sampled_collocations_zh.csv", index=False, encoding='utf-8-sig'); self.logger.info("Saved Chinese collocations.")
except Exception as e: self.logger.warning(f"Failed to save Chinese collocations: {e}")
else: self.logger.warning("Collocation analysis did not produce results.")
print("-" * 80)
# 7. Overall N-gram Word Clouds
print("\n" + "="*30 + " 7. Overall N-gram Word Clouds " + "="*30)
self.logger.info("[EDA 7/8] Generating Overall N-gram Word Clouds...")
self._generate_overall_ngram_word_clouds(eda_dir)
print("-" * 80)
# 8. Categorized Unigram Word Clouds
print("\n" + "="*30 + " 8. Categorized Unigram Word Clouds " + "="*30)
self.logger.info("[EDA 8/8] Generating Categorized Unigram Word Clouds...")
self._generate_categorized_unigram_word_clouds(eda_dir)
print("-" * 80)
elapsed_time = time.time() - start_time
self.logger.info(f"--- Sampled EDA Complete ({elapsed_time:.2f} seconds) ---")
return results
# --- Deprecated / Unused Methods from v4.1 ---
# These methods seem duplicated or replaced by the run_full_sampled_eda structure
# Commenting them out to avoid confusion and potential errors.
# def _analyze_wordcloud(self, lang='en', max_words=100): -> Replaced by _generate_... clouds
# pass
# def run_analysis(self): -> Replaced by run_full_sampled_eda
# pass
# def _save_analysis_results(self, results): -> Integrated into run_full_sampled_eda
# pass
# def _get_basic_stats(self): -> Integrated into _analyze_basic_stats
# pass
# def _analyze_lengths(self): -> Integrated into _analyze_length_distributions_matplotlib
# pass
# def _analyze_ngrams(self, n=1): -> Integrated into _analyze_frequency
# pass
# ==============================================================================
# 5. NMT Model Training and Translation (With Fixes)
# ==============================================================================
class ImprovedBuddhistNMTTrainer:
""" Manages NMT model lifecycle: loading, fine-tuning, translation. """
def __init__(self, base_model_name="Helsinki-NLP/opus-mt-zh-en", output_dir="./buddhist-nmt-finetuned"):
self.base_model_name=base_model_name; self.output_dir=Path(output_dir); self.device=torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')); self.logger=logging.getLogger("BuddhistNMT"); self.logger.info(f"Using device: {self.device}")
self.tokenizer=None; self.model=None; self.is_fine_tuned=False; self.base_tokenizer=None; self.base_model_loaded=None
self.output_dir.mkdir(parents=True, exist_ok=True); self._initialize_models()
def _initialize_models(self): # Unchanged from v4.1
self.logger.info("--- Initializing Models ---"); self.is_fine_tuned=False; ft_path=self.output_dir; base_loaded=False
conf_ok=(ft_path/"config.json").is_file(); tok_ok=(ft_path/"tokenizer_config.json").is_file(); model_ok=(ft_path/"model.safetensors").is_file()or(ft_path/"pytorch_model.bin").is_file()
try:
if conf_ok and tok_ok and model_ok:
self.logger.info(f"Loading FT model from {ft_path}");
try:
self.tokenizer=MarianTokenizer.from_pretrained(str(ft_path)) # Use string path
self.model=MarianMTModel.from_pretrained(str(ft_path)) # Use string path
self.logger.info("✅ Fine-tuned model loaded successfully."); self.is_fine_tuned=True
except Exception as e:
self.logger.warning(f"⚠️ Failed to load fine-tuned model from {ft_path}: {e}. Falling back to base model."); self.model,self.tokenizer=None,None; self.is_fine_tuned=False
else: self.logger.info(f"No valid fine-tuned model found at {ft_path}. Will use base model.")
if not self.is_fine_tuned:
self.logger.info(f"Loading base model '{self.base_model_name}' for primary translation task.");
try:
self.tokenizer=MarianTokenizer.from_pretrained(self.base_model_name); self.model=MarianMTModel.from_pretrained(self.base_model_name); self.logger.info("✅ Base model loaded for primary task."); base_loaded=True
except Exception as e: self.logger.error(f"Failed to load base model '{self.base_model_name}': {e}",exc_info=True); raise
# Move primary model to device
if self.model: self.model=self.model.to(self.device); self.model.eval()
# Load base model for comparison if needed
if base_loaded:
# If the primary model *is* the base model, reuse it for comparison
self.logger.info("Reusing loaded base model for comparison."); self.base_tokenizer=self.tokenizer; self.base_model_loaded=self.model
else:
# Load base model specifically for comparison if FT model was loaded or base loading failed initially
self.logger.info(f"Loading base model '{self.base_model_name}' specifically for comparison.");
try:
self.base_tokenizer=MarianTokenizer.from_pretrained(self.base_model_name); self.base_model_loaded=MarianMTModel.from_pretrained(self.base_model_name); self.base_model_loaded=self.base_model_loaded.to(self.device); self.base_model_loaded.eval(); self.logger.info("✅ Base comparison model loaded successfully.")
except Exception as e:
self.logger.error(f"Failed to load base model for comparison: {e}",exc_info=True); self.logger.warning("Comparison with base model will not be available."); self.base_tokenizer,self.base_model_loaded=None,None
except Exception as e:
self.logger.critical(f"Critical error during model initialization: {e}", exc_info=True); self.model,self.tokenizer,self.base_model_loaded,self.base_tokenizer=None,None,None,None; self.is_fine_tuned=False; raise # Re-raise critical errors
primary_desc = "Fine-Tuned" if self.is_fine_tuned else f"Base ({self.base_model_name})"
compare_desc = f"Base ({self.base_model_name})" if self.base_model_loaded else "N/A"
self.logger.info(f"Model Status -> Primary: {primary_desc} | Comparison: {compare_desc}"); self.logger.info("--- Model Initialization Finished ---")
def prepare_enhanced_dataset(self, corpus_path, max_samples=None, min_length=3, max_length=256, length_ratio_range=(0.1,15.0), test_split=0.1): # Unchanged from v4.1
"""Prepare an enhanced dataset for training."""
corpus_path = Path(corpus_path)
self.logger.info(f"--- Preparing Dataset for Training: {corpus_path} ---")
max_s_info = f"Max samples to load: {max_samples:,}" if max_samples else "Loading all samples"
self.logger.info(max_s_info)
self.logger.info(f"Filters: Min/Max Length={min_length}/{max_length}, Ratio Range={length_ratio_range}")
if not corpus_path.is_file(): raise FileNotFoundError(f"Corpus file not found: {corpus_path}")
data, read_count, loaded_count, skip_json, skip_format, skip_len, skip_ratio = [], 0, 0, 0, 0, 0, 0
try:
with corpus_path.open('r', encoding='utf-8') as f:
total_lines_est = None # Optional line count for progress bar
try: f.seek(0); total_lines_est = sum(1 for _ in f); f.seek(0);
except: pass
pbar = tqdm(f, desc="Reading Corpus", unit=" lines", total=max_samples or total_lines_est, disable=max_samples is None and total_lines_est is None, leave=False)
for line in pbar:
read_count += 1;
if max_samples is not None and loaded_count >= max_samples: break # Stop if max sample limit reached
try:
item = json.loads(line.strip()); src = item.get('source_sentence'); tgt = item.get('translation')
if isinstance(src, str) and isinstance(tgt, str) and src.strip() and tgt.strip():
src, tgt = src.strip(), tgt.strip(); slen, tlen = len(src), len(tgt)
# Apply filters
if not (min_length <= slen <= max_length and min_length <= tlen <= max_length): skip_len += 1; continue
ratio = tlen / slen if slen > 0 else float('inf') # Handle slen=0
if not (length_ratio_range[0] <= ratio <= length_ratio_range[1]): skip_ratio += 1; continue
# If passes filters, add to data
data.append({'source': src, 'target': tgt}); loaded_count += 1
if loaded_count % 5000 == 0: pbar.set_postfix({"Loaded":f"{loaded_count:,}"}, refresh=True)
else: skip_format += 1 # Skip if not strings or empty after stripping
except json.JSONDecodeError: skip_json += 1
except Exception: skip_format += 1 # Catch other unexpected errors during processing
pbar.close()
self.logger.info(f"Read {read_count:,} lines. Loaded {loaded_count:,} valid samples pre-split.")
self.logger.info(f"Skipped: JSON {skip_json:,}, Format/Empty {skip_format:,}, Length {skip_len:,}, Ratio {skip_ratio:,}.")
if not data: self.logger.error("No valid data loaded after filtering. Training cannot proceed."); return [], []
random.seed(42); random.shuffle(data) # Shuffle before splitting
# Split into train/eval
if 0 < test_split < 1:
split_idx = int(len(data) * (1 - test_split)); train_data, eval_data = data[:split_idx], data[split_idx:]
self.logger.info(f"Split dataset: {len(train_data):,} train, {len(eval_data):,} eval.")
else:
train_data, eval_data = data, [] # Use all data for training if split is invalid
self.logger.warning(f"Invalid test_split ({test_split}). Using all {len(train_data):,} samples for training.")
return train_data, eval_data
except Exception as e: self.logger.error(f"Error preparing dataset: {e}", exc_info=True); raise
# <<< MODIFIED v4.2 >>> Fixed training start issue by removing Tensorboard reporting
def improved_train(self, corpus_path, max_samples=None, epochs=1, batch_size=16,
learning_rate=3e-5, weight_decay=0.01, warmup_steps_ratio=0.1,
eval_steps=500, save_steps=500, gradient_accumulation_steps=1,
max_token_length=128, fp16=None):
""" Fine-tunes the model with compatibility checks and better logging. """
self.logger.info("--- Starting Model Fine-Tuning Process ---")
self.logger.warning("Ensure sufficient GPU/CPU resources and time for training.")
# Determine FP16 setting
if fp16 is None:
fp16 = self.device.type == 'cuda' # Automatically enable for CUDA
self.logger.info(f"Auto-detected FP16 setting based on device: {fp16}")
elif fp16 and self.device.type not in ['cuda']: # Cannot use FP16 on CPU or MPS (usually)
self.logger.warning(f"FP16 requested but device is {self.device.type}. Disabling FP16.")
fp16 = False
# Load base model specifically for training
try:
self.logger.info(f"Loading base model '{self.base_model_name}' for fine-tuning...")
train_tok = MarianTokenizer.from_pretrained(self.base_model_name)
train_model = MarianMTModel.from_pretrained(self.base_model_name).to(self.device)
self.logger.info("✅ Base model and tokenizer loaded for training.")
except Exception as e:
self.logger.error(f"Failed to load base model for training: {e}", exc_info=True)
return None # Cannot proceed
# Prepare datasets
self.logger.info("Preparing training and evaluation datasets...")
# Allow longer raw text length as input to dataset prep, tokenization will handle truncation
train_data, eval_data = self.prepare_enhanced_dataset(corpus_path, max_samples, max_length=max_token_length * 4)
if not train_data:
self.logger.error("Training aborted: No training data loaded after preparation.")
return None
train_ds = BuddhistDataset(train_data, train_tok, max_length=max_token_length)
eval_ds = BuddhistDataset(eval_data, train_tok, max_length=max_token_length) if eval_data else None
self.logger.info(f"Datasets created: Train {len(train_ds):,}, Eval {len(eval_ds) if eval_ds else 'N/A'}")
# Calculate training steps and warmup steps
if batch_size <= 0 or gradient_accumulation_steps <= 0:
self.logger.error("Batch size and gradient accumulation steps must be positive."); return None
num_gpus = torch.cuda.device_count() if self.device.type == 'cuda' else 1
effective_batch_size = batch_size * gradient_accumulation_steps * num_gpus
steps_per_epoch = math.ceil(len(train_ds) / (batch_size * gradient_accumulation_steps))
total_training_steps = steps_per_epoch * epochs
num_warmup_steps = int(total_training_steps * warmup_steps_ratio)
self.logger.info(f"Training Steps per Epoch: {steps_per_epoch:,}, Total Steps: {total_training_steps:,}")
self.logger.info(f"Warmup Steps: {num_warmup_steps:,} ({warmup_steps_ratio*100:.1f}% of total)")
self.logger.info(f"Effective Batch Size: {effective_batch_size:,} (Batch: {batch_size}, Accum: {gradient_accumulation_steps}, GPUs: {num_gpus})")
# Configure evaluation and saving steps
do_eval = eval_ds is not None and eval_steps is not None and eval_steps > 0
current_eval_steps = eval_steps if do_eval else None
current_save_steps = save_steps
# Adjust save steps to be multiple of eval steps if both are active
if do_eval and current_save_steps and current_eval_steps and current_save_steps % current_eval_steps != 0:
current_save_steps = math.ceil(current_save_steps / current_eval_steps) * current_eval_steps
self.logger.warning(f"Adjusted save_steps to {current_save_steps} to align with eval_steps ({current_eval_steps})")
# Determine logging steps (e.g., 5% of steps per epoch, or minimum 10)
logging_steps = max(10, steps_per_epoch // 20) if steps_per_epoch > 200 else 50
# Check Transformers version for argument compatibility
self.logger.info(f"Using Transformers version: {transformers_version}")
is_modern_trainer = parse_version(transformers_version) >= parse_version("4.6.0")
self.logger.info(f"Trainer strategy: {'Modern (>=4.6)' if is_modern_trainer else 'Legacy (<4.6)'}")
# Define Training Arguments Dictionary
args_dict = {
"output_dir": str(self.output_dir),
"overwrite_output_dir": True,
"num_train_epochs": epochs,
"per_device_train_batch_size": batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
"learning_rate": learning_rate,
"weight_decay": weight_decay,
"warmup_steps": num_warmup_steps,
"logging_dir": str(self.output_dir / "logs" / "training"), # Specific subfolder for logs
"logging_strategy": "steps",
"logging_steps": logging_steps,
"save_total_limit": 2, # Keep best and latest
"fp16": fp16,
"seed": 42,
"disable_tqdm": False,
"predict_with_generate": True, # Enable generation for potential BLEU/ROUGE metrics
"generation_max_length": max_token_length, # Max length for eval generation
"generation_num_beams": 4, # Beam size for eval generation
# <<< MODIFIED v4.2 >>> Removed Tensorboard reporting to fix startup error
# "report_to": ["tensorboard"] if NOTEBOOK_ENV else [], # Report to Tensorboard if available
"report_to": [], # Disable external reporting by default
"per_device_eval_batch_size": batch_size * 2, # Usually can use larger eval batch
}
# Add version-specific arguments
if is_modern_trainer:
args_dict["eval_strategy"] = "steps" if do_eval else "no"
if do_eval: args_dict["eval_steps"] = current_eval_steps
args_dict["save_strategy"] = "steps" if current_save_steps and current_save_steps > 0 else "epoch"
if current_save_steps and current_save_steps > 0: args_dict["save_steps"] = current_save_steps
args_dict["load_best_model_at_end"] = do_eval
if do_eval:
args_dict["metric_for_best_model"] = "loss" # Use eval loss to find best model
args_dict["greater_is_better"] = False # Lower loss is better
try:
# Use newer AdamW implementation if available
args_dict["optim"] = "adamw_torch" # or "adamw_torch_fused" if applicable
except: # Should be args_dict.setdefaults or check availability
args_dict["optim"] = "adamw_hf" # Fallback
else:
# Use older argument names for compatibility
self.logger.warning("Using legacy argument names for older Transformers version.")
args_dict["evaluation_strategy"] = "steps" if do_eval else "no" # Old name
if do_eval: args_dict["eval_steps"] = current_eval_steps
if current_save_steps and current_save_steps > 0:
args_dict["save_steps"] = current_save_steps
# Need to set save_strategy if using save_steps with older versions
args_dict["save_strategy"] = "steps"
else:
args_dict["save_strategy"] = "epoch" # Default save strategy if no steps
args_dict["load_best_model_at_end"] = do_eval
if do_eval: args_dict["metric_for_best_model"] = "loss"; args_dict["greater_is_better"] = False
# Remove None values from dict before passing to arguments
args_dict = {k: v for k, v in args_dict.items() if v is not None}
# Create TrainingArguments object
try:
training_args = Seq2SeqTrainingArguments(**args_dict)
self.logger.info(f"Training arguments parsed successfully. Eval: {args_dict.get('eval_strategy', args_dict.get('evaluation_strategy', 'N/A'))}, Save: {args_dict.get('save_strategy', 'N/A')}")
except TypeError as e:
self.logger.critical(f"FATAL: Failed to create Seq2SeqTrainingArguments: {e}", exc_info=True)
self.logger.critical(f"Arguments provided: {args_dict}")
self.logger.critical("This often happens with incompatible Transformers versions or incorrect arguments. Consider upgrading Transformers.")
return None
# Ensure logging directory exists
Path(training_args.logging_dir).mkdir(parents=True, exist_ok=True)
# Initialize Trainer
trainer = Seq2SeqTrainer(
model=train_model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=train_tok, # Pass tokenizer for saving and generation
# data_collator could be added for dynamic padding if needed
)
# --- Start Training ---
self.logger.info(f"Starting fine-tuning for {epochs} epoch(s)...")
start_time = time.time()
try:
self.logger.info(">>> Calling trainer.train() <<<")
train_result = trainer.train() # <<< THE ACTUAL TRAINING CALL >>>
train_duration = time.time() - start_time
self.logger.info(f"✅ Training finished successfully in {train_duration // 60:.0f}m {train_duration % 60:.0f}s")
try: self.logger.info(f" Training Metrics: {train_result.metrics}")
except: self.logger.warning(" Could not retrieve final training metrics.")
# Save the final model and tokenizer
# If load_best_model_at_end=True, trainer.model is already the best one
self.logger.info(f"💾 Saving final model state to {self.output_dir}...")
trainer.save_model(str(self.output_dir)) # Saves the current state (best or last)
train_tok.save_pretrained(str(self.output_dir)) # Save tokenizer config with model
self.logger.info(" Model and tokenizer saved.")
# Reload the saved model to ensure consistency and update self.model/tokenizer
self.logger.info("🔄 Reloading the saved fine-tuned model to activate it...")
try:
self._initialize_models() # This should now load the fine-tuned model
if self.is_fine_tuned: self.logger.info("✅ Fine-tuned model reloaded and is now active.")
else: self.logger.error("❌ CRITICAL: Fine-tuned model training finished but failed to reload!")
except Exception as reload_err:
self.logger.error(f"❌ Error reloading model after training: {reload_err}", exc_info=True)
self.is_fine_tuned = False # Mark as not fine-tuned if reload fails
return trainer # Return trainer instance which contains history etc.
except Exception as train_err:
self.logger.error(f"❌ Training Exception Occurred: {train_err}", exc_info=True)
# Attempt to save the interrupted state
try:
if trainer and hasattr(trainer, 'model') and trainer.model:
save_path_interrupted = self.output_dir / f"model_interrupted_{datetime.datetime.now():%Y%m%d_%H%M%S}"
save_path_interrupted.mkdir(exist_ok=True)
trainer.save_model(str(save_path_interrupted))
self.logger.info(f"💾 Saved interrupted model state to: {save_path_interrupted}")
except Exception as save_interrupt_err:
self.logger.error(f"Failed to save interrupted model state: {save_interrupt_err}")
# Re-initialize models to pre-training state to avoid inconsistent state
self.logger.warning("Re-initializing models to their state before the failed training attempt...")
try: self._initialize_models()
except: self.logger.error("Failed to re-initialize models after training error.")
return None # Indicate training failure
# --- Translation Methods ---
def _translate_batch(self, texts, model, tokenizer, model_desc="Unknown", beam_size=5, max_gen_length=200, **gen_kwargs): # Unchanged from v4.1
if not model or not tokenizer:
self.logger.error(f"Translation failed: {model_desc} model or tokenizer not loaded.")
return [f"Error: {model_desc} model not loaded"] * len(texts)
if not texts or not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
self.logger.error(f"Translation failed: Input must be a list of non-empty strings. Got: {type(texts)}")
return ["Error: Invalid input type"] * (len(texts) if isinstance(texts, list) else 1)
# Filter out empty strings before tokenization
non_empty_texts = [t for t in texts if t and t.strip()]
original_indices = [i for i, t in enumerate(texts) if t and t.strip()]
if not non_empty_texts:
return [""] * len(texts) # Return empty strings if all inputs were empty/whitespace
translations = [""] * len(texts) # Initialize result list
self.logger.debug(f"Translating batch of {len(non_empty_texts)} texts with {model_desc}...")
try:
inputs = tokenizer(non_empty_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
model.eval() # Ensure model is in eval mode
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_gen_length, # Max length of generated sequence
num_beams=beam_size,
early_stopping=True,
**gen_kwargs # Pass other generation params
)
batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Place translations back into the original list structure
for i, trans in enumerate(batch_translations):
translations[original_indices[i]] = trans
return translations
except Exception as e:
self.logger.error(f"Batch translation error ({model_desc}): {e}", exc_info=False) # exc_info=False reduces noise
error_msg = f"Error: Batch translation failed ({e})"
# Fill errors back into the original structure
for i in original_indices:
translations[i] = error_msg
return translations # Return list with errors filled in
def _translate_single_text(self, text, model, tokenizer, model_desc="Unknown", beam_size=5, max_gen_length=200, **gen_kwargs): # Unchanged from v4.1
"""Helper to translate a single text using the batch method."""
if not isinstance(text, str):
self.logger.error(f"Translation failed: Input must be a string, got: {type(text)}")
return "Error: Invalid input type"
if not text.strip():
return "" # Return empty string for empty/whitespace input
# Translate as a batch of one
result = self._translate_batch([text], model, tokenizer, model_desc, beam_size, max_gen_length, **gen_kwargs)
return result[0] # Return the first (and only) element
def general_translate(self, text, beam_size=5, max_length=200): # Unchanged from v4.1
"""Translate using the base (general) comparison model."""
if not self.base_model_loaded or not self.base_tokenizer:
self.logger.error("Cannot perform general translation: Base comparison model not loaded.")
return "Error: Base comparison model unavailable"
self.logger.debug(f"Performing General Translation for: '{text[:50]}...'")
# Use default generation parameters for the general model comparison
return self._translate_single_text(
text,
self.base_model_loaded,
self.base_tokenizer,
model_desc="General",
beam_size=beam_size,
max_gen_length=max_length,
length_penalty=1.0, # Default length penalty
no_repeat_ngram_size=0 # Default (no restriction)
)
def improved_translate(self, text, beam_size=5, max_length=200): # Unchanged from v4.1
"""Translate using the primary model (FT or Base), with improved generation parameters and splitting."""
model_status = "Primary(FT)" if self.is_fine_tuned else "Primary(Base)"
self.logger.debug(f"Performing Improved Translation ({model_status}) for: '{text[:50]}...'")
if not self.model or not self.tokenizer:
self.logger.error(f"{model_status} model/tokenizer not loaded. Cannot perform improved translation.")
return f"Error: {model_status} model unavailable"
# Define improved generation parameters (can be tuned further)
gen_params = {
'length_penalty': 0.9, # Slightly favor shorter sequences
'no_repeat_ngram_size': 3, # Avoid 3-gram repetition
# Add other params like temperature, top_k, top_p if using sampling instead of beam search
}
# Simple sentence splitting for long inputs
# Use a heuristic max length (e.g., characters or estimated tokens)
max_len_heuristic = 450 # Approx chars threshold
should_split = isinstance(text, str) and len(text) > max_len_heuristic
if should_split:
self.logger.info(f"Input length > {max_len_heuristic} chars, attempting sentence splitting...")
sentences = []
current_sentence = ""
try:
# Split by common delimiters, keeping them
# Adjusted regex for potentially better splitting
segments = re.split(r'([.。!?.?!;;]+)', text)
i = 0
while i < len(segments):
part = segments[i]
if part: current_sentence += part
# Add the delimiter if it exists and is a target delimiter
if i + 1 < len(segments) and segments[i+1] and segments[i+1][0] in '.。!!??;;':
current_sentence += segments[i+1]
i += 1 # Skip the delimiter in the next iteration
# Add the complete sentence if it's non-empty
if current_sentence.strip():
sentences.append(current_sentence.strip())
current_sentence = "" # Reset for next sentence
i += 1
# Add any trailing part if the text didn't end with a delimiter
if current_sentence.strip(): sentences.append(current_sentence.strip())
# Fallback if splitting was ineffective
if not sentences or (len(sentences) == 1 and len(text) > max_len_heuristic * 1.5):
self.logger.warning("Sentence splitting was ineffective. Translating as a single block.")
sentences = [text.strip()] # Use original text
self.logger.info(f"Split input into {len(sentences)} segments.")
except Exception as e:
self.logger.error(f"Error during sentence splitting: {e}. Translating as a single block.", exc_info=False)
sentences = [text.strip()] # Fallback on error
# Translate segments in batch
if sentences:
translations = self._translate_batch(sentences, self.model, self.tokenizer, model_status, beam_size, max_length, **gen_params)
final_translation_parts = []
errors_encountered = False
for i, t in enumerate(translations):
if isinstance(t, str) and t.startswith("Error:"):
self.logger.error(f"Error translating segment {i+1}: {t}")
final_translation_parts.append("[Segment Translation Error]")
errors_encountered = True
else:
final_translation_parts.append(str(t).strip() if t else "")
# Join translated parts (consider sentence spacing)
full_translation = " ".join(filter(None, final_translation_parts))
# Prepend warning if errors occurred
return f"[Partial Translation - Errors Occurred] {full_translation}" if errors_encountered else full_translation
else:
return "Error: Sentence splitting failed to produce segments."
else:
# Translate short text directly
return self._translate_single_text(
text, self.model, self.tokenizer, model_status, beam_size, max_length, **gen_params
)
def compare_translations(self, text): # Unchanged from v4.1
"""Compares translations from the primary model and the base model."""
if not isinstance(text, str) or not text.strip():
return {"source_text": text, "buddhist_translation": "Error: Empty input.", "general_translation": "Error: Empty input.",
"similarity_percent": 0.0, "buddhist_specificity_percent": 0.0, "error": "Input is empty."}
res = {"source_text": text, "buddhist_translation": None, "general_translation": None,
"similarity_percent": 0.0, "buddhist_specificity_percent": 0.0, "error": None}
try:
start_bt = time.time(); bt = self.improved_translate(text); time_bt = time.time() - start_bt
start_gt = time.time(); gt = self.general_translate(text); time_gt = time.time() - start_gt
res["buddhist_translation"] = bt; res["general_translation"] = gt
# Check for errors in translations
bt_has_error = isinstance(bt, str) and bt.startswith("Error:")
gt_has_error = isinstance(gt, str) and gt.startswith("Error:")
error_messages = []
if bt_has_error: error_messages.append(f"PrimaryModel: {bt}")
if gt_has_error: error_messages.append(f"GeneralModel: {gt}")
if error_messages:
res["error"] = " | ".join(error_messages)
self.logger.error(f"Error during comparison for '{text[:30]}...': {res['error']}")
elif bt and gt: # Only calculate similarity if both are valid strings
similarity = difflib.SequenceMatcher(None, str(bt), str(gt)).ratio() * 100
specificity = 100.0 - similarity
res["similarity_percent"] = round(similarity, 1)
res["buddhist_specificity_percent"] = round(specificity, 1)
self.logger.info(f"Compared '{text[:30]}...' (Times: Prim {time_bt:.2f}s, Gen {time_gt:.2f}s): Sim={similarity:.1f}%, Spec={specificity:.1f}%")
except Exception as e:
self.logger.error(f"Unexpected error during compare_translations function: {e}", exc_info=True)
res["error"] = f"Comparison function error: {e}"
res["buddhist_translation"] = res["buddhist_translation"] or f"Error: {e}" # Ensure error is propagated
res["general_translation"] = res["general_translation"] or f"Error: {e}"
return res
# ==============================================================================
# 6. Translation History Management (Unchanged from v4.1)
# ==============================================================================
class TranslationHistory:
"""Class to manage translation history and export functionality."""
def __init__(self, max_history=50):
self.history = []
self.max_history = max_history
self.logger = logging.getLogger("TranslationHistory")
self.columns = [ # Define standard columns
"Timestamp", "Source Text", "Buddhist Translation", "General Translation",
"Buddhist Specificity (%)", "Similarity (%)", "Error Info" ]
def add_entry(self, comparison_result):
"""Add a new translation comparison result to the history."""
if not isinstance(comparison_result, dict):
self.logger.warning("Attempted to add invalid entry to history.")
return
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
specificity = comparison_result.get("buddhist_specificity_percent", 0.0)
similarity = comparison_result.get("similarity_percent", 0.0)
# Ensure numeric types, default to 0.0 if missing or invalid
try: spec_float = float(specificity) if specificity is not None else 0.0
except (ValueError, TypeError): spec_float = 0.0
try: sim_float = float(similarity) if similarity is not None else 0.0
except (ValueError, TypeError): sim_float = 0.0
entry = {
"Timestamp": ts,
"Source Text": comparison_result.get("source_text", ""),
"Buddhist Translation": comparison_result.get("buddhist_translation", ""),
"General Translation": comparison_result.get("general_translation", ""),
"Buddhist Specificity (%)": round(spec_float, 1),
"Similarity (%)": round(sim_float, 1),
"Error Info": comparison_result.get("error", None) # Store error message if present
}
self.history.insert(0, entry) # Add to the beginning (newest first)
# Trim history if it exceeds max size
if len(self.history) > self.max_history:
self.history.pop() # Remove the oldest entry
self.logger.debug(f"History entry added. Current size: {len(self.history)}")
def get_history(self):
"""Get the current translation history (list of dictionaries)."""
return self.history
def to_dataframe(self):
"""Convert history to a pandas DataFrame."""
if not self.history:
return pd.DataFrame(columns=self.columns) # Return empty DF with correct columns
# Create DataFrame and reindex to ensure consistent column order
return pd.DataFrame(self.history).reindex(columns=self.columns)
def display_history(self, num_entries=10):
"""Display the most recent translation history entries using pandas display."""
print(f"\n--- Translation History (Last {min(num_entries, len(self.history))} Entries) ---")
if not self.history:
print("(No history entries yet)")
return
df_history = self.to_dataframe()
try:
display(df_history.head(num_entries)) # Use IPython display
except Exception:
print(df_history.head(num_entries).to_string()) # Fallback to string print
if len(self.history) > num_entries:
print(f"... ({len(self.history) - num_entries} older entries exist)")
def clear_history(self):
"""Clear the translation history."""
self.history = []
self.logger.info("Translation history cleared.")
def export_history_to_csv(self, filename="translation_history.csv"):
"""Export translation history to a CSV file."""
if not self.history:
msg = "No history entries to export."
self.logger.warning(msg)
return False, msg
try:
df_history = self.to_dataframe()
df_history.to_csv(filename, index=False, encoding='utf-8-sig') # Use utf-8-sig for Excel
msg = f"Successfully exported {len(self.history)} history entries to '{filename}'"
self.logger.info(msg)
return True, msg
except Exception as e:
msg = f"Failed to export history to CSV: {e}"
self.logger.error(msg, exc_info=True)
return False, msg
# ==============================================================================
# 7. UI and Visualization Helpers (With Fixes)
# ==============================================================================
# <<< MODIFIED v4.2 >>> Added font properties to specificity chart
def create_specificity_chart_matplotlib(buddhist_specificity, width=5.5, height=3.5):
"""Create a matplotlib pie chart showing Buddhist specificity."""
logger_viz = logging.getLogger("VizHelper")
if not plt or not sns:
logger_viz.warning("Matplotlib/Seaborn unavailable for specificity chart.")
# Return placeholder HTML and None for data
return "<p>(Specificity chart generation failed: Plotting libraries missing)</p>", None
# Try to get font properties for labels/title
font_prop = None
if CHINESE_FONT_PATH and Path(CHINESE_FONT_PATH).is_file() and fm:
try: font_prop = fm.FontProperties(fname=CHINESE_FONT_PATH)
except Exception as e: logger_viz.warning(f"Failed to load font properties for chart: {e}")
# Validate input specificity
try:
spec_value = max(0.0, min(100.0, float(buddhist_specificity if buddhist_specificity is not None else 0.0)))
except (ValueError, TypeError):
spec_value = 0.0
general_spec_value = 100.0 - spec_value
labels = ['Buddhist Model', 'General Model']
values = [spec_value, general_spec_value]
# Use slightly different colors maybe
colors = ['#1f77b4', '#ff7f0e'] # Blue for Buddhist, Orange for General
explode = (0.05, 0) if spec_value > general_spec_value else (0, 0.05) # Explode the larger slice slightly
fig, ax = plt.subplots(figsize=(width, height))
try:
wedges, texts, autotexts = ax.pie(
values,
labels=labels,
colors=colors,
autopct='%1.1f%%',
startangle=90,
pctdistance=0.85, # Position percentages inside wedges
explode=explode,
wedgeprops={'edgecolor': 'white', 'linewidth': 1} # Add white edge
)
# Apply font properties to labels and percentages
for text_item in texts + autotexts:
if font_prop:
text_item.set_fontproperties(font_prop)
text_item.set_fontsize(10) # Adjust font size
# Ensure percentages are white for dark wedges, black for light
for i, p in enumerate(autotexts):
lum = sum(matplotlib.colors.to_rgb(colors[i % len(colors)])) / 3 # Basic luminance check
p.set_color('white' if lum < 0.5 else 'black')
ax.axis('equal') # Equal aspect ratio ensures pie is drawn as a circle.
# Apply font property to title
plt.title('Translation Model Specificity', fontproperties=font_prop)
# Save to buffer
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png', bbox_inches='tight', dpi=150)
img_buffer.seek(0)
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
plt.close(fig) # Close plot
# Return HTML img tag and raw base64 data
html_tag = f'<img src="data:image/png;base64,{img_base64}" alt="Specificity Chart" style="max-width: 100%; height: auto; display: block; margin: 10px auto;">'
# Return raw base64 data as well, in case direct image display is preferred
return html_tag, img_base64
except Exception as e:
logger_viz.error(f"Failed to create specificity chart: {e}", exc_info=True)
if fig is not None and plt.fignum_exists(fig.number): plt.close(fig)
return f"<p>(Error generating specificity chart: {e})</p>", None
# <<< MODIFIED v4.2 >>> Added font properties to word cloud title
def create_word_cloud_html(text, title, language='english', background_color='white', colormap='viridis', width=450, height=300):
"""Create an HTML word cloud visualization using Matplotlib/WordCloud."""
logger_viz = logging.getLogger("VizHelper")
if not EDA_LIBS_AVAILABLE or not WordCloud:
logger_viz.warning("WordCloud/Matplotlib unavailable for word cloud generation.")
return "<p>(Word cloud generation failed: Libraries missing)</p>", None # Return HTML placeholder and None data
if not text or not isinstance(text, str) or text.startswith("Error:"):
logger_viz.warning(f"Cannot generate word cloud '{title}': Invalid text provided.")
return f"<p>(Cannot generate word cloud '{title}': No valid text)</p>", None
is_chinese = (language == 'chinese')
font_path = None
font_prop = None # For title
if is_chinese:
if CHINESE_FONT_PATH and Path(CHINESE_FONT_PATH).is_file():
font_path = CHINESE_FONT_PATH
try: font_prop = fm.FontProperties(fname=font_path)
except Exception as e: logger_viz.warning(f"Failed to create FontProperties for word cloud title: {e}")
else:
logger_viz.warning(f"Chinese font path not found or invalid for '{title}'. WordCloud may render incorrectly.")
try:
# Tokenize based on language
if is_chinese:
if JIEBA_AVAILABLE and jieba:
tokens = [w for w in jieba.cut(text) if w.strip()] # Use Jieba if available
else:
logger_viz.warning("Jieba unavailable, using character splitting for Chinese word cloud.")
tokens = [c for c in text if c.strip()] # Fallback to characters
else:
# Basic English tokenization (could use NLTK if needed)
tokens = [w for w in re.findall(r'\b[a-zA-Z]{2,}\b', text.lower())]
if not tokens:
logger_viz.warning(f"No tokens found for word cloud '{title}' after processing.")
return f"<p>(No words found for cloud '{title}')</p>", None
text_for_cloud = " ".join(tokens)
# Generate WordCloud object
wc = WordCloud(
width=width*2, height=height*2, # Generate at higher res
background_color=background_color,
font_path=font_path, # Pass font path for WordCloud internals
max_words=100,
colormap=colormap,
collocations=False,
prefer_horizontal=0.95
).generate(text_for_cloud)
# Plot using Matplotlib
fig, ax = plt.subplots(figsize=(width / 100, height / 100)) # Match aspect ratio
ax.imshow(wc, interpolation='bilinear')
ax.axis('off')
# Set title using font properties if Chinese
ax.set_title(title, fontproperties=font_prop if is_chinese else None, fontsize=10)
plt.tight_layout(pad=0.1)
# Save to buffer
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png', bbox_inches='tight', dpi=150)
img_buffer.seek(0)
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
plt.close(fig) # Close plot
# Return HTML img tag and raw base64 data
html_tag = f'<div style="text-align: center; width: {width}px; padding: 5px; display: inline-block; vertical-align: top;"><img src="data:image/png;base64,{img_base64}" alt="{title}" style="max-width: 100%; height: auto;"></div>'
return html_tag, img_base64
except Exception as e:
logger_viz.error(f"Failed to create word cloud '{title}': {e}", exc_info=True)
if plt and 'fig' in locals() and fig is not None and plt.fignum_exists(fig.number): plt.close(fig)
return f"<p>(Error generating word cloud '{title}': {e})</p>", None
# ==============================================================================
# 8. Interactive Translator Interface (With Fixes)
# ==============================================================================
# <<< MODIFIED v4.2 >>> Refined display_result for robustness and clarity
def run_interactive_translator(trainer_instance, history_manager):
""" Runs interactive translator loop with robust chart display and text UI. """
logger_ui = logging.getLogger("InteractiveUI")
# Initial checks
if not isinstance(trainer_instance, ImprovedBuddhistNMTTrainer): logger_ui.critical("Invalid NMT trainer instance provided to UI."); return
if not trainer_instance.model or not trainer_instance.tokenizer: logger_ui.critical("Primary model/tokenizer not loaded in trainer instance."); return
if not trainer_instance.base_model_loaded or not trainer_instance.base_tokenizer: logger_ui.warning("Base comparison model not loaded. Comparison features disabled.")
if not isinstance(history_manager, TranslationHistory): logger_ui.critical("Invalid TranslationHistory manager provided."); return
# Sample texts for user convenience
sample_texts = [
"多欲為苦。生死疲勞。從貪欲起。少欲無為。身心自在。", "爾時薄伽梵在室羅伐城逝多林給孤獨園。",
"色不異空,空不異色,色即是空,空即是色,受想行識亦復如是。", "一切有為法,如夢幻泡影,如露亦如電,應作如是觀。",
"法無眾生,離眾生垢故;法無有我,離我垢故。" ]
def display_result(result):
"""Displays the translation comparison result, including chart and clouds if possible."""
print("\n" + "="*80); print(f"SOURCE TEXT:\n {result.get('source_text', 'N/A')}"); print("-" * 80)
model_desc = "Primary(FT)" if trainer_instance.is_fine_tuned else "Primary(Base)"
buddhist_trans = result.get('buddhist_translation', 'N/A')
print(f"{model_desc.upper()} TRANSLATION:\n {buddhist_trans}"); print("-" * 80)
general_trans = result.get('general_translation', 'N/A')
base_model_available = trainer_instance.base_model_loaded is not None
if base_model_available:
print(f"GENERAL (BASE) TRANSLATION:\n {general_trans}"); print("-" * 80)
else:
print("GENERAL (BASE) TRANSLATION:\n (Base model not loaded for comparison)"); print("-" * 80)
# --- Display Specificity and Visualizations ---
# Only show these if comparison was possible and successful
comparison_error = result.get("error")
bt_is_error = isinstance(buddhist_trans, str) and buddhist_trans.startswith("Error:")
gt_is_error = isinstance(general_trans, str) and general_trans.startswith("Error:")
if base_model_available and not comparison_error and not bt_is_error and not gt_is_error:
spec = result.get('buddhist_specificity_percent', 0.0)
sim = result.get('similarity_percent', 0.0)
print(f"BUDDHIST SPECIFICITY: {spec:.1f}% | SIMILARITY: {sim:.1f}%")
# Attempt to display visualizations only in notebook environment
if NOTEBOOK_ENV and Image and plt: # Check for Image and plt availability
print("\n--- Visualizations ---")
try:
# --- Specificity Chart ---
print("\nSpecificity Chart:")
chart_html, chart_base64 = create_specificity_chart_matplotlib(spec)
if chart_base64:
try: display(Image(data=base64.b64decode(chart_base64))) # Try direct image display
except Exception as e_img: logger_ui.warning(f"Direct image display failed: {e_img}. Falling back to HTML.") ; display(HTML(chart_html))
elif chart_html: display(HTML(chart_html)) # Display HTML if data failed or wasn't returned
else: print("[Chart generation failed]")
# --- Word Clouds ---
print("\nWord Clouds:")
wc_source_html, _ = create_word_cloud_html(result.get('source_text'), "Source Text", language='chinese', background_color='#f0fff0', colormap='Greens')
wc_buddhist_html, _ = create_word_cloud_html(buddhist_trans, f"{model_desc} Output", language='english', background_color='#e7f3ff', colormap='Blues')
wc_general_html, _ = create_word_cloud_html(general_trans, "General Output", language='english', background_color='#fff0f0', colormap='Reds')
# Display clouds side-by-side using HTML flexbox
display(HTML(f"""
<div style="display: flex; flex-wrap: wrap; justify-content: space-around; gap: 15px; margin-top: 10px; padding-top: 10px; border-top: 1px dashed #ccc;">
{wc_source_html}
{wc_buddhist_html}
{wc_general_html}
</div>
"""))
except Exception as viz_err:
print(f"\n[Error generating visualizations: {viz_err}]")
logger_ui.error(f"Visualization generation failed: {viz_err}", exc_info=True)
elif not base_model_available:
print("\n(Visualizations skipped: Base comparison model unavailable)")
elif NOTEBOOK_ENV:
print("\n(Visualizations skipped: Image/Matplotlib might be unavailable in this environment)")
else:
print("\n(Visualizations require a Notebook environment with IPython.display)")
elif comparison_error:
print(f"[Comparison Error]: {comparison_error}")
elif bt_is_error:
print(f"[Error in {model_desc.upper()} Translation]: {buddhist_trans}")
elif gt_is_error and base_model_available:
print(f"[Error in GENERAL Translation]: {general_trans}")
# No need for extra message if base model wasn't loaded, already handled above
print("="*80)
# --- Interactive Loop ---
print("\n" + "="*80); print(" ☯️ Buddhist Text Translator Interface ☯️ ".center(80, "=")); print("=" * 80)
primary_status = "(FT)" if trainer_instance.is_fine_tuned else "(Base)"
compare_status = "(Available)" if trainer_instance.base_model_loaded else "(Unavailable)"
print(f"Primary Model: {primary_status} | Comparison Model: {compare_status}")
print("Commands: /sample | /history | /export | /clear | /help | /quit")
print("-" * 80)
while True:
try:
user_input = input("\nEnter Chinese text or command > ").strip()
processed_input = ""
if not user_input: continue # Ignore empty input
if user_input.startswith("/"):
command = user_input.lower()
if command in ["/q", "/quit", "/exit"]:
print("Exiting translator..."); logger_ui.info("User initiated exit."); break
elif command == "/sample":
processed_input = random.choice(sample_texts)
print(f"\n[Sample Loaded]: {processed_input}")
# Automatically process the sample
elif command == "/history":
history_manager.display_history(10); continue # Display history and prompt again
elif command == "/export":
# Propose a default filename
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
default_filename = trainer_instance.output_dir / f"translation_history_{timestamp}.csv"
try:
filename_input = input(f"Enter filename to export history [{default_filename}]: ").strip()
filename_to_export = Path(filename_input) if filename_input else default_filename
success, message = history_manager.export_history_to_csv(filename_to_export)
print(f"[Export]: {message}")
except Exception as export_err: print(f"[Export Error]: {export_err}")
continue # Prompt again after export attempt
elif command == "/clear":
# Simple clear for text interface might not work well, just print separators
print("\n" * 2 + "=" * 80 + "\n")
# Re-print header
print(" ☯️ Buddhist Text Translator Interface ☯️ ".center(80, "=")); print("-" * 80)
print(f"Primary Model: {primary_status} | Comparison Model: {compare_status}")
print("Commands: /sample | /history | /export | /clear | /help | /quit"); print("-" * 80)
continue
elif command == "/help":
print("\nCommands:")
print(" /sample - Load a random sample Buddhist text.")
print(" /history - Show the last 10 translation entries.")
print(" /export - Save the translation history to a CSV file.")
print(" /clear - Clear the screen (print separators).")
print(" /help - Show this help message.")
print(" /quit - Exit the translator.")
continue
else:
print(f"Unknown command: '{command}'. Type /help for options.")
continue
else:
processed_input = user_input # Use direct user input
# --- Process Translation Request ---
if processed_input:
print("\n[Processing translation...]"); logger_ui.info(f"Translating: '{processed_input[:50]}...'")
comparison_result = trainer_instance.compare_translations(processed_input)
history_manager.add_entry(comparison_result) # Add result to history
display_result(comparison_result) # Display the formatted result
except (KeyboardInterrupt, EOFError):
print("\nExiting translator (Interrupted)."); logger_ui.info("UI loop interrupted."); break
except Exception as e:
print(f"\n[Error in UI Loop]: {e}"); logger_ui.error(f"Interactive UI Error: {e}", exc_info=True); time.sleep(0.5) # Pause briefly after error
# ==============================================================================
# 9. Main Execution Orchestration (With Fixes)
# ==============================================================================
def run_buddhist_translation_system(
corpus_path="corpus.json",
output_dir="./buddhist_nmt_output",
base_model_name="Helsinki-NLP/opus-mt-zh-en",
run_full_corpus_analysis=True,
run_sampled_eda=True,
full_corpus_sample_rate=1.0, # Default to full corpus for streaming analysis
eda_max_samples=25000, # Default limit for sampled EDA
run_training=False,
train_max_samples=None, # Default to all available for training
train_epochs=1,
train_batch_size=8, # Smaller default batch size
train_grad_accum=2, # Default grad accum
train_lr=3e-5,
train_eval_steps=500,
train_save_steps=500, # Align save/eval steps by default
train_max_token_length=128,
train_fp16=None, # Auto-detect FP16 based on CUDA availability
log_level=logging.INFO
):
"""Run the complete Buddhist translation system pipeline."""
main_start_time = time.time()
output_dir_path = Path(output_dir)
output_dir_path.mkdir(parents=True, exist_ok=True)
# Setup logging first
logger = setup_logging(output_dir_path, level=log_level)
logger.info("="*80); logger.info("🚀 Starting Buddhist NMT System v4.2 🚀"); logger.info("="*80)
# Log parameters
run_params = {k:v for k,v in locals().items() if k not in ['logger', 'output_dir_path', 'main_start_time']}
logger.info("System Configuration:")
for key, val in run_params.items(): logger.info(f" - {key}: {val}")
logger.info(f" - Python Version: {sys.version.split()[0]}")
logger.info(f" - Torch Version: {torch.__version__}")
logger.info(f" - Transformers Version: {transformers_version}")
logger.info(f" - CUDA Available: {torch.cuda.is_available()}")
logger.info(f" - MPS Available: {torch.backends.mps.is_available()}")
logger.info(f" - CWD: {Path.cwd()}")
logger.info(f" - Output Directory: {output_dir_path.resolve()}")
logger.info("-" * 80)
# --- Validate Corpus Path ---
corpus_file_path = Path(corpus_path)
if not corpus_file_path.is_file():
logger.critical(f"CRITICAL ERROR: Corpus file not found at '{corpus_file_path.resolve()}'. System cannot proceed.")
return None, None # Return None for both trainer and history
# Initialize history manager early
history_manager = TranslationHistory()
final_trainer = None # Initialize trainer variable
# --- Pipeline Steps ---
pipeline_steps = []
if run_full_corpus_analysis: pipeline_steps.append("Full Corpus Analysis")
if run_sampled_eda: pipeline_steps.append("Sampled EDA")
pipeline_steps.append("NMT Setup")
if run_training: pipeline_steps.append("NMT Training")
pipeline_steps.append("Interactive UI")
total_steps = len(pipeline_steps)
logger.info(f"Pipeline Steps Planned: {', '.join(pipeline_steps)}")
current_step = 0
# --- 1. Full Corpus Streaming Analysis ---
if run_full_corpus_analysis:
current_step += 1
logger.info(f"\n--- [Step {current_step}/{total_steps}] Starting Full Corpus Analysis ---")
full_eda_start = time.time()
try:
full_eda_output_dir = output_dir_path / "full_corpus_analysis"
analyze_full_corpus_streaming(
corpus_path=corpus_file_path,
output_dir=full_eda_output_dir,
chinese_font_path=CHINESE_FONT_PATH,
jieba_available=JIEBA_AVAILABLE,
sample_rate=full_corpus_sample_rate,
max_ngrams_to_save=10000,
word_cloud_max_words=150
)
logger.info(f"✅ Full Corpus Analysis completed in {time.time() - full_eda_start:.2f}s.")
except Exception as e:
logger.error(f"❌ Full Corpus Analysis failed: {e}", exc_info=True)
# Decide whether to continue or stop based on severity? For now, continue.
else:
logger.info("\n--- [Skipping Full Corpus Analysis] ---")
# --- 2. Sampled EDA ---
if run_sampled_eda:
current_step += 1
logger.info(f"\n--- [Step {current_step}/{total_steps}] Starting Sampled EDA ---")
sampled_eda_start = time.time()
try:
sampled_eda_output_dir = output_dir_path / "sampled_eda"
# Pass the specific output directory to the analyzer
analyzer = BuddhistTextAnalyzer(corpus_path=corpus_file_path, max_samples=eda_max_samples, output_dir_eda=sampled_eda_output_dir)
if analyzer.df is not None and not analyzer.df.empty:
analyzer.run_full_sampled_eda() # Uses its own output dir now
logger.info(f"✅ Sampled EDA completed in {time.time() - sampled_eda_start:.2f}s.")
else:
logger.warning("⚠️ Sampled EDA skipped: No data loaded into analyzer DataFrame.")
except Exception as e:
logger.error(f"❌ Sampled EDA failed: {e}", exc_info=True)
else:
logger.info("\n--- [Skipping Sampled EDA] ---")
# --- 3. NMT Setup & Optional Training ---
current_step += 1
nmt_setup_start = time.time()
logger.info(f"\n--- [Step {current_step}/{total_steps}] Setting up NMT Model ---")
try:
# Initialize the trainer - this loads the model (FT or base)
final_trainer = ImprovedBuddhistNMTTrainer(
base_model_name=base_model_name,
output_dir=output_dir_path # Main output dir for model
)
logger.info(f"NMT model setup initialized in {time.time() - nmt_setup_start:.2f}s.")
# Check if initialization succeeded before potentially training
if not final_trainer.model or not final_trainer.tokenizer:
raise RuntimeError("NMT model or tokenizer failed to initialize. Cannot proceed.")
if run_training:
current_step += 1
logger.info(f"\n--- [Step {current_step}/{total_steps}] Starting NMT Training ---")
training_start = time.time()
# Call the training method
train_result_obj = final_trainer.improved_train(
corpus_path=corpus_file_path,
max_samples=train_max_samples,
epochs=train_epochs,
batch_size=train_batch_size,
gradient_accumulation_steps=train_grad_accum,
learning_rate=train_lr,
eval_steps=train_eval_steps,
save_steps=train_save_steps,
max_token_length=train_max_token_length,
fp16=train_fp16 # Pass the potentially auto-detected value
)
training_time = time.time() - training_start
if train_result_obj and final_trainer.is_fine_tuned:
logger.info(f"✅ Training completed successfully in {training_time // 60:.0f}m {training_time % 60:.0f}s!")
elif train_result_obj: # Training ran but reload failed
logger.error("❌ CRITICAL: Training finished, but fine-tuned model failed to reload!")
# The system might be in an inconsistent state, maybe return None?
# For now, let it proceed but the FT model isn't active.
else: # Training failed or was interrupted
logger.error("❌ Training failed or was interrupted. Using model state from before training attempt.")
# The final_trainer object should have been reset by improved_train on failure
else:
logger.info(f"\n--- [Skipping NMT Training] --- Using {'existing Fine-Tuned model' if final_trainer.is_fine_tuned else 'Base model'}.")
logger.info("✅ NMT Setup/Training step finished.")
except Exception as e:
logger.critical(f"❌ CRITICAL ERROR during NMT Setup or Training: {e}", exc_info=True)
logger.critical("Cannot proceed to interactive translator.")
return None, history_manager # Return None for trainer, but history might have init logs
# --- 4. Launch Interactive UI ---
current_step += 1
logger.info(f"\n--- [Step {current_step}/{total_steps}] Initializing Interactive Translator UI ---")
ui_start_time = time.time()
try:
# Check again if trainer is usable after potential training step
if final_trainer and final_trainer.model and final_trainer.tokenizer:
run_interactive_translator(final_trainer, history_manager)
logger.info(f"✅ Interactive UI session finished in {time.time() - ui_start_time:.2f}s.")
else:
logger.error("❌ Cannot start interactive UI: NMT trainer is not in a valid state.")
except Exception as e:
logger.error(f"❌ Interactive UI session crashed: {e}", exc_info=True)
# Fall through to return the potentially usable trainer and history
# --- System Finish ---
total_time = time.time() - main_start_time
logger.info("="*80); logger.info(f"🏁 Buddhist NMT System Finished in {total_time // 60:.0f}m {total_time % 60:.2f}s 🏁"); logger.info("="*80)
return final_trainer, history_manager
# ==============================================================================
# 10. Example Execution Block (Unchanged from v4.1 - Uses the fixed functions)
# ==============================================================================
if __name__ == "__main__":
print(f"\n--- Script Execution Start ---"); print(f"Timestamp: {datetime.datetime.now()}"); print(f"Working Dir: {Path.cwd()}"); print(f"Python Executable: {sys.executable}")
# --- Display Versions ---
print(f"Relevant Library Versions:")
print(f" - Python: {sys.version.split()[0]}")
print(f" - Transformers: {transformers_version}")
print(f" - Torch: {torch.__version__}")
print(f" - Pandas: {pd.__version__}")
print(f" - NLTK: {nltk.__version__ if nltk else 'Not Found'}")
print(f" - Matplotlib: {matplotlib.__version__ if matplotlib else 'Not Found'}")
print(f" - WordCloud: {'Available' if WordCloud else 'Not Found'}")
print(f" - Jieba: {'Available' if JIEBA_AVAILABLE else 'Not Found'}")
print(f"Hardware Info:")
print(f" - CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available(): print(f" - GPU Devices: {torch.cuda.device_count()}")
print(f" - MPS (Apple Silicon GPU) Available: {torch.backends.mps.is_available()}")
print("-"*50)
# --- Configuration ---
CORPUS_FILE = "corpus.json" # Ensure this file exists and is JSONL
MODEL_OUTPUT_DIR = "./buddhist_nmt_zh_en_output_v4_2_fixed" # Updated output dir name
BASE_MODEL = "Helsinki-NLP/opus-mt-zh-en"
# --- Control Flags ---
RUN_FULL_CORPUS_ANALYSIS = True # Run the streaming analysis on (a sample of) the full corpus
RUN_SAMPLED_EDA = True # Run the detailed EDA on a smaller sample
DO_TRAINING = True # <<< !!! SET TO True TO FINE-TUNE THE MODEL !!! >>>
# --- Parameters ---
# EDA Params
FULL_CORPUS_SAMPLE_RATE = 0.1 # Use 10% for full corpus streaming analysis (faster)
EDA_SAMPLE_LIMIT = 50000 # Max samples for the detailed sampled EDA
# Training Params (only used if DO_TRAINING is True)
TRAIN_SAMPLE_LIMIT = 100000 # Max samples to load FOR TRAINING (None = all valid samples)
NUM_TRAINING_EPOCHS = 1 # Number of training epochs
TRAINING_BATCH_SIZE = 4 # Per device batch size (adjust based on GPU memory)
TRAINING_GRAD_ACCUM = 4 # Accumulate gradients over 4 steps (effective batch size = 4 * 4 = 16)
TRAINING_LEARNING_RATE = 3e-5
EVAL_SAVE_STEPS = 500 # Evaluate and potentially save every 500 steps
MAX_TOKEN_LENGTH = 128 # Max sequence length for tokenizer
FP16_TRAINING = None # Auto-detect based on CUDA availability (set True/False to override)
# Logging Level
LOGGING_LEVEL = logging.INFO # Change to logging.DEBUG for more verbose logs
# --- Execute ---
print(f"Checking Corpus Path: '{Path(CORPUS_FILE).resolve()}'...")
final_trainer_instance, final_history_manager = None, None
execution_start_time = time.time()
try:
# Check if corpus exists before calling the main function
if not Path(CORPUS_FILE).is_file():
raise FileNotFoundError(f"Corpus file '{CORPUS_FILE}' not found in {Path.cwd()}. Please provide the correct path.")
final_trainer_instance, final_history_manager = run_buddhist_translation_system(
corpus_path=CORPUS_FILE,
output_dir=MODEL_OUTPUT_DIR,
base_model_name=BASE_MODEL,
run_full_corpus_analysis=RUN_FULL_CORPUS_ANALYSIS,
run_sampled_eda=RUN_SAMPLED_EDA,
full_corpus_sample_rate=FULL_CORPUS_SAMPLE_RATE,
eda_max_samples=EDA_SAMPLE_LIMIT,
run_training=DO_TRAINING,
train_max_samples=TRAIN_SAMPLE_LIMIT,
train_epochs=NUM_TRAINING_EPOCHS,
train_batch_size=TRAINING_BATCH_SIZE,
train_grad_accum=TRAINING_GRAD_ACCUM,
train_lr=TRAINING_LEARNING_RATE,
train_eval_steps=EVAL_SAVE_STEPS,
train_save_steps=EVAL_SAVE_STEPS,
train_max_token_length=MAX_TOKEN_LENGTH,
train_fp16=FP16_TRAINING,
log_level=LOGGING_LEVEL
)
except FileNotFoundError as fnf_error:
print(f"\n[FATAL ERROR] {fnf_error}")
# Log the error if logger was set up, otherwise just print
try: logging.getLogger().critical(f"File Not Found: {fnf_error}", exc_info=False)
except: pass
except Exception as main_exception:
print(f"\n{'='*10} [FATAL] An Unhandled Exception Occurred {'='*10}")
# Log critical error
try: logging.getLogger().critical("Unhandled exception during main execution", exc_info=True)
except: pass
# Print details to console
import traceback
traceback.print_exc()
print(f"Error Type: {type(main_exception).__name__}")
print(f"Error Details: {main_exception}")
print("="*60)
finally:
execution_time = time.time() - execution_start_time
print("\n" + "="*30 + " Post-Execution Summary " + "="*30)
print(f"Total Execution Time: {execution_time // 60:.0f}m {execution_time % 60:.2f}s")
if final_trainer_instance and final_history_manager:
print("\n[Status] System finished execution (may include errors, check logs).")
model_state = "Fine-Tuned" if final_trainer_instance.is_fine_tuned else "Base"
print(f" - Active Primary Model State: {model_state}")
print(f" - Comparison Model Loaded: {'Yes' if final_trainer_instance.base_model_loaded else 'No'}")
print(f" - Output Directory: '{Path(MODEL_OUTPUT_DIR).resolve()}'")
print(f" - Final History Entries: {len(final_history_manager.get_history())}")
# Attempt a final test translation if models seem okay
if len(final_history_manager.get_history()) > 0 or (final_trainer_instance.model and final_trainer_instance.tokenizer):
try:
test_text = "色不異空"
print(f"\n--- Final Test Translation ('{test_text}') ---")
if hasattr(final_trainer_instance, 'compare_translations'):
primary_ok = final_trainer_instance.model and final_trainer_instance.tokenizer
base_ok = final_trainer_instance.base_model_loaded and final_trainer_instance.base_tokenizer
if primary_ok and base_ok:
test_result = final_trainer_instance.compare_translations(test_text)
# Display result in a readable format
print(f" Source: {test_result.get('source_text')}")
print(f" Primary ({model_state}): {test_result.get('buddhist_translation')}")
print(f" General (Base): {test_result.get('general_translation')}")
print(f" Specificity: {test_result.get('buddhist_specificity_percent'):.1f}%")
if test_result.get('error'): print(f" Comparison Note: {test_result['error']}")
elif primary_ok:
bt = final_trainer_instance.improved_translate(test_text)
print(f" Source: {test_text}")
print(f" Primary ({model_state}): {bt}")
print(" (Comparison model unavailable for test)")
else: print("[Warning] Final test skipped: Primary model not loaded.")
else: print("[Warning] Final test skipped: 'compare_translations' method not available.")
except Exception as test_err: print(f"[Error] Final test translation failed: {test_err}")
# Ask to export history only if running in a TTY (likely interactive console)
if sys.stdin.isatty():
try:
export_prompt = input("\nExport final translation history to CSV? (y/N): ").strip().lower()
if export_prompt == 'y':
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
final_hist_filename = Path(MODEL_OUTPUT_DIR) / f"final_translation_history_{timestamp}.csv"
success, msg = final_history_manager.export_history_to_csv(final_hist_filename)
print(f"[Export]: {msg}")
except EOFError: pass # Handle case where input stream is closed
except Exception as final_export_err: print(f"[Error] Final history export failed: {final_export_err}")
else:
print("\n[Status] System did not complete successfully or trainer/history objects are unavailable.")
print(f" Please check logs in '{Path(MODEL_OUTPUT_DIR) / 'logs'}' for details.")
print("\n--- End of Script Execution ---")