/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/protocol.py

  • Committer: Martin
  • Date: 2018-11-16 19:10:17 UTC
  • mto: This revision was merged to the branch mainline in revision 7177.
  • Revision ID: gzlist@googlemail.com-20181116191017-kyedz1qck0ovon3h
Remove lazy_regexp reset in bt.test_source

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
18
client and server.
 
19
"""
 
20
 
 
21
from __future__ import absolute_import
 
22
 
 
23
import collections
 
24
import struct
 
25
import sys
 
26
try:
 
27
    import _thread
 
28
except ImportError:
 
29
    import thread as _thread
 
30
import time
 
31
 
 
32
import breezy
 
33
from ... import (
 
34
    debug,
 
35
    errors,
 
36
    osutils,
 
37
    )
 
38
from ...sixish import (
 
39
    BytesIO,
 
40
    reraise,
 
41
)
 
42
from . import message, request
 
43
from ...sixish import text_type
 
44
from ...trace import log_exception_quietly, mutter
 
45
from ...bencode import bdecode_as_tuple, bencode
 
46
 
 
47
 
 
48
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
49
# responses to identify the protocol version being used. (There are no version
 
50
# one strings because that version doesn't send any).
 
51
REQUEST_VERSION_TWO = b'bzr request 2\n'
 
52
RESPONSE_VERSION_TWO = b'bzr response 2\n'
 
53
 
 
54
MESSAGE_VERSION_THREE = b'bzr message 3 (bzr 1.6)\n'
 
55
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
56
 
 
57
 
 
58
def _recv_tuple(from_file):
 
59
    req_line = from_file.readline()
 
60
    return _decode_tuple(req_line)
 
61
 
 
62
 
 
63
def _decode_tuple(req_line):
 
64
    if req_line is None or req_line == b'':
 
65
        return None
 
66
    if not req_line.endswith(b'\n'):
 
67
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
68
    return tuple(req_line[:-1].split(b'\x01'))
 
69
 
 
70
 
 
71
def _encode_tuple(args):
 
72
    """Encode the tuple args to a bytestream."""
 
73
    for arg in args:
 
74
        if isinstance(arg, text_type):
 
75
            raise TypeError(args)
 
76
    return b'\x01'.join(args) + b'\n'
 
77
 
 
78
 
 
79
class Requester(object):
 
80
    """Abstract base class for an object that can issue requests on a smart
 
81
    medium.
 
82
    """
 
83
 
 
84
    def call(self, *args):
 
85
        """Make a remote call.
 
86
 
 
87
        :param args: the arguments of this call.
 
88
        """
 
89
        raise NotImplementedError(self.call)
 
90
 
 
91
    def call_with_body_bytes(self, args, body):
 
92
        """Make a remote call with a body.
 
93
 
 
94
        :param args: the arguments of this call.
 
95
        :type body: str
 
96
        :param body: the body to send with the request.
 
97
        """
 
98
        raise NotImplementedError(self.call_with_body_bytes)
 
99
 
 
100
    def call_with_body_readv_array(self, args, body):
 
101
        """Make a remote call with a readv array.
 
102
 
 
103
        :param args: the arguments of this call.
 
104
        :type body: iterable of (start, length) tuples.
 
105
        :param body: the readv ranges to send with this request.
 
106
        """
 
107
        raise NotImplementedError(self.call_with_body_readv_array)
 
108
 
 
109
    def set_headers(self, headers):
 
110
        raise NotImplementedError(self.set_headers)
 
111
 
 
112
 
 
113
class SmartProtocolBase(object):
 
114
    """Methods common to client and server"""
 
115
 
 
116
    # TODO: this only actually accomodates a single block; possibly should
 
117
    # support multiple chunks?
 
118
    def _encode_bulk_data(self, body):
 
119
        """Encode body as a bulk data chunk."""
 
120
        return b''.join((b'%d\n' % len(body), body, b'done\n'))
 
121
 
 
122
    def _serialise_offsets(self, offsets):
 
123
        """Serialise a readv offset list."""
 
124
        txt = []
 
125
        for start, length in offsets:
 
126
            txt.append(b'%d,%d' % (start, length))
 
127
        return b'\n'.join(txt)
 
128
 
 
129
 
 
130
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
131
    """Server-side encoding and decoding logic for smart version 1."""
 
132
 
 
133
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
134
                 jail_root=None):
 
135
        self._backing_transport = backing_transport
 
136
        self._root_client_path = root_client_path
 
137
        self._jail_root = jail_root
 
138
        self.unused_data = b''
 
139
        self._finished = False
 
140
        self.in_buffer = b''
 
141
        self._has_dispatched = False
 
142
        self.request = None
 
143
        self._body_decoder = None
 
144
        self._write_func = write_func
 
145
 
 
146
    def accept_bytes(self, data):
 
147
        """Take bytes, and advance the internal state machine appropriately.
 
148
 
 
149
        :param data: must be a byte string
 
150
        """
 
151
        if not isinstance(data, bytes):
 
152
            raise ValueError(data)
 
153
        self.in_buffer += data
 
154
        if not self._has_dispatched:
 
155
            if b'\n' not in self.in_buffer:
 
156
                # no command line yet
 
157
                return
 
158
            self._has_dispatched = True
 
159
            try:
 
160
                first_line, self.in_buffer = self.in_buffer.split(b'\n', 1)
 
161
                first_line += b'\n'
 
162
                req_args = _decode_tuple(first_line)
 
163
                self.request = request.SmartServerRequestHandler(
 
164
                    self._backing_transport, commands=request.request_handlers,
 
165
                    root_client_path=self._root_client_path,
 
166
                    jail_root=self._jail_root)
 
167
                self.request.args_received(req_args)
 
168
                if self.request.finished_reading:
 
169
                    # trivial request
 
170
                    self.unused_data = self.in_buffer
 
171
                    self.in_buffer = b''
 
172
                    self._send_response(self.request.response)
 
173
            except KeyboardInterrupt:
 
174
                raise
 
175
            except errors.UnknownSmartMethod as err:
 
176
                protocol_error = errors.SmartProtocolError(
 
177
                    "bad request '%s'" % (err.verb.decode('ascii'),))
 
178
                failure = request.FailedSmartServerResponse(
 
179
                    (b'error', str(protocol_error).encode('utf-8')))
 
180
                self._send_response(failure)
 
181
                return
 
182
            except Exception as exception:
 
183
                # everything else: pass to client, flush, and quit
 
184
                log_exception_quietly()
 
185
                self._send_response(request.FailedSmartServerResponse(
 
186
                    (b'error', str(exception).encode('utf-8'))))
 
187
                return
 
188
 
 
189
        if self._has_dispatched:
 
190
            if self._finished:
 
191
                # nothing to do.XXX: this routine should be a single state
 
192
                # machine too.
 
193
                self.unused_data += self.in_buffer
 
194
                self.in_buffer = b''
 
195
                return
 
196
            if self._body_decoder is None:
 
197
                self._body_decoder = LengthPrefixedBodyDecoder()
 
198
            self._body_decoder.accept_bytes(self.in_buffer)
 
199
            self.in_buffer = self._body_decoder.unused_data
 
200
            body_data = self._body_decoder.read_pending_data()
 
201
            self.request.accept_body(body_data)
 
202
            if self._body_decoder.finished_reading:
 
203
                self.request.end_of_body()
 
204
                if not self.request.finished_reading:
 
205
                    raise AssertionError("no more body, request not finished")
 
206
            if self.request.response is not None:
 
207
                self._send_response(self.request.response)
 
208
                self.unused_data = self.in_buffer
 
209
                self.in_buffer = b''
 
210
            else:
 
211
                if self.request.finished_reading:
 
212
                    raise AssertionError(
 
213
                        "no response and we have finished reading.")
 
214
 
 
215
    def _send_response(self, response):
 
216
        """Send a smart server response down the output stream."""
 
217
        if self._finished:
 
218
            raise AssertionError('response already sent')
 
219
        args = response.args
 
220
        body = response.body
 
221
        self._finished = True
 
222
        self._write_protocol_version()
 
223
        self._write_success_or_failure_prefix(response)
 
224
        self._write_func(_encode_tuple(args))
 
225
        if body is not None:
 
226
            if not isinstance(body, bytes):
 
227
                raise ValueError(body)
 
228
            data = self._encode_bulk_data(body)
 
229
            self._write_func(data)
 
230
 
 
231
    def _write_protocol_version(self):
 
232
        """Write any prefixes this protocol requires.
 
