# AI Fetch Interface

This document describes the interface specification for neural network models that control the Fetch mobile manipulator robot in ManiSkill/MS-HAB environments. Models following this specification can be used as drop-in replacements for the default control policies.

***

## Quick Reference

| Component          | Specification                                    |
| ------------------ | ------------------------------------------------ |
| **Input**          | Observation (state)                              |
| Input Shape        | `[batch_size, state_dim]` or `[state_dim,]`      |
| Input Data Type    | `float32`                                        |
| Input Range        | Typically `[-inf, inf]` or normalized `[-1, 1]`  |
| **Neural Network** |                                                  |
| Architecture       | MLP / CNN / Transformer / etc.                   |
| Weights Format     | PyTorch (`.pt`) / ONNX (`.onnx`) / Keras (`.h5`) |
| **Output**         | Action Vector                                    |
| Output Shape       | `[batch_size, 13]` or `[13,]`                    |
| Output Data Type   | `float32`                                        |
| Output Range       | `[-1.0, 1.0]` (normalized)                       |

## Overview

The Fetch robot is a mobile manipulator with:

* **7-DOF arm** (shoulder, elbow, wrist joints)
* **2-finger gripper** (mimic-controlled)
* **3-DOF body** (head pan/tilt, torso lift)
* **Mobile base** (2D translation)

The control interface uses **delta position control** (`pd_joint_delta_pos` mode), where actions specify incremental changes to joint positions rather than absolute targets.

***

## Model Interface

### Input: Observation Space

**Format:** `Dict[str, Array]` or flattened `Array`

**Observation Mode:** `state` (default) or `rgbd`

#### State Observation (Recommended)

When `obs_mode="state"`, the observation format depends on the environment:

**Format 1: Flattened Tensor** (most common, e.g., `ReplicaCAD_SceneManipulation-v1`)

```python
observation = torch.Tensor[float32]  # Shape: (batch_size, state_dim) or (state_dim,)
# Contains: flattened robot state + task state
# Example: Shape (1, 30) for ReplicaCAD_SceneManipulation-v1
```

**Format 2: Dictionary** (some environments)

```python
observation = {
    "agent": {
        # Robot state (shape: [state_dim])
        "qpos": Array[float32],      # Joint positions (all joints)
        "qvel": Array[float32],       # Joint velocities (all joints)
        # ... other agent-specific state
    },
    "extra": {
        # Task-specific state (object positions, etc.)
        # Shape and content depend on the environment
    },
    "sensor_param": {
        # Camera/sensor parameters
    },
    "sensor_data": {
        # Camera images (if rgbd mode)
    }
}
```

**State Dimensions by Environment:**

* `ReplicaCAD_SceneManipulation-v1`: **30** dimensions (flattened tensor)
* `PickCube-v1`: Typically **40-50** dimensions
* `SequentialTask-v0`: Typically **50-100** dimensions (varies by task)

**Data Type:** `float32`

**Normalization:** Observations may be normalized to `[-1, 1]` range depending on the training setup.

**Model Input Handling:** Models should handle both formats for maximum compatibility:

```python
def preprocess_observation(obs):
    """Handle both tensor and dictionary observations."""
    if isinstance(obs, dict):
        # Extract and flatten dictionary
        if 'agent' in obs:
            state = obs['agent']
        else:
            # Concatenate all dict values
            state = torch.cat([v.flatten() for v in obs.values()], dim=-1)
    else:
        # Already a tensor
        state = obs
    
    # Ensure correct shape
    if len(state.shape) == 1:
        state = state.unsqueeze(0)  # Add batch dimension if needed
    
    return state
```

#### RGBD Observation (Alternative)

When `obs_mode="rgbd"`, the observation includes camera images:

```python
observation = {
    "agent": {...},  # Same as above
    "sensor_data": {
        "rgb": Array[uint8],      # Shape: [H, W, 3] or [N, H, W, 3]
        "depth": Array[float32],   # Shape: [H, W] or [N, H, W]
        # Camera-specific keys
    },
    # ... other keys
}
```

**Note:** For model compatibility, state observations are recommended as they are more compact and environment-agnostic.

***

### Output: Action Space

**Format:** `Array[float32]`

**Shape:** `(13,)` - Fixed dimension

**Range:** `[-1.0, 1.0]` (normalized)

**Action Vector Breakdown:**

