/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Andrew Bennetts
  • Date: 2009-10-21 11:13:40 UTC
  • mto: This revision was merged to the branch mainline in revision 4762.
  • Revision ID: andrew.bennetts@canonical.com-20091021111340-w7x4d5yf83qwjncc
Add test that WSGI glue allows request handlers to access paths above that request's. backing transport, so long as it is within the WSGI app's backing transport.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
27
import errno
30
 
import io
31
28
import os
 
29
import socket
32
30
import sys
33
 
import time
 
31
import urllib
34
32
 
35
 
import breezy
36
 
from ...lazy_import import lazy_import
 
33
from bzrlib.lazy_import import lazy_import
37
34
lazy_import(globals(), """
38
 
import select
39
 
import socket
40
 
import thread
 
35
import atexit
41
36
import weakref
42
 
 
43
 
from breezy import (
 
37
from bzrlib import (
44
38
    debug,
 
39
    errors,
 
40
    symbol_versioning,
45
41
    trace,
46
 
    transport,
47
42
    ui,
48
43
    urlutils,
49
44
    )
50
 
from breezy.i18n import gettext
51
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
52
 
from breezy.transport import ssh
 
45
from bzrlib.smart import client, protocol, request, vfs
 
46
from bzrlib.transport import ssh
53
47
""")
54
 
from ... import (
55
 
    errors,
56
 
    osutils,
57
 
    )
58
 
 
59
 
# Throughout this module buffer size parameters are either limited to be at
60
 
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
61
 
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
62
 
# from non-sockets as well.
63
 
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
64
 
 
65
 
 
66
 
class HpssVfsRequestNotAllowed(errors.BzrError):
67
 
 
68
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
69
 
            "%(method)s, %(arguments)s.")
70
 
 
71
 
    def __init__(self, method, arguments):
72
 
        self.method = method
73
 
        self.arguments = arguments
 
48
#usually already imported, and getting IllegalScoperReplacer on it here.
 
49
from bzrlib import osutils
 
50
 
 
51
# We must not read any more than 64k at a time so we don't risk "no buffer
 
52
# space available" errors on some platforms.  Windows in particular is likely
 
53
# to give error 10053 or 10055 if we read more than 64k from a socket.
 
54
_MAX_READ_SIZE = 64 * 1024
74
55
 
75
56
 
76
57
def _get_protocol_factory_for_bytes(bytes):
114
95
    :returns: a tuple of two strs: (line, excess)
115
96
    """
116
97
    newline_pos = -1
117
 
    bytes = b''
 
98
    bytes = ''
118
99
    while newline_pos == -1:
119
100
        new_bytes = read_bytes_func(1)
120
101
        bytes += new_bytes
121
 
        if new_bytes == b'':
 
102
        if new_bytes == '':
122
103
            # Ran out of bytes before receiving a complete line.
123
 
            return bytes, b''
124
 
        newline_pos = bytes.find(b'\n')
 
104
            return bytes, ''
 
105
        newline_pos = bytes.find('\n')
125
106
    line = bytes[:newline_pos+1]
126
107
    excess = bytes[newline_pos+1:]
127
108
    return line, excess
133
114
    def __init__(self):
134
115
        self._push_back_buffer = None
135
116
 
136
 
    def _push_back(self, data):
 
117
    def _push_back(self, bytes):
137
118
        """Return unused bytes to the medium, because they belong to the next
138
119
        request(s).
139
120
 
140
121
        This sets the _push_back_buffer to the given bytes.
141
122
        """
142
 
        if not isinstance(data, bytes):
143
 
            raise TypeError(data)
144
123
        if self._push_back_buffer is not None:
145
124
            raise AssertionError(
146
125
                "_push_back called when self._push_back_buffer is %r"
147
126
                % (self._push_back_buffer,))
148
 
        if data == b'':
 
127
        if bytes == '':
149
128
            return
150
 
        self._push_back_buffer = data
 
129
        self._push_back_buffer = bytes
151
130
 
152
131
    def _get_push_back_buffer(self):
153
 
        if self._push_back_buffer == b'':
 
132
        if self._push_back_buffer == '':
154
133
            raise AssertionError(
155
134
                '%s._push_back_buffer should never be the empty string, '
156
135
                'which can be confused with EOF' % (self,))
197
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
198
177
 
199
178
 
200
 
_bad_file_descriptor = (errno.EBADF,)
201
 
if sys.platform == 'win32':
202
 
    # Given on Windows if you pass a closed socket to select.select. Probably
203
 
    # also given if you pass a file handle to select.
204
 
    WSAENOTSOCK = 10038
205
 
    _bad_file_descriptor += (WSAENOTSOCK,)
206
 
 
207
 
 
208
179
class SmartServerStreamMedium(SmartMedium):
209
180
    """Handles smart commands coming over a stream.