233
 
 
234
        Version one doesn't send protocol versions.
 
235
        """
 
236
 
 
237
    def _write_success_or_failure_prefix(self, response):
 
238
        """Write the protocol specific success/failure prefix.
 
239
 
 
240
        For SmartServerRequestProtocolOne this is omitted but we
 
241
        call is_successful to ensure that the response is valid.
 
242
        """
 
243
        response.is_successful()
 
244
 
 
245
    def next_read_size(self):
 
246
        if self._finished:
 
247
            return 0
 
248
        if self._body_decoder is None:
 
249
            return 1
 
250
        else:
 
251
            return self._body_decoder.next_read_size()
 
252
 
 
253
 
 
254
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
255
    r"""Version two of the server side of the smart protocol.
 
256
 
 
257
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
258
    """
 
259
 
 
260
    response_marker = RESPONSE_VERSION_TWO
 
261
    request_marker = REQUEST_VERSION_TWO
 
262
 
 
263
    def _write_success_or_failure_prefix(self, response):
 
264
        """Write the protocol specific success/failure prefix."""
 
265
        if response.is_successful():
 
266
            self._write_func(b'success\n')
 
267
        else:
 
268
            self._write_func(b'failed\n')
 
269
 
 
270
    def _write_protocol_version(self):
 
271
        r"""Write any prefixes this protocol requires.
 
272
 
 
273
        Version two sends the value of RESPONSE_VERSION_TWO.
 
274
        """
 
275
        self._write_func(self.response_marker)
 
276
 
 
277
    def _send_response(self, response):
 
278
        """Send a smart server response down the output stream."""
 
279
        if (self._finished):
 
280
            raise AssertionError('response already sent')
 
281
        self._finished = True
 
282
        self._write_protocol_version()
 
283
        self._write_success_or_failure_prefix(response)
 
284
        self._write_func(_encode_tuple(response.args))
 
285
        if response.body is not None:
 
286
            if not isinstance(response.body, bytes):
 
287
                raise AssertionError('body must be bytes')
 
288
            if not (response.body_stream is None):
 
289
                raise AssertionError(
 
290
                    'body_stream and body cannot both be set')
 
291
            data = self._encode_bulk_data(response.body)
 
292
            self._write_func(data)
 
293
        elif response.body_stream is not None:
 
294
            _send_stream(response.body_stream, self._write_func)
 
295
 
 
296
 
 
297
def _send_stream(stream, write_func):
 
298
    write_func(b'chunked\n')
 
299
    _send_chunks(stream, write_func)
 
300
    write_func(b'END\n')
 
301
 
 
302
 
 
303
def _send_chunks(stream, write_func):
 
304
    for chunk in stream:
 
305
        if isinstance(chunk, bytes):
 
306
            data = ("%x\n" % len(chunk)).encode('ascii') + chunk
 
307
            write_func(data)
 
308
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
309
            write_func(b'ERR\n')
 
310
            _send_chunks(chunk.args, write_func)
 
311
            return
 
312
        else:
 
313
            raise errors.BzrError(
 
314
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
315
                % chunk)
 
316
 
 
317
 
 
318
class _NeedMoreBytes(Exception):
 
319
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
320
    have been received.
 
321
    """
 
322
 
 
323
    def __init__(self, count=None):
 
324
        """Constructor.
 
325
 
 
326
        :param count: the total number of bytes needed by the current state.
 
327
            May be None if the number of bytes needed is unknown.
 
328
        """
 
329
        self.count = count
 
330
 
 
331
 
 
332
class _StatefulDecoder(object):
 
333
    """Base class for writing state machines to decode byte streams.
 
334
 
 
335
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
336
    and, if appropriate, updates self.state_accept to a different function.
 
337
    accept_bytes will call state_accept as often as necessary to make sure the
 
338
    state machine has progressed as far as possible before it returns.
 
339
 
 
340
    See ProtocolThreeDecoder for an example subclass.
 
341
    """
 
342
 
 
343
    def __init__(self):
 
344
        self.finished_reading = False
 
345
        self._in_buffer_list = []
 
346
        self._in_buffer_len = 0
 
347
        self.unused_data = b''
 
348
        self.bytes_left = None
 
349
        self._number_needed_bytes = None
 
350
 
 
351
    def _get_in_buffer(self):
 
352
        if len(self._in_buffer_list) == 1:
 
353
            return self._in_buffer_list[0]
 
354
        in_buffer = b''.join(self._in_buffer_list)
 
355
        if len(in_buffer) != self._in_buffer_len:
 
356
            raise AssertionError(
 
357
                "Length of buffer did not match expected value: %s != %s"
 
358
                % self._in_buffer_len, len(in_buffer))
 
359
        self._in_buffer_list = [in_buffer]
 
360
        return in_buffer
 
361
 
 
362
    def _get_in_bytes(self, count):
 
363
        """Grab X bytes from the input_buffer.
 
364
 
 
365
        Callers should have already checked that self._in_buffer_len is >
 
366
        count. Note, this does not consume the bytes from the buffer. The
 
367
        caller will still need to call _get_in_buffer() and then
 
368
        _set_in_buffer() if they actually need to consume the bytes.
 
369
        """
 
370
        # check if we can yield the bytes from just the first entry in our list
 
371
        if len(self._in_buffer_list) == 0:
 
372
            raise AssertionError('Callers must be sure we have buffered bytes'
 
373
                                 ' before calling _get_in_bytes')
 
374
        if len(self._in_buffer_list[0]) > count:
 
375
            return self._in_buffer_list[0][:count]
 
376
        # We can't yield it from the first buffer, so collapse all buffers, and
 
