/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: John Arbash Meinel
  • Date: 2009-12-18 16:39:21 UTC
  • mto: This revision was merged to the branch mainline in revision 4934.
  • Revision ID: john@arbash-meinel.com-20091218163921-tcltjarx4pxxm08y
Basic implementation of logging bytes transferred when bzr exits.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
18
client and server.
 
19
"""
 
20
 
 
21
import collections
 
22
from cStringIO import StringIO
 
23
import struct
 
24
import sys
 
25
import threading
 
26
import time
 
27
 
 
28
import bzrlib
 
29
from bzrlib import (
 
30
    debug,
 
31
    errors,
 
32
    osutils,
 
33
    )
 
34
from bzrlib.smart import message, request
 
35
from bzrlib.trace import log_exception_quietly, mutter
 
36
from bzrlib.bencode import bdecode_as_tuple, bencode
 
37
 
 
38
 
 
39
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
40
# responses to identify the protocol version being used. (There are no version
 
41
# one strings because that version doesn't send any).
 
42
REQUEST_VERSION_TWO = 'bzr request 2\n'
 
43
RESPONSE_VERSION_TWO = 'bzr response 2\n'
 
44
 
 
45
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
 
46
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
47
 
 
48
 
 
49
def _recv_tuple(from_file):
 
50
    req_line = from_file.readline()
 
51
    return _decode_tuple(req_line)
 
52
 
 
53
 
 
54
def _decode_tuple(req_line):
 
55
    if req_line is None or req_line == '':
 
56
        return None
 
57
    if req_line[-1] != '\n':
 
58
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
59
    return tuple(req_line[:-1].split('\x01'))
 
60
 
 
61
 
 
62
def _encode_tuple(args):
 
63
    """Encode the tuple args to a bytestream."""
 
64
    return '\x01'.join(args) + '\n'
 
65
 
 
66
 
 
67
class Requester(object):
 
68
    """Abstract base class for an object that can issue requests on a smart
 
69
    medium.
 
70
    """
 
71
 
 
72
    def call(self, *args):
 
73
        """Make a remote call.
 
74
 
 
75
        :param args: the arguments of this call.
 
76
        """
 
77
        raise NotImplementedError(self.call)
 
78
 
 
79
    def call_with_body_bytes(self, args, body):
 
80
        """Make a remote call with a body.
 
81
 
 
82
        :param args: the arguments of this call.
 
83
        :type body: str
 
84
        :param body: the body to send with the request.
 
85
        """
 
86
        raise NotImplementedError(self.call_with_body_bytes)
 
87
 
 
88
    def call_with_body_readv_array(self, args, body):
 
89
        """Make a remote call with a readv array.
 
90
 
 
91
        :param args: the arguments of this call.
 
92
        :type body: iterable of (start, length) tuples.
 
93
        :param body: the readv ranges to send with this request.
 
94
        """
 
95
        raise NotImplementedError(self.call_with_body_readv_array)
 
96
 
 
97
    def set_headers(self, headers):
 
98
        raise NotImplementedError(self.set_headers)
 
99
 
 
100
 
 
101
class SmartProtocolBase(object):
 
102
    """Methods common to client and server"""
 
103
 
 
104
    # TODO: this only actually accomodates a single block; possibly should
 
105
    # support multiple chunks?
 
106
    def _encode_bulk_data(self, body):
 
107
        """Encode body as a bulk data chunk."""
 
108
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
109
 
 
110
    def _serialise_offsets(self, offsets):
 
111
        """Serialise a readv offset list."""
 
112
        txt = []
 
113
        for start, length in offsets:
 
114
            txt.append('%d,%d' % (start, length))
 
115
        return '\n'.join(txt)
 
116
 
 
117
 
 
118
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
119
    """Server-side encoding and decoding logic for smart version 1."""
 
120
 
 
121
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
122
            jail_root=None):
 
123
        self._backing_transport = backing_transport
 
124
        self._root_client_path = root_client_path
 
125
        self._jail_root = jail_root
 
126
        self.unused_data = ''
 
127
        self._finished = False
 
128
        self.in_buffer = ''
 
129
        self._has_dispatched = False
 
130
        self.request = None
 
131
        self._body_decoder = None
 
132
        self._write_func = write_func
 
133
 
 
134
    def accept_bytes(self, bytes):
 
135
        """Take bytes, and advance the internal state machine appropriately.
 
136
 
 
137
        :param bytes: must be a byte string
 
138
        """
 
139
        if not isinstance(bytes, str):
 
140
            raise ValueError(bytes)
 
141
        self.in_buffer += bytes
 
142
        if not self._has_dispatched:
 
143
            if '\n' not in self.in_buffer:
 
144
                # no command line yet
 
145
                return
 
146
            self._has_dispatched = True
 
147
            try:
 
148
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
149
                first_line += '\n'
 
150
                req_args = _decode_tuple(first_line)
 
151
                self.request = request.SmartServerRequestHandler(
 
152
                    self._backing_transport, commands=request.request_handlers,
 
153
                    root_client_path=self._root_client_path,
 
154
                    jail_root=self._jail_root)
 
155
                self.request.args_received(req_args)
 
156
                if self.request.finished_reading:
 
157
                    # trivial request
 
158
                    self.unused_data = self.in_buffer
 
159
                    self.in_buffer = ''
 
160
                    self._send_response(self.request.response)
 
161
            except KeyboardInterrupt:
 
162
                raise
 
163
            except errors.UnknownSmartMethod, err:
 
164
                protocol_error = errors.SmartProtocolError(
 
165
                    "bad request %r" % (err.verb,))
 
166
                failure = request.FailedSmartServerResponse(
 
167
                    ('error', str(protocol_error)))
 
168
                self._send_response(failure)
 
169
                return
 
170
            except Exception, exception:
 
171
                # everything else: pass to client, flush, and quit
 
172
                log_exception_quietly()
 
173
                self._send_response(request.FailedSmartServerResponse(
 
174
                    ('error', str(exception))))
 
175
                return
 
176
 
 
177
        if self._has_dispatched:
 
178
            if self._finished:
 
179
                # nothing to do.XXX: this routine should be a single state
 
180
                # machine too.
 
181
                self.unused_data += self.in_buffer
 
182
                self.in_buffer = ''
 
183
                return
 
184
            if self._body_decoder is None:
 
185
                self._body_decoder = LengthPrefixedBodyDecoder()
 
186
            self._body_decoder.accept_bytes(self.in_buffer)
 
187
            self.in_buffer = self._body_decoder.unused_data
 
188
            body_data = self._body_decoder.read_pending_data()
 
189
            self.request.accept_body(body_data)
 
190
            if self._body_decoder.finished_reading:
 
191
                self.request.end_of_body()
 
192
                if not self.request.finished_reading:
 
193
                    raise AssertionError("no more body, request not finished")
 
194
            if self.request.response is not None:
 
195
                self._send_response(self.request.response)
 
196
                self.unused_data = self.in_buffer
 
197
                self.in_buffer = ''
 
198
            else:
 
199
                if self.request.finished_reading:
 
200
                    raise AssertionError(
 
201
                        "no response and we have finished reading.")
 
202
 
 
203
    def _send_response(self, response):
 
204
        """Send a smart server response down the output stream."""
 
205
        if self._finished:
 
206
            raise AssertionError('response already sent')
 
207
        args = response.args
 
208
        body = response.body
 
209
        self._finished = True
 
210
        self._write_protocol_version()
 
211
        self._write_success_or_failure_prefix(response)
 
212
        self._write_func(_encode_tuple(args))
 
213
        if body is not None:
 
214
            if not isinstance(body, str):
 
215
                raise ValueError(body)
 
216
            bytes = self._encode_bulk_data(body)
 
217
            self._write_func(bytes)
 
218
 
 
219
    def _write_protocol_version(self):
 
220
        """Write any prefixes this protocol requires.
 