210
181
 
223
194
        the stream.  See also the _push_back method.
224
195
    """
225
196
 
226
 
    _timer = time.time
227
 
 
228
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
229
198
        """Construct new server.
230
199
 
231
200
        :param backing_transport: Transport for the directory served.
234
203
        self.backing_transport = backing_transport
235
204
        self.root_client_path = root_client_path
236
205
        self.finished = False
237
 
        if timeout is None:
238
 
            raise AssertionError('You must supply a timeout.')
239
 
        self._client_timeout = timeout
240
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
241
206
        SmartMedium.__init__(self)
242
207
 
243
208
    def serve(self):
249
214
            while not self.finished:
250
215
                server_protocol = self._build_protocol()
251
216
                self._serve_one_request(server_protocol)
252
 
        except errors.ConnectionTimeout as e:
253
 
            trace.note('%s' % (e,))
254
 
            trace.log_exception_quietly()
255
 
            self._disconnect_client()
256
 
            # We reported it, no reason to make a big fuss.
257
 
            return
258
 
        except Exception as e:
 
217
        except Exception, e:
259
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
260
219
            raise
261
 
        self._disconnect_client()
262
 
 
263
 
    def _stop_gracefully(self):
264
 
        """When we finish this message, stop looking for more."""
265
 
        trace.mutter('Stopping %s' % (self,))
266
 
        self.finished = True
267
 
 
268
 
    def _disconnect_client(self):
269
 
        """Close the current connection. We stopped due to a timeout/etc."""
270
 
        # The default implementation is a no-op, because that is all we used to
271
 
        # do when disconnecting from a client. I suppose we never had the
272
 
        # *server* initiate a disconnect, before
273
 
 
274
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
275
 
        """Wait for more bytes to be read, but timeout if none available.
276
 
 
277
 
        This allows us to detect idle connections, and stop trying to read from
278
 
        them, without setting the socket itself to non-blocking. This also
279
 
        allows us to specify when we watch for idle timeouts.
280
 
 
281
 
        :return: Did we timeout? (True if we timed out, False if there is data
282
 
            to be read)
283
 
        """
284
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
285
220
 
286
221
    def _build_protocol(self):
287
222
        """Identifies the version of the incoming request, and returns an
292
227
 
293
228
        :returns: a SmartServerRequestProtocol.
294
229
        """
295
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
296
 
        if self.finished:
297
 
            # We're stopping, so don't try to do any more work
298
 
            return None
299
230
        bytes = self._get_line()
300
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
301
232
        protocol = protocol_factory(
303
234
        protocol.accept_bytes(unused_bytes)
304
235
        return protocol
305
236
 
306
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
307
 
        """select() on a file descriptor, waiting for nonblocking read()
308
 
 
309
 
        This will raise a ConnectionTimeout exception if we do not get a
310
 
        readable handle before timeout_seconds.
311
 
        :return: None
312
 
        """
313
 
        t_end = self._timer() + timeout_seconds
314
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
315
 
        rs = xs = None
316
 
        while not rs and not xs and self._timer() < t_end:
317
 
            if self.finished:
318
 
                return
319
 
            try:
320
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
321
 
            except (select.error, socket.error) as e:
322
 
                err = getattr(e, 'errno', None)
323
 
                if err is None and getattr(e, 'args', None) is not None:
324
 
                    # select.error doesn't have 'errno', it just has args[0]
325
 
                    err = e.args[0]
326
 
                if err in _bad_file_descriptor:
327
 
                    return # Not a socket indicates read() will fail
328
 
                elif err == errno.EINTR:
329
 
                    # Interrupted, keep looping.
330
 
                    continue
331
 
                raise
332
 
            except ValueError:
333
 
                return  # Socket may already be closed
334
 
        if rs or xs:
335
 
            return
336
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
337
 
                                       % (timeout_seconds,))
338
 
 
339
237
    def _serve_one_request(self, protocol):
340
238
        """Read one request from input, process, send back a response.
341
239
 
342
240
        :param protocol: a SmartServerRequestProtocol.
343
241
        """
