Stateful Environment

View as Markdown

This tutorial focuses on the Resources Server implementation for environments that maintain state across tool calls within an episode. The full workflow — task data preparation, agent/model configuration, rollout collection, and training — follows the same steps as the single-step tutorial. What changes here is the addition of per-episode session state via middleware.

← Multi-Step Environment

What You’ll Build

A counter environment where the user asks the agent to increment a counter by a certain amount and report the final value. The environment seeds the counter with an initial value, and the agent must call increment_counter and get_counter_value in sequence to produce the correct result. Like the multi-step tutorial, this is single-turn (one user message, multiple tool calls) — but the key difference is that the counter value lives as server-side session state that persists across tool calls within the episode, managed via SESSION_ID_KEY.

Episode Flow

Goal (what the agent is learning)
- Learn multi-step stateful tool usage: perform actions that change environment state and then read/verify the final state.
Inputs
- seed input: initial_count (e.g., 3)
- ground truth for grading: expected_count (e.g., 5)
Flow (state is stored per session_id inside the ResourcesServer)
1) POST ResourcesServer /seed_session {"initial_count": 3}
- stores session_id_to_counter[session_id] = 3
2) POST ModelServer /v1/responses -> function_call: increment_counter({"count": 2})
3) POST ResourcesServer /increment_counter {"count": 2}
- counter becomes 5 for this session_id
4) POST ModelServer /v1/responses -> function_call: get_counter_value({})
5) POST ResourcesServer /get_counter_value {}
- returns {"count": 5}
6) POST ResourcesServer /verify {"expected_count": 5, ...}
- reward = 1.0 iff stored counter == expected_count

Implementation

File (simplified from resources_servers/example_session_state_mgmt/app.py, with improved seed_session pattern):

1# simplified
2from typing import Dict
3
4from fastapi import FastAPI, Request
5from pydantic import BaseModel, Field
6
7from nemo_gym.base_resources_server import (
8 BaseResourcesServerConfig,
9 BaseSeedSessionRequest,
10 BaseSeedSessionResponse,
11 BaseVerifyRequest,
12 BaseVerifyResponse,
13 SimpleResourcesServer,
14)
15from nemo_gym.server_utils import SESSION_ID_KEY # Critical import!
16
17class StatefulCounterResourcesServerConfig(BaseResourcesServerConfig):
18 pass
19
20# Custom seed request to initialize state
21class StatefulCounterSeedSessionRequest(BaseSeedSessionRequest):
22 initial_count: int
23
24class IncrementCounterRequest(BaseModel):
25 count: int
26
27class IncrementCounterResponse(BaseModel):
28 success: bool
29
30class GetCounterValueResponse(BaseModel):
31 count: int
32
33class StatefulCounterVerifyRequest(BaseVerifyRequest):
34 expected_count: int
35
36class StatefulCounterResourcesServer(SimpleResourcesServer):
37 config: StatefulCounterResourcesServerConfig
38
39 # Session state storage - maps session_id -> state
40 session_id_to_counter: Dict[str, int] = Field(default_factory=dict)
41
42 def setup_webserver(self) -> FastAPI:
43 app = super().setup_webserver()
44 app.post("/increment_counter")(self.increment_counter)
45 app.post("/get_counter_value")(self.get_counter_value)
46 return app
47
48 # Initialize session state
49 async def seed_session(
50 self,
51 request: Request, # Must include Request to access session
52 body: StatefulCounterSeedSessionRequest
53 ) -> BaseSeedSessionResponse:
54 session_id = request.session[SESSION_ID_KEY] # Get unique session ID
55 self.session_id_to_counter[session_id] = body.initial_count
56 return BaseSeedSessionResponse()
57
58 # Stateful tool - modifies session state
59 async def increment_counter(
60 self,
61 request: Request,
62 body: IncrementCounterRequest
63 ) -> IncrementCounterResponse:
64 session_id = request.session[SESSION_ID_KEY]
65 counter = self.session_id_to_counter.setdefault(session_id, 0)
66 counter += body.count
67 self.session_id_to_counter[session_id] = counter
68 return IncrementCounterResponse(success=True)
69
70 # Read-only tool
71 async def get_counter_value(self, request: Request) -> GetCounterValueResponse:
72 session_id = request.session[SESSION_ID_KEY]
73 counter = self.session_id_to_counter.setdefault(session_id, 0)
74 return GetCounterValueResponse(count=counter)
75
76 # Verify against expected final state
77 async def verify(
78 self,
79 request: Request,
80 body: StatefulCounterVerifyRequest
81 ) -> BaseVerifyResponse:
82 session_id = request.session[SESSION_ID_KEY]
83
84 reward = 0.0
85 if session_id in self.session_id_to_counter:
86 counter = self.session_id_to_counter[session_id]
87 reward = float(body.expected_count == counter)
88
89 return BaseVerifyResponse(**body.model_dump(), reward=reward)
90
91if __name__ == "__main__":
92 StatefulCounterResourcesServer.run_webserver()

Key Pattern

Use SESSION_ID_KEY from the request session middleware to maintain per-episode state. The session ID is automatically assigned by the framework’s middleware (set up in SimpleResourcesServer.setup_webserver() via self.setup_session_middleware(app)).

To access session state:

  1. Add request: Request as a parameter to any endpoint method
  2. Read the session ID with request.session[SESSION_ID_KEY]
  3. Store state in an instance-level dictionary keyed by session ID

In seed_session, use direct assignment (self.session_id_to_counter[session_id] = body.initial_count) rather than setdefault. Using setdefault would silently ignore re-seed attempts if the session already exists, which can cause subtle bugs when the same session ID is reused across episodes. Note: the current example_session_state_mgmt implementation still uses setdefault in seed_session — the direct assignment shown here is the preferred pattern.

In tool methods like increment_counter and get_counter_value, setdefault is appropriate — it provides a safe fallback of 0 if the session was somehow not initialized.


Rollout Transcript

[Episode start]
Agent -> ResourcesServer: POST /seed_session {"initial_count": 3}
(ResourcesServer stores session_id_to_counter[session_id] = 3)
User: "Increment the counter by 2, then tell me the current value."
Agent -> ModelServer: POST /v1/responses (tools: increment_counter, get_counter_value)
Model calls tool:
function_call: increment_counter({"count": 2})
Agent -> ResourcesServer: POST /increment_counter {"count": 2}
ResourcesServer -> Agent:
{"success": true}
(counter is now 5 for this session_id)
Agent -> ModelServer: POST /v1/responses
Model calls tool:
function_call: get_counter_value({})
Agent -> ResourcesServer: POST /get_counter_value {}
ResourcesServer -> Agent:
{"count": 5}
[Episode end -> grading]
Agent -> ResourcesServer: POST /verify {"expected_count": 5, ...}
ResourcesServer:
- reads counter for this session_id
- reward = 1.0 if counter == expected_count else 0.0

Real-World Environment →