/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: John Ferlito
  • Date: 2009-09-02 04:31:45 UTC
  • mto: (4665.7.1 serve-init)
  • mto: This revision was merged to the branch mainline in revision 4913.
  • Revision ID: johnf@inodes.org-20090902043145-gxdsfw03ilcwbyn5
Add a debian init script for bzr --serve

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
27
import errno
28
 
import io
29
28
import os
 
29
import socket
30
30
import sys
31
 
import time
32
 
 
33
 
try:
34
 
    import _thread
35
 
except ImportError:
36
 
    import thread as _thread
37
 
 
38
 
import breezy
39
 
from ...lazy_import import lazy_import
 
31
import urllib
 
32
 
 
33
from bzrlib.lazy_import import lazy_import
40
34
lazy_import(globals(), """
41
 
import select
42
 
import socket
 
35
import atexit
43
36
import weakref
44
 
 
45
 
from breezy import (
 
37
from bzrlib import (
46
38
    debug,
 
39
    errors,
 
40
    symbol_versioning,
47
41
    trace,
48
 
    transport,
49
42
    ui,
50
43
    urlutils,
51
44
    )
52
 
from breezy.i18n import gettext
53
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
54
 
from breezy.transport import ssh
 
45
from bzrlib.smart import client, protocol, request, vfs
 
46
from bzrlib.transport import ssh
55
47
""")
56
 
from ... import (
57
 
    errors,
58
 
    osutils,
59
 
    )
60
 
 
61
 
# Throughout this module buffer size parameters are either limited to be at
62
 
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
63
 
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
64
 
# from non-sockets as well.
65
 
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
66
 
 
67
 
 
68
 
class HpssVfsRequestNotAllowed(errors.BzrError):
69
 
 
70
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
71
 
            "%(method)s, %(arguments)s.")
72
 
 
73
 
    def __init__(self, method, arguments):
74
 
        self.method = method
75
 
        self.arguments = arguments
 
48
#usually already imported, and getting IllegalScoperReplacer on it here.
 
49
from bzrlib import osutils
 
50
 
 
51
# We must not read any more than 64k at a time so we don't risk "no buffer
 
52
# space available" errors on some platforms.  Windows in particular is likely
 
53
# to give error 10053 or 10055 if we read more than 64k from a socket.
 
54
_MAX_READ_SIZE = 64 * 1024
76
55
 
77
56
 
78
57
def _get_protocol_factory_for_bytes(bytes):
116
95
    :returns: a tuple of two strs: (line, excess)
117
96
    """
118
97
    newline_pos = -1
119
 
    bytes = b''
 
98
    bytes = ''
120
99
    while newline_pos == -1:
121
100
        new_bytes = read_bytes_func(1)
122
101
        bytes += new_bytes
123
 
        if new_bytes == b'':
 
102
        if new_bytes == '':
124
103
            # Ran out of bytes before receiving a complete line.
125
 
            return bytes, b''
126
 
        newline_pos = bytes.find(b'\n')
127
 
    line = bytes[:newline_pos + 1]
128
 
    excess = bytes[newline_pos + 1:]
 
104
            return bytes, ''
 
105
        newline_pos = bytes.find('\n')
 
106
    line = bytes[:newline_pos+1]
 
107
    excess = bytes[newline_pos+1:]
129
108
    return line, excess
130
109
 
131
110
 
135
114
    def __init__(self):
136
115
        self._push_back_buffer = None
137
116
 
138
 
    def _push_back(self, data):
 
117
    def _push_back(self, bytes):
139
118
        """Return unused bytes to the medium, because they belong to the next
140
119
        request(s).
141
120
 
142
121
        This sets the _push_back_buffer to the given bytes.
143
122
        """
144
 
        if not isinstance(data, bytes):
145
 
            raise TypeError(data)
146
123
        if self._push_back_buffer is not None:
147
124
            raise AssertionError(
148
125
                "_push_back called when self._push_back_buffer is %r"
149
126
                % (self._push_back_buffer,))
150
 
        if data == b'':
 
127
        if bytes == '':
151
128
            return
152
 
        self._push_back_buffer = data
 
129
        self._push_back_buffer = bytes
153
130
 
154
131
    def _get_push_back_buffer(self):
155
 
        if self._push_back_buffer == b'':
 
132
        if self._push_back_buffer == '':
156
133
            raise AssertionError(
157
134
                '%s._push_back_buffer should never be the empty string, '
158
135
                'which can be confused with EOF' % (self,))
199
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
200
177
 
201
178
 
202
 
_bad_file_descriptor = (errno.EBADF,)
203
 
if sys.platform == 'win32':
204
 
    # Given on Windows if you pass a closed socket to select.select. Probably
205
 
    # also given if you pass a file handle to select.
206
 
    WSAENOTSOCK = 10038
207
 
    _bad_file_descriptor += (WSAENOTSOCK,)
208
 
 
209
 
 
210
179
class SmartServerStreamMedium(SmartMedium):
211
180
    """Handles smart commands coming over a stream.