```python
action = [
    a0,   # [0]  Arm: shoulder_pan_joint      (delta position)
    a1,   # [1]  Arm: shoulder_lift_joint       (delta position)
    a2,   # [2]  Arm: upperarm_roll_joint     (delta position)
    a3,   # [3]  Arm: elbow_flex_joint         (delta position)
    a4,   # [4]  Arm: forearm_roll_joint        (delta position)
    a5,   # [5]  Arm: wrist_flex_joint          (delta position)
    a6,   # [6]  Arm: wrist_roll_joint         (delta position)
    a7,   # [7]  Gripper: l_gripper_finger_joint (position, mimic-controlled)
    a8,   # [8]  Body: head_pan_joint           (delta position)
    a9,   # [9]  Body: head_tilt_joint          (delta position)
    a10,  # [10] Body: torso_lift_joint         (delta position)
    a11,  # [11] Base: root_x_axis_joint       (velocity: left/right)
    a12,  # [12] Base: root_y_axis_joint       (velocity: forward/back)
]
```

#### Detailed Action Components

| Index     | Component                | Joint Name                        | Control Type     | Range (normalized) | Physical Range   |
| --------- | ------------------------ | --------------------------------- | ---------------- | ------------------ | ---------------- |
| **0-6**   | **Arm**                  |                                   |                  |                    |                  |
| 0         | `shoulder_pan_joint`     | Shoulder pan (left/right)         | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 1         | `shoulder_lift_joint`    | Shoulder lift (up/down)           | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 2         | `upperarm_roll_joint`    | Upper arm roll                    | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 3         | `elbow_flex_joint`       | Elbow flexion                     | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 4         | `forearm_roll_joint`     | Forearm roll                      | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 5         | `wrist_flex_joint`       | Wrist flexion                     | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 6         | `wrist_roll_joint`       | Wrist roll                        | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| **7**     | **Gripper**              |                                   |                  |                    |                  |
| 7         | `l_gripper_finger_joint` | Gripper open/close                | Position (mimic) | \[-1, 1]           | \[-0.01, 0.05] m |
| **8-10**  | **Body**                 |                                   |                  |                    |                  |
| 8         | `head_pan_joint`         | Head pan (left/right)             | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 9         | `head_tilt_joint`        | Head tilt (up/down)               | Delta position   | \[-1, 1]           | \[-0.1, 0.1] rad |
| 10        | `torso_lift_joint`       | Torso lift                        | Delta position   | \[-1, 1]           | \[-0.1, 0.1] m   |
| **11-12** | **Base**                 |                                   |                  |                    |                  |
| 11        | `root_x_axis_joint`      | Base translation X (left/right)   | Velocity         | \[-1, 1]           | \[-1.0, 1.0] m/s |
| 12        | `root_y_axis_joint`      | Base translation Y (forward/back) | Velocity         | \[-1, 1]           | \[-1.0, 1.0] m/s |

**Note:** The gripper uses a mimic controller where `r_gripper_finger_joint` automatically mirrors `l_gripper_finger_joint`.

***

## Model Architecture Requirements

### Input Processing

The model should accept:

1. **State observations** (recommended):
   * Input shape: `(batch_size, state_dim)` where `state_dim` is:
     * **30** for `ReplicaCAD_SceneManipulation-v1`
     * **40-50** for `PickCube-v1`
     * **50-100** for other environments (varies)
   * Or: Dictionary with `agent` key containing state vector
   * Data type: `float32`
2. **RGBD observations** (optional):
   * Input shape: `(batch_size, H, W, C)` for images
   * May require CNN backbone (ResNet, etc.)
   * Data type: `uint8` for RGB, `float32` for depth

### Output Processing

The model must output:

* **Shape:** `(batch_size, 13)` or `(13,)` for single inference
* **Data type:** `float32`
* **Range:** `[-1.0, 1.0]` (normalized actions)
* **Activation:** `tanh` is commonly used for the final layer

### Example Architectures

#### PyTorch MLP (Minimal)

```python
import torch
import torch.nn as nn

class FetchControlModel(nn.Module):
    def __init__(self, state_dim=30):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 13),
            nn.Tanh()  # Ensures output in [-1, 1]
        )
    
    def forward(self, obs):
        # Handle both dict and tensor observations
        if isinstance(obs, dict):
            if 'agent' in obs:
                state = obs['agent']
            else:
                # Flatten all dict values
                state = torch.cat([v.flatten() for v in obs.values()], dim=-1)
        else:
            state = obs
        
        # Ensure correct shape
        if len(state.shape) == 1:
            state = state.unsqueeze(0)  # Add batch dimension
        
        return self.net(state)
```

#### Keras/TensorFlow (Includes Architecture)

