diff --git a/nodes.py b/nodes.py index 0822865..5e4d7e4 100644 --- a/nodes.py +++ b/nodes.py @@ -2663,7 +2663,47 @@ class ReferenceOnlySimple3: out_mask = torch.zeros((1,mask.shape[1],mask.shape[2]), dtype=torch.float32, device="cpu") return (model_reference, {"samples": out_latent, "noise_mask": torch.cat((out_mask,out_mask, mask))}) +class SoundReactive: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "average_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 99999, "step": 0.01}), + "low_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 99999, "step": 0.01}), + "mid_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 99999, "step": 0.01}), + "high_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 99999, "step": 0.01}), + "low_range_hz": ("INT", {"default": 150, "min": 0, "max": 9999, "step": 1}), + "mid_range_hz": ("INT", {"default": 2000, "min": 0, "max": 9999, "step": 1}), + "multiplier": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 99999, "step": 0.01}), + "normalize": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("FLOAT","FLOAT","FLOAT","FLOAT","INT","INT","INT","INT") + RETURN_NAMES =("average_level", "low_level", "mid_level", "high_level", "average_level_int", "low_level_int", "mid_level_int", "high_level_int") + FUNCTION = "react" + CATEGORY = "KJNodes/experimental" + + def react(self, low_level, mid_level, high_level, low_range_hz, mid_range_hz, average_level, multiplier, normalize): + low_level *= multiplier + mid_level *= multiplier + high_level *= multiplier + average_level = average_level * multiplier + + if normalize: + low_level = low_level / 255 + mid_level = mid_level / 255 + high_level = high_level / 255 + average_level = average_level / 255 + + low_level_int = int(low_level) + mid_level_int = int(mid_level) + high_level_int = int(high_level) + average_level_int = int(average_level) + + + return (average_level, low_level, mid_level, high_level, average_level_int, low_level_int, mid_level_int, high_level_int) + NODE_CLASS_MAPPINGS = { "INTConstant": INTConstant, "FloatConstant": FloatConstant, @@ -2713,7 +2753,8 @@ NODE_CLASS_MAPPINGS = { "FlipSigmasAdjusted": FlipSigmasAdjusted, "InjectNoiseToLatent": InjectNoiseToLatent, "AddLabel": AddLabel, - "ReferenceOnlySimple3": ReferenceOnlySimple3 + "ReferenceOnlySimple3": ReferenceOnlySimple3, + "SoundReactive": SoundReactive } NODE_DISPLAY_NAME_MAPPINGS = { "INTConstant": "INT Constant", @@ -2763,6 +2804,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "FlipSigmasAdjusted": "FlipSigmasAdjusted", "InjectNoiseToLatent": "InjectNoiseToLatent", "AddLabel": "AddLabel", - "ReferenceOnlySimple3": "ReferenceOnlySimple3" + "ReferenceOnlySimple3": "ReferenceOnlySimple3", + "SoundReactive": "SoundReactive" } \ No newline at end of file diff --git a/web/js/jsnodes.js b/web/js/jsnodes.js index 1a1757a..e90f83e 100644 --- a/web/js/jsnodes.js +++ b/web/js/jsnodes.js @@ -28,8 +28,120 @@ app.registerExtension({ this.addInput(`conditioning_${i}`, this.cond_type) } }); - } + } + case "SoundReactive": + nodeType.prototype.onNodeCreated = function () { + let audioContext; + let microphoneStream; + let animationFrameId; + let analyser; + let dataArray; + let lowRangeHz; + let midRangeHz; + + // Function to update the widget value in real-time + const updateWidgetValueInRealTime = () => { + // Ensure analyser and dataArray are defined before using them + if (analyser && dataArray) { + analyser.getByteFrequencyData(dataArray); + + // Calculate frequency bin width (frequency resolution) + const frequencyBinWidth = audioContext.sampleRate / analyser.fftSize; + // Convert the widget values from Hz to indices + const lowRangeIndex = Math.floor(lowRangeHz / frequencyBinWidth); + const midRangeIndex = Math.floor(midRangeHz / frequencyBinWidth); + + // Define frequency ranges for low, mid, and high + const frequencyRanges = { + low: { start: 0, end: lowRangeIndex }, + mid: { start: lowRangeIndex, end: midRangeIndex }, + high: { start: midRangeIndex, end: dataArray.length } + }; + const lowRangeHzWidget = this.widgets.find(w => w.name === "low_range_hz"); + if (lowRangeHzWidget) lowRangeHz = lowRangeHzWidget.value; + + const midRangeHzWidget = this.widgets.find(w => w.name === "mid_range_hz"); + if (midRangeHzWidget) midRangeHz = midRangeHzWidget.value; + + // Function to calculate the average value for a frequency range + const calculateAverage = (start, end) => { + const sum = dataArray.slice(start, end).reduce((acc, val) => acc + val, 0); + return sum / (end - start); + }; + // Calculate the average levels for each frequency range + const lowLevel = calculateAverage(frequencyRanges.low.start, frequencyRanges.low.end); + const midLevel = calculateAverage(frequencyRanges.low.end, frequencyRanges.mid.end); // mid starts where low ends + const highLevel = calculateAverage(frequencyRanges.mid.end, frequencyRanges.high.end); // high starts where mid ends + const averageLevel = dataArray.reduce((sum, averageLevel) => sum + averageLevel, 0) / dataArray.length; + + // Update the widget values + const averageLevelWidget = this.widgets.find(w => w.name === "average_level"); + if (averageLevelWidget) averageLevelWidget.value = averageLevel; + + const lowLevelWidget = this.widgets.find(w => w.name === "low_level"); + if (lowLevelWidget) lowLevelWidget.value = lowLevel; + + const midLevelWidget = this.widgets.find(w => w.name === "mid_level"); + if (midLevelWidget) midLevelWidget.value = midLevel; + + const highLevelWidget = this.widgets.find(w => w.name === "high_level"); + if (highLevelWidget) highLevelWidget.value = highLevel; + + animationFrameId = requestAnimationFrame(updateWidgetValueInRealTime); + } + }; + + // Function to start capturing audio from the microphone + const startMicrophoneCapture = () => { + // Only create the audio context and analyser once + if (!audioContext) { + audioContext = new (window.AudioContext || window.webkitAudioContext)(); + // Access the sample rate of the audio context + console.log(`Sample rate: ${audioContext.sampleRate}Hz`); + analyser = audioContext.createAnalyser(); + analyser.fftSize = 2048; + dataArray = new Uint8Array(analyser.frequencyBinCount); + // Get the range values from widgets (assumed to be in Hz) + const lowRangeWidget = this.widgets.find(w => w.name === "low_range_hz"); + if (lowRangeWidget) lowRangeHz = lowRangeWidget.value; + + const midRangeWidget = this.widgets.find(w => w.name === "mid_range_hz"); + if (midRangeWidget) midRangeHz = midRangeWidget.value; + } + + navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => { + microphoneStream = stream; + const microphone = audioContext.createMediaStreamSource(stream); + microphone.connect(analyser); + updateWidgetValueInRealTime(); + }).catch(error => { + console.error('Access to microphone was denied or an error occurred:', error); + }); + }; + + // Function to stop capturing audio from the microphone + const stopMicrophoneCapture = () => { + if (animationFrameId) { + cancelAnimationFrame(animationFrameId); + } + if (microphoneStream) { + microphoneStream.getTracks().forEach(track => track.stop()); + } + if (audioContext) { + audioContext.close(); + // Reset audioContext to ensure it can be created again when starting + audioContext = null; + } + }; + + // Add start button + this.addWidget("button", "Start mic capture", null, startMicrophoneCapture); + + // Add stop button + this.addWidget("button", "Stop mic capture", null, stopMicrophoneCapture); + }; break; + } }, }); \ No newline at end of file diff --git a/web/js/setgetnodes.js b/web/js/setgetnodes.js index bcdb17d..15a898c 100644 --- a/web/js/setgetnodes.js +++ b/web/js/setgetnodes.js @@ -313,13 +313,11 @@ app.registerExtension({ if (setter) { const slotInfo = setter.inputs[slot]; - console.log(slotInfo) const link = this.graph.links[slotInfo.link]; - console.log(link) return link; } else { const errorMessage = "No SetNode found for " + this.widgets[0].value + "(" + this.type + ")"; - alert(errorMessage); + console.log(errorMessage); throw new Error(errorMessage); } }