221
 
 
222
        Version one doesn't send protocol versions.
 
223
        """
 
224
 
 
225
    def _write_success_or_failure_prefix(self, response):
 
226
        """Write the protocol specific success/failure prefix.
 
227
 
 
228
        For SmartServerRequestProtocolOne this is omitted but we
 
229
        call is_successful to ensure that the response is valid.
 
230
        """
 
231
        response.is_successful()
 
232
 
 
233
    def next_read_size(self):
 
234
        if self._finished:
 
235
            return 0
 
236
        if self._body_decoder is None:
 
237
            return 1
 
238
        else:
 
239
            return self._body_decoder.next_read_size()
 
240
 
 
241
 
 
242
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
243
    r"""Version two of the server side of the smart protocol.
 
244
 
 
245
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
246
    """
 
247
 
 
248
    response_marker = RESPONSE_VERSION_TWO
 
249
    request_marker = REQUEST_VERSION_TWO
 
250
 
 
251
    def _write_success_or_failure_prefix(self, response):
 
252
        """Write the protocol specific success/failure prefix."""
 
253
        if response.is_successful():
 
254
            self._write_func('success\n')
 
255
        else:
 
256
            self._write_func('failed\n')
 
257
 
 
258
    def _write_protocol_version(self):
 
259
        r"""Write any prefixes this protocol requires.
 
260
 
 
261
        Version two sends the value of RESPONSE_VERSION_TWO.
 
262
        """
 
263
        self._write_func(self.response_marker)
 
264
 
 
265
    def _send_response(self, response):
 
266
        """Send a smart server response down the output stream."""
 
267
        if (self._finished):
 
268
            raise AssertionError('response already sent')
 
269
        self._finished = True
 
270
        self._write_protocol_version()
 
271
        self._write_success_or_failure_prefix(response)
 
272
        self._write_func(_encode_tuple(response.args))
 
273
        if response.body is not None:
 
274
            if not isinstance(response.body, str):
 
275
                raise AssertionError('body must be a str')
 
276
            if not (response.body_stream is None):
 
277
                raise AssertionError(
 
278
                    'body_stream and body cannot both be set')
 
279
            bytes = self._encode_bulk_data(response.body)
 
280
            self._write_func(bytes)
 
281
        elif response.body_stream is not None:
 
282
            _send_stream(response.body_stream, self._write_func)
 
283
 
 
284
 
 
285
def _send_stream(stream, write_func):
 
286
    write_func('chunked\n')
 
287
    _send_chunks(stream, write_func)
 
288
    write_func('END\n')
 
289
 
 
290
 
 
291
def _send_chunks(stream, write_func):
 
292
    for chunk in stream:
 
293
        if isinstance(chunk, str):
 
294
            bytes = "%x\n%s" % (len(chunk), chunk)
 
295
            write_func(bytes)
 
296
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
297
            write_func('ERR\n')
 
298
            _send_chunks(chunk.args, write_func)
 
299
            return
 
300
        else:
 
301
            raise errors.BzrError(
 
302
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
303
                % chunk)
 
304
 
 
305
 
 
306
class _NeedMoreBytes(Exception):
 
307
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
308
    have been received.
 
309
    """
 
310
 
 
311
    def __init__(self, count=None):
 
312
        """Constructor.
 
313
 
 
314
        :param count: the total number of bytes needed by the current state.
 
315
            May be None if the number of bytes needed is unknown.
 
316
        """
 
317
        self.count = count
 
318
 
 
319
 
 
320
class _StatefulDecoder(object):
 
321
    """Base class for writing state machines to decode byte streams.
 
322
 
 
323
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
324
    and, if appropriate, updates self.state_accept to a different function.
 
325
    accept_bytes will call state_accept as often as necessary to make sure the
 
326
    state machine has progressed as far as possible before it returns.
 
327
 
 
328
    See ProtocolThreeDecoder for an example subclass.
 
329
    """
 
330
 
 
331
    def __init__(self):
 
332
        self.finished_reading = False
 
333
        self._in_buffer_list = []
 
334
        self._in_buffer_len = 0
 
335
        self.unused_data = ''
 
336
        self.bytes_left = None
 
337
        self._number_needed_bytes = None
 
338
 
 
339
    def _get_in_buffer(self):
 
340
        if len(self._in_buffer_list) == 1:
 
341
            return self._in_buffer_list[0]
 
342
        in_buffer = ''.join(self._in_buffer_list)
 
343
        if len(in_buffer) != self._in_buffer_len:
 
344
            raise AssertionError(
 
345
                "Length of buffer did not match expected value: %s != %s"
 
346
                % self._in_buffer_len, len(in_buffer))
 
347
        self._in_buffer_list = [in_buffer]
 
348
        return in_buffer
 
349
 
 
350
    def _get_in_bytes(self, count):
 
351
        """Grab X bytes from the input_buffer.
 
352
 
 
353
        Callers should have already checked that self._in_buffer_len is >
 
354
        count. Note, this does not consume the bytes from the buffer. The
 
355
        caller will still need to call _get_in_buffer() and then
 
356
        _set_in_buffer() if they actually need to consume the bytes.
 
357
        """
 
358
        # check if we can yield the bytes from just the first entry in our list
 
359
        if len(self._in_buffer_list) == 0:
 
360
            raise AssertionError('Callers must be sure we have buffered bytes'
 
361
                ' before calling _get_in_bytes')
 
362
        if len(self._in_buffer_list[0]) > count:
 
363
            return self._in_buffer_list[0][:count]
 
364
        # We can't yield it from the first buffer, so collapse all buffers, and
 
365
        # yield it from that
 
366
        in_buf = self._get_in_buffer()
 
367
        return in_buf[:count]
 
368
 
 
369
    def _set_in_buffer(self, new_buf):
 
370
        if new_buf is not None:
 
