Add timeout to ssh client
[teamcity/git-plugin.git] / git-agent / src / jetbrains / buildServer / buildTriggers / vcs / git / agent / JSchClient.java
1 /*
2  * Copyright 2000-2018 JetBrains s.r.o.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package jetbrains.buildServer.buildTriggers.vcs.git.agent;
18
19 import com.jcraft.jsch.ChannelExec;
20 import com.jcraft.jsch.JSch;
21 import com.jcraft.jsch.Logger;
22 import com.jcraft.jsch.Session;
23 import jetbrains.buildServer.buildTriggers.vcs.git.GitUtils;
24 import org.jetbrains.annotations.NotNull;
25 import org.jetbrains.annotations.Nullable;
26 import org.jetbrains.git4idea.ssh.GitSSHHandler;
27
28 import javax.security.auth.callback.Callback;
29 import javax.security.auth.callback.CallbackHandler;
30 import javax.security.auth.callback.UnsupportedCallbackException;
31 import java.io.File;
32 import java.io.IOException;
33 import java.io.InputStream;
34 import java.security.Security;
35 import java.text.SimpleDateFormat;
36 import java.util.*;
37 import java.util.concurrent.TimeUnit;
38 import java.util.concurrent.atomic.AtomicLong;
39
40 public class JSchClient {
41
42   private final static int BUF_SIZE = 32 * 1024;
43
44   private final String myHost;
45   private final String myUsername;
46   private final Integer myPort;
47   private final String myCommand;
48   private final Logger myLogger;
49
50   private JSchClient(@NotNull String host,
51                      @Nullable String username,
52                      @Nullable Integer port,
53                      @NotNull String command,
54                      @NotNull Logger logger) {
55     myHost = host;
56     myUsername = username;
57     myPort = port;
58     myCommand = command;
59     myLogger = logger;
60   }
61
62
63   public static void main(String... args) {
64     boolean debug = Boolean.parseBoolean(System.getenv(GitSSHHandler.TEAMCITY_DEBUG_SSH));
65     Logger logger = debug ? new StdErrLogger() : new InMemoryLogger(Logger.INFO);
66     try {
67       JSchClient ssh = createClient(logger, args);
68       ssh.run();
69     } catch (Throwable t) {
70       if (logger instanceof InMemoryLogger) {
71         ((InMemoryLogger)logger).printLog();
72       }
73       System.err.println(t.getMessage());
74       if (t instanceof NullPointerException || debug)
75         t.printStackTrace();
76       System.exit(1);
77     }
78   }
79
80
81   private static JSchClient createClient(@NotNull Logger logger, String[] args) {
82     if (args.length != 2 && args.length != 4) {
83       System.err.println("Invalid arguments " + Arrays.asList(args));
84       System.exit(1);
85     }
86
87     int i = 0;
88     Integer port = null;
89     //noinspection HardCodedStringLiteral
90     if ("-p".equals(args[i])) {
91       i++;
92       port = Integer.parseInt(args[i++]);
93     }
94     String host = args[i++];
95     String user;
96     int atIndex = host.lastIndexOf('@');
97     if (atIndex == -1) {
98       user = null;
99     }
100     else {
101       user = host.substring(0, atIndex);
102       host = host.substring(atIndex + 1);
103     }
104     String command = args[i];
105     return new JSchClient(host, user, port, command, logger);
106   }
107
108
109   public void run() throws Exception {
110     ChannelExec channel = null;
111     Session session = null;
112     try {
113       JSch.setLogger(myLogger);
114       JSch jsch = new JSch();
115       String privateKeyPath = System.getenv(GitSSHHandler.TEAMCITY_PRIVATE_KEY_PATH);
116       if (privateKeyPath != null) {
117         jsch.addIdentity(privateKeyPath, System.getenv(GitSSHHandler.TEAMCITY_PASSPHRASE));
118       } else {
119         String userHome = System.getProperty("user.home");
120         if (userHome != null) {
121           File homeDir = new File(userHome);
122           File ssh = new File(homeDir, ".ssh");
123           File rsa = new File(ssh, "id_rsa");
124           if (rsa.isFile()) {
125             jsch.addIdentity(rsa.getAbsolutePath());
126           }
127           File dsa = new File(ssh, "id_dsa");
128           if (dsa.isFile()) {
129             jsch.addIdentity(dsa.getAbsolutePath());
130           }
131         }
132       }
133       session = jsch.getSession(myUsername, myHost, myPort != null ? myPort : 22);
134
135       String teamCityVersion = System.getenv(GitSSHHandler.TEAMCITY_VERSION);
136       if (teamCityVersion != null) {
137         session.setClientVersion(GitUtils.getSshClientVersion(session.getClientVersion(), teamCityVersion));
138       }
139
140       if (Boolean.parseBoolean(System.getenv(GitSSHHandler.SSH_IGNORE_KNOWN_HOSTS_ENV))) {
141         session.setConfig("StrictHostKeyChecking", "no");
142       } else {
143         String userHome = System.getProperty("user.home");
144         if (userHome != null) {
145           File homeDir = new File(userHome);
146           File ssh = new File(homeDir, ".ssh");
147           File knownHosts = new File(ssh, "known_hosts");
148           if (knownHosts.isFile()) {
149             try {
150               jsch.setKnownHosts(knownHosts.getAbsolutePath());
151             } catch (Exception e) {
152               myLogger.log(Logger.WARN, "Failed to configure known hosts: '" + e.toString() + "'");
153             }
154           }
155         }
156       }
157
158       String authMethods = System.getenv(GitSSHHandler.TEAMCITY_SSH_PREFERRED_AUTH_METHODS);
159       if (authMethods != null && authMethods.length() > 0)
160         session.setConfig("PreferredAuthentications", authMethods);
161
162       EmptySecurityCallbackHandler.install();
163
164       session.connect();
165
166       channel = (ChannelExec) session.openChannel("exec");
167       channel.setPty(false);
168       channel.setCommand(myCommand);
169       channel.setInputStream(System.in);
170       channel.setErrStream(System.err);
171       InputStream input = channel.getInputStream();
172       Integer timeoutSeconds = getTimeoutSeconds();
173       if (timeoutSeconds != null) {
174         channel.connect(timeoutSeconds * 1000);
175       } else {
176         channel.connect();
177       }
178
179
180       if (!channel.isConnected()) {
181         throw new IOException("Connection failed");
182       }
183
184       Copy copyThread = new Copy(channel, input);
185       if (timeoutSeconds != null) {
186         new Timer(copyThread, timeoutSeconds * 1000).start();
187       }
188       copyThread.start();
189       copyThread.join();
190       copyThread.rethrowError();
191     } finally {
192       if (channel != null)
193         channel.disconnect();
194       if (session != null)
195         session.disconnect();
196     }
197   }
198
199
200   @Nullable
201   private Integer getTimeoutSeconds() {
202     String timeout = System.getenv(GitSSHHandler.TEAMCITY_SSH_IDLE_TIMEOUT_SECONDS);
203     if (timeout == null)
204       return null;
205     try {
206       return Integer.parseInt(timeout);
207     } catch (NumberFormatException e) {
208       myLogger.log(Logger.WARN, "Failed to parse idle timeout: '" + timeout + "'");
209       return null;
210     }
211   }
212
213
214   private class Timer extends Thread {
215     private final long myThresholdNanos;
216     private volatile Copy myCopyThread;
217     Timer(@NotNull Copy copyThread, long timeoutSeconds) {
218       myCopyThread = copyThread;
219       myThresholdNanos = TimeUnit.SECONDS.toNanos(timeoutSeconds);
220       setDaemon(true);
221       setName("Timer");
222     }
223
224     @Override
225     public void run() {
226       boolean logged = false;
227       long sleepInterval = Math.min(TimeUnit.SECONDS.toMillis(10), TimeUnit.NANOSECONDS.toMillis(myThresholdNanos));
228       //noinspection InfiniteLoopStatement: it is a daemon thread and doesn't prevent process from termination
229       while (true) {
230         if (System.nanoTime() - myCopyThread.getTimestamp() > myThresholdNanos) {
231           if (!logged) {
232             myLogger.log(Logger.ERROR, String.format("Timeout error: no activity for %s seconds", TimeUnit.NANOSECONDS.toSeconds(myThresholdNanos)));
233             logged = true;
234           }
235           myCopyThread.interrupt();
236         } else {
237           try {
238             Thread.sleep(sleepInterval);
239           } catch (Exception e) {
240             //ignore
241           }
242         }
243       }
244     }
245   }
246
247
248   private class Copy extends Thread {
249     private final ChannelExec myChannel;
250     private final InputStream myInput;
251     private final AtomicLong myTimestamp = new AtomicLong(System.nanoTime());
252     private volatile Exception myError;
253     Copy(@NotNull ChannelExec channel, @NotNull InputStream input) {
254       myChannel = channel;
255       myInput = input;
256       setName("Copy");
257     }
258
259     @Override
260     public void run() {
261       byte[] buffer = new byte[BUF_SIZE];
262       int count;
263       try {
264         while (myChannel.isConnected() && !myChannel.isClosed() && (count = myInput.read(buffer)) != -1) {
265           System.out.write(buffer, 0, count);
266           myTimestamp.set(System.nanoTime());
267           if (System.out.checkError()) {
268             myLogger.log(Logger.ERROR, "Error while writing to stdout");
269             throw new IOException("Error while writing to stdout");
270           }
271         }
272       } catch (Exception e) {
273         myError = e;
274       }
275     }
276
277     long getTimestamp() {
278       return myTimestamp.get();
279     }
280
281     void rethrowError() throws Exception {
282       if (myError != null)
283         throw myError;
284     }
285   }
286
287
288   private static class StdErrLogger implements Logger {
289     private final SimpleDateFormat myDateFormat = new SimpleDateFormat("[HH:mm:ss.SSS]");
290     @Override
291     public boolean isEnabled(final int level) {
292       return true;
293     }
294
295     @Override
296     public void log(final int level, final String message) {
297       System.err.print(getTimestamp());
298       System.err.print(" ");
299       System.err.print(getLevel(level));
300       System.err.print(" ");
301       System.err.println(message);
302     }
303
304     @NotNull
305     private String getTimestamp() {
306       synchronized (myDateFormat) {
307         return myDateFormat.format(new Date());
308       }
309     }
310   }
311
312
313   private static class InMemoryLogger implements Logger {
314     private final int myMinLogLevel;
315     private final List<LogEntry> myLogEntries;
316     InMemoryLogger(int minLogLevel) {
317       myMinLogLevel = minLogLevel;
318       myLogEntries = new ArrayList<LogEntry>();
319     }
320
321     @Override
322     public boolean isEnabled(final int level) {
323       return level >= myMinLogLevel;
324     }
325
326     @Override
327     public void log(final int level, final String message) {
328       if (isEnabled(level)) {
329         synchronized (myLogEntries) {
330           myLogEntries.add(new LogEntry(System.currentTimeMillis(), level, message));
331         }
332       }
333     }
334
335     void printLog() {
336       SimpleDateFormat dateFormat = new SimpleDateFormat("[HH:mm:ss.SSS]");
337       synchronized (myLogEntries) {
338         for (LogEntry entry : myLogEntries) {
339           System.err.print(dateFormat.format(new Date(entry.myTimestamp)));
340           System.err.print(" ");
341           System.err.print(getLevel(entry.myLogLevel));
342           System.err.print(" ");
343           System.err.println(entry.myMessage);
344         }
345       }
346     }
347
348     private static class LogEntry {
349       private final long myTimestamp;
350       private final int myLogLevel;
351       private final String myMessage;
352       LogEntry(long timestamp, int logLevel, @NotNull String message) {
353         myTimestamp = timestamp;
354         myLogLevel = logLevel;
355         myMessage = message;
356       }
357     }
358   }
359
360
361   @NotNull
362   private static String getLevel(int level) {
363     switch (level) {
364       case Logger.DEBUG:
365         return "DEBUG";
366       case Logger.INFO:
367         return "INFO";
368       case Logger.WARN:
369         return "WARN";
370       case Logger.ERROR:
371         return "ERROR";
372       case Logger.FATAL:
373         return "FATAL";
374       default:
375         return "UNKNOWN";
376     }
377   }
378
379
380   // Doesn't provide any credentials, used instead the default handler from jdk
381   // which reads credentials them from stdin.
382   public static class EmptySecurityCallbackHandler implements CallbackHandler {
383     @Override
384     public void handle(final Callback[] callbacks) throws UnsupportedCallbackException {
385       if (callbacks.length > 0) {
386         throw new UnsupportedCallbackException(callbacks[0], "Unsupported callback");
387       }
388     }
389
390     static void install() {
391       Security.setProperty("auth.login.defaultCallbackHandler", EmptySecurityCallbackHandler.class.getName());
392     }
393   }
394 }