1
# Copyright (C) 2006 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
from cStringIO import StringIO
19
from bzrlib import errors
22
def _recv_tuple(from_file):
23
req_line = from_file.readline()
24
return _decode_tuple(req_line)
27
def _decode_tuple(req_line):
28
if req_line == None or req_line == '':
30
if req_line[-1] != '\n':
31
raise errors.SmartProtocolError("request %r not terminated" % req_line)
32
return tuple(req_line[:-1].split('\x01'))
35
def _encode_tuple(args):
36
"""Encode the tuple args to a bytestream."""
37
return '\x01'.join(args) + '\n'
40
class SmartProtocolBase(object):
41
"""Methods common to client and server"""
43
# TODO: this only actually accomodates a single block; possibly should
44
# support multiple chunks?
45
def _encode_bulk_data(self, body):
46
"""Encode body as a bulk data chunk."""
47
return ''.join(('%d\n' % len(body), body, 'done\n'))
49
def _serialise_offsets(self, offsets):
50
"""Serialise a readv offset list."""
52
for start, length in offsets:
53
txt.append('%d,%d' % (start, length))
57
class SmartServerRequestProtocolOne(SmartProtocolBase):
58
"""Server-side encoding and decoding logic for smart version 1."""
60
def __init__(self, backing_transport, write_func):
61
self._backing_transport = backing_transport
62
self.excess_buffer = ''
63
self._finished = False
65
self.has_dispatched = False
67
self._body_decoder = None
68
self._write_func = write_func
70
def accept_bytes(self, bytes):
71
"""Take bytes, and advance the internal state machine appropriately.
73
:param bytes: must be a byte string
75
from bzrlib.transport.smart import request
76
assert isinstance(bytes, str)
77
self.in_buffer += bytes
78
if not self.has_dispatched:
79
if '\n' not in self.in_buffer:
82
self.has_dispatched = True
84
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
86
req_args = _decode_tuple(first_line)
87
self.request = request.SmartServerRequestHandler(
88
self._backing_transport)
89
self.request.dispatch_command(req_args[0], req_args[1:])
90
if self.request.finished_reading:
92
self.excess_buffer = self.in_buffer
94
self._send_response(self.request.response.args,
95
self.request.response.body)
96
except KeyboardInterrupt:
98
except Exception, exception:
99
# everything else: pass to client, flush, and quit
100
self._send_response(('error', str(exception)))
103
if self.has_dispatched:
105
# nothing to do.XXX: this routine should be a single state
107
self.excess_buffer += self.in_buffer
110
if self._body_decoder is None:
111
self._body_decoder = LengthPrefixedBodyDecoder()
112
self._body_decoder.accept_bytes(self.in_buffer)
113
self.in_buffer = self._body_decoder.unused_data
114
body_data = self._body_decoder.read_pending_data()
115
self.request.accept_body(body_data)
116
if self._body_decoder.finished_reading:
117
self.request.end_of_body()
118
assert self.request.finished_reading, \
119
"no more body, request not finished"
120
if self.request.response is not None:
121
self._send_response(self.request.response.args,
122
self.request.response.body)
123
self.excess_buffer = self.in_buffer
126
assert not self.request.finished_reading, \
127
"no response and we have finished reading."
129
def _send_response(self, args, body=None):
130
"""Send a smart server response down the output stream."""
131
assert not self._finished, 'response already sent'
132
self._finished = True
133
self._write_func(_encode_tuple(args))
135
assert isinstance(body, str), 'body must be a str'
136
bytes = self._encode_bulk_data(body)
137
self._write_func(bytes)
139
def next_read_size(self):
142
if self._body_decoder is None:
145
return self._body_decoder.next_read_size()
148
class LengthPrefixedBodyDecoder(object):
149
"""Decodes the length-prefixed bulk data."""
152
self.bytes_left = None
153
self.finished_reading = False
154
self.unused_data = ''
155
self.state_accept = self._state_accept_expecting_length
156
self.state_read = self._state_read_no_data
158
self._trailer_buffer = ''
160
def accept_bytes(self, bytes):
161
"""Decode as much of bytes as possible.
163
If 'bytes' contains too much data it will be appended to
166
finished_reading will be set when no more data is required. Further
167
data will be appended to self.unused_data.
169
# accept_bytes is allowed to change the state
170
current_state = self.state_accept
171
self.state_accept(bytes)
172
while current_state != self.state_accept:
173
current_state = self.state_accept
174
self.state_accept('')
176
def next_read_size(self):
177
if self.bytes_left is not None:
178
# Ideally we want to read all the remainder of the body and the
180
return self.bytes_left + 5
181
elif self.state_accept == self._state_accept_reading_trailer:
182
# Just the trailer left
183
return 5 - len(self._trailer_buffer)
184
elif self.state_accept == self._state_accept_expecting_length:
185
# There's still at least 6 bytes left ('\n' to end the length, plus
189
# Reading excess data. Either way, 1 byte at a time is fine.
192
def read_pending_data(self):
193
"""Return any pending data that has been decoded."""
194
return self.state_read()
196
def _state_accept_expecting_length(self, bytes):
197
self._in_buffer += bytes
198
pos = self._in_buffer.find('\n')
201
self.bytes_left = int(self._in_buffer[:pos])
202
self._in_buffer = self._in_buffer[pos+1:]
203
self.bytes_left -= len(self._in_buffer)
204
self.state_accept = self._state_accept_reading_body
205
self.state_read = self._state_read_in_buffer
207
def _state_accept_reading_body(self, bytes):
208
self._in_buffer += bytes
209
self.bytes_left -= len(bytes)
210
if self.bytes_left <= 0:
212
if self.bytes_left != 0:
213
self._trailer_buffer = self._in_buffer[self.bytes_left:]
214
self._in_buffer = self._in_buffer[:self.bytes_left]
215
self.bytes_left = None
216
self.state_accept = self._state_accept_reading_trailer
218
def _state_accept_reading_trailer(self, bytes):
219
self._trailer_buffer += bytes
220
# TODO: what if the trailer does not match "done\n"? Should this raise
221
# a ProtocolViolation exception?
222
if self._trailer_buffer.startswith('done\n'):
223
self.unused_data = self._trailer_buffer[len('done\n'):]
224
self.state_accept = self._state_accept_reading_unused
225
self.finished_reading = True
227
def _state_accept_reading_unused(self, bytes):
228
self.unused_data += bytes
230
def _state_read_no_data(self):
233
def _state_read_in_buffer(self):
234
result = self._in_buffer
239
class SmartServerResponse(object):
240
"""Response generated by SmartServerRequestHandler."""
242
def __init__(self, args, body=None):
246
# XXX: TODO: Create a SmartServerRequestHandler which will take the responsibility
247
# for delivering the data for a request. This could be done with as the
248
# StreamServer, though that would create conflation between request and response
249
# which may be undesirable.
252
class SmartClientRequestProtocolOne(SmartProtocolBase):
253
"""The client-side protocol for smart version 1."""
255
def __init__(self, request):
256
"""Construct a SmartClientRequestProtocolOne.
258
:param request: A SmartClientMediumRequest to serialise onto and
261
self._request = request
262
self._body_buffer = None
264
def call(self, *args):
265
bytes = _encode_tuple(args)
266
self._request.accept_bytes(bytes)
267
self._request.finished_writing()
269
def call_with_body_bytes(self, args, body):
270
"""Make a remote call of args with body bytes 'body'.
272
After calling this, call read_response_tuple to find the result out.
274
bytes = _encode_tuple(args)
275
self._request.accept_bytes(bytes)
276
bytes = self._encode_bulk_data(body)
277
self._request.accept_bytes(bytes)
278
self._request.finished_writing()
280
def call_with_body_readv_array(self, args, body):
281
"""Make a remote call with a readv array.
283
The body is encoded with one line per readv offset pair. The numbers in
284
each pair are separated by a comma, and no trailing \n is emitted.
286
bytes = _encode_tuple(args)
287
self._request.accept_bytes(bytes)
288
readv_bytes = self._serialise_offsets(body)
289
bytes = self._encode_bulk_data(readv_bytes)
290
self._request.accept_bytes(bytes)
291
self._request.finished_writing()
293
def cancel_read_body(self):
294
"""After expecting a body, a response code may indicate one otherwise.
296
This method lets the domain client inform the protocol that no body
297
will be transmitted. This is a terminal method: after calling it the
298
protocol is not able to be used further.
300
self._request.finished_reading()
302
def read_response_tuple(self, expect_body=False):
303
"""Read a response tuple from the wire.
305
This should only be called once.
307
result = self._recv_tuple()
309
self._request.finished_reading()
312
def read_body_bytes(self, count=-1):
313
"""Read bytes from the body, decoding into a byte stream.
315
We read all bytes at once to ensure we've checked the trailer for
316
errors, and then feed the buffer back as read_body_bytes is called.
318
if self._body_buffer is not None:
319
return self._body_buffer.read(count)
320
_body_decoder = LengthPrefixedBodyDecoder()
322
while not _body_decoder.finished_reading:
323
bytes_wanted = _body_decoder.next_read_size()
324
bytes = self._request.read_bytes(bytes_wanted)
325
_body_decoder.accept_bytes(bytes)
326
self._request.finished_reading()
327
self._body_buffer = StringIO(_body_decoder.read_pending_data())
328
# XXX: TODO check the trailer result.
329
return self._body_buffer.read(count)
331
def _recv_tuple(self):
332
"""Receive a tuple from the medium request."""
334
while not line or line[-1] != '\n':
335
# TODO: this is inefficient - but tuples are short.
336
new_char = self._request.read_bytes(1)
338
assert new_char != '', "end of file reading from server."
339
return _decode_tuple(line)
341
def query_version(self):
342
"""Return protocol version number of the server."""
344
resp = self.read_response_tuple()
345
if resp == ('ok', '1'):
348
raise errors.SmartProtocolError("bad response %r" % (resp,))