Skip to content

Gabesarch/grounded-rl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ViGoRL: Visually Grounded Reinforcement Learning

This repository contains the official code implementation for the paper:

Grounded Reinforcement Learning for Visual Reasoning Gabriel Sarch, Snigdha Saha, Naitik Khandelwal, Ayush Jain, Michael J. Tarr, Aviral Kumar, Katerina Fragkiadaki

Paper | Datasets | Models | Project Page


Overview

Visually-Grounded Reinforcement Learning (ViGoRL) integrates vision-language models (VLMs) with reinforcement learning (RL) to produce spatially grounded reasoning steps explicitly anchored to visual regions. By using a novel multi-turn RL formulation, ViGoRL dynamically zooms into relevant image areas to improve visual attention, grounding, and reasoning capabilities across various visual reasoning tasks.


Features

  • Supported Models

    • Qwen2/Qwen2.5-VL
    • Easily supports additional models compatible with vLLM
  • Supported Algorithms

    • Async, distributed Monte Carlo Tree Search (MCTS) over thought steps with VLMs (& visual grounding), and tree search linearization
    • Supervised Fine-Tuning (SFT)
    • Group Relative Policy Optimization (GRPO) reinforcement learning
    • Multi-turn GRPO with visual feedback loops
    • Async, distributed VLM evaluation
  • Supported Tasks and Benchmarks

    • Spatial Reasoning: SAT-2, BLINK, RoboSpatial
    • Visual Search: V*Bench
    • Web Grounding and Interaction: OS-ATLAS, ScreenSpot (V2 and Pro), VisualWebArena
    • Any vision-text dataset in a specific format
  • Key Features

    • Spatial grounding of textual reasoning steps to visual coordinates
    • Dynamic, zoom-in visual feedback for multi-turn reasoning
    • Efficient, async inference via vLLM and multiprocessing

Requirements

  • Python >= 3.9
  • transformers >= 4.51.0
  • flash-attn >= 2.4.3
  • vllm >= 0.8.3

Installation

Clone Repository

git clone https://github.com/Gabesarch/grounded-rl.git
cd grounded-rl

Environment Setup

conda create -n grounded-rl python=3.10
conda activate grounded-rl
pip install uv
uv pip install -e .
uv pip install flash-attn --no-build-isolation
uv pip install deepspeed<=0.16.9

# RL dependencies
cd src/trainer/rl
uv pip install -e .
cd ../../..

# SFT dependencies
cd src/trainer/offline
uv pip install -e ".[torch,metrics]"

Data Setup

Set the dataset path environment variable before running scripts:

export DATA_ROOT="/path/to/your/data/root"

Downloading Datasets

Run the provided script to download and extract data:

python download_data.py

Datasets include:

  • Spatial reasoning tasks (SAT-2, BLINK)
  • Visual search tasks (V*Bench)
  • Web grounding and action tasks (OS-ATLAS, ScreenSpot, VisualWebArena)

Model Checkpoints

Pretrained ViGoRL checkpoints are available via Hugging Face:

Task 3B Model 7B Model
Visual Search ViGoRL Multiturn, MCTS-SFT ViGoRL Multiturn, MCTS-SFT
Web Grounding ViGoRL, ViGoRL Multiturn, MCTS-SFT ViGoRL
Web Action ViGoRL ViGoRL
Spatial Reasoning ViGoRL, MCTS-SFT ViGoRL, MCTS-SFT

Demo Scripts

We provide demo scripts to run ViGoRL checkpoints interactively on a single image. Two versions are available:

  • demo/demo_singleturn.py: Run and visualize our singe-turn variant. Generates a single output of multiple grounded reasoning steps for a given image and query.
  • demo/demo_multiturn.py: Run and visualize our multi-turn variant. Iteratively generates reasoning steps, dynamically feeding cropped images back to the model based on model outputs until reaching a final answer.

Single-Turn Demo

Usage Example:

# Spatial reasoning example
python demo/demo_singleturn.py \
    --model gsarch/ViGoRL-7b-Spatial \
    --image demo/examples/LivingRoom.jpg \
    --query "What is above the blue lamp?"

