ca32e99bdc3f31656853b1497989e648bfc83553
[teamcity/git-plugin.git] / git-agent / src / jetbrains / buildServer / buildTriggers / vcs / git / agent / JSchClient.java
1 /*
2  * Copyright 2000-2018 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package jetbrains.buildServer.buildTriggers.vcs.git.agent;
18
19 import com.jcraft.jsch.ChannelExec;
20 import com.jcraft.jsch.JSch;
21 import com.jcraft.jsch.Logger;
22 import com.jcraft.jsch.Session;
23 import jetbrains.buildServer.buildTriggers.vcs.git.GitUtils;
24 import org.jetbrains.annotations.NotNull;
25 import org.jetbrains.annotations.Nullable;
26 import org.jetbrains.git4idea.ssh.GitSSHHandler;
27
28 import javax.security.auth.callback.Callback;
29 import javax.security.auth.callback.CallbackHandler;
30 import javax.security.auth.callback.UnsupportedCallbackException;
31 import java.io.File;
32 import java.io.InputStream;
33 import java.security.Security;
34 import java.text.SimpleDateFormat;
35 import java.util.ArrayList;
36 import java.util.Arrays;
37 import java.util.Date;
38 import java.util.List;
39
40 public class JSchClient {
41
42   private final static int BUF_SIZE = 32 * 1024;
43
44   private final String myHost;
45   private final String myUsername;
46   private final Integer myPort;
47   private final String myCommand;
48   private final Logger myLogger;
49
50   private JSchClient(@NotNull String host,
51                      @Nullable String username,
52                      @Nullable Integer port,
53                      @NotNull String command,
54                      @NotNull Logger logger) {
55     myHost = host;
56     myUsername = username;
57     myPort = port;
58     myCommand = command;
59     myLogger = logger;
60   }
61
62
63   public static void main(String... args) {
64     boolean debug = Boolean.parseBoolean(System.getenv(GitSSHHandler.TEAMCITY_DEBUG_SSH));
65     Logger logger = debug ? new StdErrLogger() : new InMemoryLogger(Logger.INFO);
66     try {
67       JSchClient ssh = createClient(logger, args);
68       ssh.run();
69     } catch (Throwable t) {
70       if (logger instanceof InMemoryLogger) {
71         ((InMemoryLogger)logger).printLog();
72       }
73       System.err.println(t.getMessage());
74       if (t instanceof NullPointerException || debug)
75         t.printStackTrace();
76       System.exit(1);
77     }
78   }
79
80
81   private static JSchClient createClient(@NotNull Logger logger, String[] args) {
82     if (args.length != 2 && args.length != 4) {
83       System.err.println("Invalid arguments " + Arrays.asList(args));
84       System.exit(1);
85     }
86
87     int i = 0;
88     Integer port = null;
89     //noinspection HardCodedStringLiteral
90     if ("-p".equals(args[i])) {
91       i++;
92       port = Integer.parseInt(args[i++]);
93     }
94     String host = args[i++];
95     String user;
96     int atIndex = host.lastIndexOf('@');
97     if (atIndex == -1) {
98       user = null;
99     }
100     else {
101       user = host.substring(0, atIndex);
102       host = host.substring(atIndex + 1);
103     }
104     String command = args[i];
105     return new JSchClient(host, user, port, command, logger);
106   }
107
108
109   public void run() throws Exception {
110     ChannelExec channel = null;
111     Session session = null;
112     try {
113       JSch.setLogger(myLogger);
114       JSch jsch = new JSch();
115       String privateKeyPath = System.getenv(GitSSHHandler.TEAMCITY_PRIVATE_KEY_PATH);
116       if (privateKeyPath != null) {
117         jsch.addIdentity(privateKeyPath, System.getenv(GitSSHHandler.TEAMCITY_PASSPHRASE));
118       } else {
119         String userHome = System.getProperty("user.home");
120         if (userHome != null) {
121           File homeDir = new File(userHome);
122           File ssh = new File(homeDir, ".ssh");
123           File rsa = new File(ssh, "id_rsa");
124           if (rsa.isFile()) {
125             jsch.addIdentity(rsa.getAbsolutePath());
126           }
127           File dsa = new File(ssh, "id_dsa");
128           if (dsa.isFile()) {
129             jsch.addIdentity(dsa.getAbsolutePath());
130           }
131         }
132       }
133       session = jsch.getSession(myUsername, myHost, myPort != null ? myPort : 22);
134
135       String teamCityVersion = System.getenv(GitSSHHandler.TEAMCITY_VERSION);
136       if (teamCityVersion != null) {
137         session.setClientVersion(GitUtils.getSshClientVersion(session.getClientVersion(), teamCityVersion));
138       }
139
140       if (Boolean.parseBoolean(System.getenv(GitSSHHandler.SSH_IGNORE_KNOWN_HOSTS_ENV))) {
141         session.setConfig("StrictHostKeyChecking", "no");
142       } else {
143         String userHome = System.getProperty("user.home");
144         if (userHome != null) {
145           File homeDir = new File(userHome);
146           File ssh = new File(homeDir, ".ssh");
147           File knownHosts = new File(ssh, "known_hosts");
148           if (knownHosts.isFile()) {
149             try {
150               jsch.setKnownHosts(knownHosts.getAbsolutePath());
151             } catch (Exception e) {
152               myLogger.log(Logger.WARN, "Failed to configure known hosts: '" + e.toString() + "'");
153             }
154           }
155         }
156       }
157
158       String authMethods = System.getenv(GitSSHHandler.TEAMCITY_SSH_PREFERRED_AUTH_METHODS);
159       if (authMethods != null && authMethods.length() > 0)
160         session.setConfig("PreferredAuthentications", authMethods);
161
162       EmptySecurityCallbackHandler.install();
163
164       session.connect();
165
166       channel = (ChannelExec) session.openChannel("exec");
167       channel.setPty(false);
168       channel.setCommand(myCommand);
169       channel.setInputStream(System.in);
170       channel.setErrStream(System.err);
171       channel.connect();
172
173       InputStream input = channel.getInputStream();
174       byte[] buffer = new byte[BUF_SIZE];
175       int count;
176       while ((count = input.read(buffer)) != -1) {
177         System.out.write(buffer, 0, count);
178       }
179     } finally {
180       if (channel != null)
181         channel.disconnect();
182       if (session != null)
183         session.disconnect();
184     }
185   }
186
187
188   private static class StdErrLogger implements Logger {
189     private final SimpleDateFormat myDateFormat = new SimpleDateFormat("[HH:mm:ss.SSS]");
190     @Override
191     public boolean isEnabled(final int level) {
192       return true;
193     }
194
195     @Override
196     public void log(final int level, final String message) {
197       System.err.print(getTimestamp());
198       System.err.print(" ");
199       System.err.print(getLevel(level));
200       System.err.print(" ");
201       System.err.println(message);
202     }
203
204     @NotNull
205     private String getTimestamp() {
206       synchronized (myDateFormat) {
207         return myDateFormat.format(new Date());
208       }
209     }
210   }
211
212
213   private static class InMemoryLogger implements Logger {
214     private final int myMinLogLevel;
215     private final List<LogEntry> myLogEntries;
216     InMemoryLogger(int minLogLevel) {
217       myMinLogLevel = minLogLevel;
218       myLogEntries = new ArrayList<LogEntry>();
219     }
220
221     @Override
222     public boolean isEnabled(final int level) {
223       return level >= myMinLogLevel;
224     }
225
226     @Override
227     public void log(final int level, final String message) {
228       if (isEnabled(level)) {
229         synchronized (myLogEntries) {
230           myLogEntries.add(new LogEntry(System.currentTimeMillis(), level, message));
231         }
232       }
233     }
234
235     void printLog() {
236       SimpleDateFormat dateFormat = new SimpleDateFormat("[HH:mm:ss.SSS]");
237       synchronized (myLogEntries) {
238         for (LogEntry entry : myLogEntries) {
239           System.err.print(dateFormat.format(new Date(entry.myTimestamp)));
240           System.err.print(" ");
241           System.err.print(getLevel(entry.myLogLevel));
242           System.err.print(" ");
243           System.err.println(entry.myMessage);
244         }
245       }
246     }
247
248     private static class LogEntry {
249       private final long myTimestamp;
250       private final int myLogLevel;
251       private final String myMessage;
252       LogEntry(long timestamp, int logLevel, @NotNull String message) {
253         myTimestamp = timestamp;
254         myLogLevel = logLevel;
255         myMessage = message;
256       }
257     }
258   }
259
260
261   @NotNull
262   private static String getLevel(int level) {
263     switch (level) {
264       case Logger.DEBUG:
265         return "DEBUG";
266       case Logger.INFO:
267         return "INFO";
268       case Logger.WARN:
269         return "WARN";
270       case Logger.ERROR:
271         return "ERROR";
272       case Logger.FATAL:
273         return "FATAL";
274       default:
275         return "UNKNOWN";
276     }
277   }
278
279
280   // Doesn't provide any credentials, used instead the default handler from jdk
281   // which reads credentials them from stdin.
282   public static class EmptySecurityCallbackHandler implements CallbackHandler {
283     @Override
284     public void handle(final Callback[] callbacks) throws UnsupportedCallbackException {
285       if (callbacks.length > 0) {
286         throw new UnsupportedCallbackException(callbacks[0], "Unsupported callback");
287       }
288     }
289
290     static void install() {
291       Security.setProperty("auth.login.defaultCallbackHandler", EmptySecurityCallbackHandler.class.getName());
292     }
293   }
294 }