#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import re
import sys

from winswitch.util.simple_logger import Logger
DEBUG_IMPORTS = "--debug-imports" in sys.argv	#Useful for py2app/py2exe silent crashes (import fails and it just exits)
logger = Logger("ssh_command_monitor")
def debug_import(msg):
	if DEBUG_IMPORTS:
		logger.slog(msg)

debug_import("consts")
from winswitch.consts import NOTIFY_ERROR, NOTIFY_AUTH_ERROR
debug_import("globals")
from winswitch.globals import WIN32
debug_import("net_util")
from winswitch.net.net_util import get_port_mapper
debug_import("common")
from winswitch.util.common import escape_newlines, hash_text
debug_import("process_util")
from winswitch.util.process_util import LineProcessProtocolWrapper
debug_import("tunnel_monitor")
from winswitch.util.tunnel_monitor import QUERY
debug_import("config")
from winswitch.util.config import modify_server_config


PORTFORWARD_HEALTH_CHECK = 30			#how long between tests on port forwards
PORTMONITOR_HEALTH_CHECK = 30			#how long between tests on port monitor

DEBUG = False
TUNNEL_LOCALHOST = False				#force localhost to use an ssh tunnel (only useful for testing)
ALWAYS_FORWARD_X = True					#the ssh connection used for connecting to the server does not need X-forwarding
										#but since we use connection sharing... it is best to have it (or we will need a new connection for display forwarding)
#useful for testing
IGNORE_REMOTE_ICONS = False
IGNORE_LOCAL_ICONS = False


WARN_RE = re.compile(r"No .SA host key is known for (.*) and you have requested strict checking.")
KEY_RE = re.compile(r"^.SA key fingerprint is ([a-f0-9:]*)")
HOST_RE = re.compile(r"^The authenticity of host '(.*)' can't be established.")

KEY_MISMATCH_RES = [re.compile("@\s*WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!.*"),
					re.compile("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"),
					re.compile("IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!"),
					re.compile("Someone could be eavesdropping on you right now (man-in-the-middle attack)!"),
					re.compile("It is also possible that the RSA host key has just been changed."),
					re.compile("The fingerprint for the RSA key sent by the remote host is"),
					re.compile("Please contact your system administrator."),
					re.compile("Add correct host key in /home/antoine/.ssh/known_hosts to get rid of this message."),
					re.compile("Offending key in .*"),
					re.compile(".* host key for .* has changed and you have requested strict checking."),
					re.compile("Host key verification failed.")
					]

port_mapper = get_port_mapper()