371
            self._in_buffer_list = [new_buf]
 
372
            self._in_buffer_len = len(new_buf)
 
373
        else:
 
374
            self._in_buffer_list = []
 
375
            self._in_buffer_len = 0
 
376
 
 
377
    def accept_bytes(self, bytes):
 
378
        """Decode as much of bytes as possible.
 
379
 
 
380
        If 'bytes' contains too much data it will be appended to
 
381
        self.unused_data.
 
382
 
 
383
        finished_reading will be set when no more data is required.  Further
 
384
        data will be appended to self.unused_data.
 
385
        """
 
386
        # accept_bytes is allowed to change the state
 
387
        self._number_needed_bytes = None
 
388
        # lsprof puts a very large amount of time on this specific call for
 
389
        # large readv arrays
 
390
        self._in_buffer_list.append(bytes)
 
391
        self._in_buffer_len += len(bytes)
 
392
        try:
 
393
            # Run the function for the current state.
 
394
            current_state = self.state_accept
 
395
            self.state_accept()
 
396
            while current_state != self.state_accept:
 
397
                # The current state has changed.  Run the function for the new
 
398
                # current state, so that it can:
 
399
                #   - decode any unconsumed bytes left in a buffer, and
 
400
                #   - signal how many more bytes are expected (via raising
 
401
                #     _NeedMoreBytes).
 
402
                current_state = self.state_accept
 
403
                self.state_accept()
 
404
        except _NeedMoreBytes, e:
 
405
            self._number_needed_bytes = e.count
 
406
 
 
407
 
 
408
class ChunkedBodyDecoder(_StatefulDecoder):
 
409
    """Decoder for chunked body data.
 
410
 
 
411
    This is very similar the HTTP's chunked encoding.  See the description of
 
412
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
413
    """
 
414
 
 
415
    def __init__(self):
 
416
        _StatefulDecoder.__init__(self)
 
417
        self.state_accept = self._state_accept_expecting_header
 
418
        self.chunk_in_progress = None
 
419
        self.chunks = collections.deque()
 
420
        self.error = False
 
421
        self.error_in_progress = None
 
422
 
 
423
    def next_read_size(self):
 
424
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
425
        # end-of-body marker is 4 bytes: 'END\n'.
 
426
        if self.state_accept == self._state_accept_reading_chunk:
 
427
            # We're expecting more chunk content.  So we're expecting at least
 
428
            # the rest of this chunk plus an END chunk.
 
429
            return self.bytes_left + 4
 
430
        elif self.state_accept == self._state_accept_expecting_length:
 
431
            if self._in_buffer_len == 0:
 
432
                # We're expecting a chunk length.  There's at least two bytes
 
433
                # left: a digit plus '\n'.
 
434
                return 2
 
435
            else:
 
436
                # We're in the middle of reading a chunk length.  So there's at
 
437
                # least one byte left, the '\n' that terminates the length.
 
438
                return 1
 
439
        elif self.state_accept == self._state_accept_reading_unused:
 
440
            return 1
 
441
        elif self.state_accept == self._state_accept_expecting_header:
 
442
            return max(0, len('chunked\n') - self._in_buffer_len)
 
443
        else:
 
444
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
445
 
 
446
    def read_next_chunk(self):
 
447
        try:
 
448
            return self.chunks.popleft()
 
449
        except IndexError:
 
450
            return None
 
451
 
 
452
    def _extract_line(self):
 
453
        in_buf = self._get_in_buffer()
 
454
        pos = in_buf.find('\n')
 
455
        if pos == -1:
 
456
            # We haven't read a complete line yet, so request more bytes before
 
457
            # we continue.
 
458
            raise _NeedMoreBytes(1)
 
459
        line = in_buf[:pos]
 
460
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
461
        self._set_in_buffer(in_buf[pos+1:])
 
462
        return line
 
463
 
 
464
    def _finished(self):
 
465
        self.unused_data = self._get_in_buffer()
 
466
        self._in_buffer_list = []
 
467
        self._in_buffer_len = 0
 
468
        self.state_accept = self._state_accept_reading_unused
 
469
        if self.error:
 
470
            error_args = tuple(self.error_in_progress)
 
471
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
472
            self.error_in_progress = None
 
473
        self.finished_reading = True
 
474
 
 
475
    def _state_accept_expecting_header(self):
 
476
        prefix = self._extract_line()
 
477
        if prefix == 'chunked':
 
478
            self.state_accept = self._state_accept_expecting_length
 
479
        else:
 
480
            raise errors.SmartProtocolError(
 
481
                'Bad chunked body header: "%s"' % (prefix,))
 
482
 
 
483
    def _state_accept_expecting_length(self):
 
484
        prefix = self._extract_line()
 
485
        if prefix == 'ERR':
 
486
            self.error = True
 
487
            self.error_in_progress = []
 
488
            self._state_accept_expecting_length()
 
489
            return
 
490
        elif prefix == 'END':
 
491
            # We've read the end-of-body marker.
 
492
            # Any further bytes are unused data, including the bytes left in
 
493
            # the _in_buffer.
 
494
            self._finished()
 
495
            return
 
496
        else:
 
497
            self.bytes_left = int(prefix, 16)
 
498
            self.chunk_in_progress = ''
 
499
            self.state_accept = self._state_accept_reading_chunk
 
500
 
 
501
    def _state_accept_reading_chunk(self):
 
502
        in_buf = self._get_in_buffer()
 
503
        in_buffer_len = len(in_buf)
 
504
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
505
        self._set_in_buffer(in_buf[self.bytes_left:])
 
506
        self.bytes_left -= in_buffer_len
 
507
        if self.bytes_left <= 0:
 
508
            # Finished with chunk
 
509
            self.bytes_left = None
 
510
            if self.error:
 
511
                self.error_in_progress.append(self.chunk_in_progress)
 
512
            else:
 
513
                self.chunks.append(self.chunk_in_progress)
 
514
            self.chunk_in_progress = None
 
515
            self.state_accept = self._state_accept_expecting_length
 
516
 
 
517
    def _state_accept_reading_unused(self):
 
518
        self.unused_data += self._get_in_buffer()
 
519
        self._in_buffer_list = []
 
520
 
 
521
 
 
522
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
523
    """Decodes the length-prefixed bulk data."""
 
524
 
 
525
    def __init__(self):
 
526
        _StatefulDecoder.__init__(self)
 
527
        self.state_accept = self._state_accept_expecting_length
 
528
        self.state_read = self._state_read_no_data
 
529
        self._body = ''
 
530
        self._trailer_buffer = ''
 
531
 
 
532
    def next_read_size(self):
 
533
        if self.bytes_left is not None:
 
534
            # Ideally we want to read all the remainder of the body and the
 
535
            # trailer in one go.
 
536
            return self.bytes_left + 5
 
537
        elif self.state_accept == self._state_accept_reading_trailer:
 
