/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart.py

  • Committer: Robert Collins
  • Date: 2007-04-05 00:39:03 UTC
  • mto: This revision was merged to the branch mainline in revision 2401.
  • Revision ID: robertc@robertcollins.net-20070405003903-u1ys8t2lo5gs6b35
Overhaul the SmartTCPServer connect-thread logic to synchronise on startup and shutdown and notify the server if it is in accept.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Smart-server protocol, client and server.
 
18
 
 
19
Requests are sent as a command and list of arguments, followed by optional
 
20
bulk body data.  Responses are similarly a response and list of arguments,
 
21
followed by bulk body data. ::
 
22
 
 
23
  SEP := '\001'
 
24
    Fields are separated by Ctrl-A.
 
25
  BULK_DATA := CHUNK TRAILER
 
26
    Chunks can be repeated as many times as necessary.
 
27
  CHUNK := CHUNK_LEN CHUNK_BODY
 
28
  CHUNK_LEN := DIGIT+ NEWLINE
 
29
    Gives the number of bytes in the following chunk.
 
30
  CHUNK_BODY := BYTE[chunk_len]
 
31
  TRAILER := SUCCESS_TRAILER | ERROR_TRAILER
 
32
  SUCCESS_TRAILER := 'done' NEWLINE
 
33
  ERROR_TRAILER := 
 
34
 
 
35
Paths are passed across the network.  The client needs to see a namespace that
 
36
includes any repository that might need to be referenced, and the client needs
 
37
to know about a root directory beyond which it cannot ascend.
 
38
 
 
39
Servers run over ssh will typically want to be able to access any path the user 
 
40
can access.  Public servers on the other hand (which might be over http, ssh
 
41
or tcp) will typically want to restrict access to only a particular directory 
 
42
and its children, so will want to do a software virtual root at that level.
 
43
In other words they'll want to rewrite incoming paths to be under that level
 
44
(and prevent escaping using ../ tricks.)
 
45
 
 
46
URLs that include ~ should probably be passed across to the server verbatim
 
47
and the server can expand them.  This will proably not be meaningful when 
 
48
limited to a directory?
 
49
 
 
50
At the bottom level socket, pipes, HTTP server.  For sockets, we have the idea
 
51
that you have multiple requests and get a read error because the other side did
 
52
shutdown.  For pipes we have read pipe which will have a zero read which marks
 
53
end-of-file.  For HTTP server environment there is not end-of-stream because
 
54
each request coming into the server is independent.
 
55
 
 
56
So we need a wrapper around pipes and sockets to seperate out requests from
 
57
substrate and this will give us a single model which is consist for HTTP,
 
58
sockets and pipes.
 
59
 
 
60
Server-side
 
61
-----------
 
62
 
 
63
 MEDIUM  (factory for protocol, reads bytes & pushes to protocol,
 
64
          uses protocol to detect end-of-request, sends written
 
65
          bytes to client) e.g. socket, pipe, HTTP request handler.
 
66
  ^
 
67
  | bytes.
 
68
  v
 
69
 
 
70
PROTOCOL  (serialization, deserialization)  accepts bytes for one
 
71
          request, decodes according to internal state, pushes
 
72
          structured data to handler.  accepts structured data from
 
73
          handler and encodes and writes to the medium.  factory for
 
74
          handler.
 
75
  ^
 
76
  | structured data
 
77
  v
 
78
 
 
79
HANDLER   (domain logic) accepts structured data, operates state
 
80
          machine until the request can be satisfied,
 
81
          sends structured data to the protocol.
 
82
 
 
83
 
 
84
Client-side
 
85
-----------
 
86
 
 
87
 CLIENT             domain logic, accepts domain requests, generated structured
 
88
                    data, reads structured data from responses and turns into
 
89
                    domain data.  Sends structured data to the protocol.
 
90
                    Operates state machines until the request can be delivered
 
91
                    (e.g. reading from a bundle generated in bzrlib to deliver a
 
92
                    complete request).
 
93
 
 
94
                    Possibly this should just be RemoteBzrDir, RemoteTransport,
 
95
                    ...
 
96
  ^
 
97
  | structured data
 
98
  v
 
99
 
 
100
PROTOCOL  (serialization, deserialization)  accepts structured data for one
 
101
          request, encodes and writes to the medium.  Reads bytes from the
 
102
          medium, decodes and allows the client to read structured data.
 
103
  ^
 
104
  | bytes.
 
105
  v
 