```python
import tensorflow as tf
from tensorflow import keras

def create_fetch_model(state_dim=30):
    """Create a Keras model for Fetch control."""
    model = keras.Sequential([
        keras.layers.Dense(256, activation='relu', input_shape=(state_dim,)),
        keras.layers.Dense(256, activation='relu'),
        keras.layers.Dense(256, activation='relu'),
        keras.layers.Dense(13, activation='tanh')  # Output in [-1, 1]
    ])
    return model

# Save model (includes architecture + weights)
model = create_fetch_model()
model.save('fetch_policy.h5')  # Can be loaded directly without architecture code
```

**Advantage of Keras:** The `.h5` file contains both architecture and weights, so it can be loaded without providing the architecture definition:

```python
# Load Keras model (no need to define architecture)
model = keras.models.load_model('fetch_policy.h5')
```

***

## Model Weights Format

### Supported Formats

Models should be saved in one of the following formats:

1. **PyTorch** (`.pt` or `.pth`):

   ```python
   torch.save({
       'model_state_dict': model.state_dict(),
       'config': {...},  # Optional: model config
   }, 'model.pt')
   ```
2. **ONNX** (`.onnx`):
   * Standardized format, framework-agnostic
   * Can be loaded by PyTorch, TensorFlow, etc.
3. **TensorFlow/Keras** (`.h5` or SavedModel):
   * **Key Advantage:** Keras models include architecture + weights in a single file
   * Can be loaded directly without defining architecture separately
   * Perfect for URL-based loading (like Keras model zoo)
   * Example: `model = keras.models.load_model(url)` - no architecture code needed!
4. **HuggingFace Model Hub**:
   * Models can be hosted and loaded via URL
   * Example: `model = torch.hub.load('user/repo', 'model')`

### Model Metadata

The model file or repository should include metadata for proper loading and usage:

```json
{
    "model_type": "fetch_control",
    "version": "1.0",
    "architecture": "mlp" | "cnn" | "transformer" | ...,
    "input_type": "state" | "rgbd",
    "input_dim": 30,  # For ReplicaCAD_SceneManipulation-v1, or shape for images
    "output_dim": 13,
    "control_mode": "pd_joint_delta_pos",
    "normalization": {
        "obs_mean": [0.0, 0.0, ...],  # Mean for each observation dimension
        "obs_std": [1.0, 1.0, ...]    # Std for each observation dimension
    },
    "training_config": {
        "env_id": "PickCube-v1" | "ReplicaCAD_SceneManipulation-v1" | ...,
        "algorithm": "SAC" | "PPO" | "BC" | "ACT" | ...,
        "seed": 42,
        "total_timesteps": 1000000
    },
    "framework": "pytorch" | "tensorflow" | "onnx",
    "device": "cuda" | "cpu"
}
```

**Metadata Location:**

* For PyTorch: Include in checkpoint dict or separate `config.json`
* For Keras: Stored in model file or `config.json`
* For HuggingFace: In repository root as `config.json`

***

## Model Loading Interface

### URL-Based Loading

Models should be loadable via URL or local path:

```python
def load_fetch_model(model_url: str, device: str = "cuda"):
    """
    Load a Fetch control model from URL or local path.
    
    Args:
        model_url: URL (http/https) or local file path
        device: Device to load model on ("cuda" or "cpu")
    
    Returns:
        model: Loaded model ready for inference
        metadata: Model metadata dict
    """
    # Implementation depends on model format
    # Example for PyTorch:
    if model_url.startswith("http"):
        # Download from URL
        model_path = download_model(model_url)
    else:
        model_path = model_url
    
    # Load model
    checkpoint = torch.load(model_path, map_location=device)
    model = create_model_from_checkpoint(checkpoint)
    model.eval()
    return model, checkpoint.get('metadata', {})
```

### Example Usage

```python
import torch
from fetch_model_loader import load_fetch_model

# Load model from URL
model, metadata = load_fetch_model(
    "https://huggingface.co/user/fetch-policy/resolve/main/model.pt"
)

# Or from local path
model, metadata = load_fetch_model("./checkpoints/fetch_policy.pt")

# Inference
obs = env.reset()
with torch.no_grad():
    action = model(obs)  # Shape: (13,)
    action = action.clamp(-1.0, 1.0)  # Ensure valid range

env.step(action)
```

***

## Environment Compatibility

### Required Environment Settings

```python
env = gym.make(
    env_id,  # e.g., "PickCube-v1", "SequentialTask-v0"
    num_envs=1,
    obs_mode="state",  # or "rgbd" if model supports it
    render_mode="rgb_array",
    sim_backend="gpu",  # or "cpu"
    robot_uids="fetch",
    control_mode="pd_joint_delta_pos",  # Must match!
)
```