377
        # yield it from that
 
378
        in_buf = self._get_in_buffer()
 
379
        return in_buf[:count]
 
380
 
 
381
    def _set_in_buffer(self, new_buf):
 
382
        if new_buf is not None:
 
383
            if not isinstance(new_buf, bytes):
 
384
                raise TypeError(new_buf)
 
385
            self._in_buffer_list = [new_buf]
 
386
            self._in_buffer_len = len(new_buf)
 
387
        else:
 
388
            self._in_buffer_list = []
 
389
            self._in_buffer_len = 0
 
390
 
 
391
    def accept_bytes(self, new_buf):
 
392
        """Decode as much of bytes as possible.
 
393
 
 
394
        If 'new_buf' contains too much data it will be appended to
 
395
        self.unused_data.
 
396
 
 
397
        finished_reading will be set when no more data is required.  Further
 
398
        data will be appended to self.unused_data.
 
399
        """
 
400
        if not isinstance(new_buf, bytes):
 
401
            raise TypeError(new_buf)
 
402
        # accept_bytes is allowed to change the state
 
403
        self._number_needed_bytes = None
 
404
        # lsprof puts a very large amount of time on this specific call for
 
405
        # large readv arrays
 
406
        self._in_buffer_list.append(new_buf)
 
407
        self._in_buffer_len += len(new_buf)
 
408
        try:
 
409
            # Run the function for the current state.
 
410
            current_state = self.state_accept
 
411
            self.state_accept()
 
412
            while current_state != self.state_accept:
 
413
                # The current state has changed.  Run the function for the new
 
414
                # current state, so that it can:
 
415
                #   - decode any unconsumed bytes left in a buffer, and
 
416
                #   - signal how many more bytes are expected (via raising
 
417
                #     _NeedMoreBytes).
 
418
                current_state = self.state_accept
 
419
                self.state_accept()
 
420
        except _NeedMoreBytes as e:
 
421
            self._number_needed_bytes = e.count
 
422
 
 
423
 
 
424
class ChunkedBodyDecoder(_StatefulDecoder):
 
425
    """Decoder for chunked body data.
 
426
 
 
427
    This is very similar the HTTP's chunked encoding.  See the description of
 
428
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
429
    """
 
430
 
 
431
    def __init__(self):
 
432
        _StatefulDecoder.__init__(self)
 
433
        self.state_accept = self._state_accept_expecting_header
 
434
        self.chunk_in_progress = None
 
435
        self.chunks = collections.deque()
 
436
        self.error = False
 
437
        self.error_in_progress = None
 
438
 
 
439
    def next_read_size(self):
 
440
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
441
        # end-of-body marker is 4 bytes: 'END\n'.
 
442
        if self.state_accept == self._state_accept_reading_chunk:
 
443
            # We're expecting more chunk content.  So we're expecting at least
 
444
            # the rest of this chunk plus an END chunk.
 
445
            return self.bytes_left + 4
 
446
        elif self.state_accept == self._state_accept_expecting_length:
 
447
            if self._in_buffer_len == 0:
 
448
                # We're expecting a chunk length.  There's at least two bytes
 
449
                # left: a digit plus '\n'.
 
450
                return 2
 
451
            else:
 
452
                # We're in the middle of reading a chunk length.  So there's at
 
453
                # least one byte left, the '\n' that terminates the length.
 
454
                return 1
 
455
        elif self.state_accept == self._state_accept_reading_unused:
 
456
            return 1
 
457
        elif self.state_accept == self._state_accept_expecting_header:
 
458
            return max(0, len('chunked\n') - self._in_buffer_len)
 
459
        else:
 
460
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
461
 
 
462
    def read_next_chunk(self):
 
463
        try:
 
464
            return self.chunks.popleft()
 
465
        except IndexError:
 
466
            return None
 
467
 
 
468
    def _extract_line(self):
 
469
        in_buf = self._get_in_buffer()
 
470
        pos = in_buf.find(b'\n')
 
471
        if pos == -1:
 
472
            # We haven't read a complete line yet, so request more bytes before
 
473
            # we continue.
 
474
            raise _NeedMoreBytes(1)
 
475
        line = in_buf[:pos]
 
476
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
477
        self._set_in_buffer(in_buf[pos + 1:])
 
478
        return line
 
479
 
 
480
    def _finished(self):
 
481
        self.unused_data = self._get_in_buffer()
 
482
        self._in_buffer_list = []
 
483
        self._in_buffer_len = 0
 
484
        self.state_accept = self._state_accept_reading_unused
 
485
        if self.error:
 
486
            error_args = tuple(self.error_in_progress)
 
487
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
488
            self.error_in_progress = None
 
489
        self.finished_reading = True
 
490
 
 
491
    def _state_accept_expecting_header(self):
 
492
        prefix = self._extract_line()
 
493
        if prefix == b'chunked':
 
494
            self.state_accept = self._state_accept_expecting_length
 
495
        else:
 
496
            raise errors.SmartProtocolError(
 
497
                'Bad chunked body header: "%s"' % (prefix,))
 
498
 
 
499
    def _state_accept_expecting_length(self):
 
500
        prefix = self._extract_line()
 
501
        if prefix == b'ERR':
 
502
            self.error = True
 
503
            self.error_in_progress = []
 
504
            self._state_accept_expecting_length()
 
505
            return
 
506
        elif prefix == b'END':
 
507
            # We've read the end-of-body marker.
 
508
            # Any further bytes are unused data, including the bytes left in
 
509
            # the _in_buffer.
 
510
            self._finished()
 
511
            return
 
512
        else:
 
513
            self.bytes_left = int(prefix, 16)
 
514
            self.chunk_in_progress = b''
 
515
            self.state_accept = self._state_accept_reading_chunk
 
516
 
 
517
    def _state_accept_reading_chunk(self):
 
518
        in_buf = self._get_in_buffer()
 
519
        in_buffer_len = len(in_buf)
 
520
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
521
        self._set_in_buffer(in_buf[self.bytes_left:])
 
522
        self.bytes_left -= in_buffer_len
 
523
        if self.bytes_left <= 0:
 
524
            # Finished with chunk
 
525
            self.bytes_left = None
 
526
            if self.error:
 
527
                self.error_in_progress.append(self.chunk_in_progress)
 
528
            else:
 
529
                self.chunks.append(self.chunk_in_progress)
 
530
            self.chunk_in_progress = None
 
531
            self.state_accept = self._state_accept_expecting_length
 
532
 
 
533
    def _state_accept_reading_unused(self):
 
534
        self.unused_data += self._get_in_buffer()
 
535
        self._in_buffer_list = []
 
536
 
 
537
 
 
538
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
539
    """Decodes the length-prefixed bulk data."""
 
540
 
 
541
    def __init__(self):
 
542
        _StatefulDecoder.__init__(self)
 
543
        self.state_accept = self._state_accept_expecting_length
 
544
        self.state_read = self._state_read_no_data
 
545
        self._body = b''
 
546
        self._trailer_buffer = b''
 
547
 
 
548
    def next_read_size(self):
 