212
181
 
225
194
        the stream.  See also the _push_back method.
226
195
    """
227
196
 
228
 
    _timer = time.time
229
 
 
230
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
231
198
        """Construct new server.
232
199
 
233
200
        :param backing_transport: Transport for the directory served.
236
203
        self.backing_transport = backing_transport
237
204
        self.root_client_path = root_client_path
238
205
        self.finished = False
239
 
        if timeout is None:
240
 
            raise AssertionError('You must supply a timeout.')
241
 
        self._client_timeout = timeout
242
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
243
206
        SmartMedium.__init__(self)
244
207
 
245
208
    def serve(self):
251
214
            while not self.finished:
252
215
                server_protocol = self._build_protocol()
253
216
                self._serve_one_request(server_protocol)
254
 
        except errors.ConnectionTimeout as e:
255
 
            trace.note('%s' % (e,))
256
 
            trace.log_exception_quietly()
257
 
            self._disconnect_client()
258
 
            # We reported it, no reason to make a big fuss.
259
 
            return
260
 
        except Exception as e:
 
217
        except Exception, e:
261
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
262
219
            raise
263
 
        self._disconnect_client()
264
 
 
265
 
    def _stop_gracefully(self):
266
 
        """When we finish this message, stop looking for more."""
267
 
        trace.mutter('Stopping %s' % (self,))
268
 
        self.finished = True
269
 
 
270
 
    def _disconnect_client(self):
271
 
        """Close the current connection. We stopped due to a timeout/etc."""
272
 
        # The default implementation is a no-op, because that is all we used to
273
 
        # do when disconnecting from a client. I suppose we never had the
274
 
        # *server* initiate a disconnect, before
275
 
 
276
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
277
 
        """Wait for more bytes to be read, but timeout if none available.
278
 
 
279
 
        This allows us to detect idle connections, and stop trying to read from
280
 
        them, without setting the socket itself to non-blocking. This also
281
 
        allows us to specify when we watch for idle timeouts.
282
 
 
283
 
        :return: Did we timeout? (True if we timed out, False if there is data
284
 
            to be read)
285
 
        """
286
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
287
220
 
288
221
    def _build_protocol(self):
289
222
        """Identifies the version of the incoming request, and returns an
294
227
 
295
228
        :returns: a SmartServerRequestProtocol.
296
229
        """
297
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
298
 
        if self.finished:
299
 
            # We're stopping, so don't try to do any more work
300
 
            return None
301
230
        bytes = self._get_line()
302
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
303
232
        protocol = protocol_factory(
305
234
        protocol.accept_bytes(unused_bytes)
306
235
        return protocol
307
236
 
308
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
309
 
        """select() on a file descriptor, waiting for nonblocking read()
310
 
 
311
 
        This will raise a ConnectionTimeout exception if we do not get a
312
 
        readable handle before timeout_seconds.
313
 
        :return: None
314
 
        """
315
 
        t_end = self._timer() + timeout_seconds
316
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
317
 
        rs = xs = None
318
 
        while not rs and not xs and self._timer() < t_end:
319
 
            if self.finished:
320
 
                return
321
 
            try:
322
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
323
 
            except (select.error, socket.error) as e:
324
 
                err = getattr(e, 'errno', None)
325
 
                if err is None and getattr(e, 'args', None) is not None:
326
 
                    # select.error doesn't have 'errno', it just has args[0]
327
 
                    err = e.args[0]
328
 
                if err in _bad_file_descriptor:
329
 
                    return  # Not a socket indicates read() will fail
330
 
                elif err == errno.EINTR:
331
 
                    # Interrupted, keep looping.
332
 
                    continue
333
 
                raise
334
 
            except ValueError:
335
 
                return  # Socket may already be closed
336
 
        if rs or xs:
337
 
            return
338
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
339
 
                                       % (timeout_seconds,))