538
            # Just the trailer left
 
539
            return 5 - len(self._trailer_buffer)
 
540
        elif self.state_accept == self._state_accept_expecting_length:
 
541
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
542
            # 'done\n').
 
543
            return 6
 
544
        else:
 
545
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
546
            return 1
 
547
 
 
548
    def read_pending_data(self):
 
549
        """Return any pending data that has been decoded."""
 
550
        return self.state_read()
 
551
 
 
552
    def _state_accept_expecting_length(self):
 
553
        in_buf = self._get_in_buffer()
 
554
        pos = in_buf.find('\n')
 
555
        if pos == -1:
 
556
            return
 
557
        self.bytes_left = int(in_buf[:pos])
 
558
        self._set_in_buffer(in_buf[pos+1:])
 
559
        self.state_accept = self._state_accept_reading_body
 
560
        self.state_read = self._state_read_body_buffer
 
561
 
 
562
    def _state_accept_reading_body(self):
 
563
        in_buf = self._get_in_buffer()
 
564
        self._body += in_buf
 
565
        self.bytes_left -= len(in_buf)
 
566
        self._set_in_buffer(None)
 
567
        if self.bytes_left <= 0:
 
568
            # Finished with body
 
569
            if self.bytes_left != 0:
 
570
                self._trailer_buffer = self._body[self.bytes_left:]
 
571
                self._body = self._body[:self.bytes_left]
 
572
            self.bytes_left = None
 
573
            self.state_accept = self._state_accept_reading_trailer
 
574
 
 
575
    def _state_accept_reading_trailer(self):
 
576
        self._trailer_buffer += self._get_in_buffer()
 
577
        self._set_in_buffer(None)
 
578
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
579
        # a ProtocolViolation exception?
 
580
        if self._trailer_buffer.startswith('done\n'):
 
581
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
582
            self.state_accept = self._state_accept_reading_unused
 
583
            self.finished_reading = True
 
584
 
 
585
    def _state_accept_reading_unused(self):
 
586
        self.unused_data += self._get_in_buffer()
 
587
        self._set_in_buffer(None)
 
588
 
 
589
    def _state_read_no_data(self):
 
590
        return ''
 
591
 
 
592
    def _state_read_body_buffer(self):
 
593
        result = self._body
 
594
        self._body = ''
 
595
        return result
 
596
 
 
597
 
 
598
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
599
                                    message.ResponseHandler):
 
600
    """The client-side protocol for smart version 1."""
 
601
 
 
602
    def __init__(self, request):
 
603
        """Construct a SmartClientRequestProtocolOne.
 
604
 
 
605
        :param request: A SmartClientMediumRequest to serialise onto and
 
606
            deserialise from.
 
607
        """
 
608
        self._request = request
 
609
        self._body_buffer = None
 
610
        self._request_start_time = None
 
611
        self._last_verb = None
 
612
        self._headers = None
 
613
 
 
614
    def set_headers(self, headers):
 
615
        self._headers = dict(headers)
 
616
 
 
617
    def call(self, *args):
 
618
        if 'hpss' in debug.debug_flags:
 
619
            mutter('hpss call:   %s', repr(args)[1:-1])
 
620
            if getattr(self._request._medium, 'base', None) is not None:
 
621
                mutter('             (to %s)', self._request._medium.base)
 
622
            self._request_start_time = osutils.timer_func()
 
623
        self._write_args(args)
 
624
        self._request.finished_writing()
 
625
        self._last_verb = args[0]
 
626
 
 
627
    def call_with_body_bytes(self, args, body):
 
628
        """Make a remote call of args with body bytes 'body'.
 
629
 
 
630
        After calling this, call read_response_tuple to find the result out.
 
631
        """
 
632
        if 'hpss' in debug.debug_flags:
 
633
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
634
            if getattr(self._request._medium, '_path', None) is not None:
 
635
                mutter('                  (to %s)', self._request._medium._path)
 
636
            mutter('              %d bytes', len(body))
 
637
            self._request_start_time = osutils.timer_func()
 
638
            if 'hpssdetail' in debug.debug_flags:
 
639
                mutter('hpss body content: %s', body)
 
640
        self._write_args(args)
 
641
        bytes = self._encode_bulk_data(body)
 
642
        self._request.accept_bytes(bytes)
 
643
        self._request.finished_writing()
 
644
        self._last_verb = args[0]
 
645
 
 
646
    def call_with_body_readv_array(self, args, body):
 
647
        """Make a remote call with a readv array.
 
648
 
 
649
        The body is encoded with one line per readv offset pair. The numbers in
 
650
        each pair are separated by a comma, and no trailing \n is emitted.
 
651
        """
 
652
        if 'hpss' in debug.debug_flags:
 
653
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
654
            if getattr(self._request._medium, '_path', None) is not None:
 
655
                mutter('                  (to %s)', self._request._medium._path)
 
656
            self._request_start_time = osutils.timer_func()
 
657
        self._write_args(args)
 
658
        readv_bytes = self._serialise_offsets(body)
 
659
        bytes = self._encode_bulk_data(readv_bytes)
 
660
        self._request.accept_bytes(bytes)
 
661
        self._request.finished_writing()
 
662
        if 'hpss' in debug.debug_flags:
 
663
            mutter('              %d bytes in readv request', len(readv_bytes))
 
664
        self._last_verb = args[0]
 
665
 
 
666
    def call_with_body_stream(self, args, stream):
 
667
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
668
        # assume that a v1/v2 server doesn't support whatever method we're
 
669
        # trying to call with a body stream.
 
670
        self._request.finished_writing()
 
671
        self._request.finished_reading()
 
672
        raise errors.UnknownSmartMethod(args[0])
 
673
 
 
674
    def cancel_read_body(self):
 
675
        """After expecting a body, a response code may indicate one otherwise.
 
676
 
 
677
        This method lets the domain client inform the protocol that no body
 
678
        will be transmitted. This is a terminal method: after calling it the
 
679
        protocol is not able to be used further.
 
680
        """
 
681
        self._request.finished_reading()
 
682
 
 
683
    def _read_response_tuple(self):
 
684
        result = self._recv_tuple()
 
685
        if 'hpss' in debug.debug_flags:
 
686
            if self._request_start_time is not None:
 
687
                mutter('   result:   %6.3fs  %s',
 
688
                       osutils.timer_func() - self._request_start_time,
 
689
                       repr(result)[1:-1])
 
690
                self._request_start_time = None
 
691
            else:
 
692
                mutter('   result:   %s', repr(result)[1:-1])
 
693
        return result
 
694
 
 
695
    def read_response_tuple(self, expect_body=False):
 
696
        """Read a response tuple from the wire.
 
697
 
 
698
        This should only be called once.
 
699
        """
 
700
        result = self._read_response_tuple()
 
701
        self._response_is_unknown_method(result)
 
702
        self._raise_args_if_error(result)
 