106
 
 
107
 MEDIUM  (accepts bytes from the protocol & delivers to the remote server.
 
108
          Allows the potocol to read bytes e.g. socket, pipe, HTTP request.
 
109
"""
 
110
 
 
111
 
 
112
# TODO: _translate_error should be on the client, not the transport because
 
113
#     error coding is wire protocol specific.
 
114
 
 
115
# TODO: A plain integer from query_version is too simple; should give some
 
116
# capabilities too?
 
117
 
 
118
# TODO: Server should probably catch exceptions within itself and send them
 
119
# back across the network.  (But shouldn't catch KeyboardInterrupt etc)
 
120
# Also needs to somehow report protocol errors like bad requests.  Need to
 
121
# consider how we'll handle error reporting, e.g. if we get halfway through a
 
122
# bulk transfer and then something goes wrong.
 
123
 
 
124
# TODO: Standard marker at start of request/response lines?
 
125
 
 
126
# TODO: Make each request and response self-validatable, e.g. with checksums.
 
127
#
 
128
# TODO: get/put objects could be changed to gradually read back the data as it
 
129
# comes across the network
 
130
#
 
131
# TODO: What should the server do if it hits an error and has to terminate?
 
132
#
 
133
# TODO: is it useful to allow multiple chunks in the bulk data?
 
134
#
 
135
# TODO: If we get an exception during transmission of bulk data we can't just
 
136
# emit the exception because it won't be seen.
 
137
#   John proposes:  I think it would be worthwhile to have a header on each
 
138
#   chunk, that indicates it is another chunk. Then you can send an 'error'
 
139
#   chunk as long as you finish the previous chunk.
 
140
#
 
141
# TODO: Clone method on Transport; should work up towards parent directory;
 
142
# unclear how this should be stored or communicated to the server... maybe
 
143
# just pass it on all relevant requests?
 
144
#
 
145
# TODO: Better name than clone() for changing between directories.  How about
 
146
# open_dir or change_dir or chdir?
 
147
#
 
148
# TODO: Is it really good to have the notion of current directory within the
 
149
# connection?  Perhaps all Transports should factor out a common connection
 
150
# from the thing that has the directory context?
 
151
#
 
152
# TODO: Pull more things common to sftp and ssh to a higher level.
 
153
#
 
154
# TODO: The server that manages a connection should be quite small and retain
 
155
# minimum state because each of the requests are supposed to be stateless.
 
156
# Then we can write another implementation that maps to http.
 
157
#
 
158
# TODO: What to do when a client connection is garbage collected?  Maybe just
 
159
# abruptly drop the connection?
 
160
#
 
161
# TODO: Server in some cases will need to restrict access to files outside of
 
162
# a particular root directory.  LocalTransport doesn't do anything to stop you
 
163
# ascending above the base directory, so we need to prevent paths
 
164
# containing '..' in either the server or transport layers.  (Also need to
 
165
# consider what happens if someone creates a symlink pointing outside the 
 
166
# directory tree...)
 
167
#
 
168
# TODO: Server should rebase absolute paths coming across the network to put
 
169
# them under the virtual root, if one is in use.  LocalTransport currently
 
170
# doesn't do that; if you give it an absolute path it just uses it.
 
171
 
172
# XXX: Arguments can't contain newlines or ascii; possibly we should e.g.
 
173
# urlescape them instead.  Indeed possibly this should just literally be
 
174
# http-over-ssh.
 
175
#
 
176
# FIXME: This transport, with several others, has imperfect handling of paths
 
177
# within urls.  It'd probably be better for ".." from a root to raise an error
 
178
# rather than return the same directory as we do at present.
 
179
#
 
180
# TODO: Rather than working at the Transport layer we want a Branch,
 
181
# Repository or BzrDir objects that talk to a server.
 
182
#
 
183
# TODO: Probably want some way for server commands to gradually produce body
 
184
# data rather than passing it as a string; they could perhaps pass an
 
185
# iterator-like callback that will gradually yield data; it probably needs a
 
186
# close() method that will always be closed to do any necessary cleanup.
 
187
#
 
188
# TODO: Split the actual smart server from the ssh encoding of it.
 
189
#
 
190
# TODO: Perhaps support file-level readwrite operations over the transport
 
191
# too.
 
192
#
 
193
# TODO: SmartBzrDir class, proxying all Branch etc methods across to another
 
194
# branch doing file-level operations.
 
195
#
 
196
 
 
197
from cStringIO import StringIO
 
198
import errno
 
199
import os
 
200
import socket
 
201
import sys
 
202
import tempfile
 
203
import threading
 
204
import urllib
 
205
import urlparse
 
206
 
 
207
from bzrlib import (
 
208
    bzrdir,
 
209
    errors,
 
210
    revision,
 
211
    transport,
 
212
    trace,
 
213
    urlutils,
 
214
    )
 
215
from bzrlib.bundle.serializer import write_bundle
 
216
from bzrlib.hooks import Hooks
 
217
try:
 
218
    from bzrlib.transport import ssh
 
219
except errors.ParamikoNotPresent:
 
220
    # no paramiko.  SmartSSHClientMedium will break.
 
221
    pass
 
222
 
 
223
# must do this otherwise urllib can't parse the urls properly :(
 
224
for scheme in ['ssh', 'bzr', 'bzr+loopback', 'bzr+ssh', 'bzr+http']:
 
225
    transport.register_urlparse_netloc_protocol(scheme)
 
226
del scheme
 
227
 
 
228
 
 
229
# Port 4155 is the default port for bzr://, registered with IANA.
 
230
BZR_DEFAULT_PORT = 4155
 
231
 
 
232
 
 
233
def _recv_tuple(from_file):
 
234
    req_line = from_file.readline()
 
235
    return _decode_tuple(req_line)
 
236
 
 
237
 
 
238
def _decode_tuple(req_line):
 
239
    if req_line == None or req_line == '':
 
240
        return None
 
241
    if req_line[-1] != '\n':
 
242
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
243
    return tuple(req_line[:-1].split('\x01'))
 
244
 
 
245
 
 
246
def _encode_tuple(args):
 
247
    """Encode the tuple args to a bytestream."""
 
248
    return '\x01'.join(args) + '\n'
 
249
 
 
250
 
 
251
class SmartProtocolBase(object):
 
252
    """Methods common to client and server"""
 
253
 
 
254
    # TODO: this only actually accomodates a single block; possibly should
 
255
    # support multiple chunks?
 
256
    def _encode_bulk_data(self, body):
 
257
        """Encode body as a bulk data chunk."""
 
258
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
259
 
 
260
    def _serialise_offsets(self, offsets):
 
261
        """Serialise a readv offset list."""
 
262
        txt = []
 
263
        for start, length in offsets:
 
264
            txt.append('%d,%d' % (start, length))
 
265
        return '\n'.join(txt)
 
266
        
 
267
 
 
268
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
269
    """Server-side encoding and decoding logic for smart version 1."""
 
270
    
 
271
    def __init__(self, backing_transport, write_func):
 
272
        self._backing_transport = backing_transport
 
273
        self.excess_buffer = ''
 
274
        self._finished = False
 
275
        self.in_buffer = ''
 
276
        self.has_dispatched = False
 
277
        self.request = None
 
278
        self._body_decoder = None
 
279
        self._write_func = write_func
 
280
 
 
281
    def accept_bytes(self, bytes):
 
282
        """Take bytes, and advance the internal state machine appropriately.
 
283
        
 
284
        :param bytes: must be a byte string
 
285
        """
 
286
        assert isinstance(bytes, str)
 
287
        self.in_buffer += bytes
 
288
        if not self.has_dispatched:
 
289
            if '\n' not in self.in_buffer:
 
290
                # no command line yet
 
291
                return
 
292
            self.has_dispatched = True
 
293
            try:
 
294
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
295
                first_line += '\n'
 
296
                req_args = _decode_tuple(first_line)
 
297
                self.request = SmartServerRequestHandler(
 
298
                    self._backing_transport)
 
299
                self.request.dispatch_command(req_args[0], req_args[1:])
 
300
                if self.request.finished_reading:
 
301
                    # trivial request
 
302
                    self.excess_buffer = self.in_buffer
 
303
                    self.in_buffer = ''
 
304
                    self._send_response(self.request.response.args,
 
305
                        self.request.response.body)
 
306
            except KeyboardInterrupt:
 
307
                raise
 
308
            except Exception, exception:
 
309
                # everything else: pass to client, flush, and quit
 
310
                self._send_response(('error', str(exception)))
 
311
                return
 
312
 
 
313
        if self.has_dispatched:
 
314
            if self._finished:
 
315
                # nothing to do.XXX: this routine should be a single state 
 
316
                # machine too.
 
317
                self.excess_buffer += self.in_buffer
 
318
                self.in_buffer = ''
 
319
                return
 
320
            if self._body_decoder is None:
 
321
                self._body_decoder = LengthPrefixedBodyDecoder()
 
322
            self._body_decoder.accept_bytes(self.in_buffer)
 
323
            self.in_buffer = self._body_decoder.unused_data
 
324
            body_data = self._body_decoder.read_pending_data()
 
325
            self.request.accept_body(body_data)
 
326
            if self._body_decoder.finished_reading:
 
327
                self.request.end_of_body()
 
328
                assert self.request.finished_reading, \
 
329
                    "no more body, request not finished"
 
330
            if self.request.response is not None:
 
331
                self._send_response(self.request.response.args,
 
332
                    self.request.response.body)
 
333
                self.excess_buffer = self.in_buffer
 
334
                self.in_buffer = ''
 
335
            else:
 
336
                assert not self.request.finished_reading, \
 
337
                    "no response and we have finished reading."
 
338
 
 
339
    def _send_response(self, args, body=None):
 
340
        """Send a smart server response down the output stream."""
 
341
        assert not self._finished, 'response already sent'
 
342
        self._finished = True
 
343
        self._write_func(_encode_tuple(args))
 
344
        if body is not None:
 
345
            assert isinstance(body, str), 'body must be a str'
 
346
            bytes = self._encode_bulk_data(body)
 
347
            self._write_func(bytes)
 
348
 
 
349
    def next_read_size(self):
 
350
        if self._finished:
 
351
            return 0
 
352
        if self._body_decoder is None:
 
353
            return 1
 
354
        else:
 
355
            return self._body_decoder.next_read_size()
 
356
 
 
357
 
 
358
class LengthPrefixedBodyDecoder(object):
 
359
    """Decodes the length-prefixed bulk data."""
 
360
    
 
361
    def __init__(self):
 
362
        self.bytes_left = None
 
363
        self.finished_reading = False
 
364
        self.unused_data = ''
 
365
        self.state_accept = self._state_accept_expecting_length
 
366
        self.state_read = self._state_read_no_data
 
367
        self._in_buffer = ''
 
368
        self._trailer_buffer = ''
 
369
    
 
370
    def accept_bytes(self, bytes):
 
371
        """Decode as much of bytes as possible.
 
372
 
 
373
        If 'bytes' contains too much data it will be appended to
 
374
        self.unused_data.
 
375
 
 
376
        finished_reading will be set when no more data is required.  Further
 
377
        data will be appended to self.unused_data.
 
378
        """
 
379
        # accept_bytes is allowed to change the state
 
380
        current_state = self.state_accept
 
381
        self.state_accept(bytes)
 
382
        while current_state != self.state_accept:
 
383
            current_state = self.state_accept
 
384
            self.state_accept('')
 
385
 
 
386
    def next_read_size(self):
 
387
        if self.bytes_left is not None:
 
388
            # Ideally we want to read all the remainder of the body and the
 
389
            # trailer in one go.
 
390
            return self.bytes_left + 5
 
391
        elif self.state_accept == self._state_accept_reading_trailer:
 
392
            # Just the trailer left
 
393
            return 5 - len(self._trailer_buffer)
 
394
        elif self.state_accept == self._state_accept_expecting_length:
 
395
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
396
            # 'done\n').
 
397
            return 6
 
398
        else:
 
399
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
400
            return 1
 
401
        
 
402
    def read_pending_data(self):
 
403
        """Return any pending data that has been decoded."""
 
404
        return self.state_read()
 
405
 
 
406
    def _state_accept_expecting_length(self, bytes):
 
407
        self._in_buffer += bytes
 
408
        pos = self._in_buffer.find('\n')
 
409
        if pos == -1:
 
410
            return
 
411
        self.bytes_left = int(self._in_buffer[:pos])
 
412
        self._in_buffer = self._in_buffer[pos+1:]
 
413
        self.bytes_left -= len(self._in_buffer)
 
414
        self.state_accept = self._state_accept_reading_body
 
415
        self.state_read = self._state_read_in_buffer
 
416
 
 
417
    def _state_accept_reading_body(self, bytes):
 
418
        self._in_buffer += bytes
 
419
        self.bytes_left -= len(bytes)
 
420
        if self.bytes_left <= 0:
 
421
            # Finished with body
 
422
            if self.bytes_left != 0:
 
423
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
424
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
425
            self.bytes_left = None
 
426
            self.state_accept = self._state_accept_reading_trailer
 
427
        
 
428
    def _state_accept_reading_trailer(self, bytes):
 
429
        self._trailer_buffer += bytes
 
430
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
431
        # a ProtocolViolation exception?
 
432
        if self._trailer_buffer.startswith('done\n'):
 
433
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
434
            self.state_accept = self._state_accept_reading_unused
 
435
            self.finished_reading = True
 
436
    
 
437
    def _state_accept_reading_unused(self, bytes):
 
438
        self.unused_data += bytes
 
439
 
 
440
    def _state_read_no_data(self):
 
441
        return ''
 
442
 
 
443
    def _state_read_in_buffer(self):
 
444
        result = self._in_buffer
 
445
        self._in_buffer = ''
 
446
        return result
 
447
 
 
448
 
 
449
class SmartServerStreamMedium(object):
 
450
    """Handles smart commands coming over a stream.
 
451
 
 
452
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
453
    in-process fifo for testing.
 
454
 
 
455
    One instance is created for each connected client; it can serve multiple
 
456
    requests in the lifetime of the connection.
 
457
 
 
458
    The server passes requests through to an underlying backing transport, 
 
459
    which will typically be a LocalTransport looking at the server's filesystem.
 
460
    """
 
461
 
 
462
    def __init__(self, backing_transport):
 
463
        """Construct new server.
 
464
 
 
465
        :param backing_transport: Transport for the directory served.
 
466
        """
 
467
        # backing_transport could be passed to serve instead of __init__
 
468
        self.backing_transport = backing_transport
 
469
        self.finished = False
 
470
 
 
471
    def serve(self):
 
472
        """Serve requests until the client disconnects."""
 
473
        # Keep a reference to stderr because the sys module's globals get set to
 
474
        # None during interpreter shutdown.
 
475
        from sys import stderr
 
476
        try:
 
477
            while not self.finished:
 
478
                protocol = SmartServerRequestProtocolOne(self.backing_transport,
 
479
                                                         self._write_out)
 
480
                self._serve_one_request(protocol)
 
481
        except Exception, e:
 
482
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
483
            raise
 
484
 
 
485
    def _serve_one_request(self, protocol):
 
486
        """Read one request from input, process, send back a response.
 
487
        
 
488
        :param protocol: a SmartServerRequestProtocol.
 
489
        """
 
490
        try:
 
491
            self._serve_one_request_unguarded(protocol)
 
492
        except KeyboardInterrupt:
 
493
            raise
 
494
        except Exception, e:
 
495
            self.terminate_due_to_error()
 
496
 
 
497
    def terminate_due_to_error(self):
 
498
        """Called when an unhandled exception from the protocol occurs."""
 
499
        raise NotImplementedError(self.terminate_due_to_error)
 
500
 
 
501
 
 
502
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
503
 
 
504
    def __init__(self, sock, backing_transport):
 
505
        """Constructor.
 
506
 
 
507
        :param sock: the socket the server will read from.  It will be put
 
508
            into blocking mode.
 
509
        """
 
510
        SmartServerStreamMedium.__init__(self, backing_transport)
 
511
        self.push_back = ''
 
512
        sock.setblocking(True)
 
513
        self.socket = sock
 
514
 
 
515
    def _serve_one_request_unguarded(self, protocol):
 
516
        while protocol.next_read_size():
 
517
            if self.push_back:
 
518
                protocol.accept_bytes(self.push_back)
 
519
                self.push_back = ''
 
520
            else:
 
521
                bytes = self.socket.recv(4096)
 
522
                if bytes == '':
 
523
                    self.finished = True
 
524
                    return
 
525
                protocol.accept_bytes(bytes)
 
526
        
 
527
        self.push_back = protocol.excess_buffer
 
528
    
 
529
    def terminate_due_to_error(self):
 
530
        """Called when an unhandled exception from the protocol occurs."""
 
531
        # TODO: This should log to a server log file, but no such thing
 
532
        # exists yet.  Andrew Bennetts 2006-09-29.
 
533
        self.socket.close()
 
534
        self.finished = True
 
535
 
 
536
    def _write_out(self, bytes):
 
537
        self.socket.sendall(bytes)
 
538
 
 
539
 
 
540
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
541
 
 
542
    def __init__(self, in_file, out_file, backing_transport):
 
543
        """Construct new server.
 
544
 
 
545
        :param in_file: Python file from which requests can be read.
 
546
        :param out_file: Python file to write responses.
 
547
        :param backing_transport: Transport for the directory served.
 
548
        """
 
549
        SmartServerStreamMedium.__init__(self, backing_transport)
 
550
        if sys.platform == 'win32':
 
551
            # force binary mode for files
 
552
            import msvcrt
 
553
            for f in (in_file, out_file):
 
554
                fileno = getattr(f, 'fileno', None)
 
555
                if fileno:
 
556
                    msvcrt.setmode(fileno(), os.O_BINARY)
 
557
        self._in = in_file
 
558
        self._out = out_file
 
559
 
 
560
    def _serve_one_request_unguarded(self, protocol):
 
561
        while True:
 
562
            bytes_to_read = protocol.next_read_size()
 
563
            if bytes_to_read == 0:
 
564
                # Finished serving this request.
 
565
                self._out.flush()
 
566
                return
 
567
            bytes = self._in.read(bytes_to_read)
 
568
            if bytes == '':
 
569
                # Connection has been closed.
 
570
                self.finished = True
 
571
                self._out.flush()
 
572
                return
 
573
            protocol.accept_bytes(bytes)
 
574
 
 
575
    def terminate_due_to_error(self):
 
576
        # TODO: This should log to a server log file, but no such thing
 
577
        # exists yet.  Andrew Bennetts 2006-09-29.
 
578
        self._out.close()
 
579
        self.finished = True
 
580
 
 
581
    def _write_out(self, bytes):
 
582
        self._out.write(bytes)
 
583
 
 
584
 
 
585
class SmartServerResponse(object):
 
586
    """Response generated by SmartServerRequestHandler."""
 
587
 
 
588
    def __init__(self, args, body=None):
 
589
        self.args = args
 
590
        self.body = body
 
591
 
 
592
# XXX: TODO: Create a SmartServerRequestHandler which will take the responsibility
 
593
# for delivering the data for a request. This could be done with as the
 
594
# StreamServer, though that would create conflation between request and response
 
595
# which may be undesirable.
 
596
 
 
597
 
 
598
class SmartServerRequestHandler(object):
 
599
    """Protocol logic for smart server.
 
600
    
 
601
    This doesn't handle serialization at all, it just processes requests and
 
602
    creates responses.
 
603
    """
 
604
 
 
605
    # IMPORTANT FOR IMPLEMENTORS: It is important that SmartServerRequestHandler
 
606
    # not contain encoding or decoding logic to allow the wire protocol to vary
 
607
    # from the object protocol: we will want to tweak the wire protocol separate
 
608
    # from the object model, and ideally we will be able to do that without
 
609
    # having a SmartServerRequestHandler subclass for each wire protocol, rather
 
610
    # just a Protocol subclass.
 
611
 
 
612
    # TODO: Better way of representing the body for commands that take it,
 
613
    # and allow it to be streamed into the server.
 
614
    
 
615
    def __init__(self, backing_transport):
 
616
        self._backing_transport = backing_transport
 
617
        self._converted_command = False
 
618
        self.finished_reading = False
 
619
        self._body_bytes = ''
 
620
        self.response = None
 
621
 
 
622
    def accept_body(self, bytes):
 
623
        """Accept body data.
 
624
 
 
625
        This should be overriden for each command that desired body data to
 
626
        handle the right format of that data. I.e. plain bytes, a bundle etc.
 
627
 
 
628
        The deserialisation into that format should be done in the Protocol
 
629
        object. Set self.desired_body_format to the format your method will
 
630
        handle.
 
631
        """
 
632
        # default fallback is to accumulate bytes.
 
633
        self._body_bytes += bytes
 
634
        
 
635
    def _end_of_body_handler(self):
 
636
        """An unimplemented end of body handler."""
 
637
        raise NotImplementedError(self._end_of_body_handler)
 
638
        
 
639
    def do_hello(self):
 
640
        """Answer a version request with my version."""
 
641
        return SmartServerResponse(('ok', '1'))
 
642
 
 
643
    def do_has(self, relpath):
 
644
        r = self._backing_transport.has(relpath) and 'yes' or 'no'
 
645
        return SmartServerResponse((r,))
 
646
 
 
647
    def do_get(self, relpath):
 
648
        backing_bytes = self._backing_transport.get_bytes(relpath)
 
649
        return SmartServerResponse(('ok',), backing_bytes)
 
650
 
 
651
    def _deserialise_optional_mode(self, mode):
 
652
        # XXX: FIXME this should be on the protocol object.
 
653
        if mode == '':
 
654
            return None
 
655
        else:
 
656
            return int(mode)
 
657
 
 
658
    def do_append(self, relpath, mode):
 
659
        self._converted_command = True
 
660
        self._relpath = relpath
 
661
        self._mode = self._deserialise_optional_mode(mode)
 
662
        self._end_of_body_handler = self._handle_do_append_end
 
663
    
 
664
    def _handle_do_append_end(self):
 
665
        old_length = self._backing_transport.append_bytes(
 
666
            self._relpath, self._body_bytes, self._mode)
 
667
        self.response = SmartServerResponse(('appended', '%d' % old_length))
 
668
 
 
669
    def do_delete(self, relpath):
 
670
        self._backing_transport.delete(relpath)
 
671
 
 
672
    def do_iter_files_recursive(self, relpath):
 
673
        transport = self._backing_transport.clone(relpath)
 
674
        filenames = transport.iter_files_recursive()
 
675
        return SmartServerResponse(('names',) + tuple(filenames))
 
676
 
 
677
    def do_list_dir(self, relpath):
 
678
        filenames = self._backing_transport.list_dir(relpath)
 
679
        return SmartServerResponse(('names',) + tuple(filenames))
 
680
 
 
681
    def do_mkdir(self, relpath, mode):
 
682
        self._backing_transport.mkdir(relpath,
 
683
                                      self._deserialise_optional_mode(mode))
 
684
 
 
685
    def do_move(self, rel_from, rel_to):
 
686
        self._backing_transport.move(rel_from, rel_to)
 
687
 
 
688
    def do_put(self, relpath, mode):
 
689
        self._converted_command = True
 
690
        self._relpath = relpath
 
691
        self._mode = self._deserialise_optional_mode(mode)
 
692
        self._end_of_body_handler = self._handle_do_put
 
693
 
 
694
    def _handle_do_put(self):
 
695
        self._backing_transport.put_bytes(self._relpath,
 
696
                self._body_bytes, self._mode)
 
697
        self.response = SmartServerResponse(('ok',))
 
698
 
 
699
    def _deserialise_offsets(self, text):
 
700
        # XXX: FIXME this should be on the protocol object.
 
701
        offsets = []
 
702
        for line in text.split('\n'):
 
703
            if not line:
 
704
                continue
 
705
            start, length = line.split(',')
 
706
            offsets.append((int(start), int(length)))
 
707
        return offsets
 
708
 
 
709
    def do_put_non_atomic(self, relpath, mode, create_parent, dir_mode):
 
710
        self._converted_command = True
 
711
        self._end_of_body_handler = self._handle_put_non_atomic
 
712
        self._relpath = relpath
 
713
        self._dir_mode = self._deserialise_optional_mode(dir_mode)
 
714
        self._mode = self._deserialise_optional_mode(mode)
 
715
        # a boolean would be nicer XXX
 
716
        self._create_parent = (create_parent == 'T')
 
717
 
 
718
    def _handle_put_non_atomic(self):
 
719
        self._backing_transport.put_bytes_non_atomic(self._relpath,
 
720
                self._body_bytes,
 
721
                mode=self._mode,
 
722
                create_parent_dir=self._create_parent,
 
723
                dir_mode=self._dir_mode)
 
724
        self.response = SmartServerResponse(('ok',))
 
725
 
 
726
    def do_readv(self, relpath):
 
727
        self._converted_command = True
 
728
        self._end_of_body_handler = self._handle_readv_offsets
 
729
        self._relpath = relpath
 
730
 
 
731
    def end_of_body(self):
 
732
        """No more body data will be received."""
 
733
        self._run_handler_code(self._end_of_body_handler, (), {})
 
734
        # cannot read after this.
 
735
        self.finished_reading = True
 
736
 
 
737
    def _handle_readv_offsets(self):
 
738
        """accept offsets for a readv request."""
 
739
        offsets = self._deserialise_offsets(self._body_bytes)
 
740
        backing_bytes = ''.join(bytes for offset, bytes in
 
741
            self._backing_transport.readv(self._relpath, offsets))
 
742
        self.response = SmartServerResponse(('readv',), backing_bytes)
 
743
        
 
744
    def do_rename(self, rel_from, rel_to):
 
745
        self._backing_transport.rename(rel_from, rel_to)
 
746
 
 
747
    def do_rmdir(self, relpath):
 
748
        self._backing_transport.rmdir(relpath)
 
749
 
 
750
    def do_stat(self, relpath):
 
751
        stat = self._backing_transport.stat(relpath)
 
752
        return SmartServerResponse(('stat', str(stat.st_size), oct(stat.st_mode)))
 
753
        
 
754
    def do_get_bundle(self, path, revision_id):
 
755
        # open transport relative to our base
 
756
        t = self._backing_transport.clone(path)
 
757
        control, extra_path = bzrdir.BzrDir.open_containing_from_transport(t)
 
758
        repo = control.open_repository()
 
759
        tmpf = tempfile.TemporaryFile()
 
760
        base_revision = revision.NULL_REVISION
 
761
        write_bundle(repo, revision_id, base_revision, tmpf)
 
762
        tmpf.seek(0)
 
763
        return SmartServerResponse((), tmpf.read())
 
764
 
 
765
    def dispatch_command(self, cmd, args):
 
766
        """Deprecated compatibility method.""" # XXX XXX
 
767
        func = getattr(self, 'do_' + cmd, None)
 
768
        if func is None:
 
769
            raise errors.SmartProtocolError("bad request %r" % (cmd,))
 
770
        self._run_handler_code(func, args, {})
 
771
 
 
772
    def _run_handler_code(self, callable, args, kwargs):
 
773
        """Run some handler specific code 'callable'.
 
774
 
 
775
        If a result is returned, it is considered to be the commands response,
 
776
        and finished_reading is set true, and its assigned to self.response.
 
777
 
 
778
        Any exceptions caught are translated and a response object created
 
779
        from them.
 
780
        """
 
781
        result = self._call_converting_errors(callable, args, kwargs)
 
782
        if result is not None:
 
783
            self.response = result
 
784
            self.finished_reading = True
 
785
        # handle unconverted commands
 
786
        if not self._converted_command:
 
787
            self.finished_reading = True
 
788
            if result is None:
 
789
                self.response = SmartServerResponse(('ok',))
 
790
 
 
791
    def _call_converting_errors(self, callable, args, kwargs):
 
792
        """Call callable converting errors to Response objects."""
 
793
        try:
 
794
            return callable(*args, **kwargs)
 
795
        except errors.NoSuchFile, e:
 
796
            return SmartServerResponse(('NoSuchFile', e.path))
 
797
        except errors.FileExists, e:
 
798
            return SmartServerResponse(('FileExists', e.path))
 
799
        except errors.DirectoryNotEmpty, e:
 
800
            return SmartServerResponse(('DirectoryNotEmpty', e.path))
 
801
        except errors.ShortReadvError, e:
 
802
            return SmartServerResponse(('ShortReadvError',
 
803
                e.path, str(e.offset), str(e.length), str(e.actual)))
 
804
        except UnicodeError, e:
 
805
            # If it is a DecodeError, than most likely we are starting
 
806
            # with a plain string
 
807
            str_or_unicode = e.object
 
808
            if isinstance(str_or_unicode, unicode):
 
809
                # XXX: UTF-8 might have \x01 (our seperator byte) in it.  We
 
810
                # should escape it somehow.
 
811
                val = 'u:' + str_or_unicode.encode('utf-8')
 
812
            else:
 
813
                val = 's:' + str_or_unicode.encode('base64')
 
814
            # This handles UnicodeEncodeError or UnicodeDecodeError
 
815
            return SmartServerResponse((e.__class__.__name__,
 
816
                    e.encoding, val, str(e.start), str(e.end), e.reason))
 
817
        except errors.TransportNotPossible, e:
 
818
            if e.msg == "readonly transport":
 
819
                return SmartServerResponse(('ReadOnlyError', ))
 
820
            else:
 
821
                raise
 
822
 
 
823
 
 
824
class SmartTCPServer(object):
 
825
    """Listens on a TCP socket and accepts connections from smart clients
 
826
 
 
827
    hooks: An instance of SmartServerHooks.
 
828
    """
 
829
 
 
830
    def __init__(self, backing_transport, host='127.0.0.1', port=0):
 
831
        """Construct a new server.
 
832
 
 
833
        To actually start it running, call either start_background_thread or
 
834
        serve.
 
835
 
 
836
        :param host: Name of the interface to listen on.
 
837
        :param port: TCP port to listen on, or 0 to allocate a transient port.
 
838
        """
 
839
        from socket import error as socket_error
 
840
        self._socket_error = socket_error
 
841
        self._server_socket = socket.socket()
 
842
        self._server_socket.bind((host, port))
 
843
        self._sockname = self._server_socket.getsockname()
 
844
        self.port = self._sockname[1]
 
845
        self._server_socket.listen(1)
 
846
        self._server_socket.settimeout(1)
 
847
        self.backing_transport = backing_transport
 
848
        self._started = threading.Event()
 
849
        self._stopped = threading.Event()
 
850
 
 
851
    def serve(self):
 
852
        # let connections timeout so that we get a chance to terminate
 
853
        # Keep a reference to the exceptions we want to catch because the socket
 
854
        # module's globals get set to None during interpreter shutdown.
 
855
        from socket import timeout as socket_timeout
 
856
        self._should_terminate = False
 
857
        for hook in SmartTCPServer.hooks['server_started']:
 
858
            hook(self.backing_transport.base, self.get_url())
 
859
        self._started.set()
 
860
        try:
 
861
            try:
 
862
                while not self._should_terminate:
 
863
                    try:
 
864
                        conn, client_addr = self._server_socket.accept()
 
865
                    except socket_timeout:
 
866
                        # just check if we're asked to stop
 
867
                        pass
 
868
                    except self._socket_error, e:
 
869
                        # if the socket is closed by stop_background_thread
 
870
                        # we might get a EBADF here, any other socket errors
 
871
                        # should get logged.
 
872
                        if e.args[0] != errno.EBADF:
 
873
                            trace.warning("listening socket error: %s", e)
 
874
                    else:
 
875
                        self.serve_conn(conn)
 
876
            except KeyboardInterrupt:
 
877
                # dont log when CTRL-C'd.
 
878
                raise
 
879
            except Exception, e:
 
880
                trace.error("Unhandled smart server error.")
 
881
                trace.log_exception_quietly()
 
882
                raise
 
883
        finally:
 
884
            self._stopped.set()
 
885
            try:
 
886
                # ensure the server socket is closed.
 
887
                self._server_socket.close()
 
888
            except self._socket_error:
 
889
                # ignore errors on close
 
890
                pass
 
891
            for hook in SmartTCPServer.hooks['server_stopped']:
 
892
                hook(self.backing_transport.base, self.get_url())
 
893
 
 
894
    def get_url(self):
 
895
        """Return the url of the server"""
 
896
        return "bzr://%s:%d/" % self._sockname
 
897
 
 
898
    def serve_conn(self, conn):
 
899
        # For WIN32, where the timeout value from the listening socket
 
900
        # propogates to the newly accepted socket.
 
901
        conn.setblocking(True)
 
902
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
903
        handler = SmartServerSocketStreamMedium(conn, self.backing_transport)
 
904
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
 
905
        connection_thread.setDaemon(True)
 
906
        connection_thread.start()
 
907
 
 
908
    def start_background_thread(self):
 
909
        self._started.clear()
 
910
        self._server_thread = threading.Thread(None,
 
911
                self.serve,
 
912
                name='server-' + self.get_url())
 
913
        self._server_thread.setDaemon(True)
 
914
        self._server_thread.start()
 
915
        self._started.wait()
 
916
 
 
917
    def stop_background_thread(self):
 
918
        self._stopped.clear()
 
919
        # tell the main loop to quit on the next iteration.
 
920
        self._should_terminate = True
 
921
        # close the socket - gives error to connections from here on in,
 
922
        # rather than a connection reset error to connections made during
 
923
        # the period between setting _should_terminate = True and 
 
924
        # the current request completing/aborting. It may also break out the
 
925
        # main loop if it was currently in accept() (on some platforms).
 
926
        try:
 
927
            self._server_socket.close()
 
928
        except self._socket_error:
 
929
            # ignore errors on close
 
930
            pass
 
931
        if not self._stopped.isSet():
 
932
            # server has not stopped (though it may be stopping)
 
933
            # its likely in accept(), so give it a connection
 
934
            temp_socket = socket.socket()
 
935
            temp_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
936
            if not temp_socket.connect_ex(self._sockname):
 
937
                # and close it immediately: we dont choose to send any requests.
 
938
                temp_socket.close()
 
939
        self._stopped.wait()
 
940
        self._server_thread.join()
 
941
 
 
942
 
 
943
class SmartServerHooks(Hooks):
 
944
    """Hooks for the smart server."""
 
945
 
 
946
    def __init__(self):
 
947
        """Create the default hooks.
 
948
 
 
949
        These are all empty initially, because by default nothing should get
 
950
        notified.
 
951
        """
 
952
        Hooks.__init__(self)
 
953
        # Introduced in 0.16:
 
954
        # invoked whenever the server starts serving a directory.
 
955
        # The api signature is (backing url, public url).
 
956
        self['server_started'] = []
 
957
        # Introduced in 0.16:
 
958
        # invoked whenever the server stops serving a directory.
 
959
        # The api signature is (backing url, public url).
 
960
        self['server_stopped'] = []
 
961
 
 
962
SmartTCPServer.hooks = SmartServerHooks()
 
963
 
 
964
 
 
965
class SmartTCPServer_for_testing(SmartTCPServer):
 
966
    """Server suitable for use by transport tests.
 
967
    
 
968
    This server is backed by the process's cwd.
 
969
    """
 
970
 
 
971
    def __init__(self):
 
972
        self._homedir = urlutils.local_path_to_url(os.getcwd())[7:]
 
973
        # The server is set up by default like for ssh access: the client
 
974
        # passes filesystem-absolute paths; therefore the server must look
 
975
        # them up relative to the root directory.  it might be better to act
 
976
        # a public server and have the server rewrite paths into the test
 
977
        # directory.
 
978
        SmartTCPServer.__init__(self,
 
979
            transport.get_transport(urlutils.local_path_to_url('/')))
 
980
        
 
981
    def get_backing_transport(self, backing_transport_server):
 
982
        """Get a backing transport from a server we are decorating."""
 
983
        return transport.get_transport(backing_transport_server.get_url())
 
984
 
 
985
    def setUp(self, backing_transport_server=None):
 
986
        """Set up server for testing"""
 
987
        from bzrlib.transport.chroot import TestingChrootServer
 
988
        if backing_transport_server is None:
 
989
            from bzrlib.transport.local import LocalURLServer
 
990
            backing_transport_server = LocalURLServer()
 
991
        self.chroot_server = TestingChrootServer()
 
992
        self.chroot_server.setUp(backing_transport_server)
 
993
        self.backing_transport = transport.get_transport(
 
994
            self.chroot_server.get_url())
 
995
        self.start_background_thread()
 
996
 
 
997
    def tearDown(self):
 
998
        self.stop_background_thread()
 
999
 
 
1000
    def get_bogus_url(self):
 
1001
        """Return a URL which will fail to connect"""
 
1002
        return 'bzr://127.0.0.1:1/'
 
1003
 
 
1004
 
 
1005
class SmartStat(object):
 
1006
 
 
1007
    def __init__(self, size, mode):
 
1008
        self.st_size = size
 
1009
        self.st_mode = mode
 
1010
 
 
1011
 
 
1012
class SmartTransport(transport.Transport):
 
1013
    """Connection to a smart server.
 
1014
 
 
1015
    The connection holds references to pipes that can be used to send requests
 
1016
    to the server.
 
1017
 
 
1018
    The connection has a notion of the current directory to which it's
 
1019
    connected; this is incorporated in filenames passed to the server.
 
1020
    
 
1021
    This supports some higher-level RPC operations and can also be treated 
 
1022
    like a Transport to do file-like operations.
 
1023
 
 
1024
    The connection can be made over a tcp socket, or (in future) an ssh pipe
 
1025
    or a series of http requests.  There are concrete subclasses for each
 
1026
    type: SmartTCPTransport, etc.
 
1027
    """
 
1028
 
 
1029
    # IMPORTANT FOR IMPLEMENTORS: SmartTransport MUST NOT be given encoding
 
1030
    # responsibilities: Put those on SmartClient or similar. This is vital for
 
1031
    # the ability to support multiple versions of the smart protocol over time:
 
1032
    # SmartTransport is an adapter from the Transport object model to the 
 
1033
    # SmartClient model, not an encoder.
 
1034
 
 
1035
    def __init__(self, url, clone_from=None, medium=None):
 
1036
        """Constructor.
 
1037
 
 
1038
        :param medium: The medium to use for this RemoteTransport. This must be
 
1039
            supplied if clone_from is None.
 
1040
        """
 
1041
        ### Technically super() here is faulty because Transport's __init__
 
1042
        ### fails to take 2 parameters, and if super were to choose a silly
 
1043
        ### initialisation order things would blow up. 
 
1044
        if not url.endswith('/'):
 
1045
            url += '/'
 
1046
        super(SmartTransport, self).__init__(url)
 
1047
        self._scheme, self._username, self._password, self._host, self._port, self._path = \
 
1048
                transport.split_url(url)
 
1049
        if clone_from is None:
 
1050
            self._medium = medium
 
1051
        else:
 
1052
            # credentials may be stripped from the base in some circumstances
 
1053
            # as yet to be clearly defined or documented, so copy them.
 
1054
            self._username = clone_from._username
 
1055
            # reuse same connection
 
1056
            self._medium = clone_from._medium
 
1057
        assert self._medium is not None
 
1058
 
 
1059
    def abspath(self, relpath):
 
1060
        """Return the full url to the given relative path.
 
1061
        
 
1062
        @param relpath: the relative path or path components
 
1063
        @type relpath: str or list
 
1064
        """
 
1065
        return self._unparse_url(self._remote_path(relpath))
 
1066
    
 
1067
    def clone(self, relative_url):
 
1068
        """Make a new SmartTransport related to me, sharing the same connection.
 
1069
 
 
1070
        This essentially opens a handle on a different remote directory.
 
1071
        """
 
1072
        if relative_url is None:
 
1073
            return SmartTransport(self.base, self)
 
1074
        else:
 
1075
            return SmartTransport(self.abspath(relative_url), self)
 
1076
 
 
1077
    def is_readonly(self):
 
1078
        """Smart server transport can do read/write file operations."""
 
1079
        return False
 
1080
                                                   
 
1081
    def get_smart_client(self):
 
1082
        return self._medium
 
1083
 
 
1084
    def get_smart_medium(self):
 
1085
        return self._medium
 
1086
                                                   
 
1087
    def _unparse_url(self, path):
 
1088
        """Return URL for a path.
 
1089
 
 
1090
        :see: SFTPUrlHandling._unparse_url
 
1091
        """
 
1092
        # TODO: Eventually it should be possible to unify this with
 
1093
        # SFTPUrlHandling._unparse_url?
 
1094
        if path == '':
 
1095
            path = '/'
 
1096
        path = urllib.quote(path)
 
1097
        netloc = urllib.quote(self._host)
 
1098
        if self._username is not None:
 
1099
            netloc = '%s@%s' % (urllib.quote(self._username), netloc)
 
1100
        if self._port is not None:
 
1101
            netloc = '%s:%d' % (netloc, self._port)
 
1102
        return urlparse.urlunparse((self._scheme, netloc, path, '', '', ''))
 
1103
 
 
1104
    def _remote_path(self, relpath):
 
1105
        """Returns the Unicode version of the absolute path for relpath."""
 
1106
        return self._combine_paths(self._path, relpath)
 
1107
 
 
1108
    def _call(self, method, *args):
 
1109
        resp = self._call2(method, *args)
 
1110
        self._translate_error(resp)
 
1111
 
 
1112
    def _call2(self, method, *args):
 
1113
        """Call a method on the remote server."""
 
1114
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1115
        protocol.call(method, *args)
 
1116
        return protocol.read_response_tuple()
 
1117
 
 
1118
    def _call_with_body_bytes(self, method, args, body):
 
1119
        """Call a method on the remote server with body bytes."""
 
1120
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1121
        protocol.call_with_body_bytes((method, ) + args, body)
 
1122
        return protocol.read_response_tuple()
 
1123
 
 
1124
    def has(self, relpath):
 
1125
        """Indicate whether a remote file of the given name exists or not.
 
1126
 
 
1127
        :see: Transport.has()
 
1128
        """
 
1129
        resp = self._call2('has', self._remote_path(relpath))
 
1130
        if resp == ('yes', ):
 
1131
            return True
 
1132
        elif resp == ('no', ):
 
1133
            return False
 
1134
        else:
 
1135
            self._translate_error(resp)
 
1136
 
 
1137
    def get(self, relpath):
 
1138
        """Return file-like object reading the contents of a remote file.
 
1139
        
 
1140
        :see: Transport.get_bytes()/get_file()
 
1141
        """
 
1142
        return StringIO(self.get_bytes(relpath))
 
1143
 
 
1144
    def get_bytes(self, relpath):
 
1145
        remote = self._remote_path(relpath)
 
1146
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1147
        protocol.call('get', remote)
 
1148
        resp = protocol.read_response_tuple(True)
 
1149
        if resp != ('ok', ):
 
1150
            protocol.cancel_read_body()
 
1151
            self._translate_error(resp, relpath)
 
1152
        return protocol.read_body_bytes()
 
1153
 
 
1154
    def _serialise_optional_mode(self, mode):
 
1155
        if mode is None:
 
1156
            return ''
 
1157
        else:
 
1158
            return '%d' % mode
 
1159
 
 
1160
    def mkdir(self, relpath, mode=None):
 
1161
        resp = self._call2('mkdir', self._remote_path(relpath),
 
1162
            self._serialise_optional_mode(mode))
 
1163
        self._translate_error(resp)
 
1164
 
 
1165
    def put_bytes(self, relpath, upload_contents, mode=None):
 
1166
        # FIXME: upload_file is probably not safe for non-ascii characters -
 
1167
        # should probably just pass all parameters as length-delimited
 
1168
        # strings?
 
1169
        resp = self._call_with_body_bytes('put',
 
1170
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
1171
            upload_contents)
 
1172
        self._translate_error(resp)
 
1173
 
 
1174
    def put_bytes_non_atomic(self, relpath, bytes, mode=None,
 
1175
                             create_parent_dir=False,
 
1176
                             dir_mode=None):
 
1177
        """See Transport.put_bytes_non_atomic."""
 
1178
        # FIXME: no encoding in the transport!
 
1179
        create_parent_str = 'F'
 
1180
        if create_parent_dir:
 
1181
            create_parent_str = 'T'
 
1182
 
 
1183
        resp = self._call_with_body_bytes(
 
1184
            'put_non_atomic',
 
1185
            (self._remote_path(relpath), self._serialise_optional_mode(mode),
 
1186
             create_parent_str, self._serialise_optional_mode(dir_mode)),
 
1187
            bytes)
 
1188
        self._translate_error(resp)
 
1189
 
 
1190
    def put_file(self, relpath, upload_file, mode=None):
 
1191
        # its not ideal to seek back, but currently put_non_atomic_file depends
 
1192
        # on transports not reading before failing - which is a faulty
 
1193
        # assumption I think - RBC 20060915
 
1194
        pos = upload_file.tell()
 
1195
        try:
 
1196
            return self.put_bytes(relpath, upload_file.read(), mode)
 
1197
        except:
 
1198
            upload_file.seek(pos)
 
1199
            raise
 
1200
 
 
1201
    def put_file_non_atomic(self, relpath, f, mode=None,
 
1202
                            create_parent_dir=False,
 
1203
                            dir_mode=None):
 
1204
        return self.put_bytes_non_atomic(relpath, f.read(), mode=mode,
 
1205
                                         create_parent_dir=create_parent_dir,
 
1206
                                         dir_mode=dir_mode)
 
1207
 
 
1208
    def append_file(self, relpath, from_file, mode=None):
 
1209
        return self.append_bytes(relpath, from_file.read(), mode)
 
1210
        
 
1211
    def append_bytes(self, relpath, bytes, mode=None):
 
1212
        resp = self._call_with_body_bytes(
 
1213
            'append',
 
1214
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
 
1215
            bytes)
 
1216
        if resp[0] == 'appended':
 
1217
            return int(resp[1])
 
1218
        self._translate_error(resp)
 
1219
 
 
1220
    def delete(self, relpath):
 
1221
        resp = self._call2('delete', self._remote_path(relpath))
 
1222
        self._translate_error(resp)
 
1223
 
 
1224
    def readv(self, relpath, offsets):
 
1225
        if not offsets:
 
1226
            return
 
1227
 
 
1228
        offsets = list(offsets)
 
1229
 
 
1230
        sorted_offsets = sorted(offsets)
 
1231
        # turn the list of offsets into a stack
 
1232
        offset_stack = iter(offsets)
 
1233
        cur_offset_and_size = offset_stack.next()
 
1234
        coalesced = list(self._coalesce_offsets(sorted_offsets,
 
1235
                               limit=self._max_readv_combine,
 
1236
                               fudge_factor=self._bytes_to_read_before_seek))
 
1237
 
 
1238
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1239
        protocol.call_with_body_readv_array(
 
1240
            ('readv', self._remote_path(relpath)),
 
1241
            [(c.start, c.length) for c in coalesced])
 
1242
        resp = protocol.read_response_tuple(True)
 
1243
 
 
1244
        if resp[0] != 'readv':
 
1245
            # This should raise an exception
 
1246
            protocol.cancel_read_body()
 
1247
            self._translate_error(resp)
 
1248
            return
 
1249
 
 
1250
        # FIXME: this should know how many bytes are needed, for clarity.
 
1251
        data = protocol.read_body_bytes()
 
1252
        # Cache the results, but only until they have been fulfilled
 
1253
        data_map = {}
 
1254
        for c_offset in coalesced:
 
1255
            if len(data) < c_offset.length:
 
1256
                raise errors.ShortReadvError(relpath, c_offset.start,
 
1257
                            c_offset.length, actual=len(data))
 
1258
            for suboffset, subsize in c_offset.ranges:
 
1259
                key = (c_offset.start+suboffset, subsize)
 
1260
                data_map[key] = data[suboffset:suboffset+subsize]
 
1261
            data = data[c_offset.length:]
 
1262
 
 
1263
            # Now that we've read some data, see if we can yield anything back
 
1264
            while cur_offset_and_size in data_map:
 
1265
                this_data = data_map.pop(cur_offset_and_size)
 
1266
                yield cur_offset_and_size[0], this_data
 
1267
                cur_offset_and_size = offset_stack.next()
 
1268
 
 
1269
    def rename(self, rel_from, rel_to):
 
1270
        self._call('rename',
 
1271
                   self._remote_path(rel_from),
 
1272
                   self._remote_path(rel_to))
 
1273
 
 
1274
    def move(self, rel_from, rel_to):
 
1275
        self._call('move',
 
1276
                   self._remote_path(rel_from),
 
1277
                   self._remote_path(rel_to))
 
1278
 
 
1279
    def rmdir(self, relpath):
 
1280
        resp = self._call('rmdir', self._remote_path(relpath))
 
1281
 
 
1282
    def _translate_error(self, resp, orig_path=None):
 
1283
        """Raise an exception from a response"""
 
1284
        if resp is None:
 
1285
            what = None
 
1286
        else:
 
1287
            what = resp[0]
 
1288
        if what == 'ok':
 
1289
            return
 
1290
        elif what == 'NoSuchFile':
 
1291
            if orig_path is not None:
 
1292
                error_path = orig_path
 
1293
            else:
 
1294
                error_path = resp[1]
 
1295
            raise errors.NoSuchFile(error_path)
 
1296
        elif what == 'error':
 
1297
            raise errors.SmartProtocolError(unicode(resp[1]))
 
1298
        elif what == 'FileExists':
 
1299
            raise errors.FileExists(resp[1])
 
1300
        elif what == 'DirectoryNotEmpty':
 
1301
            raise errors.DirectoryNotEmpty(resp[1])
 
1302
        elif what == 'ShortReadvError':
 
1303
            raise errors.ShortReadvError(resp[1], int(resp[2]),
 
1304
                                         int(resp[3]), int(resp[4]))
 
1305
        elif what in ('UnicodeEncodeError', 'UnicodeDecodeError'):
 
1306
            encoding = str(resp[1]) # encoding must always be a string
 
1307
            val = resp[2]
 
1308
            start = int(resp[3])
 
1309
            end = int(resp[4])
 
1310
            reason = str(resp[5]) # reason must always be a string
 
1311
            if val.startswith('u:'):
 
1312
                val = val[2:].decode('utf-8')
 
1313
            elif val.startswith('s:'):
 
1314
                val = val[2:].decode('base64')
 
1315
            if what == 'UnicodeDecodeError':
 
1316
                raise UnicodeDecodeError(encoding, val, start, end, reason)
 
1317
            elif what == 'UnicodeEncodeError':
 
1318
                raise UnicodeEncodeError(encoding, val, start, end, reason)
 
1319
        elif what == "ReadOnlyError":
 
1320
            raise errors.TransportNotPossible('readonly transport')
 
1321
        else:
 
1322
            raise errors.SmartProtocolError('unexpected smart server error: %r' % (resp,))
 
1323
 
 
1324
    def disconnect(self):
 
1325
        self._medium.disconnect()
 
1326
 
 
1327
    def delete_tree(self, relpath):
 
1328
        raise errors.TransportNotPossible('readonly transport')
 
1329
 
 
1330
    def stat(self, relpath):
 
1331
        resp = self._call2('stat', self._remote_path(relpath))
 
1332
        if resp[0] == 'stat':
 
1333
            return SmartStat(int(resp[1]), int(resp[2], 8))
 
1334
        else:
 
1335
            self._translate_error(resp)
 
1336
 
 
1337
    ## def lock_read(self, relpath):
 
1338
    ##     """Lock the given file for shared (read) access.
 
1339
    ##     :return: A lock object, which should be passed to Transport.unlock()
 
1340
    ##     """
 
1341
    ##     # The old RemoteBranch ignore lock for reading, so we will
 
1342
    ##     # continue that tradition and return a bogus lock object.
 
1343
    ##     class BogusLock(object):
 
1344
    ##         def __init__(self, path):
 
1345
    ##             self.path = path
 
1346
    ##         def unlock(self):
 
1347
    ##             pass
 
1348
    ##     return BogusLock(relpath)
 
1349
 
 
1350
    def listable(self):
 
1351
        return True
 
1352
 
 
1353
    def list_dir(self, relpath):
 
1354
        resp = self._call2('list_dir', self._remote_path(relpath))
 
1355
        if resp[0] == 'names':
 
1356
            return [name.encode('ascii') for name in resp[1:]]
 
1357
        else:
 
1358
            self._translate_error(resp)
 
1359
 
 
1360
    def iter_files_recursive(self):
 
1361
        resp = self._call2('iter_files_recursive', self._remote_path(''))
 
1362
        if resp[0] == 'names':
 
1363
            return resp[1:]
 
1364
        else:
 
1365
            self._translate_error(resp)
 
1366
 
 
1367
 
 
1368
class SmartClientMediumRequest(object):
 
1369
    """A request on a SmartClientMedium.
 
1370
 
 
1371
    Each request allows bytes to be provided to it via accept_bytes, and then
 
1372
    the response bytes to be read via read_bytes.
 
1373
 
 
1374
    For instance:
 
1375
    request.accept_bytes('123')
 
1376
    request.finished_writing()
 
1377
    result = request.read_bytes(3)
 
1378
    request.finished_reading()
 
1379
 
 
1380
    It is up to the individual SmartClientMedium whether multiple concurrent
 
1381
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
1382
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
1383
    details on concurrency and pipelining.
 
1384
    """
 
1385
 
 
1386
    def __init__(self, medium):
 
1387
        """Construct a SmartClientMediumRequest for the medium medium."""
 
1388
        self._medium = medium
 
1389
        # we track state by constants - we may want to use the same
 
1390
        # pattern as BodyReader if it gets more complex.
 
1391
        # valid states are: "writing", "reading", "done"
 
1392
        self._state = "writing"
 
1393
 
 
1394
    def accept_bytes(self, bytes):
 
1395
        """Accept bytes for inclusion in this request.
 
1396
 
 
1397
        This method may not be be called after finished_writing() has been
 
1398
        called.  It depends upon the Medium whether or not the bytes will be
 
1399
        immediately transmitted. Message based Mediums will tend to buffer the
 
1400
        bytes until finished_writing() is called.
 
1401
 
 
1402
        :param bytes: A bytestring.
 
1403
        """
 
1404
        if self._state != "writing":
 
1405
            raise errors.WritingCompleted(self)
 
1406
        self._accept_bytes(bytes)
 
1407
 
 
1408
    def _accept_bytes(self, bytes):
 
1409
        """Helper for accept_bytes.
 
1410
 
 
1411
        Accept_bytes checks the state of the request to determing if bytes
 
1412
        should be accepted. After that it hands off to _accept_bytes to do the
 
1413
        actual acceptance.
 
1414
        """
 
1415
        raise NotImplementedError(self._accept_bytes)
 
1416
 
 
1417
    def finished_reading(self):
 
1418
        """Inform the request that all desired data has been read.
 
1419
 
 
1420
        This will remove the request from the pipeline for its medium (if the
 
1421
        medium supports pipelining) and any further calls to methods on the
 
1422
        request will raise ReadingCompleted.
 
1423
        """
 
1424
        if self._state == "writing":
 
1425
            raise errors.WritingNotComplete(self)
 
1426
        if self._state != "reading":
 
1427
            raise errors.ReadingCompleted(self)
 
1428
        self._state = "done"
 
1429
        self._finished_reading()
 
1430
 
 
1431
    def _finished_reading(self):
 
1432
        """Helper for finished_reading.
 
1433
 
 
1434
        finished_reading checks the state of the request to determine if 
 
1435
        finished_reading is allowed, and if it is hands off to _finished_reading
 
1436
        to perform the action.
 
1437
        """
 
1438
        raise NotImplementedError(self._finished_reading)
 
1439
 
 
1440
    def finished_writing(self):
 
1441
        """Finish the writing phase of this request.
 
1442
 
 
1443
        This will flush all pending data for this request along the medium.
 
1444
        After calling finished_writing, you may not call accept_bytes anymore.
 
1445
        """
 
1446
        if self._state != "writing":
 
1447
            raise errors.WritingCompleted(self)
 
1448
        self._state = "reading"
 
1449
        self._finished_writing()
 
1450
 
 
1451
    def _finished_writing(self):
 
1452
        """Helper for finished_writing.
 
1453
 
 
1454
        finished_writing checks the state of the request to determine if 
 
1455
        finished_writing is allowed, and if it is hands off to _finished_writing
 
1456
        to perform the action.
 
1457
        """
 
1458
        raise NotImplementedError(self._finished_writing)
 
1459
 
 
1460
    def read_bytes(self, count):
 
1461
        """Read bytes from this requests response.
 
1462
 
 
1463
        This method will block and wait for count bytes to be read. It may not
 
1464
        be invoked until finished_writing() has been called - this is to ensure
 
1465
        a message-based approach to requests, for compatability with message
 
1466
        based mediums like HTTP.
 
1467
        """
 
1468
        if self._state == "writing":
 
1469
            raise errors.WritingNotComplete(self)
 
1470
        if self._state != "reading":
 
1471
            raise errors.ReadingCompleted(self)
 
1472
        return self._read_bytes(count)
 
1473
 
 
1474
    def _read_bytes(self, count):
 
1475
        """Helper for read_bytes.
 
1476
 
 
1477
        read_bytes checks the state of the request to determing if bytes
 
1478
        should be read. After that it hands off to _read_bytes to do the
 
1479
        actual read.
 
1480
        """
 
1481
        raise NotImplementedError(self._read_bytes)
 
1482
 
 
1483
 
 
1484
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
1485
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
1486
 
 
1487
    def __init__(self, medium):
 
1488
        SmartClientMediumRequest.__init__(self, medium)
 
1489
        # check that we are safe concurrency wise. If some streams start
 
1490
        # allowing concurrent requests - i.e. via multiplexing - then this
 
1491
        # assert should be moved to SmartClientStreamMedium.get_request,
 
1492
        # and the setting/unsetting of _current_request likewise moved into
 
1493
        # that class : but its unneeded overhead for now. RBC 20060922
 
1494
        if self._medium._current_request is not None:
 
1495
            raise errors.TooManyConcurrentRequests(self._medium)
 
1496
        self._medium._current_request = self
 
1497
 
 
1498
    def _accept_bytes(self, bytes):
 
1499
        """See SmartClientMediumRequest._accept_bytes.
 
1500
        
 
1501
        This forwards to self._medium._accept_bytes because we are operating
 
1502
        on the mediums stream.
 
1503
        """
 
1504
        self._medium._accept_bytes(bytes)
 
1505
 
 
1506
    def _finished_reading(self):
 
1507
        """See SmartClientMediumRequest._finished_reading.
 
1508
 
 
1509
        This clears the _current_request on self._medium to allow a new 
 
1510
        request to be created.
 
1511
        """
 
1512
        assert self._medium._current_request is self
 
1513
        self._medium._current_request = None
 
1514
        
 
1515
    def _finished_writing(self):
 
1516
        """See SmartClientMediumRequest._finished_writing.
 
1517
 
 
1518
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
1519
        """
 
1520
        self._medium._flush()
 
1521
 
 
1522
    def _read_bytes(self, count):
 
1523
        """See SmartClientMediumRequest._read_bytes.
 
1524
        
 
1525
        This forwards to self._medium._read_bytes because we are operating
 
1526
        on the mediums stream.
 
1527
        """
 
1528
        return self._medium._read_bytes(count)
 
1529
 
 
1530
 
 
1531
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
1532
    """The client-side protocol for smart version 1."""
 
1533
 
 
1534
    def __init__(self, request):
 
1535
        """Construct a SmartClientRequestProtocolOne.
 
1536
 
 
1537
        :param request: A SmartClientMediumRequest to serialise onto and
 
1538
            deserialise from.
 
1539
        """
 
1540
        self._request = request
 
1541
        self._body_buffer = None
 
1542
 
 
1543
    def call(self, *args):
 
1544
        bytes = _encode_tuple(args)
 
1545
        self._request.accept_bytes(bytes)
 
1546
        self._request.finished_writing()
 
1547
 
 
1548
    def call_with_body_bytes(self, args, body):
 
1549
        """Make a remote call of args with body bytes 'body'.
 
1550
 
 
1551
        After calling this, call read_response_tuple to find the result out.
 
1552
        """
 
1553
        bytes = _encode_tuple(args)
 
1554
        self._request.accept_bytes(bytes)
 
1555
        bytes = self._encode_bulk_data(body)
 
1556
        self._request.accept_bytes(bytes)
 
1557
        self._request.finished_writing()
 
1558
 
 
1559
    def call_with_body_readv_array(self, args, body):
 
1560
        """Make a remote call with a readv array.
 
1561
 
 
1562
        The body is encoded with one line per readv offset pair. The numbers in
 
1563
        each pair are separated by a comma, and no trailing \n is emitted.
 
1564
        """
 
1565
        bytes = _encode_tuple(args)
 
1566
        self._request.accept_bytes(bytes)
 
1567
        readv_bytes = self._serialise_offsets(body)
 
1568
        bytes = self._encode_bulk_data(readv_bytes)
 
1569
        self._request.accept_bytes(bytes)
 
1570
        self._request.finished_writing()
 
1571
 
 
1572
    def cancel_read_body(self):
 
1573
        """After expecting a body, a response code may indicate one otherwise.
 
1574
 
 
1575
        This method lets the domain client inform the protocol that no body
 
1576
        will be transmitted. This is a terminal method: after calling it the
 
1577
        protocol is not able to be used further.
 
1578
        """
 
1579
        self._request.finished_reading()
 
1580
 
 
1581
    def read_response_tuple(self, expect_body=False):
 
1582
        """Read a response tuple from the wire.
 
1583
 
 
1584
        This should only be called once.
 
1585
        """
 
1586
        result = self._recv_tuple()
 
1587
        if not expect_body:
 
1588
            self._request.finished_reading()
 
1589
        return result
 
1590
 
 
1591
    def read_body_bytes(self, count=-1):
 
1592
        """Read bytes from the body, decoding into a byte stream.
 
1593
        
 
1594
        We read all bytes at once to ensure we've checked the trailer for 
 
1595
        errors, and then feed the buffer back as read_body_bytes is called.
 
1596
        """
 
1597
        if self._body_buffer is not None:
 
1598
            return self._body_buffer.read(count)
 
1599
        _body_decoder = LengthPrefixedBodyDecoder()
 
1600
 
 
1601
        while not _body_decoder.finished_reading:
 
1602
            bytes_wanted = _body_decoder.next_read_size()
 
1603
            bytes = self._request.read_bytes(bytes_wanted)
 
1604
            _body_decoder.accept_bytes(bytes)
 
1605
        self._request.finished_reading()
 
1606
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
1607
        # XXX: TODO check the trailer result.
 
1608
        return self._body_buffer.read(count)
 
1609
 
 
1610
    def _recv_tuple(self):
 
1611
        """Receive a tuple from the medium request."""
 
1612
        line = ''
 
1613
        while not line or line[-1] != '\n':
 
1614
            # TODO: this is inefficient - but tuples are short.
 
1615
            new_char = self._request.read_bytes(1)
 
1616
            line += new_char
 
1617
            assert new_char != '', "end of file reading from server."
 
1618
        return _decode_tuple(line)
 
1619
 
 
1620
    def query_version(self):
 
1621
        """Return protocol version number of the server."""
 
1622
        self.call('hello')
 
1623
        resp = self.read_response_tuple()
 
1624
        if resp == ('ok', '1'):
 
1625
            return 1
 
1626
        else:
 
1627
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
1628
 
 
1629
 
 
1630
class SmartClientMedium(object):
 
1631
    """Smart client is a medium for sending smart protocol requests over."""
 
1632
 
 
1633
    def disconnect(self):
 
1634
        """If this medium maintains a persistent connection, close it.
 
1635
        
 
1636
        The default implementation does nothing.
 
1637
        """
 
1638
        
 
1639
 
 
1640
class SmartClientStreamMedium(SmartClientMedium):
 
1641
    """Stream based medium common class.
 
1642
 
 
1643
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
1644
    SmartClientStreamMediumRequest for their requests, and should implement
 
1645
    _accept_bytes and _read_bytes to allow the request objects to send and
 
1646
    receive bytes.
 
1647
    """
 
1648
 
 
1649
    def __init__(self):
 
1650
        self._current_request = None
 
1651
 
 
1652
    def accept_bytes(self, bytes):
 
1653
        self._accept_bytes(bytes)
 
1654
 
 
1655
    def __del__(self):
 
1656
        """The SmartClientStreamMedium knows how to close the stream when it is
 
1657
        finished with it.
 
1658
        """
 
1659
        self.disconnect()
 
1660
 
 
1661
    def _flush(self):
 
1662
        """Flush the output stream.
 
1663
        
 
1664
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
1665
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
1666
        """
 
1667
        raise NotImplementedError(self._flush)
 
1668
 
 
1669
    def get_request(self):
 
1670
        """See SmartClientMedium.get_request().
 
1671
 
 
1672
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
1673
        for get_request.
 
1674
        """
 
1675
        return SmartClientStreamMediumRequest(self)
 
1676
 
 
1677
    def read_bytes(self, count):
 
1678
        return self._read_bytes(count)
 
1679
 
 
1680
 
 
1681
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
1682
    """A client medium using simple pipes.
 
1683
    
 
1684
    This client does not manage the pipes: it assumes they will always be open.
 
1685
    """
 
1686
 
 
1687
    def __init__(self, readable_pipe, writeable_pipe):
 
1688
        SmartClientStreamMedium.__init__(self)
 
1689
        self._readable_pipe = readable_pipe
 
1690
        self._writeable_pipe = writeable_pipe
 
1691
 
 
1692
    def _accept_bytes(self, bytes):
 
1693
        """See SmartClientStreamMedium.accept_bytes."""
 
1694
        self._writeable_pipe.write(bytes)
 
1695
 
 
1696
    def _flush(self):
 
1697
        """See SmartClientStreamMedium._flush()."""
 
1698
        self._writeable_pipe.flush()
 
1699
 
 
1700
    def _read_bytes(self, count):
 
1701
        """See SmartClientStreamMedium._read_bytes."""
 
1702
        return self._readable_pipe.read(count)
 
1703
 
 
1704
 
 
1705
class SmartSSHClientMedium(SmartClientStreamMedium):
 
1706
    """A client medium using SSH."""
 
1707
    
 
1708
    def __init__(self, host, port=None, username=None, password=None,
 
1709
            vendor=None):
 
1710
        """Creates a client that will connect on the first use.
 
1711
        
 
1712
        :param vendor: An optional override for the ssh vendor to use. See
 
1713
            bzrlib.transport.ssh for details on ssh vendors.
 
1714
        """
 
1715
        SmartClientStreamMedium.__init__(self)
 
1716
        self._connected = False
 
1717
        self._host = host
 
1718
        self._password = password
 
1719
        self._port = port
 
1720
        self._username = username
 
1721
        self._read_from = None
 
1722
        self._ssh_connection = None
 
1723
        self._vendor = vendor
 
1724
        self._write_to = None
 
1725
 
 
1726
    def _accept_bytes(self, bytes):
 
1727
        """See SmartClientStreamMedium.accept_bytes."""
 
1728
        self._ensure_connection()
 
1729
        self._write_to.write(bytes)
 
1730
 
 
1731
    def disconnect(self):
 
1732
        """See SmartClientMedium.disconnect()."""
 
1733
        if not self._connected:
 
1734
            return
 
1735
        self._read_from.close()
 
1736
        self._write_to.close()
 
1737
        self._ssh_connection.close()
 
1738
        self._connected = False
 
1739
 
 
1740
    def _ensure_connection(self):
 
1741
        """Connect this medium if not already connected."""
 
1742
        if self._connected:
 
1743
            return
 
1744
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
1745
        if self._vendor is None:
 
1746
            vendor = ssh._get_ssh_vendor()
 
1747
        else:
 
1748
            vendor = self._vendor
 
1749
        self._ssh_connection = vendor.connect_ssh(self._username,
 
1750
                self._password, self._host, self._port,
 
1751
                command=[executable, 'serve', '--inet', '--directory=/',
 
1752
                         '--allow-writes'])
 
1753
        self._read_from, self._write_to = \
 
1754
            self._ssh_connection.get_filelike_channels()
 
1755
        self._connected = True
 
1756
 
 
1757
    def _flush(self):
 
1758
        """See SmartClientStreamMedium._flush()."""
 
1759
        self._write_to.flush()
 
1760
 
 
1761
    def _read_bytes(self, count):
 
1762
        """See SmartClientStreamMedium.read_bytes."""
 
1763
        if not self._connected:
 
1764
            raise errors.MediumNotConnected(self)
 
1765
        return self._read_from.read(count)
 
1766
 
 
1767
 
 
1768
class SmartTCPClientMedium(SmartClientStreamMedium):
 
1769
    """A client medium using TCP."""
 
1770
    
 
1771
    def __init__(self, host, port):
 
1772
        """Creates a client that will connect on the first use."""
 
1773
        SmartClientStreamMedium.__init__(self)
 
1774
        self._connected = False
 
1775
        self._host = host
 
1776
        self._port = port
 
1777
        self._socket = None
 
1778
 
 
1779
    def _accept_bytes(self, bytes):
 
1780
        """See SmartClientMedium.accept_bytes."""
 
1781
        self._ensure_connection()
 
1782
        self._socket.sendall(bytes)
 
1783
 
 
1784
    def disconnect(self):
 
1785
        """See SmartClientMedium.disconnect()."""
 
1786
        if not self._connected:
 
1787
            return
 
1788
        self._socket.close()
 
1789
        self._socket = None
 
1790
        self._connected = False
 
1791
 
 
1792
    def _ensure_connection(self):
 
1793
        """Connect this medium if not already connected."""
 
1794
        if self._connected:
 
1795
            return
 
1796
        self._socket = socket.socket()
 
1797
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
1798
        result = self._socket.connect_ex((self._host, int(self._port)))
 
1799
        if result:
 
1800
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
1801
                    (self._host, self._port, os.strerror(result)))
 
