diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d4edb79 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ + +.PHONY: test + +test: + python3 -m pytest test/ + diff --git a/test/test_io.py b/test/test_io.py new file mode 100644 index 0000000..1382e94 --- /dev/null +++ b/test/test_io.py @@ -0,0 +1,41 @@ +from io import StringIO +from util.io import readvalue + +def test_io_readvalue_int(): + fin = StringIO("10\n") + fout = StringIO() + + result = readvalue("Enter an Integer ", int, fin=fin, fout=fout) + + assert result == 10 + +def test_io_readvalue_msg(): + fin = StringIO("10\n") + fout = StringIO() + readvalue("Enter an Integer ", int, fin=fin, fout=fout) + fout.seek(0, 0) + + result = fout.read() + + assert result == "Enter an Integer " + +def test_io_readvalue_float(): + fin = StringIO("10.1\n") + fout = StringIO() + + result = readvalue("Enter a Float ", float, fin=fin, fout=fout) + + assert result == 10.1 + +def test_io_readvalue_int_fail(): + fin = StringIO("10.1\n10\n") + fout = StringIO() + readvalue("Enter an Integer ", int, fin=fin, fout=fout) + fout.seek(0, 0) + + result = fout.read() + + assert result == ("Enter an Integer " + r"invalid literal for int() with base 10: '10.1'" + "\nEnter an Integer ") + diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/io.py b/util/io.py new file mode 100644 index 0000000..9722be9 --- /dev/null +++ b/util/io.py @@ -0,0 +1,32 @@ +#!/usr/bin/python3 + +""" +Input/Output utilities. +""" + +import sys + + +def readvalue(msg, type_, emsg=None, fin=sys.stdin, fout=sys.stdout): + """ + Basically uses ``input(msg)`` until it has been able to convert + the input to ``type_``. Prints the error message ``emsg`` + if supplied, else the type error message. + + Reads from ``fin`` and writes to ``fout``. They default to ``sys.stdin`` + and ``sys.stdout``. + """ + x = None + while(x is None): + print(msg, file=fout, flush=True, end="") + x_raw = fin.readline() + try: + # Remove the trailing \n. + x = type_(x_raw[:-1]) + except ValueError as e: + if(emsg): + print(emsg, file=fout, flush=True) + else: + print(str(e), file=fout, flush=True) + + return x