Copy
Ask AI
# requires-python = ">=3.11"
# dependencies = [
# "ucimlrepo",
# "pandas",
# "requests",
# ]
# ///
import argparse
import os
import sys
import time
import json
import pandas as pd
import requests
from ucimlrepo import fetch_ucirepo
# Defaults
DEFAULT_BASE_URL = "https://beta.woodwide.ai/"
UCI_DATASET_ID = 45 # Heart Disease dataset
def setup_args():
parser = argparse.ArgumentParser(
description="Test Woodwide API with UCI Heart Disease dataset for Clustering and Anomaly Detection"
)
parser.add_argument("-k", "--api-key", required=True, help="Woodwide API Key")
parser.add_argument(
"-m", "--model-name", default="heart_disease", help="Base name for the models"
)
parser.add_argument(
"-d", "--dataset-name", default="heart_disease", help="Name for the dataset"
)
parser.add_argument(
"-o", "--output-file", default="results.json", help="File path to save results"
)
parser.add_argument(
"--base-url", default=DEFAULT_BASE_URL, help="Base URL for API"
)
return parser.parse_args()
def fetch_and_prepare_data():
print(f"Fetching UCI dataset ID={UCI_DATASET_ID} (Heart Disease)...")
dataset = fetch_ucirepo(id=UCI_DATASET_ID)
X = dataset.data.features
y = dataset.data.targets
# Truncate for file size limitations
#N = len(X)
#X = X[:N]
#y = y[:N]
# Combine features and targets
df = pd.concat([X, y], axis=1)
# Save to temporary CSV file
dataset_path = "heart_disease.csv"
df.to_csv(dataset_path, index=False)
print(f"Data prepared and saved to {dataset_path}. Shape: {df.shape}")
return dataset_path
def main():
args = setup_args()
base_url = args.base_url.rstrip("/")
headers = {
"Authorization": f"Bearer {args.api_key}",
"accept": "application/json",
}
# 1. Fetch Data
dataset_path = fetch_and_prepare_data()
try:
# 2. Upload Dataset
print(f"Uploading {dataset_path} as '{args.dataset_name}'...")
start_time = time.time()
with open(dataset_path, "rb") as f:
files = {"file": (os.path.basename(dataset_path), f, "text/csv")}
data = {"name": args.dataset_name, "overwrite": "true"}
response = requests.post(f"{base_url}/api/datasets", headers=headers, files=files, data=data)
if response.status_code != 200:
print(f"Error uploading dataset: {response.status_code}\n{response.text}")
sys.exit(1)
dataset_id = response.json().get("id")
print(f"Upload took {time.time() - start_time:.2f}s. ID: {dataset_id}\n")
# 3. Train Clustering Model
clustering_model_name = f"{args.model_name}_clustering"
print(f"Step 1: Training Clustering Model '{clustering_model_name}' to identify patient groups...")
start_time = time.time()
response = requests.post(
f"{base_url}/api/models/clustering/train",
params={"dataset_name": args.dataset_name},
data={"model_name": clustering_model_name, "overwrite": "true"},
headers=headers,
)
if response.status_code != 200:
print(f"Error starting clustering training: {response.status_code}\n{response.text}")
sys.exit(1)
clustering_model_id = response.json().get("id")
print(f"Clustering training started. ID: {clustering_model_id}")
# Wait for Clustering
timeout = 3000
while True:
response = requests.get(f"{base_url}/api/models/{clustering_model_id}", headers=headers)
status = response.json().get("training_status")
if status == "COMPLETE":
print(f"Clustering complete! (Took {time.time() - start_time:.2f}s)\n")
break
elif status == "FAILED":
print(f"Error: Clustering failed.\n{response.text}")
sys.exit(1)
time.sleep(2)
# 4. Run Clustering Inference to get clusters
print(f"Running clustering inference to segment patients...")
response = requests.post(
f"{base_url}/api/models/clustering/{clustering_model_id}/infer",
params={"dataset_id": dataset_id},
headers=headers,
stream=True,
)
if response.status_code != 200:
print(f"Error running clustering inference: {response.status_code}\n{response.text}")
sys.exit(1)
# Read the streamed response
clusters_raw = b""
for chunk in response.iter_content(chunk_size=None):
if chunk:
clusters_raw += chunk
# Sometimes the response is wrapped in a way that needs cleaning,
# or we just need to ensure we have the full JSON.
try:
clusters = json.loads(clusters_raw)
except json.JSONDecodeError as e:
print(f"Error decoding clustering JSON: {e}")
print(f"Raw response start: {clusters_raw[:100]}")
sys.exit(1)
# Extract cluster description if available
cluster_descriptions = {}
if isinstance(clusters, dict):
cluster_descriptions = clusters.get("cluster_descriptions", {})
# If clusters is a dict, the labels are likely in 'cluster_label'
cluster_labels = clusters.get("cluster_label", {})
# Convert to list for the rest of the script logic
if cluster_labels:
# Assuming cluster_labels is a dict of {index: label}
sorted_indices = sorted([int(k) for k in cluster_labels.keys()])
clusters_list = [cluster_labels[str(i)] for i in sorted_indices]
else:
clusters_list = clusters
else:
clusters_list = clusters
# 5. Filter for the largest cluster for targeted anomaly detection
# Count occurrences of each cluster
if isinstance(clusters_list, list):
from collections import Counter
cluster_counts = Counter(clusters_list)
# Get the cluster with the maximum count
target_cluster = max(cluster_counts, key=cluster_counts.get)
print(f"Largest cluster identified: Cluster {target_cluster} with {cluster_counts[target_cluster]} patients.")
else:
target_cluster = 0
print(f"Warning: Could not determine cluster counts. Defaulting to Cluster {target_cluster}.")
target_cluster_desc = cluster_descriptions.get(str(target_cluster), "No description available.")
print(f"Filtering patients in Cluster {target_cluster} for targeted anomaly detection...")
print(f"Cluster Description: {target_cluster_desc}")
# Load the original data to filter it
df = pd.read_csv(dataset_path)
# We assume 'clusters_list' is a list of integers corresponding to the rows in df
if isinstance(clusters_list, list) and len(clusters_list) == len(df):
df['cluster'] = clusters_list
cluster_df = df[df['cluster'] == target_cluster].drop(columns=['cluster'])
if cluster_df.empty:
print(f"Warning: Cluster {target_cluster} is empty. Using full dataset instead.")
cluster_dataset_path = dataset_path
cluster_dataset_name = args.dataset_name
else:
cluster_dataset_path = f"heart_disease_cluster_{target_cluster}.csv"
cluster_df.to_csv(cluster_dataset_path, index=False)
cluster_dataset_name = f"{args.dataset_name}_cluster_{target_cluster}"
print(f"Uploading Cluster {target_cluster} data ({len(cluster_df)} patients)...")
with open(cluster_dataset_path, "rb") as f:
files = {"file": (os.path.basename(cluster_dataset_path), f, "text/csv")}
data = {"name": cluster_dataset_name, "overwrite": "true"}
requests.post(f"{base_url}/api/datasets", headers=headers, files=files, data=data)
else:
print("Warning: Could not map clusters to rows. Using full dataset for anomaly detection.")
cluster_dataset_path = dataset_path
cluster_dataset_name = args.dataset_name
# 6. Train Anomaly Detection Model on the specific cluster
anomaly_model_name = f"{args.model_name}_cluster_{target_cluster}_anomaly"
print(f"Step 2: Training Anomaly Detection Model '{anomaly_model_name}' for Cluster {target_cluster}...")
start_time = time.time()
response = requests.post(
f"{base_url}/api/models/anomaly/train",
params={"dataset_name": cluster_dataset_name},
data={"model_name": anomaly_model_name, "overwrite": "true"},
headers=headers,
)
if response.status_code != 200:
print(f"Error starting anomaly training: {response.status_code}\n{response.text}")
sys.exit(1)
anomaly_model_id = response.json().get("id")
print(f"Anomaly detection training started. ID: {anomaly_model_id}")
# Wait for Anomaly Detection
while True:
response = requests.get(f"{base_url}/api/models/{anomaly_model_id}", headers=headers)
status = response.json().get("training_status")
if status == "COMPLETE":
print(f"Anomaly detection complete! (Took {time.time() - start_time:.2f}s)\n")
break
elif status == "FAILED":
print(f"Error: Anomaly detection failed.\n{response.text}")
sys.exit(1)
time.sleep(2)
# 7. Run Anomaly Inference
print(f"Running anomaly detection inference on Cluster {target_cluster}...")
start_time = time.time()
# Get the dataset ID for the cluster dataset
response = requests.get(f"{base_url}/api/datasets", headers=headers)
datasets = response.json()
cluster_dataset_id = next((d['id'] for d in datasets if d['name'] == cluster_dataset_name), None)
response = requests.post(
f"{base_url}/api/models/anomaly/{anomaly_model_id}/infer",
params={"dataset_id": cluster_dataset_id},
headers=headers,
stream=True,
)
if response.status_code != 200:
print(f"Error running anomaly inference: {response.status_code}\n{response.text}")
sys.exit(1)
# Read the streamed response
anomalies_raw = b""
for chunk in response.iter_content(chunk_size=None):
if chunk:
# The server might send chunks that are just strings,
# but iter_content with chunk_size=None on a StreamingResponse
# should give us the raw bytes.
anomalies_raw += chunk
# Clean up the raw response if it's malformed
# The server yields '{' then '"anomalous_ids": [' which might result in '{"anomalous_ids": ['
# If it's missing the opening quote for the key, we fix it here,
# but the server-side fix is better.
try:
anomalies = json.loads(anomalies_raw)
except json.JSONDecodeError as e:
# Fallback: try to see if it's just a missing quote after the first '{'
if anomalies_raw.startswith(b'{"anomalous_ids"'):
# This would be valid, so the error must be elsewhere
pass
print(f"Error decoding anomaly JSON: {e}")
print(f"Raw response: {anomalies_raw.decode('utf-8', errors='replace')}")
sys.exit(1)
print(f"Anomaly detection took {time.time() - start_time:.2f}s")
# 8. Combine and Save Results
# The anomaly endpoint returns a list of anomalous_ids
anomalous_ids = anomalies.get("anomalous_ids", [])
final_results = {
"target_cluster": target_cluster,
"target_cluster_description": target_cluster_desc,
"clustering_model_id": clustering_model_id,
"anomaly_model_id": anomaly_model_id,
"cluster_size": len(cluster_df) if 'cluster_df' in locals() else "unknown",
"anomalous_ids": anomalous_ids
}
# 9. Extract details for anomalous patients
anomalous_details = []
if anomalous_ids and 'cluster_df' in locals():
# The IDs from the API are now integers matching the row index of the uploaded CSV
try:
# Relevant columns for Heart Disease dataset
relevant_cols = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'num']
# Filter for columns that actually exist in the dataframe
display_cols = [c for c in relevant_cols if c in cluster_df.columns]
for aid in anomalous_ids[:5]: # Get details for first 5
idx = int(aid)
if idx < len(cluster_df):
patient_data = cluster_df.iloc[idx][display_cols].to_dict()
patient_data['id'] = aid
anomalous_details.append(patient_data)
except (ValueError, IndexError):
pass
final_results["anomalous_details_sample"] = anomalous_details
formatted_result = json.dumps(final_results, indent=2)
with open(args.output_file, "w") as f:
f.write(formatted_result)
print(f"Results saved to {args.output_file}")
# Print a small snippet explaining the results
print("\n" + "="*120)
print(f"ANALYSIS SUMMARY FOR CLUSTER {target_cluster}")
print(f"Cluster {target_cluster} Description: {target_cluster_desc}")
print("="*120)
print(f"We identified a group of {final_results['cluster_size']} patients with similar clinical profiles.")
print(f"Within this specific group, we found {len(anomalous_ids)} anomalous cases.")
if anomalous_details:
print("\nSample of Anomalous Patients (Full Clinical Metrics):")
# Create a table-like header with most columns
cols_to_print = ['id', 'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'num']
header = " | ".join([f"{c.upper():<7}" for c in cols_to_print])
print(header)
print("-" * len(header))
for p in anomalous_details:
row = []
for c in cols_to_print:
val = p.get(c, '')
if isinstance(val, float):
row.append(f"{val:<7.1f}")
else:
row.append(f"{str(val):<7}")
print(" | ".join(row))
if len(anomalous_ids) > 5:
print(f"\n... and {len(anomalous_ids) - 5} more anomalous cases.")
elif anomalous_ids:
print(f"\nFound {len(anomalous_ids)} anomalous patient IDs: {', '.join(anomalous_ids[:10])}")
print("="*120 + "\n")
finally:
if 'cluster_dataset_path' in locals() and os.path.exists(cluster_dataset_path) and cluster_dataset_path != dataset_path:
os.remove(cluster_dataset_path)
if __name__ == "__main__":
main()