1802
        self._connected = True
 
1803
 
 
1804
    def _flush(self):
 
1805
        """See SmartClientStreamMedium._flush().
 
1806
        
 
1807
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
1808
        add a means to do a flush, but that can be done in the future.
 
1809
        """
 
1810
 
 
1811
    def _read_bytes(self, count):
 
1812
        """See SmartClientMedium.read_bytes."""
 
1813
        if not self._connected:
 
1814
            raise errors.MediumNotConnected(self)
 
1815
        return self._socket.recv(count)
 
1816
 
 
1817
 
 
1818
class SmartTCPTransport(SmartTransport):
 
1819
    """Connection to smart server over plain tcp.
 
1820
    
 
1821
    This is essentially just a factory to get 'RemoteTransport(url,
 
1822
        SmartTCPClientMedium).
 
1823
    """
 
1824
 
 
1825
    def __init__(self, url):
 
1826
        _scheme, _username, _password, _host, _port, _path = \
 
1827
            transport.split_url(url)
 
1828
        if _port is None:
 
1829
            _port = BZR_DEFAULT_PORT
 
1830
        else:
 
1831
            try:
 
1832
                _port = int(_port)
 
1833
            except (ValueError, TypeError), e:
 
1834
                raise errors.InvalidURL(
 
1835
                    path=url, extra="invalid port %s" % _port)
 
