Skip to content
64 changes: 43 additions & 21 deletions dapr/ext/grpc/_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,30 @@ def __init__(self):
self._registered_topics: List[appcallback_v1.TopicSubscription] = []
self._registered_bindings: List[str] = []

self._route_map: Dict[Tuple[str, str], TopicSubscribeCallable] = {}
self._validation_disabled_pubsubs: Dict[str, List[TopicSubscribeCallable]] = {}

def _get_topic_callback(
self, pubsub_name: str, topic: str, path: str
) -> Optional[TopicSubscribeCallable]:
pubsub_topic = pubsub_name + DELIMITER + topic + DELIMITER + path
if pubsub_topic in self._topic_map:
return self._topic_map[pubsub_topic]

if (pubsub_name, path) in self._route_map:
return self._route_map[(pubsub_name, path)]

if path == '':
if (pubsub_name, topic) in self._route_map:
return self._route_map[(pubsub_name, topic)]

if pubsub_name in self._validation_disabled_pubsubs:
callbacks = self._validation_disabled_pubsubs[pubsub_name]
if len(callbacks) == 1:
return callbacks[0]

Comment on lines +94 to +105
return None

def register_method(self, method: str, cb: InvokeMethodCallable) -> None:
"""Registers method for service invocation."""
if method in self._invoke_method_map:
Expand All @@ -98,17 +122,21 @@ def register_topic(
disable_topic_validation: Optional[bool] = False,
) -> None:
"""Registers topic subscription for pubsub."""
if not disable_topic_validation:
topic_key = pubsub_name + DELIMITER + topic
else:
topic_key = pubsub_name
topic_key = pubsub_name + DELIMITER + topic
pubsub_topic = topic_key + DELIMITER
if rule is not None:
path = getattr(cb, '__name__', rule.match)
pubsub_topic = pubsub_topic + path
if pubsub_topic in self._topic_map:
raise ValueError(f'{topic} is already registered with {pubsub_name}')
self._topic_map[pubsub_topic] = cb
routing_path = path if rule is not None else topic
self._route_map[(pubsub_name, routing_path)] = cb

if disable_topic_validation:
if pubsub_name not in self._validation_disabled_pubsubs:
self._validation_disabled_pubsubs[pubsub_name] = []
self._validation_disabled_pubsubs[pubsub_name].append(cb)

registered_topic = self._registered_topics_map.get(topic_key)
sub: appcallback_v1.TopicSubscription = appcallback_v1.TopicSubscription()
Expand All @@ -122,6 +150,10 @@ def register_topic(
)
if dead_letter_topic:
sub.dead_letter_topic = dead_letter_topic

if disable_topic_validation and rule is None:
sub.routes.default = topic

registered_topic = _RegisteredSubscription(sub, rules)
self._registered_topics_map[topic_key] = registered_topic
self._registered_topics.append(sub)
Expand Down Expand Up @@ -196,15 +228,10 @@ def ListTopicSubscriptions(self, request, context):

def OnTopicEvent(self, request: TopicEventRequest, context):
"""Subscribes events from Pubsub."""
pubsub_topic = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path
no_validation_key = request.pubsub_name + DELIMITER + request.path

if pubsub_topic not in self._topic_map:
if no_validation_key in self._topic_map:
pubsub_topic = no_validation_key
else:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
raise NotImplementedError(f'topic {request.topic} is not implemented!')
cb = self._get_topic_callback(request.pubsub_name, request.topic, request.path)
if cb is None:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
raise NotImplementedError(f'topic {request.topic} is not implemented!')

customdata: Struct = request.extensions
extensions = dict()
Expand All @@ -222,7 +249,7 @@ def OnTopicEvent(self, request: TopicEventRequest, context):
event.SetSubject(request.topic)
event.SetExtensions(extensions)

response = self._topic_map[pubsub_topic](event)
response = cb(event)
if isinstance(response, TopicEventResponse):
return appcallback_v1.TopicEventResponse(status=response.status.value)
return empty_pb2.Empty()
Expand Down Expand Up @@ -292,15 +319,10 @@ def _handle_bulk_topic_event(
self, request: TopicEventBulkRequest, context
) -> Optional[TopicEventBulkResponse]:
"""Process bulk topic event request - routes each entry to the appropriate topic handler."""
topic_key = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path
no_validation_key = request.pubsub_name + DELIMITER + request.path

if topic_key not in self._topic_map and no_validation_key not in self._topic_map:
cb = self._get_topic_callback(request.pubsub_name, request.topic, request.path)
if cb is None:
return None # we don't have a handler

handler_key = topic_key if topic_key in self._topic_map else no_validation_key
cb = self._topic_map[handler_key] # callback

statuses = []
for entry in request.entries:
entry_id = entry.entry_id
Expand Down
34 changes: 34 additions & 0 deletions tests/ext/grpc/test_servicier.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,40 @@ def test_non_registered_topic(self):
self.fake_context,
)

def test_multiple_wildcard_subscriptions(self):
self._servicer.register_topic(
'pubsub_multi_wildcard',
'orders/+/items',
self._topic1_method,
None,
disable_topic_validation=True,
)
self._servicer.register_topic(
'pubsub_multi_wildcard',
'inventory/#',
self._topic2_method,
None,
disable_topic_validation=True,
)

self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(
pubsub_name='pubsub_multi_wildcard', topic='orders/123/items', path='orders/+/items'
),
self.fake_context,
)
self._topic1_method.assert_called_once()

self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(
pubsub_name='pubsub_multi_wildcard',
topic='inventory/warehouse/aisle4',
path='inventory/#',
),
self.fake_context,
)
self._topic2_method.assert_called_once()


class BulkTopicEventTests(unittest.TestCase):
def setUp(self):
Expand Down
Loading