Source code for aiogremlin.driver.protocol
import asyncio
import base64
import collections
import logging
try:
import ujson as json
except ImportError:
import json
from gremlin_python.driver import protocol, request, serializer
__author__ = 'David M. Brown (davebshow@gmail.com)'
logger = logging.getLogger(__name__)
Message = collections.namedtuple(
"Message",
["status_code", "data", "message"])
[docs]class GremlinServerWSProtocol(protocol.AbstractBaseProtocol):
"""Implemenation of the Gremlin Server Websocket protocol"""
def __init__(self, message_serializer, username='', password=''):
if isinstance(message_serializer, type):
message_serializer = message_serializer()
self._message_serializer = message_serializer
self._username = username
self._password = password
[docs] def connection_made(self, transport):
self._transport = transport
[docs] async def write(self, request_id, request_message):
message = self._message_serializer.serialize_message(
request_id, request_message)
func = self._transport.write(message)
if asyncio.iscoroutine(func):
await func
[docs] async def data_received(self, data, results_dict):
data = data.decode('utf-8')
message = self._message_serializer.deserialize_message(json.loads(data))
request_id = message['requestId']
status_code = message['status']['code']
data = message['result']['data']
msg = message['status']['message']
if request_id in results_dict:
result_set = results_dict[request_id]
aggregate_to = message['result']['meta'].get('aggregateTo', 'list')
result_set.aggregate_to = aggregate_to
if status_code == 407:
auth = b''.join([b'\x00', self._username.encode('utf-8'),
b'\x00', self._password.encode('utf-8')])
request_message = request.RequestMessage(
'traversal', 'authentication',
{'sasl': base64.b64encode(auth).decode()})
await self.write(request_id, request_message)
elif status_code == 204:
result_set.queue_result(None)
else:
if data:
for result in data:
result = self._message_serializer.deserialize_message(result)
message = Message(status_code, result, msg)
result_set.queue_result(message)
else:
message = Message(status_code, data, msg)
result_set.queue_result(message)
if status_code != 206:
result_set.queue_result(None)