Lokasi ngalangkungan proxy:   [ UP ]  
[Ngawartoskeun bug]   [Panyetelan cookie]                
Skip to content

aditya17varma/golfdb

 
 

Repository files navigation

GolfDB: A Video Database for Golf Swing Sequencing

The code in this repository is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

Introduction

GolfDB is a high-quality video dataset created for general recognition applications in the sport of golf, and specifically for the task of golf swing sequencing.

This repo contains a simple PyTorch implemention of the SwingNet baseline model presented in the paper. The model was trained on split 1 without any data augmentation and achieved an average PCE of 71.5% (PCE of 76.1% reported in the paper is credited to data augmentation including horizontal flipping and affine transformations).

If you use this repo please cite the GolfDB paper:

@InProceedings{McNally_2019_CVPR_Workshops,
author = {McNally, William and Vats, Kanav and Pinto, Tyler and Dulhanty, Chris and McPhee, John and Wong, Alexander},
title = {GolfDB: A Video Database for Golf Swing Sequencing},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2019}
}

Dependencies

Getting Started

Run generate_splits.py to convert the .mat dataset file to a dataframe and generate the 4 splits.

Train

  • I have provided the preprocessed video clips for a frame size of 160x160 (download here). Place 'videos_160' in the data directory. If you wish to use a different input configuration you must download the YouTube videos (URLs provided in dataset) and preprocess the videos yourself. I have provided preprocess_videos.py to help with that.

  • Download the MobileNetV2 pretrained weights from this repository and place 'mobilenet_v2.pth.tar' in the root directory.

  • Run train.py

Optimized training (speed + accuracy)

train.py now includes:

  • Mixed precision training (AMP) and gradient scaling
  • cuDNN benchmark + high matmul precision on CUDA
  • DataLoader optimizations (pin_memory, persistent_workers, prefetch_factor)
  • OneCycleLR scheduler + gradient clipping
  • Periodic validation and automatic best-checkpoint saving (models/swingnet_best.pth.tar)

Environment knobs (all optional):

  • GOLFDB_ITERATIONS (default: 2000 optimizer steps)
  • GOLFDB_SAVE_INTERVAL (default: 100)
  • GOLFDB_SEQ_LENGTH (default: 64)
  • GOLFDB_BATCH_SIZE (default: 22)
  • GOLFDB_NUM_WORKERS (default: 6)
  • GOLFDB_USE_AMP (default: 1)
  • GOLFDB_USE_COMPILE (default: 0)
  • GOLFDB_PIN_MEMORY (default: 1)
  • GOLFDB_PERSISTENT_WORKERS (default: 1)
  • GOLFDB_PREFETCH_FACTOR (default: 4)
  • GOLFDB_FREEZE_LAYERS (default: 0)
  • GOLFDB_LR (default: 0.001)
  • GOLFDB_ONECYCLE_PCT_START (default: 0.1)
  • GOLFDB_GRAD_ACCUM_STEPS (default: 1)
  • GOLFDB_LOSS (default: ce, options: ce, focal)
  • GOLFDB_FOCAL_GAMMA (default: 2.0)
  • GOLFDB_LABEL_SMOOTHING (default: 0.0, CE only)
  • GOLFDB_USE_AUGMENT (default: 1)
  • GOLFDB_AUG_FLIP_P (default: 0.5)
  • GOLFDB_AUG_ROTATE (default: 8.0)
  • GOLFDB_AUG_TRANSLATE (default: 0.05)
  • GOLFDB_AUG_SCALE_MIN (default: 0.95)
  • GOLFDB_AUG_SCALE_MAX (default: 1.05)
  • GOLFDB_AUG_BRIGHTNESS (default: 0.12)
  • GOLFDB_AUG_CONTRAST (default: 0.12)
  • GOLFDB_MAX_GRAD_NORM (default: 1.0)
  • GOLFDB_LOG_EVERY (default: 10)
  • GOLFDB_EVAL_INTERVAL (default: 100)
  • GOLFDB_EVAL_NUM_WORKERS (default: min(max(num_workers, 1), 4))
  • GOLFDB_EVAL_DISP (default: 0)
  • GOLFDB_DATALOADER_TIMEOUT_S (default: 60, forced to 0 when GOLFDB_NUM_WORKERS=0)

Sample commands:

High-performance training on GPU:

GOLFDB_NUM_WORKERS=8 \
GOLFDB_USE_AMP=1 \
GOLFDB_USE_COMPILE=1 \
GOLFDB_PIN_MEMORY=1 \
GOLFDB_PERSISTENT_WORKERS=1 \
GOLFDB_PREFETCH_FACTOR=4 \
GOLFDB_FREEZE_LAYERS=0 \
GOLFDB_LR=0.001 \
GOLFDB_EVAL_INTERVAL=100 \
GOLFDB_EVAL_NUM_WORKERS=4 \
GOLFDB_LOG_EVERY=20 \
python train.py