344
 
        if protocol is None:
345
 
            return
346
242
        try:
347
243
            self._serve_one_request_unguarded(protocol)
348
244
        except KeyboardInterrupt:
349
245
            raise
350
 
        except Exception as e:
 
246
        except Exception, e:
351
247
            self.terminate_due_to_error()
352
248
 
353
249
    def terminate_due_to_error(self):
364
260
 
365
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
366
262
 
367
 
    def __init__(self, sock, backing_transport, root_client_path='/',
368
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
369
264
        """Constructor.
370
265
 
371
266
        :param sock: the socket the server will read from.  It will be put
372
267
            into blocking mode.
373
268
        """
374
269
        SmartServerStreamMedium.__init__(
375
 
            self, backing_transport, root_client_path=root_client_path,
376
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
377
271
        sock.setblocking(True)
378
272
        self.socket = sock
379
 
        # Get the getpeername now, as we might be closed later when we care.
380
 
        try:
381
 
            self._client_info = sock.getpeername()
382
 
        except socket.error:
383
 
            self._client_info = '<unknown>'
384
 
 
385
 
    def __str__(self):
386
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
387
 
 
388
 
    def __repr__(self):
389
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
390
 
            self._client_info)
391
273
 
392
274
    def _serve_one_request_unguarded(self, protocol):
393
275
        while protocol.next_read_size():
394
276
            # We can safely try to read large chunks.  If there is less data
395
 
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
396
 
            # short read immediately rather than block.
397
 
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
398
 
            if bytes == b'':
 
277
            # than _MAX_READ_SIZE ready, the socket wil just return a short
 
278
            # read immediately rather than block.
 
279
            bytes = self.read_bytes(_MAX_READ_SIZE)
 
280
            if bytes == '':
399
281
                self.finished = True
400
282
                return
401
283
            protocol.accept_bytes(bytes)
402
284
 
403
285
        self._push_back(protocol.unused_data)
404
286
 
405
 
    def _disconnect_client(self):
406
 
        """Close the current connection. We stopped due to a timeout/etc."""
407
 
        self.socket.close()
408
 
 
409
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
410
 
        """Wait for more bytes to be read, but timeout if none available.
411
 
 
412
 
        This allows us to detect idle connections, and stop trying to read from
413
 
        them, without setting the socket itself to non-blocking. This also
414
 
        allows us to specify when we watch for idle timeouts.
415
 
 
416
 
        :return: None, this will raise ConnectionTimeout if we time out before
417
 
            data is available.
418
 
        """
419
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
420
 
 
421
287
    def _read_bytes(self, desired_count):
422
 
        return osutils.read_bytes_from_socket(
423
 
            self.socket, self._report_activity)
 
288
        return _read_bytes_from_socket(
 
289
            self.socket.recv, desired_count, self._report_activity)
424
290
 
425
291
    def terminate_due_to_error(self):
426
292
        # TODO: This should log to a server log file, but no such thing
427
293
        # exists yet.  Andrew Bennetts 2006-09-29.
428
 
        self.socket.close()
 
294
        osutils.until_no_eintr(self.socket.close)
429
295
        self.finished = True
430
296
 
431
297
    def _write_out(self, bytes):
432
 
        tstart = osutils.timer_func()
433
298
        osutils.send_all(self.socket, bytes, self._report_activity)
434
 
        if 'hpss' in debug.debug_flags:
435
 
            thread_id = thread.get_ident()
436
 
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
437
 
                         % ('wrote', thread_id, len(bytes),
438
 
                            osutils.timer_func() - tstart))
439
299
 
440
300
 
441
301
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
442
302
 
443
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
303
    def __init__(self, in_file, out_file, backing_transport):
444
304
        """Construct new server.
445
305
 
446
306
        :param in_file: Python file from which requests can be read.
447
307
        :param out_file: Python file to write responses.
448
308
        :param backing_transport: Transport for the directory served.
449
309
        """
450
 
        SmartServerStreamMedium.__init__(self, backing_transport,
451
 
            timeout=timeout)
 
310
        SmartServerStreamMedium.__init__(self, backing_transport)
452
311
        if sys.platform == 'win32':
453
312
            # force binary mode for files
454
313
            import msvcrt
459
318
        self._in = in_file
460
319
        self._out = out_file
461
320
 
462
 
    def serve(self):
463
 
        """See SmartServerStreamMedium.serve"""