### Observation Preprocessing

If the model was trained with normalized observations:

```python
def preprocess_obs(obs, metadata):
    """Normalize observation using model's statistics."""
    if isinstance(obs, dict):
        state = obs['agent']
    else:
        state = obs
    
    if 'normalization' in metadata:
        mean = metadata['normalization']['obs_mean']
        std = metadata['normalization']['obs_std']
        state = (state - mean) / (std + 1e-8)
    
    return state
```

***

## Testing Model Compatibility

### Validation Checklist

Before using a model, verify:

* [ ] **Input shape matches:** Model accepts observation of correct dimension
* [ ] **Output shape is (13,):** Model outputs 13-dimensional action vector
* [ ] **Output range is \[-1, 1]:** Actions are properly normalized
* [ ] **Control mode matches:** Model trained with `pd_joint_delta_pos`
* [ ] **Observation mode matches:** Model trained with `state` or `rgbd` as specified
* [ ] **Framework compatibility:** Model can be loaded in your environment

### Test Script

```python
def test_model_compatibility(model, env):
    """Test if model is compatible with environment."""
    obs, _ = env.reset()
    
    # Test single observation
    with torch.no_grad():
        action = model(obs)
    
    # Validate output
    assert action.shape == (13,), f"Expected shape (13,), got {action.shape}"
    assert action.min() >= -1.0 and action.max() <= 1.0, \
        f"Actions must be in [-1, 1], got range [{action.min()}, {action.max()}]"
    
    # Test step
    obs, reward, done, truncated, info = env.step(action)
    print("✅ Model is compatible!")
    return True
```

***

## Example Model Repositories

### Format for HuggingFace

```
fetch-policy/
├── model.pt              # PyTorch weights
├── config.json           # Model metadata
├── README.md            # Documentation
└── requirements.txt     # Dependencies
```

### Format for Local Storage

```
models/
├── fetch_policy_v1.pt
├── fetch_policy_v1.json  # Metadata
└── fetch_policy_v1_README.md
```

***

## Model Loading from URL

### Supported URL Formats

Models can be loaded from:

1. **HuggingFace Model Hub:**

   ```python
   model_url = "https://huggingface.co/user/fetch-policy/resolve/main/model.pt"
   ```
2. **Direct HTTP/HTTPS URLs:**

   ```python
   model_url = "https://example.com/models/fetch_policy.pt"
   ```
3. **Local file paths:**

   ```python
   model_url = "./checkpoints/fetch_policy.pt"
   model_url = "/path/to/model.pt"
   ```

### Implementation Example

```python
import torch
import urllib.request
from pathlib import Path

def load_model_from_url(model_url: str, device: str = "cuda"):
    """
    Load a Fetch control model from URL or local path.
    
    Supports:
    - HuggingFace model URLs
    - Direct HTTP/HTTPS URLs
    - Local file paths
    
    Args:
        model_url: URL or local path to model file
        device: Device to load model on ("cuda" or "cpu")
    
    Returns:
        model: Loaded model ready for inference
        metadata: Model metadata dict (if available)
    """
    # Download if URL
    if model_url.startswith("http"):
        print(f"Downloading model from {model_url}...")
        model_path = Path("/tmp/fetch_model.pt")
        urllib.request.urlretrieve(model_url, model_path)
    else:
        model_path = Path(model_url)
    
    if not model_path.exists():
        raise FileNotFoundError(f"Model not found: {model_url}")
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract model state and metadata
    if isinstance(checkpoint, dict):
        state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
        metadata = checkpoint.get('metadata', {})
    else:
        state_dict = checkpoint
        metadata = {}
    
    # Create model architecture (user must provide)
    # This is framework-specific
    model = create_model_architecture(metadata)  # User implements this
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)
    
    return model, metadata
```

### Keras/TensorFlow Models

For Keras models (which include architecture + weights in a single file):

