Source code for nv_ingest_api.util.message_brokers.simple_message_broker.ordered_message_queue

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import threading
import heapq


[docs] class OrderedMessageQueue: def __init__(self, maxsize=0): self.queue = [] # List of (index, message) tuples self.maxsize = maxsize self.next_index = 0 # Monotonically increasing message index self.in_flight = {} # Mapping of transaction_id to (index, message) self.lock = threading.Lock() self.not_empty = threading.Condition(self.lock) self.not_full = threading.Condition(self.lock)
[docs] def can_push(self): """Check if the queue can accept more messages.""" with self.lock: return self.maxsize == 0 or (len(self.queue) + len(self.in_flight)) < self.maxsize
[docs] def push(self, message): """Add a message to the queue after it has been acknowledged.""" with self.lock: index = self.next_index self.next_index += 1 heapq.heappush(self.queue, (index, message)) self.not_empty.notify() return index
[docs] def pop(self, transaction_id): """Pop a message from the queue and mark it as in-flight.""" with self.lock: while not self.queue: self.not_empty.wait() index, message = heapq.heappop(self.queue) self.in_flight[transaction_id] = (index, message) self.not_full.notify() return message
[docs] def acknowledge(self, transaction_id): """Acknowledge that a message has been processed.""" with self.lock: if transaction_id in self.in_flight: del self.in_flight[transaction_id]
[docs] def return_message(self, transaction_id): """Return an unacknowledged message back to the queue.""" with self.lock: if transaction_id in self.in_flight: index, message = self.in_flight.pop(transaction_id) heapq.heappush(self.queue, (index, message)) self.not_empty.notify()
[docs] def qsize(self): """Get the number of messages currently in the queue.""" with self.lock: return len(self.queue)
[docs] def empty(self): """Check if the queue is empty.""" with self.lock: return not self.queue
[docs] def full(self): """Check if the queue is full.""" with self.lock: return self.maxsize > 0 and (len(self.queue) + len(self.in_flight)) >= self.maxsize