/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to dulwich/client.py

  • Committer: Jelmer Vernooij
  • Date: 2009-01-14 18:24:38 UTC
  • mto: (0.222.3 dulwich)
  • mto: This revision was merged to the branch mainline in revision 6960.
  • Revision ID: jelmer@samba.org-20090114182438-c0tn5eczyupi4ztn
Fix download url, add version number.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# server.py -- Implementation of the server side git protocols
 
2
# Copryight (C) 2008 Jelmer Vernooij <jelmer@samba.org>
 
3
#
 
4
# This program is free software; you can redistribute it and/or
 
5
# modify it under the terms of the GNU General Public License
 
6
# as published by the Free Software Foundation; either version 2
 
7
# or (at your option) a later version of the License.
 
8
#
 
9
# This program is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# GNU General Public License for more details.
 
13
#
 
14
# You should have received a copy of the GNU General Public License
 
15
# along with this program; if not, write to the Free Software
 
16
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 
17
# MA  02110-1301, USA.
 
18
 
 
19
import select
 
20
import socket
 
21
from dulwich.protocol import Protocol, TCP_GIT_PORT, extract_capabilities
 
22
 
 
23
class SimpleFetchGraphWalker(object):
 
24
 
 
25
    def __init__(self, local_heads, get_parents):
 
26
        self.heads = set(local_heads)
 
27
        self.get_parents = get_parents
 
28
        self.parents = {}
 
29
 
 
30
    def ack(self, ref):
 
31
        if ref in self.heads:
 
32
            self.heads.remove(ref)
 
33
        if ref in self.parents:
 
34
            for p in self.parents[ref]:
 
35
                self.ack(p)
 
36
 
 
37
    def next(self):
 
38
        if self.heads:
 
39
            ret = self.heads.pop()
 
40
            ps = self.get_parents(ret)
 
41
            self.parents[ret] = ps
 
42
            self.heads.update(ps)
 
43
            return ret
 
44
        return None
 
45
 
 
46
 
 
47
class GitClient(object):
 
48
    """Git smart server client.
 
49
 
 
50
    """
 
51
 
 
52
    def __init__(self, fileno, read, write):
 
53
        self.proto = Protocol(read, write)
 
54
        self.fileno = fileno
 
55
 
 
56
    def capabilities(self):
 
57
        return "multi_ack side-band-64k thin-pack ofs-delta"
 
58
 
 
59
    def read_refs(self):
 
60
        server_capabilities = None
 
61
        refs = {}
 
62
        # Receive refs from server
 
63
        for pkt in self.proto.read_pkt_seq():
 
64
            (sha, ref) = pkt.rstrip("\n").split(" ", 1)
 
65
            if server_capabilities is None:
 
66
                (ref, server_capabilities) = extract_capabilities(ref)
 
67
            if not (ref == "capabilities^{}" and sha == "0" * 40):
 
68
                refs[ref] = sha
 
69
        return refs, server_capabilities
 
70
 
 
71
    def send_pack(self, path):
 
72
        refs, server_capabilities = self.read_refs()
 
73
        changed_refs = [] # FIXME
 
74
        if not changed_refs:
 
75
            self.proto.write_pkt_line(None)
 
76
            return
 
77
        self.proto.write_pkt_line("%s %s %s\0%s" % (changed_refs[0][0], changed_refs[0][1], changed_refs[0][2], self.capabilities()))
 
78
        want = []
 
79
        have = []
 
80
        for changed_ref in changed_refs[:]:
 
81
            self.proto.write_pkt_line("%s %s %s" % changed_refs)
 
82
            want.append(changed_refs[1])
 
83
            if changed_refs[0] != "0"*40:
 
84
                have.append(changed_refs[0])
 
85
        self.proto.write_pkt_line(None)
 
86
        # FIXME: This is implementation specific
 
87
        # shas = generate_pack_contents(want, have, None)
 
88
        # write_pack_data(self.write, shas, len(shas))
 
89
 
 
90
    def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
 
91
        """Retrieve a pack from a git smart server.
 
92
 
 
93
        :param determine_wants: Callback that returns list of commits to fetch
 
94
        :param graph_walker: Object with next() and ack().
 
95
        :param pack_data: Callback called for each bit of data in the pack
 
96
        :param progress: Callback for progress reports (strings)
 
97
        """
 
98
        (refs, server_capabilities) = self.read_refs()
 
99
       
 
100
        wants = determine_wants(refs)
 
101
        if not wants:
 
102
            self.proto.write_pkt_line(None)
 
103
            return
 
104
        self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
 
105
        for want in wants[1:]:
 
106
            self.proto.write_pkt_line("want %s\n" % want)
 
107
        self.proto.write_pkt_line(None)
 
108
        have = graph_walker.next()
 
109
        while have:
 
110
            self.proto.write_pkt_line("have %s\n" % have)
 
111
            if len(select.select([self.fileno], [], [], 0)[0]) > 0:
 
112
                pkt = self.proto.read_pkt_line()
 
113
                parts = pkt.rstrip("\n").split(" ")
 
114
                if parts[0] == "ACK":
 
115
                    graph_walker.ack(parts[1])
 
116
                    assert parts[2] == "continue"
 
117
            have = graph_walker.next()
 
118
        self.proto.write_pkt_line("done\n")
 
119
        pkt = self.proto.read_pkt_line()
 
120
        while pkt:
 
121
            parts = pkt.rstrip("\n").split(" ")
 
122
            if parts[0] == "ACK":
 
123
                graph_walker.ack(pkt.split(" ")[1])
 
124
            if len(parts) < 3 or parts[2] != "continue":
 
125
                break
 
126
            pkt = self.proto.read_pkt_line()
 
127
        for pkt in self.proto.read_pkt_seq():
 
128
            channel = ord(pkt[0])
 
129
            pkt = pkt[1:]
 
130
            if channel == 1:
 
131
                pack_data(pkt)
 
132
            elif channel == 2:
 
133
                progress(pkt)
 
134
            else:
 
135
                raise AssertionError("Invalid sideband channel %d" % channel)
 
136
 
 
137
 
 
138
class TCPGitClient(GitClient):
 
139
 
 
140
    def __init__(self, host, port=TCP_GIT_PORT):
 
141
        self._socket = socket.socket(type=socket.SOCK_STREAM)
 
142
        self._socket.connect((host, port))
 
143
        self.rfile = self._socket.makefile('rb', -1)
 
144
        self.wfile = self._socket.makefile('wb', 0)
 
145
        self.host = host
 
146
        super(TCPGitClient, self).__init__(self._socket.fileno(), self.rfile.read, self.wfile.write)
 
147
 
 
148
    def send_pack(self, path):
 
149
        self.proto.send_cmd("git-receive-pack", path, "host=%s" % self.host)
 
150
        super(TCPGitClient, self).send_pack(path)
 
151
 
 
152
    def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
 
153
        self.proto.send_cmd("git-upload-pack", path, "host=%s" % self.host)
 
154
        super(TCPGitClient, self).fetch_pack(path, determine_wants, graph_walker, pack_data, progress)