```python
import tensorflow as tf
from tensorflow import keras
import urllib.request
from pathlib import Path

def load_keras_model_from_url(model_url: str):
    """
    Load Keras model from URL.
    
    Keras models (.h5 or SavedModel) include both architecture and weights,
    so no need to specify architecture separately - this is the key advantage!
    
    Args:
        model_url: URL or local path to .h5 or SavedModel directory
    
    Returns:
        model: Loaded Keras model ready for inference
    """
    if model_url.startswith("http"):
        # Download using Keras utility or manual download
        print(f"Downloading model from {model_url}...")
        model_path = tf.keras.utils.get_file(
            "fetch_model.h5",
            model_url,
            cache_subdir="models",
            cache_dir="/tmp"
        )
    else:
        model_path = model_url
    
    if not Path(model_path).exists():
        raise FileNotFoundError(f"Model not found: {model_url}")
    
    # Keras automatically loads architecture + weights
    # This is why Keras is convenient - no need to define architecture!
    model = keras.models.load_model(model_path)
    
    print(f"✅ Model loaded from {model_path}")
    print(f"   Input shape: {model.input_shape}")
    print(f"   Output shape: {model.output_shape}")
    
    return model

# Usage
model = load_keras_model_from_url("https://example.com/models/fetch_policy.h5")
# Or from HuggingFace (if they support direct .h5 links)
model = load_keras_model_from_url("https://huggingface.co/user/fetch-policy/resolve/main/model.h5")
```

**Key Advantage of Keras:** The model file is self-contained - it includes:

* Architecture definition (layers, connections)
* Weights (trained parameters)
* Optimizer state (optional)
* Training configuration (optional)

This means you can load a Keras model with just:

```python
model = keras.models.load_model(url_or_path)
```

No need to define the architecture separately!

***

## Summary

| Component             | Specification                                                                          |
| --------------------- | -------------------------------------------------------------------------------------- |
| **Input**             | Observation dict or flattened array, shape `(batch_size, state_dim)` or `(state_dim,)` |
| **Input Dimension**   | Typically 30-100 (varies by environment)                                               |
| **Output**            | Action array, shape `(13,)`, range `[-1.0, 1.0]`, dtype `float32`                      |
| **Control Mode**      | `pd_joint_delta_pos` (required)                                                        |
| **Robot**             | Fetch mobile manipulator                                                               |
| **Action Components** | 7 arm joints + 1 gripper + 3 body joints + 2 base velocities                           |
| **Model Format**      | PyTorch (`.pt`), ONNX (`.onnx`), or Keras (`.h5`)                                      |
| **Loading**           | URL (HTTP/HTTPS/HuggingFace) or local path supported                                   |
| **Architecture**      | User-defined (MLP, CNN, Transformer, etc.)                                             |

***

## Action Vector Visualization

**Action Vector (13 dimensions)**

| Index     | Component        | Joint Name               | Control Type | Description                           |
| --------- | ---------------- | ------------------------ | ------------ | ------------------------------------- |
| **0-6**   | **Arm (7 DOF)**  |                          |              |                                       |
| 0         | Shoulder Pan     | `shoulder_pan_joint`     | Δpos         | Shoulder rotation (left/right)        |
| 1         | Shoulder Lift    | `shoulder_lift_joint`    | Δpos         | Shoulder elevation (up/down)          |
| 2         | Upper Arm Roll   | `upperarm_roll_joint`    | Δpos         | Upper arm rotation                    |
| 3         | Elbow Flex       | `elbow_flex_joint`       | Δpos         | Elbow flexion                         |
| 4         | Forearm Roll     | `forearm_roll_joint`     | Δpos         | Forearm rotation                      |
| 5         | Wrist Flex       | `wrist_flex_joint`       | Δpos         | Wrist flexion                         |
| 6         | Wrist Roll       | `wrist_roll_joint`       | Δpos         | Wrist rotation                        |
| **7**     | **Gripper**      |                          |              |                                       |
| 7         | Gripper          | `l_gripper_finger_joint` | pos          | Gripper open/close (mimic-controlled) |
| **8-10**  | **Body (3 DOF)** |                          |              |                                       |
| 8         | Head Pan         | `head_pan_joint`         | Δpos         | Head rotation (left/right)            |
| 9         | Head Tilt        | `head_tilt_joint`        | Δpos         | Head tilt (up/down)                   |
| 10        | Torso Lift       | `torso_lift_joint`       | Δpos         | Torso vertical movement               |
| **11-12** | **Base (2 DOF)** |                          |              |                                       |
| 11        | Base X           | `root_x_axis_joint`      | vel          | Base translation X (left/right)       |
| 12        | Base Y           | `root_y_axis_joint`      | vel          | Base translation Y (forward/back)     |

**Legend:**

* **Δpos** = Delta position (incremental change from current position)
* **pos** = Absolute position target
* **vel** = Velocity control

***

## References

* [ManiSkill Documentation](https://maniskill.readthedocs.io/)
* [MS-HAB Repository](https://github.com/...)
* Fetch Robot URDF: Included in ManiSkill assets
* Control Modes: See `mani_skill/agents/robots/fetch/fetch.py`


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://docs.konnex.world/supported-ai-models/fetch_interface.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
