1
# Copyright (C) 2006, 2007 Canonical Ltd
 
 
3
# This program is free software; you can redistribute it and/or modify
 
 
4
# it under the terms of the GNU General Public License as published by
 
 
5
# the Free Software Foundation; either version 2 of the License, or
 
 
6
# (at your option) any later version.
 
 
8
# This program is distributed in the hope that it will be useful,
 
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
 
11
# GNU General Public License for more details.
 
 
13
# You should have received a copy of the GNU General Public License
 
 
14
# along with this program; if not, write to the Free Software
 
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
 
22
from cStringIO import StringIO
 
 
24
from bzrlib import errors
 
 
25
from bzrlib.smart import request
 
 
28
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
 
29
# responses to identify the protocol version being used. (There are no version
 
 
30
# one strings because that version doesn't send any).
 
 
31
REQUEST_VERSION_TWO = 'bzr request 2\n'
 
 
32
RESPONSE_VERSION_TWO = 'bzr response 2\n'
 
 
35
def _recv_tuple(from_file):
 
 
36
    req_line = from_file.readline()
 
 
37
    return _decode_tuple(req_line)
 
 
40
def _decode_tuple(req_line):
 
 
41
    if req_line == None or req_line == '':
 
 
43
    if req_line[-1] != '\n':
 
 
44
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
 
45
    return tuple(req_line[:-1].split('\x01'))
 
 
48
def _encode_tuple(args):
 
 
49
    """Encode the tuple args to a bytestream."""
 
 
50
    return '\x01'.join(args) + '\n'
 
 
53
class SmartProtocolBase(object):
 
 
54
    """Methods common to client and server"""
 
 
56
    # TODO: this only actually accomodates a single block; possibly should
 
 
57
    # support multiple chunks?
 
 
58
    def _encode_bulk_data(self, body):
 
 
59
        """Encode body as a bulk data chunk."""
 
 
60
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
 
62
    def _serialise_offsets(self, offsets):
 
 
63
        """Serialise a readv offset list."""
 
 
65
        for start, length in offsets:
 
 
66
            txt.append('%d,%d' % (start, length))
 
 
70
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
 
71
    """Server-side encoding and decoding logic for smart version 1."""
 
 
73
    def __init__(self, backing_transport, write_func):
 
 
74
        self._backing_transport = backing_transport
 
 
75
        self.excess_buffer = ''
 
 
76
        self._finished = False
 
 
78
        self.has_dispatched = False
 
 
80
        self._body_decoder = None
 
 
81
        self._write_func = write_func
 
 
83
    def accept_bytes(self, bytes):
 
 
84
        """Take bytes, and advance the internal state machine appropriately.
 
 
86
        :param bytes: must be a byte string
 
 
88
        assert isinstance(bytes, str)
 
 
89
        self.in_buffer += bytes
 
 
90
        if not self.has_dispatched:
 
 
91
            if '\n' not in self.in_buffer:
 
 
94
            self.has_dispatched = True
 
 
96
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
 
98
                req_args = _decode_tuple(first_line)
 
 
99
                self.request = request.SmartServerRequestHandler(
 
 
100
                    self._backing_transport, commands=request.request_handlers)
 
 
101
                self.request.dispatch_command(req_args[0], req_args[1:])
 
 
102
                if self.request.finished_reading:
 
 
104
                    self.excess_buffer = self.in_buffer
 
 
106
                    self._send_response(self.request.response.args,
 
 
107
                        self.request.response.body)
 
 
108
            except KeyboardInterrupt:
 
 
110
            except Exception, exception:
 
 
111
                # everything else: pass to client, flush, and quit
 
 
112
                self._send_response(('error', str(exception)))
 
 
115
        if self.has_dispatched:
 
 
117
                # nothing to do.XXX: this routine should be a single state 
 
 
119
                self.excess_buffer += self.in_buffer
 
 
122
            if self._body_decoder is None:
 
 
123
                self._body_decoder = LengthPrefixedBodyDecoder()
 
 
124
            self._body_decoder.accept_bytes(self.in_buffer)
 
 
125
            self.in_buffer = self._body_decoder.unused_data
 
 
126
            body_data = self._body_decoder.read_pending_data()
 
 
127
            self.request.accept_body(body_data)
 
 
128
            if self._body_decoder.finished_reading:
 
 
129
                self.request.end_of_body()
 
 
130
                assert self.request.finished_reading, \
 
 
131
                    "no more body, request not finished"
 
 
132
            if self.request.response is not None:
 
 
133
                self._send_response(self.request.response.args,
 
 
134
                    self.request.response.body)
 
 
135
                self.excess_buffer = self.in_buffer
 
 
138
                assert not self.request.finished_reading, \
 
 
139
                    "no response and we have finished reading."
 
 
141
    def _send_response(self, args, body=None):
 
 
142
        """Send a smart server response down the output stream."""
 
 
143
        assert not self._finished, 'response already sent'
 
 
144
        self._finished = True
 
 
145
        self._write_protocol_version()
 
 
146
        self._write_func(_encode_tuple(args))
 
 
148
            assert isinstance(body, str), 'body must be a str'
 
 
149
            bytes = self._encode_bulk_data(body)
 
 
150
            self._write_func(bytes)
 
 
152
    def _write_protocol_version(self):
 
 
153
        """Write any prefixes this protocol requires.
 
 