340
 
 
341
237
    def _serve_one_request(self, protocol):
342
238
        """Read one request from input, process, send back a response.
343
239
 
344
240
        :param protocol: a SmartServerRequestProtocol.
345
241
        """
346
 
        if protocol is None:
347
 
            return
348
242
        try:
349
243
            self._serve_one_request_unguarded(protocol)
350
244
        except KeyboardInterrupt:
351
245
            raise
352
 
        except Exception as e:
 
246
        except Exception, e:
353
247
            self.terminate_due_to_error()
354
248
 
355
249
    def terminate_due_to_error(self):
366
260
 
367
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
368
262
 
369
 
    def __init__(self, sock, backing_transport, root_client_path='/',
370
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
371
264
        """Constructor.
372
265
 
373
266
        :param sock: the socket the server will read from.  It will be put
374
267
            into blocking mode.
375
268
        """
376
269
        SmartServerStreamMedium.__init__(
377
 
            self, backing_transport, root_client_path=root_client_path,
378
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
379
271
        sock.setblocking(True)
380
272
        self.socket = sock
381
 
        # Get the getpeername now, as we might be closed later when we care.
382
 
        try:
383
 
            self._client_info = sock.getpeername()
384
 
        except socket.error:
385
 
            self._client_info = '<unknown>'
386
 
 
387
 
    def __str__(self):
388
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
389
 
 
390
 
    def __repr__(self):
391
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
392
 
                                     self._client_info)
393
273
 
394
274
    def _serve_one_request_unguarded(self, protocol):
395
275
        while protocol.next_read_size():
396
276
            # We can safely try to read large chunks.  If there is less data
397
 
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
398
 
            # short read immediately rather than block.
399
 
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
400
 
            if bytes == b'':
 
277
            # than _MAX_READ_SIZE ready, the socket wil just return a short
 
278
            # read immediately rather than block.
 
279
            bytes = self.read_bytes(_MAX_READ_SIZE)
 
280
            if bytes == '':
401
281
                self.finished = True
402
282
                return
403
283
            protocol.accept_bytes(bytes)
404
284
 
405
285
        self._push_back(protocol.unused_data)
406
286
 
407
 
    def _disconnect_client(self):
408
 
        """Close the current connection. We stopped due to a timeout/etc."""
409
 
        self.socket.close()
410
 
 
411
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
412
 
        """Wait for more bytes to be read, but timeout if none available.
413
 
 
414
 
        This allows us to detect idle connections, and stop trying to read from
415
 
        them, without setting the socket itself to non-blocking. This also
416
 
        allows us to specify when we watch for idle timeouts.
417
 
 
418
 
        :return: None, this will raise ConnectionTimeout if we time out before
419
 
            data is available.
420
 
        """
421
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
422
 
 
423
287
    def _read_bytes(self, desired_count):
424
 
        return osutils.read_bytes_from_socket(
425
 
            self.socket, self._report_activity)
 
288
        return _read_bytes_from_socket(
 
289
            self.socket.recv, desired_count, self._report_activity)
426
290
 
427
291
    def terminate_due_to_error(self):
428
292
        # TODO: This should log to a server log file, but no such thing
431
295
        self.finished = True
432
296
 
433
297
    def _write_out(self, bytes):
434
 
        tstart = osutils.perf_counter()
435
298
        osutils.send_all(self.socket, bytes, self._report_activity)
436
 
        if 'hpss' in debug.debug_flags:
437
 
            thread_id = _thread.get_ident()
438
 
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
439
 
                         % ('wrote', thread_id, len(bytes),
440
 
                            osutils.perf_counter() - tstart))
441
299
 
442
300
 
443
301
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
444
302
 
445
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
303
    def __init__(self, in_file, out_file, backing_transport):
446
304
        """Construct new server.
447
305
 
448
306
        :param in_file: Python file from which requests can be read.
449
307
        :param out_file: Python file to write responses.
450
308
        :param backing_transport: Transport for the directory served.
451
309
        """
452
 
        SmartServerStreamMedium.__init__(self, backing_transport,
453
 
                                         timeout=timeout)
 
310
        SmartServerStreamMedium.__init__(self, backing_transport)
454
311
        if sys.platform == 'win32':
455
312
            # force binary mode for files
456
313
            import msvcrt
