diff --git a/src/ui/widgets/graph.py b/src/ui/widgets/graph.py index 5e51699..db32f07 100644 --- a/src/ui/widgets/graph.py +++ b/src/ui/widgets/graph.py @@ -1,51 +1,161 @@ -import matplotlib - -matplotlib.use("QtAgg") - -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from matplotlib.figure import Figure -from PyQt6 import QtCore, QtWidgets - - -class graph(FigureCanvas): - def __init__(self, parent=None, width=5, height=4, dpi=100): - fig = Figure(figsize=(width, height), dpi=dpi) - self.axes = fig.add_subplot(111) - self.canvas = FigureCanvas(fig) - super().__init__(fig) - - -# display graph with multiple lines -class GraphWidget(QtWidgets.QWidget): - def __init__(self, parent=None, data=None, legend_labels=None): - super().__init__() - self.graph = graph(self, width=7, height=4.5, dpi=100) - self.graph.axes.plot(data["x"], data["y"], "r", data["x"], data["y2"], "b") - self.toolbar = NavigationToolbar(self.graph, self) - # set legend - self.graph.axes.legend(legend_labels, loc="upper left") - # set x-axis text to be slanted - self.graph.axes.set_xticklabels(data["x"], rotation=45, ha="right") - # set the layout - layout = QtWidgets.QVBoxLayout() - layout.addWidget(self.toolbar) - layout.addWidget(self.graph) - self.setLayout(layout) - - # def set_area(self, top=1, bottom= 0.07, left=0.1, right=0.994,hspace=0.2,wspace=0.2): - # self.graph. - - -if __name__ == "__main__": - import sys - - app = QtWidgets.QApplication(sys.argv) - data = { - "x": ["A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10"], - "y": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "y2": [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], - } - widget = GraphWidget(data=data, legend_labels=["+", "-"]) - widget.show() - sys.exit(app.exec()) +import random +from typing import Union + +import pyqtgraph as pg +from PyQt6 import QtWidgets + + +def mergedicts(d1, d2): + res = {} + d1_data = list(d1.items()) + d2_data = list(d2.items()) + for i in range(len(d1)): + # _data is a list of tuples + d1_data_slice = d1_data[i] + d2_data_slice = d2_data[i] + # convert the tuples to dicts + d1_dict = dict([d1_data_slice]) + d2_dict = dict([d2_data_slice]) + # merge the dicts + res.update(d1_dict) + res.update(d2_dict) + return res + + +class DataGraph(QtWidgets.QWidget): + def __init__( + self, + title, + data=Union[dict[list, list] | dict[list[dict[str, list]]]], + generateMissing=False, + label=None, + ): + super().__init__() + + lst = [] + if generateMissing: + x_data = data["x"] + y_data = data["y"] + if not isinstance(y_data, list): + for key in y_data: + data = {"x": x_data, "y": y_data[key]} + data = self.generateMissingSemesters(data) + data["y-label"] = key + lst.append(data) + else: + data = self.generateMissingSemesters(data) + lst.append(data) + + else: + x_data = data["x"] + y_data = data["y"] + if not isinstance(y_data, list): + for key in y_data: + data = {"x": x_data, "y": y_data[key]} + data["y-label"] = key + lst.append(data) + else: + lst.append(data) + x_data = lst[0]["x"] # + xdict = dict(enumerate(x_data)) + stringaxis_x = pg.AxisItem(orientation="bottom") + stringaxis_x.setTicks([xdict.items()]) + graph = pg.PlotWidget(axisItems={"bottom": stringaxis_x}) + graph.addLegend() + + colors = ["b", "r", "c", "m", "y", "k", "w"] + symbols = [ + "o", + "s", + "t", + "d", + "+", + "t1", + "t2", + "t3", + "p", + "h", + "star", + "x", + "arrow_up", + "arrow_down", + "arrow_left", + "arrow_right", + ] + color_index = 0 + index = 0 + + for data in lst: + symbol = symbols[random.randint(0, len(symbols) - 1)] + if color_index >= len(colors): + color_index = 0 + # iterate over the list, use y-data and y-label to plot the graph + y_data = data["y"] + label = data["y-label"] if "y-label" in data else label + + pen = pg.mkPen(color=colors[color_index], width=2) + if isinstance(y_data, list): + graph.plot( + list(xdict.keys()), y_data, pen=pen, symbol=symbol, name=label + ) + color_index += 1 + index += 1 + else: + pass + graph.setBackground("#d3d3d3") + graph.setTitle(title) + layout = QtWidgets.QVBoxLayout() + layout.addWidget(graph) + self.setLayout(layout) + + def generateMissingSemesters(self, data: dict[list]): + # join the data into a single dict with x values as key and y values as value + tmp_data = dict(zip(data["x"], data["y"])) + # split into dicts based on SoSe and WiSe + SoSe_data = {k: v for k, v in tmp_data.items() if "SoSe" in k} + WiSe_data = {k: v for k, v in tmp_data.items() if "WiSe" in k} + SoSe_years = [int(sose.split("SoSe")[1]) for sose in SoSe_data] + WiSe_years = [int(wise.split("WiSe")[1].split("/")[0]) for wise in WiSe_data] + years = SoSe_years + WiSe_years + years = [ + year for year in range(min(list(set(years))), max(list(set(years))) + 1) + ] + years.sort() + for year in years: + SoSe_year = f"SoSe{year}" + WiSe_year = f"WiSe{year}/{year+1}" + if SoSe_year not in SoSe_data.keys(): + SoSe_data[SoSe_year] = 0 + if WiSe_year not in WiSe_data.keys(): + WiSe_data[WiSe_year] = 0 + + # sort WiSe_data to have same order as SoSe_data + WiSe_data = dict(sorted(WiSe_data.items(), key=lambda x: x[0])) + SoSe_data = dict(sorted(SoSe_data.items(), key=lambda x: x[0])) + data = mergedicts(SoSe_data, WiSe_data) + # split the data back into x and y + data = {"x": list(data.keys()), "y": list(data.values())} + return data + + +if __name__ == "__main__": + import sys + + app = QtWidgets.QApplication(sys.argv) + data_1 = { + "x": ["SoSe 10", "WiSe 10/11", "SoSe 11", "SoSe 14"], + "y": { + "Added": [1, 2, 3, 4], + "Deleted": [4, 3, 2, 1], + }, + } + data_2 = { + "x": ["SoSe 10"], + "y": [2], + } + graph_data = {"x": ["SoSe 24"], "y": [1]} + widget = DataGraph( + "ELSA Apparate pro Semester", data_2, True, "Anzahl der Apparate" + ) + widget.show() + sys.exit(app.exec())