# Tensorboard Image Extractor Copyright (C) 2021 Otthorn # License: GNU GPL v3 or later import argparse import io import tensorboard.compat.proto.event_pb2 as event_pb2 from PIL import Image from tqdm import tqdm def read_event(data): """ Read one event from the datastream. Returns the event as a string and the trucated data without the event that was read. """ h0 = int.from_bytes(data[:8], "little") event_str = data[12 : 12 + h0] data = data[12 + h0 + 4 :] return data, event_str def read_file(input_path): """ Read a file. Read a file and return the data, throws an error and exits if no file is found. """ try: with open(input_path, "rb") as f: data = f.read() return data except FileNotFoundError: print(f"Input file {input_path} is not a valid path.") exit() def decode_image(img): """Decodes an image""" d_img = Image.open(io.BytesIO(img.encoded_image_string)) return d_img def main(args): data = read_file(args.input) original_length = len(data) pbar = tqdm(total=original_length) img_list = [] while data: data, event_str = read_event(data) pbar.n = original_length - len(data) pbar.update(0) event = event_pb2.Event() event.ParseFromString(event_str) if event.HasField("summary"): for value in event.summary.value: if value.HasField("image"): tag = value.ListFields()[0][1] # if args.Nons is None process everything, else process # only the given tag if args.tag is None or args.tag == tag: img = value.image img_d = decode_image(img) # sanitize tag tag = tag.replace("/","_") tag = tag.replace(" ","_") if args.gif: # save an image list for the gif img_list.append(img_d) else: print(f"Saving as: img_{tag}_{event.step}.png") img_d.save(f"img_{tag}_{event.step}.png", format="png") if args.gif: # save as an animated gif print("[DEBUG] saving animated gif") im = img_list[0] im.save( args.output, save_all=True, append_images=img_list, duration=args.second_per_frame, loop=args.do_not_loop, ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Tensorboard image dumper and gif creator" ) parser.add_argument( "--input", "-i", type=str, help="Input file, must be a tensorboard event file", required=True, ) parser.add_argument( "--output", "-o", type=str, help="Output file for the gif, must have a .gif extension", ) parser.add_argument( "--gif", default=False, action="store_true", help="Save the ouptut as an animated gif", ) parser.add_argument( "--do-not-loop", default=True, action="store_false", help="Prevent the gif from looping", ) parser.add_argument( "--second-per-frame", "-spf", type=int, default=60, help="Time between each frame (in milisecond)", ) parser.add_argument( "--tag", "-t", type=str, help="Select a single tag for the ouptut", ) args = parser.parse_args() main(args)