461
318
        self._in = in_file
462
319
        self._out = out_file
463
320
 
464
 
    def serve(self):
465
 
        """See SmartServerStreamMedium.serve"""
466
 
        # This is the regular serve, except it adds signal trapping for soft
467
 
        # shutdown.
468
 
        stop_gracefully = self._stop_gracefully
469
 
        signals.register_on_hangup(id(self), stop_gracefully)
470
 
        try:
471
 
            return super(SmartServerPipeStreamMedium, self).serve()
472
 
        finally:
473
 
            signals.unregister_on_hangup(id(self))
474
 
 
475
321
    def _serve_one_request_unguarded(self, protocol):
476
322
        while True:
477
323
            # We need to be careful not to read past the end of the current
483
329
                self._out.flush()
484
330
                return
485
331
            bytes = self.read_bytes(bytes_to_read)
486
 
            if bytes == b'':
 
332
            if bytes == '':
487
333
                # Connection has been closed.
488
334
                self.finished = True
489
335
                self._out.flush()
490
336
                return
491
337
            protocol.accept_bytes(bytes)
492
338
 
493
 
    def _disconnect_client(self):
494
 
        self._in.close()
495
 
        self._out.flush()
496
 
        self._out.close()
497
 
 
498
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
499
 
        """Wait for more bytes to be read, but timeout if none available.
500
 
 
501
 
        This allows us to detect idle connections, and stop trying to read from
502
 
        them, without setting the socket itself to non-blocking. This also
503
 
        allows us to specify when we watch for idle timeouts.
504
 
 
505
 
        :return: None, this will raise ConnectionTimeout if we time out before
506
 
            data is available.
507
 
        """
508
 
        if (getattr(self._in, 'fileno', None) is None
509
 
                or sys.platform == 'win32'):
510
 
            # You can't select() file descriptors on Windows.
511
 
            return
512
 
        try:
513
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
514
 
        except io.UnsupportedOperation:
515
 
            return
516
 
 
517
339
    def _read_bytes(self, desired_count):
518
340
        return self._in.read(desired_count)
519
341
 
647
469
 
648
470
    def read_line(self):
649
471
        line = self._read_line()
650
 
        if not line.endswith(b'\n'):
 
472
        if not line.endswith('\n'):
651
473
            # end of file encountered reading from server