1836
        medium = SmartTCPClientMedium(_host, _port)
 
1837
        super(SmartTCPTransport, self).__init__(url, medium=medium)
 
1838
 
 
1839
 
 
1840
class SmartSSHTransport(SmartTransport):
 
1841
    """Connection to smart server over SSH.
 
1842
 
 
1843
    This is essentially just a factory to get 'RemoteTransport(url,
 
1844
        SmartSSHClientMedium).
 
1845
    """
 
1846
 
 
1847
    def __init__(self, url):
 
1848
        _scheme, _username, _password, _host, _port, _path = \
 
1849
            transport.split_url(url)
 
1850
        try:
 
1851
            if _port is not None:
 
1852
                _port = int(_port)
 
1853
        except (ValueError, TypeError), e:
 
1854
            raise errors.InvalidURL(path=url, extra="invalid port %s" % 
 
1855
                _port)
 
1856
        medium = SmartSSHClientMedium(_host, _port, _username, _password)
 
1857
        super(SmartSSHTransport, self).__init__(url, medium=medium)
 
1858
 
 
1859
 
 
1860
class SmartHTTPTransport(SmartTransport):
 
1861
    """Just a way to connect between a bzr+http:// url and http://.
 
1862
    
 
1863
    This connection operates slightly differently than the SmartSSHTransport.
 
1864
    It uses a plain http:// transport underneath, which defines what remote
 
1865
    .bzr/smart URL we are connected to. From there, all paths that are sent are
 
1866
    sent as relative paths, this way, the remote side can properly
 
1867
    de-reference them, since it is likely doing rewrite rules to translate an
 
1868
    HTTP path into a local path.
 
1869
    """
 
