diff --git a/src/ui/widgets/graph.py b/src/ui/widgets/graph.py index 013916b..4907652 100644 --- a/src/ui/widgets/graph.py +++ b/src/ui/widgets/graph.py @@ -10,6 +10,7 @@ from PySide6.QtGui import QPainter, QPen, QColor from PySide6.QtCharts import QChart, QChartView, QLineSeries, QValueAxis, QCategoryAxis from src import LOG_DIR +from src.backend.semester import Semester log = loguru.logger log.remove() @@ -37,7 +38,7 @@ class DataQtGraph(QtWidgets.QWidget): def __init__( self, title: str, - data: dict, + data: dict[str, Union[list[str], dict[str, list[int]]]], generateMissing: bool, y_label: str, x_rotation: int = 90, @@ -53,16 +54,46 @@ class DataQtGraph(QtWidgets.QWidget): lst = [] if generateMissing: + s_start = Semester.from_string(data["x"][0]) + s_end = Semester.from_string(data["x"][-1]) + # generate all semesters from start to end + missing_semesters = Semester.generate_missing(s_start, s_end) x_data = data["x"] y_data = data["y"] if not isinstance(y_data, list): for key in y_data: - data = {"x": x_data, "y": y_data[key]} - data = self.generateMissingSemesters(data) + data = {"x": [], "y": []} data["y-label"] = key + for semester in missing_semesters: + if semester not in x_data: + data["x"].append(semester) + data["y"].append(0) + data["x"].append(semester) + # get the index of the semester in x_data + if semester in x_data: + index = x_data.index(semester) + print("index:", index) + print(key, y_data[key]) + data["y"].append(y_data[key][index]) lst.append(data) + # for key in y_data: + # data = {"x": x_data, "y": y_data[key]} + # data = self.generateMissingSemesters(data) + # data["y-label"] = key else: - data = self.generateMissingSemesters(data) + # data = self.generateMissingSemesters(data) + data = {"x": [], "y": []} + for semester in missing_semesters: + # if semester not in x_data, set y to 0 + if semester not in x_data: + data["x"].append(semester) + data["y"].append(0) + data["x"].append(semester) + # get the index of the semester in x_data + if semester in x_data: + index = x_data.index(semester) + data["y"].append(y_data[index]) + data["y-label"] = y_label lst.append(data) else: @@ -76,9 +107,6 @@ class DataQtGraph(QtWidgets.QWidget): else: lst.append(data) x_data = lst[0]["x"] # - xdict = dict(enumerate(x_data)) - - print("xdict:", xdict) self.chart.createDefaultAxes() for entry in lst: @@ -88,6 +116,16 @@ class DataQtGraph(QtWidgets.QWidget): # entryseries.append(entry["x"].index(x_val), y_val) entryseries.setName(entry["y-label"] if "y-label" in entry else y_label) + entryseries.setPen( + QPen( + QColor( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ), + 2, + ) + ) self.chart.addSeries(entryseries) @@ -106,7 +144,7 @@ class DataQtGraph(QtWidgets.QWidget): # str() # ) self.chart.legend().setVisible(True) - self.chart.legend().setAlignment(QtCore.Qt.AlignmentFlag.AlignBottom) + self.chart.legend().setAlignment(QtCore.Qt.AlignmentFlag.AlignTop) # set legend labels self.chart.setAxisY(QValueAxis(self.chart), entryseries)