24
24
bzrlib/transport/smart/__init__.py.
31
from bzrlib import errors
32
from bzrlib.smart.protocol import (
34
SmartServerRequestProtocolOne,
35
SmartServerRequestProtocolTwo,
33
from bzrlib.lazy_import import lazy_import
34
lazy_import(globals(), """
39
from bzrlib.transport import ssh
40
except errors.ParamikoNotPresent:
41
# no paramiko. SmartSSHClientMedium will break.
45
class SmartServerStreamMedium(object):
46
from bzrlib.smart import client, protocol, request, vfs
47
from bzrlib.transport import ssh
51
# We must not read any more than 64k at a time so we don't risk "no buffer
52
# space available" errors on some platforms. Windows in particular is likely
53
# to give error 10053 or 10055 if we read more than 64k from a socket.
54
_MAX_READ_SIZE = 64 * 1024
57
def _get_protocol_factory_for_bytes(bytes):
58
"""Determine the right protocol factory for 'bytes'.
60
This will return an appropriate protocol factory depending on the version
61
of the protocol being used, as determined by inspecting the given bytes.
62
The bytes should have at least one newline byte (i.e. be a whole line),
63
otherwise it's possible that a request will be incorrectly identified as
66
Typical use would be::
68
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
69
server_protocol = factory(transport, write_func, root_client_path)
70
server_protocol.accept_bytes(unused_bytes)
72
:param bytes: a str of bytes of the start of the request.
73
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
74
a callable that takes three args: transport, write_func,
75
root_client_path. unused_bytes are any bytes that were not part of a
76
protocol version marker.
78
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
79
protocol_factory = protocol.build_server_protocol_three
80
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
81
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
82
protocol_factory = protocol.SmartServerRequestProtocolTwo
83
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
85
protocol_factory = protocol.SmartServerRequestProtocolOne
86
return protocol_factory, bytes
89
def _get_line(read_bytes_func):
90
"""Read bytes using read_bytes_func until a newline byte.
92
This isn't particularly efficient, so should only be used when the
93
expected size of the line is quite short.
95
:returns: a tuple of two strs: (line, excess)
99
while newline_pos == -1:
100
new_bytes = read_bytes_func(1)
103
# Ran out of bytes before receiving a complete line.
105
newline_pos = bytes.find('\n')
106
line = bytes[:newline_pos+1]
107
excess = bytes[newline_pos+1:]
111
class SmartMedium(object):
112
"""Base class for smart protocol media, both client- and server-side."""
115
self._push_back_buffer = None
117
def _push_back(self, bytes):
118
"""Return unused bytes to the medium, because they belong to the next
121
This sets the _push_back_buffer to the given bytes.
123
if self._push_back_buffer is not None:
124
raise AssertionError(
125
"_push_back called when self._push_back_buffer is %r"
126
% (self._push_back_buffer,))
129
self._push_back_buffer = bytes
131
def _get_push_back_buffer(self):
132
if self._push_back_buffer == '':
133
raise AssertionError(
134
'%s._push_back_buffer should never be the empty string, '
135
'which can be confused with EOF' % (self,))
136
bytes = self._push_back_buffer
137
self._push_back_buffer = None
140
def read_bytes(self, desired_count):
141
"""Read some bytes from this medium.
143
:returns: some bytes, possibly more or less than the number requested
144
in 'desired_count' depending on the medium.
146
if self._push_back_buffer is not None:
147
return self._get_push_back_buffer()
148
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
149
return self._read_bytes(bytes_to_read)
151
def _read_bytes(self, count):
152
raise NotImplementedError(self._read_bytes)
155
"""Read bytes from this request's response until a newline byte.
157
This isn't particularly efficient, so should only be used when the
158
expected size of the line is quite short.
160
:returns: a string of bytes ending in a newline (byte 0x0A).
162
line, excess = _get_line(self.read_bytes)
163
self._push_back(excess)
166
def _report_activity(self, bytes, direction):
167
"""Notify that this medium has activity.
169
Implementations should call this from all methods that actually do IO.
170
Be careful that it's not called twice, if one method is implemented on
173
:param bytes: Number of bytes read or written.
174
:param direction: 'read' or 'write' or None.
176
ui.ui_factory.report_transport_activity(self, bytes, direction)
179
class SmartServerStreamMedium(SmartMedium):
46
180
"""Handles smart commands coming over a stream.
48
182
The stream may be a pipe connected to sshd, or a tcp socket, or an
113
250
"""Called when an unhandled exception from the protocol occurs."""
114
251
raise NotImplementedError(self.terminate_due_to_error)
116
def _get_bytes(self, desired_count):
253
def _read_bytes(self, desired_count):
117
254
"""Get some bytes from the medium.
119
256
:param desired_count: number of bytes we want to read.
121
raise NotImplementedError(self._get_bytes)
124
"""Read bytes from this request's response until a newline byte.
126
This isn't particularly efficient, so should only be used when the
127
expected size of the line is quite short.
129
:returns: a string of bytes ending in a newline (byte 0x0A).
131
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
133
while not line or line[-1] != '\n':
134
new_char = self._get_bytes(1)
137
# Ran out of bytes before receiving a complete line.
258
raise NotImplementedError(self._read_bytes)
142
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
144
def __init__(self, sock, backing_transport):
263
def __init__(self, sock, backing_transport, root_client_path='/'):
147
266
:param sock: the socket the server will read from. It will be put
148
267
into blocking mode.
150
SmartServerStreamMedium.__init__(self, backing_transport)
269
SmartServerStreamMedium.__init__(
270
self, backing_transport, root_client_path=root_client_path)
152
271
sock.setblocking(True)
153
272
self.socket = sock
155
274
def _serve_one_request_unguarded(self, protocol):
156
275
while protocol.next_read_size():
158
protocol.accept_bytes(self.push_back)
161
bytes = self._get_bytes(4096)
165
protocol.accept_bytes(bytes)
167
self.push_back = protocol.excess_buffer
169
def _get_bytes(self, desired_count):
276
# We can safely try to read large chunks. If there is less data
277
# than _MAX_READ_SIZE ready, the socket wil just return a short
278
# read immediately rather than block.
279
bytes = self.read_bytes(_MAX_READ_SIZE)
283
protocol.accept_bytes(bytes)
285
self._push_back(protocol.unused_data)
287
def _read_bytes(self, desired_count):
170
288
# We ignore the desired_count because on sockets it's more efficient to
172
return self.socket.recv(4096)
289
# read large chunks (of _MAX_READ_SIZE bytes) at a time.
290
bytes = osutils.until_no_eintr(self.socket.recv, _MAX_READ_SIZE)
291
self._report_activity(len(bytes), 'read')
174
294
def terminate_due_to_error(self):
175
"""Called when an unhandled exception from the protocol occurs."""
176
295
# TODO: This should log to a server log file, but no such thing
177
296
# exists yet. Andrew Bennetts 2006-09-29.
178
297
self.socket.close()
179
298
self.finished = True
181
300
def _write_out(self, bytes):
182
self.socket.sendall(bytes)
301
osutils.send_all(self.socket, bytes, self._report_activity)
185
304
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
337
459
return self._read_bytes(count)
339
461
def _read_bytes(self, count):
340
"""Helper for read_bytes.
462
"""Helper for SmartClientMediumRequest.read_bytes.
342
464
read_bytes checks the state of the request to determing if bytes
343
465
should be read. After that it hands off to _read_bytes to do the
468
By default this forwards to self._medium.read_bytes because we are
469
operating on the medium's stream.
346
raise NotImplementedError(self._read_bytes)
471
return self._medium.read_bytes(count)
348
473
def read_line(self):
349
"""Read bytes from this request's response until a newline byte.
351
This isn't particularly efficient, so should only be used when the
352
expected size of the line is quite short.
354
:returns: a string of bytes ending in a newline (byte 0x0A).
356
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
358
while not line or line[-1] != '\n':
359
new_char = self.read_bytes(1)
362
raise errors.SmartProtocolError(
363
'unexpected end of file reading from server')
474
line = self._read_line()
475
if not line.endswith('\n'):
476
# end of file encountered reading from server
477
raise errors.ConnectionReset(
478
"please check connectivity and permissions")
367
class SmartClientMedium(object):
481
def _read_line(self):
482
"""Helper for SmartClientMediumRequest.read_line.
484
By default this forwards to self._medium._get_line because we are
485
operating on the medium's stream.
487
return self._medium._get_line()
490
class _DebugCounter(object):
491
"""An object that counts the HPSS calls made to each client medium.
493
When a medium is garbage-collected, or failing that when atexit functions
494
are run, the total number of calls made on that medium are reported via
499
self.counts = weakref.WeakKeyDictionary()
500
client._SmartClient.hooks.install_named_hook(
501
'call', self.increment_call_count, 'hpss call counter')
502
atexit.register(self.flush_all)
504
def track(self, medium):
505
"""Start tracking calls made to a medium.
507
This only keeps a weakref to the medium, so shouldn't affect the
510
medium_repr = repr(medium)
511
# Add this medium to the WeakKeyDictionary
512
self.counts[medium] = dict(count=0, vfs_count=0,
513
medium_repr=medium_repr)
514
# Weakref callbacks are fired in reverse order of their association
515
# with the referenced object. So we add a weakref *after* adding to
516
# the WeakKeyDict so that we can report the value from it before the
517
# entry is removed by the WeakKeyDict's own callback.
518
ref = weakref.ref(medium, self.done)
520
def increment_call_count(self, params):
521
# Increment the count in the WeakKeyDictionary
522
value = self.counts[params.medium]
524
request_method = request.request_handlers.get(params.method)
525
if issubclass(request_method, vfs.VfsRequest):
526
value['vfs_count'] += 1
529
value = self.counts[ref]
530
count, vfs_count, medium_repr = (
531
value['count'], value['vfs_count'], value['medium_repr'])
532
# In case this callback is invoked for the same ref twice (by the
533
# weakref callback and by the atexit function), set the call count back
534
# to 0 so this item won't be reported twice.
536
value['vfs_count'] = 0
538
trace.note('HPSS calls: %d (%d vfs) %s',
539
count, vfs_count, medium_repr)
542
for ref in list(self.counts.keys()):
545
_debug_counter = None
548
class SmartClientMedium(SmartMedium):
368
549
"""Smart client is a medium for sending smart protocol requests over."""
551
def __init__(self, base):
552
super(SmartClientMedium, self).__init__()
554
self._protocol_version_error = None
555
self._protocol_version = None
556
self._done_hello = False
557
# Be optimistic: we assume the remote end can accept new remote
558
# requests until we get an error saying otherwise.
559
# _remote_version_is_before tracks the bzr version the remote side
560
# can be based on what we've seen so far.
561
self._remote_version_is_before = None
562
# Install debug hook function if debug flag is set.
563
if 'hpss' in debug.debug_flags:
564
global _debug_counter
565
if _debug_counter is None:
566
_debug_counter = _DebugCounter()
567
_debug_counter.track(self)
569
def _is_remote_before(self, version_tuple):
570
"""Is it possible the remote side supports RPCs for a given version?
574
needed_version = (1, 2)
575
if medium._is_remote_before(needed_version):
576
fallback_to_pre_1_2_rpc()
580
except UnknownSmartMethod:
581
medium._remember_remote_is_before(needed_version)
582
fallback_to_pre_1_2_rpc()
584
:seealso: _remember_remote_is_before
586
if self._remote_version_is_before is None:
587
# So far, the remote side seems to support everything
589
return version_tuple >= self._remote_version_is_before
591
def _remember_remote_is_before(self, version_tuple):
592
"""Tell this medium that the remote side is older the given version.
594
:seealso: _is_remote_before
596
if (self._remote_version_is_before is not None and
597
version_tuple > self._remote_version_is_before):
598
# We have been told that the remote side is older than some version
599
# which is newer than a previously supplied older-than version.
600
# This indicates that some smart verb call is not guarded
601
# appropriately (it should simply not have been tried).
602
raise AssertionError(
603
"_remember_remote_is_before(%r) called, but "
604
"_remember_remote_is_before(%r) was called previously."
605
% (version_tuple, self._remote_version_is_before))
606
self._remote_version_is_before = version_tuple
608
def protocol_version(self):
609
"""Find out if 'hello' smart request works."""
610
if self._protocol_version_error is not None:
611
raise self._protocol_version_error
612
if not self._done_hello:
614
medium_request = self.get_request()
615
# Send a 'hello' request in protocol version one, for maximum
616
# backwards compatibility.
617
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
618
client_protocol.query_version()
619
self._done_hello = True
620
except errors.SmartProtocolError, e:
621
# Cache the error, just like we would cache a successful
623
self._protocol_version_error = e
627
def should_probe(self):
628
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
631
Some transports are unambiguously smart-only; there's no need to check
632
if the transport is able to carry smart requests, because that's all
633
it is for. In those cases, this method should return False.
635
But some HTTP transports can sometimes fail to carry smart requests,
636
but still be usuable for accessing remote bzrdirs via plain file
637
accesses. So for those transports, their media should return True here
638
so that RemoteBzrDirFormat can determine if it is appropriate for that
370
643
def disconnect(self):
371
644
"""If this medium maintains a persistent connection, close it.
373
646
The default implementation does nothing.
649
def remote_path_from_transport(self, transport):
650
"""Convert transport into a path suitable for using in a request.
652
Note that the resulting remote path doesn't encode the host name or
653
anything but path, so it is only safe to use it in requests sent over
654
the medium from the matching transport.
656
medium_base = urlutils.join(self.base, '/')
657
rel_url = urlutils.relative_url(medium_base, transport.base)
658
return urllib.unquote(rel_url)
377
661
class SmartClientStreamMedium(SmartClientMedium):
378
662
"""Stream based medium common class.
437
720
def _read_bytes(self, count):
438
721
"""See SmartClientStreamMedium._read_bytes."""
439
return self._readable_pipe.read(count)
722
bytes = self._readable_pipe.read(count)
723
self._report_activity(len(bytes), 'read')
442
727
class SmartSSHClientMedium(SmartClientStreamMedium):
443
728
"""A client medium using SSH."""
445
730
def __init__(self, host, port=None, username=None, password=None,
731
base=None, vendor=None, bzr_remote_path=None):
447
732
"""Creates a client that will connect on the first use.
449
734
:param vendor: An optional override for the ssh vendor to use. See
450
735
bzrlib.transport.ssh for details on ssh vendors.
452
SmartClientStreamMedium.__init__(self)
453
737
self._connected = False
454
738
self._host = host
455
739
self._password = password
456
740
self._port = port
457
741
self._username = username
742
# SmartClientStreamMedium stores the repr of this object in its
743
# _DebugCounter so we have to store all the values used in our repr
744
# method before calling the super init.
745
SmartClientStreamMedium.__init__(self, base)
458
746
self._read_from = None
459
747
self._ssh_connection = None
460
748
self._vendor = vendor
461
749
self._write_to = None
750
self._bzr_remote_path = bzr_remote_path
751
# for the benefit of progress making a short description of this
753
self._scheme = 'bzr+ssh'
756
return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
757
self.__class__.__name__,
463
763
def _accept_bytes(self, bytes):
464
764
"""See SmartClientStreamMedium.accept_bytes."""
465
765
self._ensure_connection()
466
766
self._write_to.write(bytes)
767
self._report_activity(len(bytes), 'write')
468
769
def disconnect(self):
469
770
"""See SmartClientMedium.disconnect()."""
530
838
"""Connect this medium if not already connected."""
531
839
if self._connected:
533
self._socket = socket.socket()
534
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
535
result = self._socket.connect_ex((self._host, int(self._port)))
841
if self._port is None:
842
port = BZR_DEFAULT_PORT
844
port = int(self._port)
846
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
847
socket.SOCK_STREAM, 0, 0)
848
except socket.gaierror, (err_num, err_msg):
849
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
850
(self._host, port, err_msg))
851
# Initialize err in case there are no addresses returned:
852
err = socket.error("no address found for %s" % self._host)
853
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
855
self._socket = socket.socket(family, socktype, proto)
856
self._socket.setsockopt(socket.IPPROTO_TCP,
857
socket.TCP_NODELAY, 1)
858
self._socket.connect(sockaddr)
859
except socket.error, err:
860
if self._socket is not None:
865
if self._socket is None:
866
# socket errors either have a (string) or (errno, string) as their
868
if type(err.args) is str:
871
err_msg = err.args[1]
537
872
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
538
(self._host, self._port, os.strerror(result)))
873
(self._host, port, err_msg))
539
874
self._connected = True
541
876
def _flush(self):
542
877
"""See SmartClientStreamMedium._flush().
544
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
879
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
545
880
add a means to do a flush, but that can be done in the future.