549
        if self.bytes_left is not None:
 
550
            # Ideally we want to read all the remainder of the body and the
 
551
            # trailer in one go.
 
552
            return self.bytes_left + 5
 
553
        elif self.state_accept == self._state_accept_reading_trailer:
 
554
            # Just the trailer left
 
555
            return 5 - len(self._trailer_buffer)
 
556
        elif self.state_accept == self._state_accept_expecting_length:
 
557
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
558
            # 'done\n').
 
559
            return 6
 
560
        else:
 
561
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
562
            return 1
 
563
 
 
564
    def read_pending_data(self):
 
565
        """Return any pending data that has been decoded."""
 
566
        return self.state_read()
 
567
 
 
568
    def _state_accept_expecting_length(self):
 
569
        in_buf = self._get_in_buffer()
 
570
        pos = in_buf.find(b'\n')
 
571
        if pos == -1:
 
572
            return
 
573
        self.bytes_left = int(in_buf[:pos])
 
574
        self._set_in_buffer(in_buf[pos + 1:])
 
575
        self.state_accept = self._state_accept_reading_body
 
576
        self.state_read = self._state_read_body_buffer
 
577
 
 
578
    def _state_accept_reading_body(self):
 
579
        in_buf = self._get_in_buffer()
 
580
        self._body += in_buf
 
581
        self.bytes_left -= len(in_buf)
 
582
        self._set_in_buffer(None)
 
583
        if self.bytes_left <= 0:
 
584
            # Finished with body
 
585
            if self.bytes_left != 0:
 
586
                self._trailer_buffer = self._body[self.bytes_left:]
 
587
                self._body = self._body[:self.bytes_left]
 
588
            self.bytes_left = None
 
589
            self.state_accept = self._state_accept_reading_trailer
 
590
 
 
591
    def _state_accept_reading_trailer(self):
 
592
        self._trailer_buffer += self._get_in_buffer()
 
593
        self._set_in_buffer(None)
 
594
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
595
        # a ProtocolViolation exception?
 
596
        if self._trailer_buffer.startswith(b'done\n'):
 
597
            self.unused_data = self._trailer_buffer[len(b'done\n'):]
 
598
            self.state_accept = self._state_accept_reading_unused
 
599
            self.finished_reading = True
 
600
 
 
601
    def _state_accept_reading_unused(self):
 
602
        self.unused_data += self._get_in_buffer()
 
603
        self._set_in_buffer(None)
 
604
 
 
605
    def _state_read_no_data(self):
 
606
        return b''
 
607
 
 
608
    def _state_read_body_buffer(self):
 
609
        result = self._body
 
610
        self._body = b''
 
611
        return result
 
612
 
 
613
 
 
614
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
615
                                    message.ResponseHandler):
 
616
    """The client-side protocol for smart version 1."""
 
617
 
 
618
    def __init__(self, request):
 
619
        """Construct a SmartClientRequestProtocolOne.
 
620
 
 
621
        :param request: A SmartClientMediumRequest to serialise onto and
 
622
            deserialise from.
 
623
        """
 
624
        self._request = request
 
625
        self._body_buffer = None
 
626
        self._request_start_time = None
 
627
        self._last_verb = None
 
628
        self._headers = None
 
629
 
 
630
    def set_headers(self, headers):
 
631
        self._headers = dict(headers)
 
632
 
 
633
    def call(self, *args):
 
634
        if 'hpss' in debug.debug_flags:
 
635
            mutter('hpss call:   %s', repr(args)[1:-1])
 
636
            if getattr(self._request._medium, 'base', None) is not None:
 
637
                mutter('             (to %s)', self._request._medium.base)
 
638
            self._request_start_time = osutils.timer_func()
 
639
        self._write_args(args)
 
640
        self._request.finished_writing()
 
641
        self._last_verb = args[0]
 
642
 
 
643
    def call_with_body_bytes(self, args, body):
 
644
        """Make a remote call of args with body bytes 'body'.
 
645
 
 
646
        After calling this, call read_response_tuple to find the result out.
 
647
        """
 
648
        if 'hpss' in debug.debug_flags:
 
649
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
650
            if getattr(self._request._medium, '_path', None) is not None:
 
651
                mutter('                  (to %s)',
 
652
                       self._request._medium._path)
 
653
            mutter('              %d bytes', len(body))
 
654
            self._request_start_time = osutils.timer_func()
 
655
            if 'hpssdetail' in debug.debug_flags:
 
656
                mutter('hpss body content: %s', body)
 
657
        self._write_args(args)
 
658
        bytes = self._encode_bulk_data(body)
 
659
        self._request.accept_bytes(bytes)
 
660
        self._request.finished_writing()
 
661
        self._last_verb = args[0]
 
662
 
 
663
    def call_with_body_readv_array(self, args, body):
 
664
        """Make a remote call with a readv array.
 
665
 
 
666
        The body is encoded with one line per readv offset pair. The numbers in
 
667
        each pair are separated by a comma, and no trailing \\n is emitted.
 
668
        """
 
669
        if 'hpss' in debug.debug_flags:
 
670
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
671
            if getattr(self._request._medium, '_path', None) is not None:
 
672
                mutter('                  (to %s)',
 
673
                       self._request._medium._path)
 
674
            self._request_start_time = osutils.timer_func()
 
675
        self._write_args(args)
 
676
        readv_bytes = self._serialise_offsets(body)
 
677
        bytes = self._encode_bulk_data(readv_bytes)
 
678
        self._request.accept_bytes(bytes)
 
679
        self._request.finished_writing()
 
680
        if 'hpss' in debug.debug_flags:
 
681
            mutter('              %d bytes in readv request', len(readv_bytes))
 
682
        self._last_verb = args[0]
 
683
 
 
684
    def call_with_body_stream(self, args, stream):
 
685
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
686
        # assume that a v1/v2 server doesn't support whatever method we're
 
687
        # trying to call with a body stream.
 
688
        self._request.finished_writing()
 
689
        self._request.finished_reading()
 
690
        raise errors.UnknownSmartMethod(args[0])
 
691
 
 
692
    def cancel_read_body(self):
 
693
        """After expecting a body, a response code may indicate one otherwise.
 
694
 
 
695
        This method lets the domain client inform the protocol that no body
 
696
        will be transmitted. This is a terminal method: after calling it the
 
697
        protocol is not able to be used further.
 
698
        """
 
699
        self._request.finished_reading()
 
700
 
 
701
    def _read_response_tuple(self):
 
702
        result = self._recv_tuple()
 
703
        if 'hpss' in debug.debug_flags:
 
704
            if self._request_start_time is not None:
 
705
                mutter('   result:   %6.3fs  %s',
 
706
                       osutils.timer_func() - self._request_start_time,
 
707
                       repr(result)[1:-1])
 
708
                self._request_start_time = None
 
709
            else:
 
710
                mutter('   result:   %s', repr(result)[1:-1])
 
711
        return result
 
712
 
 
713
    def read_response_tuple(self, expect_body=False):
 
714
        """Read a response tuple from the wire.
 
715
 
 
716
        This should only be called once.
 
717
        """
 