652
474
            raise errors.ConnectionReset(
653
475
                "Unexpected end of message. Please check connectivity "
663
485
        return self._medium._get_line()
664
486
 
665
487
 
666
 
class _VfsRefuser(object):
667
 
    """An object that refuses all VFS requests.
668
 
 
669
 
    """
670
 
 
671
 
    def __init__(self):
672
 
        client._SmartClient.hooks.install_named_hook(
673
 
            'call', self.check_vfs, 'vfs refuser')
674
 
 
675
 
    def check_vfs(self, params):
676
 
        try:
677
 
            request_method = request.request_handlers.get(params.method)
678
 
        except KeyError:
679
 
            # A method we don't know about doesn't count as a VFS method.
680
 
            return
681
 
        if issubclass(request_method, vfs.VfsRequest):
682
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
683
 
 
684
 
 
685
488
class _DebugCounter(object):
686
489
    """An object that counts the HPSS calls made to each client medium.
687
490
 
688
 
    When a medium is garbage-collected, or failing that when
689
 
    breezy.global_state exits, the total number of calls made on that medium
690
 
    are reported via trace.note.
 
491
    When a medium is garbage-collected, or failing that when atexit functions
 
492
    are run, the total number of calls made on that medium are reported via
 
493
    trace.note.
691
494
    """
692
495
 
693
496
    def __init__(self):
694
497
        self.counts = weakref.WeakKeyDictionary()
695
498
        client._SmartClient.hooks.install_named_hook(
696
499
            'call', self.increment_call_count, 'hpss call counter')
697
 
        breezy.get_global_state().exit_stack.callback(self.flush_all)
 
500
        atexit.register(self.flush_all)
698
501
 
699
502
    def track(self, medium):
700
503
        """Start tracking calls made to a medium.
734
537
        value['count'] = 0
735
538
        value['vfs_count'] = 0
736
539
        if count != 0:
737
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
738
 
                       count, vfs_count, medium_repr))
 
540
            trace.note('HPSS calls: %d (%d vfs) %s',
 
541
                       count, vfs_count, medium_repr)
739
542
 
740
543
    def flush_all(self):
741
544
        for ref in list(self.counts.keys()):
742
545
            self.done(ref)
743
546
 
744
 
 
745
547
_debug_counter = None
746
 
_vfs_refuser = None
747
548
 
748
549
 
749
550
class SmartClientMedium(SmartMedium):
766
567
            if _debug_counter is None:
767
568
                _debug_counter = _DebugCounter()
768
569
            _debug_counter.track(self)
769
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
770
 
            global _vfs_refuser
771
 
            if _vfs_refuser is None:
772
 
                _vfs_refuser = _VfsRefuser()
773
570
 
774
571
    def _is_remote_before(self, version_tuple):
775
572
        """Is it possible the remote side supports RPCs for a given version?
799
596
        :seealso: _is_remote_before
800
597
        """
801
598
        if (self._remote_version_is_before is not None and
802
 
                version_tuple > self._remote_version_is_before):
 
599
            version_tuple > self._remote_version_is_before):
803
600
            # We have been told that the remote side is older than some version
804
601
            # which is newer than a previously supplied older-than version.
805
602
            # This indicates that some smart verb call is not guarded
806
603
            # appropriately (it should simply not have been tried).
807
 
            trace.mutter(
 
604
            raise AssertionError(
808
605
                "_remember_remote_is_before(%r) called, but "
809
 
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
810
 
            if 'hpss' in debug.debug_flags:
811
 
                ui.ui_factory.show_warning(
812
 
                    "_remember_remote_is_before(%r) called, but "
813
 
                    "_remember_remote_is_before(%r) was called previously."
814
 
                    % (version_tuple, self._remote_version_is_before))
815
 
            return
 
606
                "_remember_remote_is_before(%r) was called previously."
 
607
                % (version_tuple, self._remote_version_is_before))
816
608
        self._remote_version_is_before = version_tuple
817
609
 
818
610
    def protocol_version(self):
824
616
                medium_request = self.get_request()
825
617
                # Send a 'hello' request in protocol version one, for maximum
826
618
                # backwards compatibility.
827
 
                client_protocol = protocol.SmartClientRequestProtocolOne(
828
 
                    medium_request)
 
619
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
829
620
                client_protocol.query_version()
830
621
                self._done_hello = True
831
 
            except errors.SmartProtocolError as e:
 
622
            except errors.SmartProtocolError, e:
832
623
                # Cache the error, just like we would cache a successful
833
624
                # result.
834
625
                self._protocol_version_error = e
866
657
        """
867
658
        medium_base = urlutils.join(self.base, '/')
868
659
        rel_url = urlutils.relative_url(medium_base, transport.base)
869
 
        return urlutils.unquote(rel_url)
 
660
        return urllib.unquote(rel_url)
870
661
 
871
662
 
872
663
class SmartClientStreamMedium(SmartClientMedium):
907
698
        """
908
699
        return SmartClientStreamMediumRequest(self)
909
700
 
910
 
    def reset(self):
911
 
        """We have been disconnected, reset current state.
912
 
 
913
 
        This resets things like _current_request and connected state.
914
 
        """
915
 
        self.disconnect()
916
 
        self._current_request = None
917
 
 
918
701
 
919
702
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
920
703
    """A client medium using simple pipes.
927
710
        self._readable_pipe = readable_pipe
928
711
        self._writeable_pipe = writeable_pipe
929
712
 
930
 
    def _accept_bytes(self, data):
 
713
    def _accept_bytes(self, bytes):
931
714
        """See SmartClientStreamMedium.accept_bytes."""
932
 
        try:
933
 
            self._writeable_pipe.write(data)
934
 
        except IOError as e:
935
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
936
 
                raise errors.ConnectionReset(
937
 
                    "Error trying to write to subprocess", e)
938
 
            raise
939
 
        self._report_activity(len(data), 'write')
 
715
        self._writeable_pipe.write(bytes)
 
716
        self._report_activity(len(bytes), 'write')
940
717
 
941
718
    def _flush(self):
942
719
        """See SmartClientStreamMedium._flush()."""
943
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
944
 
        #       However, testing shows that even when the child process is
945
 
        #       gone, this doesn't error.
946
720
        self._writeable_pipe.flush()
947
721
 
948
722
    def _read_bytes(self, count):
949
723
        """See SmartClientStreamMedium._read_bytes."""
950
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
951
 
        data = self._readable_pipe.read(bytes_to_read)
952
 
        self._report_activity(len(data), 'read')
953
 
        return data
954
 
 
955
 
 
956
 
class SSHParams(object):
957
 
    """A set of parameters for starting a remote bzr via SSH."""
 
724
        bytes = self._readable_pipe.read(count)
 
725
        self._report_activity(len(bytes), 'read')
 
726
        return bytes
 
727
 
 
728
 
 
729
class SmartSSHClientMedium(SmartClientStreamMedium):
 
730
    """A client medium using SSH."""
958
731
 
959
732
    def __init__(self, host, port=None, username=None, password=None,
960
 
                 bzr_remote_path='bzr'):
961
 
        self.host = host
962
 
        self.port = port
963
 
        self.username = username
964
 
        self.password = password
965
 
        self.bzr_remote_path = bzr_remote_path
966
 
 
967
 
 
968
 
class SmartSSHClientMedium(SmartClientStreamMedium):
969
 
    """A client medium using SSH.
970
 
 
971
 
    It delegates IO to a SmartSimplePipesClientMedium or
972
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
973
 
    """
974
 
 
975
 
    def __init__(self, base, ssh_params, vendor=None):
 
733
            base=None, vendor=None, bzr_remote_path=None):