155
        Version one doesn't send protocol versions.
 
 
158
    def next_read_size(self):
 
 
161
        if self._body_decoder is None:
 
 
164
            return self._body_decoder.next_read_size()
 
 
167
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
 
168
    r"""Version two of the server side of the smart protocol.
 
 
170
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
 
173
    def _write_protocol_version(self):
 
 
174
        r"""Write any prefixes this protocol requires.
 
 
176
        Version two sends the value of RESPONSE_VERSION_TWO.
 
 
178
        self._write_func(RESPONSE_VERSION_TWO)
 
 
181
class LengthPrefixedBodyDecoder(object):
 
 
182
    """Decodes the length-prefixed bulk data."""
 
 
185
        self.bytes_left = None
 
 
186
        self.finished_reading = False
 
 
187
        self.unused_data = ''
 
 
188
        self.state_accept = self._state_accept_expecting_length
 
 
189
        self.state_read = self._state_read_no_data
 
 
191
        self._trailer_buffer = ''
 
 
193
    def accept_bytes(self, bytes):
 
 
194
        """Decode as much of bytes as possible.
 
 
196
        If 'bytes' contains too much data it will be appended to
 
 
199
        finished_reading will be set when no more data is required.  Further
 
 
200
        data will be appended to self.unused_data.
 
 
202
        # accept_bytes is allowed to change the state
 
 
203
        current_state = self.state_accept
 
 
204
        self.state_accept(bytes)
 
 
205
        while current_state != self.state_accept:
 
 
206
            current_state = self.state_accept
 
 
207
            self.state_accept('')
 
 
209
    def next_read_size(self):
 
 
210
        if self.bytes_left is not None:
 
 
211
            # Ideally we want to read all the remainder of the body and the
 
 
213
            return self.bytes_left + 5
 
 
214
        elif self.state_accept == self._state_accept_reading_trailer:
 
 
215
            # Just the trailer left
 
 
216
            return 5 - len(self._trailer_buffer)
 
 
217
        elif self.state_accept == self._state_accept_expecting_length:
 
 
218
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
 
222
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
 
225
    def read_pending_data(self):
 
 
226
        """Return any pending data that has been decoded."""
 
 
227
        return self.state_read()
 
 
229
    def _state_accept_expecting_length(self, bytes):
 
 
230
        self._in_buffer += bytes
 
 
231
        pos = self._in_buffer.find('\n')
 
 
234
        self.bytes_left = int(self._in_buffer[:pos])
 
 
235
        self._in_buffer = self._in_buffer[pos+1:]
 
 
236
        self.bytes_left -= len(self._in_buffer)
 
 
237
        self.state_accept = self._state_accept_reading_body
 
 
238
        self.state_read = self._state_read_in_buffer
 
 
240
    def _state_accept_reading_body(self, bytes):
 
 
241
        self._in_buffer += bytes
 
 
242
        self.bytes_left -= len(bytes)
 
 
243
        if self.bytes_left <= 0:
 
 
245
            if self.bytes_left != 0:
 
 
246
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
 
247
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
 
248
            self.bytes_left = None
 
 
249
            self.state_accept = self._state_accept_reading_trailer
 
 
251
    def _state_accept_reading_trailer(self, bytes):
 
 
252
        self._trailer_buffer += bytes
 
 
253
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
 
254
        # a ProtocolViolation exception?
 
 
255
        if self._trailer_buffer.startswith('done\n'):
 
 
256
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
 
257
            self.state_accept = self._state_accept_reading_unused
 
 
258
            self.finished_reading = True
 
 
260
    def _state_accept_reading_unused(self, bytes):
 
 
261
        self.unused_data += bytes
 
 
263
    def _state_read_no_data(self):
 
 
266
    def _state_read_in_buffer(self):
 
 
267
        result = self._in_buffer
 
 
272
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
 
273
    """The client-side protocol for smart version 1."""
 
 
275
    def __init__(self, request):
 
 
276
        """Construct a SmartClientRequestProtocolOne.
 
 