464
 
        # This is the regular serve, except it adds signal trapping for soft
465
 
        # shutdown.
466
 
        stop_gracefully = self._stop_gracefully
467
 
        signals.register_on_hangup(id(self), stop_gracefully)
468
 
        try:
469
 
            return super(SmartServerPipeStreamMedium, self).serve()
470
 
        finally:
471
 
            signals.unregister_on_hangup(id(self))
472
 
 
473
321
    def _serve_one_request_unguarded(self, protocol):
474
322
        while True:
475
323
            # We need to be careful not to read past the end of the current
478
326
            bytes_to_read = protocol.next_read_size()
479
327
            if bytes_to_read == 0:
480
328
                # Finished serving this request.
481
 
                self._out.flush()
 
329
                osutils.until_no_eintr(self._out.flush)
482
330
                return
483
331
            bytes = self.read_bytes(bytes_to_read)
484
 
            if bytes == b'':
 
332
            if bytes == '':
485
333
                # Connection has been closed.
486
334
                self.finished = True
487
 
                self._out.flush()
 
335
                osutils.until_no_eintr(self._out.flush)
488
336
                return
489
337
            protocol.accept_bytes(bytes)
490
338
 
491
 
    def _disconnect_client(self):
492
 
        self._in.close()
493
 
        self._out.flush()
494
 
        self._out.close()
495
 
 
496
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
497
 
        """Wait for more bytes to be read, but timeout if none available.
498
 
 
499
 
        This allows us to detect idle connections, and stop trying to read from
500
 
        them, without setting the socket itself to non-blocking. This also
501
 
        allows us to specify when we watch for idle timeouts.
502
 
 
503
 
        :return: None, this will raise ConnectionTimeout if we time out before
504
 
            data is available.
505
 
        """
506
 
        if (getattr(self._in, 'fileno', None) is None
507
 
            or sys.platform == 'win32'):
508
 
            # You can't select() file descriptors on Windows.
509
 
            return
510
 
        try:
511
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
512
 
        except io.UnsupportedOperation:
513
 
            return
514
 
 
515
339
    def _read_bytes(self, desired_count):
516
 
        return self._in.read(desired_count)
 
340
        return osutils.until_no_eintr(self._in.read, desired_count)
517
341
 
518
342
    def terminate_due_to_error(self):
519
343
        # TODO: This should log to a server log file, but no such thing
520
344
        # exists yet.  Andrew Bennetts 2006-09-29.
521
 
        self._out.close()
 
345
        osutils.until_no_eintr(self._out.close)
522
346
        self.finished = True
523
347
 
524
348
    def _write_out(self, bytes):
525
 
        self._out.write(bytes)
 
349
        osutils.until_no_eintr(self._out.write, bytes)
526
350
 
527
351
 
528
352
class SmartClientMediumRequest(object):
645
469
 
646
470
    def read_line(self):
647
471
        line = self._read_line()
648
 
        if not line.endswith(b'\n'):
 
472
        if not line.endswith('\n'):
649
473
            # end of file encountered reading from server
