Skip to main content

End-to-end script with an example dataset with prediction and clustering

# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "ucimlrepo",
#   "pandas",
#   "woodwide",
# ]
# ///
import argparse
import os
import sys
import time
import json
import pandas as pd
from ucimlrepo import fetch_ucirepo
from woodwide import WoodWide

# Defaults
DEFAULT_BASE_URL = "https://beta.woodwide.ai/"
UCI_DATASET_ID = 2  # Adult dataset


def setup_args():
    parser = argparse.ArgumentParser(
        description="Test Woodwide API with UCI Adult dataset"
    )
    parser.add_argument("-k", "--api-key", required=True, help="Woodwide API Key")
    parser.add_argument(
        "-m", "--model-name", default="uci_adult_model", help="Name for the model"
    )
    parser.add_argument(
        "-d", "--dataset-name", default="uci_adult", help="Name for the dataset"
    )
    parser.add_argument(
        "-o", "--output-file", help="File path to save inference results"
    )
    parser.add_argument(
        "--base-url", default=DEFAULT_BASE_URL, help="Base URL for API"
    )
    parser.add_argument(
        "-c", "--clustering", action="store_true", help="Run clustering instead of prediction"
    )
    return parser.parse_args()


def fetch_and_prepare_data():
    print(f"Fetching UCI dataset ID={UCI_DATASET_ID}...")
    dataset = fetch_ucirepo(id=UCI_DATASET_ID)

    X = dataset.data.features
    y = dataset.data.targets

    # Combine features and targets
    df = pd.concat([X, y], axis=1)

    # Simple train/test split (80/20)
    train_df = df.sample(frac=0.8, random_state=42)
    test_df = df.drop(train_df.index)

    # Save to temporary CSV files
    train_path = "uci_train.csv"
    test_path = "uci_test.csv"

    train_df.to_csv(train_path, index=False)
    test_df.to_csv(test_path, index=False)

    label_column = y.columns[0]
    print(f"Data prepared. Label column: '{label_column}'")
    print(f"Train shape: {train_df.shape}, Test shape: {test_df.shape}")

    return train_path, test_path, label_column


def upload_dataset(client, file_path, name):
    print(f"Uploading {file_path} as '{name}'...")
    start_time = time.time()

    with open(file_path, "rb") as f:
        # Note: The SDK typically handles file opening, but if it takes binary IO:
        dataset = client.api.datasets.upload(
            file=f,
            name=name,
            overwrite=True
        )

    elapsed = time.time() - start_time
    print(f"Upload took {elapsed:.2f}s")

    # Accessing ID directly assuming the SDK returns a Pydantic model
    dataset_id = dataset.id
    print(f"Dataset Uploaded. ID: {dataset_id}\n")
    return dataset_id


def train_model(client, dataset_name, model_name, label_column, is_clustering=False):
    if is_clustering:
        print(f"Training Clustering Model '{model_name}' using dataset '{dataset_name}'...")
        endpoint = "/api/models/clustering/train"
        data = {
            "model_name": model_name,
            "overwrite": "true",
        }
    else:
        print(f"Training Prediction Model '{model_name}' using dataset '{dataset_name}'...")
        endpoint = "/api/models/prediction/train"
        data = {
            "model_name": model_name,
            "label_column": label_column,
            "overwrite": "true",
        }

    start_time = time.time()

    # Make the raw HTTP request for this endpoint, it's having issues
    response = client._client.post(
        endpoint,
        params={"dataset_name": dataset_name},
        data=data,
        headers=client.auth_headers,
    )

    elapsed = time.time() - start_time
    print(f"Request took {elapsed:.2f}s")

    if response.status_code != 200:
        print(f"Error starting training: {response.status_code}")
        print(response.text)
        sys.exit(1)

    response_json = response.json()
    model_id = response_json.get("id")
    if not model_id:
        print("Error: No Model ID returned")
        print(response_json)
        sys.exit(1)

    print(f"Model Training Started. ID: {model_id}\n")
    return model_id


def wait_for_training(client, model_id):
    print(f"Waiting for Model Training to Complete (ID: {model_id})...")
    start_time = time.time()
    timeout = 3000

    while True:
        model = client.api.models.retrieve(model_id)
        training_status = model.training_status

        if training_status == "COMPLETE":
            print("Training Complete.")
            print(model)
            break
        elif training_status == "FAILED":
            print("Error: Model Training Failed.")
            print(model)
            sys.exit(1)

        elapsed = time.time() - start_time
        if elapsed >= timeout:
            print(f"Error: Training Timed Out after {timeout} seconds.")
            sys.exit(1)

        print(f"Status: {training_status}. Waiting...")
        time.sleep(2)

    print(f"Success: Took {elapsed:.2f} seconds to train model.\n")


def run_inference(client, model_id, test_dataset_id, output_file, is_clustering=False):
    print(
        f"Running Inference on Model {model_id} with Test Dataset ID {test_dataset_id}..."
    )
    start_time = time.time()

    if is_clustering:
        result = client.api.models.clustering.infer(
            model_id=model_id,
            dataset_id=test_dataset_id
        )
    else:
        result = client.api.models.prediction.infer(
            model_id=model_id,
            dataset_id=test_dataset_id
        )

    elapsed = time.time() - start_time
    print(f"Inference took {elapsed:.2f}s")

    # Assuming result is a dict or list, or a Pydantic model we can dump
    try:
        # If it's a Pydantic model
        if hasattr(result, "model_dump_json"):
             formatted_result = result.model_dump_json(indent=2)
        # If it's a dict
        elif isinstance(result, (dict, list)):
            formatted_result = json.dumps(result, indent=2)
        else:
            formatted_result = str(result)
    except Exception:
        formatted_result = str(result)

    if output_file:
        with open(output_file, "w") as f:
            f.write(formatted_result)
        print(f"Inference results saved to {output_file}")
    else:
        print("Inference Response:")
        print(formatted_result)
    print("")


def main():
    args = setup_args()

    # Initialize SDK Client
    client = WoodWide(
        api_key=args.api_key,
        base_url=args.base_url
    )
    # 1. Fetch Data
    train_path, test_path, label_column = fetch_and_prepare_data()

    try:
        # 2. Upload Train
        train_dataset_id = upload_dataset(
            client, train_path, args.dataset_name
        )

        # 3. Upload Test
        test_dataset_name = f"{args.dataset_name}_test"
        test_dataset_id = upload_dataset(
            client, test_path, test_dataset_name
        )

        # 4. Train Model
        model_id = train_model(
            client,
            args.dataset_name,
            args.model_name,
            label_column,
            is_clustering=args.clustering,
        )

        # 5. Wait for Training
        wait_for_training(client, model_id)

        # 6. Run Inference
        run_inference(
            client, model_id, test_dataset_id, args.output_file, is_clustering=args.clustering
        )

    finally:
        # Cleanup temp files
        if os.path.exists(train_path):
            os.remove(train_path)
        if os.path.exists(test_path):
            os.remove(test_path)


if __name__ == "__main__":
    main()