278
        :param request: A SmartClientMediumRequest to serialise onto and
 
 
281
        self._request = request
 
 
282
        self._body_buffer = None
 
 
284
    def call(self, *args):
 
 
285
        self._write_args(args)
 
 
286
        self._request.finished_writing()
 
 
288
    def call_with_body_bytes(self, args, body):
 
 
289
        """Make a remote call of args with body bytes 'body'.
 
 
291
        After calling this, call read_response_tuple to find the result out.
 
 
293
        self._write_args(args)
 
 
294
        bytes = self._encode_bulk_data(body)
 
 
295
        self._request.accept_bytes(bytes)
 
 
296
        self._request.finished_writing()
 
 
298
    def call_with_body_readv_array(self, args, body):
 
 
299
        """Make a remote call with a readv array.
 
 
301
        The body is encoded with one line per readv offset pair. The numbers in
 
 
302
        each pair are separated by a comma, and no trailing \n is emitted.
 
 
304
        self._write_args(args)
 
 
305
        readv_bytes = self._serialise_offsets(body)
 
 
306
        bytes = self._encode_bulk_data(readv_bytes)
 
 
307
        self._request.accept_bytes(bytes)
 
 
308
        self._request.finished_writing()
 
 
310
    def cancel_read_body(self):
 
 
311
        """After expecting a body, a response code may indicate one otherwise.
 
 
313
        This method lets the domain client inform the protocol that no body
 
 
314
        will be transmitted. This is a terminal method: after calling it the
 
 
315
        protocol is not able to be used further.
 
 
317
        self._request.finished_reading()
 
 
319
    def read_response_tuple(self, expect_body=False):
 
 
320
        """Read a response tuple from the wire.
 
 
322
        This should only be called once.
 
 
324
        result = self._recv_tuple()
 
 
326
            self._request.finished_reading()
 
 
329
    def read_body_bytes(self, count=-1):
 
 
330
        """Read bytes from the body, decoding into a byte stream.
 
 
332
        We read all bytes at once to ensure we've checked the trailer for 
 
 
333
        errors, and then feed the buffer back as read_body_bytes is called.
 
 
335
        if self._body_buffer is not None:
 
 
336
            return self._body_buffer.read(count)
 
 
337
        _body_decoder = LengthPrefixedBodyDecoder()
 
 
339
        while not _body_decoder.finished_reading:
 
 
340
            bytes_wanted = _body_decoder.next_read_size()
 
 
341
            bytes = self._request.read_bytes(bytes_wanted)
 
 
342
            _body_decoder.accept_bytes(bytes)
 
 
343
        self._request.finished_reading()
 
 
344
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
 
345
        # XXX: TODO check the trailer result.
 
 
346
        return self._body_buffer.read(count)
 
 
348
    def _recv_tuple(self):
 
 
349
        """Receive a tuple from the medium request."""
 
 
351
        while not line or line[-1] != '\n':
 
 
352
            # TODO: this is inefficient - but tuples are short.
 
 
353
            new_char = self._request.read_bytes(1)
 
 
355
            assert new_char != '', "end of file reading from server."
 
 
356
        return _decode_tuple(line)
 
 
358
    def query_version(self):
 
 
359
        """Return protocol version number of the server."""
 
 
361
        resp = self.read_response_tuple()
 
 
362
        if resp == ('ok', '1'):
 
 
364
        elif resp == ('ok', '2'):
 
 
367
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
 
369
    def _write_args(self, args):
 
 
370
        self._write_protocol_version()
 
 
371
        bytes = _encode_tuple(args)
 
 
372
        self._request.accept_bytes(bytes)
 
 
374
    def _write_protocol_version(self):
 
 
375
        """Write any prefixes this protocol requires.
 
 
377
        Version one doesn't send protocol versions.
 
 
381
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
 
382
    """Version two of the client side of the smart protocol.
 
 
384
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
 
387
    def read_response_tuple(self, expect_body=False):
 
 
388
        """Read a response tuple from the wire.
 
 
390
        This should only be called once.
 
 
392
        version = self._request.read_line()
 
 
393
        if version != RESPONSE_VERSION_TWO:
 
 
394
            raise errors.SmartProtocolError('bad protocol marker %r' % version)
 
 
395
        return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
 
 
397
    def _write_protocol_version(self):
 
 
398
        r"""Write any prefixes this protocol requires.
 
 
400
        Version two sends the value of REQUEST_VERSION_TWO.
 
 
402
        self._request.accept_bytes(REQUEST_VERSION_TWO)