976
734
        """Creates a client that will connect on the first use.
977
735
 
978
 
        :param ssh_params: A SSHParams instance.
979
736
        :param vendor: An optional override for the ssh vendor to use. See
980
 
            breezy.transport.ssh for details on ssh vendors.
 
737
            bzrlib.transport.ssh for details on ssh vendors.
981
738
        """
982
 
        self._real_medium = None
983
 
        self._ssh_params = ssh_params
984
 
        # for the benefit of progress making a short description of this
985
 
        # transport
986
 
        self._scheme = 'bzr+ssh'
 
739
        self._connected = False
 
740
        self._host = host
 
741
        self._password = password
 
742
        self._port = port
 
743
        self._username = username
987
744
        # SmartClientStreamMedium stores the repr of this object in its
988
745
        # _DebugCounter so we have to store all the values used in our repr
989
746
        # method before calling the super init.
990
747
        SmartClientStreamMedium.__init__(self, base)
 
748
        self._read_from = None
 
749
        self._ssh_connection = None
991
750
        self._vendor = vendor
992
 
        self._ssh_connection = None
 
751
        self._write_to = None
 
752
        self._bzr_remote_path = bzr_remote_path
 
753
        # for the benefit of progress making a short description of this
 
754
        # transport
 
755
        self._scheme = 'bzr+ssh'
993
756
 
994
757
    def __repr__(self):
995
 
        if self._ssh_params.port is None:
996
 
            maybe_port = ''
997
 
        else:
998
 
            maybe_port = ':%s' % self._ssh_params.port
999
 
        if self._ssh_params.username is None:
1000
 
            maybe_user = ''
1001
 
        else:
1002
 
            maybe_user = '%s@' % self._ssh_params.username
1003
 
        return "%s(%s://%s%s%s/)" % (
 
758
        return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
1004
759
            self.__class__.__name__,
1005
 
            self._scheme,
1006
 
            maybe_user,
1007
 
            self._ssh_params.host,
1008
 
            maybe_port)
 
760
            self._connected,
 
761
            self._username,
 
762
            self._host,
 
763
            self._port)
1009
764
 
1010
765
    def _accept_bytes(self, bytes):
1011
766
        """See SmartClientStreamMedium.accept_bytes."""
1012
767
        self._ensure_connection()
1013
 
        self._real_medium.accept_bytes(bytes)
 
768
        self._write_to.write(bytes)
 
769
        self._report_activity(len(bytes), 'write')
1014
770
 
1015
771
    def disconnect(self):
1016
772
        """See SmartClientMedium.disconnect()."""
1017
 
        if self._real_medium is not None:
1018
 
            self._real_medium.disconnect()
1019
 
            self._real_medium = None
1020
 
        if self._ssh_connection is not None:
1021
 
            self._ssh_connection.close()
1022
 
            self._ssh_connection = None
 
773
        if not self._connected:
 
774
            return
 
775
        self._read_from.close()
 
776
        self._write_to.close()
 
777
        self._ssh_connection.close()
 
778
        self._connected = False
1023
779
 
1024
780
    def _ensure_connection(self):
1025
781
        """Connect this medium if not already connected."""