Accuracy-first preset (stronger regularization + longer schedule):

GOLFDB_ITERATIONS=6000 \
GOLFDB_SEQ_LENGTH=96 \
GOLFDB_BATCH_SIZE=12 \
GOLFDB_GRAD_ACCUM_STEPS=2 \
GOLFDB_NUM_WORKERS=8 \
GOLFDB_USE_AUGMENT=1 \
GOLFDB_LOSS=focal \
GOLFDB_FOCAL_GAMMA=2.0 \
GOLFDB_LR=0.0008 \
GOLFDB_USE_AMP=1 \
GOLFDB_USE_COMPILE=1 \
GOLFDB_EVAL_INTERVAL=200 \
GOLFDB_EVAL_NUM_WORKERS=4 \
python train.py

Debug data loading / path issues:

GOLFDB_DEBUG_TRAIN=1 \
GOLFDB_DEBUG_DATALOADER=1 \
GOLFDB_NUM_WORKERS=0 \
python train.py

Optimization rationale (interview notes)

The project optimizations were grouped into three buckets: throughput, generalization, and training stability.

Throughput (faster training):

  • AMP (torch.amp.autocast + GradScaler) reduces GPU memory and improves tensor-core utilization.
  • torch.compile can fuse graph segments and reduce Python overhead.
  • Data pipeline tuning (num_workers, pin_memory, persistent_workers, prefetch_factor) keeps GPU fed.
  • Reduced logging frequency lowers host-side synchronization overhead.

Generalization (better validation PCE):

  • Sequence-consistent augmentation (RandomAugment) applies one geometric/photometric transform to all frames in a clip, preserving temporal coherence.
  • Configurable objective (ce or focal) handles class imbalance and hard-example emphasis.
  • Optional label smoothing for CE improves calibration and reduces over-confidence.
  • Longer schedules + best-checkpoint-by-PCE selection avoid relying on final-step weights.

Training stability (fewer bad runs):

  • OneCycleLR provides warmup + annealing behavior with strong empirical convergence.
  • Gradient clipping prevents unstable updates in the CNN+LSTM stack.
  • Gradient accumulation enables larger effective batch size without exceeding VRAM.
  • Explicit PyTorch 2.6 checkpoint loading (weights_only=False) avoids serialization regressions.

Tradeoffs to mention:

  • More augmentation usually improves robustness, but too much can distort event timing cues.
  • Focal loss can improve minority-event learning, but may underperform CE on well-balanced subsets.
  • torch.compile improves speed on many setups, but can increase startup/compile time.
  • Higher seq_length adds temporal context but increases memory and runtime per step.

Common tuning order:

  1. Stabilize pipeline (num_workers, AMP, eval interval, best-checkpoint logic).
  2. Tune optimization (LR, OneCycle shape, accumulation, clipping).
  3. Tune objective (ce + smoothing vs focal).
  4. Tune augmentation strength and sequence length.

Evaluate

  • Train your own model by following the steps above or download the pre-trained weights here. Create a 'models' directory if not already created and place 'swingnet_1800.pth.tar' in this directory.

  • Run eval.py. If using the pre-trained weights provided, the PCE should be 0.715.

Evaluate the best checkpoint from training:

python eval.py

eval.py uses models/swingnet_best.pth.tar by default when present. Override with:

GOLFDB_EVAL_CKPT=models/swingnet_1800.pth.tar python eval.py

Test your own video

  • Follow steps above to download pre-trained weights. Then in the terminal: python3 test_video.py -p test_video.mp4

  • Note: This code requires the sample video to be cropped and cut to bound a single golf swing. I used online video cropping and cutting tools for my golf swing video. See test_video.mp4 for reference.

iOS / CoreML deployment notes

If you convert models for iOS using Scripts/convert_model.py:

  • The converter now prefers models/swingnet_best.pth.tar automatically.
  • You can override checkpoint path with:
GOLFDB_CONVERT_CKPT=/absolute/path/to/checkpoint.pth.tar python ../../Scripts/convert_model.py
  • Conversion loading is compatible with PyTorch 2.6+ checkpoint defaults.

For on-device inference, preprocessing must exactly match training:

  • Mean: [0.485, 0.456, 0.406]
  • Std: [0.229, 0.224, 0.225]

Good luck!

About

GolfDB is a video database for Golf Swing Sequencing, which involves detecting 8 golf swing events in trimmed golf swing videos. This repo demos the baseline model, SwingNet.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%