718
        result = self._read_response_tuple()
 
719
        self._response_is_unknown_method(result)
 
720
        self._raise_args_if_error(result)
 
721
        if not expect_body:
 
722
            self._request.finished_reading()
 
723
        return result
 
724
 
 
725
    def _raise_args_if_error(self, result_tuple):
 
726
        # Later protocol versions have an explicit flag in the protocol to say
 
727
        # if an error response is "failed" or not.  In version 1 we don't have
 
728
        # that luxury.  So here is a complete list of errors that can be
 
729
        # returned in response to existing version 1 smart requests.  Responses
 
730
        # starting with these codes are always "failed" responses.
 
731
        v1_error_codes = [
 
732
            b'norepository',
 
733
            b'NoSuchFile',
 
734
            b'FileExists',
 
735
            b'DirectoryNotEmpty',
 
736
            b'ShortReadvError',
 
737
            b'UnicodeEncodeError',
 
738
            b'UnicodeDecodeError',
 
739
            b'ReadOnlyError',
 
740
            b'nobranch',
 
741
            b'NoSuchRevision',
 
742
            b'nosuchrevision',
 
743
            b'LockContention',
 
744
            b'UnlockableTransport',
 
745
            b'LockFailed',
 
746
            b'TokenMismatch',
 
747
            b'ReadError',
 
748
            b'PermissionDenied',
 
749
            ]
 
750
        if result_tuple[0] in v1_error_codes:
 
751
            self._request.finished_reading()
 
752
            raise errors.ErrorFromSmartServer(result_tuple)
 
753
 
 
754
    def _response_is_unknown_method(self, result_tuple):
 
755
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
756
        method' response to the request.
 
757
 
 
758
        :param response: The response from a smart client call_expecting_body
 
759
            call.
 
760
        :param verb: The verb used in that call.
 
761
        :raises: UnexpectedSmartServerResponse
 
762
        """
 
763
        if (result_tuple == (b'error', b"Generic bzr smart protocol error: "
 
764
                             b"bad request '" + self._last_verb + b"'")
 
765
            or result_tuple == (b'error', b"Generic bzr smart protocol error: "
 
766
                                b"bad request u'%s'" % self._last_verb)):
 
767
            # The response will have no body, so we've finished reading.
 
768
            self._request.finished_reading()
 
769
            raise errors.UnknownSmartMethod(self._last_verb)
 
770
 
 
771
    def read_body_bytes(self, count=-1):
 
772
        """Read bytes from the body, decoding into a byte stream.
 
773
 
 
774
        We read all bytes at once to ensure we've checked the trailer for
 
775
        errors, and then feed the buffer back as read_body_bytes is called.
 
776
        """
 
777
        if self._body_buffer is not None:
 
778
            return self._body_buffer.read(count)
 
779
        _body_decoder = LengthPrefixedBodyDecoder()
 
780
 
 
781
        while not _body_decoder.finished_reading:
 
782
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
783
            if bytes == b'':
 
784
                # end of file encountered reading from server
 
785
                raise errors.ConnectionReset(
 
786
                    "Connection lost while reading response body.")
 
787
            _body_decoder.accept_bytes(bytes)
 
788
        self._request.finished_reading()
 
789
        self._body_buffer = BytesIO(_body_decoder.read_pending_data())
 
790
        # XXX: TODO check the trailer result.
 
791
        if 'hpss' in debug.debug_flags:
 
792
            mutter('              %d body bytes read',
 
793
                   len(self._body_buffer.getvalue()))
 
794
        return self._body_buffer.read(count)
 
795
 
 
796
    def _recv_tuple(self):
 
797
        """Receive a tuple from the medium request."""
 
798
        return _decode_tuple(self._request.read_line())
 
799
 
 
800
    def query_version(self):
 
801
        """Return protocol version number of the server."""
 
802
        self.call(b'hello')
 
803
        resp = self.read_response_tuple()
 
804
        if resp == (b'ok', b'1'):
 
805
            return 1
 
806
        elif resp == (b'ok', b'2'):
 
807
            return 2
 
808
        else:
 
809
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
810
 
 
811
    def _write_args(self, args):
 
812
        self._write_protocol_version()
 
813
        bytes = _encode_tuple(args)
 
814
        self._request.accept_bytes(bytes)
 
815
 
 
816
    def _write_protocol_version(self):
 
817
        """Write any prefixes this protocol requires.
 
818
 
 
819
        Version one doesn't send protocol versions.
 
820
        """
 
821
 
 
822
 
 
823
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
824
    """Version two of the client side of the smart protocol.
 
825
 
 
826
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
827
    """
 
828
 
 
829
    response_marker = RESPONSE_VERSION_TWO
 
830
    request_marker = REQUEST_VERSION_TWO
 
831
 
 
832
    def read_response_tuple(self, expect_body=False):
 
833
        """Read a response tuple from the wire.
 
834
 
 
835
        This should only be called once.
 
836
        """
 
837
        version = self._request.read_line()
 
838
        if version != self.response_marker:
 
839
            self._request.finished_reading()
 
840
            raise errors.UnexpectedProtocolVersionMarker(version)
 
841
        response_status = self._request.read_line()
 
842
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
843
        self._response_is_unknown_method(result)
 
844
        if response_status == b'success\n':
 
845
            self.response_status = True
 
846
            if not expect_body:
 
847
                self._request.finished_reading()
 
848
            return result
 
849
        elif response_status == b'failed\n':
 
850
            self.response_status = False
 
851
            self._request.finished_reading()
 
852
            raise errors.ErrorFromSmartServer(result)
 
853
        else:
 
854
            raise errors.SmartProtocolError(
 
855
                'bad protocol status %r' % response_status)
 
856
 
 
857
    def _write_protocol_version(self):
 
858
        """Write any prefixes this protocol requires.
 
859
 
 
860
        Version two sends the value of REQUEST_VERSION_TWO.
 
861
        """
 
862
        self._request.accept_bytes(self.request_marker)
 
863
 
 
864
    def read_streamed_body(self):
 
865
        """Read bytes from the body, decoding into a byte stream.
 
