1
# Copyright (C) 2006, 2007 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
from cStringIO import StringIO
24
from bzrlib import errors
25
from bzrlib.smart import request
28
def _recv_tuple(from_file):
29
req_line = from_file.readline()
30
return _decode_tuple(req_line)
33
def _decode_tuple(req_line):
34
if req_line == None or req_line == '':
36
if req_line[-1] != '\n':
37
raise errors.SmartProtocolError("request %r not terminated" % req_line)
38
return tuple(req_line[:-1].split('\x01'))
41
def _encode_tuple(args):
42
"""Encode the tuple args to a bytestream."""
43
return '\x01'.join(args) + '\n'
46
class SmartProtocolBase(object):
47
"""Methods common to client and server"""
49
# TODO: this only actually accomodates a single block; possibly should
50
# support multiple chunks?
51
def _encode_bulk_data(self, body):
52
"""Encode body as a bulk data chunk."""
53
return ''.join(('%d\n' % len(body), body, 'done\n'))
55
def _serialise_offsets(self, offsets):
56
"""Serialise a readv offset list."""
58
for start, length in offsets:
59
txt.append('%d,%d' % (start, length))
63
class SmartServerRequestProtocolOne(SmartProtocolBase):
64
"""Server-side encoding and decoding logic for smart version 1."""
66
def __init__(self, backing_transport, write_func):
67
self._backing_transport = backing_transport
68
self.excess_buffer = ''
69
self._finished = False
71
self.has_dispatched = False
73
self._body_decoder = None
74
self._write_func = write_func
76
def accept_bytes(self, bytes):
77
"""Take bytes, and advance the internal state machine appropriately.
79
:param bytes: must be a byte string
81
assert isinstance(bytes, str)
82
self.in_buffer += bytes
83
if not self.has_dispatched:
84
if '\n' not in self.in_buffer:
87
self.has_dispatched = True
89
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
91
req_args = _decode_tuple(first_line)
92
self.request = request.SmartServerRequestHandler(
93
self._backing_transport, commands=request.request_handlers)
94
self.request.dispatch_command(req_args[0], req_args[1:])
95
if self.request.finished_reading:
97
self.excess_buffer = self.in_buffer
99
self._send_response(self.request.response.args,
100
self.request.response.body)
101
except KeyboardInterrupt:
103
except Exception, exception:
104
# everything else: pass to client, flush, and quit
105
self._send_response(('error', str(exception)))
108
if self.has_dispatched:
110
# nothing to do.XXX: this routine should be a single state
112
self.excess_buffer += self.in_buffer
115
if self._body_decoder is None:
116
self._body_decoder = LengthPrefixedBodyDecoder()
117
self._body_decoder.accept_bytes(self.in_buffer)
118
self.in_buffer = self._body_decoder.unused_data
119
body_data = self._body_decoder.read_pending_data()
120
self.request.accept_body(body_data)
121
if self._body_decoder.finished_reading:
122
self.request.end_of_body()
123
assert self.request.finished_reading, \
124
"no more body, request not finished"
125
if self.request.response is not None:
126
self._send_response(self.request.response.args,
127
self.request.response.body)
128
self.excess_buffer = self.in_buffer
131
assert not self.request.finished_reading, \
132
"no response and we have finished reading."
134
def _send_response(self, args, body=None):
135
"""Send a smart server response down the output stream."""
136
assert not self._finished, 'response already sent'
137
self._finished = True
138
self._write_protocol_version()
139
self._write_func(_encode_tuple(args))
141
assert isinstance(body, str), 'body must be a str'
142
bytes = self._encode_bulk_data(body)
143
self._write_func(bytes)
145
def _write_protocol_version(self):
146
"""Write any prefixes this protocol requires.
148
Version one doesn't send protocol versions.
151
def next_read_size(self):
154
if self._body_decoder is None:
157
return self._body_decoder.next_read_size()
160
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
161
r"""Version two of the server side of the smart protocol.
163
This prefixes responses with the protocol version: "2\n".
166
def _write_protocol_version(self):
167
r"""Write any prefixes this protocol requires.
169
Version two sends "2\n".
171
self._write_func('2\n')
174
class LengthPrefixedBodyDecoder(object):
175
"""Decodes the length-prefixed bulk data."""
178
self.bytes_left = None
179
self.finished_reading = False
180
self.unused_data = ''
181
self.state_accept = self._state_accept_expecting_length
182
self.state_read = self._state_read_no_data
184
self._trailer_buffer = ''
186
def accept_bytes(self, bytes):
187
"""Decode as much of bytes as possible.
189
If 'bytes' contains too much data it will be appended to
192
finished_reading will be set when no more data is required. Further
193
data will be appended to self.unused_data.
195
# accept_bytes is allowed to change the state
196
current_state = self.state_accept
197
self.state_accept(bytes)
198
while current_state != self.state_accept:
199
current_state = self.state_accept
200
self.state_accept('')
202
def next_read_size(self):
203
if self.bytes_left is not None:
204
# Ideally we want to read all the remainder of the body and the
206
return self.bytes_left + 5
207
elif self.state_accept == self._state_accept_reading_trailer:
208
# Just the trailer left
209
return 5 - len(self._trailer_buffer)
210
elif self.state_accept == self._state_accept_expecting_length:
211
# There's still at least 6 bytes left ('\n' to end the length, plus
215
# Reading excess data. Either way, 1 byte at a time is fine.
218
def read_pending_data(self):
219
"""Return any pending data that has been decoded."""
220
return self.state_read()
222
def _state_accept_expecting_length(self, bytes):
223
self._in_buffer += bytes
224
pos = self._in_buffer.find('\n')
227
self.bytes_left = int(self._in_buffer[:pos])
228
self._in_buffer = self._in_buffer[pos+1:]
229
self.bytes_left -= len(self._in_buffer)
230
self.state_accept = self._state_accept_reading_body
231
self.state_read = self._state_read_in_buffer
233
def _state_accept_reading_body(self, bytes):
234
self._in_buffer += bytes
235
self.bytes_left -= len(bytes)
236
if self.bytes_left <= 0:
238
if self.bytes_left != 0:
239
self._trailer_buffer = self._in_buffer[self.bytes_left:]
240
self._in_buffer = self._in_buffer[:self.bytes_left]
241
self.bytes_left = None
242
self.state_accept = self._state_accept_reading_trailer
244
def _state_accept_reading_trailer(self, bytes):
245
self._trailer_buffer += bytes
246
# TODO: what if the trailer does not match "done\n"? Should this raise
247
# a ProtocolViolation exception?
248
if self._trailer_buffer.startswith('done\n'):
249
self.unused_data = self._trailer_buffer[len('done\n'):]
250
self.state_accept = self._state_accept_reading_unused
251
self.finished_reading = True
253
def _state_accept_reading_unused(self, bytes):
254
self.unused_data += bytes
256
def _state_read_no_data(self):
259
def _state_read_in_buffer(self):
260
result = self._in_buffer
265
class SmartClientRequestProtocolOne(SmartProtocolBase):
266
"""The client-side protocol for smart version 1."""
268
def __init__(self, request):
269
"""Construct a SmartClientRequestProtocolOne.
271
:param request: A SmartClientMediumRequest to serialise onto and
274
self._request = request
275
self._body_buffer = None
277
def call(self, *args):
278
self._write_args(args)
279
self._request.finished_writing()
281
def call_with_body_bytes(self, args, body):
282
"""Make a remote call of args with body bytes 'body'.
284
After calling this, call read_response_tuple to find the result out.
286
self._write_args(args)
287
bytes = self._encode_bulk_data(body)
288
self._request.accept_bytes(bytes)
289
self._request.finished_writing()
291
def call_with_body_readv_array(self, args, body):
292
"""Make a remote call with a readv array.
294
The body is encoded with one line per readv offset pair. The numbers in
295
each pair are separated by a comma, and no trailing \n is emitted.
297
self._write_args(args)
298
readv_bytes = self._serialise_offsets(body)
299
bytes = self._encode_bulk_data(readv_bytes)
300
self._request.accept_bytes(bytes)
301
self._request.finished_writing()
303
def cancel_read_body(self):
304
"""After expecting a body, a response code may indicate one otherwise.
306
This method lets the domain client inform the protocol that no body
307
will be transmitted. This is a terminal method: after calling it the
308
protocol is not able to be used further.
310
self._request.finished_reading()
312
def read_response_tuple(self, expect_body=False):
313
"""Read a response tuple from the wire.
315
This should only be called once.
317
result = self._recv_tuple()
319
self._request.finished_reading()
322
def read_body_bytes(self, count=-1):
323
"""Read bytes from the body, decoding into a byte stream.
325
We read all bytes at once to ensure we've checked the trailer for
326
errors, and then feed the buffer back as read_body_bytes is called.
328
if self._body_buffer is not None:
329
return self._body_buffer.read(count)
330
_body_decoder = LengthPrefixedBodyDecoder()
332
while not _body_decoder.finished_reading:
333
bytes_wanted = _body_decoder.next_read_size()
334
bytes = self._request.read_bytes(bytes_wanted)
335
_body_decoder.accept_bytes(bytes)
336
self._request.finished_reading()
337
self._body_buffer = StringIO(_body_decoder.read_pending_data())
338
# XXX: TODO check the trailer result.
339
return self._body_buffer.read(count)
341
def _recv_tuple(self):
342
"""Receive a tuple from the medium request."""
344
while not line or line[-1] != '\n':
345
# TODO: this is inefficient - but tuples are short.
346
new_char = self._request.read_bytes(1)
348
assert new_char != '', "end of file reading from server."
349
return _decode_tuple(line)
351
def query_version(self):
352
"""Return protocol version number of the server."""
354
resp = self.read_response_tuple()
355
if resp == ('ok', '1'):
357
elif resp == ('ok', '2'):
360
raise errors.SmartProtocolError("bad response %r" % (resp,))
362
def _write_args(self, args):
363
self._write_protocol_version()
364
bytes = _encode_tuple(args)
365
self._request.accept_bytes(bytes)
367
def _write_protocol_version(self):
368
"""Write any prefixes this protocol requires.
370
Version one doesn't send protocol versions.
374
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
375
r"""Version two of the client side of the smart protocol.
377
This prefixes the request with the protocol version: "2\n".
380
_version_string = '2\n'
382
def read_response_tuple(self, expect_body=False):
383
"""Read a response tuple from the wire.
385
This should only be called once.
387
version = self._request.read_bytes(2)
388
if version != SmartClientRequestProtocolTwo._version_string:
389
raise errors.SmartProtocolError('bad protocol marker %r' % version)
390
return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
392
def _write_protocol_version(self):
393
r"""Write any prefixes this protocol requires.
395
Version two sends "2\n".
397
self._request.accept_bytes(SmartClientRequestProtocolTwo._version_string)