703
        if not expect_body:
 
704
            self._request.finished_reading()
 
705
        return result
 
706
 
 
707
    def _raise_args_if_error(self, result_tuple):
 
708
        # Later protocol versions have an explicit flag in the protocol to say
 
709
        # if an error response is "failed" or not.  In version 1 we don't have
 
710
        # that luxury.  So here is a complete list of errors that can be
 
711
        # returned in response to existing version 1 smart requests.  Responses
 
712
        # starting with these codes are always "failed" responses.
 
713
        v1_error_codes = [
 
714
            'norepository',
 
715
            'NoSuchFile',
 
716
            'FileExists',
 
717
            'DirectoryNotEmpty',
 
718
            'ShortReadvError',
 
719
            'UnicodeEncodeError',
 
720
            'UnicodeDecodeError',
 
721
            'ReadOnlyError',
 
722
            'nobranch',
 
723
            'NoSuchRevision',
 
724
            'nosuchrevision',
 
725
            'LockContention',
 
726
            'UnlockableTransport',
 
727
            'LockFailed',
 
728
            'TokenMismatch',
 
729
            'ReadError',
 
730
            'PermissionDenied',
 
731
            ]
 
732
        if result_tuple[0] in v1_error_codes:
 
733
            self._request.finished_reading()
 
734
            raise errors.ErrorFromSmartServer(result_tuple)
 
735
 
 
736
    def _response_is_unknown_method(self, result_tuple):
 
737
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
738
        method' response to the request.
 
739
 
 
740
        :param response: The response from a smart client call_expecting_body
 
741
            call.
 
742
        :param verb: The verb used in that call.
 
743
        :raises: UnexpectedSmartServerResponse
 
744
        """
 
745
        if (result_tuple == ('error', "Generic bzr smart protocol error: "
 
746
                "bad request '%s'" % self._last_verb) or
 
747
              result_tuple == ('error', "Generic bzr smart protocol error: "
 
748
                "bad request u'%s'" % self._last_verb)):
 
749
            # The response will have no body, so we've finished reading.
 
750
            self._request.finished_reading()
 
751
            raise errors.UnknownSmartMethod(self._last_verb)
 
752
 
 
753
    def read_body_bytes(self, count=-1):
 
754
        """Read bytes from the body, decoding into a byte stream.
 
755
 
 
756
        We read all bytes at once to ensure we've checked the trailer for
 
757
        errors, and then feed the buffer back as read_body_bytes is called.
 
758
        """
 
759
        if self._body_buffer is not None:
 
760
            return self._body_buffer.read(count)
 
761
        _body_decoder = LengthPrefixedBodyDecoder()
 
762
 
 
763
        while not _body_decoder.finished_reading:
 
764
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
765
            if bytes == '':
 
766
                # end of file encountered reading from server
 
767
                raise errors.ConnectionReset(
 
768
                    "Connection lost while reading response body.")
 
769
            _body_decoder.accept_bytes(bytes)
 
770
        self._request.finished_reading()
 
771
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
772
        # XXX: TODO check the trailer result.
 
773
        if 'hpss' in debug.debug_flags:
 
774
            mutter('              %d body bytes read',
 
775
                   len(self._body_buffer.getvalue()))
 
776
        return self._body_buffer.read(count)
 
777
 
 
778
    def _recv_tuple(self):
 
779
        """Receive a tuple from the medium request."""
 
780
        return _decode_tuple(self._request.read_line())
 
781
 
 
782
    def query_version(self):
 
783
        """Return protocol version number of the server."""
 
784
        self.call('hello')
 
785
        resp = self.read_response_tuple()
 
786
        if resp == ('ok', '1'):
 
787
            return 1
 
788
        elif resp == ('ok', '2'):
 
789
            return 2
 
790
        else:
 
791
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
792
 
 
793
    def _write_args(self, args):
 
794
        self._write_protocol_version()
 
795
        bytes = _encode_tuple(args)
 
796
        self._request.accept_bytes(bytes)
 
797
 
 
798
    def _write_protocol_version(self):
 
799
        """Write any prefixes this protocol requires.
 
800
 
 
801
        Version one doesn't send protocol versions.
 
802
        """
 
803
 
 
804
 
 
805
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
806
    """Version two of the client side of the smart protocol.
 
807
 
 
808
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
809
    """
 
810
 
 
811
    response_marker = RESPONSE_VERSION_TWO
 
812
    request_marker = REQUEST_VERSION_TWO
 
813
 
 
814
    def read_response_tuple(self, expect_body=False):
 
815
        """Read a response tuple from the wire.
 
816
 
 
817
        This should only be called once.
 
818
        """
 
819
        version = self._request.read_line()
 
820
        if version != self.response_marker:
 
821
            self._request.finished_reading()
 
822
            raise errors.UnexpectedProtocolVersionMarker(version)
 
823
        response_status = self._request.read_line()
 
824
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
825
        self._response_is_unknown_method(result)
 
826
        if response_status == 'success\n':
 
827
            self.response_status = True
 
828
            if not expect_body:
 
829
                self._request.finished_reading()
 
830
            return result
 
831
        elif response_status == 'failed\n':
 
832
            self.response_status = False
 
833
            self._request.finished_reading()
 
834
            raise errors.ErrorFromSmartServer(result)
 
835
        else:
 
836
            raise errors.SmartProtocolError(
 
837
                'bad protocol status %r' % response_status)
 
838
 
 
839
    def _write_protocol_version(self):
 
840
        """Write any prefixes this protocol requires.
 
841
 
 
842
        Version two sends the value of REQUEST_VERSION_TWO.
 
843
        """
 
844
        self._request.accept_bytes(self.request_marker)
 
845
 
 
846
    def read_streamed_body(self):
 
847
        """Read bytes from the body, decoding into a byte stream.
 