866
        """
 
867
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
868
        # buffer space available) on Windows.
 
869
        _body_decoder = ChunkedBodyDecoder()
 
870
        while not _body_decoder.finished_reading:
 
871
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
872
            if bytes == b'':
 
873
                # end of file encountered reading from server
 
874
                raise errors.ConnectionReset(
 
875
                    "Connection lost while reading streamed body.")
 
876
            _body_decoder.accept_bytes(bytes)
 
877
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
878
                if 'hpss' in debug.debug_flags and isinstance(body_bytes, str):
 
879
                    mutter('              %d byte chunk read',
 
880
                           len(body_bytes))
 
881
                yield body_bytes
 
882
        self._request.finished_reading()
 
883
 
 
884
 
 
885
def build_server_protocol_three(backing_transport, write_func,
 
886
                                root_client_path, jail_root=None):
 
887
    request_handler = request.SmartServerRequestHandler(
 
888
        backing_transport, commands=request.request_handlers,
 
889
        root_client_path=root_client_path, jail_root=jail_root)
 
890
    responder = ProtocolThreeResponder(write_func)
 
891
    message_handler = message.ConventionalRequestHandler(
 
892
        request_handler, responder)
 
893
    return ProtocolThreeDecoder(message_handler)
 
894
 
 
895
 
 
896
class ProtocolThreeDecoder(_StatefulDecoder):
 
897
 
 
898
    response_marker = RESPONSE_VERSION_THREE
 
899
    request_marker = REQUEST_VERSION_THREE
 
900
 
 
901
    def __init__(self, message_handler, expect_version_marker=False):
 
902
        _StatefulDecoder.__init__(self)
 
903
        self._has_dispatched = False
 
904
        # Initial state
 
905
        if expect_version_marker:
 
906
            self.state_accept = self._state_accept_expecting_protocol_version
 
907
            # We're expecting at least the protocol version marker + some
 
908
            # headers.
 
909
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
910
        else:
 
911
            self.state_accept = self._state_accept_expecting_headers
 
912
            self._number_needed_bytes = 4
 
913
        self.decoding_failed = False
 
914
        self.request_handler = self.message_handler = message_handler
 
915
 
 
916
    def accept_bytes(self, bytes):
 
917
        self._number_needed_bytes = None
 
918
        try:
 
919
            _StatefulDecoder.accept_bytes(self, bytes)
 
920
        except KeyboardInterrupt:
 
921
            raise
 
922
        except errors.SmartMessageHandlerError as exception:
 
923
            # We do *not* set self.decoding_failed here.  The message handler
 
924
            # has raised an error, but the decoder is still able to parse bytes
 
925
            # and determine when this message ends.
 
926
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
927
                log_exception_quietly()
 
928
            self.message_handler.protocol_error(exception.exc_value)
 
929
            # The state machine is ready to continue decoding, but the
 
930
            # exception has interrupted the loop that runs the state machine.
 
931
            # So we call accept_bytes again to restart it.
 
932
            self.accept_bytes(b'')
 
933
        except Exception as exception:
 
934
            # The decoder itself has raised an exception.  We cannot continue
 
935
            # decoding.
 
936
            self.decoding_failed = True
 
937
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
938
                # This happens during normal operation when the client tries a
 
939
                # protocol version the server doesn't understand, so no need to
 
940
                # log a traceback every time.
 
941
                # Note that this can only happen when
 
942
                # expect_version_marker=True, which is only the case on the
 
943
                # client side.
 
944
                pass
 
945
            else:
 
946
                log_exception_quietly()
 
947
            self.message_handler.protocol_error(exception)
 
948
 
 
949
    def _extract_length_prefixed_bytes(self):
 
950
        if self._in_buffer_len < 4:
 
951
            # A length prefix by itself is 4 bytes, and we don't even have that
 
952
            # many yet.
 
953
            raise _NeedMoreBytes(4)
 
954
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
955
        end_of_bytes = 4 + length
 
956
        if self._in_buffer_len < end_of_bytes:
 
957
            # We haven't yet read as many bytes as the length-prefix says there
 
958
            # are.
 
959
            raise _NeedMoreBytes(end_of_bytes)
 
960
        # Extract the bytes from the buffer.
 
961
        in_buf = self._get_in_buffer()
 
962
        bytes = in_buf[4:end_of_bytes]
 
963
        self._set_in_buffer(in_buf[end_of_bytes:])
 
964
        return bytes
 
965
 
 
966
    def _extract_prefixed_bencoded_data(self):
 
967
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
968
        try:
 
969
            decoded = bdecode_as_tuple(prefixed_bytes)
 
970
        except ValueError:
 
971
            raise errors.SmartProtocolError(
 
972
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
973
        return decoded
 
974
 
 
975
    def _extract_single_byte(self):
 
976
        if self._in_buffer_len == 0:
 
977
            # The buffer is empty
 
978
            raise _NeedMoreBytes(1)
 
979
        in_buf = self._get_in_buffer()
 
980
        one_byte = in_buf[0:1]
 
981
        self._set_in_buffer(in_buf[1:])
 
982
        return one_byte
 
983
 
 
984
    def _state_accept_expecting_protocol_version(self):
 
985
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
986
        in_buf = self._get_in_buffer()
 
987
        if needed_bytes > 0:
 
988
            # We don't have enough bytes to check if the protocol version
 
989
            # marker is right.  But we can check if it is already wrong by
 
990
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
991
            # we've read so far.
 
992
            # [In fact, if the remote end isn't bzr we might never receive
 
993
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
994
            # are wrong then we should just raise immediately rather than
 
995
            # stall.]
 
996
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
997
                # We have enough bytes to know the protocol version is wrong
 
998
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
999
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
1000
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
1001
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
1002
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
1003
        self.state_accept = self._state_accept_expecting_headers
 
1004
 
 
1005
    def _state_accept_expecting_headers(self):
 
1006
        decoded = self._extract_prefixed_bencoded_data()
 
1007
        if not isinstance(decoded, dict):
 
1008
            raise errors.SmartProtocolError(
 
1009
                'Header object %r is not a dict' % (decoded,))
 
1010
        self.state_accept = self._state_accept_expecting_message_part
 
1011
        try:
 
1012
            self.message_handler.headers_received(decoded)
 
1013
        except:
 
1014
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1015
 
 
1016
    def _state_accept_expecting_message_part(self):
 
1017
        message_part_kind = self._extract_single_byte()
 
1018
        if message_part_kind == b'o':
 
1019
            self.state_accept = self._state_accept_expecting_one_byte
 
1020
        elif message_part_kind == b's':
 
1021
            self.state_accept = self._state_accept_expecting_structure
 
1022
        elif message_part_kind == b'b':
 
1023
            self.state_accept = self._state_accept_expecting_bytes
 
1024
        elif message_part_kind == b'e':
 
1025
            self.done()
 
1026
        else:
 
1027
            raise errors.SmartProtocolError(
 
1028
                'Bad message kind byte: %r' % (message_part_kind,))
 
1029
 
 
1030
    def _state_accept_expecting_one_byte(self):
 
1031
        byte = self._extract_single_byte()
 
1032
        self.state_accept = self._state_accept_expecting_message_part
 
1033
        try:
 
1034
            self.message_handler.byte_part_received(byte)
 
1035
        except:
 
1036
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1037
 
 
1038
    def _state_accept_expecting_bytes(self):
 
1039
        # XXX: this should not buffer whole message part, but instead deliver
 
1040
        # the bytes as they arrive.
 
1041
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1042
        self.state_accept = self._state_accept_expecting_message_part
 
1043
        try:
 
1044
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1045
        except:
 
1046
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1047
 
 
1048
    def _state_accept_expecting_structure(self):
 
1049
        structure = self._extract_prefixed_bencoded_data()
 
1050
        self.state_accept = self._state_accept_expecting_message_part
 
1051
        try:
 
1052
            self.message_handler.structure_part_received(structure)
 
1053
        except:
 
1054
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1055
 
 
1056
    def done(self):
 
1057
        self.unused_data = self._get_in_buffer()
 
1058
        self._set_in_buffer(None)
 
1059
        self.state_accept = self._state_accept_reading_unused
 
1060
        try:
 
1061
            self.message_handler.end_received()
 
1062
        except:
 
1063
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1064
 
 
1065
    def _state_accept_reading_unused(self):
 
1066
        self.unused_data += self._get_in_buffer()
 
1067
        self._set_in_buffer(None)
 
1068
 
 
1069
    def next_read_size(self):
 
1070
        if self.state_accept == self._state_accept_reading_unused:
 
1071
            return 0
 
1072
        elif self.decoding_failed:
 
1073
            # An exception occured while processing this message, probably from
 
1074
            # self.message_handler.  We're not sure that this state machine is
 
1075
            # in a consistent state, so just signal that we're done (i.e. give
 
1076
            # up).
 
1077
            return 0
 
1078
        else:
 
1079
            if self._number_needed_bytes is not None:
 
1080
                return self._number_needed_bytes - self._in_buffer_len
 
1081
            else:
 
1082
                raise AssertionError("don't know how many bytes are expected!")
 
1083
 
 
1084
 
 
1085
class _ProtocolThreeEncoder(object):
 
1086
 
 
1087
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1088
    BUFFER_SIZE = 1024 * 1024  # 1 MiB buffer before flushing
 
1089
 
 
1090
    def __init__(self, write_func):
 
1091
        self._buf = []
 
1092
        self._buf_len = 0
 
1093
        self._real_write_func = write_func
 
1094
 
 
1095
    def _write_func(self, bytes):
 
1096
        # TODO: Another possibility would be to turn this into an async model.
 
1097
        #       Where we let another thread know that we have some bytes if
 
1098
        #       they want it, but we don't actually block for it
 
1099
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1100
        #       we might just push out smaller bits at a time?
 
1101
        self._buf.append(bytes)
 
1102
        self._buf_len += len(bytes)
 
1103
        if self._buf_len > self.BUFFER_SIZE:
 
1104
            self.flush()
 
1105
 
 
1106
    def flush(self):
 
1107
        if self._buf:
 
1108
            self._real_write_func(b''.join(self._buf))
 
1109
            del self._buf[:]
 
1110
            self._buf_len = 0
 
1111
 
 
1112
    def _serialise_offsets(self, offsets):
 
1113
        """Serialise a readv offset list."""
 
1114
        txt = []
 
1115
        for start, length in offsets:
 
1116
            txt.append(b'%d,%d' % (start, length))
 
1117
        return b'\n'.join(txt)
 
1118
 
 
1119
    def _write_protocol_version(self):
 
1120
        self._write_func(MESSAGE_VERSION_THREE)
 
1121
 
 
1122
    def _write_prefixed_bencode(self, structure):
 
1123
        bytes = bencode(structure)
 
1124
        self._write_func(struct.pack('!L', len(bytes)))
 
1125
        self._write_func(bytes)
 
1126
 
 
1127
    def _write_headers(self, headers):
 
1128
        self._write_prefixed_bencode(headers)
 
1129
 
 
1130
    def _write_structure(self, args):
 
1131
        self._write_func(b's')
 
1132
        utf8_args = []
 
1133
        for arg in args:
 
1134
            if isinstance(arg, text_type):
 
1135
                utf8_args.append(arg.encode('utf8'))
 
1136
            else:
 
1137
                utf8_args.append(arg)
 
1138
        self._write_prefixed_bencode(utf8_args)
 
1139
 
 
1140
    def _write_end(self):
 
1141
        self._write_func(b'e')
 
1142
        self.flush()
 
1143
 
 
1144
    def _write_prefixed_body(self, bytes):
 
1145
        self._write_func(b'b')
 
1146
        self._write_func(struct.pack('!L', len(bytes)))
 
1147
        self._write_func(bytes)
 
1148
 
 
1149
    def _write_chunked_body_start(self):
 
1150
        self._write_func(b'oC')
 
1151
 
 
1152
    def _write_error_status(self):
 
1153
        self._write_func(b'oE')
 
1154
 
 
1155
    def _write_success_status(self):
 
1156
        self._write_func(b'oS')
 
1157
 
 
1158
 
 
1159
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1160
 
 
1161
    def __init__(self, write_func):
 
1162
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1163
        self.response_sent = False
 
1164
        self._headers = {
 
1165
            b'Software version': breezy.__version__.encode('utf-8')}
 
1166
        if 'hpss' in debug.debug_flags:
 
1167
            self._thread_id = _thread.get_ident()
 
1168
            self._response_start_time = None
 
1169
 
 
1170
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1171
        if self._response_start_time is None:
 
1172
            self._response_start_time = osutils.timer_func()
 
1173
        if include_time:
 
1174
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1175
        else:
 
1176
            t = ''
 
1177
        if extra_bytes is None:
 
1178
            extra = ''
 
1179
        else:
 
1180
            extra = ' ' + repr(extra_bytes[:40])
 
1181
            if len(extra) > 33:
 
1182
                extra = extra[:29] + extra[-1] + '...'
 
1183
        mutter('%12s: [%s] %s%s%s'
 
1184
               % (action, self._thread_id, t, message, extra))
 
1185
 
 
1186
    def send_error(self, exception):
 
1187
        if self.response_sent:
 
1188
            raise AssertionError(
 
1189
                "send_error(%s) called, but response already sent."
 
1190
                % (exception,))
 
1191
        if isinstance(exception, errors.UnknownSmartMethod):
 
1192
            failure = request.FailedSmartServerResponse(
 
1193
                (b'UnknownMethod', exception.verb))
 
1194
            self.send_response(failure)
 
1195
            return
 
1196
        if 'hpss' in debug.debug_flags:
 
1197
            self._trace('error', str(exception))
 
1198
        self.response_sent = True
 
1199
        self._write_protocol_version()
 
1200
        self._write_headers(self._headers)
 
1201
        self._write_error_status()
 
1202
        self._write_structure(
 
1203
            (b'error', str(exception).encode('utf-8', 'replace')))
 
1204
        self._write_end()
 
1205
 
 
1206
    def send_response(self, response):
 
1207
        if self.response_sent:
 
1208
            raise AssertionError(
 
1209
                "send_response(%r) called, but response already sent."
 
1210
                % (response,))
 
1211
        self.response_sent = True
 
1212
        self._write_protocol_version()
 
1213
        self._write_headers(self._headers)
 
1214
        if response.is_successful():
 
1215
            self._write_success_status()
 
1216
        else:
 
1217
            self._write_error_status()
 
1218
        if 'hpss' in debug.debug_flags:
 
1219
            self._trace('response', repr(response.args))
 
1220
        self._write_structure(response.args)
 
1221
        if response.body is not None:
 
1222
            self._write_prefixed_body(response.body)
 
1223
            if 'hpss' in debug.debug_flags:
 
1224
                self._trace('body', '%d bytes' % (len(response.body),),
 
1225
                            response.body, include_time=True)
 
1226
        elif response.body_stream is not None:
 
1227
            count = num_bytes = 0
 
1228
            first_chunk = None
 
1229
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1230
                count += 1
 
1231
                if exc_info is not None:
 
1232
                    self._write_error_status()
 
1233
                    error_struct = request._translate_error(exc_info[1])
 
1234
                    self._write_structure(error_struct)
 
1235
                    break
 
1236
                else:
 
1237
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1238
                        self._write_error_status()
 
1239
                        self._write_structure(chunk.args)
 
1240
                        break
 
1241
                    num_bytes += len(chunk)
 
1242
                    if first_chunk is None:
 
1243
                        first_chunk = chunk
 
1244
                    self._write_prefixed_body(chunk)
 
1245
                    self.flush()
 
1246
                    if 'hpssdetail' in debug.debug_flags:
 
1247
                        # Not worth timing separately, as _write_func is
 
1248
                        # actually buffered
 
1249
                        self._trace('body chunk',
 
1250
                                    '%d bytes' % (len(chunk),),
 
1251
                                    chunk, suppress_time=True)
 
1252
            if 'hpss' in debug.debug_flags:
 
1253
                self._trace('body stream',
 
1254
                            '%d bytes %d chunks' % (num_bytes, count),
 
1255
                            first_chunk)
 
1256
        self._write_end()
 
1257
        if 'hpss' in debug.debug_flags:
 
1258
            self._trace('response end', '', include_time=True)
 
1259
 
 
1260
 
 
1261
def _iter_with_errors(iterable):
 
1262
    """Handle errors from iterable.next().
 