1026
 
        if self._real_medium is not None:
 
782
        if self._connected:
1027
783
            return
1028
784
        if self._vendor is None:
1029
785
            vendor = ssh._get_ssh_vendor()
1030
786
        else:
1031
787
            vendor = self._vendor
1032
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1033
 
                                                  self._ssh_params.password, self._ssh_params.host,
1034
 
                                                  self._ssh_params.port,
1035
 
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1036
 
                                                           '--directory=/', '--allow-writes'])
1037
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1038
 
        if io_kind == 'socket':
1039
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1040
 
                self.base, io_object)
1041
 
        elif io_kind == 'pipes':
1042
 
            read_from, write_to = io_object
1043
 
            self._real_medium = SmartSimplePipesClientMedium(
1044
 
                read_from, write_to, self.base)
1045
 
        else:
1046
 
            raise AssertionError(
1047
 
                "Unexpected io_kind %r from %r"
1048
 
                % (io_kind, self._ssh_connection))
1049
 
        for hook in transport.Transport.hooks["post_connect"]:
1050
 
            hook(self)
 
788
        self._ssh_connection = vendor.connect_ssh(self._username,
 
789
                self._password, self._host, self._port,
 
790
                command=[self._bzr_remote_path, 'serve', '--inet',
 
791
                         '--directory=/', '--allow-writes'])
 
792
        self._read_from, self._write_to = \
 
793
            self._ssh_connection.get_filelike_channels()
 
794
        self._connected = True
1051
795
 
1052
796
    def _flush(self):
1053
797
        """See SmartClientStreamMedium._flush()."""
1054
 
        self._real_medium._flush()
 
798
        self._write_to.flush()
1055
799
 
1056
800
    def _read_bytes(self, count):
1057
801
        """See SmartClientStreamMedium.read_bytes."""
1058
 
        if self._real_medium is None:
 
802
        if not self._connected:
1059
803
            raise errors.MediumNotConnected(self)
1060
 
        return self._real_medium.read_bytes(count)
 
804
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
805
        bytes = self._read_from.read(bytes_to_read)
 
806
        self._report_activity(len(bytes), 'read')
 
807
        return bytes
1061
808
 
1062
809
 
1063
810
# Port 4155 is the default port for bzr://, registered with IANA.
1065
812
BZR_DEFAULT_PORT = 4155
1066
813
 
1067
814
 
1068
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1069
 
    """A client medium using a socket.
1070
 
 
1071
 
    This class isn't usable directly.  Use one of its subclasses instead.
1072
 
    """
1073
 
 
1074
 
    def __init__(self, base):
 
815
class SmartTCPClientMedium(SmartClientStreamMedium):
 
816
    """A client medium using TCP."""
 
817
 
 
818
    def __init__(self, host, port, base):
 
819
        """Creates a client that will connect on the first use."""
1075
820
        SmartClientStreamMedium.__init__(self, base)
 
821
        self._connected = False
 
822
        self._host = host
 
823
        self._port = port
1076
824
        self._socket = None
1077
 
        self._connected = False
1078
825
 
1079
826
    def _accept_bytes(self, bytes):
1080
827
        """See SmartClientMedium.accept_bytes."""
1081
828
        self._ensure_connection()
1082
829
        osutils.send_all(self._socket, bytes, self._report_activity)
1083
830
 
1084
 
    def _ensure_connection(self):
1085
 
        """Connect this medium if not already connected."""
1086
 
        raise NotImplementedError(self._ensure_connection)
1087
 
 
1088
 
    def _flush(self):
1089
 
        """See SmartClientStreamMedium._flush().
1090
 
 
1091
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1092
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1093
 
        future.
1094
 
        """
1095
 
 
1096
 
    def _read_bytes(self, count):
1097
 
        """See SmartClientMedium.read_bytes."""
1098
 
        if not self._connected:
1099
 
            raise errors.MediumNotConnected(self)
1100
 
        return osutils.read_bytes_from_socket(
1101
 
            self._socket, self._report_activity)
1102
 
 
1103
831
    def disconnect(self):
1104
832
        """See SmartClientMedium.disconnect()."""
1105
833
        if not self._connected:
1108
836
        self._socket = None
1109
837
        self._connected = False
1110
838
 
1111
 
 
1112
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1113
 
    """A client medium that creates a TCP connection."""
1114
 
 
1115
 
    def __init__(self, host, port, base):
