summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--garcon.py55
1 files changed, 32 insertions, 23 deletions
diff --git a/garcon.py b/garcon.py
index 040b045..e10f85e 100644
--- a/garcon.py
+++ b/garcon.py
@@ -1,11 +1,12 @@
 from sys import stdin, stdout, stderr
+import asyncio
 import datetime
 import hashlib
 import socket
 import ssl
+import traceback
 import urllib.parse
 import uuid
-import traceback
 
 # based on:
 # https://tildegit.org/solderpunk/gemini-demo-1/src/branch/master/gemini-demo.py
@@ -13,6 +14,7 @@ import traceback
 # TODO ciphers etc
 
 SIZE_LIMIT = 4 * 1024 * 1024 # 4MB seems reasonable
+TIME_LIMIT = 45 # seconds for each request
 
 outf = stdout.buffer
 
@@ -47,25 +49,32 @@ def warcinfo():
 	outf.write(payload)
 	outf.write(b'\r\n\r\n')
 
-def request_raw(host, port, url):
+async def request_raw(host, port, url):
 	assert '\n' not in url
 
-	s = socket.create_connection((host, port))
 	context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
 	context.check_hostname = False
 	context.verify_mode = ssl.CERT_NONE
-	s = context.wrap_socket(s, server_hostname=host)
-
-	s.sendall((url + '\r\n').encode("UTF-8"))
-	peername = s.getpeername()
-	cert = s.getpeercert(True)
-
-	fp = s.makefile("rb")
-	payload = fp.read(SIZE_LIMIT)
-	truncated = fp.read() != b''
-	print(truncated)
-	fp.close()
-	s.close()
+	reader, writer = await asyncio.open_connection(host, port, ssl=context)
+
+	writer.write((url + '\r\n').encode("UTF-8"))
+	peername = writer.transport.get_extra_info('peername')
+	cert     = writer.transport.get_extra_info('ssl_object').getpeercert(True)
+
+	truncated = None
+	payload = bytearray()
+	try:
+		async with asyncio.timeout(TIME_LIMIT):
+			while len(payload) < SIZE_LIMIT:
+				res = await reader.read(SIZE_LIMIT - len(payload))
+				if res == b'': break
+				payload += res
+		if (await reader.read(1)) != b'':
+			truncated = 'length'
+	except TimeoutError:
+		truncated = 'time'
+
+	writer.close()
 
 	# warctools doesn't like WARC/1.1
 	outf.write(b'WARC/1.0\r\n')
@@ -81,28 +90,28 @@ def request_raw(host, port, url):
 	header("WARC-IP-Address", peername[0])
 	header("WARC-Target-URI", url)
 	header("Content-Type", "application/gemini; msgtype=response") # as in mozz-archiver
-	if trunacted:
-		header("WARC-Truncated", "length")
+	if truncated:
+		header("WARC-Truncated", truncated)
 
 	# my extensions
 	header("X-Server-Fingerprint", 'sha256:' + hashlib.sha256(cert).hexdigest())
 
 	outf.write(b'\r\n')
-#outf.write(payload)
+	outf.write(payload)
 	outf.write(b'\r\n\r\n')
 
 	# TODO check for close_notify
 	return payload
 
-def request_url(url):
+async def request_url(url):
 	p = urllib.parse.urlparse(url)
 	assert p.scheme == 'gemini'
-	return request_raw(p.hostname, p.port or 1965, url)
+	return await request_raw(p.hostname, p.port or 1965, url)
 
-def request_url_loop(url):
+async def request_url_loop(url):
 	# i only allow 3 redirects, so detecting loops isn't really necessary
 	for _ in range(3):
-		res = request_url(url)
+		res = await request_url(url)
 		header = res.split(b'\r\n')[0]
 		if 2 + 1 + 1024 < len(header): break
 		if len(header) > 0 and header[0] == ord('3'):
@@ -116,7 +125,7 @@ if __name__ == '__main__':
 	outf.flush()
 	for line in stdin:
 		try:
-			request_url_loop(line.rstrip('\r\n').rstrip('\n'))
+			asyncio.run(request_url_loop(line.rstrip('\r\n').rstrip('\n')))
 			outf.flush()
 		except:
 			print(traceback.format_exc(), file=stderr)