650
474
            raise errors.ConnectionReset(
651
475
                "Unexpected end of message. Please check connectivity "
661
485
        return self._medium._get_line()
662
486
 
663
487
 
664
 
class _VfsRefuser(object):
665
 
    """An object that refuses all VFS requests.
666
 
 
667
 
    """
668
 
 
669
 
    def __init__(self):
670
 
        client._SmartClient.hooks.install_named_hook(
671
 
            'call', self.check_vfs, 'vfs refuser')
672
 
 
673
 
    def check_vfs(self, params):
674
 
        try:
675
 
            request_method = request.request_handlers.get(params.method)
676
 
        except KeyError:
677
 
            # A method we don't know about doesn't count as a VFS method.
678
 
            return
679
 
        if issubclass(request_method, vfs.VfsRequest):
680
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
681
 
 
682
 
 
683
488
class _DebugCounter(object):
684
489
    """An object that counts the HPSS calls made to each client medium.
685
490
 
686
 
    When a medium is garbage-collected, or failing that when
687
 
    breezy.global_state exits, the total number of calls made on that medium
688
 
    are reported via trace.note.
 
491
    When a medium is garbage-collected, or failing that when atexit functions
 
492
    are run, the total number of calls made on that medium are reported via
 
493
    trace.note.
689
494
    """
690
495
 
691
496
    def __init__(self):
692
497
        self.counts = weakref.WeakKeyDictionary()
693
498
        client._SmartClient.hooks.install_named_hook(
694
499
            'call', self.increment_call_count, 'hpss call counter')
695
 
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
 
500
        atexit.register(self.flush_all)
696
501
 
697
502
    def track(self, medium):
698
503
        """Start tracking calls made to a medium.
732
537
        value['count'] = 0
733
538
        value['vfs_count'] = 0
734
539
        if count != 0:
735
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
736
 
                       count, vfs_count, medium_repr))
 
540
            trace.note('HPSS calls: %d (%d vfs) %s',
 
541
                       count, vfs_count, medium_repr)
737
542
 
738
543
    def flush_all(self):
739
544
        for ref in list(self.counts.keys()):
740
545
            self.done(ref)
741
546
 
742
547
_debug_counter = None
743
 
_vfs_refuser = None
744
548
 
745
549
 
746
550
class SmartClientMedium(SmartMedium):
763
567
            if _debug_counter is None:
764
568
                _debug_counter = _DebugCounter()
765
569
            _debug_counter.track(self)
766
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
767
 
            global _vfs_refuser
768
 
            if _vfs_refuser is None:
769
 
                _vfs_refuser = _VfsRefuser()
770
570
 
771
571
    def _is_remote_before(self, version_tuple):
772
572
        """Is it possible the remote side supports RPCs for a given version?
801
601
            # which is newer than a previously supplied older-than version.
802
602
            # This indicates that some smart verb call is not guarded
803
603
            # appropriately (it should simply not have been tried).
804
 
            trace.mutter(
 
604
            raise AssertionError(
805
605
                "_remember_remote_is_before(%r) called, but "
806
606
                "_remember_remote_is_before(%r) was called previously."
807
 
                , version_tuple, self._remote_version_is_before)
808
 
            if 'hpss' in debug.debug_flags:
809
 
                ui.ui_factory.show_warning(
810
 
                    "_remember_remote_is_before(%r) called, but "
811
 
                    "_remember_remote_is_before(%r) was called previously."
812
 
                    % (version_tuple, self._remote_version_is_before))
813
 
            return
 
607
                % (version_tuple, self._remote_version_is_before))
814
608
        self._remote_version_is_before = version_tuple
815
609
 
816
610
    def protocol_version(self):
825
619
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
826
620
                client_protocol.query_version()
827
621
                self._done_hello = True
828
 
            except errors.SmartProtocolError as e:
 
622
            except errors.SmartProtocolError, e:
829
623
                # Cache the error, just like we would cache a successful
830
624
                # result.
831
625
                self._protocol_version_error = e
863
657
        """
864
658
        medium_base = urlutils.join(self.base, '/')
865
659
        rel_url = urlutils.relative_url(medium_base, transport.base)
866
 
        return urlutils.unquote(rel_url)
 
660
        return urllib.unquote(rel_url)
867
661
 
868
662
 
869
663
class SmartClientStreamMedium(SmartClientMedium):
904
698
        """
905
699
        return SmartClientStreamMediumRequest(self)
906
700
 
907
 
    def reset(self):
908
 
        """We have been disconnected, reset current state.
909
 
 
910
 
        This resets things like _current_request and connected state.
911
 
        """
912
 
        self.disconnect()
913
 
        self._current_request = None
914
 
 
915
701
 
916
702
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
917
703
    """A client medium using simple pipes.
924
710
        self._readable_pipe = readable_pipe
925
711
        self._writeable_pipe = writeable_pipe
926
712
 
927
 
    def _accept_bytes(self, data):
 
713
    def _accept_bytes(self, bytes):
928
714
        """See SmartClientStreamMedium.accept_bytes."""
929
 
        try:
930
 
            self._writeable_pipe.write(data)
931
 
        except IOError as e:
932
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
933
 
                raise errors.ConnectionReset(
934
 
                    "Error trying to write to subprocess", e)
935
 
            raise
936
 
        self._report_activity(len(data), 'write')
 
715
        osutils.until_no_eintr(self._writeable_pipe.write, bytes)
 
716
        self._report_activity(len(bytes), 'write')
937
717
 
938
718
    def _flush(self):
939
719
        """See SmartClientStreamMedium._flush()."""
940
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
941
 
        #       However, testing shows that even when the child process is
942
 
        #       gone, this doesn't error.
943
 
        self._writeable_pipe.flush()
 
720
        osutils.until_no_eintr(self._writeable_pipe.flush)
944
721
 
945
722
    def _read_bytes(self, count):
946
723
        """See SmartClientStreamMedium._read_bytes."""
947
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
948
 
        data = self._readable_pipe.read(bytes_to_read)
949
 
        self._report_activity(len(data), 'read')
950
 
        return data
951
 
 
952
 
 
953
 
class SSHParams(object):
954
 
    """A set of parameters for starting a remote bzr via SSH."""
 
724
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
 
725
        self._report_activity(len(bytes), 'read')
 
726
        return bytes
 
727
 
 
728
 
 
729
class SmartSSHClientMedium(SmartClientStreamMedium):
 
730
    """A client medium using SSH."""
955
731
 
956
732
    def __init__(self, host, port=None, username=None, password=None,
957
 
            bzr_remote_path='bzr'):
958
 
        self.host = host
959
 
        self.port = port
960
 
        self.username = username
961
 
        self.password = password
962
 
        self.bzr_remote_path = bzr_remote_path
963
 
 
964
 
 
965
 
class SmartSSHClientMedium(SmartClientStreamMedium):
966
 
    """A client medium using SSH.
967
 
 
968
 
    It delegates IO to a SmartSimplePipesClientMedium or
969
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
970
 
    """
971
 
 
972
 
    def __init__(self, base, ssh_params, vendor=None):
 
733
            base=None, vendor=None, bzr_remote_path=None):
973
734
        """Creates a client that will connect on the first use.
974
735
 
975
 
        :param ssh_params: A SSHParams instance.
976
736
        :param vendor: An optional override for the ssh vendor to use. See
977
 
            breezy.transport.ssh for details on ssh vendors.
 
737
            bzrlib.transport.ssh for details on ssh vendors.
978
738
        """
979
 
        self._real_medium = None
980
 
        self._ssh_params = ssh_params
981
 
        # for the benefit of progress making a short description of this
982
 
        # transport
983
 
        self._scheme = 'bzr+ssh'
 
739
        self._connected = False
 
740
        self._host = host
 
741
        self._password = password
 
742
        self._port = port
 
743
        self._username = username
984
744
        # SmartClientStreamMedium stores the repr of this object in its
985
745
        # _DebugCounter so we have to store all the values used in our repr
986
746
        # method before calling the super init.
987
747
        SmartClientStreamMedium.__init__(self, base)
 
748
        self._read_from = None
 
749
        self._ssh_connection = None
988
750
        self._vendor = vendor
989
 
        self._ssh_connection = None
 
751
        self._write_to = None
 
752
        self._bzr_remote_path = bzr_remote_path
 
753
        # for the benefit of progress making a short description of this
 
754
        # transport
 
755
        self._scheme = 'bzr+ssh'
990
756
 
991
757
    def __repr__(self):
992
 
        if self._ssh_params.port is None:
993
 
            maybe_port = ''
994
 
        else:
995
 
            maybe_port = ':%s' % self._ssh_params.port
996
 
        if self._ssh_params.username is None:
997
 
            maybe_user = ''
998
 
        else:
999
 
            maybe_user = '%s@' % self._ssh_params.username
1000
 
        return "%s(%s://%s%s%s/)" % (
 
758
        return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
1001
759
            self.__class__.__name__,
1002
 
            self._scheme,
1003
 
            maybe_user,
1004
 
            self._ssh_params.host,
1005
 
            maybe_port)
 
760
            self._connected,
 
761
            self._username,
 
762
            self._host,
 
763
            self._port)
1006
764
 
1007
765
    def _accept_bytes(self, bytes):
1008
766
        """See SmartClientStreamMedium.accept_bytes."""
1009
767
        self._ensure_connection()
1010
 
        self._real_medium.accept_bytes(bytes)
 
768
        osutils.until_no_eintr(self._write_to.write, bytes)
 
769
        self._report_activity(len(bytes), 'write')
1011
770
 
1012
771
    def disconnect(self):
1013
772
        """See SmartClientMedium.disconnect()."""
1014
 
        if self._real_medium is not None:
1015
 
            self._real_medium.disconnect()
1016
 
            self._real_medium = None
1017
 
        if self._ssh_connection is not None:
1018
 
            self._ssh_connection.close()
1019
 
            self._ssh_connection = None
 
773
        if not self._connected:
 
774
            return
 
775
        osutils.until_no_eintr(self._read_from.close)
 
776
        osutils.until_no_eintr(self._write_to.close)
 
777
        self._ssh_connection.close()
 
778
        self._connected = False
1020
779
 
1021
780
    def _ensure_connection(self):
1022
781
        """Connect this medium if not already connected."""
1023
 
        if self._real_medium is not None:
 
782
        if self._connected:
1024
783
            return
1025
784
        if self._vendor is None:
1026
785
            vendor = ssh._get_ssh_vendor()
1027
786
        else:
1028
787
            vendor = self._vendor
1029
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1030
 
                self._ssh_params.password, self._ssh_params.host,
1031
 
                self._ssh_params.port,
1032
 
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
788
        self._ssh_connection = vendor.connect_ssh(self._username,
 
789
                self._password, self._host, self._port,
 
790
                command=[self._bzr_remote_path, 'serve', '--inet',
1033
791
                         '--directory=/', '--allow-writes'])
1034
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1035
 
        if io_kind == 'socket':
1036
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1037
 
                self.base, io_object)
1038
 
        elif io_kind == 'pipes':
1039
 
            read_from, write_to = io_object
1040
 
            self._real_medium = SmartSimplePipesClientMedium(
1041
 
                read_from, write_to, self.base)
1042
 
        else:
1043
 
            raise AssertionError(
1044
 
                "Unexpected io_kind %r from %r"
1045
 
                % (io_kind, self._ssh_connection))
1046
 
        for hook in transport.Transport.hooks["post_connect"]:
1047
 
            hook(self)
 
792
        self._read_from, self._write_to = \
 
793
            self._ssh_connection.get_filelike_channels()
 
794
        self._connected = True
1048
795
 
1049
796
    def _flush(self):
1050
797
        """See SmartClientStreamMedium._flush()."""
1051
 
        self._real_medium._flush()
 
798
        self._write_to.flush()
1052
799
 
1053
800
    def _read_bytes(self, count):
1054
801
        """See SmartClientStreamMedium.read_bytes."""
1055
 
        if self._real_medium is None:
 
802
        if not self._connected:
1056
803
            raise errors.MediumNotConnected(self)
1057
 
        return self._real_medium.read_bytes(count)
 
804
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
805
        bytes = osutils.until_no_eintr(self._read_from.read, bytes_to_read)
 
806
        self._report_activity(len(bytes), 'read')
 
807
        return bytes
1058
808
 
1059
809
 
1060
810
# Port 4155 is the default port for bzr://, registered with IANA.
1062
812
BZR_DEFAULT_PORT = 4155
1063
813
 
1064
814
 
1065
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1066
 
    """A client medium using a socket.
1067
 
 
1068
 
    This class isn't usable directly.  Use one of its subclasses instead.
1069
 
    """
1070
 
 
1071
 
    def __init__(self, base):
 
815
class SmartTCPClientMedium(SmartClientStreamMedium):
 
816
    """A client medium using TCP."""
 
817
 
 
818
    def __init__(self, host, port, base):
 
819
        """Creates a client that will connect on the first use."""
1072
820
        SmartClientStreamMedium.__init__(self, base)
 
821
        self._connected = False
 
822
        self._host = host
 
823
        self._port = port
1073
824
        self._socket = None
1074
 
        self._connected = False
1075
825
 
1076
826
    def _accept_bytes(self, bytes):
1077
827
        """See SmartClientMedium.accept_bytes."""
1078
828
        self._ensure_connection()
1079
829
        osutils.send_all(self._socket, bytes, self._report_activity)
1080
830
 
1081
 
    def _ensure_connection(self):
1082
 
        """Connect this medium if not already connected."""
1083
 
        raise NotImplementedError(self._ensure_connection)
1084
 
 
1085
 
    def _flush(self):
1086
 
        """See SmartClientStreamMedium._flush().
1087
 
 
1088
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1089
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1090
 
        future.
1091
 
        """
1092
 
 
1093
 
    def _read_bytes(self, count):
1094
 
        """See SmartClientMedium.read_bytes."""
1095
 
        if not self._connected:
1096
 
            raise errors.MediumNotConnected(self)
1097
 
        return osutils.read_bytes_from_socket(
1098
 
            self._socket, self._report_activity)
1099
 
 
1100
831
    def disconnect(self):
1101
832
        """See SmartClientMedium.disconnect()."""
1102
833
        if not self._connected:
1103
834
            return
1104
 
        self._socket.close()
 
835
        osutils.until_no_eintr(self._socket.close)
1105
836
        self._socket = None
1106
837
        self._connected = False
1107
838
 
1108
 
 
1109
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1110
 
    """A client medium that creates a TCP connection."""
1111
 
 
1112
 
    def __init__(self, host, port, base):
1113
 
        """Creates a client that will connect on the first use."""
1114
 
        SmartClientSocketMedium.__init__(self, base)
1115
 
        self._host = host
1116
 
        self._port = port
1117
 
 
1118
839
    def _ensure_connection(self):
1119
840
        """Connect this medium if not already connected."""
1120
841
        if self._connected:
1126
847
        try:
1127
848
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1128
849
                socket.SOCK_STREAM, 0, 0)
1129
 
        except socket.gaierror as xxx_todo_changeme:
1130
 
            (err_num, err_msg) = xxx_todo_changeme.args
 
850
        except socket.gaierror, (err_num, err_msg):
1131
851
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1132
852
                    (self._host, port, err_msg))
1133
853
        # Initialize err in case there are no addresses returned:
1134
 
        last_err = socket.error("no address found for %s" % self._host)
 
854
        err = socket.error("no address found for %s" % self._host)
1135
855
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1136
856
            try:
1137
857
                self._socket = socket.socket(family, socktype, proto)
1138
858
                self._socket.setsockopt(socket.IPPROTO_TCP,
1139
859
                                        socket.TCP_NODELAY, 1)
1140
860
                self._socket.connect(sockaddr)
1141
 
            except socket.error as err:
 
861
            except socket.error, err:
1142
862
                if self._socket is not None:
1143
863
                    self._socket.close()
1144
864
                self._socket = None
1145
 
                last_err = err
1146
865
                continue
1147
866
            break
1148
867
        if self._socket is None:
1149
868
            # socket errors either have a (string) or (errno, string) as their
1150
869
            # args.
1151
 
            if isinstance(last_err.args, str):
1152
 
                err_msg = last_err.args
 
870
            if type(err.args) is str:
 
871
                err_msg = err.args
1153
872
            else:
1154
 
                err_msg = last_err.args[1]
 
873
                err_msg = err.args[1]
1155
874
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1156
875
                    (self._host, port, err_msg))
1157
876
        self._connected = True
1158
 
        for hook in transport.Transport.hooks["post_connect"]:
1159
 
            hook(self)
1160
 
 
1161
 
 
1162
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1163
 
    """A client medium for an already connected socket.
1164
 
    
1165
 
    Note that this class will assume it "owns" the socket, so it will close it
1166
 
    when its disconnect method is called.
1167
 
    """
1168
 
 
1169
 
    def __init__(self, base, sock):
1170
 
        SmartClientSocketMedium.__init__(self, base)
1171
 
        self._socket = sock
1172
 
        self._connected = True
1173
 
 
1174
 
    def _ensure_connection(self):
1175
 
        # Already connected, by definition!  So nothing to do.
1176
 
        pass
 
877
 
 
878
    def _flush(self):
 
879
        """See SmartClientStreamMedium._flush().
 
880
 
 
881
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
882
        add a means to do a flush, but that can be done in the future.
 
883
        """
 
884
 
 
885
    def _read_bytes(self, count):
 
886
        """See SmartClientMedium.read_bytes."""
 
887
        if not self._connected:
 
888
            raise errors.MediumNotConnected(self)
 
889
        return _read_bytes_from_socket(
 
890
            self._socket.recv, count, self._report_activity)
1177
891
 
1178
892
 
1179
893
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1214
928
        This invokes self._medium._flush to ensure all bytes are transmitted.
1215
929
        """
1216
930
        self._medium._flush()
 
931
 
 
932
 
 
933
def _read_bytes_from_socket(sock, desired_count, report_activity):
 
934
    # We ignore the desired_count because on sockets it's more efficient to
 
935
    # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
936
    try:
 
937
        bytes = osutils.until_no_eintr(sock, _MAX_READ_SIZE)
 
938
    except socket.error, e:
 
939
        if len(e.args) and e.args[0] in (errno.ECONNRESET, 10054):
 
940
            # The connection was closed by the other side.  Callers expect an
 
941
            # empty string to signal end-of-stream.
 
942
            bytes = ''
 
943
        else:
 
944
            raise
 
945
    else:
 
946
        report_activity(len(bytes), 'read')
 
947
    return bytes
 
948