Source code for thriftpool.rpc.transport

from __future__ import absolute_import

import logging
import cPickle as pickle
from uuid import UUID, uuid4
from struct import Struct
from cStringIO import StringIO

from gaffer.events import EventEmitter

from thriftworker.utils.decorators import cached_property
from thriftworker.utils.loop import in_loop

logger = logging.getLogger(__name__)


[docs]class Proto(object): """Mixin that produce *encode* and *decode* methods. Each packet represented by simple structure: +----------+--------------------+-----------------+ | Length | Request id (UUID) | Pickled object | +==========+====================+=================+ | 4 bytes | 32 bytes | undefined | +----------+--------------------+-----------------+ """ @cached_property def _length_struct(self): """Pack and unpack length.""" return Struct('I') def _prepend_length(self, data): """Prepend length to data.""" return self._length_struct.pack(len(data)) + data def _encode(self, request_id, obj): """Convert given request to packet.""" assert isinstance(request_id, UUID), 'wrong request id given' return self._prepend_length(request_id.hex + pickle.dumps(obj)) def _split_length(self, data): """Return tuple of length and raw data.""" assert len(data) >= 4, 'given buffer to small (length)' encoded_length, left_data = data[:4], data[4:] length = self._length_struct.unpack(encoded_length)[0] assert length > 0, 'wrong length provided' return (length, left_data) def _split_request_id(self, data): """Return tuple of request id and raw data.""" assert len(data) >= 32, 'given buffer to small (request_id)' request_id, left_data = data[:32], data[32:] return (UUID(hex=request_id, version=4), left_data) def _decode_body(self, data): """Unpickle given data.""" return pickle.loads(data)
[docs]class Receiver(Proto): """Receive request from channel.""" WAIT_LENGTH = 1 WAIT_PAYLOAD = 2 def __init__(self, stream, emitter): self.stream = stream self.emitter = emitter self._buf = StringIO() self._state = self.WAIT_LENGTH self._received = self._length = 0 super(Receiver, self).__init__() def _on_read(self, evtype, info): data = info['data'] if self._state == self.WAIT_LENGTH: # Try to receive packet length here. self._length, data = self._split_length(data) self._state = self.WAIT_PAYLOAD if self._state == self.WAIT_PAYLOAD: # Write data to buffer. self._received += len(data) self._buf.write(data) if self._received >= self._length: data = self._buf.getvalue() data, left = data[:self._length], data[self._length:] self._on_received(data) self._buf = StringIO() self._received = self._length = 0 self._state = self.WAIT_LENGTH if left: self._on_read(evtype, dict(data=left)) def _on_received(self, data): request_id, data = self._split_request_id(data) obj = self._decode_body(data) self.emitter.publish("received", request_id=request_id, obj=obj)
[docs] def start(self): self.stream.subscribe(self._on_read)
[docs] def stop(self): self.stream.unsubscribe(self._on_read)
[docs]class Transmitter(Proto): """Write request to channel.""" def __init__(self, stream, emitter): self.stream = stream self.emitter = emitter self._requests = {} super(Transmitter, self).__init__()
[docs] def write(self, obj, callback=None, request_id=None): request_id = request_id or uuid4() if callback is not None: self._requests[request_id] = callback data = self._encode(request_id, obj) self.stream.write(data)
def _on_received(self, evtype, request_id, obj): callback = self._requests.pop(request_id, None) if callback is not None: try: callback(obj) except Exception as exc: logger.exception(exc)
[docs] def start(self): self.emitter.subscribe('received', self._on_received)
[docs] def stop(self): self.emitter.unsubscribe('received', self._on_received)
[docs]class Transport(object): """Exchange messages on channel.""" Receiver = Receiver Transmitter = Transmitter def __init__(self, loop, incoming, outgoing): self.loop = loop self._emitter = EventEmitter(loop) self._receiver = self.Receiver(incoming, self._emitter) self._transmitter = self.Transmitter(outgoing, self._emitter)
[docs] def write(self, obj, callback=None, request_id=None): self._transmitter.write(obj, callback, request_id)
[docs] def subscribe(self, listener): self._emitter.subscribe('received', listener)
[docs] def unsubscribe(self, listener): self._emitter.unsubscribe('received', listener)
[docs] def start(self): self._receiver.start() self._transmitter.start()
[docs] def stop(self): self._transmitter.stop() self._receiver.stop() self._emitter.close()
[docs]class Consumer(Transport): """Pull commands from consumer, execute them and return result.""" def __init__(self, loop, incoming, outgoing, handler): self.handler = handler super(Consumer, self).__init__(loop, incoming, outgoing) def _on_incoming(self, evtype, request_id, obj): method_name, args, kwargs = obj logger.debug('Execute method {0!r} with args {1!r} and kwargs {2!r}' .format(method_name, args, kwargs)) try: result = getattr(self.handler, method_name)(*args, **kwargs) except Exception as exc: logger.exception(exc) result = exc self.write(result, request_id=request_id)
[docs] def start(self): super(Consumer, self).start() self.subscribe(self._on_incoming)
[docs] def stop(self): self.unsubscribe(self._on_incoming) super(Consumer, self).stop()
[docs]class Producer(Transport): """Push commands to consumer.""" def __init__(self, loop, incoming, outgoing, process): super(Producer, self).__init__(loop, incoming, outgoing) self.process = process @in_loop
[docs] def apply(self, method_name, callback=None, args=None, kwargs=None): """Enqueue new remote procedure call.""" assert self.process.active, 'process not active' self.write((str(method_name), args or [], kwargs or {}), callback)

Project Versions

This Page