Stateful Environment#
This tutorial focuses on the Resources Server implementation for environments that maintain state across tool calls within an episode. The full workflow — task data preparation, agent/model configuration, rollout collection, and training — follows the same steps as the single-step tutorial. What changes here is the addition of per-episode session state via middleware.
< Previous: Multi-Step Environment
What You’ll Build#
A counter environment where the user asks the agent to increment a counter by a certain amount and report the final value. The environment seeds the counter with an initial value, and the agent must call increment_counter and get_counter_value in sequence to produce the correct result. Like the multi-step tutorial, this is single-turn (one user message, multiple tool calls) — but the key difference is that the counter value lives as server-side session state that persists across tool calls within the episode, managed via SESSION_ID_KEY.
Episode Flow#
Goal (what the agent is learning)
- Learn multi-step stateful tool usage: perform actions that change environment state and then read/verify the final state.
Inputs
- seed input: initial_count (e.g., 3)
- ground truth for grading: expected_count (e.g., 5)
Flow (state is stored per session_id inside the ResourcesServer)
1) POST ResourcesServer /seed_session {"initial_count": 3}
- stores session_id_to_counter[session_id] = 3
2) POST ModelServer /v1/responses -> function_call: increment_counter({"count": 2})
3) POST ResourcesServer /increment_counter {"count": 2}
- counter becomes 5 for this session_id
4) POST ModelServer /v1/responses -> function_call: get_counter_value({})
5) POST ResourcesServer /get_counter_value {}
- returns {"count": 5}
6) POST ResourcesServer /verify {"expected_count": 5, ...}
- reward = 1.0 iff stored counter == expected_count
Implementation#
File (simplified from resources_servers/example_session_state_mgmt/app.py, with improved seed_session pattern):
# simplified
from typing import Dict
from fastapi import FastAPI, Request
from pydantic import BaseModel, Field
from nemo_gym.base_resources_server import (
BaseResourcesServerConfig,
BaseSeedSessionRequest,
BaseSeedSessionResponse,
BaseVerifyRequest,
BaseVerifyResponse,
SimpleResourcesServer,
)
from nemo_gym.server_utils import SESSION_ID_KEY # Critical import!
class StatefulCounterResourcesServerConfig(BaseResourcesServerConfig):
pass
# Custom seed request to initialize state
class StatefulCounterSeedSessionRequest(BaseSeedSessionRequest):
initial_count: int
class IncrementCounterRequest(BaseModel):
count: int
class IncrementCounterResponse(BaseModel):
success: bool
class GetCounterValueResponse(BaseModel):
count: int
class StatefulCounterVerifyRequest(BaseVerifyRequest):
expected_count: int
class StatefulCounterResourcesServer(SimpleResourcesServer):
config: StatefulCounterResourcesServerConfig
# Session state storage - maps session_id -> state
session_id_to_counter: Dict[str, int] = Field(default_factory=dict)
def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()
app.post("/increment_counter")(self.increment_counter)
app.post("/get_counter_value")(self.get_counter_value)
return app
# Initialize session state
async def seed_session(
self,
request: Request, # Must include Request to access session
body: StatefulCounterSeedSessionRequest
) -> BaseSeedSessionResponse:
session_id = request.session[SESSION_ID_KEY] # Get unique session ID
self.session_id_to_counter[session_id] = body.initial_count
return BaseSeedSessionResponse()
# Stateful tool - modifies session state
async def increment_counter(
self,
request: Request,
body: IncrementCounterRequest
) -> IncrementCounterResponse:
session_id = request.session[SESSION_ID_KEY]
counter = self.session_id_to_counter.setdefault(session_id, 0)
counter += body.count
self.session_id_to_counter[session_id] = counter
return IncrementCounterResponse(success=True)
# Read-only tool
async def get_counter_value(self, request: Request) -> GetCounterValueResponse:
session_id = request.session[SESSION_ID_KEY]
counter = self.session_id_to_counter.setdefault(session_id, 0)
return GetCounterValueResponse(count=counter)
# Verify against expected final state
async def verify(
self,
request: Request,
body: StatefulCounterVerifyRequest
) -> BaseVerifyResponse:
session_id = request.session[SESSION_ID_KEY]
reward = 0.0
if session_id in self.session_id_to_counter:
counter = self.session_id_to_counter[session_id]
reward = float(body.expected_count == counter)
return BaseVerifyResponse(**body.model_dump(), reward=reward)
if __name__ == "__main__":
StatefulCounterResourcesServer.run_webserver()
Key Pattern#
Use SESSION_ID_KEY from the request session middleware to maintain per-episode state. The session ID is automatically assigned by the framework’s middleware (set up in SimpleResourcesServer.setup_webserver() via self.setup_session_middleware(app)).
To access session state:
Add
request: Requestas a parameter to any endpoint methodRead the session ID with
request.session[SESSION_ID_KEY]Store state in an instance-level dictionary keyed by session ID
Note
In seed_session, use direct assignment (self.session_id_to_counter[session_id] = body.initial_count) rather than setdefault. Using setdefault would silently ignore re-seed attempts if the session already exists, which can cause subtle bugs when the same session ID is reused across episodes. Note: the current example_session_state_mgmt implementation still uses setdefault in seed_session — the direct assignment shown here is the preferred pattern.
In tool methods like increment_counter and get_counter_value, setdefault is appropriate — it provides a safe fallback of 0 if the session was somehow not initialized.
Rollout Transcript#
[Episode start]
Agent -> ResourcesServer: POST /seed_session {"initial_count": 3}
(ResourcesServer stores session_id_to_counter[session_id] = 3)
User: "Increment the counter by 2, then tell me the current value."
Agent -> ModelServer: POST /v1/responses (tools: increment_counter, get_counter_value)
Model calls tool:
function_call: increment_counter({"count": 2})
Agent -> ResourcesServer: POST /increment_counter {"count": 2}
ResourcesServer -> Agent:
{"success": true}
(counter is now 5 for this session_id)
Agent -> ModelServer: POST /v1/responses
Model calls tool:
function_call: get_counter_value({})
Agent -> ResourcesServer: POST /get_counter_value {}
ResourcesServer -> Agent:
{"count": 5}
[Episode end -> grading]
Agent -> ResourcesServer: POST /verify {"expected_count": 5, ...}
ResourcesServer:
- reads counter for this session_id
- reward = 1.0 if counter == expected_count else 0.0