848
        """
 
849
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
850
        # buffer space available) on Windows.
 
851
        _body_decoder = ChunkedBodyDecoder()
 
852
        while not _body_decoder.finished_reading:
 
853
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
854
            if bytes == '':
 
855
                # end of file encountered reading from server
 
856
                raise errors.ConnectionReset(
 
857
                    "Connection lost while reading streamed body.")
 
858
            _body_decoder.accept_bytes(bytes)
 
859
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
860
                if 'hpss' in debug.debug_flags and type(body_bytes) is str:
 
861
                    mutter('              %d byte chunk read',
 
862
                           len(body_bytes))
 
863
                yield body_bytes
 
864
        self._request.finished_reading()
 
865
 
 
866
 
 
867
def build_server_protocol_three(backing_transport, write_func,
 
868
                                root_client_path, jail_root=None):
 
869
    request_handler = request.SmartServerRequestHandler(
 
870
        backing_transport, commands=request.request_handlers,
 
871
        root_client_path=root_client_path, jail_root=jail_root)
 
872
    responder = ProtocolThreeResponder(write_func)
 
873
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
 
874
    return ProtocolThreeDecoder(message_handler)
 
875
 
 
876
 
 
877
class ProtocolThreeDecoder(_StatefulDecoder):
 
878
 
 
879
    response_marker = RESPONSE_VERSION_THREE
 
880
    request_marker = REQUEST_VERSION_THREE
 
881
 
 
882
    def __init__(self, message_handler, expect_version_marker=False):
 
883
        _StatefulDecoder.__init__(self)
 
884
        self._has_dispatched = False
 
885
        # Initial state
 
886
        if expect_version_marker:
 
887
            self.state_accept = self._state_accept_expecting_protocol_version
 
888
            # We're expecting at least the protocol version marker + some
 
889
            # headers.
 
890
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
891
        else:
 
892
            self.state_accept = self._state_accept_expecting_headers
 
893
            self._number_needed_bytes = 4
 
894
        self.decoding_failed = False
 
895
        self.request_handler = self.message_handler = message_handler
 
896
 
 
897
    def accept_bytes(self, bytes):
 
898
        self._number_needed_bytes = None
 
899
        try:
 
900
            _StatefulDecoder.accept_bytes(self, bytes)
 
901
        except KeyboardInterrupt:
 
902
            raise
 
903
        except errors.SmartMessageHandlerError, exception:
 
904
            # We do *not* set self.decoding_failed here.  The message handler
 
905
            # has raised an error, but the decoder is still able to parse bytes
 
906
            # and determine when this message ends.
 
907
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
908
                log_exception_quietly()
 
909
            self.message_handler.protocol_error(exception.exc_value)
 
910
            # The state machine is ready to continue decoding, but the
 
911
            # exception has interrupted the loop that runs the state machine.
 
912
            # So we call accept_bytes again to restart it.
 
913
            self.accept_bytes('')
 
914
        except Exception, exception:
 
915
            # The decoder itself has raised an exception.  We cannot continue
 
916
            # decoding.
 
917
            self.decoding_failed = True
 
918
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
919
                # This happens during normal operation when the client tries a
 
920
                # protocol version the server doesn't understand, so no need to
 
921
                # log a traceback every time.
 
922
                # Note that this can only happen when
 
923
                # expect_version_marker=True, which is only the case on the
 
924
                # client side.
 
925
                pass
 
926
            else:
 
927
                log_exception_quietly()
 
928
            self.message_handler.protocol_error(exception)
 
929
 
 
930
    def _extract_length_prefixed_bytes(self):
 
931
        if self._in_buffer_len < 4:
 
932
            # A length prefix by itself is 4 bytes, and we don't even have that
 
933
            # many yet.
 
934
            raise _NeedMoreBytes(4)
 
935
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
936
        end_of_bytes = 4 + length
 
937
        if self._in_buffer_len < end_of_bytes:
 
938
            # We haven't yet read as many bytes as the length-prefix says there
 
939
            # are.
 
940
            raise _NeedMoreBytes(end_of_bytes)
 
941
        # Extract the bytes from the buffer.
 
942
        in_buf = self._get_in_buffer()
 
943
        bytes = in_buf[4:end_of_bytes]
 
944
        self._set_in_buffer(in_buf[end_of_bytes:])
 
945
        return bytes
 
946
 
 
947
    def _extract_prefixed_bencoded_data(self):
 
948
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
949
        try:
 
950
            decoded = bdecode_as_tuple(prefixed_bytes)
 
951
        except ValueError:
 
952
            raise errors.SmartProtocolError(
 
953
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
954
        return decoded
 
955
 
 
956
    def _extract_single_byte(self):
 
957
        if self._in_buffer_len == 0:
 
958
            # The buffer is empty
 
959
            raise _NeedMoreBytes(1)
 
960
        in_buf = self._get_in_buffer()
 
961
        one_byte = in_buf[0]
 
962
        self._set_in_buffer(in_buf[1:])
 
963
        return one_byte
 
964
 
 
965
    def _state_accept_expecting_protocol_version(self):
 
966
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
967
        in_buf = self._get_in_buffer()
 
968
        if needed_bytes > 0:
 
969
            # We don't have enough bytes to check if the protocol version
 
970
            # marker is right.  But we can check if it is already wrong by
 
971
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
972
            # we've read so far.
 
973
            # [In fact, if the remote end isn't bzr we might never receive
 
974
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
975
            # are wrong then we should just raise immediately rather than
 
976
            # stall.]
 
977
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
978
                # We have enough bytes to know the protocol version is wrong
 
979
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
980
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
981
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
982
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
983
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
984
        self.state_accept = self._state_accept_expecting_headers
 
985
 
 
986
    def _state_accept_expecting_headers(self):
 
987
        decoded = self._extract_prefixed_bencoded_data()
 
988
        if type(decoded) is not dict:
 
989
            raise errors.SmartProtocolError(
 
990
                'Header object %r is not a dict' % (decoded,))
 
991
        self.state_accept = self._state_accept_expecting_message_part
 
992
        try:
 
993
            self.message_handler.headers_received(decoded)
 
994
        except:
 
995
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
996
 
 
997
    def _state_accept_expecting_message_part(self):
 
998
        message_part_kind = self._extract_single_byte()
 
999
        if message_part_kind == 'o':
 
1000
            self.state_accept = self._state_accept_expecting_one_byte
 
1001
        elif message_part_kind == 's':
 
1002
            self.state_accept = self._state_accept_expecting_structure
 
1003
        elif message_part_kind == 'b':
 
1004
            self.state_accept = self._state_accept_expecting_bytes
 
1005
        elif message_part_kind == 'e':
 
1006
            self.done()
 
1007
        else:
 
1008
            raise errors.SmartProtocolError(
 
1009
                'Bad message kind byte: %r' % (message_part_kind,))
 
1010
 
 
1011
    def _state_accept_expecting_one_byte(self):
 
1012
        byte = self._extract_single_byte()
 
1013
        self.state_accept = self._state_accept_expecting_message_part
 
1014
        try:
 
1015
            self.message_handler.byte_part_received(byte)
 
1016
        except:
 
1017
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1018
 
 
1019
    def _state_accept_expecting_bytes(self):
 
1020
        # XXX: this should not buffer whole message part, but instead deliver
 
1021
        # the bytes as they arrive.
 
1022
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1023
        self.state_accept = self._state_accept_expecting_message_part
 
1024
        try:
 
1025
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1026
        except:
 
1027
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1028
 
 
1029
    def _state_accept_expecting_structure(self):
 
1030
        structure = self._extract_prefixed_bencoded_data()
 
1031
        self.state_accept = self._state_accept_expecting_message_part
 
1032
        try:
 
1033
            self.message_handler.structure_part_received(structure)
 
1034
        except:
 
1035
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1036
 
 
1037
    def done(self):
 
1038
        self.unused_data = self._get_in_buffer()
 
1039
        self._set_in_buffer(None)
 
1040
        self.state_accept = self._state_accept_reading_unused
 
1041
        try:
 
1042
            self.message_handler.end_received()
 
1043
        except:
 
1044
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1045
 
 
1046
    def _state_accept_reading_unused(self):
 
1047
        self.unused_data += self._get_in_buffer()
 
1048
        self._set_in_buffer(None)
 
1049
 
 
1050
    def next_read_size(self):
 
1051
        if self.state_accept == self._state_accept_reading_unused:
 
1052
            return 0
 
1053
        elif self.decoding_failed:
 
1054
            # An exception occured while processing this message, probably from
 
1055
            # self.message_handler.  We're not sure that this state machine is
 
1056
            # in a consistent state, so just signal that we're done (i.e. give
 
1057
            # up).
 
1058
            return 0
 
1059
        else:
 
1060
            if self._number_needed_bytes is not None:
 
1061
                return self._number_needed_bytes - self._in_buffer_len
 
1062
            else:
 
1063
                raise AssertionError("don't know how many bytes are expected!")
 
1064
 
 
1065
 
 
1066
class _ProtocolThreeEncoder(object):
 
1067
 
 
1068
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1069
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
 
1070
 
 
1071
    def __init__(self, write_func):
 
1072
        self._buf = []
 
1073
        self._buf_len = 0
 
1074
        self._real_write_func = write_func
 
1075
 
 
1076
    def _write_func(self, bytes):
 
1077
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1078
        #       for total number of bytes to write, rather than buffer based on
 
1079
        #       the number of write() calls
 
1080
        # TODO: Another possibility would be to turn this into an async model.
 
1081
        #       Where we let another thread know that we have some bytes if
 
1082
        #       they want it, but we don't actually block for it
 
1083
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1084
        #       we might just push out smaller bits at a time?
 
1085
        self._buf.append(bytes)
 
1086
        self._buf_len += len(bytes)
 
1087
        if self._buf_len > self.BUFFER_SIZE:
 
1088
            self.flush()
 
1089
 
 
1090
    def flush(self):
 
1091
        if self._buf:
 
1092
            self._real_write_func(''.join(self._buf))
 
1093
            del self._buf[:]
 
1094
            self._buf_len = 0
 
1095
 
 
1096
    def _serialise_offsets(self, offsets):
 
1097
        """Serialise a readv offset list."""
 
1098
        txt = []
 
1099
        for start, length in offsets:
 
1100
            txt.append('%d,%d' % (start, length))
 
1101
        return '\n'.join(txt)
 
1102
 
 
1103
    def _write_protocol_version(self):
 
1104
        self._write_func(MESSAGE_VERSION_THREE)
 
1105
 
 
1106
    def _write_prefixed_bencode(self, structure):
 
1107
        bytes = bencode(structure)
 
1108
        self._write_func(struct.pack('!L', len(bytes)))
 
1109
        self._write_func(bytes)
 
1110
 
 
1111
    def _write_headers(self, headers):
 
1112
        self._write_prefixed_bencode(headers)
 
1113
 
 
1114
    def _write_structure(self, args):
 
1115
        self._write_func('s')
 
1116
        utf8_args = []
 
1117
        for arg in args:
 
1118
            if type(arg) is unicode:
 
1119
                utf8_args.append(arg.encode('utf8'))
 
1120
            else:
 
1121
                utf8_args.append(arg)
 
1122
        self._write_prefixed_bencode(utf8_args)
 
1123
 
 
1124
    def _write_end(self):
 
1125
        self._write_func('e')
 
1126
        self.flush()
 
1127
 
 
1128
    def _write_prefixed_body(self, bytes):
 
1129
        self._write_func('b')
 
1130
        self._write_func(struct.pack('!L', len(bytes)))
 
1131
        self._write_func(bytes)
 
1132
 
 
1133
    def _write_chunked_body_start(self):
 
1134
        self._write_func('oC')
 
1135
 
 
1136
    def _write_error_status(self):
 
1137
        self._write_func('oE')
 
1138
 
 
1139
    def _write_success_status(self):
 
1140
        self._write_func('oS')
 
1141
 
 
1142
 
 
1143
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1144
 
 
1145
    def __init__(self, write_func):
 
1146
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1147
        self.response_sent = False
 
1148
        self._headers = {'Software version': bzrlib.__version__}
 
1149
        if 'hpss' in debug.debug_flags:
 
1150
            self._thread_id = threading.currentThread().get_ident()
 
1151
            self._response_start_time = None
 
1152
 
 
1153
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1154
        if self._response_start_time is None:
 
1155
            self._response_start_time = osutils.timer_func()
 
1156
        if include_time:
 
1157
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1158
        else:
 
1159
            t = ''
 
1160
        if extra_bytes is None:
 
1161
            extra = ''
 
1162
        else:
 
1163
            extra = ' ' + repr(extra_bytes[:40])
 
1164
            if len(extra) > 33:
 
1165
                extra = extra[:29] + extra[-1] + '...'
 
1166
        mutter('%12s: [%s] %s%s%s'
 
1167
               % (action, self._thread_id, t, message, extra))
 
1168
 
 
1169
    def send_error(self, exception):
 
1170
        if self.response_sent:
 
1171
            raise AssertionError(
 
1172
                "send_error(%s) called, but response already sent."
 
1173
                % (exception,))
 
1174
        if isinstance(exception, errors.UnknownSmartMethod):
 
1175
            failure = request.FailedSmartServerResponse(
 
1176
                ('UnknownMethod', exception.verb))
 
1177
            self.send_response(failure)
 
1178
            return
 
1179
        if 'hpss' in debug.debug_flags:
 
1180
            self._trace('error', str(exception))
 
1181
        self.response_sent = True
 
1182
        self._write_protocol_version()
 
1183
        self._write_headers(self._headers)
 
1184
        self._write_error_status()
 
1185
        self._write_structure(('error', str(exception)))
 
1186
        self._write_end()
 
1187
 
 
1188
    def send_response(self, response):
 
1189
        if self.response_sent:
 
1190
            raise AssertionError(
 
1191
                "send_response(%r) called, but response already sent."
 
1192
                % (response,))
 
1193
        self.response_sent = True
 
1194
        self._write_protocol_version()
 
1195
        self._write_headers(self._headers)
 
1196
        if response.is_successful():
 
1197
            self._write_success_status()
 
1198
        else:
 
1199
            self._write_error_status()
 
1200
        if 'hpss' in debug.debug_flags:
 
1201
            self._trace('response', repr(response.args))
 
1202
        self._write_structure(response.args)
 
1203
        if response.body is not None:
 
1204
            self._write_prefixed_body(response.body)
 
1205
            if 'hpss' in debug.debug_flags:
 
1206
                self._trace('body', '%d bytes' % (len(response.body),),
 
1207
                            response.body, include_time=True)
 
1208
        elif response.body_stream is not None:
 
1209
            count = num_bytes = 0
 
1210
            first_chunk = None
 
1211
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1212
                count += 1
 
1213
                if exc_info is not None:
 
1214
                    self._write_error_status()
 
1215
                    error_struct = request._translate_error(exc_info[1])
 
1216
                    self._write_structure(error_struct)
 
1217
                    break
 
1218
                else:
 
1219
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1220
                        self._write_error_status()
 
1221
                        self._write_structure(chunk.args)
 
1222
                        break
 
1223
                    num_bytes += len(chunk)
 
1224
                    if first_chunk is None:
 
1225
                        first_chunk = chunk
 
1226
                    self._write_prefixed_body(chunk)
 
1227
                    if 'hpssdetail' in debug.debug_flags:
 
1228
                        # Not worth timing separately, as _write_func is
 
1229
                        # actually buffered
 
1230
                        self._trace('body chunk',
 
1231
                                    '%d bytes' % (len(chunk),),
 
1232
                                    chunk, suppress_time=True)
 
1233
            if 'hpss' in debug.debug_flags:
 
1234
                self._trace('body stream',
 
1235
                            '%d bytes %d chunks' % (num_bytes, count),
 
1236
                            first_chunk)
 
1237
        self._write_end()
 
1238
        if 'hpss' in debug.debug_flags:
 
1239
            self._trace('response end', '', include_time=True)
 
1240
 
 
1241
 
 
1242
def _iter_with_errors(iterable):
 
1243
    """Handle errors from iterable.next().
 