1263
 
 
1264
    Use like::
 
1265
 
 
1266
        for exc_info, value in _iter_with_errors(iterable):
 
1267
            ...
 
1268
 
 
1269
    This is a safer alternative to::
 
1270
 
 
1271
        try:
 
1272
            for value in iterable:
 
1273
               ...
 
1274
        except:
 
1275
            ...
 
1276
 
 
1277
    Because the latter will catch errors from the for-loop body, not just
 
1278
    iterable.next()
 
1279
 
 
1280
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1281
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1282
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1283
    will not be itercepted.
 
1284
    """
 
1285
    iterator = iter(iterable)
 
1286
    while True:
 
1287
        try:
 
1288
            yield None, next(iterator)
 
1289
        except StopIteration:
 
1290
            return
 
1291
        except (KeyboardInterrupt, SystemExit):
 
1292
            raise
 
1293
        except Exception:
 
1294
            mutter('_iter_with_errors caught error')
 
1295
            log_exception_quietly()
 
1296
            yield sys.exc_info(), None
 
1297
            return
 
1298
 
 
1299
 
 
1300
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1301
 
 
1302
    def __init__(self, medium_request):
 
1303
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1304
        self._medium_request = medium_request
 
1305
        self._headers = {}
 
1306
        self.body_stream_started = None
 
1307
 
 
1308
    def set_headers(self, headers):
 
1309
        self._headers = headers.copy()
 
1310
 
 
1311
    def call(self, *args):
 
1312
        if 'hpss' in debug.debug_flags:
 
1313
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1314
            base = getattr(self._medium_request._medium, 'base', None)
 
1315
            if base is not None:
 
1316
                mutter('             (to %s)', base)
 
1317
            self._request_start_time = osutils.timer_func()
 
1318
        self._write_protocol_version()
 
1319
        self._write_headers(self._headers)
 
1320
        self._write_structure(args)
 
1321
        self._write_end()
 
1322
        self._medium_request.finished_writing()
 
1323
 
 
1324
    def call_with_body_bytes(self, args, body):
 
1325
        """Make a remote call of args with body bytes 'body'.
 
