1
# Copyright (C) 2006-2011 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""The 'medium' layer for the smart servers and clients.
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
20
for "the quality between big and small."
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
over SSH), and pass them to and from the protocol logic. See the overview in
24
breezy/transport/smart/__init__.py.
27
from __future__ import absolute_import
36
from ...lazy_import import lazy_import
37
lazy_import(globals(), """
50
from breezy.i18n import gettext
51
from breezy.bzr.smart import client, protocol, request, signals, vfs
52
from breezy.transport import ssh
59
# Throughout this module buffer size parameters are either limited to be at
60
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
61
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
62
# from non-sockets as well.
63
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
66
class HpssVfsRequestNotAllowed(errors.BzrError):
68
_fmt = ("VFS requests over the smart server are not allowed. Encountered: "
69
"%(method)s, %(arguments)s.")
71
def __init__(self, method, arguments):
73
self.arguments = arguments
76
def _get_protocol_factory_for_bytes(bytes):
77
"""Determine the right protocol factory for 'bytes'.
79
This will return an appropriate protocol factory depending on the version
80
of the protocol being used, as determined by inspecting the given bytes.
81
The bytes should have at least one newline byte (i.e. be a whole line),
82
otherwise it's possible that a request will be incorrectly identified as
85
Typical use would be::
87
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
88
server_protocol = factory(transport, write_func, root_client_path)
89
server_protocol.accept_bytes(unused_bytes)
91
:param bytes: a str of bytes of the start of the request.
92
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
93
a callable that takes three args: transport, write_func,
94
root_client_path. unused_bytes are any bytes that were not part of a
95
protocol version marker.
97
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
98
protocol_factory = protocol.build_server_protocol_three
99
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
100
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
101
protocol_factory = protocol.SmartServerRequestProtocolTwo
102
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
104
protocol_factory = protocol.SmartServerRequestProtocolOne
105
return protocol_factory, bytes
108
def _get_line(read_bytes_func):
109
"""Read bytes using read_bytes_func until a newline byte.
111
This isn't particularly efficient, so should only be used when the
112
expected size of the line is quite short.
114
:returns: a tuple of two strs: (line, excess)
118
while newline_pos == -1:
119
new_bytes = read_bytes_func(1)
122
# Ran out of bytes before receiving a complete line.
124
newline_pos = bytes.find(b'\n')
125
line = bytes[:newline_pos + 1]
126
excess = bytes[newline_pos + 1:]
130
class SmartMedium(object):
131
"""Base class for smart protocol media, both client- and server-side."""
134
self._push_back_buffer = None
136
def _push_back(self, data):
137
"""Return unused bytes to the medium, because they belong to the next
140
This sets the _push_back_buffer to the given bytes.
142
if not isinstance(data, bytes):
143
raise TypeError(data)
144
if self._push_back_buffer is not None:
145
raise AssertionError(
146
"_push_back called when self._push_back_buffer is %r"
147
% (self._push_back_buffer,))
150
self._push_back_buffer = data
152
def _get_push_back_buffer(self):
153
if self._push_back_buffer == b'':
154
raise AssertionError(
155
'%s._push_back_buffer should never be the empty string, '
156
'which can be confused with EOF' % (self,))
157
bytes = self._push_back_buffer
158
self._push_back_buffer = None
161
def read_bytes(self, desired_count):
162
"""Read some bytes from this medium.
164
:returns: some bytes, possibly more or less than the number requested
165
in 'desired_count' depending on the medium.
167
if self._push_back_buffer is not None:
168
return self._get_push_back_buffer()
169
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
170
return self._read_bytes(bytes_to_read)
172
def _read_bytes(self, count):
173
raise NotImplementedError(self._read_bytes)
176
"""Read bytes from this request's response until a newline byte.
178
This isn't particularly efficient, so should only be used when the
179
expected size of the line is quite short.
181
:returns: a string of bytes ending in a newline (byte 0x0A).
183
line, excess = _get_line(self.read_bytes)
184
self._push_back(excess)
187
def _report_activity(self, bytes, direction):
188
"""Notify that this medium has activity.
190
Implementations should call this from all methods that actually do IO.
191
Be careful that it's not called twice, if one method is implemented on
194
:param bytes: Number of bytes read or written.
195
:param direction: 'read' or 'write' or None.
197
ui.ui_factory.report_transport_activity(self, bytes, direction)
200
_bad_file_descriptor = (errno.EBADF,)
201
if sys.platform == 'win32':
202
# Given on Windows if you pass a closed socket to select.select. Probably
203
# also given if you pass a file handle to select.
205
_bad_file_descriptor += (WSAENOTSOCK,)
208
class SmartServerStreamMedium(SmartMedium):
209
"""Handles smart commands coming over a stream.
211
The stream may be a pipe connected to sshd, or a tcp socket, or an
212
in-process fifo for testing.
214
One instance is created for each connected client; it can serve multiple
215
requests in the lifetime of the connection.
217
The server passes requests through to an underlying backing transport,
218
which will typically be a LocalTransport looking at the server's filesystem.
220
:ivar _push_back_buffer: a str of bytes that have been read from the stream
221
but not used yet, or None if there are no buffered bytes. Subclasses
222
should make sure to exhaust this buffer before reading more bytes from
223
the stream. See also the _push_back method.
228
def __init__(self, backing_transport, root_client_path='/', timeout=None):
229
"""Construct new server.
231
:param backing_transport: Transport for the directory served.
233
# backing_transport could be passed to serve instead of __init__
234
self.backing_transport = backing_transport
235
self.root_client_path = root_client_path
236
self.finished = False
238
raise AssertionError('You must supply a timeout.')
239
self._client_timeout = timeout
240
self._client_poll_timeout = min(timeout / 10.0, 1.0)
241
SmartMedium.__init__(self)
244
"""Serve requests until the client disconnects."""
245
# Keep a reference to stderr because the sys module's globals get set to
246
# None during interpreter shutdown.
247
from sys import stderr
249
while not self.finished:
250
server_protocol = self._build_protocol()
251
self._serve_one_request(server_protocol)
252
except errors.ConnectionTimeout as e:
253
trace.note('%s' % (e,))
254
trace.log_exception_quietly()
255
self._disconnect_client()
256
# We reported it, no reason to make a big fuss.
258
except Exception as e:
259
stderr.write("%s terminating on exception %s\n" % (self, e))
261
self._disconnect_client()
263
def _stop_gracefully(self):
264
"""When we finish this message, stop looking for more."""
265
trace.mutter('Stopping %s' % (self,))
268
def _disconnect_client(self):
269
"""Close the current connection. We stopped due to a timeout/etc."""
270
# The default implementation is a no-op, because that is all we used to
271
# do when disconnecting from a client. I suppose we never had the
272
# *server* initiate a disconnect, before
274
def _wait_for_bytes_with_timeout(self, timeout_seconds):
275
"""Wait for more bytes to be read, but timeout if none available.
277
This allows us to detect idle connections, and stop trying to read from
278
them, without setting the socket itself to non-blocking. This also
279
allows us to specify when we watch for idle timeouts.
281
:return: Did we timeout? (True if we timed out, False if there is data
284
raise NotImplementedError(self._wait_for_bytes_with_timeout)
286
def _build_protocol(self):
287
"""Identifies the version of the incoming request, and returns an
288
a protocol object that can interpret it.
290
If more bytes than the version prefix of the request are read, they will
291
be fed into the protocol before it is returned.
293
:returns: a SmartServerRequestProtocol.
295
self._wait_for_bytes_with_timeout(self._client_timeout)
297
# We're stopping, so don't try to do any more work
299
bytes = self._get_line()
300
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
301
protocol = protocol_factory(
302
self.backing_transport, self._write_out, self.root_client_path)
303
protocol.accept_bytes(unused_bytes)
306
def _wait_on_descriptor(self, fd, timeout_seconds):
307
"""select() on a file descriptor, waiting for nonblocking read()
309
This will raise a ConnectionTimeout exception if we do not get a
310
readable handle before timeout_seconds.
313
t_end = self._timer() + timeout_seconds
314
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
316
while not rs and not xs and self._timer() < t_end:
320
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
321
except (select.error, socket.error) as e:
322
err = getattr(e, 'errno', None)
323
if err is None and getattr(e, 'args', None) is not None:
324
# select.error doesn't have 'errno', it just has args[0]
326
if err in _bad_file_descriptor:
327
return # Not a socket indicates read() will fail
328
elif err == errno.EINTR:
329
# Interrupted, keep looping.
333
return # Socket may already be closed
336
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
337
% (timeout_seconds,))
339
def _serve_one_request(self, protocol):
340
"""Read one request from input, process, send back a response.
342
:param protocol: a SmartServerRequestProtocol.
347
self._serve_one_request_unguarded(protocol)
348
except KeyboardInterrupt:
350
except Exception as e:
351
self.terminate_due_to_error()
353
def terminate_due_to_error(self):
354
"""Called when an unhandled exception from the protocol occurs."""
355
raise NotImplementedError(self.terminate_due_to_error)
357
def _read_bytes(self, desired_count):
358
"""Get some bytes from the medium.
360
:param desired_count: number of bytes we want to read.
362
raise NotImplementedError(self._read_bytes)
365
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
367
def __init__(self, sock, backing_transport, root_client_path='/',
371
:param sock: the socket the server will read from. It will be put
374
SmartServerStreamMedium.__init__(
375
self, backing_transport, root_client_path=root_client_path,
377
sock.setblocking(True)
379
# Get the getpeername now, as we might be closed later when we care.
381
self._client_info = sock.getpeername()
383
self._client_info = '<unknown>'
386
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
389
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
392
def _serve_one_request_unguarded(self, protocol):
393
while protocol.next_read_size():
394
# We can safely try to read large chunks. If there is less data
395
# than MAX_SOCKET_CHUNK ready, the socket will just return a
396
# short read immediately rather than block.
397
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
401
protocol.accept_bytes(bytes)
403
self._push_back(protocol.unused_data)
405
def _disconnect_client(self):
406
"""Close the current connection. We stopped due to a timeout/etc."""
409
def _wait_for_bytes_with_timeout(self, timeout_seconds):
410
"""Wait for more bytes to be read, but timeout if none available.
412
This allows us to detect idle connections, and stop trying to read from
413
them, without setting the socket itself to non-blocking. This also
414
allows us to specify when we watch for idle timeouts.
416
:return: None, this will raise ConnectionTimeout if we time out before
419
return self._wait_on_descriptor(self.socket, timeout_seconds)
421
def _read_bytes(self, desired_count):
422
return osutils.read_bytes_from_socket(
423
self.socket, self._report_activity)
425
def terminate_due_to_error(self):
426
# TODO: This should log to a server log file, but no such thing
427
# exists yet. Andrew Bennetts 2006-09-29.
431
def _write_out(self, bytes):
432
tstart = osutils.timer_func()
433
osutils.send_all(self.socket, bytes, self._report_activity)
434
if 'hpss' in debug.debug_flags:
435
thread_id = thread.get_ident()
436
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
437
% ('wrote', thread_id, len(bytes),
438
osutils.timer_func() - tstart))
441
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
443
def __init__(self, in_file, out_file, backing_transport, timeout=None):
444
"""Construct new server.
446
:param in_file: Python file from which requests can be read.
447
:param out_file: Python file to write responses.
448
:param backing_transport: Transport for the directory served.
450
SmartServerStreamMedium.__init__(self, backing_transport,
452
if sys.platform == 'win32':
453
# force binary mode for files
455
for f in (in_file, out_file):
456
fileno = getattr(f, 'fileno', None)
458
msvcrt.setmode(fileno(), os.O_BINARY)
463
"""See SmartServerStreamMedium.serve"""
464
# This is the regular serve, except it adds signal trapping for soft
466
stop_gracefully = self._stop_gracefully
467
signals.register_on_hangup(id(self), stop_gracefully)
469
return super(SmartServerPipeStreamMedium, self).serve()
471
signals.unregister_on_hangup(id(self))
473
def _serve_one_request_unguarded(self, protocol):
475
# We need to be careful not to read past the end of the current
476
# request, or else the read from the pipe will block, so we use
477
# protocol.next_read_size().
478
bytes_to_read = protocol.next_read_size()
479
if bytes_to_read == 0:
480
# Finished serving this request.
483
bytes = self.read_bytes(bytes_to_read)
485
# Connection has been closed.
489
protocol.accept_bytes(bytes)
491
def _disconnect_client(self):
496
def _wait_for_bytes_with_timeout(self, timeout_seconds):
497
"""Wait for more bytes to be read, but timeout if none available.
499
This allows us to detect idle connections, and stop trying to read from
500
them, without setting the socket itself to non-blocking. This also
501
allows us to specify when we watch for idle timeouts.
503
:return: None, this will raise ConnectionTimeout if we time out before
506
if (getattr(self._in, 'fileno', None) is None
507
or sys.platform == 'win32'):
508
# You can't select() file descriptors on Windows.
511
return self._wait_on_descriptor(self._in, timeout_seconds)
512
except io.UnsupportedOperation:
515
def _read_bytes(self, desired_count):
516
return self._in.read(desired_count)
518
def terminate_due_to_error(self):
519
# TODO: This should log to a server log file, but no such thing
520
# exists yet. Andrew Bennetts 2006-09-29.
524
def _write_out(self, bytes):
525
self._out.write(bytes)
528
class SmartClientMediumRequest(object):
529
"""A request on a SmartClientMedium.
531
Each request allows bytes to be provided to it via accept_bytes, and then
532
the response bytes to be read via read_bytes.
535
request.accept_bytes('123')
536
request.finished_writing()
537
result = request.read_bytes(3)
538
request.finished_reading()
540
It is up to the individual SmartClientMedium whether multiple concurrent
541
requests can exist. See SmartClientMedium.get_request to obtain instances
542
of SmartClientMediumRequest, and the concrete Medium you are using for
543
details on concurrency and pipelining.
546
def __init__(self, medium):
547
"""Construct a SmartClientMediumRequest for the medium medium."""
548
self._medium = medium
549
# we track state by constants - we may want to use the same
550
# pattern as BodyReader if it gets more complex.
551
# valid states are: "writing", "reading", "done"
552
self._state = "writing"
554
def accept_bytes(self, bytes):
555
"""Accept bytes for inclusion in this request.
557
This method may not be called after finished_writing() has been
558
called. It depends upon the Medium whether or not the bytes will be
559
immediately transmitted. Message based Mediums will tend to buffer the
560
bytes until finished_writing() is called.
562
:param bytes: A bytestring.
564
if self._state != "writing":
565
raise errors.WritingCompleted(self)
566
self._accept_bytes(bytes)
568
def _accept_bytes(self, bytes):
569
"""Helper for accept_bytes.
571
Accept_bytes checks the state of the request to determing if bytes
572
should be accepted. After that it hands off to _accept_bytes to do the
575
raise NotImplementedError(self._accept_bytes)
577
def finished_reading(self):
578
"""Inform the request that all desired data has been read.
580
This will remove the request from the pipeline for its medium (if the
581
medium supports pipelining) and any further calls to methods on the
582
request will raise ReadingCompleted.
584
if self._state == "writing":
585
raise errors.WritingNotComplete(self)
586
if self._state != "reading":
587
raise errors.ReadingCompleted(self)
589
self._finished_reading()
591
def _finished_reading(self):
592
"""Helper for finished_reading.
594
finished_reading checks the state of the request to determine if
595
finished_reading is allowed, and if it is hands off to _finished_reading
596
to perform the action.
598
raise NotImplementedError(self._finished_reading)
600
def finished_writing(self):
601
"""Finish the writing phase of this request.
603
This will flush all pending data for this request along the medium.
604
After calling finished_writing, you may not call accept_bytes anymore.
606
if self._state != "writing":
607
raise errors.WritingCompleted(self)
608
self._state = "reading"
609
self._finished_writing()
611
def _finished_writing(self):
612
"""Helper for finished_writing.
614
finished_writing checks the state of the request to determine if
615
finished_writing is allowed, and if it is hands off to _finished_writing
616
to perform the action.
618
raise NotImplementedError(self._finished_writing)
620
def read_bytes(self, count):
621
"""Read bytes from this requests response.
623
This method will block and wait for count bytes to be read. It may not
624
be invoked until finished_writing() has been called - this is to ensure
625
a message-based approach to requests, for compatibility with message
626
based mediums like HTTP.
628
if self._state == "writing":
629
raise errors.WritingNotComplete(self)
630
if self._state != "reading":
631
raise errors.ReadingCompleted(self)
632
return self._read_bytes(count)
634
def _read_bytes(self, count):
635
"""Helper for SmartClientMediumRequest.read_bytes.
637
read_bytes checks the state of the request to determing if bytes
638
should be read. After that it hands off to _read_bytes to do the
641
By default this forwards to self._medium.read_bytes because we are
642
operating on the medium's stream.
644
return self._medium.read_bytes(count)
647
line = self._read_line()
648
if not line.endswith(b'\n'):
649
# end of file encountered reading from server
650
raise errors.ConnectionReset(
651
"Unexpected end of message. Please check connectivity "
652
"and permissions, and report a bug if problems persist.")
655
def _read_line(self):
656
"""Helper for SmartClientMediumRequest.read_line.
658
By default this forwards to self._medium._get_line because we are
659
operating on the medium's stream.
661
return self._medium._get_line()
664
class _VfsRefuser(object):
665
"""An object that refuses all VFS requests.
670
client._SmartClient.hooks.install_named_hook(
671
'call', self.check_vfs, 'vfs refuser')
673
def check_vfs(self, params):
675
request_method = request.request_handlers.get(params.method)
677
# A method we don't know about doesn't count as a VFS method.
679
if issubclass(request_method, vfs.VfsRequest):
680
raise HpssVfsRequestNotAllowed(params.method, params.args)
683
class _DebugCounter(object):
684
"""An object that counts the HPSS calls made to each client medium.
686
When a medium is garbage-collected, or failing that when
687
breezy.global_state exits, the total number of calls made on that medium
688
are reported via trace.note.
692
self.counts = weakref.WeakKeyDictionary()
693
client._SmartClient.hooks.install_named_hook(
694
'call', self.increment_call_count, 'hpss call counter')
695
breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
697
def track(self, medium):
698
"""Start tracking calls made to a medium.
700
This only keeps a weakref to the medium, so shouldn't affect the
703
medium_repr = repr(medium)
704
# Add this medium to the WeakKeyDictionary
705
self.counts[medium] = dict(count=0, vfs_count=0,
706
medium_repr=medium_repr)
707
# Weakref callbacks are fired in reverse order of their association
708
# with the referenced object. So we add a weakref *after* adding to
709
# the WeakKeyDict so that we can report the value from it before the
710
# entry is removed by the WeakKeyDict's own callback.
711
ref = weakref.ref(medium, self.done)
713
def increment_call_count(self, params):
714
# Increment the count in the WeakKeyDictionary
715
value = self.counts[params.medium]
718
request_method = request.request_handlers.get(params.method)
720
# A method we don't know about doesn't count as a VFS method.
722
if issubclass(request_method, vfs.VfsRequest):
723
value['vfs_count'] += 1
726
value = self.counts[ref]
727
count, vfs_count, medium_repr = (
728
value['count'], value['vfs_count'], value['medium_repr'])
729
# In case this callback is invoked for the same ref twice (by the
730
# weakref callback and by the atexit function), set the call count back
731
# to 0 so this item won't be reported twice.
733
value['vfs_count'] = 0
735
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
736
count, vfs_count, medium_repr))
739
for ref in list(self.counts.keys()):
743
_debug_counter = None
747
class SmartClientMedium(SmartMedium):
748
"""Smart client is a medium for sending smart protocol requests over."""
750
def __init__(self, base):
751
super(SmartClientMedium, self).__init__()
753
self._protocol_version_error = None
754
self._protocol_version = None
755
self._done_hello = False
756
# Be optimistic: we assume the remote end can accept new remote
757
# requests until we get an error saying otherwise.
758
# _remote_version_is_before tracks the bzr version the remote side
759
# can be based on what we've seen so far.
760
self._remote_version_is_before = None
761
# Install debug hook function if debug flag is set.
762
if 'hpss' in debug.debug_flags:
763
global _debug_counter
764
if _debug_counter is None:
765
_debug_counter = _DebugCounter()
766
_debug_counter.track(self)
767
if 'hpss_client_no_vfs' in debug.debug_flags:
769
if _vfs_refuser is None:
770
_vfs_refuser = _VfsRefuser()
772
def _is_remote_before(self, version_tuple):
773
"""Is it possible the remote side supports RPCs for a given version?
777
needed_version = (1, 2)
778
if medium._is_remote_before(needed_version):
779
fallback_to_pre_1_2_rpc()
783
except UnknownSmartMethod:
784
medium._remember_remote_is_before(needed_version)
785
fallback_to_pre_1_2_rpc()
787
:seealso: _remember_remote_is_before
789
if self._remote_version_is_before is None:
790
# So far, the remote side seems to support everything
792
return version_tuple >= self._remote_version_is_before
794
def _remember_remote_is_before(self, version_tuple):
795
"""Tell this medium that the remote side is older the given version.
797
:seealso: _is_remote_before
799
if (self._remote_version_is_before is not None and
800
version_tuple > self._remote_version_is_before):
801
# We have been told that the remote side is older than some version
802
# which is newer than a previously supplied older-than version.
803
# This indicates that some smart verb call is not guarded
804
# appropriately (it should simply not have been tried).
806
"_remember_remote_is_before(%r) called, but "
807
"_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
808
if 'hpss' in debug.debug_flags:
809
ui.ui_factory.show_warning(
810
"_remember_remote_is_before(%r) called, but "
811
"_remember_remote_is_before(%r) was called previously."
812
% (version_tuple, self._remote_version_is_before))
814
self._remote_version_is_before = version_tuple
816
def protocol_version(self):
817
"""Find out if 'hello' smart request works."""
818
if self._protocol_version_error is not None:
819
raise self._protocol_version_error
820
if not self._done_hello:
822
medium_request = self.get_request()
823
# Send a 'hello' request in protocol version one, for maximum
824
# backwards compatibility.
825
client_protocol = protocol.SmartClientRequestProtocolOne(
827
client_protocol.query_version()
828
self._done_hello = True
829
except errors.SmartProtocolError as e:
830
# Cache the error, just like we would cache a successful
832
self._protocol_version_error = e
836
def should_probe(self):
837
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
840
Some transports are unambiguously smart-only; there's no need to check
841
if the transport is able to carry smart requests, because that's all
842
it is for. In those cases, this method should return False.
844
But some HTTP transports can sometimes fail to carry smart requests,
845
but still be usuable for accessing remote bzrdirs via plain file
846
accesses. So for those transports, their media should return True here
847
so that RemoteBzrDirFormat can determine if it is appropriate for that
852
def disconnect(self):
853
"""If this medium maintains a persistent connection, close it.
855
The default implementation does nothing.
858
def remote_path_from_transport(self, transport):
859
"""Convert transport into a path suitable for using in a request.
861
Note that the resulting remote path doesn't encode the host name or
862
anything but path, so it is only safe to use it in requests sent over
863
the medium from the matching transport.
865
medium_base = urlutils.join(self.base, '/')
866
rel_url = urlutils.relative_url(medium_base, transport.base)
867
return urlutils.unquote(rel_url)
870
class SmartClientStreamMedium(SmartClientMedium):
871
"""Stream based medium common class.
873
SmartClientStreamMediums operate on a stream. All subclasses use a common
874
SmartClientStreamMediumRequest for their requests, and should implement
875
_accept_bytes and _read_bytes to allow the request objects to send and
879
def __init__(self, base):
880
SmartClientMedium.__init__(self, base)
881
self._current_request = None
883
def accept_bytes(self, bytes):
884
self._accept_bytes(bytes)
887
"""The SmartClientStreamMedium knows how to close the stream when it is
893
"""Flush the output stream.
895
This method is used by the SmartClientStreamMediumRequest to ensure that
896
all data for a request is sent, to avoid long timeouts or deadlocks.
898
raise NotImplementedError(self._flush)
900
def get_request(self):
901
"""See SmartClientMedium.get_request().
903
SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
906
return SmartClientStreamMediumRequest(self)
909
"""We have been disconnected, reset current state.
911
This resets things like _current_request and connected state.
914
self._current_request = None
917
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
918
"""A client medium using simple pipes.
920
This client does not manage the pipes: it assumes they will always be open.
923
def __init__(self, readable_pipe, writeable_pipe, base):
924
SmartClientStreamMedium.__init__(self, base)
925
self._readable_pipe = readable_pipe
926
self._writeable_pipe = writeable_pipe
928
def _accept_bytes(self, data):
929
"""See SmartClientStreamMedium.accept_bytes."""
931
self._writeable_pipe.write(data)
933
if e.errno in (errno.EINVAL, errno.EPIPE):
934
raise errors.ConnectionReset(
935
"Error trying to write to subprocess", e)
937
self._report_activity(len(data), 'write')
940
"""See SmartClientStreamMedium._flush()."""
941
# Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
942
# However, testing shows that even when the child process is
943
# gone, this doesn't error.
944
self._writeable_pipe.flush()
946
def _read_bytes(self, count):
947
"""See SmartClientStreamMedium._read_bytes."""
948
bytes_to_read = min(count, _MAX_READ_SIZE)
949
data = self._readable_pipe.read(bytes_to_read)
950
self._report_activity(len(data), 'read')
954
class SSHParams(object):
955
"""A set of parameters for starting a remote bzr via SSH."""
957
def __init__(self, host, port=None, username=None, password=None,
958
bzr_remote_path='bzr'):
961
self.username = username
962
self.password = password
963
self.bzr_remote_path = bzr_remote_path
966
class SmartSSHClientMedium(SmartClientStreamMedium):
967
"""A client medium using SSH.
969
It delegates IO to a SmartSimplePipesClientMedium or
970
SmartClientAlreadyConnectedSocketMedium (depending on platform).
973
def __init__(self, base, ssh_params, vendor=None):
974
"""Creates a client that will connect on the first use.
976
:param ssh_params: A SSHParams instance.
977
:param vendor: An optional override for the ssh vendor to use. See
978
breezy.transport.ssh for details on ssh vendors.
980
self._real_medium = None
981
self._ssh_params = ssh_params
982
# for the benefit of progress making a short description of this
984
self._scheme = 'bzr+ssh'
985
# SmartClientStreamMedium stores the repr of this object in its
986
# _DebugCounter so we have to store all the values used in our repr
987
# method before calling the super init.
988
SmartClientStreamMedium.__init__(self, base)
989
self._vendor = vendor
990
self._ssh_connection = None
993
if self._ssh_params.port is None:
996
maybe_port = ':%s' % self._ssh_params.port
997
if self._ssh_params.username is None:
1000
maybe_user = '%s@' % self._ssh_params.username
1001
return "%s(%s://%s%s%s/)" % (
1002
self.__class__.__name__,
1005
self._ssh_params.host,
1008
def _accept_bytes(self, bytes):
1009
"""See SmartClientStreamMedium.accept_bytes."""
1010
self._ensure_connection()
1011
self._real_medium.accept_bytes(bytes)
1013
def disconnect(self):
1014
"""See SmartClientMedium.disconnect()."""
1015
if self._real_medium is not None:
1016
self._real_medium.disconnect()
1017
self._real_medium = None
1018
if self._ssh_connection is not None:
1019
self._ssh_connection.close()
1020
self._ssh_connection = None
1022
def _ensure_connection(self):
1023
"""Connect this medium if not already connected."""
1024
if self._real_medium is not None:
1026
if self._vendor is None:
1027
vendor = ssh._get_ssh_vendor()
1029
vendor = self._vendor
1030
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1031
self._ssh_params.password, self._ssh_params.host,
1032
self._ssh_params.port,
1033
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1034
'--directory=/', '--allow-writes'])
1035
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1036
if io_kind == 'socket':
1037
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1038
self.base, io_object)
1039
elif io_kind == 'pipes':
1040
read_from, write_to = io_object
1041
self._real_medium = SmartSimplePipesClientMedium(
1042
read_from, write_to, self.base)
1044
raise AssertionError(
1045
"Unexpected io_kind %r from %r"
1046
% (io_kind, self._ssh_connection))
1047
for hook in transport.Transport.hooks["post_connect"]:
1051
"""See SmartClientStreamMedium._flush()."""
1052
self._real_medium._flush()
1054
def _read_bytes(self, count):
1055
"""See SmartClientStreamMedium.read_bytes."""
1056
if self._real_medium is None:
1057
raise errors.MediumNotConnected(self)
1058
return self._real_medium.read_bytes(count)
1061
# Port 4155 is the default port for bzr://, registered with IANA.
1062
BZR_DEFAULT_INTERFACE = None
1063
BZR_DEFAULT_PORT = 4155
1066
class SmartClientSocketMedium(SmartClientStreamMedium):
1067
"""A client medium using a socket.
1069
This class isn't usable directly. Use one of its subclasses instead.
1072
def __init__(self, base):
1073
SmartClientStreamMedium.__init__(self, base)
1075
self._connected = False
1077
def _accept_bytes(self, bytes):
1078
"""See SmartClientMedium.accept_bytes."""
1079
self._ensure_connection()
1080
osutils.send_all(self._socket, bytes, self._report_activity)
1082
def _ensure_connection(self):
1083
"""Connect this medium if not already connected."""
1084
raise NotImplementedError(self._ensure_connection)
1087
"""See SmartClientStreamMedium._flush().
1089
For sockets we do no flushing. For TCP sockets we may want to turn off
1090
TCP_NODELAY and add a means to do a flush, but that can be done in the
1094
def _read_bytes(self, count):
1095
"""See SmartClientMedium.read_bytes."""
1096
if not self._connected:
1097
raise errors.MediumNotConnected(self)
1098
return osutils.read_bytes_from_socket(
1099
self._socket, self._report_activity)
1101
def disconnect(self):
1102
"""See SmartClientMedium.disconnect()."""
1103
if not self._connected:
1105
self._socket.close()
1107
self._connected = False
1110
class SmartTCPClientMedium(SmartClientSocketMedium):
1111
"""A client medium that creates a TCP connection."""
1113
def __init__(self, host, port, base):
1114
"""Creates a client that will connect on the first use."""
1115
SmartClientSocketMedium.__init__(self, base)
1119
def _ensure_connection(self):
1120
"""Connect this medium if not already connected."""
1123
if self._port is None:
1124
port = BZR_DEFAULT_PORT
1126
port = int(self._port)
1128
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1129
socket.SOCK_STREAM, 0, 0)
1130
except socket.gaierror as xxx_todo_changeme:
1131
(err_num, err_msg) = xxx_todo_changeme.args
1132
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1133
(self._host, port, err_msg))
1134
# Initialize err in case there are no addresses returned:
1135
last_err = socket.error("no address found for %s" % self._host)
1136
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1138
self._socket = socket.socket(family, socktype, proto)
1139
self._socket.setsockopt(socket.IPPROTO_TCP,
1140
socket.TCP_NODELAY, 1)
1141
self._socket.connect(sockaddr)
1142
except socket.error as err:
1143
if self._socket is not None:
1144
self._socket.close()
1149
if self._socket is None:
1150
# socket errors either have a (string) or (errno, string) as their
1152
if isinstance(last_err.args, str):
1153
err_msg = last_err.args
1155
err_msg = last_err.args[1]
1156
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1157
(self._host, port, err_msg))
1158
self._connected = True
1159
for hook in transport.Transport.hooks["post_connect"]:
1163
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1164
"""A client medium for an already connected socket.
1166
Note that this class will assume it "owns" the socket, so it will close it
1167
when its disconnect method is called.
1170
def __init__(self, base, sock):
1171
SmartClientSocketMedium.__init__(self, base)
1173
self._connected = True
1175
def _ensure_connection(self):
1176
# Already connected, by definition! So nothing to do.
1180
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1181
"""A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
1183
def __init__(self, medium):
1184
SmartClientMediumRequest.__init__(self, medium)
1185
# check that we are safe concurrency wise. If some streams start
1186
# allowing concurrent requests - i.e. via multiplexing - then this
1187
# assert should be moved to SmartClientStreamMedium.get_request,
1188
# and the setting/unsetting of _current_request likewise moved into
1189
# that class : but its unneeded overhead for now. RBC 20060922
1190
if self._medium._current_request is not None:
1191
raise errors.TooManyConcurrentRequests(self._medium)
1192
self._medium._current_request = self
1194
def _accept_bytes(self, bytes):
1195
"""See SmartClientMediumRequest._accept_bytes.
1197
This forwards to self._medium._accept_bytes because we are operating
1198
on the mediums stream.
1200
self._medium._accept_bytes(bytes)
1202
def _finished_reading(self):
1203
"""See SmartClientMediumRequest._finished_reading.
1205
This clears the _current_request on self._medium to allow a new
1206
request to be created.
1208
if self._medium._current_request is not self:
1209
raise AssertionError()
1210
self._medium._current_request = None
1212
def _finished_writing(self):
1213
"""See SmartClientMediumRequest._finished_writing.
1215
This invokes self._medium._flush to ensure all bytes are transmitted.
1217
self._medium._flush()