# Web grounding example
python demo/demo_singleturn.py \
    --model gsarch/ViGoRL-7b-Web-Grounding \
    --image demo/examples/APnews.png \
    --query "Description: check sports news"

# Web action example
python demo/demo_singleturn.py \
    --model gsarch/ViGoRL-7b-Web-Action \
    --image demo/examples/osclass_page.png \
    --query "OBJECTIVE: Identify the insect in the picture. Leave a comment with the title \"Questions\" and text containing the insect's identity, with the purpose of confirming with the seller.\n\nPREVIOUS ACTIONS: \n1. type  [INPUT] [] [Questions]\n\nIMAGE:\n"

Multi-Turn Demo

Usage Example:

# Visual search example
python demo/demo_multiturn.py \
    --model gsarch/ViGoRL-Multiturn-7b-Visual-Search \
    --image demo/examples/man.jpg \
    --query "What color is the man's shirt?" \
    --crop_offset 182 \
    --crop_size 672

# Web grounding example
python demo/demo_multiturn.py \
    --model gsarch/ViGoRL-Multiturn-3b-Web-Grounding \
    --image demo/examples/APnews.png \
    --query "Description: check sports news" \
    --draw_dot \
    --crop_offset 100 \
    --crop_size 512

Repository Structure

grounded-rl
├── data/                     # Data and rollouts
├── scripts/
│   ├── evaluation/           # Evaluation scripts
│   └── mcts/                 # MCTS-related scripts
├── src/
│   ├── vlmsearch/
│   │   ├── arguments.py
│   │   └── tree_search/
│   │       ├── mcts_search.py
│   │       └── single_path_rollouts.py
│   └── trainer/
│       ├── offline/          # SFT (Llama-Factory based)
│       └── rl/               # GRPO (EasyR1 based)
└── download_data.py

Training Pipeline

Step 1: Run MCTS to Generate Grounded Traces

bash scripts/mcts/run_mcts_qwen72b.sh
  • Specify dataset and vLLM port at the top of the script.

Step 2: Linearize MCTS Trees

Single-turn:

python scripts/mcts/build_reasoning_chains_from_mcts.py

Multi-turn:

python scripts/mcts/build_reasoning_chains_from_mcts_multiturn.py

Step 3: Supervised Fine-Tuning (SFT)

Add reasoning chains to src/trainer/offline/data/dataset_info.json and run:

cd src/trainer/offline
bash examples/train_qwen2_5_vl_sft.sh

Step 4: Reinforcement Learning (GRPO)

Run GRPO on top of the SFT model:

cd src/trainer/rl
bash examples/run_vigorl.sh

Checkpoints are sharded; convert to HF format using:

python src/trainer/rl/scripts/model_merger.py

Evaluation

Configure MODEL and SYSTEM_PROMPT variables at the top of each script:

  • Multi-turn visual search & ScreenSpot:
bash scripts/evaluation/eval_multiturn.sh
  • Spatial reasoning (BLINK, SAT-2):
bash scripts/evaluation/eval_spatial.sh
  • Web grounding (ScreenSpot-Pro):
bash scripts/evaluation/eval_web_grounding.sh

FAQ

Q: I’m getting the error KeyError: 'visual.patch_embed.proj.weight'. What should I do?

This error typically indicates a version mismatch between the model and the installed transformers library. To resolve it, downgrade your transformers version to:

pip install transformers==4.51.0

Q: I'm getting the error ImportError: .../flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: .... What should I do?

This indicates a problem with flash attention version compilation. To resolve, compile with an alternate flash attention version:

pip install flash-attn==2.5.1.post1 --no-build-isolation

Citation

@article{sarch2025vigorl,
    title={Grounded Reinforcement Learning for Visual Reasoning},
    author={Sarch, Gabriel and Saha, Snigdha and Khandelwal, Naitik and Jain, Ayush and Tarr, Michael J and Kumar, Aviral and Fragkiadaki, Katerina},
    year={2025}
}

Acknowledgements

This project builds on the LLaMA-Factory and EasyR1 projects to support visually-grounded RL, we thank all the authors for providing such a high-performance training framework.

License

The code is released under MIT License.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages