The code in this repository is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
GolfDB is a high-quality video dataset created for general recognition applications in the sport of golf, and specifically for the task of golf swing sequencing.
This repo contains a simple PyTorch implemention of the SwingNet baseline model presented in the paper. The model was trained on split 1 without any data augmentation and achieved an average PCE of 71.5% (PCE of 76.1% reported in the paper is credited to data augmentation including horizontal flipping and affine transformations).
If you use this repo please cite the GolfDB paper:
@InProceedings{McNally_2019_CVPR_Workshops,
author = {McNally, William and Vats, Kanav and Pinto, Tyler and Dulhanty, Chris and McPhee, John and Wong, Alexander},
title = {GolfDB: A Video Database for Golf Swing Sequencing},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2019}
}
Run generate_splits.py to convert the .mat dataset file to a dataframe and generate the 4 splits.
-
I have provided the preprocessed video clips for a frame size of 160x160 (download here). Place 'videos_160' in the data directory. If you wish to use a different input configuration you must download the YouTube videos (URLs provided in dataset) and preprocess the videos yourself. I have provided preprocess_videos.py to help with that.
-
Download the MobileNetV2 pretrained weights from this repository and place 'mobilenet_v2.pth.tar' in the root directory.
-
Run train.py
train.py now includes:
- Mixed precision training (AMP) and gradient scaling
- cuDNN benchmark + high matmul precision on CUDA
- DataLoader optimizations (
pin_memory,persistent_workers,prefetch_factor) - OneCycleLR scheduler + gradient clipping
- Periodic validation and automatic best-checkpoint saving (
models/swingnet_best.pth.tar)
Environment knobs (all optional):
GOLFDB_ITERATIONS(default:2000optimizer steps)GOLFDB_SAVE_INTERVAL(default:100)GOLFDB_SEQ_LENGTH(default:64)GOLFDB_BATCH_SIZE(default:22)GOLFDB_NUM_WORKERS(default:6)GOLFDB_USE_AMP(default:1)GOLFDB_USE_COMPILE(default:0)GOLFDB_PIN_MEMORY(default:1)GOLFDB_PERSISTENT_WORKERS(default:1)GOLFDB_PREFETCH_FACTOR(default:4)GOLFDB_FREEZE_LAYERS(default:0)GOLFDB_LR(default:0.001)GOLFDB_ONECYCLE_PCT_START(default:0.1)GOLFDB_GRAD_ACCUM_STEPS(default:1)GOLFDB_LOSS(default:ce, options:ce,focal)GOLFDB_FOCAL_GAMMA(default:2.0)GOLFDB_LABEL_SMOOTHING(default:0.0, CE only)GOLFDB_USE_AUGMENT(default:1)GOLFDB_AUG_FLIP_P(default:0.5)GOLFDB_AUG_ROTATE(default:8.0)GOLFDB_AUG_TRANSLATE(default:0.05)GOLFDB_AUG_SCALE_MIN(default:0.95)GOLFDB_AUG_SCALE_MAX(default:1.05)GOLFDB_AUG_BRIGHTNESS(default:0.12)GOLFDB_AUG_CONTRAST(default:0.12)GOLFDB_MAX_GRAD_NORM(default:1.0)GOLFDB_LOG_EVERY(default:10)GOLFDB_EVAL_INTERVAL(default:100)GOLFDB_EVAL_NUM_WORKERS(default:min(max(num_workers, 1), 4))GOLFDB_EVAL_DISP(default:0)GOLFDB_DATALOADER_TIMEOUT_S(default:60, forced to0whenGOLFDB_NUM_WORKERS=0)
Sample commands:
High-performance training on GPU:
GOLFDB_NUM_WORKERS=8 \
GOLFDB_USE_AMP=1 \
GOLFDB_USE_COMPILE=1 \
GOLFDB_PIN_MEMORY=1 \
GOLFDB_PERSISTENT_WORKERS=1 \
GOLFDB_PREFETCH_FACTOR=4 \
GOLFDB_FREEZE_LAYERS=0 \
GOLFDB_LR=0.001 \
GOLFDB_EVAL_INTERVAL=100 \
GOLFDB_EVAL_NUM_WORKERS=4 \
GOLFDB_LOG_EVERY=20 \
python train.pyAccuracy-first preset (stronger regularization + longer schedule):
GOLFDB_ITERATIONS=6000 \
GOLFDB_SEQ_LENGTH=96 \
GOLFDB_BATCH_SIZE=12 \
GOLFDB_GRAD_ACCUM_STEPS=2 \
GOLFDB_NUM_WORKERS=8 \
GOLFDB_USE_AUGMENT=1 \
GOLFDB_LOSS=focal \
GOLFDB_FOCAL_GAMMA=2.0 \
GOLFDB_LR=0.0008 \
GOLFDB_USE_AMP=1 \
GOLFDB_USE_COMPILE=1 \
GOLFDB_EVAL_INTERVAL=200 \
GOLFDB_EVAL_NUM_WORKERS=4 \
python train.pyDebug data loading / path issues:
GOLFDB_DEBUG_TRAIN=1 \
GOLFDB_DEBUG_DATALOADER=1 \
GOLFDB_NUM_WORKERS=0 \
python train.pyThe project optimizations were grouped into three buckets: throughput, generalization, and training stability.
Throughput (faster training):
- AMP (
torch.amp.autocast+GradScaler) reduces GPU memory and improves tensor-core utilization. torch.compilecan fuse graph segments and reduce Python overhead.- Data pipeline tuning (
num_workers,pin_memory,persistent_workers,prefetch_factor) keeps GPU fed. - Reduced logging frequency lowers host-side synchronization overhead.
Generalization (better validation PCE):
- Sequence-consistent augmentation (
RandomAugment) applies one geometric/photometric transform to all frames in a clip, preserving temporal coherence. - Configurable objective (
ceorfocal) handles class imbalance and hard-example emphasis. - Optional label smoothing for CE improves calibration and reduces over-confidence.
- Longer schedules + best-checkpoint-by-PCE selection avoid relying on final-step weights.
Training stability (fewer bad runs):
- OneCycleLR provides warmup + annealing behavior with strong empirical convergence.
- Gradient clipping prevents unstable updates in the CNN+LSTM stack.
- Gradient accumulation enables larger effective batch size without exceeding VRAM.
- Explicit PyTorch 2.6 checkpoint loading (
weights_only=False) avoids serialization regressions.
Tradeoffs to mention:
- More augmentation usually improves robustness, but too much can distort event timing cues.
- Focal loss can improve minority-event learning, but may underperform CE on well-balanced subsets.
torch.compileimproves speed on many setups, but can increase startup/compile time.- Higher
seq_lengthadds temporal context but increases memory and runtime per step.
Common tuning order:
- Stabilize pipeline (
num_workers, AMP, eval interval, best-checkpoint logic). - Tune optimization (
LR, OneCycle shape, accumulation, clipping). - Tune objective (
ce+ smoothing vsfocal). - Tune augmentation strength and sequence length.
-
Train your own model by following the steps above or download the pre-trained weights here. Create a 'models' directory if not already created and place 'swingnet_1800.pth.tar' in this directory.
-
Run eval.py. If using the pre-trained weights provided, the PCE should be 0.715.
Evaluate the best checkpoint from training:
python eval.pyeval.py uses models/swingnet_best.pth.tar by default when present. Override with:
GOLFDB_EVAL_CKPT=models/swingnet_1800.pth.tar python eval.py-
Follow steps above to download pre-trained weights. Then in the terminal:
python3 test_video.py -p test_video.mp4 -
Note: This code requires the sample video to be cropped and cut to bound a single golf swing. I used online video cropping and cutting tools for my golf swing video. See test_video.mp4 for reference.
If you convert models for iOS using Scripts/convert_model.py:
- The converter now prefers
models/swingnet_best.pth.tarautomatically. - You can override checkpoint path with:
GOLFDB_CONVERT_CKPT=/absolute/path/to/checkpoint.pth.tar python ../../Scripts/convert_model.py- Conversion loading is compatible with PyTorch 2.6+ checkpoint defaults.
For on-device inference, preprocessing must exactly match training:
- Mean:
[0.485, 0.456, 0.406] - Std:
[0.229, 0.224, 0.225]
Good luck!