1326
 
 
1327
        After calling this, call read_response_tuple to find the result out.
 
1328
        """
 
1329
        if 'hpss' in debug.debug_flags:
 
1330
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1331
            path = getattr(self._medium_request._medium, '_path', None)
 
1332
            if path is not None:
 
1333
                mutter('                  (to %s)', path)
 
1334
            mutter('              %d bytes', len(body))
 
1335
            self._request_start_time = osutils.timer_func()
 
1336
        self._write_protocol_version()
 
1337
        self._write_headers(self._headers)
 
1338
        self._write_structure(args)
 
1339
        self._write_prefixed_body(body)
 
1340
        self._write_end()
 
1341
        self._medium_request.finished_writing()
 
1342
 
 
1343
    def call_with_body_readv_array(self, args, body):
 
1344
        """Make a remote call with a readv array.
 
1345
 
 
1346
        The body is encoded with one line per readv offset pair. The numbers in
 
1347
        each pair are separated by a comma, and no trailing \\n is emitted.
 
1348
        """
 
1349
        if 'hpss' in debug.debug_flags:
 
1350
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1351
            path = getattr(self._medium_request._medium, '_path', None)
 
1352
            if path is not None:
 
1353
                mutter('                  (to %s)', path)
 
1354
            self._request_start_time = osutils.timer_func()
 
1355
        self._write_protocol_version()
 
1356
        self._write_headers(self._headers)
 
1357
        self._write_structure(args)
 
1358
        readv_bytes = self._serialise_offsets(body)
 
1359
        if 'hpss' in debug.debug_flags:
 
1360
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1361
        self._write_prefixed_body(readv_bytes)
 
1362
        self._write_end()
 
1363
        self._medium_request.finished_writing()
 
1364
 
 
1365
    def call_with_body_stream(self, args, stream):
 
1366
        if 'hpss' in debug.debug_flags:
 
1367
            mutter('hpss call w/body stream: %r', args)
 
1368
            path = getattr(self._medium_request._medium, '_path', None)
 
1369
            if path is not None:
 
1370
                mutter('                  (to %s)', path)
 
1371
            self._request_start_time = osutils.timer_func()
 
1372
        self.body_stream_started = False
 
1373
        self._write_protocol_version()
 
1374
        self._write_headers(self._headers)
 
1375
        self._write_structure(args)
 
1376
        # TODO: notice if the server has sent an early error reply before we
 
1377
        #       have finished sending the stream.  We would notice at the end
 
1378
        #       anyway, but if the medium can deliver it early then it's good
 
1379
        #       to short-circuit the whole request...
 
1380
        # Provoke any ConnectionReset failures before we start the body stream.
 
1381
        self.flush()
 
1382
        self.body_stream_started = True
 
1383
        for exc_info, part in _iter_with_errors(stream):
 
1384
            if exc_info is not None:
 
1385
                # Iterating the stream failed.  Cleanly abort the request.
 
1386
                self._write_error_status()
 
1387
                # Currently the client unconditionally sends ('error',) as the
 
1388
                # error args.
 
1389
                self._write_structure((b'error',))
 
1390
                self._write_end()
 
1391
                self._medium_request.finished_writing()
 
1392
                try:
 
1393
                    reraise(*exc_info)
 
1394
                finally:
 
1395
                    del exc_info
 
1396
            else:
 
1397
                self._write_prefixed_body(part)
 
1398
                self.flush()
 
1399
        self._write_end()
 
1400
        self._medium_request.finished_writing()