This project focuses on the automated classification of Caenorhabditis elegans (C. elegans) nematodes based on their longevity phenotypes. Specifically, we aim to distinguish between worms treated with Terbinafine (Terbinafine+) and untreated control worms (Terbinafine-).
Terbinafine helps extending the lifespan of C. elegans. By analyzing movement trajectories and derived features, we investigate whether machine learning models can accurately predict the treatment group from behavioral data.
A key objective of this project was to adopt a featureless approach as much as possible. Instead of relying heavily on handcrafted biological metrics, we prioritized methods (such as ROCKET, Tail-MIL, and CNNs) that learn directly from raw data representations, minimizing the bias introduced by manual feature selection.
The raw data consists of movement trajectories tracked over the worms' lifespans. Since the raw tracking data can be noisy and inconsistent, a rigorous preprocessing pipeline was applied to ensure data quality and relevance. The steps were applied in the following order:
Objective: Remove non-biological noise and tracking errors.
- Drop First Row: The first row of many files contained inconsistent timestamps or initialization artifacts. It was systematically removed.
- Death Clipping: Frames recorded after the annotated frame of death were removed to focus analysis on the living phase.
Objective: Fix tracking "jumps" where the camera lost the worm or swapped identity, causing unrealistic displacements.
- Displacement Thresholding: We removed frames where the sudden displacement exceeded a biological threshold (e.g., > 16 pixels/frame), that are caused by tracking errors.
- Coordinate Reconstruction: When gaps or jumps were removed, the worm's trajectory was stitched back together (cumulative summation of valid displacements) to recreate a continuous, biologically plausible path.
Objective: Handle the variable lifespan of worms.
- Tracks were divided into fixed-length segments (e.g., 900 frames). This standardizes the input for models that process sequential data and allows us to analyze behavior at different life stages.
Objective: Ensure feature consistency after trajectory repair.
- Speed Recomputation: Since coordinate reconstruction modifies the path, we recalculate the instantaneous speed (
ComputedSpeed) from the new coordinates to ensure it matches the visual trajectory. - Turning Rate & Noise Filtering: We calculate the instantaneous turning rate (change in heading). To prevent sensor noise from being interpreted as movement, we force the turning rate to 0 whenever the worm is stationary (speed < 0.05).
For the Convolutional Neural Network (CNN) approach, we treated the trajectory classification as a computer vision problem. Instead of scalar features, we generated visual representations of the worm's movement.
Windowing Strategy: although the tabular models use full 900-frame segments, for the CNN we further slice these segments into smaller clips of 300 frames (with a stride of 150).
- This allows us to filter out clips containing
NaNs(gaps) without discarding the entire segment. - It focuses the network on shorter, more detailed movement patterns.
We then converted these time-series coordinates
- Red Channel (Path): Binary occupancy map. Indicates where the worm has been.
- Green Channel (Time): Gradient from 0 to 255. Encodes when the worm was at a specific position (fading from dark to bright). This preserves temporal order in a static image.
- Blue Channel (Speed): Intensity mapped to instantaneous speed. Brighter pixels indicate faster movement at that location.
This encoding might allows the CNN (e.g., ResNet) to learn complex patterns like "slowing down while turning" or "looping behavior" that scalar features might miss.
One of the most critical aspects of our methodology was ensuring zero data leakage between training and validation sets. Since we segmented the lifespan of each worm into multiple data points:
- A simple random split would put segments of the same worm in both the training and validation sets.
- The model would then learn to recognize the specific worm's movement style rather than the treatment effect, leading to massively inflated performance metrics.
Solution: We implemented Stratified Group K-Fold Cross-Validation, encapsulated in our custom module utils/train_utils/fold_utils.py.
- Group: We grouped data by
WormID. All segments from a single worm are forced to be in the same fold (either all in train or all in validation). - Stratified: We ensured that the ratio of Treated vs. Control worms remains balanced across folds.
This rigorous validation strategy ensures that our reported metrics reflect the model's ability to generalize to new, unseen worms.
We implemented two distinct pipelines to robustly evaluate different modeling approaches.
This pipeline handles traditional Machine Learning models and Time Series classifiers.
- Models Supported: Logistic Regression, Random Forest, SVM, XGBoost, and Time Series models like ROCKET and Tail-MIL.
- Architecture: The pipeline is designed to be model-agnostic. All models inherit from a
BaseModelabstract class, ensuring a consistent interface for data loading (load_data) and execution (run_fold). - Data Augmentation: To improve model robustness (avoid overfitting on the dataset's characteristics), we implemented a
UnifiedCElegansAugmentedDatasetthat expands the training data with the following transforms:- Rotations: Random rotations on the X and Y axis.
- Translation: Random X/Y offsets.
- Scaling: Random scaling (0.8x to 1.2x) of all variates.
- Workflow:
- Loads the unified dataset.
- Initializes models and loads their specific data requirements.
- Performs Stratified Group K-Fold Cross-Validation (ensuring all segments of one worm stay in the same fold prevent leakage).
- Training, Validation, and Metric reporting (Accuracy, F1, Precision, Recall).
A dedicated pipeline for Deep Learning models processing the image datasets.
- Models Supported: ResNet18, ResNet50, DenseNet121.
- Workflow:
- Custom Dataset Class (
CElegansCNNDataset): Loads images and extracts labels/worm IDs. - Augmentation: Applies random rotations, flips, and normalization to improve generalization.
- Training Loop: Runs a PyTorch training loop with Stratified Group K-Fold.
- Comparison: Automatically plots and compares the performance of different architectures.
- Custom Dataset Class (
.
├── cnn_dataset/ # Generated dataset for CNNs
├── data/ # Raw data and summary files
├── models/ # Model definitions
│ ├── base.py # Abstract base class for models
│ ├── model_cnn.py # CNN factory (ResNet, DenseNet)
│ ├── model_lr.py # Logistic Regression wrapper
│ └── ... # Other model wrappers
├── scripts/ # Execution scripts
│ ├── cnn_pipeline.py # CNN training pipeline
│ ├── main_pipeline.py # Main pipeline for tabular/time-series models
│ ├── preprocess.py # Data cleaning and reconstruction
│ └── extract_features.py # Feature extraction for tabular models
├── utils/ # Utility functions
│ ├── train_utils/ # Datasets and Stratified Group K-Fold logic
│ └── plot_utils/ # Plotting functions
└── requirements.txt # Python dependencies
- Clone the repository.
- Install the required dependencies using pip:
pip install -r requirements.txtBefore running any analysis, the raw data must be cleaned, reconstructed, and prepared.
Standard Cleaning & Reconstruction:
python scripts/preprocess.pyThis script reads from data/, cleans the trajectories, and outputs validation-ready CSVs to preprocessed_data/.
Generating CNN Images: To generate the dataset for the CNN pipeline (images):
python scripts/preprocess.py --generate-images --cnn-output-dir "cnn_dataset/"Feature Extraction (for tabular models): After standard preprocessing, extract scalar features (speed, tortuosity, etc.):
python scripts/extract_features.pyTrain and evaluate tabular models (Logistic Regression, Random Forest, etc.) or Time Series models (ROCKET).
python scripts/main_pipeline.py --pytorch_dir "preprocessed_data/"Options:
--plot: Generate plots of the results.--augmented_data/-a: Use the augmented dataset. The number given after the argument specify the number of transformations per worm. If no number is specified but the options is present, 5 augmentations per worm will be produced.--prod: Run in production mode (saves the best model).-o: Specify output JSON filename for results.
Train and compare CNN models (ResNet18, ResNet50, DenseNet).
python scripts/cnn_pipeline.py --data_dir "cnn_dataset"Configuration:
- You can modify the
models_configdictionary insidescripts/cnn_pipeline.pyto change architectures, batch sizes, or learning rates.
scripts/preprocess.py: Implementation of cleaning and trajectory reconstruction logic.scripts/main_pipeline.py: Orchestrator for tabular and time-series models.scripts/cnn_pipeline.py: Orchestrator for CNN models.models/: Directory containing model definitions (ResNet factory, LogisticRegression wrapper, etc.).utils/train_utils/dataset.py: Unified data loading logic.