1244
 
 
1245
    Use like::
 
1246
 
 
1247
        for exc_info, value in _iter_with_errors(iterable):
 
1248
            ...
 
1249
 
 
1250
    This is a safer alternative to::
 
1251
 
 
1252
        try:
 
1253
            for value in iterable:
 
1254
               ...
 
1255
        except:
 
1256
            ...
 
1257
 
 
1258
    Because the latter will catch errors from the for-loop body, not just
 
1259
    iterable.next()
 
1260
 
 
1261
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1262
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1263
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1264
    will not be itercepted.
 
1265
    """
 
1266
    iterator = iter(iterable)
 
1267
    while True:
 
1268
        try:
 
1269
            yield None, iterator.next()
 
1270
        except StopIteration:
 
1271
            return
 
1272
        except (KeyboardInterrupt, SystemExit):
 
1273
            raise
 
1274
        except Exception:
 
1275
            mutter('_iter_with_errors caught error')
 
1276
            log_exception_quietly()
 
1277
            yield sys.exc_info(), None
 
1278
            return
 
1279
 
 
1280
 
 
1281
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1282
 
 
1283
    def __init__(self, medium_request):
 
1284
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1285
        self._medium_request = medium_request
 
1286
        self._headers = {}
 
1287
 
 
1288
    def set_headers(self, headers):
 
1289
        self._headers = headers.copy()
 
1290
 
 
1291
    def call(self, *args):
 
1292
        if 'hpss' in debug.debug_flags:
 
1293
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1294
            base = getattr(self._medium_request._medium, 'base', None)
 
1295
            if base is not None:
 
1296
                mutter('             (to %s)', base)
 
1297
            self._request_start_time = osutils.timer_func()
 
1298
        self._write_protocol_version()
 
1299
        self._write_headers(self._headers)
 
1300
        self._write_structure(args)
 
1301
        self._write_end()
 
1302
        self._medium_request.finished_writing()
 
1303
 
 
1304
    def call_with_body_bytes(self, args, body):
 
1305
        """Make a remote call of args with body bytes 'body'.
 
