Stateful Environment#

This tutorial focuses on the Resources Server implementation for environments that maintain state across tool calls within an episode. The full workflow — task data preparation, agent/model configuration, rollout collection, and training — follows the same steps as the single-step tutorial. What changes here is the addition of per-episode session state via middleware.

< Previous: Multi-Step Environment


What You’ll Build#

A counter environment where the user asks the agent to increment a counter by a certain amount and report the final value. The environment seeds the counter with an initial value, and the agent must call increment_counter and get_counter_value in sequence to produce the correct result. Like the multi-step tutorial, this is single-turn (one user message, multiple tool calls) — but the key difference is that the counter value lives as server-side session state that persists across tool calls within the episode, managed via SESSION_ID_KEY.

Episode Flow#

Goal (what the agent is learning)
  - Learn multi-step stateful tool usage: perform actions that change environment state and then read/verify the final state.

Inputs
  - seed input: initial_count (e.g., 3)
  - ground truth for grading: expected_count (e.g., 5)

Flow (state is stored per session_id inside the ResourcesServer)
  1) POST ResourcesServer /seed_session {"initial_count": 3}
     - stores session_id_to_counter[session_id] = 3
  2) POST ModelServer /v1/responses -> function_call: increment_counter({"count": 2})
  3) POST ResourcesServer /increment_counter {"count": 2}
     - counter becomes 5 for this session_id
  4) POST ModelServer /v1/responses -> function_call: get_counter_value({})
  5) POST ResourcesServer /get_counter_value {}
     - returns {"count": 5}
  6) POST ResourcesServer /verify {"expected_count": 5, ...}
     - reward = 1.0 iff stored counter == expected_count

Implementation#

File (simplified from resources_servers/example_session_state_mgmt/app.py, with improved seed_session pattern):

# simplified
from typing import Dict

from fastapi import FastAPI, Request
from pydantic import BaseModel, Field

from nemo_gym.base_resources_server import (
    BaseResourcesServerConfig,
    BaseSeedSessionRequest,
    BaseSeedSessionResponse,
    BaseVerifyRequest,
    BaseVerifyResponse,
    SimpleResourcesServer,
)
from nemo_gym.server_utils import SESSION_ID_KEY  # Critical import!


class StatefulCounterResourcesServerConfig(BaseResourcesServerConfig):
    pass


# Custom seed request to initialize state
class StatefulCounterSeedSessionRequest(BaseSeedSessionRequest):
    initial_count: int


class IncrementCounterRequest(BaseModel):
    count: int


class IncrementCounterResponse(BaseModel):
    success: bool


class GetCounterValueResponse(BaseModel):
    count: int


class StatefulCounterVerifyRequest(BaseVerifyRequest):
    expected_count: int


class StatefulCounterResourcesServer(SimpleResourcesServer):
    config: StatefulCounterResourcesServerConfig

    # Session state storage - maps session_id -> state
    session_id_to_counter: Dict[str, int] = Field(default_factory=dict)

    def setup_webserver(self) -> FastAPI:
        app = super().setup_webserver()
        app.post("/increment_counter")(self.increment_counter)
        app.post("/get_counter_value")(self.get_counter_value)
        return app

    # Initialize session state
    async def seed_session(
        self,
        request: Request,  # Must include Request to access session
        body: StatefulCounterSeedSessionRequest
    ) -> BaseSeedSessionResponse:
        session_id = request.session[SESSION_ID_KEY]  # Get unique session ID
        self.session_id_to_counter[session_id] = body.initial_count
        return BaseSeedSessionResponse()

    # Stateful tool - modifies session state
    async def increment_counter(
        self,
        request: Request,
        body: IncrementCounterRequest
    ) -> IncrementCounterResponse:
        session_id = request.session[SESSION_ID_KEY]
        counter = self.session_id_to_counter.setdefault(session_id, 0)
        counter += body.count
        self.session_id_to_counter[session_id] = counter
        return IncrementCounterResponse(success=True)

    # Read-only tool
    async def get_counter_value(self, request: Request) -> GetCounterValueResponse:
        session_id = request.session[SESSION_ID_KEY]
        counter = self.session_id_to_counter.setdefault(session_id, 0)
        return GetCounterValueResponse(count=counter)

    # Verify against expected final state
    async def verify(
        self,
        request: Request,
        body: StatefulCounterVerifyRequest
    ) -> BaseVerifyResponse:
        session_id = request.session[SESSION_ID_KEY]

        reward = 0.0
        if session_id in self.session_id_to_counter:
            counter = self.session_id_to_counter[session_id]
            reward = float(body.expected_count == counter)

        return BaseVerifyResponse(**body.model_dump(), reward=reward)


if __name__ == "__main__":
    StatefulCounterResourcesServer.run_webserver()

Key Pattern#

Use SESSION_ID_KEY from the request session middleware to maintain per-episode state. The session ID is automatically assigned by the framework’s middleware (set up in SimpleResourcesServer.setup_webserver() via self.setup_session_middleware(app)).

To access session state:

  1. Add request: Request as a parameter to any endpoint method

  2. Read the session ID with request.session[SESSION_ID_KEY]

  3. Store state in an instance-level dictionary keyed by session ID

Note

In seed_session, use direct assignment (self.session_id_to_counter[session_id] = body.initial_count) rather than setdefault. Using setdefault would silently ignore re-seed attempts if the session already exists, which can cause subtle bugs when the same session ID is reused across episodes. Note: the current example_session_state_mgmt implementation still uses setdefault in seed_session — the direct assignment shown here is the preferred pattern.

In tool methods like increment_counter and get_counter_value, setdefault is appropriate — it provides a safe fallback of 0 if the session was somehow not initialized.


Rollout Transcript#

[Episode start]

Agent -> ResourcesServer: POST /seed_session {"initial_count": 3}
  (ResourcesServer stores session_id_to_counter[session_id] = 3)

User: "Increment the counter by 2, then tell me the current value."

Agent -> ModelServer: POST /v1/responses (tools: increment_counter, get_counter_value)
Model calls tool:
  function_call: increment_counter({"count": 2})

Agent -> ResourcesServer: POST /increment_counter {"count": 2}
ResourcesServer -> Agent:
  {"success": true}
  (counter is now 5 for this session_id)

Agent -> ModelServer: POST /v1/responses
Model calls tool:
  function_call: get_counter_value({})

Agent -> ResourcesServer: POST /get_counter_value {}
ResourcesServer -> Agent:
  {"count": 5}

[Episode end -> grading]

Agent -> ResourcesServer: POST /verify {"expected_count": 5, ...}
ResourcesServer:
  - reads counter for this session_id
  - reward = 1.0 if counter == expected_count else 0.0

Next: Real-World Environment >