1870
 
 
1871
    def __init__(self, url, http_transport=None):
 
1872
        assert url.startswith('bzr+http://')
 
1873
 
 
1874
        if http_transport is None:
 
1875
            http_url = url[len('bzr+'):]
 
1876
            self._http_transport = transport.get_transport(http_url)
 
1877
        else:
 
1878
            self._http_transport = http_transport
 
1879
        http_medium = self._http_transport.get_smart_medium()
 
1880
        super(SmartHTTPTransport, self).__init__(url, medium=http_medium)
 
1881
 
 
1882
    def _remote_path(self, relpath):
 
1883
        """After connecting HTTP Transport only deals in relative URLs."""
 
1884
        if relpath == '.':
 
1885
            return ''
 
1886
        else:
 
1887
            return relpath
 
1888
 
 
1889
    def abspath(self, relpath):
 
1890
        """Return the full url to the given relative path.
 
1891
        
 
1892
        :param relpath: the relative path or path components
 
1893
        :type relpath: str or list
 
1894
        """
 
1895
        return self._unparse_url(self._combine_paths(self._path, relpath))
 
1896
 
 
1897
    def clone(self, relative_url):
 
1898
        """Make a new SmartHTTPTransport related to me.
 
1899
 
 
1900
        This is re-implemented rather than using the default
 
1901
        SmartTransport.clone() because we must be careful about the underlying
 
1902
        http transport.
 
1903
        """
 
1904
        if relative_url:
 
1905
            abs_url = self.abspath(relative_url)
 
1906
        else:
 
1907
            abs_url = self.base
 
1908
        # By cloning the underlying http_transport, we are able to share the
 
1909
        # connection.
 
1910
        new_transport = self._http_transport.clone(relative_url)
 
1911
        return SmartHTTPTransport(abs_url, http_transport=new_transport)
 
1912
 
 
1913
 
 
1914
def get_test_permutations():
 
1915
    """Return (transport, server) permutations for testing."""
 
1916
    ### We may need a little more test framework support to construct an
 
1917
    ### appropriate RemoteTransport in the future.
 
1918
    return [(SmartTCPTransport, SmartTCPServer_for_testing)]