1306
 
 
1307
        After calling this, call read_response_tuple to find the result out.
 
1308
        """
 
1309
        if 'hpss' in debug.debug_flags:
 
1310
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1311
            path = getattr(self._medium_request._medium, '_path', None)
 
1312
            if path is not None:
 
1313
                mutter('                  (to %s)', path)
 
1314
            mutter('              %d bytes', len(body))
 
1315
            self._request_start_time = osutils.timer_func()
 
1316
        self._write_protocol_version()
 
1317
        self._write_headers(self._headers)
 
1318
        self._write_structure(args)
 
1319
        self._write_prefixed_body(body)
 
1320
        self._write_end()
 
1321
        self._medium_request.finished_writing()
 
1322
 
 
1323
    def call_with_body_readv_array(self, args, body):
 
1324
        """Make a remote call with a readv array.
 
1325
 
 
1326
        The body is encoded with one line per readv offset pair. The numbers in
 
1327
        each pair are separated by a comma, and no trailing \n is emitted.
 
1328
        """
 
1329
        if 'hpss' in debug.debug_flags:
 
1330
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1331
            path = getattr(self._medium_request._medium, '_path', None)
 
1332
            if path is not None:
 
1333
                mutter('                  (to %s)', path)
 
1334
            self._request_start_time = osutils.timer_func()
 
1335
        self._write_protocol_version()
 
1336
        self._write_headers(self._headers)
 
1337
        self._write_structure(args)
 
1338
        readv_bytes = self._serialise_offsets(body)
 
1339
        if 'hpss' in debug.debug_flags:
 
1340
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1341
        self._write_prefixed_body(readv_bytes)
 
1342
        self._write_end()
 
1343
        self._medium_request.finished_writing()
 
1344
 
 
1345
    def call_with_body_stream(self, args, stream):
 
1346
        if 'hpss' in debug.debug_flags:
 
1347
            mutter('hpss call w/body stream: %r', args)
 
1348
            path = getattr(self._medium_request._medium, '_path', None)
 
1349
            if path is not None:
 
1350
                mutter('                  (to %s)', path)
 
1351
            self._request_start_time = osutils.timer_func()
 
1352
        self._write_protocol_version()
 
1353
        self._write_headers(self._headers)
 
1354
        self._write_structure(args)
 
1355
        # TODO: notice if the server has sent an early error reply before we
 
1356
        #       have finished sending the stream.  We would notice at the end
 
1357
        #       anyway, but if the medium can deliver it early then it's good
 
1358
        #       to short-circuit the whole request...
 
1359
        for exc_info, part in _iter_with_errors(stream):
 
1360
            if exc_info is not None:
 
1361
                # Iterating the stream failed.  Cleanly abort the request.
 
1362
                self._write_error_status()
 
1363
                # Currently the client unconditionally sends ('error',) as the
 
1364
                # error args.
 
1365
                self._write_structure(('error',))
 
1366
                self._write_end()
 
1367
                self._medium_request.finished_writing()
 
1368
                raise exc_info[0], exc_info[1], exc_info[2]
 
1369
            else:
 
1370
                self._write_prefixed_body(part)
 
1371
                self.flush()
 
1372
        self._write_end()
 
1373
        self._medium_request.finished_writing()
 
1374