1116
 
        """Creates a client that will connect on the first use."""
1117
 
        SmartClientSocketMedium.__init__(self, base)
1118
 
        self._host = host
1119
 
        self._port = port
1120
 
 
1121
839
    def _ensure_connection(self):
1122
840
        """Connect this medium if not already connected."""
1123
841
        if self._connected:
1128
846
            port = int(self._port)
1129
847
        try:
1130
848
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1131
 
                                           socket.SOCK_STREAM, 0, 0)
1132
 
        except socket.gaierror as xxx_todo_changeme:
1133
 
            (err_num, err_msg) = xxx_todo_changeme.args
 
849
                socket.SOCK_STREAM, 0, 0)
 
850
        except socket.gaierror, (err_num, err_msg):
1134
851
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1135
 
                                         (self._host, port, err_msg))
 
852
                    (self._host, port, err_msg))
1136
853
        # Initialize err in case there are no addresses returned:
1137
 
        last_err = socket.error("no address found for %s" % self._host)
 
854
        err = socket.error("no address found for %s" % self._host)
1138
855
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1139
856
            try:
1140
857
                self._socket = socket.socket(family, socktype, proto)
1141
858
                self._socket.setsockopt(socket.IPPROTO_TCP,
1142
859
                                        socket.TCP_NODELAY, 1)
1143
860
                self._socket.connect(sockaddr)
1144
 
            except socket.error as err:
 
861
            except socket.error, err:
1145
862
                if self._socket is not None:
1146
863
                    self._socket.close()
1147
864
                self._socket = None
1148
 
                last_err = err
1149
865
                continue
1150
866
            break
1151
867
        if self._socket is None:
1152
868
            # socket errors either have a (string) or (errno, string) as their
1153
869
            # args.
1154
 
            if isinstance(last_err.args, str):
1155
 
                err_msg = last_err.args
 
870
            if type(err.args) is str:
 
871
                err_msg = err.args
1156
872
            else:
1157
 
                err_msg = last_err.args[1]
 
873
                err_msg = err.args[1]
1158
874
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1159
 
                                         (self._host, port, err_msg))
1160
 
        self._connected = True
1161
 
        for hook in transport.Transport.hooks["post_connect"]:
1162
 
            hook(self)
1163
 
 
1164
 
 
1165
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1166
 
    """A client medium for an already connected socket.
1167
 
 
1168
 
    Note that this class will assume it "owns" the socket, so it will close it
1169
 
    when its disconnect method is called.
1170
 
    """
1171
 
 
1172
 
    def __init__(self, base, sock):
1173
 
        SmartClientSocketMedium.__init__(self, base)
1174
 
        self._socket = sock
1175
 
        self._connected = True
1176
 
 
1177
 
    def _ensure_connection(self):
1178
 
        # Already connected, by definition!  So nothing to do.
1179
 
        pass
 
875
                    (self._host, port, err_msg))
 
876
        self._connected = True
 
877
 
 
878
    def _flush(self):
 
879
        """See SmartClientStreamMedium._flush().
 
880
 
 
881
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
882
        add a means to do a flush, but that can be done in the future.
 
883
        """
 
884
 
 
885
    def _read_bytes(self, count):
 
886
        """See SmartClientMedium.read_bytes."""
 
887
        if not self._connected:
 
888
            raise errors.MediumNotConnected(self)
 
889
        return _read_bytes_from_socket(
 
890
            self._socket.recv, count, self._report_activity)
1180
891
 
1181
892
 
1182
893
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1217
928
        This invokes self._medium._flush to ensure all bytes are transmitted.
1218
929
        """
1219
930
        self._medium._flush()
 
931
 
 
932
 
 
933
def _read_bytes_from_socket(sock, desired_count, report_activity):
 
934
    # We ignore the desired_count because on sockets it's more efficient to
 
935
    # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
936
    try:
 
937
        bytes = osutils.until_no_eintr(sock, _MAX_READ_SIZE)
 
938
    except socket.error, e:
 
939
        if len(e.args) and e.args[0] in (errno.ECONNRESET, 10054):
 
940
            # The connection was closed by the other side.  Callers expect an
 
941
            # empty string to signal end-of-stream.
 
942
            bytes = ''
 
943
        else:
 
944
            raise
 
945
    else:
 
946
        report_activity(len(bytes), 'read')
 
947
    return bytes
 
948