class SSHCommandMonitor(LineProcessProtocolWrapper):
	"""
	Abstract superclass for interacting with the input and output of an SSH process
	"""

	def __init__(self, server, name, server_timeout, notify=None, ask=None):
		LineProcessProtocolWrapper.__init__(self)
		self.server = server
		self.name = name
		self.notify = notify
		self.ask = ask
		self.server_timeout = server_timeout
		self.win32_detached = False				#no need to do the win32 process hackery for plink
		self.log_full_command = not WIN32		#would end up logging the SSH password!
		
		self.password_sent = False				#have we sent the password yet? (if so, a new password request means it's incorrect)
		self.connected = False					#set to True when we know we have connected to the other end (received valid data)
		self.usePTY = not WIN32					#we'll need a pty to interact with ssh on *nix
		self.DEBUG = DEBUG
		self.warned = False
		self.merge_stderr = True

	def exit(self):
		self.connected = False
		LineProcessProtocolWrapper.exit(self)

	def exec_error(self):
		self.serror("failed to start '%s' with command '%s'" % (self.name, self.command))
		if self.notify:
			self.notify("Failed to launch", "Tunnel error", notification_type=NOTIFY_ERROR)

	def is_ssh_warning(self, data):
		return data and (data.startswith("Warning: ") or data.startswith("ControlSocket /"))

	def is_ssh_error(self, data):
		return data and data.startswith("No RSA host key is known for")

	def is_password_request(self, data):
		if data=="Password: ":
			return	True
		if data and data.startswith("%s@" % self.server.username) and data.endswith("'s password: "):
			return	True
		return	False

	def confirm_password(self, data):
		text = "Please enter your account password for:\n%s@%s:%s" % (self.server.username, self.server.host, self.server.port)
		if len(self.server.password)>0:
			if not self.password_sent:
				self.slog("received SSH password request, sending back existing password", data)
				self.respond(self.server.password)
				self.password_sent = True
				return
			else:
				self.sdebug("received another password request, not using existing password", data)
				text = "Password failed, try again.\n%s" % text
		if self.ask:
			def set_new_password(password):
				self.slog(None, hash_text(password))
				self.server.set_password(password)
				self.respond(password, False)
				self.server.touch()
				modify_server_config(self.server, ["encrypted_password"])
			self.ask("Confirm Password", text, lambda : self.stop("User cancelled"), lambda password : set_new_password(password), True,
					UUID="%s-SSH-PASS" % self.server.ID)
		else:
			self.serror("UI question handler not set! cannot request the user to enter the password!", data)
			self.respond("")


	def is_authentication_failed(self, data):
		if data=="Permission denied, please try again.":
			return	True
		if data and data.startswith("Received disconnect from ") and data.endswith("Too many authentication failures for %s" % self.server.username):
			return	True
		return	False

	def is_confirm_key_check(self, data, with_response):
		if not data:
			return	False
		q = "Are you sure you want to continue connecting (yes/no)? "
		if with_response:
			return	data.startswith(q)
		else:
			return	data==q
	
	def is_key_check(self, data):
		return	data and (KEY_RE.match(data) or HOST_RE.match(data) or WARN_RE.match(data))

	def confirm_key(self):
		"""
		This fires when we detect self.is_confirm_key_check()
		We must send back "yes" or "no" to the ssh client process.
		"""
		l = len(self.previous_lines)
		if l<2:
			self.respond("no")
			return
		key_line = self.previous_lines[l-1]
		host_line = self.previous_lines[l-2]
		self.sdebug("rsa_line=%s, host_line=%s" % (key_line, host_line))
		key_match = KEY_RE.match(key_line)
		host_match = HOST_RE.match(host_line)
		if key_match is None or host_match is None:
			self.respond("no")
			return True
		key = key_match.group(1)
		host = host_match.group(1)
		self.slog("ask for confirmation of key=%s and host=%s" % (key, host))
		if self.ask:
			text = "%s\n%s\nAre you sure you want to continue connecting?" % (host_line, key_line)
			self.ask("Confirm Host Key", text, self._hostkey_unconfirmed, lambda : self.respond("yes"),
					UUID="%s-HOSTKEY" % self.server.ID)
		else:
			self.serror("UI question handler not set! cannot request confirmation from the user!")
			self.respond("no")
	
	def _hostkey_unconfirmed(self):
		self.respond("no")
		self.stop("User rejected host key")

	
	def dataReceived(self, data):
		"""
		Override so we can detect the password request line and key confirmation line.
		These do not have a carriage return - so cannot be detected via handle(line).
		"""
		LineProcessProtocolWrapper.dataReceived(self, data)
		if self.DEBUG:
			self.sdebug("buffer=%s" % escape_newlines(self.buffer), escape_newlines(data))
		if self.is_password_request(self.buffer):
			self.confirm_password(data)
		if self.is_confirm_key_check(self.buffer, False):
			self.confirm_key()

	def handle(self, line):
		"""
		We deal with all the SSH related messages here,
		any messages not handled here will be passed on to do_handle(line)
		"""
		if self.DEBUG:
			self.sdebug(None, line)
		if self.is_ssh_warning(line):
			self.slog("received SSH warning", line)
			return
		if line.startswith('Xlib:  extension "RANDR" missing on display'):
			return				#ignore warning
		if self.is_password_request(line):
			return				#this is the password request handled above being echoed (we only process it here after we respond and send CR) - just ignore it
		if self.is_key_check(line) or self.is_confirm_key_check(line, True):
			return
		if self.is_ssh_error(line):
			self.server.invalid_login = True		#record the login failure against this server
			self.notify("SSH Error",
					"Login failed on '%s':\n%s." % (self.server.get_display_name(), line),
					notification_type=NOTIFY_AUTH_ERROR)
			self.stop("SSH error")
			return
		if self.is_authentication_failed(line):
			self.serror("authentication failed", line)
			self.server.invalid_login = True		#record the login failure against this server
			self.notify("SSH Authentication Failed",
						"Login failed on '%s',\nplease check your username and password." % self.server.get_display_name(),
						notification_type=NOTIFY_AUTH_ERROR)
			self.stop("SSH authentication failed")
			return
		if line.startswith("Killed by signal "):
			self.serror("process: %s" % self.process, line)
			return
		if line.startswith("@@@@@@@@@@") and self.line_count<2:
			self.slog("ssh warning/error message follows", line)
			return
		for re in KEY_MISMATCH_RES:
			if re.match(line):
				if not self.warned:
					self.serror("Key mismatch error", line)
					self.notify("SSH Host Key Error",
								"The SSH Host key for '%s' does not match the key recorded!\nEither update your SSH known hosts or disable host key checking." % self.server.get_display_name(),
								notification_type=NOTIFY_AUTH_ERROR)
					self.warned = True
				else:
					self.sdebug("Key mismatch error", line)
				return
		if line==QUERY:
			# ignoring our own echo
			return
		last_line = None
		if len(self.previous_lines)>0:
			last_line = self.previous_lines[len(self.previous_lines)-1]
		if not line and last_line and (self.is_ssh_warning(last_line) or self.is_key_check(last_line) or self.is_confirm_key_check(last_line, True) or self.is_password_request(last_line)):
			# ignoring empty line following a warning or key check
			return
		if line=="Timeout, server not responding.":
			self.stop("Server timeout")
			return
		if not self.connected:
			self.connected = True
		self.do_handle(line)
	
	
	def do_handle(self, line):
		self.callLater(5, self.respond, line)
	
	def respond(self, line=QUERY, logit=True):
		if self.DEBUG:
			if logit:
				self.sdebug(None, line, logit)
			else:
				self.sdebug(None, "???", logit)
		self.transport.write("%s\n" % line)
		current = self.line_count
		timeout = self.server_timeout
		if timeout<5:
			timeout=5
		if current<2:
			timeout += self.server_timeout		#double timeout for initial lines
		self.callLater(timeout, self.check_read_timeout, timeout, current)		#ensure we keep receiving

	def check_read_timeout(self, timeout, prev_line_count):
		"""
		We schedule this method to fire to ensure that we are still receiving data from the other end regularly.
		"""
		if self.terminated or self.stopping:
			return
		if prev_line_count == self.line_count:
			if self.notify:
				msg = "(%s as %s)" % (self.server.get_display_name(), self.server.username)
				if prev_line_count==0:
					self.notify("SSH Tunnel Timeout",
							msg+"\nFailed to connect or receive data after waiting %d seconds" % timeout,
							notification_type=NOTIFY_ERROR)
				else:
					self.notify("SSH Tunnel Timeout",
							msg+"\nConnection failed, waited %d seconds, after %d successful exchanges" % (timeout, prev_line_count),
							notification_type=NOTIFY_ERROR)
			if not self.stopping:
				self.stop("read timeout, timeout=%s, command=%s" % (self.server_timeout, self.command))
